Skip to content

Commit 7cdfd64

Browse files
stevesuzuki-armalexreinking
authored andcommitted
Shuffle scalable vector in CodeGen_ARM
By design, LLVM shufflevector doesn't accept scalable vectors. So, we try to use llvm.vector.xx intrinsic where possible. However, those are not enough to cover wide usage of shuffles in Halide. To handle arbitrary index pattern, we decompose a shuffle operation to a sequence of multiple native shuffles, which are lowered to Arm SVE2 intrinsic TBL or TBL2. Another approach could be to perform shuffle in fixed sized vector by adding conversion between scalable vector and fixed vector. However, it seems to be only possible via load/store memory, which would presumably be poor performance. This change also includes: - Peep-hole the particular predicate pattern to emit WHILELT instruction - Shuffle 1bit type scalable vectors as 8bit with type casts - Peep-hole concat_vectors for padding to align up vector - Fix redundant broadcast in CodeGen_LLVM
1 parent 263f6c6 commit 7cdfd64

6 files changed

Lines changed: 636 additions & 36 deletions

File tree

src/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ target_sources(
276276
Debug.cpp
277277
DebugArguments.cpp
278278
DebugToFile.cpp
279+
DecomposeVectorShuffle.cpp
279280
Definition.cpp
280281
Deinterleave.cpp
281282
Derivative.cpp

src/CodeGen_ARM.cpp

Lines changed: 226 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "CodeGen_Posix.h"
77
#include "ConciseCasts.h"
88
#include "Debug.h"
9+
#include "DecomposeVectorShuffle.h"
910
#include "DistributeShifts.h"
1011
#include "IREquality.h"
1112
#include "IRMatch.h"
@@ -20,6 +21,7 @@
2021
namespace Halide {
2122
namespace Internal {
2223

24+
using std::optional;
2325
using std::ostringstream;
2426
using std::pair;
2527
using std::string;
@@ -217,6 +219,9 @@ class CodeGen_ARM : public CodeGen_Posix {
217219

218220
Value *interleave_vectors(const std::vector<Value *> &) override;
219221
Value *shuffle_vectors(Value *a, Value *b, const std::vector<int> &indices) override;
222+
Value *shuffle_scalable_vectors_general(Value *a, Value *b, const std::vector<int> &indices);
223+
Value *codegen_shuffle_indices(int bits, const std::vector<int> &indices);
224+
Value *codegen_whilelt(int total_lanes, int start, int end);
220225
void codegen_vector_reduce(const VectorReduce *, const Expr &) override;
221226
bool codegen_dot_product_vector_reduce(const VectorReduce *, const Expr &);
222227
bool codegen_pairwise_vector_reduce(const VectorReduce *, const Expr &);
@@ -237,6 +242,7 @@ class CodeGen_ARM : public CodeGen_Posix {
237242
};
238243
vector<Pattern> casts, calls, negations;
239244

245+
int natural_vector_size(const Halide::Type &t) const;
240246
string mcpu_target() const override;
241247
string mcpu_tune() const override;
242248
string mattrs() const override;
@@ -267,6 +273,37 @@ class CodeGen_ARM : public CodeGen_Posix {
267273
return Shuffle::make_concat({const_true(true_lanes), const_false(false_lanes)});
268274
}
269275
}
276+
277+
/** Handle general shuffle of vectors. See DecomposeVectorShuffle.h about how it works */
278+
struct VectorShuffler : public DecomposeVectorShuffle<VectorShuffler, Value *> {
279+
VectorShuffler(Value *src_a, Value *src_b, const vector<int> &indices, int vl, CodeGen_ARM &codegen)
280+
: DecomposeVectorShuffle(src_a, src_b, indices, vl), codegen(codegen) {
281+
}
282+
283+
int get_vec_length(Value *v) {
284+
return codegen.get_vector_num_elements(v->getType());
285+
}
286+
287+
Value *align_up_vector(Value *v, int align) {
288+
size_t org_len = get_vec_length(v);
289+
return codegen.slice_vector(v, 0, align_up(org_len, align));
290+
}
291+
292+
Value *slice_vec(Value *v, int start, size_t lanes) {
293+
return codegen.slice_vector(v, start, lanes);
294+
}
295+
296+
Value *concat_vecs(const vector<Value *> &vecs) {
297+
return codegen.concat_vectors(vecs);
298+
}
299+
300+
Value *shuffle_vl_aligned(Value *a, optional<Value *> &b, const vector<int> &indices, int vl) {
301+
return codegen.shuffle_scalable_vectors_general(a, b.value_or(nullptr), indices);
302+
}
303+
304+
private:
305+
CodeGen_ARM &codegen;
306+
};
270307
};
271308

272309
CodeGen_ARM::CodeGen_ARM(const Target &target)
@@ -1905,9 +1942,71 @@ void CodeGen_ARM::visit(const Shuffle *op) {
19051942

19061943
value = codegen_dense_vector_load(load, nullptr, /* slice_to_native */ false);
19071944
value = CodeGen_Posix::shuffle_vectors(value, op->indices);
1908-
} else {
1945+
return;
1946+
}
1947+
1948+
if (target_vscale() == 0) {
19091949
CodeGen_Posix::visit(op);
1950+
return;
1951+
}
1952+
1953+
const int total_lanes = op->type.lanes();
1954+
if (op->type.bits() == 1) {
1955+
// Peep-hole pattern that matches SVE "whilelt" which represents particular pattern of
1956+
// vector predicate. e.g. 11100000 (active_lanes=3, all_lanes=8)
1957+
if (op->is_concat() && op->vectors.size() == 2 &&
1958+
op->type.is_int_or_uint() &&
1959+
is_power_of_two(total_lanes) &&
1960+
total_lanes >= 2 * target_vscale() && total_lanes <= 16 * target_vscale() &&
1961+
is_const_one(op->vectors[0]) && is_const_zero(op->vectors[1])) {
1962+
1963+
int active_lanes = op->vectors[0].type().lanes();
1964+
value = codegen_whilelt(op->type.lanes(), 0, active_lanes);
1965+
return;
1966+
} else {
1967+
// Rewrite to process 1bit type vector as 8 bit vector, and then cast back
1968+
std::vector<Expr> vecs_i8;
1969+
vecs_i8.reserve(op->vectors.size());
1970+
for (const auto &vec_i1 : op->vectors) {
1971+
Type upgraded_type = vec_i1.type().with_bits(8);
1972+
vecs_i8.emplace_back(Cast::make(upgraded_type, vec_i1));
1973+
}
1974+
Expr equiv = Shuffle::make(vecs_i8, op->indices);
1975+
equiv = Cast::make(op->type, equiv);
1976+
equiv = common_subexpression_elimination(equiv);
1977+
value = codegen(equiv);
1978+
return;
1979+
}
1980+
} else if (op->is_concat() && op->vectors.size() == 2) {
1981+
// Here, we deal with some specific patterns of concat(a, b).
1982+
// Others are decomposed by CodeGen_LLVM at first,
1983+
// which in turn calles CodeGen_ARM::concat_vectors().
1984+
1985+
if (const Broadcast *bc_1 = op->vectors[1].as<Broadcast>()) {
1986+
// Common pattern where padding is appended to align lanes.
1987+
// Create broadcast of padding with dst lanes, then insert vec[0] at lane 0.
1988+
Value *val_0 = codegen(op->vectors[0]);
1989+
Value *val_1_scalar = codegen(bc_1->value);
1990+
Value *padding = builder->CreateVectorSplat(llvm::ElementCount::getScalable(total_lanes / target_vscale()), val_1_scalar);
1991+
value = insert_scalable_vector(padding, val_0, 0);
1992+
return;
1993+
}
1994+
} else if (op->is_broadcast()) {
1995+
// Undo simplification to avoid arbitrary-indexed shuffle
1996+
Expr equiv;
1997+
for (int f = 0; f < op->broadcast_factor(); ++f) {
1998+
if (equiv.defined()) {
1999+
equiv = Shuffle::make_concat({equiv, op->vectors[0]});
2000+
} else {
2001+
equiv = op->vectors[0];
2002+
}
2003+
}
2004+
equiv = common_subexpression_elimination(equiv);
2005+
value = codegen(equiv);
2006+
return;
19102007
}
2008+
2009+
CodeGen_Posix::visit(op);
19112010
}
19122011

19132012
llvm::Type *CodeGen_ARM::get_vector_type_from_value(Value *vec_or_scalar, int n) {
@@ -2110,52 +2209,139 @@ Value *CodeGen_ARM::shuffle_vectors(Value *a, Value *b, const std::vector<int> &
21102209
}
21112210

21122211
internal_assert(a->getType() == b->getType());
2212+
llvm::Type *src_type = a->getType();
2213+
llvm::Type *elt = get_vector_element_type(src_type);
2214+
const int bits = elt->getScalarSizeInBits();
2215+
// note: lanes are multiplied by vscale
2216+
const int natural_lanes = natural_vector_size(Int(bits));
2217+
const int src_lanes = get_vector_num_elements(src_type);
2218+
const int dst_lanes = indices.size();
2219+
2220+
if (src_type->isVectorTy()) {
2221+
// i1 -> shuffle with i8 -> i1
2222+
if (src_type->getScalarSizeInBits() == 1) {
2223+
internal_assert(src_type->isIntegerTy()) << "1 bit floating point type is unexpected\n";
2224+
a = builder->CreateIntCast(a, VectorType::get(i8_t, dyn_cast<llvm::VectorType>(src_type)), false);
2225+
b = builder->CreateIntCast(b, VectorType::get(i8_t, dyn_cast<llvm::VectorType>(src_type)), false);
2226+
Value *v = shuffle_vectors(a, b, indices);
2227+
return builder->CreateIntCast(v, VectorType::get(i1_t, dyn_cast<llvm::VectorType>(v->getType())), false);
2228+
}
2229+
2230+
// Check if deinterleaved slice
2231+
{
2232+
// Get the stride of slice
2233+
int slice_stride = 0;
2234+
const int start_index = indices[0];
2235+
if (dst_lanes > 1) {
2236+
const int stride = indices[1] - start_index;
2237+
bool stride_equal = true;
2238+
for (int i = 2; i < dst_lanes; ++i) {
2239+
stride_equal &= (indices[i] == start_index + i * stride);
2240+
}
2241+
slice_stride = stride_equal ? stride : 0;
2242+
}
2243+
2244+
// Lower slice with stride into llvm.vector.deinterleave intrinsic
2245+
const std::set<int> supported_strides{2, 3, 4, 8};
2246+
if (supported_strides.find(slice_stride) != supported_strides.end() &&
2247+
dst_lanes * slice_stride == src_lanes &&
2248+
indices.front() < slice_stride && // Start position cannot be larger than stride
2249+
is_power_of_two(dst_lanes) &&
2250+
dst_lanes % target_vscale() == 0 &&
2251+
dst_lanes / target_vscale() > 1) {
2252+
2253+
std::string instr = concat_strings("llvm.vector.deinterleave", slice_stride, mangle_llvm_type(a->getType()));
2254+
2255+
// We cannot mix FixedVector and ScalableVector, so dst_type must be scalable
2256+
llvm::Type *dst_type = get_vector_type(elt, dst_lanes / target_vscale(), VectorTypeConstraint::VScale);
2257+
StructType *sret_type = StructType::get(*context, std::vector(slice_stride, dst_type));
2258+
std::vector<llvm::Type *> arg_types{a->getType()};
2259+
llvm::FunctionType *fn_type = FunctionType::get(sret_type, arg_types, false);
2260+
FunctionCallee fn = module->getOrInsertFunction(instr, fn_type);
2261+
2262+
CallInst *deinterleave = builder->CreateCall(fn, {a});
2263+
// extract one element out of the returned struct
2264+
Value *extracted = builder->CreateExtractValue(deinterleave, indices.front());
2265+
2266+
return extracted;
2267+
}
2268+
}
2269+
}
2270+
2271+
// Perform vector shuffle by decomposing the operation to multiple native shuffle steps
2272+
// which calls shuffle_scalable_vectors_general() which emits TBL/TBL2 instruction
2273+
VectorShuffler shuffler(a, b, indices, natural_lanes, *this);
2274+
Value *v = shuffler.shuffle();
2275+
return v;
2276+
}
21132277

2278+
Value *CodeGen_ARM::shuffle_scalable_vectors_general(Value *a, Value *b, const std::vector<int> &indices) {
21142279
llvm::Type *elt = get_vector_element_type(a->getType());
2280+
const int bits = elt->getScalarSizeInBits();
2281+
const int natural_lanes = natural_vector_size(Int(bits));
21152282
const int src_lanes = get_vector_num_elements(a->getType());
21162283
const int dst_lanes = indices.size();
2284+
llvm::Type *dst_type = get_vector_type(elt, dst_lanes);
21172285

2118-
// Check if deinterleaved slice
2119-
{
2120-
// Get the stride of slice
2121-
int slice_stride = 0;
2122-
const int start_index = indices[0];
2123-
if (dst_lanes > 1) {
2124-
const int stride = indices[1] - start_index;
2125-
bool stride_equal = true;
2126-
for (int i = 2; i < dst_lanes; ++i) {
2127-
stride_equal &= (indices[i] == start_index + i * stride);
2128-
}
2129-
slice_stride = stride_equal ? stride : 0;
2130-
}
2286+
internal_assert(target_vscale() > 0 && is_scalable_vector(a)) << "Only deal with scalable vectors\n";
2287+
internal_assert(src_lanes == natural_lanes && dst_lanes == natural_lanes)
2288+
<< "Only deal with vector with natural_lanes\n";
21312289

2132-
// Lower slice with stride into llvm.vector.deinterleave intrinsic
2133-
const std::set<int> supported_strides{2, 3, 4, 8};
2134-
if (supported_strides.find(slice_stride) != supported_strides.end() &&
2135-
dst_lanes * slice_stride == src_lanes &&
2136-
indices.front() < slice_stride && // Start position cannot be larger than stride
2137-
is_power_of_two(dst_lanes) &&
2138-
dst_lanes % target_vscale() == 0 &&
2139-
dst_lanes / target_vscale() > 1) {
2290+
// We select TBL or TBL2 intrinsic depending on indices range
2291+
bool use_tbl = *std::max_element(indices.begin(), indices.end()) < src_lanes;
2292+
internal_assert(use_tbl || b) << "'b' must be valid in case of tbl2\n";
21402293

2141-
std::string instr = concat_strings("llvm.vector.deinterleave", slice_stride, mangle_llvm_type(a->getType()));
2294+
auto instr = concat_strings("llvm.aarch64.sve.", use_tbl ? "tbl" : "tbl2", mangle_llvm_type(dst_type));
21422295

2143-
// We cannot mix FixedVector and ScalableVector, so dst_type must be scalable
2144-
llvm::Type *dst_type = get_vector_type(elt, dst_lanes / target_vscale(), VectorTypeConstraint::VScale);
2145-
StructType *sret_type = StructType::get(*context, std::vector(slice_stride, dst_type));
2146-
std::vector<llvm::Type *> arg_types{a->getType()};
2147-
llvm::FunctionType *fn_type = FunctionType::get(sret_type, arg_types, false);
2148-
FunctionCallee fn = module->getOrInsertFunction(instr, fn_type);
2296+
Value *val_indices = codegen_shuffle_indices(bits, indices);
2297+
llvm::Type *vt_natural = get_vector_type(elt, natural_lanes);
2298+
std::vector<llvm::Type *> llvm_arg_types;
2299+
std::vector<llvm::Value *> llvm_arg_vals;
2300+
if (use_tbl) {
2301+
llvm_arg_types = {vt_natural, val_indices->getType()};
2302+
llvm_arg_vals = {a, val_indices};
2303+
} else {
2304+
llvm_arg_types = {vt_natural, vt_natural, val_indices->getType()};
2305+
llvm_arg_vals = {a, b, val_indices};
2306+
}
2307+
llvm::FunctionType *fn_type = FunctionType::get(vt_natural, llvm_arg_types, false);
2308+
FunctionCallee fn = module->getOrInsertFunction(instr, fn_type);
21492309

2150-
CallInst *deinterleave = builder->CreateCall(fn, {a});
2151-
// extract one element out of the returned struct
2152-
Value *extracted = builder->CreateExtractValue(deinterleave, indices.front());
2310+
Value *v = builder->CreateCall(fn, llvm_arg_vals);
2311+
return v;
2312+
}
21532313

2154-
return extracted;
2155-
}
2314+
Value *CodeGen_ARM::codegen_shuffle_indices(int bits, const std::vector<int> &indices) {
2315+
const int lanes = indices.size();
2316+
llvm::Type *index_type = IntegerType::get(module->getContext(), bits);
2317+
llvm::Type *index_vec_type = get_vector_type(index_type, lanes);
2318+
2319+
std::vector<Constant *> llvm_indices(lanes);
2320+
for (int i = 0; i < lanes; i++) {
2321+
int idx = indices[i];
2322+
llvm_indices[i] = idx >= 0 ? ConstantInt::get(index_type, idx) : UndefValue::get(index_type);
21562323
}
21572324

2158-
return CodeGen_Posix::shuffle_vectors(a, b, indices);
2325+
Value *v = ConstantVector::get(llvm_indices);
2326+
v = builder->CreateInsertVector(index_vec_type, UndefValue::get(index_vec_type),
2327+
v, ConstantInt::get(i64_t, 0));
2328+
return v;
2329+
}
2330+
2331+
Value *CodeGen_ARM::codegen_whilelt(int total_lanes, int start, int end) {
2332+
// Generates SVE "whilelt" instruction which represents vector predicate pattern of
2333+
// e.g. 11100000 (total_lanes = 8 , start = 0, end = 3)
2334+
// -> @llvm.aarch64.sve.whilelt.nxv8i1.i32(i32 0, i32 3)
2335+
internal_assert(target_vscale() > 0);
2336+
internal_assert(total_lanes % target_vscale() == 0);
2337+
std::string instr = concat_strings("llvm.aarch64.sve.whilelt.nxv", total_lanes / target_vscale(), "i1.i32");
2338+
2339+
llvm::Type *pred_type = get_vector_type(llvm_type_of(Int(1)), total_lanes);
2340+
llvm::FunctionType *fn_type = FunctionType::get(pred_type, {i32_t, i32_t}, false);
2341+
FunctionCallee fn = module->getOrInsertFunction(instr, fn_type);
2342+
2343+
value = builder->CreateCall(fn, {ConstantInt::get(i32_t, start), ConstantInt::get(i32_t, end)});
2344+
return value;
21592345
}
21602346

21612347
void CodeGen_ARM::visit(const Ramp *op) {
@@ -2579,6 +2765,11 @@ Type CodeGen_ARM::upgrade_type_for_storage(const Type &t) const {
25792765
return CodeGen_Posix::upgrade_type_for_storage(t);
25802766
}
25812767

2768+
int CodeGen_ARM::natural_vector_size(const Halide::Type &t) const {
2769+
internal_assert(t.bits() > 1) << "natural_vector_size requested with 1 bits\n";
2770+
return native_vector_bits() / t.bits();
2771+
}
2772+
25822773
string CodeGen_ARM::mcpu_target() const {
25832774
if (target.bits == 32) {
25842775
if (target.has_feature(Target::ARMv7s)) {

src/CodeGen_LLVM.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4128,7 +4128,9 @@ void CodeGen_LLVM::visit(const Shuffle *op) {
41284128
} else {
41294129
internal_assert(op->indices[0] == 0);
41304130
}
4131-
value = create_broadcast(value, op->indices.size());
4131+
if (op->indices.size() > 1) {
4132+
value = create_broadcast(value, op->indices.size());
4133+
}
41324134
return;
41334135
}
41344136
}

0 commit comments

Comments
 (0)