@@ -1953,33 +1953,28 @@ class Tensor {
19531953 return !static_cast <bool >(perm) || perm.is_identity ();
19541954 }
19551955
1956- // / Permuted add for `Tensor<ArenaTensor>` ToT operands. A non-trivial
1957- // / permutation of arena ToT tiles is not yet supported; an identity (or
1958- // / null) permutation falls through to the plain element-wise add.
1956+ // / Permuted add for `Tensor<ArenaTensor>` ToT operands. The operands are
1957+ // / congruent by the time a permuted product reaches a tile op, so the
1958+ // / elementwise `add(right)` is valid and `perm` is the result permutation;
1959+ // / `permute` applies it (shallow outer reindex + inner-slab rewrite).
19591960 template <typename Right, typename Perm>
19601961 requires (is_arena_tensor_v<value_type> &&
19611962 is_arena_tensor_v<typename Right::value_type> &&
19621963 detail::is_permutation_v<Perm>)
19631964 Tensor add (const Right& right, const Perm& perm) const {
1964- if (!arena_perm_is_trivial (perm))
1965- TA_EXCEPTION (
1966- " TA::Tensor<ArenaTensor>::add: permuted add of a tensor-of-tensors "
1967- " is not yet supported" );
1968- return add (right);
1965+ auto result = add (right);
1966+ return arena_perm_is_trivial (perm) ? result : result.permute (perm);
19691967 }
19701968
19711969 // / Permuted scaled add for `Tensor<ArenaTensor>` ToT operands; see the
1972- // / permuted-add overload above for the permutation restriction .
1970+ // / permuted-add overload above for the congruent-operand rationale .
19731971 template <typename Right, typename Scalar, typename Perm>
19741972 requires (is_arena_tensor_v<value_type> &&
19751973 is_arena_tensor_v<typename Right::value_type> &&
19761974 detail::is_numeric_v<Scalar> && detail::is_permutation_v<Perm>)
19771975 Tensor add (const Right& right, const Scalar factor, const Perm& perm) const {
1978- if (!arena_perm_is_trivial (perm))
1979- TA_EXCEPTION (
1980- " TA::Tensor<ArenaTensor>::add: permuted scaled add of a "
1981- " tensor-of-tensors is not yet supported" );
1982- return add (right, factor);
1976+ auto result = add (right, factor);
1977+ return arena_perm_is_trivial (perm) ? result : result.permute (perm);
19831978 }
19841979
19851980 // / Add this and \c other to construct a new tensor
@@ -2382,8 +2377,15 @@ class Tensor {
23822377 typename std::enable_if<is_tensor<Right>::value &&
23832378 detail::is_permutation_v<Perm>>::type* = nullptr >
23842379 Tensor subt (const Right& right, const Perm& perm) const {
2385- if constexpr (is_tensor_view_v<value_type>) {
2386- // Permutation isn't supported for view inner cells (fixed storage
2380+ if constexpr (is_arena_tensor_v<value_type> &&
2381+ is_arena_tensor_v<typename Right::value_type>) {
2382+ // arena ToT x arena ToT: operands are congruent at tile-op time, so the
2383+ // elementwise `subt(right)` is valid; apply the result permutation as a
2384+ // post-pass (shallow outer reindex + inner-slab rewrite).
2385+ auto result = subt (right);
2386+ return arena_perm_is_trivial (perm) ? result : result.permute (perm);
2387+ } else if constexpr (is_tensor_view_v<value_type>) {
2388+ // Permutation isn't supported for other view inner cells (fixed storage
23872389 // layout). Subt+permute would require materialization.
23882390 TA_EXCEPTION (
23892391 " Tensor<View>::subt(right, perm): permutation is not "
@@ -2443,11 +2445,10 @@ class Tensor {
24432445 Tensor subt (const Right& right, const Scalar factor, const Perm& perm) const {
24442446 if constexpr (is_arena_tensor_v<value_type> &&
24452447 is_arena_tensor_v<typename Right::value_type>) {
2446- if (!arena_perm_is_trivial (perm))
2447- TA_EXCEPTION (
2448- " TA::Tensor<ArenaTensor>::subt: permuted scaled subt of a "
2449- " tensor-of-tensors is not yet supported" );
2450- return subt (right, factor);
2448+ // arena ToT x arena ToT scaled subtraction; see the unscaled permuted
2449+ // subt overload above for the congruent-operand rationale.
2450+ auto result = subt (right, factor);
2451+ return arena_perm_is_trivial (perm) ? result : result.permute (perm);
24512452 } else {
24522453 return binary (
24532454 right,
@@ -2622,11 +2623,15 @@ class Tensor {
26222623 decltype (auto ) mult(const Right& right, const Perm& perm) const {
26232624 if constexpr (is_arena_tensor_v<value_type> &&
26242625 is_arena_tensor_v<typename Right::value_type>) {
2625- if (!arena_perm_is_trivial (perm))
2626- TA_EXCEPTION (
2627- " TA::Tensor<ArenaTensor>::mult: permuted mult of a "
2628- " tensor-of-tensors is not yet supported" );
2629- return mult (right);
2626+ // arena ToT x arena ToT Hadamard product. By the time a permuted product
2627+ // reaches a tile op, the engine has already brought both operands to a
2628+ // common (congruent) layout, so the elementwise `mult(right)` is valid;
2629+ // `perm` is the result permutation (common layout -> target). Apply it
2630+ // as a post-pass: `permute` reindexes the outer cells shallowly
2631+ // (arena_permute_shallow) and rewrites the inner slab if the inner part
2632+ // of the permutation is non-trivial (arena_inner_permute).
2633+ auto result = mult (right);
2634+ return arena_perm_is_trivial (perm) ? result : result.permute (perm);
26302635 } else if constexpr (detail::is_numeric_v<value_type> &&
26312636 is_arena_tensor_v<typename Right::value_type>) {
26322637 // t x tot: a plain scalar tile times an arena ToT tile. The 2-arg
@@ -2697,11 +2702,11 @@ class Tensor {
26972702 const Perm& perm) const {
26982703 if constexpr (is_arena_tensor_v<value_type> &&
26992704 is_arena_tensor_v<typename Right::value_type>) {
2700- if (! arena_perm_is_trivial (perm))
2701- TA_EXCEPTION (
2702- " TA::Tensor<ArenaTensor>::mult: permuted scaled mult of a "
2703- " tensor-of-tensors is not yet supported " );
2704- return mult (right, factor );
2705+ // arena ToT x arena ToT scaled Hadamard product; see the unscaled
2706+ // permuted mult overload above for the congruent-operand rationale.
2707+ // Scale during the elementwise product, then permute the result.
2708+ auto result = mult (right, factor );
2709+ return arena_perm_is_trivial (perm) ? result : result. permute (perm );
27052710 } else {
27062711 return binary (
27072712 right,
0 commit comments