Skip to content

Commit 75573cf

Browse files
authored
Merge pull request #555 from ValeevGroup/evaleev/fix/arena-tot-permuted-binary-ops
arena: support permuted Hadamard add/subt/mult on Tensor<ArenaTensor> ToT
2 parents 00996ce + 4e4a1c1 commit 75573cf

2 files changed

Lines changed: 81 additions & 38 deletions

File tree

src/TiledArray/expressions/cont_engine.h

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -305,9 +305,18 @@ class ContEngine : public BinaryEngine<Derived> {
305305
// this->product_type() is Tensor::Contraction, and,
306306
// this->implicit_permute_inner_ is false
307307

308-
return this->inner_product_type() == TensorProduct::Scale
309-
? BipartitePermutation(outer(this->perm_))
310-
: this->perm_;
308+
if (this->inner_product_type() == TensorProduct::Scale) {
309+
// Owning inner cells apply the inner result permutation in the
310+
// per-cell scale op, so they carry only the outer perm here. View
311+
// (arena) cells instead use a perm-free per-cell op + an unpermuted
312+
// arena plan and rely on op_'s post-processing permute for the
313+
// inner perm -- so they carry the full perm, like inner
314+
// Contraction.
315+
if constexpr (!TiledArray::is_tensor_view_v<
316+
result_tile_element_type>)
317+
return BipartitePermutation(outer(this->perm_));
318+
}
319+
return this->perm_;
311320
};
312321

313322
auto total_perm = make_total_perm();
@@ -341,9 +350,18 @@ class ContEngine : public BinaryEngine<Derived> {
341350
// this->product_type() is Tensor::Contraction, and,
342351
// this->implicit_permute_inner_ is false
343352

344-
return this->inner_product_type() == TensorProduct::Scale
345-
? BipartitePermutation(outer(this->perm_))
346-
: this->perm_;
353+
if (this->inner_product_type() == TensorProduct::Scale) {
354+
// Owning inner cells apply the inner result permutation in the
355+
// per-cell scale op, so they carry only the outer perm here. View
356+
// (arena) cells instead use a perm-free per-cell op + an unpermuted
357+
// arena plan and rely on op_'s post-processing permute for the
358+
// inner perm -- so they carry the full perm, like inner
359+
// Contraction.
360+
if constexpr (!TiledArray::is_tensor_view_v<
361+
result_tile_element_type>)
362+
return BipartitePermutation(outer(this->perm_));
363+
}
364+
return this->perm_;
347365
};
348366

349367
auto total_perm = make_total_perm();
@@ -949,10 +967,30 @@ class ContEngine : public BinaryEngine<Derived> {
949967
result_tile_type, left_tile_type, right_tile_type>;
950968
if constexpr (arena_eligible_scale) {
951969
if (this->product_type() == TensorProduct::Contraction) {
970+
// The inner perm handed to the plan must match how the inner
971+
// *result* permutation is applied for this result cell type --
972+
// and the two cell types apply it in different places:
973+
//
974+
// * View (arena) cells: pass an identity inner perm so the
975+
// plan is always built (pre-shaping result cells in the
976+
// unpermuted operand inner layout) and the perm-free fused
977+
// scale op is selected; the inner result perm is applied
978+
// downstream by op_'s post-processing permute (carried in
979+
// make_total_perm for view cells).
980+
//
981+
// * Owning cells: pass the inner result perm so the plan bails
982+
// (nullopt) on a non-identity inner perm, falling back to the
983+
// per-cell op that applies the inner perm itself -- matching
984+
// the outer-only total_perm make_total_perm carries here.
985+
// (A trivial inner perm still lets the plan + fused op run.)
986+
Permutation plan_inner_perm;
987+
if constexpr (!TiledArray::is_tensor_view_v<
988+
result_tile_element_type>)
989+
plan_inner_perm = inner(this->perm_);
952990
this->arena_plan_ =
953991
TiledArray::detail::make_contraction_arena_plan<
954992
result_tile_type, left_tile_type, right_tile_type>(
955-
kind, std::nullopt, inner(this->perm_));
993+
kind, std::nullopt, plan_inner_perm);
956994
}
957995
}
958996
// Fallback per-element op for the scale inner-product when no

src/TiledArray/tensor/tensor.h

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1953,33 +1953,28 @@ class Tensor {
19531953
return !static_cast<bool>(perm) || perm.is_identity();
19541954
}
19551955

1956-
/// Permuted add for `Tensor<ArenaTensor>` ToT operands. A non-trivial
1957-
/// permutation of arena ToT tiles is not yet supported; an identity (or
1958-
/// null) permutation falls through to the plain element-wise add.
1956+
/// Permuted add for `Tensor<ArenaTensor>` ToT operands. The operands are
1957+
/// congruent by the time a permuted product reaches a tile op, so the
1958+
/// elementwise `add(right)` is valid and `perm` is the result permutation;
1959+
/// `permute` applies it (shallow outer reindex + inner-slab rewrite).
19591960
template <typename Right, typename Perm>
19601961
requires(is_arena_tensor_v<value_type> &&
19611962
is_arena_tensor_v<typename Right::value_type> &&
19621963
detail::is_permutation_v<Perm>)
19631964
Tensor add(const Right& right, const Perm& perm) const {
1964-
if (!arena_perm_is_trivial(perm))
1965-
TA_EXCEPTION(
1966-
"TA::Tensor<ArenaTensor>::add: permuted add of a tensor-of-tensors "
1967-
"is not yet supported");
1968-
return add(right);
1965+
auto result = add(right);
1966+
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
19691967
}
19701968

