Skip to content

Commit 3eef5db

Browse files
committed
Use Catanzaro's algorithm for non-power-of-two interleaves
1 parent 23b79ba commit 3eef5db

3 files changed

Lines changed: 241 additions & 39 deletions

File tree

src/CodeGen_LLVM.cpp

Lines changed: 81 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2211,6 +2211,8 @@ Value *CodeGen_LLVM::interleave_vectors(const std::vector<Value *> &vecs) {
22112211
}
22122212
int vec_elements = get_vector_num_elements(vecs[0]->getType());
22132213

2214+
int factor = gcd(vec_elements, (int)vecs.size());
2215+
22142216
if (vecs.size() == 1) {
22152217
return vecs[0];
22162218
} else if (vecs.size() == 2) {
@@ -2221,57 +2223,97 @@ Value *CodeGen_LLVM::interleave_vectors(const std::vector<Value *> &vecs) {
22212223
indices[i] = i % 2 == 0 ? i / 2 : i / 2 + vec_elements;
22222224
}
22232225
return optimization_fence(shuffle_vectors(a, b, indices));
2224-
} else {
2225-
// Grab the even and odd elements of vecs.
2226-
vector<Value *> even_vecs;
2227-
vector<Value *> odd_vecs;
2228-
for (size_t i = 0; i < vecs.size(); i++) {
2229-
if (i % 2 == 0) {
2230-
even_vecs.push_back(vecs[i]);
2231-
} else {
2232-
odd_vecs.push_back(vecs[i]);
2226+
} else if (factor == 1) {
2227+
// The number of vectors and the vector length is
2228+
// coprime. (E.g. interleaving an odd number of vectors of some
2229+
// power-of-two length). Use the algorithm from "A Decomposition for
2230+
// In-place Matrix Transposition" by Catanzaro et al.
2231+
std::vector<Value *> v = vecs;
2232+
2233+
// Using unary shuffles, get each element into the right ultimate
2234+
// lane. This works out without collisions because the number of vectors
2235+
// and the length of each vector is coprime.
2236+
const int num_vecs = (int)v.size();
2237+
std::vector<int> shuffle(vec_elements);
2238+
for (int i = 0; i < num_vecs; i++) {
2239+
for (int j = 0; j < vec_elements; j++) {
2240+
int k = j * num_vecs + i;
2241+
shuffle[k % vec_elements] = j;
22332242
}
2243+
v[i] = shuffle_vectors(v[i], v[i], shuffle);
22342244
}
22352245

2236-
// If the number of vecs is odd, save the last one for later.
2237-
Value *last = nullptr;
2238-
if (even_vecs.size() > odd_vecs.size()) {
2239-
last = even_vecs.back();
2240-
even_vecs.pop_back();
2246+
// We intentionally don't put an optimization fence after the unary
2247+
// shuffles, because some architectures have a two-way shuffle, so it
2248+
// helps to fuse the unary shuffle into the first layer of two-way
2249+
// blends below.
2250+
2251+
// Now we need to transfer the elements across the vectors. If we
2252+
// reorder the vectors, this becomes a rotation across the vectors of a
2253+
// different amount per lane.
2254+
std::vector<Value *> new_v(v.size());
2255+
for (int i = 0; i < num_vecs; i++) {
2256+
int j = (i * vec_elements) % num_vecs;
2257+
new_v[i] = v[j];
22412258
}
2242-
internal_assert(even_vecs.size() == odd_vecs.size());
2259+
v.swap(new_v);
22432260

2244-
// Interleave the even and odd parts.
2245-
Value *even = interleave_vectors(even_vecs);
2246-
Value *odd = interleave_vectors(odd_vecs);
2261+
std::vector<int> rotation(vec_elements, 0);
2262+
for (int i = 0; i < vec_elements; i++) {
2263+
int k = (i * num_vecs) % vec_elements;
2264+
rotation[k] = (i * num_vecs) / vec_elements;
2265+
}
2266+
internal_assert(rotation[0] == 0);
22472267

2248-
if (last) {
2249-
int result_elements = vec_elements * vecs.size();
2268+
// We'll handle each bit of the rotation one at a time with a two-way
2269+
// shuffle.
2270+
int d = 1;
2271+
while (d < num_vecs) {
22502272

2251-
// Interleave even and odd, leaving a space for the last element.
2252-
vector<int> indices(result_elements, -1);
2253-
for (int i = 0, idx = 0; i < result_elements; i++) {
2254-
if (i % vecs.size() < vecs.size() - 1) {
2255-
indices[i] = idx % 2 == 0 ? idx / 2 : idx / 2 + vec_elements * even_vecs.size();
2256-
idx++;
2257-
}
2273+
for (int i = 0; i < vec_elements; i++) {
2274+
shuffle[i] = ((rotation[i] & d) == 0) ? i : (i + vec_elements);
22582275
}
2259-
Value *even_odd = shuffle_vectors(even, odd, indices);
22602276

2261-
// Interleave the last vector into the result.
2262-
last = slice_vector(last, 0, result_elements);
2263-
for (int i = 0; i < result_elements; i++) {
2264-
if (i % vecs.size() < vecs.size() - 1) {
2265-
indices[i] = i;
2266-
} else {
2267-
indices[i] = i / vecs.size() + result_elements;
2268-
}
2277+
for (int i = 0; i < num_vecs; i++) {
2278+
int j = (i + num_vecs - d) % num_vecs;
2279+
new_v[i] = shuffle_vectors(v[i], v[j], shuffle);
22692280
}
22702281

2271-
return shuffle_vectors(even_odd, last, indices);
2272-
} else {
2273-
return interleave_vectors({even, odd});
2282+
v.swap(new_v);
2283+
2284+
d *= 2;
22742285
}
2286+
2287+
return concat_vectors(v);
2288+
2289+
} else {
2290+
// The number of vectors shares a factor with the length of the
2291+
// vectors. Pick some large factor of the number of vectors, interleave
2292+
// in separate groups, and then interleave the results.
2293+
const int n = (int)vecs.size();
2294+
int f = 1;
2295+
for (int i = 2; i < n; i++) {
2296+
if (n % i == 0) {
2297+
f = i;
2298+
break;
2299+
}
2300+
}
2301+
2302+
internal_assert(f > 1 && f < n);
2303+
2304+
vector<vector<Value *>> groups(f);
2305+
for (size_t i = 0; i < vecs.size(); i++) {
2306+
groups[i % f].push_back(vecs[i]);
2307+
}
2308+
2309+
// Interleave each group
2310+
vector<Value *> interleaved(f);
2311+
for (int i = 0; i < f; i++) {
2312+
interleaved[i] = optimization_fence(interleave_vectors(groups[i]));
2313+
}
2314+
2315+
// Interleave the result
2316+
return interleave_vectors(interleaved);
22752317
}
22762318
}
22772319

test/performance/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ tests(GROUPS performance
1616
fast_pow.cpp
1717
fast_sine_cosine.cpp
1818
gpu_half_throughput.cpp
19+
interleave.cpp
1920
jit_stress.cpp
2021
lots_of_inputs.cpp
2122
memcpy.cpp

test/performance/interleave.cpp

Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#include "Halide.h"
2+
#include "halide_benchmark.h"
3+
#include "halide_test_dirs.h"
4+
5+
#include <cstdio>
6+
7+
using namespace Halide;
8+
using namespace Halide::Tools;
9+
10+
struct Result {
11+
int type_size, factor;
12+
double bandwidth;
13+
};
14+
15+
template<typename T>
16+
Result test_interleave(int factor, const Target &t) {
17+
const int N = 8192;
18+
Buffer<T> in(N, factor), out(N * factor);
19+
20+
for (int y = 0; y < factor; y++) {
21+
for (int x = 0; x < N; x++) {
22+
in(x, y) = (T)(x * factor + y);
23+
}
24+
}
25+
26+
Func output;
27+
Var x, y;
28+
29+
output(x) = in(x / factor, x % factor);
30+
31+
Var xi, yi;
32+
output.unroll(x, factor, TailStrategy::RoundUp).vectorize(x, t.natural_vector_size<T>(), TailStrategy::RoundUp);
33+
output.output_buffer().dim(0).set_min(0);
34+
35+
output.compile_jit();
36+
37+
output.realize(out);
38+
39+
double time = benchmark(20, 20, [&]() {
40+
output.realize(out);
41+
});
42+
43+
for (int y = 0; y < factor; y++) {
44+
for (int x = 0; x < N; x++) {
45+
uint64_t actual = out(x * factor + y), correct = in(x, y);
46+
if (actual != correct) {
47+
std::cerr << "For factor " << factor
48+
<< "out(" << x << " * " << factor << " + " << y << ") = "
49+
<< actual << " instead of " << correct << "\n";
50+
exit(1);
51+
}
52+
}
53+
}
54+
55+
// Uncomment to dump asm for inspection
56+
// output.compile_to_assembly("/dev/stdout",
57+
// std::vector<Argument>{in}, "interleave", t);
58+
59+
return Result{(int)sizeof(T), factor, out.size_in_bytes() / (1.0e9 * time)};
60+
}
61+
62+
template<typename T>
63+
Result test_deinterleave(int factor, const Target &t) {
64+
const int N = 8192;
65+
Buffer<T> in(N * factor), out(N, factor);
66+
67+
for (int x = 0; x < N; x++) {
68+
for (int y = 0; y < factor; y++) {
69+
in(x * factor + y) = (T)(x + y * N);
70+
}
71+
}
72+
73+
Func output;
74+
Var x, y;
75+
76+
output(x, y) = in(x * factor + y);
77+
78+
Var xi, yi;
79+
output.reorder(y, x).bound(y, 0, factor).unroll(y).vectorize(x, t.natural_vector_size<T>(), TailStrategy::RoundUp);
80+
// output.output_buffer().dim(0).set_min(0);
81+
82+
output.compile_jit();
83+
84+
output.realize(out);
85+
86+
double time = benchmark(20, 20, [&]() {
87+
output.realize(out);
88+
});
89+
90+
for (int y = 0; y < factor; y++) {
91+
for (int x = 0; x < N; x++) {
92+
uint64_t actual = out(x, y), correct = in(x * factor + y);
93+
if (actual != correct) {
94+
std::cerr << "For factor " << factor
95+
<< "out(" << x << ", " << y << ") = "
96+
<< actual << " instead of " << correct << "\n";
97+
exit(1);
98+
}
99+
}
100+
}
101+
102+
// Uncomment to dump asm for inspection
103+
output.compile_to_assembly("/dev/stdout",
104+
std::vector<Argument>{in}, "interleave", t);
105+
106+
return Result{(int)sizeof(T), factor, out.size_in_bytes() / (1.0e9 * time)};
107+
}
108+
109+
int main(int argc, char **argv) {
110+
Target target = get_jit_target_from_environment();
111+
if (target.arch == Target::WebAssembly) {
112+
printf("[SKIP] Performance tests are meaningless and/or misleading under WebAssembly interpreter.\n");
113+
return 0;
114+
}
115+
116+
// Set the target features to use for dumping to assembly
117+
target.set_features({Target::NoRuntime, Target::NoAsserts, Target::NoBoundsQuery});
118+
119+
std::cout << "\nbytes, interleave factor, interleave bandwidth (GB/s), deinterleave bandwidth (GB/s):\n";
120+
#if 0
121+
for (int t : {1, 2, 4, 8}) {
122+
for (int f = 2; f < 16; f++) {
123+
#else
124+
{
125+
{
126+
int t = 1, f = 4;
127+
#endif
128+
Result r1, r2;
129+
switch (t) {
130+
case 1:
131+
r1 = test_interleave<uint8_t>(f, target);
132+
r2 = test_deinterleave<uint8_t>(f, target);
133+
break;
134+
case 2:
135+
r1 = test_interleave<uint16_t>(f, target);
136+
r2 = test_deinterleave<uint16_t>(f, target);
137+
break;
138+
case 4:
139+
r1 = test_interleave<uint32_t>(f, target);
140+
r2 = test_deinterleave<uint32_t>(f, target);
141+
break;
142+
case 8:
143+
r1 = test_interleave<uint64_t>(f, target);
144+
r2 = test_deinterleave<uint64_t>(f, target);
145+
break;
146+
default:
147+
break;
148+
}
149+
std::cout << r1.type_size << " "
150+
<< r1.factor << " "
151+
<< r1.bandwidth << " "
152+
<< r2.bandwidth << "\n";
153+
154+
}
155+
}
156+
157+
printf("Success!\n");
158+
return 0;
159+
}

0 commit comments

Comments
 (0)