Skip to content

Commit e22e652

Browse files
authored
Merge pull request #558 from ValeevGroup/evaleev/feature/denest-fused-contraction
einsum: ToT×ToT→T denest via phantom-unit dot (avoid C0 blowup)
2 parents db0bff5 + eb92aa9 commit e22e652

5 files changed

Lines changed: 353 additions & 143 deletions

File tree

src/TiledArray/einsum/tiledarray.h

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -531,26 +531,21 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
531531
"ranks match on the arguments");
532532

533533
//
534-
// Illustration of steps by an example.
534+
// Strategy. Consider A(ijpab;xy) * B(jiqba;yx) -> C(ipjq), inner xy fully
535+
// contracted. We reduce the contracted-outer indices ab and the contracted-
536+
// inner indices xy together with a single ToT x ToT -> ToT contraction
537+
// whose inner product is annotated to leave a *phantom unit* inner mode
538+
// (⊗₁) on the result, so the inner cell is a genuine (≥ order-1) unit
539+
// tensor rather than the unsupported order-0:
535540
//
536-
// Consider the evaluation: A(ijpab;xy) * B(jiqba;yx) -> C(ipjq).
541+
// C0(ipjq; ⊗₁) = A(ijpab; xy) * B(jiqba; xy,⊗₁)
537542
//
538-
// Note for the outer indices:
539-
// - Hadamard: 'ij'
540-
// - External A: 'p'
541-
// - External B: 'q'
542-
// - Contracted: 'ab'
543-
//
544-
// Now C is evaluated in the following steps.
545-
// Step I: A(ijpab;xy) * B(jiqba;yx) -> C0(ijpqab;xy)
546-
// Step II: C0(ijpqab;xy) -> C1(ijpqab)
547-
// Step III: C1(ijpqab) -> C2(ijpq)
548-
// Step IV: C2(ijpq) -> C(ipjq)
549-
550-
// Build a "denested" tile: one scalar per outer index, summed over the
551-
// inner tile. The result tile's outer type is TA::Tensor (inner tile
552-
// types like btas::Tensor are only valid as the innermost tile and don't
553-
// expose the range+lambda ctor used here).
543+
// ⊗₁ is appended to B's inner annotation only; B's inner *tensor* is
544+
// unchanged (⊗₁ is phantom unit -- ContEngine recognizes it and realizes
545+
// the inner product as a flat dot into a [1] cell, never requiring B to
546+
// physically carry the extra mode). Each [1] inner cell is then unwrapped
547+
// to a scalar. This never materializes the uncontracted product, and is
548+
// correct when an inner extent depends on a contracted-outer index.
554549
auto sum_tot_2_tos = [](auto const &tot) {
555550
using tot_t = std::remove_reference_t<decltype(tot)>;
556551
using numeric_type = typename tot_t::numeric_type;
@@ -566,30 +561,23 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
566561
return result;
567562
};
568563

569-
auto const oixs = TensorOpIndices(a, b, c);
564+
// U+2297 CIRCLED TIMES + U+2081 SUBSCRIPT ONE: a reserved phantom-unit
565+
// inner annotator (see is_phantom_unit_label).
566+
const std::string phantom_unit = "⊗₁";
570567

571-
struct {
572-
std::string C0, C1, C2;
573-
} const Cn_annot{
574-
std::string(oixs.ix_C_canon() + oixs.contracted()) + inner.a,
575-
{oixs.ix_C_canon() + oixs.contracted()},
576-
{oixs.ix_C_canon()}};
568+
auto a_annot = std::string(a) + inner.a; // e.g. "ijpab;xy"
569+
auto b_annot =
570+
std::string(b) + inner.b + "," + phantom_unit; // e.g. "jiqba;yx,⊗₁"
571+
auto c_annot = std::string(c) + ";" + phantom_unit; // e.g. "ipjq;⊗₁"
577572

578-
// Step I: A(ijpab;xy) * B(jiqba;yx) -> C0(ijpqab;xy)
579-
auto C0 = einsum(A, B, Cn_annot.C0);
573+
// C0(c; ⊗₁) = A(a; inner.A) * B(b; inner.B,⊗₁)
574+
auto C0 = einsum(A.array()(a_annot), B.array()(b_annot), c_annot);
580575

581-
// Step II: C0(ijpqab;xy) -> C1(ijpqab)
582-
auto C1 = TA::foreach<typename ArrayC::value_type>(
576+
// unwrap unit-extent inner cells to scalars
577+
ArrayC C = TA::foreach<typename ArrayC::value_type>(
583578
C0, [sum_tot_2_tos](auto &out_tile, auto const &in_tile) {
584579
out_tile = sum_tot_2_tos(in_tile);
585580
});
586-
587-
// Step III: C1(ijpqab) -> C2(ijpq)
588-
auto C2 = reduce_modes(C1, oixs.contracted().size());
589-
590-
// Step IV: C2(ijpq) -> C(ipjq)
591-
ArrayC C;
592-
C(c) = C2(Cn_annot.C2);
593581
return C;
594582

595583
} else {

0 commit comments

Comments
 (0)