@@ -1518,9 +1518,7 @@ void CodeGen_ARM::visit(const Store *op) {
15181518 if (ramp && is_const_one (ramp->stride ) &&
15191519 shuffle && shuffle->is_interleave () &&
15201520 type_ok_for_vst &&
1521- 2 <= shuffle->vectors .size () && shuffle->vectors .size () <= 4 &&
1522- // TODO: we could handle predicated_store once shuffle_vector gets robust for scalable vectors
1523- !is_predicated_store) {
1521+ 2 <= shuffle->vectors .size () && shuffle->vectors .size () <= 4 ) {
15241522
15251523 const int num_vecs = shuffle->vectors .size ();
15261524 vector<Value *> args (num_vecs);
@@ -1587,6 +1585,48 @@ void CodeGen_ARM::visit(const Store *op) {
15871585 // Scalable vector supports predication for smaller than whole vector size.
15881586 internal_assert (target_vscale () > 0 || (t.lanes () >= intrin_type.lanes ()));
15891587
1588+ Value *vpred_predicated_store_val = nullptr ;
1589+ vector<pair<string, Expr>> lets_pred;
1590+ if (is_sve && is_predicated_store) {
1591+ // Note the predicate asked by Store op is set as interleaved vectors,
1592+ // but what we want is the original one,
1593+ // so we need to either deinterleave or get the vector from the input of Shuffle.
1594+ // And we make sure the deinterleaved predicates are all the same.
1595+
1596+ // Dig through let expressions
1597+ Expr rhs = op->predicate ;
1598+ while (const Let *let = rhs.as <Let>()) {
1599+ rhs = let->body ;
1600+ lets_pred.emplace_back (let->name , let->value );
1601+ }
1602+
1603+ Expr vpred_predicated_store;
1604+ bool predicates_are_same = true ;
1605+ const Shuffle *shuffle = rhs.as <Shuffle>();
1606+ if (shuffle && shuffle->is_interleave () && shuffle->vectors .size () == static_cast <size_t >(num_vecs)) {
1607+ vpred_predicated_store = shuffle->vectors [0 ];
1608+ for (int i = 1 ; i < num_vecs; ++i) {
1609+ predicates_are_same &= can_prove (vpred_predicated_store == shuffle->vectors [i]);
1610+ }
1611+ } else {
1612+ vpred_predicated_store = Shuffle::make_slice (op->predicate , 0 , num_vecs, t.lanes ());
1613+ for (int i = 1 ; i < num_vecs; ++i) {
1614+ predicates_are_same &= can_prove (vpred_predicated_store == Shuffle::make_slice (op->predicate , i, num_vecs, t.lanes ()));
1615+ }
1616+ }
1617+
1618+ if (predicates_are_same) {
1619+ // Codegen the lets
1620+ for (auto &let : lets_pred) {
1621+ sym_push (let.first , codegen (let.second ));
1622+ }
1623+ vpred_predicated_store_val = codegen (vpred_predicated_store);
1624+ } else {
1625+ CodeGen_Posix::visit (op);
1626+ return ;
1627+ }
1628+ }
1629+
15901630 for (int i = 0 ; i < t.lanes (); i += intrin_type.lanes ()) {
15911631 Expr slice_base = simplify (ramp->base + i * num_vecs);
15921632 Expr slice_ramp = Ramp::make (slice_base, ramp->stride , intrin_type.lanes () * num_vecs);
@@ -1606,11 +1646,22 @@ void CodeGen_ARM::visit(const Store *op) {
16061646 slice_args.push_back (ConstantInt::get (i32_t , alignment));
16071647 } else {
16081648 if (is_sve) {
1609- // Set the predicate argument to mask active lanes
1649+ // Set the predicate argument
1650+ // Use predicate to deactivate tail if t.lanes() is not the multiple of intrin_type.lanes()
16101651 auto active_lanes = std::min (t.lanes () - i, intrin_type.lanes ());
1611- Expr vpred = make_vector_predicate_1s_0s (active_lanes, intrin_type.lanes () - active_lanes);
1612- Value *vpred_val = codegen (vpred);
1613- slice_args.push_back (vpred_val);
1652+ auto inactive_lanes = intrin_type.lanes () - active_lanes;
1653+ Value *vpred;
1654+ if (is_predicated_store) {
1655+ vpred = slice_vector (vpred_predicated_store_val, i, active_lanes);
1656+ if (inactive_lanes > 0 ) {
1657+ Value *tail = codegen (const_false (inactive_lanes));
1658+ vpred = concat_vectors ({vpred, tail});
1659+ }
1660+ } else {
1661+ vpred = codegen (make_vector_predicate_1s_0s (active_lanes, inactive_lanes));
1662+ }
1663+
1664+ slice_args.push_back (vpred);
16141665 }
16151666 // Set the pointer argument
16161667 slice_args.push_back (ptr);
@@ -1632,6 +1683,9 @@ void CodeGen_ARM::visit(const Store *op) {
16321683 for (auto &let : lets) {
16331684 sym_pop (let.first );
16341685 }
1686+ for (auto &let : lets_pred) {
1687+ sym_pop (let.first );
1688+ }
16351689
16361690 return ;
16371691 }
0 commit comments