Skip to content

Commit 23505b3

Browse files
committed
cont_engine: view ToTxToT (outer Contraction, inner Hadamard) via arena
Add the (outer Contraction, inner Hadamard) case to init_inner_tile_op's view-cell branch. Mirrors the owning-tile path in init_inner_tile_op_owning_: arena_plan_ uses the `left_range` plan to shape each result cell from a non-empty left inner cell, and the per-cell op accumulates `r += l * rr` -- or `r += (l * rr) * factor_` when scaled -- via fused_hadamard_inplace into the pre-shaped view cell. No value-returning per-cell op is needed, so this works for view cells (e.g. ArenaTensor); non-identity inner result permutation is rejected (the owning fallback that materializes a permuted return cell cannot run for views). Previously this case threw "nested non-contraction product on view inner tiles is not yet supported", aborting expressions such as `C(i_3,i_4;a<...>) = A(i_3;a<...>) * B(i_4;a<...>)` over ArenaTensor inner cells -- the typical sub-product inside einsum's generalized contraction loop for ToTxToT with Hadamard outer-Hadamard inner shapes.
1 parent a98accc commit 23505b3

1 file changed

Lines changed: 58 additions & 9 deletions

File tree

src/TiledArray/expressions/cont_engine.h

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -547,20 +547,69 @@ class ContEngine : public BinaryEngine<Derived> {
547547
TiledArray::is_tensor_view_v<result_tile_element_type>) {
548548
// ToT x ToT with non-owning view inner cells (e.g. ArenaTensor). A
549549
// view cell cannot host a value-returning inner op, so the
550-
// owning-cell inner-op builder cannot be used. Two nested products
551-
// are supported here:
552-
// - the elementwise pure Hadamard, where the inner element op is
553-
// unused anyway -- MultEngine::make_tile_op passes none and the
554-
// outer Mult tile op recurses through Tensor<view>::mult -- so
555-
// element_*_op_ is left null;
556-
// - the inner contraction (incl. inner outer-product), routed
557-
// through the arena fast path: it writes results in place into
558-
// pre-shaped view cells, so only element_nonreturn_op_ is needed.
550+
// owning-cell inner-op builder cannot be used. The supported nested
551+
// products are:
552+
// - the elementwise pure Hadamard (outer Hadamard, inner Hadamard),
553+
// where the inner element op is unused anyway -- MultEngine::
554+
// make_tile_op passes none and the outer Mult tile op recurses
555+
// through Tensor<view>::mult -- so element_*_op_ is left null;
556+
// - inner Hadamard under outer Contraction, routed through the
557+
// arena fast path with a left_range plan and a per-cell
558+
// `r += l * rr` (optionally scaled) op: result cells are
559+
// pre-shaped from non-empty left cells, then accumulated in
560+
// place over the K-panel;
561+
// - inner Contraction (incl. inner outer-product) under either
562+
// outer regime, routed through the arena fast path: it writes
563+
// results in place into pre-shaped view cells, so only
564+
// element_nonreturn_op_ is needed.
559565
// Every other nested product is deferred.
560566
const auto inner_prod = this->inner_product_type();
561567
if (inner_prod == TensorProduct::Hadamard &&
562568
this->product_type() == TensorProduct::Hadamard) {
563569
// pure Hadamard: element_*_op_ left null
570+
} else if (inner_prod == TensorProduct::Hadamard &&
571+
this->product_type() == TensorProduct::Contraction) {
572+
// outer Contraction + inner Hadamard on view inner tiles.
573+
// Mirror the owning-tile path (init_inner_tile_op_owning_): the
574+
// SUMMA shapes each result cell from a non-empty left inner cell
575+
// (left_range plan), and the per-cell op accumulates `r += l * rr`
576+
// -- or `r += (l * rr) * factor_` when scaled -- via
577+
// fused_hadamard_inplace into the pre-shaped view cell. No
578+
// value-returning per-cell op is needed, so this works for view
579+
// cells; non-identity inner result permutation is rejected here
580+
// (the owning fallback that materializes a permuted return cell
581+
// cannot run for views).
582+
constexpr bool arena_eligible_h_view =
583+
TiledArray::detail::is_contraction_arena_tot_v<
584+
result_tile_type, left_tile_type, right_tile_type>;
585+
if constexpr (!arena_eligible_h_view) {
586+
TA_EXCEPTION(
587+
"nested Hadamard on view inner tiles is supported only for "
588+
"arena-backed tensors-of-tensors");
589+
} else {
590+
this->arena_plan_ = TiledArray::detail::make_contraction_arena_plan<
591+
result_tile_type, left_tile_type, right_tile_type>(
592+
TiledArray::detail::ArenaInnerShapeKind::left_range,
593+
std::nullopt, inner(this->perm_));
594+
if (!bool(this->arena_plan_))
595+
TA_EXCEPTION(
596+
"nested Hadamard on view inner tiles: the arena fast path "
597+
"was inactive (arena disabled, or a non-identity inner "
598+
"result permutation -- not yet supported on view cells)");
599+
if (this->factor_ == scalar_type{1}) {
600+
this->element_nonreturn_op_ =
601+
TiledArray::detail::make_fused_hadamard_lambda<
602+
result_tile_element_type, left_tile_element_type,
603+
right_tile_element_type>();
604+
} else {
605+
this->element_nonreturn_op_ =
606+
TiledArray::detail::make_fused_hadamard_scaled_lambda<
607+
result_tile_element_type, left_tile_element_type,
608+
right_tile_element_type>(this->factor_);
609+
}
610+
}
611+
// element_return_op_ left null: a view cell cannot be
612+
// value-returned (see the init_struct precondition check).
564613
} else if (inner_prod == TensorProduct::Contraction) {
565614
using op_type = TiledArray::detail::ContractReduce<
566615
result_tile_element_type, left_tile_element_type,

0 commit comments

Comments
 (0)