Skip to content

Commit 6330596

Browse files
committed
arena: support permuted Hadamard add/subt/mult on Tensor<ArenaTensor> ToT
The permuted, arena ToT x arena ToT overloads of add, subt, and mult (scaled and unscaled) previously threw "permuted ... of a tensor-of-tensors is not yet supported". This blocked CSV/PNO-based coupled-cluster, whose residual evaluates permuted ToT Hadamard products at the tile-op level (a binary Mult/Add op calling left.mult(right, perm) etc.). By the time a permuted product reaches a tile op, the expression engine has already brought both operands to a common (congruent) layout, so the elementwise product/sum is valid and perm is purely the result permutation. Compute the unpermuted result, then apply perm as a post-pass via permute(), which already handles arena ToT: a shallow outer-cell reindex (arena_permute_shallow) plus an inner-slab rewrite (arena_inner_permute) when the bipartite permutation's inner part is non-trivial. This mirrors the existing numeric x arena permuted-mult branches.
1 parent 00996ce commit 6330596

1 file changed

Lines changed: 36 additions & 31 deletions

File tree

src/TiledArray/tensor/tensor.h

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1953,33 +1953,28 @@ class Tensor {
19531953
return !static_cast<bool>(perm) || perm.is_identity();
19541954
}
19551955

1956-
/// Permuted add for `Tensor<ArenaTensor>` ToT operands. A non-trivial
1957-
/// permutation of arena ToT tiles is not yet supported; an identity (or
1958-
/// null) permutation falls through to the plain element-wise add.
1956+
/// Permuted add for `Tensor<ArenaTensor>` ToT operands. The operands are
1957+
/// congruent by the time a permuted product reaches a tile op, so the
1958+
/// elementwise `add(right)` is valid and `perm` is the result permutation;
1959+
/// `permute` applies it (shallow outer reindex + inner-slab rewrite).
19591960
template <typename Right, typename Perm>
19601961
requires(is_arena_tensor_v<value_type> &&
19611962
is_arena_tensor_v<typename Right::value_type> &&
19621963
detail::is_permutation_v<Perm>)
19631964
Tensor add(const Right& right, const Perm& perm) const {
1964-
if (!arena_perm_is_trivial(perm))
1965-
TA_EXCEPTION(
1966-
"TA::Tensor<ArenaTensor>::add: permuted add of a tensor-of-tensors "
1967-
"is not yet supported");
1968-
return add(right);
1965+
auto result = add(right);
1966+
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
19691967
}
19701968

19711969
/// Permuted scaled add for `Tensor<ArenaTensor>` ToT operands; see the
1972-
/// permuted-add overload above for the permutation restriction.
1970+
/// permuted-add overload above for the congruent-operand rationale.
19731971
template <typename Right, typename Scalar, typename Perm>
19741972
requires(is_arena_tensor_v<value_type> &&
19751973
is_arena_tensor_v<typename Right::value_type> &&
19761974
detail::is_numeric_v<Scalar> && detail::is_permutation_v<Perm>)
19771975
Tensor add(const Right& right, const Scalar factor, const Perm& perm) const {
1978-
if (!arena_perm_is_trivial(perm))
1979-
TA_EXCEPTION(
1980-
"TA::Tensor<ArenaTensor>::add: permuted scaled add of a "
1981-
"tensor-of-tensors is not yet supported");
1982-
return add(right, factor);
1976+
auto result = add(right, factor);
1977+
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
19831978
}
19841979

