Skip to content

Commit 4e4a1c1

Browse files
committed
arena: don't drop the inner result perm for owning ToT Scale contractions
Commit 34711c8 made the ToT x scalar Scale contraction always hand an identity inner perm to make_contraction_arena_plan, so the plan is built and the perm-free fused scale op is selected, with the inner result perm applied downstream by op_'s post-processing permute. That is correct for view (arena) inner cells -- make_total_perm carries the full perm for them -- but is_contraction_arena_tot_v is also true for owning legacy TA::Tensor ToT inner cells, and for those make_total_perm carries only the outer perm. So an owning ToT Scale contraction with a non-identity inner result permutation lost the inner perm entirely: identity plan + perm-free op + outer-only total_perm, producing a wrong-inner-layout result (and, under distributed eval, a malformed result whose deferred destruction aborted at a later fence). Mirror the make_total_perm view/owning split here: pass an identity inner perm only for view cells; for owning cells pass inner(perm_) so the plan bails (nullopt) on a non-identity inner perm and the per-cell op applies it, exactly as before 34711c8. Restores einsum_manual/ different_nested_ranks and einsum_tot_t/ilkj_nm_eq_ij_mn_times_kl.
1 parent 34711c8 commit 4e4a1c1

1 file changed

Lines changed: 21 additions & 7 deletions

File tree

src/TiledArray/expressions/cont_engine.h

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -967,16 +967,30 @@ class ContEngine : public BinaryEngine<Derived> {
967967
result_tile_type, left_tile_type, right_tile_type>;
968968
if constexpr (arena_eligible_scale) {
969969
if (this->product_type() == TensorProduct::Contraction) {
970-
// Pass an identity inner perm: a non-identity inner *result*
971-
// permutation is applied downstream by op_'s post-processing
972-
// permute (carried in make_total_perm for view cells), not by
973-
// the per-cell op -- so the plan must not bail on it. The plan
974-
// pre-shapes result cells in the unpermuted (operand) inner
975-
// layout; the perm-free fused scale op accumulates into them.
970+
// The inner perm handed to the plan must match how the inner
971+
// *result* permutation is applied for this result cell type --
972+
// and the two cell types apply it in different places:
973+
//
974+
// * View (arena) cells: pass an identity inner perm so the
975+
// plan is always built (pre-shaping result cells in the
976+
// unpermuted operand inner layout) and the perm-free fused
977+
// scale op is selected; the inner result perm is applied
978+
// downstream by op_'s post-processing permute (carried in
979+
// make_total_perm for view cells).
980+
//
981+
// * Owning cells: pass the inner result perm so the plan bails
982+
// (nullopt) on a non-identity inner perm, falling back to the
983+
// per-cell op that applies the inner perm itself -- matching
984+
// the outer-only total_perm make_total_perm carries here.
985+
// (A trivial inner perm still lets the plan + fused op run.)
986+
Permutation plan_inner_perm;
987+
if constexpr (!TiledArray::is_tensor_view_v<
988+
result_tile_element_type>)
989+
plan_inner_perm = inner(this->perm_);
976990
this->arena_plan_ =
977991
TiledArray::detail::make_contraction_arena_plan<
978992
result_tile_type, left_tile_type, right_tile_type>(
979-
kind, std::nullopt, Permutation{});
993+
kind, std::nullopt, plan_inner_perm);
980994
}
981995
}
982996
// Fallback per-element op for the scale inner-product when no

0 commit comments

Comments
 (0)