Skip to content

Commit 1a0e8d3

Browse files
Support strided store with predicate in SVE2 (#9085)
* Support strided store with predicate in SVE2 Handle the tail of a vector by using predicate while taking predicated store into account * Update strided store test in simd_op_check_sve2 * Ensure all the deinterleaved predicates are the same
1 parent 59553d7 commit 1a0e8d3

2 files changed

Lines changed: 64 additions & 10 deletions

File tree

src/CodeGen_ARM.cpp

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,9 +1518,7 @@ void CodeGen_ARM::visit(const Store *op) {
15181518
if (ramp && is_const_one(ramp->stride) &&
15191519
shuffle && shuffle->is_interleave() &&
15201520
type_ok_for_vst &&
1521-
2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4 &&
1522-
// TODO: we could handle predicated_store once shuffle_vector gets robust for scalable vectors
1523-
!is_predicated_store) {
1521+
2 <= shuffle->vectors.size() && shuffle->vectors.size() <= 4) {
15241522

15251523
const int num_vecs = shuffle->vectors.size();
15261524
vector<Value *> args(num_vecs);
@@ -1587,6 +1585,48 @@ void CodeGen_ARM::visit(const Store *op) {
15871585
// Scalable vector supports predication for smaller than whole vector size.
15881586
internal_assert(target_vscale() > 0 || (t.lanes() >= intrin_type.lanes()));
15891587

1588+
Value *vpred_predicated_store_val = nullptr;
1589+
vector<pair<string, Expr>> lets_pred;
1590+
if (is_sve && is_predicated_store) {
1591+
// Note the predicate asked by Store op is set as interleaved vectors,
1592+
// but what we want is the original one,
1593+
// so we need to either deinterleave or get the vector from the input of Shuffle.
1594+
// And we make sure the deinterleaved predicates are all the same.
1595+
1596+
// Dig through let expressions
1597+
Expr rhs = op->predicate;
1598+
while (const Let *let = rhs.as<Let>()) {
1599+
rhs = let->body;
1600+
lets_pred.emplace_back(let->name, let->value);
1601+
}
1602+
1603+
Expr vpred_predicated_store;
1604+
bool predicates_are_same = true;
1605+
const Shuffle *shuffle = rhs.as<Shuffle>();
1606+
if (shuffle && shuffle->is_interleave() && shuffle->vectors.size() == static_cast<size_t>(num_vecs)) {
1607+
vpred_predicated_store = shuffle->vectors[0];
1608+
for (int i = 1; i < num_vecs; ++i) {
1609+
predicates_are_same &= can_prove(vpred_predicated_store == shuffle->vectors[i]);
1610+
}
1611+
} else {
1612+
vpred_predicated_store = Shuffle::make_slice(op->predicate, 0, num_vecs, t.lanes());
1613+
for (int i = 1; i < num_vecs; ++i) {
1614+
predicates_are_same &= can_prove(vpred_predicated_store == Shuffle::make_slice(op->predicate, i, num_vecs, t.lanes()));
1615+
}
1616+
}
1617+
1618+
if (predicates_are_same) {
1619+
// Codegen the lets
1620+
for (auto &let : lets_pred) {
1621+
sym_push(let.first, codegen(let.second));
1622+
}
1623+
vpred_predicated_store_val = codegen(vpred_predicated_store);
1624+
} else {
1625+
CodeGen_Posix::visit(op);
1626+
return;
1627+
}
1628+
}
1629+
15901630
for (int i = 0; i < t.lanes(); i += intrin_type.lanes()) {
15911631
Expr slice_base = simplify(ramp->base + i * num_vecs);
15921632
Expr slice_ramp = Ramp::make(slice_base, ramp->stride, intrin_type.lanes() * num_vecs);
@@ -1606,11 +1646,22 @@ void CodeGen_ARM::visit(const Store *op) {
16061646
slice_args.push_back(ConstantInt::get(i32_t, alignment));
16071647
} else {
16081648
if (is_sve) {
1609-
// Set the predicate argument to mask active lanes
1649+
// Set the predicate argument
1650+
// Use predicate to deactivate tail if t.lanes() is not the multiple of intrin_type.lanes()
16101651
auto active_lanes = std::min(t.lanes() - i, intrin_type.lanes());
1611-
Expr vpred = make_vector_predicate_1s_0s(active_lanes, intrin_type.lanes() - active_lanes);
1612-
Value *vpred_val = codegen(vpred);
1613-
slice_args.push_back(vpred_val);
1652+
auto inactive_lanes = intrin_type.lanes() - active_lanes;
1653+
Value *vpred;
1654+
if (is_predicated_store) {
1655+
vpred = slice_vector(vpred_predicated_store_val, i, active_lanes);
1656+
if (inactive_lanes > 0) {
1657+
Value *tail = codegen(const_false(inactive_lanes));
1658+
vpred = concat_vectors({vpred, tail});
1659+
}
1660+
} else {
1661+
vpred = codegen(make_vector_predicate_1s_0s(active_lanes, inactive_lanes));
1662+
}
1663+
1664+
slice_args.push_back(vpred);
16141665
}
16151666
// Set the pointer argument
16161667
slice_args.push_back(ptr);
@@ -1632,6 +1683,9 @@ void CodeGen_ARM::visit(const Store *op) {
16321683
for (auto &let : lets) {
16331684
sym_pop(let.first);
16341685
}
1686+
for (auto &let : lets_pred) {
1687+
sym_pop(let.first);
1688+
}
16351689

16361690
return;
16371691
}

test/correctness/simd_op_check_sve2.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
765765

766766
// Also check when the two expressions interleaved have a common
767767
// subexpression, which results in a vector var being lifted out.
768-
for (int factor : {1, 2}) {
768+
for (float factor : {0.5f, 1.f, 2.f}) {
769769
const int width = base_vec_bits * 2 * factor;
770770
const int total_lanes = width / bits;
771771
const int vector_lanes = total_lanes / 2;
@@ -790,7 +790,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
790790
}
791791

792792
// ST3 - Store three-element structures
793-
for (int factor : {1, 2}) {
793+
for (float factor : {0.5f, 1.f, 2.f}) {
794794
const int width = base_vec_bits * 3 * factor;
795795
const int total_lanes = width / bits;
796796
const int vector_lanes = total_lanes / 3;
@@ -818,7 +818,7 @@ class SimdOpCheckArmSve : public SimdOpCheckTest {
818818
}
819819

820820
// ST4 - Store four-element structures
821-
for (int factor : {1, 2}) {
821+
for (float factor : {0.5f, 1.f, 2.f}) {
822822
const int width = base_vec_bits * 4 * factor;
823823
const int total_lanes = width / bits;
824824
const int vector_lanes = total_lanes / 4;

0 commit comments

Comments
 (0)