19711969
/// Permuted scaled add for `Tensor<ArenaTensor>` ToT operands; see the
1972-
/// permuted-add overload above for the permutation restriction.
1970+
/// permuted-add overload above for the congruent-operand rationale.
19731971
template <typename Right, typename Scalar, typename Perm>
19741972
requires(is_arena_tensor_v<value_type> &&
19751973
is_arena_tensor_v<typename Right::value_type> &&
19761974
detail::is_numeric_v<Scalar> && detail::is_permutation_v<Perm>)
19771975
Tensor add(const Right& right, const Scalar factor, const Perm& perm) const {
1978-
if (!arena_perm_is_trivial(perm))
1979-
TA_EXCEPTION(
1980-
"TA::Tensor<ArenaTensor>::add: permuted scaled add of a "
1981-
"tensor-of-tensors is not yet supported");
1982-
return add(right, factor);
1976+
auto result = add(right, factor);
1977+
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
19831978
}
19841979

19851980
/// Add this and \c other to construct a new tensor
@@ -2382,8 +2377,15 @@ class Tensor {
23822377
typename std::enable_if<is_tensor<Right>::value &&
23832378
detail::is_permutation_v<Perm>>::type* = nullptr>
23842379
Tensor subt(const Right& right, const Perm& perm) const {
2385-
if constexpr (is_tensor_view_v<value_type>) {
2386-
// Permutation isn't supported for view inner cells (fixed storage
2380+
if constexpr (is_arena_tensor_v<value_type> &&
2381+
is_arena_tensor_v<typename Right::value_type>) {
2382+
// arena ToT x arena ToT: operands are congruent at tile-op time, so the
2383+
// elementwise `subt(right)` is valid; apply the result permutation as a
2384+
// post-pass (shallow outer reindex + inner-slab rewrite).
2385+
auto result = subt(right);
2386+
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
2387+
} else if constexpr (is_tensor_view_v<value_type>) {
2388+
// Permutation isn't supported for other view inner cells (fixed storage
23872389
// layout). Subt+permute would require materialization.
23882390
TA_EXCEPTION(
23892391
"Tensor<View>::subt(right, perm): permutation is not "
@@ -2443,11 +2445,10 @@ class Tensor {
24432445
Tensor subt(const Right& right, const Scalar factor, const Perm& perm) const {
24442446
if constexpr (is_arena_tensor_v<value_type> &&
24452447
is_arena_tensor_v<typename Right::value_type>) {
2446-
if (!arena_perm_is_trivial(perm))
2447-
TA_EXCEPTION(
2448-
"TA::Tensor<ArenaTensor>::subt: permuted scaled subt of a "
2449-
"tensor-of-tensors is not yet supported");
2450-
return subt(right, factor);
2448+
// arena ToT x arena ToT scaled subtraction; see the unscaled permuted
2449+
// subt overload above for the congruent-operand rationale.
2450+
auto result = subt(right, factor);
2451+
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
24512452
} else {
24522453
return binary(
24532454
right,
@@ -2622,11 +2623,15 @@ class Tensor {
26222623
decltype(auto) mult(const Right& right, const Perm& perm) const {
26232624
if constexpr (is_arena_tensor_v<value_type> &&
26242625
is_arena_tensor_v<typename Right::value_type>) {
2625-
if (!arena_perm_is_trivial(perm))
2626-
TA_EXCEPTION(
2627-
"TA::Tensor<ArenaTensor>::mult: permuted mult of a "
2628-
"tensor-of-tensors is not yet supported");
2629-
return mult(right);
2626+
// arena ToT x arena ToT Hadamard product. By the time a permuted product
2627+
// reaches a tile op, the engine has already brought both operands to a
2628+
// common (congruent) layout, so the elementwise `mult(right)` is valid;
2629+
// `perm` is the result permutation (common layout -> target). Apply it
2630+
// as a post-pass: `permute` reindexes the outer cells shallowly
2631+
// (arena_permute_shallow) and rewrites the inner slab if the inner part
2632+
// of the permutation is non-trivial (arena_inner_permute).
2633+
auto result = mult(right);
2634+
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
26302635
} else if constexpr (detail::is_numeric_v<value_type> &&
26312636
is_arena_tensor_v<typename Right::value_type>) {
26322637
// t x tot: a plain scalar tile times an arena ToT tile. The 2-arg
@@ -2697,11 +2702,11 @@ class Tensor {
26972702
const Perm& perm) const {
26982703
if constexpr (is_arena_tensor_v<value_type> &&
26992704
is_arena_tensor_v<typename Right::value_type>) {
2700-
if (!arena_perm_is_trivial(perm))
2701-
TA_EXCEPTION(
2702-
"TA::Tensor<ArenaTensor>::mult: permuted scaled mult of a "
2703-
"tensor-of-tensors is not yet supported");
2704-
return mult(right, factor);
2705+
// arena ToT x arena ToT scaled Hadamard product; see the unscaled
2706+
// permuted mult overload above for the congruent-operand rationale.
2707+
// Scale during the elementwise product, then permute the result.
2708+
auto result = mult(right, factor);
2709+
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
27052710
} else {
27062711
return binary(
27072712
right,

0 commit comments

Comments
 (0)