19851980
/// Add this and \c other to construct a new tensor
@@ -2382,8 +2377,15 @@ class Tensor {
23822377
typename std::enable_if<is_tensor<Right>::value &&
23832378
detail::is_permutation_v<Perm>>::type* = nullptr>
23842379
Tensor subt(const Right& right, const Perm& perm) const {
2385-
if constexpr (is_tensor_view_v<value_type>) {
2386-
// Permutation isn't supported for view inner cells (fixed storage
2380+
if constexpr (is_arena_tensor_v<value_type> &&
2381+
is_arena_tensor_v<typename Right::value_type>) {
2382+
// arena ToT x arena ToT: operands are congruent at tile-op time, so the
2383+
// elementwise `subt(right)` is valid; apply the result permutation as a
2384+
// post-pass (shallow outer reindex + inner-slab rewrite).
2385+
auto result = subt(right);
2386+
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
2387+
} else if constexpr (is_tensor_view_v<value_type>) {
2388+
// Permutation isn't supported for other view inner cells (fixed storage
23872389
// layout). Subt+permute would require materialization.
23882390
TA_EXCEPTION(
23892391
"Tensor<View>::subt(right, perm): permutation is not "
@@ -2443,11 +2445,10 @@ class Tensor {
24432445
Tensor subt(const Right& right, const Scalar factor, const Perm& perm) const {
24442446
if constexpr (is_arena_tensor_v<value_type> &&
24452447
is_arena_tensor_v<typename Right::value_type>) {
2446-
if (!arena_perm_is_trivial(perm))
2447-
TA_EXCEPTION(
2448-
"TA::Tensor<ArenaTensor>::subt: permuted scaled subt of a "
2449-
"tensor-of-tensors is not yet supported");
2450-
return subt(right, factor);
2448+
// arena ToT x arena ToT scaled subtraction; see the unscaled permuted
2449+
// subt overload above for the congruent-operand rationale.
2450+
auto result = subt(right, factor);
2451+
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
24512452
} else {
24522453
return binary(
24532454
right,
@@ -2622,11 +2623,15 @@ class Tensor {
26222623
decltype(auto) mult(const Right& right, const Perm& perm) const {
26232624
if constexpr (is_arena_tensor_v<value_type> &&
26242625
is_arena_tensor_v<typename Right::value_type>) {
2625-
if (!arena_perm_is_trivial(perm))
2626-
TA_EXCEPTION(
2627-
"TA::Tensor<ArenaTensor>::mult: permuted mult of a "
2628-
"tensor-of-tensors is not yet supported");
2629-
return mult(right);
2626+
// arena ToT x arena ToT Hadamard product. By the time a permuted product
2627+
// reaches a tile op, the engine has already brought both operands to a
2628+
// common (congruent) layout, so the elementwise `mult(right)` is valid;
2629+
// `perm` is the result permutation (common layout -> target). Apply it
2630+
// as a post-pass: `permute` reindexes the outer cells shallowly
2631+
// (arena_permute_shallow) and rewrites the inner slab if the inner part
2632+
// of the permutation is non-trivial (arena_inner_permute).
2633+
auto result = mult(right);
2634+
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
26302635
} else if constexpr (detail::is_numeric_v<value_type> &&
26312636
is_arena_tensor_v<typename Right::value_type>) {
26322637
// t x tot: a plain scalar tile times an arena ToT tile. The 2-arg
@@ -2697,11 +2702,11 @@ class Tensor {
26972702
const Perm& perm) const {
26982703
if constexpr (is_arena_tensor_v<value_type> &&
26992704
is_arena_tensor_v<typename Right::value_type>) {
2700-
if (!arena_perm_is_trivial(perm))
2701-
TA_EXCEPTION(
2702-
"TA::Tensor<ArenaTensor>::mult: permuted scaled mult of a "
2703-
"tensor-of-tensors is not yet supported");
2704-
return mult(right, factor);
2705+
// arena ToT x arena ToT scaled Hadamard product; see the unscaled
2706+
// permuted mult overload above for the congruent-operand rationale.
2707+
// Scale during the elementwise product, then permute the result.
2708+
auto result = mult(right, factor);
2709+
return arena_perm_is_trivial(perm) ? result : result.permute(perm);
27052710
} else {
27062711
return binary(
27072712
right,

0 commit comments

Comments
 (0)