Skip to content

Commit 52cdeeb

Browse files
Add scalable vector shuffle support for ARM SVE2 (#8898)
* Introduce helpers for scalable vector shuffles in DecomposeVectorShuffle (kept ARM-specific for now) * Implement shuffle lowering via decomposition into native SVE2 TBL/TBL2 operations, with peephole optimizations (e.g. WHILELT, concat padding, redundant broadcast removal) * Improve SVE2 broadcast performance by emitting TBL instead of insert sequences * Handle edge cases in shuffle decomposition (e.g. undef lanes) and add validation/assertions * Update tests for wider vector sizes and adjust SVE2 expectations; skip known LLVM <22 failures --------- Co-authored-by: Alex Reinking <areinking@adobe.com>
1 parent 455b34b commit 52cdeeb

15 files changed

Lines changed: 922 additions & 96 deletions

Makefile

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,7 @@ SOURCE_FILES = \
491491
Debug.cpp \
492492
DebugArguments.cpp \
493493
DebugToFile.cpp \
494+
DecomposeVectorShuffle.cpp \
494495
Definition.cpp \
495496
Deinterleave.cpp \
496497
Derivative.cpp \
@@ -687,6 +688,7 @@ HEADER_FILES = \
687688
Debug.h \
688689
DebugArguments.h \
689690
DebugToFile.h \
691+
DecomposeVectorShuffle.h \
690692
Definition.h \
691693
Deinterleave.h \
692694
Derivative.h \

src/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ target_sources(
9595
Debug.h
9696
DebugArguments.h
9797
DebugToFile.h
98+
DecomposeVectorShuffle.h
9899
Definition.h
99100
Deinterleave.h
100101
Derivative.h
@@ -279,6 +280,7 @@ target_sources(
279280
Debug.cpp
280281
DebugArguments.cpp
281282
DebugToFile.cpp
283+
DecomposeVectorShuffle.cpp
282284
Definition.cpp
283285
Deinterleave.cpp
284286
Derivative.cpp

src/CodeGen_ARM.cpp

Lines changed: 377 additions & 41 deletions
Large diffs are not rendered by default.

src/CodeGen_LLVM.cpp

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4155,7 +4155,9 @@ void CodeGen_LLVM::visit(const Shuffle *op) {
41554155
} else {
41564156
internal_assert(op->indices[0] == 0);
41574157
}
4158-
value = create_broadcast(value, op->indices.size());
4158+
if (op->indices.size() > 1) {
4159+
value = create_broadcast(value, op->indices.size());
4160+
}
41594161
return;
41604162
}
41614163
}
@@ -5445,6 +5447,10 @@ int CodeGen_LLVM::get_vector_num_elements(const llvm::Type *t) {
54455447
}
54465448
}
54475449

