Skip to content

Commit 34711c8

Browse files
committed
arena: support inner result permutation for ToT x scalar Scale contraction
A mixed inner-Scale product (Tensor<ArenaTensor> ToT x plain Tensor -> ToT) under an outer Contraction with a non-identity inner *result* permutation crashed in ArenaTensor::axpy_to(other, factor, perm), which rejects in-place permutation of view cells. The Scale path pushed the inner perm onto the per-cell op (via the fallback axpy_to(..., perm)) and dropped it from total_perm, while make_contraction_arena_plan bailed on any non-identity inner perm -- leaving the view result cell both unshaped and asked to permute itself. Mirror the inner-Contraction view handling instead: for view (arena) result cells, carry the full perm in total_perm so op_'s post-processing permute applies the inner result perm as a slab-level rewrite, and pass an identity inner perm to make_contraction_arena_plan so it builds the plan (pre-shaping result cells unpermuted) and selects the perm-free fused scale op. Owning inner cells keep applying the inner perm in the per-cell scale op (outer-only total_perm), unchanged.
1 parent 6330596 commit 34711c8

1 file changed

Lines changed: 31 additions & 7 deletions

File tree

src/TiledArray/expressions/cont_engine.h

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,18 @@ class ContEngine : public BinaryEngine<Derived> {
305305
// this->product_type() is Tensor::Contraction, and,
306306
// this->implicit_permute_inner_ is false
307307

308-
return this->inner_product_type() == TensorProduct::Scale
309-
? BipartitePermutation(outer(this->perm_))
310-
: this->perm_;
308+
if (this->inner_product_type() == TensorProduct::Scale) {
309+
// Owning inner cells apply the inner result permutation in the
310+
// per-cell scale op, so they carry only the outer perm here. View
311+
// (arena) cells instead use a perm-free per-cell op + an unpermuted
312+
// arena plan and rely on op_'s post-processing permute for the
313+
// inner perm -- so they carry the full perm, like inner
314+
// Contraction.
315+
if constexpr (!TiledArray::is_tensor_view_v<
316+
result_tile_element_type>)
317+
return BipartitePermutation(outer(this->perm_));
318+
}
319+
return this->perm_;
311320
};
312321

313322
auto total_perm = make_total_perm();
@@ -341,9 +350,18 @@ class ContEngine : public BinaryEngine<Derived> {
341350
// this->product_type() is Tensor::Contraction, and,
342351
// this->implicit_permute_inner_ is false
343352

344-
return this->inner_product_type() == TensorProduct::Scale
345-
? BipartitePermutation(outer(this->perm_))
346-
: this->perm_;
353+
if (this->inner_product_type() == TensorProduct::Scale) {
354+
// Owning inner cells apply the inner result permutation in the
355+
// per-cell scale op, so they carry only the outer perm here. View
356+
// (arena) cells instead use a perm-free per-cell op + an unpermuted
357+
// arena plan and rely on op_'s post-processing permute for the
358+
// inner perm -- so they carry the full perm, like inner
359+
// Contraction.
360+
if constexpr (!TiledArray::is_tensor_view_v<
361+
result_tile_element_type>)
362+
return BipartitePermutation(outer(this->perm_));
363+
}
364+
return this->perm_;
347365
};
348366

349367
auto total_perm = make_total_perm();
@@ -949,10 +967,16 @@ class ContEngine : public BinaryEngine<Derived> {
949967
result_tile_type, left_tile_type, right_tile_type>;
950968
if constexpr (arena_eligible_scale) {
951969
if (this->product_type() == TensorProduct::Contraction) {
970+
// Pass an identity inner perm: a non-identity inner *result*
971+
// permutation is applied downstream by op_'s post-processing
972+
// permute (carried in make_total_perm for view cells), not by
973+
// the per-cell op -- so the plan must not bail on it. The plan
974+
// pre-shapes result cells in the unpermuted (operand) inner
975+
// layout; the perm-free fused scale op accumulates into them.
952976
this->arena_plan_ =
953977
TiledArray::detail::make_contraction_arena_plan<
954978
result_tile_type, left_tile_type, right_tile_type>(
955-
kind, std::nullopt, inner(this->perm_));
979+
kind, std::nullopt, Permutation{});
956980
}
957981
}
958982
// Fallback per-element op for the scale inner-product when no

0 commit comments

Comments
 (0)