66#include " CodeGen_Posix.h"
77#include " ConciseCasts.h"
88#include " Debug.h"
9+ #include " DecomposeVectorShuffle.h"
910#include " DistributeShifts.h"
1011#include " IREquality.h"
1112#include " IRMatch.h"
2021namespace Halide {
2122namespace Internal {
2223
24+ using std::optional;
2325using std::ostringstream;
2426using std::pair;
2527using std::string;
@@ -217,6 +219,9 @@ class CodeGen_ARM : public CodeGen_Posix {
217219
218220 Value *interleave_vectors (const std::vector<Value *> &) override ;
219221 Value *shuffle_vectors (Value *a, Value *b, const std::vector<int > &indices) override ;
222+ Value *shuffle_scalable_vectors_general (Value *a, Value *b, const std::vector<int > &indices);
223+ Value *codegen_shuffle_indices (int bits, const std::vector<int > &indices);
224+ Value *codegen_whilelt (int total_lanes, int start, int end);
220225 void codegen_vector_reduce (const VectorReduce *, const Expr &) override ;
221226 bool codegen_dot_product_vector_reduce (const VectorReduce *, const Expr &);
222227 bool codegen_pairwise_vector_reduce (const VectorReduce *, const Expr &);
@@ -237,6 +242,7 @@ class CodeGen_ARM : public CodeGen_Posix {
237242 };
238243 vector<Pattern> casts, calls, negations;
239244
245+ int natural_vector_size (const Halide::Type &t) const ;
240246 string mcpu_target () const override ;
241247 string mcpu_tune () const override ;
242248 string mattrs () const override ;
@@ -267,6 +273,37 @@ class CodeGen_ARM : public CodeGen_Posix {
267273 return Shuffle::make_concat ({const_true (true_lanes), const_false (false_lanes)});
268274 }
269275 }
276+
277+ /* * Handle general shuffle of vectors. See DecomposeVectorShuffle.h about how it works */
278+ struct VectorShuffler : public DecomposeVectorShuffle <VectorShuffler, Value *> {
279+ VectorShuffler (Value *src_a, Value *src_b, const vector<int > &indices, int vl, CodeGen_ARM &codegen)
280+ : DecomposeVectorShuffle(src_a, src_b, indices, vl), codegen(codegen) {
281+ }
282+
283+ int get_vec_length (Value *v) {
284+ return codegen.get_vector_num_elements (v->getType ());
285+ }
286+
287+ Value *align_up_vector (Value *v, int align) {
288+ size_t org_len = get_vec_length (v);
289+ return codegen.slice_vector (v, 0 , align_up (org_len, align));
290+ }
291+
292+ Value *slice_vec (Value *v, int start, size_t lanes) {
293+ return codegen.slice_vector (v, start, lanes);
294+ }
295+
296+ Value *concat_vecs (const vector<Value *> &vecs) {
297+ return codegen.concat_vectors (vecs);
298+ }
299+
300+ Value *shuffle_vl_aligned (Value *a, optional<Value *> &b, const vector<int > &indices, int vl) {
301+ return codegen.shuffle_scalable_vectors_general (a, b.value_or (nullptr ), indices);
302+ }
303+
304+ private:
305+ CodeGen_ARM &codegen;
306+ };
270307};
271308
272309CodeGen_ARM::CodeGen_ARM (const Target &target)
@@ -1905,9 +1942,71 @@ void CodeGen_ARM::visit(const Shuffle *op) {
19051942
19061943 value = codegen_dense_vector_load (load, nullptr , /* slice_to_native */ false );
19071944 value = CodeGen_Posix::shuffle_vectors (value, op->indices );
1908- } else {
1945+ return ;
1946+ }
1947+
1948+ if (target_vscale () == 0 ) {
19091949 CodeGen_Posix::visit (op);
1950+ return ;
1951+ }
1952+
1953+ const int total_lanes = op->type .lanes ();
1954+ if (op->type .bits () == 1 ) {
1955+ // Peep-hole pattern that matches SVE "whilelt" which represents particular pattern of
1956+ // vector predicate. e.g. 11100000 (active_lanes=3, all_lanes=8)
1957+ if (op->is_concat () && op->vectors .size () == 2 &&
1958+ op->type .is_int_or_uint () &&
1959+ is_power_of_two (total_lanes) &&
1960+ total_lanes >= 2 * target_vscale () && total_lanes <= 16 * target_vscale () &&
1961+ is_const_one (op->vectors [0 ]) && is_const_zero (op->vectors [1 ])) {
1962+
1963+ int active_lanes = op->vectors [0 ].type ().lanes ();
1964+ value = codegen_whilelt (op->type .lanes (), 0 , active_lanes);
1965+ return ;
1966+ } else {
1967+ // Rewrite to process 1bit type vector as 8 bit vector, and then cast back
1968+ std::vector<Expr> vecs_i8;
1969+ vecs_i8.reserve (op->vectors .size ());
1970+ for (const auto &vec_i1 : op->vectors ) {
1971+ Type upgraded_type = vec_i1.type ().with_bits (8 );
1972+ vecs_i8.emplace_back (Cast::make (upgraded_type, vec_i1));
1973+ }
1974+ Expr equiv = Shuffle::make (vecs_i8, op->indices );
1975+ equiv = Cast::make (op->type , equiv);
1976+ equiv = common_subexpression_elimination (equiv);
1977+ value = codegen (equiv);
1978+ return ;
1979+ }
1980+ } else if (op->is_concat () && op->vectors .size () == 2 ) {
1981+ // Here, we deal with some specific patterns of concat(a, b).
1982+ // Others are decomposed by CodeGen_LLVM at first,
1983+ // which in turn calles CodeGen_ARM::concat_vectors().
1984+
1985+ if (const Broadcast *bc_1 = op->vectors [1 ].as <Broadcast>()) {
1986+ // Common pattern where padding is appended to align lanes.
1987+ // Create broadcast of padding with dst lanes, then insert vec[0] at lane 0.
1988+ Value *val_0 = codegen (op->vectors [0 ]);
1989+ Value *val_1_scalar = codegen (bc_1->value );
1990+ Value *padding = builder->CreateVectorSplat (llvm::ElementCount::getScalable (total_lanes / target_vscale ()), val_1_scalar);
1991+ value = insert_scalable_vector (padding, val_0, 0 );
1992+ return ;
1993+ }
1994+ } else if (op->is_broadcast ()) {
1995+ // Undo simplification to avoid arbitrary-indexed shuffle
1996+ Expr equiv;
1997+ for (int f = 0 ; f < op->broadcast_factor (); ++f) {
1998+ if (equiv.defined ()) {
1999+ equiv = Shuffle::make_concat ({equiv, op->vectors [0 ]});
2000+ } else {
2001+ equiv = op->vectors [0 ];
2002+ }
2003+ }
2004+ equiv = common_subexpression_elimination (equiv);
2005+ value = codegen (equiv);
2006+ return ;
19102007 }
2008+
2009+ CodeGen_Posix::visit (op);
19112010}
19122011
19132012llvm::Type *CodeGen_ARM::get_vector_type_from_value (Value *vec_or_scalar, int n) {
@@ -2110,52 +2209,139 @@ Value *CodeGen_ARM::shuffle_vectors(Value *a, Value *b, const std::vector<int> &
21102209 }
21112210
21122211 internal_assert (a->getType () == b->getType ());
2212+ llvm::Type *src_type = a->getType ();
2213+ llvm::Type *elt = get_vector_element_type (src_type);
2214+ const int bits = elt->getScalarSizeInBits ();
2215+ // note: lanes are multiplied by vscale
2216+ const int natural_lanes = natural_vector_size (Int (bits));
2217+ const int src_lanes = get_vector_num_elements (src_type);
2218+ const int dst_lanes = indices.size ();
2219+
2220+ if (src_type->isVectorTy ()) {
2221+ // i1 -> shuffle with i8 -> i1
2222+ if (src_type->getScalarSizeInBits () == 1 ) {
2223+ internal_assert (src_type->isIntegerTy ()) << " 1 bit floating point type is unexpected\n " ;
2224+ a = builder->CreateIntCast (a, VectorType::get (i8_t , dyn_cast<llvm::VectorType>(src_type)), false );
2225+ b = builder->CreateIntCast (b, VectorType::get (i8_t , dyn_cast<llvm::VectorType>(src_type)), false );
2226+ Value *v = shuffle_vectors (a, b, indices);
2227+ return builder->CreateIntCast (v, VectorType::get (i1_t , dyn_cast<llvm::VectorType>(v->getType ())), false );
2228+ }
2229+
2230+ // Check if deinterleaved slice
2231+ {
2232+ // Get the stride of slice
2233+ int slice_stride = 0 ;
2234+ const int start_index = indices[0 ];
2235+ if (dst_lanes > 1 ) {
2236+ const int stride = indices[1 ] - start_index;
2237+ bool stride_equal = true ;
2238+ for (int i = 2 ; i < dst_lanes; ++i) {
2239+ stride_equal &= (indices[i] == start_index + i * stride);
2240+ }
2241+ slice_stride = stride_equal ? stride : 0 ;
2242+ }
2243+
2244+ // Lower slice with stride into llvm.vector.deinterleave intrinsic
2245+ const std::set<int > supported_strides{2 , 3 , 4 , 8 };
2246+ if (supported_strides.find (slice_stride) != supported_strides.end () &&
2247+ dst_lanes * slice_stride == src_lanes &&
2248+ indices.front () < slice_stride && // Start position cannot be larger than stride
2249+ is_power_of_two (dst_lanes) &&
2250+ dst_lanes % target_vscale () == 0 &&
2251+ dst_lanes / target_vscale () > 1 ) {
2252+
2253+ std::string instr = concat_strings (" llvm.vector.deinterleave" , slice_stride, mangle_llvm_type (a->getType ()));
2254+
2255+ // We cannot mix FixedVector and ScalableVector, so dst_type must be scalable
2256+ llvm::Type *dst_type = get_vector_type (elt, dst_lanes / target_vscale (), VectorTypeConstraint::VScale);
2257+ StructType *sret_type = StructType::get (*context, std::vector (slice_stride, dst_type));
2258+ std::vector<llvm::Type *> arg_types{a->getType ()};
2259+ llvm::FunctionType *fn_type = FunctionType::get (sret_type, arg_types, false );
2260+ FunctionCallee fn = module ->getOrInsertFunction (instr, fn_type);
2261+
2262+ CallInst *deinterleave = builder->CreateCall (fn, {a});
2263+ // extract one element out of the returned struct
2264+ Value *extracted = builder->CreateExtractValue (deinterleave, indices.front ());
2265+
2266+ return extracted;
2267+ }
2268+ }
2269+ }
2270+
2271+ // Perform vector shuffle by decomposing the operation to multiple native shuffle steps
2272+ // which calls shuffle_scalable_vectors_general() which emits TBL/TBL2 instruction
2273+ VectorShuffler shuffler (a, b, indices, natural_lanes, *this );
2274+ Value *v = shuffler.shuffle ();
2275+ return v;
2276+ }
21132277
2278+ Value *CodeGen_ARM::shuffle_scalable_vectors_general (Value *a, Value *b, const std::vector<int > &indices) {
21142279 llvm::Type *elt = get_vector_element_type (a->getType ());
2280+ const int bits = elt->getScalarSizeInBits ();
2281+ const int natural_lanes = natural_vector_size (Int (bits));
21152282 const int src_lanes = get_vector_num_elements (a->getType ());
21162283 const int dst_lanes = indices.size ();
2284+ llvm::Type *dst_type = get_vector_type (elt, dst_lanes);
21172285
2118- // Check if deinterleaved slice
2119- {
2120- // Get the stride of slice
2121- int slice_stride = 0 ;
2122- const int start_index = indices[0 ];
2123- if (dst_lanes > 1 ) {
2124- const int stride = indices[1 ] - start_index;
2125- bool stride_equal = true ;
2126- for (int i = 2 ; i < dst_lanes; ++i) {
2127- stride_equal &= (indices[i] == start_index + i * stride);
2128- }
2129- slice_stride = stride_equal ? stride : 0 ;
2130- }
2286+ internal_assert (target_vscale () > 0 && is_scalable_vector (a)) << " Only deal with scalable vectors\n " ;
2287+ internal_assert (src_lanes == natural_lanes && dst_lanes == natural_lanes)
2288+ << " Only deal with vector with natural_lanes\n " ;
21312289
2132- // Lower slice with stride into llvm.vector.deinterleave intrinsic
2133- const std::set<int > supported_strides{2 , 3 , 4 , 8 };
2134- if (supported_strides.find (slice_stride) != supported_strides.end () &&
2135- dst_lanes * slice_stride == src_lanes &&
2136- indices.front () < slice_stride && // Start position cannot be larger than stride
2137- is_power_of_two (dst_lanes) &&
2138- dst_lanes % target_vscale () == 0 &&
2139- dst_lanes / target_vscale () > 1 ) {
2290+ // We select TBL or TBL2 intrinsic depending on indices range
2291+ bool use_tbl = *std::max_element (indices.begin (), indices.end ()) < src_lanes;
2292+ internal_assert (use_tbl || b) << " 'b' must be valid in case of tbl2\n " ;
21402293
2141- std::string instr = concat_strings (" llvm.vector.deinterleave " , slice_stride , mangle_llvm_type (a-> getType () ));
2294+ auto instr = concat_strings (" llvm.aarch64.sve. " , use_tbl ? " tbl " : " tbl2 " , mangle_llvm_type (dst_type ));
21422295
2143- // We cannot mix FixedVector and ScalableVector, so dst_type must be scalable
2144- llvm::Type *dst_type = get_vector_type (elt, dst_lanes / target_vscale (), VectorTypeConstraint::VScale);
2145- StructType *sret_type = StructType::get (*context, std::vector (slice_stride, dst_type));
2146- std::vector<llvm::Type *> arg_types{a->getType ()};
2147- llvm::FunctionType *fn_type = FunctionType::get (sret_type, arg_types, false );
2148- FunctionCallee fn = module ->getOrInsertFunction (instr, fn_type);
2296+ Value *val_indices = codegen_shuffle_indices (bits, indices);
2297+ llvm::Type *vt_natural = get_vector_type (elt, natural_lanes);
2298+ std::vector<llvm::Type *> llvm_arg_types;
2299+ std::vector<llvm::Value *> llvm_arg_vals;
2300+ if (use_tbl) {
2301+ llvm_arg_types = {vt_natural, val_indices->getType ()};
2302+ llvm_arg_vals = {a, val_indices};
2303+ } else {
2304+ llvm_arg_types = {vt_natural, vt_natural, val_indices->getType ()};
2305+ llvm_arg_vals = {a, b, val_indices};
2306+ }
2307+ llvm::FunctionType *fn_type = FunctionType::get (vt_natural, llvm_arg_types, false );
2308+ FunctionCallee fn = module ->getOrInsertFunction (instr, fn_type);
21492309
2150- CallInst *deinterleave = builder->CreateCall (fn, {a} );
2151- // extract one element out of the returned struct
2152- Value *extracted = builder-> CreateExtractValue (deinterleave, indices. front ());
2310+ Value *v = builder->CreateCall (fn, llvm_arg_vals );
2311+ return v;
2312+ }
21532313
2154- return extracted;
2155- }
2314+ Value *CodeGen_ARM::codegen_shuffle_indices (int bits, const std::vector<int > &indices) {
2315+ const int lanes = indices.size ();
2316+ llvm::Type *index_type = IntegerType::get (module ->getContext (), bits);
2317+ llvm::Type *index_vec_type = get_vector_type (index_type, lanes);
2318+
2319+ std::vector<Constant *> llvm_indices (lanes);
2320+ for (int i = 0 ; i < lanes; i++) {
2321+ int idx = indices[i];
2322+ llvm_indices[i] = idx >= 0 ? ConstantInt::get (index_type, idx) : UndefValue::get (index_type);
21562323 }
21572324
2158- return CodeGen_Posix::shuffle_vectors (a, b, indices);
2325+ Value *v = ConstantVector::get (llvm_indices);
2326+ v = builder->CreateInsertVector (index_vec_type, UndefValue::get (index_vec_type),
2327+ v, ConstantInt::get (i64_t , 0 ));
2328+ return v;
2329+ }
2330+
2331+ Value *CodeGen_ARM::codegen_whilelt (int total_lanes, int start, int end) {
2332+ // Generates SVE "whilelt" instruction which represents vector predicate pattern of
2333+ // e.g. 11100000 (total_lanes = 8 , start = 0, end = 3)
2334+ // -> @llvm.aarch64.sve.whilelt.nxv8i1.i32(i32 0, i32 3)
2335+ internal_assert (target_vscale () > 0 );
2336+ internal_assert (total_lanes % target_vscale () == 0 );
2337+ std::string instr = concat_strings (" llvm.aarch64.sve.whilelt.nxv" , total_lanes / target_vscale (), " i1.i32" );
2338+
2339+ llvm::Type *pred_type = get_vector_type (llvm_type_of (Int (1 )), total_lanes);
2340+ llvm::FunctionType *fn_type = FunctionType::get (pred_type, {i32_t , i32_t }, false );
2341+ FunctionCallee fn = module ->getOrInsertFunction (instr, fn_type);
2342+
2343+ value = builder->CreateCall (fn, {ConstantInt::get (i32_t , start), ConstantInt::get (i32_t , end)});
2344+ return value;
21592345}
21602346
21612347void CodeGen_ARM::visit (const Ramp *op) {
@@ -2579,6 +2765,11 @@ Type CodeGen_ARM::upgrade_type_for_storage(const Type &t) const {
25792765 return CodeGen_Posix::upgrade_type_for_storage (t);
25802766}
25812767
2768+ int CodeGen_ARM::natural_vector_size (const Halide::Type &t) const {
2769+ internal_assert (t.bits () > 1 ) << " natural_vector_size requested with 1 bits\n " ;
2770+ return native_vector_bits () / t.bits ();
2771+ }
2772+
25822773string CodeGen_ARM::mcpu_target () const {
25832774 if (target.bits == 32 ) {
25842775 if (target.has_feature (Target::ARMv7s)) {
0 commit comments