5450+
int CodeGen_LLVM::get_vector_num_elements(const llvm::Value *v) {
5451+
return get_vector_num_elements(v->getType());
5452+
}
5453+
54485454
llvm::Type *CodeGen_LLVM::llvm_type_of(LLVMContext *c, Halide::Type t,
54495455
int effective_vscale) const {
54505456
if (t.lanes() == 1) {
@@ -5481,23 +5487,7 @@ llvm::Type *CodeGen_LLVM::get_vector_type(llvm::Type *t, int n,
54815487
switch (type_constraint) {
54825488
case VectorTypeConstraint::None:
54835489
if (effective_vscale > 0) {
5484-
bool wide_enough = true;
5485-
// TODO(https://github.com/halide/Halide/issues/8119): Architecture
5486-
// specific code should not go here. Ideally part of this can go
5487-
// away via LLVM fixes and modifying intrinsic selection to handle
5488-
// scalable vs. fixed vectors. Making this method virtual is
5489-
// possibly expensive.
5490-
if (target.arch == Target::ARM) {
5491-
if (!target.has_feature(Target::NoNEON)) {
5492-
// force booleans into bytes. TODO(https://github.com/halide/Halide/issues/8119): figure out a better way to do this.
5493-
int bit_size = std::max((int)t->getScalarSizeInBits(), 8);
5494-
wide_enough = (bit_size * n) > 128;
5495-
} else {
5496-
// TODO(https://github.com/halide/Halide/issues/8119): AArch64 SVE2 support is crashy with scalable vectors of min size 1.
5497-
wide_enough = (n / effective_vscale) > 1;
5498-
}
5499-
}
5500-
scalable = wide_enough && ((n % effective_vscale) == 0);
5490+
scalable = (n % effective_vscale) == 0;
55015491
if (scalable) {
55025492
n = n / effective_vscale;
55035493
}

src/CodeGen_LLVM.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,10 @@ class CodeGen_LLVM : public IRVisitor {
605605
const std::function<llvm::Value *(llvm::Value *)> &fn);
606606

607607
/** Get number of vector elements, taking into account scalable vectors. Returns 1 for scalars. */
608+
// @{
608609
int get_vector_num_elements(const llvm::Type *t);
610+
int get_vector_num_elements(const llvm::Value *v);
611+
// @}
609612

610613
/** Interface to abstract vector code generation as LLVM is now
611614
* providing multiple options to express even simple vector

src/DecomposeVectorShuffle.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
#include "DecomposeVectorShuffle.h"
2+
3+
#include <unordered_map>
4+
5+
namespace Halide::Internal {
6+
7+
std::vector<std::vector<NativeShuffle>> decompose_to_native_shuffles(
8+
int src_lanes, const std::vector<int> &indices, int vl) {
9+
10+
int dst_lanes = static_cast<int>(indices.size());
11+
int src_lanes_aligned = align_up(src_lanes, vl);
12+
13+
// Adjust indices so that src vectors are aligned up to multiple of vl
14+
std::vector<int> aligned_indices = indices;
15+
for (int &idx : aligned_indices) {
16+
if (idx >= src_lanes) {
17+
idx += src_lanes_aligned - src_lanes;
18+
}
19+
}
20+
21+
const int num_dst_slices = align_up(dst_lanes, vl) / vl;
22+
std::vector<std::vector<NativeShuffle>> all_steps(num_dst_slices);
23+
24+
for (int dst_slice = 0; dst_slice < num_dst_slices; dst_slice++) {
25+
std::unordered_map<int, int> slice_to_step;
26+
auto &steps = all_steps[dst_slice];
27+
const int dst_start = dst_slice * vl;
28+
29+
for (int dst_index = dst_start; dst_index < dst_start + vl && dst_index < dst_lanes; ++dst_index) {
30+
const int src_index = aligned_indices[dst_index];
31+
if (src_index < 0) {
32+
continue;
33+
}
34+
35+
const int src_slice = src_index / vl;
36+
const int lane_in_src_slice = src_index % vl;
37+
const int lane_in_dst_slice = dst_index - dst_start;
38+
39+
if (steps.empty()) {
40+
// first slice in this block
41+
slice_to_step[src_slice] = 0;
42+
steps.emplace_back(vl, src_slice, SliceIndexNone);
43+
steps.back().lane_map[lane_in_dst_slice] = lane_in_src_slice;
44+
45+
} else if (auto itr = slice_to_step.find(src_slice); itr != slice_to_step.end()) {
46+
// slice already seen
47+
NativeShuffle &step = steps[itr->second];
48+
bool is_a = (step.slice_a != SliceIndexCarryPrevResult && step.slice_a == src_slice);
49+
int offset = is_a ? 0 : vl;
50+
step.lane_map[lane_in_dst_slice] = lane_in_src_slice + offset;
51+
52+
} else if (steps[0].slice_b == SliceIndexNone) {
53+
// add as 'b' of first step if b is unused
54+
slice_to_step[src_slice] = 0;
55+
steps[0].slice_b = src_slice;
56+
steps[0].lane_map[lane_in_dst_slice] = lane_in_src_slice + vl;
57+
58+
} else {
59+
// otherwise chain a new step
60+
slice_to_step[src_slice] = static_cast<int>(steps.size());
61+
// new step uses previous result as 'a', so we use 'b' for this one
62+
steps.emplace_back(vl, SliceIndexCarryPrevResult, src_slice);
63+
64+
// Except for the first step, we need to arrange indices
65+
// so that the output carried from the previous step is kept
66+
auto &lane_map = steps.back().lane_map;
67+
// initialize lane_map as identical copy
68+
for (size_t lane_idx = 0; lane_idx < lane_map.size(); ++lane_idx) {
69+
lane_map[lane_idx] = lane_idx;
70+
}
71+
// update for this index
72+
lane_map[lane_in_dst_slice] = lane_in_src_slice + vl;
73+
}
74+
}
75+
}
76+
77+
return all_steps;
78+
}
79+
80+
} // namespace Halide::Internal

src/DecomposeVectorShuffle.h

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
#ifndef HALIDE_DECOMPOSE_VECTOR_SHUFFLE_H
2+
#define HALIDE_DECOMPOSE_VECTOR_SHUFFLE_H
3+
4+
/** \file
5+
*
6+
* Perform vector shuffle by decomposing the operation to
7+
* a sequence of the sub shuffle steps where each step is a shuffle of:
8+
* - One or two slices as input (slice_a and slice_b)
9+
* - Produce one slice (dst slice)
10+
* - All the slices have the same length as target native vector (vl)
11+
*
12+
* The structure of the sequence of steps consists of:
13+
* 1. Outer loop to iterate the slices of dst vector.
14+
* 2. Inner loop to iterate the native shuffle steps to complete a single dst slice.
15+
* This can be multiple steps because a single native shuffle can take
16+
* only 2 slices (native vector length x 2) at most, while we may need
17+
* to fetch from wider location in the src vector.
18+
*
19+
* The following example, log of test code, illustrates how it works.
20+
*
21+
* src_lanes: 17, dst_lanes: 7, vl: 4
22+
* input a: [0, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160, ]
23+
* input b: [170, 180, 190, 200, 210, 220, 230, 240, 250, 260, 270, 280, 290, 300, 310, 320, 330, ]
24+
* indices: [6, 13, 24, 14, 7, 11, 5, ]
25+
*
26+
* slice a:[40, 50, 60, 70, ], slice b:[120, 130, 140, 150, ], indices:[2, 5, -1, 6, ]
27+
* => slice output:[60, 130, -559038801, 140, ]
28+
* slice a:[60, 130, -559038801, 140, ], slice b:[210, 220, 230, 240, ], indices:[0, 1, 7, 3, ]
29+
* => slice output:[60, 130, 240, 140, ]
30+
* slice a:[40, 50, 60, 70, ], slice b:[80, 90, 100, 110, ], indices:[3, 7, 1, -1, ]
31+
* => slice output:[70, 110, 50, -559038801, ]
32+
*
33+
* output: [60, 130, 240, 140, 70, 110, 50, ]
34+
*
35+
*/
36+
37+
#include "Error.h"
38+
#include "Util.h"
39+
40+
#include <optional>
41+
#include <vector>
42+
43+
namespace Halide {
44+
namespace Internal {
45+
46+
/** Enum to represent the special cases of slice index */
47+
enum {
48+
SliceIndexNone = -1,
49+
SliceIndexCarryPrevResult = -2,
50+
};
51+
52+
struct NativeShuffle {
53+
int slice_a;
54+
int slice_b;
55+
std::vector<int> lane_map;
56+
57+
NativeShuffle(int vl, int a, int b)
58+
: slice_a(a), slice_b(b) {
59+
lane_map.resize(vl, SliceIndexNone);
60+
}
61+
};
62+
63+
std::vector<std::vector<NativeShuffle>> decompose_to_native_shuffles(
64+
int src_lanes, const std::vector<int> &indices, int vl);
65+
66+
/** Algorithm logic for shuffle decomposition, parameterized on vector type
67+
* and a codegen-like class that provides primitive vector operations.
68+
*/
69+
template<typename CodeGenTy, typename VecTy>
70+
struct DecomposeVectorShuffle {
71+
// TODO: when upgrading to C++20, replace with a concept.
72+
// get_vector_num_elements may be overloaded (e.g. on Type* and Value*), so use
73+
// expression SFINAE rather than a method pointer to handle overload resolution.
74+
static_assert(std::is_convertible_v<decltype(std::declval<CodeGenTy &>().get_vector_num_elements(std::declval<VecTy>())), int>,
75+
"CodeGenTy must provide: int get_vector_num_elements(VecTy)");
76+
static_assert(std::is_invocable_r_v<VecTy, decltype(&CodeGenTy::slice_vector), CodeGenTy &, const VecTy &, int, int>,
77+
"CodeGenTy must provide: VecTy slice_vector(const VecTy &, int, int)");
78+
static_assert(std::is_invocable_r_v<VecTy, decltype(&CodeGenTy::concat_vectors), CodeGenTy &, const std::vector<VecTy> &>,
79+
"CodeGenTy must provide: VecTy concat_vectors(const std::vector<VecTy> &)");
80+
static_assert(std::is_invocable_r_v<VecTy, decltype(&CodeGenTy::shuffle_scalable_vectors_general), CodeGenTy &,
81+
const VecTy &, const VecTy &, const std::vector<int> &>,
82+
"CodeGenTy must provide: VecTy shuffle_scalable_vectors_general(const VecTy &, const VecTy &, const std::vector<int> &)");
83+
static_assert(std::is_invocable_r_v<VecTy, decltype(&CodeGenTy::create_undef_vector_like), CodeGenTy &, const VecTy &, int>,
84+
"CodeGenTy must provide: VecTy create_undef_vector_like(const VecTy &, int)");
85+
86+
DecomposeVectorShuffle(CodeGenTy &codegen, const VecTy &src_a, const VecTy &src_b, int src_lanes, int vl)
87+
: codegen(codegen),
88+
vl(vl),
89+
src_a(align_up_vector(src_a, vl)),
90+
src_b(align_up_vector(src_b, vl)),
91+
src_lanes(src_lanes),
92+
src_lanes_aligned(align_up(src_lanes, vl)) {
93+
}
94+
95+
VecTy run(const std::vector<int> &indices) {
96+
auto shuffle_plan = decompose_to_native_shuffles(src_lanes, indices, vl);
97+
int dst_lanes = static_cast<int>(indices.size());
98+
99+
// process each block divided by vl
100+
std::vector<VecTy> shuffled_dst_slices;
101+
shuffled_dst_slices.reserve(shuffle_plan.size());
102+
103+
for (const auto &steps_for_dst_slice : shuffle_plan) {
104+
std::optional<VecTy> dst_slice = std::nullopt;
105+
for (const auto &step : steps_for_dst_slice) {
106+
// Obtain 1st slice a
107+
VecTy a;
108+
if (step.slice_a == SliceIndexCarryPrevResult) {
109+
internal_assert(dst_slice.has_value()) << "Tried to carry from undefined previous result";
110+
a = *dst_slice;
111+
} else {
112+
a = get_vl_slice(step.slice_a);
113+
}
114+
// Obtain 2nd slice b
115+
std::optional<VecTy> b;
116+
if (step.slice_b == SliceIndexNone) {
117+
b = std::nullopt;
118+
} else {
119+
b = std::optional<VecTy>(get_vl_slice(step.slice_b));
120+
}
121+
// Perform shuffle where vector length is aligned
122+
dst_slice = codegen.shuffle_scalable_vectors_general(a, b.value_or(VecTy{}), step.lane_map);
123+
}
124+
if (!dst_slice.has_value()) {
125+
// No shuffle step for this slice, i.e. all the indices are -1
126+
dst_slice = codegen.create_undef_vector_like(src_a, vl);
127+
}
128+
shuffled_dst_slices.push_back(*dst_slice);
129+
}
130+
131+
return codegen.slice_vector(codegen.concat_vectors(shuffled_dst_slices), 0, dst_lanes);
132+
}
133+
134+
private:
135+
// Helper to extract slice with lanes=vl
136+
VecTy get_vl_slice(int slice_index) {
137+
const int num_slices_a = src_lanes_aligned / vl;
138+
int start_index = slice_index * vl;
139+
if (slice_index < num_slices_a) {
140+
return codegen.slice_vector(src_a, start_index, vl);
141+
} else {
142+
start_index -= src_lanes_aligned;
143+
return codegen.slice_vector(src_b, start_index, vl);
144+
}
145+
}
146+
147+
VecTy align_up_vector(const VecTy &v, int align) {
148+
int len = codegen.get_vector_num_elements(v);
149+
return codegen.slice_vector(v, 0, align_up(len, align));
150+
}
151+
152+
CodeGenTy &codegen;
153+
int vl;
154+
VecTy src_a;
155+
VecTy src_b;
156+
int src_lanes;
157+
int src_lanes_aligned;
158+
};
159+
160+
} // namespace Internal
161+
} // namespace Halide
162+
163+
#endif

test/correctness/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ tests(GROUPS correctness
8080
debug_to_file.cpp
8181
debug_to_file_multiple_outputs.cpp
8282
debug_to_file_reorder.cpp
83+
decompose_vector_shuffle.cpp
8384
deferred_loop_level.cpp
8485
deinterleave4.cpp
8586
device_buffer_copies_with_profile.cpp

0 commit comments

Comments
 (0)