Skip to content

Commit 9372a80

Browse files
committed
codegen: fix ARM SVE2 lowering and i1 vector handling
1 parent 613d127 commit 9372a80

6 files changed

Lines changed: 153 additions & 4 deletions

File tree

src/CodeGen_ARM.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ Target complete_arm_target(Target t) {
5656
}
5757
};
5858

59+
// ARMFp16 implies ARMv8.2-A; we don't know of any devices where
60+
// that doesn't hold. The cascade loop below will set ARMv81a and ARMv8a.
61+
add_implied_feature_if_supported(t, Target::ARMFp16, Target::ARMv82a);
62+
5963
constexpr int num_arm_v8_features = 10;
6064
static const Target::Feature arm_v8_features[num_arm_v8_features] = {
6165
Target::ARMv89a,
@@ -1681,6 +1685,7 @@ void CodeGen_ARM::visit(const Store *op) {
16811685
vpred_val = convert_fixed_or_scalable_vector_type(vpred_val, pred_type);
16821686
if (is_predicated_store) {
16831687
Value *sliced_store_vpred_val = slice_vector(store_pred_val, i, natural_lanes);
1688+
sliced_store_vpred_val = convert_fixed_or_scalable_vector_type(sliced_store_vpred_val, pred_type);
16841689
vpred_val = builder->CreateAnd(vpred_val, sliced_store_vpred_val);
16851690
}
16861691

@@ -1854,6 +1859,7 @@ void CodeGen_ARM::visit(const Load *op) {
18541859
Value *vpred_val = codegen(vpred);
18551860
if (is_predicated_load) {
18561861
Value *sliced_load_vpred_val = slice_vector(load_pred_val, i, natural_lanes);
1862+
sliced_load_vpred_val = convert_fixed_or_scalable_vector_type(sliced_load_vpred_val, vpred_val->getType());
18571863
vpred_val = builder->CreateAnd(vpred_val, sliced_load_vpred_val);
18581864
}
18591865

@@ -1904,8 +1910,14 @@ Value *CodeGen_ARM::interleave_vectors(const std::vector<Value *> &vecs) {
19041910
return CodeGen_Posix::interleave_vectors(vecs);
19051911
}
19061912

1907-
// Lower into llvm.vector.interleave intrinsic
1913+
// Lower into llvm.vector.interleave intrinsic.
1914+
// LLVM only supports non-power-of-2 strides (e.g. 3) for scalable
1915+
// vectors starting in LLVM 22.
1916+
#if LLVM_VERSION >= 220
19081917
const std::set<int> supported_strides{2, 3, 4, 8};
1918+
#else
1919+
const std::set<int> supported_strides{2, 4, 8};
1920+
#endif
19091921
const int stride = vecs.size();
19101922
const int src_lanes = get_vector_num_elements(vecs[0]->getType());
19111923

@@ -1957,7 +1969,11 @@ Value *CodeGen_ARM::shuffle_vectors(Value *a, Value *b, const std::vector<int> &
19571969
}
19581970

19591971
// Lower slice with stride into llvm.vector.deinterleave intrinsic
1972+
#if LLVM_VERSION >= 220
19601973
const std::set<int> supported_strides{2, 3, 4, 8};
1974+
#else
1975+
const std::set<int> supported_strides{2, 4, 8};
1976+
#endif
19611977
if (supported_strides.find(slice_stride) != supported_strides.end() &&
19621978
dst_lanes * slice_stride == src_lanes &&
19631979
indices.front() < slice_stride && // Start position cannot be larger than stride
@@ -2410,6 +2426,10 @@ string CodeGen_ARM::mcpu_target() const {
24102426
if (target.bits == 32) {
24112427
if (target.has_feature(Target::ARMv7s)) {
24122428
return "swift";
2429+
} else if (target.has_feature(Target::ARMv82a)) {
2430+
return "cortex-a55";
2431+
} else if (target.has_feature(Target::ARMv8a)) {
2432+
return "cortex-a32";
24132433
} else {
24142434
return "cortex-a9";
24152435
}
@@ -2436,7 +2456,10 @@ string CodeGen_ARM::mattrs() const {
24362456
attrs.emplace_back("+fullfp16");
24372457
}
24382458
if (target.has_feature(Target::ARMv8a)) {
2439-
attrs.emplace_back("+v8a");
2459+
// The ARM (32-bit) backend calls this feature "v8"; the AArch64
2460+
// backend calls it "v8a". The dotted sub-versions (v8.1a, v8.2a,
2461+
// etc.) use the same names in both backends.
2462+
attrs.emplace_back(target.bits == 32 ? "+v8" : "+v8a");
24402463
}
24412464
if (target.has_feature(Target::ARMv81a)) {
24422465
attrs.emplace_back("+v8.1a");

src/CodeGen_LLVM.cpp

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1515,6 +1515,17 @@ void CodeGen_LLVM::visit(const Reinterpret *op) {
15151515
llvm::Type *llvm_dst_fixed = get_vector_type(llvm_type_of(dst.element_of()), dst.lanes(), VectorTypeConstraint::Fixed);
15161516
value = builder->CreateBitOrPointerCast(value, llvm_dst_fixed);
15171517
value = fixed_to_scalable_vector_type(value);
1518+
} else if (isa<FixedVectorType>(value->getType()) && isa<ScalableVectorType>(llvm_dst)) {
1519+
// Cannot bitcast/ptrtoint directly between fixed and scalable vectors.
1520+
// First cast to a fixed vector of the destination element type, then convert to scalable.
1521+
llvm::Type *llvm_dst_fixed = get_vector_type(llvm_dst->getScalarType(), dst.lanes(), VectorTypeConstraint::Fixed);
1522+
value = builder->CreateBitOrPointerCast(value, llvm_dst_fixed);
1523+
value = fixed_to_scalable_vector_type(value);
1524+
} else if (isa<ScalableVectorType>(value->getType()) && isa<FixedVectorType>(llvm_dst)) {
1525+
// Cannot bitcast/ptrtoint directly between scalable and fixed vectors.
1526+
// First convert to a fixed vector of the source element type, then cast.
1527+
value = scalable_to_fixed_vector_type(value);
1528+
value = builder->CreateBitOrPointerCast(value, llvm_dst);
15181529
} else {
15191530
// Our `Reinterpret` expr directly maps to LLVM IR bitcast/ptrtoint/inttoptr
15201531
// instructions with no additional handling required:
@@ -4314,10 +4325,12 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini
43144325
const int input_lanes = val.type().lanes();
43154326
const int input_bytes = input_lanes * val.type().bytes();
43164327
const int vscale = std::max(effective_vscale, 1);
4328+
// LLVM added VECREDUCE_MUL/FMUL lowering for SVE in LLVM 22.
4329+
const bool mul_ok = LLVM_VERSION >= 220 || effective_vscale == 0;
43174330
const bool llvm_has_intrinsic =
43184331
// Must be one of these ops
43194332
((op->op == VectorReduce::Add ||
4320-
op->op == VectorReduce::Mul ||
4333+
(op->op == VectorReduce::Mul && mul_ok) ||
43214334
op->op == VectorReduce::Min ||
43224335
op->op == VectorReduce::Max) &&
43234336
(use_llvm_vp_intrinsics ||
@@ -4920,6 +4933,13 @@ Value *CodeGen_LLVM::slice_vector(Value *vec, int start, int size) {
49204933
// otherwise.
49214934
llvm::Type *scalar_type = vec->getType()->getScalarType();
49224935

4936+
if (scalar_type->isIntegerTy(1)) {
4937+
auto *result_type = cast<VectorType>(get_vector_type(scalar_type, size / effective_vscale, VectorTypeConstraint::VScale));
4938+
return handle_bool_as_i8(vec, result_type, [&](Value *v) {
4939+
return slice_vector(v, start, size);
4940+
});
4941+
}
4942+
49234943
int intermediate_lanes = std::min(size, vec_lanes - start);
49244944
llvm::Type *intermediate_type = get_vector_type(scalar_type, intermediate_lanes, VectorTypeConstraint::Fixed);
49254945

@@ -5190,6 +5210,18 @@ llvm::Value *CodeGen_LLVM::match_vector_type_scalable(llvm::Value *value, llvm::
51905210
return match_vector_type_scalable(value, guide->getType());
51915211
}
51925212

5213+
llvm::Value *CodeGen_LLVM::handle_bool_as_i8(llvm::Value *arg, llvm::VectorType *result_i1_type,
5214+
const std::function<llvm::Value *(llvm::Value *)> &fn) {
5215+
auto *arg_vty = cast<llvm::VectorType>(arg->getType());
5216+
bool scalable = isa<llvm::ScalableVectorType>(arg_vty);
5217+
int min_elts = scalable ? cast<llvm::ScalableVectorType>(arg_vty)->getMinNumElements() : cast<llvm::FixedVectorType>(arg_vty)->getNumElements();
5218+
auto constraint = scalable ? VectorTypeConstraint::VScale : VectorTypeConstraint::Fixed;
5219+
llvm::Type *arg_i8 = get_vector_type(i8_t, min_elts, constraint);
5220+
llvm::Value *widened = builder->CreateZExt(arg, arg_i8);
5221+
llvm::Value *result = fn(widened);
5222+
return builder->CreateTrunc(result, result_i1_type);
5223+
}
5224+
51935225
llvm::Value *CodeGen_LLVM::convert_fixed_or_scalable_vector_type(llvm::Value *arg,
51945226
llvm::Type *desired_type) {
51955227
llvm::Type *arg_type = arg->getType();
@@ -5199,6 +5231,18 @@ llvm::Value *CodeGen_LLVM::convert_fixed_or_scalable_vector_type(llvm::Value *ar
51995231
}
52005232

52015233
internal_assert(arg_type->getScalarType() == desired_type->getScalarType());
5234+
5235+
if (arg_type->isVectorTy() && desired_type->isVectorTy() &&
5236+
arg_type->getScalarType()->isIntegerTy(1)) {
5237+
bool dst_scalable = isa<llvm::ScalableVectorType>(desired_type);
5238+
int dst_elts = get_vector_num_elements(desired_type);
5239+
llvm::Type *dst_i8 = get_vector_type(i8_t, dst_scalable ? dst_elts / effective_vscale : dst_elts,
5240+
dst_scalable ? VectorTypeConstraint::VScale : VectorTypeConstraint::Fixed);
5241+
return handle_bool_as_i8(arg, cast<VectorType>(desired_type), [&](Value *v) {
5242+
return convert_fixed_or_scalable_vector_type(v, dst_i8);
5243+
});
5244+
}
5245+
52025246
if (!arg_type->isVectorTy()) {
52035247
arg = create_broadcast(arg, 1);
52045248
arg_type = arg->getType();
@@ -5280,6 +5324,12 @@ llvm::Value *CodeGen_LLVM::fixed_to_scalable_vector_type(llvm::Value *fixed_arg)
52805324
internal_assert(fixed_type->getElementType() == scalable_type->getElementType());
52815325
internal_assert(lanes == (scalable_type->getMinNumElements() * effective_vscale));
52825326

5327+
if (fixed_type->getElementType()->isIntegerTy(1)) {
5328+
return handle_bool_as_i8(fixed_arg, scalable_type, [&](Value *v) {
5329+
return fixed_to_scalable_vector_type(v);
5330+
});
5331+
}
5332+
52835333
// E.g. <vscale x 2 x i64> llvm.vector.insert.nxv2i64.v4i64(<vscale x 2 x i64>, <4 x i64>, i64)
52845334
const char *type_designator;
52855335
if (fixed_type->getElementType()->isIntegerTy()) {
@@ -5297,7 +5347,7 @@ llvm::Value *CodeGen_LLVM::fixed_to_scalable_vector_type(llvm::Value *fixed_arg)
52975347

52985348
std::vector<llvm::Value *> args;
52995349
args.push_back(result_vec);
5300-
args.push_back(value);
5350+
args.push_back(fixed_arg);
53015351
args.push_back(ConstantInt::get(i64_t, 0));
53025352

53035353
return simple_call_intrin(intrin, args, scalable_type);
@@ -5316,6 +5366,12 @@ llvm::Value *CodeGen_LLVM::scalable_to_fixed_vector_type(llvm::Value *scalable_a
53165366
internal_assert(fixed_type->getElementType() == scalable_type->getElementType());
53175367
internal_assert(fixed_type->getNumElements() == (scalable_type->getMinNumElements() * effective_vscale));
53185368

5369+
if (scalable_type->getElementType()->isIntegerTy(1)) {
5370+
return handle_bool_as_i8(scalable_arg, fixed_type, [&](Value *v) {
5371+
return scalable_to_fixed_vector_type(v);
5372+
});
5373+
}
5374+
53195375
// E.g. <64 x i8> @llvm.vector.extract.v64i8.nxv8i8(<vscale x 8 x i8> %vresult, i64 0)
53205376
const char *type_designator;
53215377
if (scalable_type->getElementType()->isIntegerTy()) {

src/CodeGen_LLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ class NamedMDNode;
3131
class DataLayout;
3232
class BasicBlock;
3333
class GlobalVariable;
34+
class VectorType;
3435
} // namespace llvm
3536

37+
#include <functional>
3638
#include <map>
3739
#include <memory>
3840
#include <optional>
@@ -589,6 +591,14 @@ class CodeGen_LLVM : public IRVisitor {
589591
/** Convert an LLVM vscale vector value to the corresponding fixed vector value. */
590592
llvm::Value *scalable_to_fixed_vector_type(llvm::Value *scalable);
591593

594+
/** Work around LLVM's inability to lower vector insert/extract for i1
595+
* element types (getVectorSubVecPointer computes byte offsets via integer
596+
* division, truncating for i1: 1/8=0). Widens the i1 vector arg to i8,
597+
* applies fn to the widened value, and truncates the result back to
598+
* result_i1_type. */
599+
llvm::Value *handle_bool_as_i8(llvm::Value *arg, llvm::VectorType *result_i1_type,
600+
const std::function<llvm::Value *(llvm::Value *)> &fn);
601+
592602
/** Get number of vector elements, taking into account scalable vectors. Returns 1 for scalars. */
593603
int get_vector_num_elements(const llvm::Type *t);
594604

test/correctness/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,8 @@ tests(GROUPS correctness
319319
strict_float.cpp
320320
strict_float_bounds.cpp
321321
strided_load.cpp
322+
sve_codegen_predicated.cpp
323+
sve_codegen_reinterpret.cpp
322324
target.cpp
323325
target_query.cpp
324326
tiled_matmul.cpp
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include "Halide.h"
2+
#include "halide_test_dirs.h"
3+
4+
#include <cstdio>
5+
#include <string>
6+
7+
using namespace Halide;
8+
9+
int main(int argc, char **argv) {
10+
const Target sve2("arm-64-linux-arm_dot_prod-arm_fp16-sve2-vector_bits_128");
11+
std::string tmpdir = Internal::get_test_tmp_dir();
12+
13+
// Dense stores with non-natural lane counts force predicate tail masking.
14+
// The predicate is a boolean (i1) vector that must be converted from fixed
15+
// to scalable, which previously triggered an LLVM assertion in
16+
// getVectorSubVecPointer ("Converting bits to bytes lost precision")
17+
// because the byte offset computation truncates for i1 (1/8=0).
18+
Func f("dense_pred_store");
19+
Var x("x");
20+
f(x) = cast<uint8_t>(x * 2);
21+
f.vectorize(x, 24); // 24 is not a multiple of 16 (natural for uint8 @ 128-bit SVE)
22+
f.compile_to_object(tmpdir + "sve_dense_pred_store.o", {}, "dense_pred_store", sve2);
23+
24+
printf("Success!\n");
25+
return 0;
26+
}
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#include "Halide.h"
2+
#include "halide_test_dirs.h"
3+
4+
#include <cstdio>
5+
#include <string>
6+
7+
using namespace Halide;
8+
9+
int main(int argc, char **argv) {
10+
const Target sve2("arm-64-linux-arm_dot_prod-arm_fp16-sve2-vector_bits_128");
11+
std::string tmpdir = Internal::get_test_tmp_dir();
12+
13+
// Reinterpret between Handle (pointer) and integer types with vectorization.
14+
// Pointers produce fixed vectors (<4 x ptr>) while the integer destination
15+
// may be scalable (<vscale x 4 x i64>), requiring conversion before the
16+
// cast. Previously triggered ConstantExpr::getCast ("Invalid constantexpr
17+
// cast!") because CreateBitOrPointerCast cannot operate across fixed and
18+
// scalable vector types, and fixed_to_scalable_vector_type passed the wrong
19+
// value to the llvm.vector.insert intrinsic.
20+
std::string msg = "hello!\n";
21+
Func f("handle_cast"), g("copy"), h("out");
22+
Var x("x");
23+
f(x) = cast<char *>(msg);
24+
f.compute_root().vectorize(x, 4);
25+
g(x) = f(x);
26+
g.compute_root();
27+
h(x) = g(x);
28+
h.compile_to_object(tmpdir + "sve_handle_cast.o", {}, "handle_cast", sve2);
29+
30+
printf("Success!\n");
31+
return 0;
32+
}

0 commit comments

Comments
 (0)