Skip to content

Commit eb92aa9

Browse files
committed
einsum: ToT x ToT -> T denest via phantom-unit dot, not expand-then-reduce
A denested ToT x ToT contraction (inner indices fully contracted, plain-T result) was evaluated by einsum<DeNest::True> as expand-then-reduce: it formed the full uncontracted product C0 (external x contracted-outer x inner) before reducing. With a large contracted-outer index (e.g. a DF/RI index) this materializes an enormous intermediate -- ~20 GB for one CSV-CC term in C8H18 PNO-CCSD (256 s, 26.6 GB peak) for a 412 KB result. Reformulate as a single contraction whose inner product is a Frobenius dot. The inner reduction is expressed with a phantom unit-extent result mode (reserved label prefix U+2297, is_phantom_unit_label) so the result inner cell is a genuine order->=1 tensor (TA has no order-0), and the dot reads operand cells flat: no operand carries the phantom mode, no inner GEMM rank match, no order-0, and C0 is never built. Correct even when an inner extent depends on a contracted-outer index. - util/annotation.h: phantom-unit label prefix + is_phantom_unit_label. - einsum/tiledarray.h: DeNest::True builds C(c;U) = A(..) * B(..;..,U) then unwraps the unit-extent inner cells to scalars. - tensor/arena_einsum.h: RegimeAInnerKind::phantom_dot, ArenaInnerShapeKind:: unit_range, and arena_hadamard_phantom_dot (view cells). - expressions/cont_engine.h: inner phantom-dot op in the owning and view-cell inner-op paths, for both outer-Contraction and outer-Hadamard regimes. - tests/einsum.cpp: external-index (e-present) denest case. C8H18 PNO-CCSD: the motivating term drops 256 s/26.6 GB -> 0.25 s/2.9 GB and the run now completes; c4h10 converges as before.
1 parent db0bff5 commit eb92aa9

5 files changed

Lines changed: 353 additions & 143 deletions

File tree

src/TiledArray/einsum/tiledarray.h

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -531,26 +531,21 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
531531
"ranks match on the arguments");
532532

533533
//
534-
// Illustration of steps by an example.
534+
// Strategy. Consider A(ijpab;xy) * B(jiqba;yx) -> C(ipjq), inner xy fully
535+
// contracted. We reduce the contracted-outer indices ab and the contracted-
536+
// inner indices xy together with a single ToT x ToT -> ToT contraction
537+
// whose inner product is annotated to leave a *phantom unit* inner mode
538+
// (⊗₁) on the result, so the inner cell is a genuine (≥ order-1) unit
539+
// tensor rather than the unsupported order-0:
535540
//
536-
// Consider the evaluation: A(ijpab;xy) * B(jiqba;yx) -> C(ipjq).
541+
// C0(ipjq; ⊗₁) = A(ijpab; xy) * B(jiqba; xy,⊗₁)
537542
//
538-
// Note for the outer indices:
539-
// - Hadamard: 'ij'
540-
// - External A: 'p'
541-
// - External B: 'q'
542-
// - Contracted: 'ab'
543-
//
544-
// Now C is evaluated in the following steps.
545-
// Step I: A(ijpab;xy) * B(jiqba;yx) -> C0(ijpqab;xy)
546-
// Step II: C0(ijpqab;xy) -> C1(ijpqab)
547-
// Step III: C1(ijpqab) -> C2(ijpq)
548-
// Step IV: C2(ijpq) -> C(ipjq)
549-
550-
// Build a "denested" tile: one scalar per outer index, summed over the
551-
// inner tile. The result tile's outer type is TA::Tensor (inner tile
552-
// types like btas::Tensor are only valid as the innermost tile and don't
553-
// expose the range+lambda ctor used here).
543+
// ⊗₁ is appended to B's inner annotation only; B's inner *tensor* is
544+
// unchanged (⊗₁ is phantom unit -- ContEngine recognizes it and realizes
545+
// the inner product as a flat dot into a [1] cell, never requiring B to
546+
// physically carry the extra mode). Each [1] inner cell is then unwrapped
547+
// to a scalar. This never materializes the uncontracted product, and is
548+
// correct when an inner extent depends on a contracted-outer index.
554549
auto sum_tot_2_tos = [](auto const &tot) {
555550
using tot_t = std::remove_reference_t<decltype(tot)>;
556551
using numeric_type = typename tot_t::numeric_type;
@@ -566,30 +561,23 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
566561
return result;
567562
};
568563

569-
auto const oixs = TensorOpIndices(a, b, c);
564+
// U+2297 CIRCLED TIMES + U+2081 SUBSCRIPT ONE: a reserved phantom-unit
565+
// inner annotator (see is_phantom_unit_label).
566+
const std::string phantom_unit = "⊗₁";
570567

571-
struct {
572-
std::string C0, C1, C2;
573-
} const Cn_annot{
574-
std::string(oixs.ix_C_canon() + oixs.contracted()) + inner.a,
575-
{oixs.ix_C_canon() + oixs.contracted()},
576-
{oixs.ix_C_canon()}};
568+
auto a_annot = std::string(a) + inner.a; // e.g. "ijpab;xy"
569+
auto b_annot =
570+
std::string(b) + inner.b + "," + phantom_unit; // e.g. "jiqba;yx,⊗₁"
571+
auto c_annot = std::string(c) + ";" + phantom_unit; // e.g. "ipjq;⊗₁"
577572

578-
// Step I: A(ijpab;xy) * B(jiqba;yx) -> C0(ijpqab;xy)
579-
auto C0 = einsum(A, B, Cn_annot.C0);
573+
// C0(c; ⊗₁) = A(a; inner.A) * B(b; inner.B,⊗₁)
574+
auto C0 = einsum(A.array()(a_annot), B.array()(b_annot), c_annot);
580575

581-
// Step II: C0(ijpqab;xy) -> C1(ijpqab)
582-
auto C1 = TA::foreach<typename ArrayC::value_type>(
576+
// unwrap unit-extent inner cells to scalars
577+
ArrayC C = TA::foreach<typename ArrayC::value_type>(
583578
C0, [sum_tot_2_tos](auto &out_tile, auto const &in_tile) {
584579
out_tile = sum_tot_2_tos(in_tile);
585580
});
586-
587-
// Step III: C1(ijpqab) -> C2(ijpq)
588-
auto C2 = reduce_modes(C1, oixs.contracted().size());
589-
590-
// Step IV: C2(ijpq) -> C(ipjq)
591-
ArrayC C;
592-
C(c) = C2(Cn_annot.C2);
593581
return C;
594582

595583
} else {

0 commit comments

Comments
 (0)