@@ -531,26 +531,21 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
531531 " ranks match on the arguments" );
532532
533533 //
534- // Illustration of steps by an example.
534+ // Strategy. Consider A(ijpab;xy) * B(jiqba;yx) -> C(ipjq), inner xy fully
535+ // contracted. We reduce the contracted-outer indices ab and the contracted-
536+ // inner indices xy together with a single ToT x ToT -> ToT contraction
537+ // whose inner product is annotated to leave a *phantom unit* inner mode
538+ // (⊗₁) on the result, so the inner cell is a genuine (≥ order-1) unit
539+ // tensor rather than the unsupported order-0:
535540 //
536- // Consider the evaluation: A(ijpab;xy) * B(jiqba;yx) -> C(ipjq).
541+ // C0(ipjq; ⊗₁) = A(ijpab; xy) * B(jiqba; xy,⊗₁)
537542 //
538- // Note for the outer indices:
539- // - Hadamard: 'ij'
540- // - External A: 'p'
541- // - External B: 'q'
542- // - Contracted: 'ab'
543- //
544- // Now C is evaluated in the following steps.
545- // Step I: A(ijpab;xy) * B(jiqba;yx) -> C0(ijpqab;xy)
546- // Step II: C0(ijpqab;xy) -> C1(ijpqab)
547- // Step III: C1(ijpqab) -> C2(ijpq)
548- // Step IV: C2(ijpq) -> C(ipjq)
549-
550- // Build a "denested" tile: one scalar per outer index, summed over the
551- // inner tile. The result tile's outer type is TA::Tensor (inner tile
552- // types like btas::Tensor are only valid as the innermost tile and don't
553- // expose the range+lambda ctor used here).
543+ // ⊗₁ is appended to B's inner annotation only; B's inner *tensor* is
544+ // unchanged (⊗₁ is phantom unit -- ContEngine recognizes it and realizes
545+ // the inner product as a flat dot into a [1] cell, never requiring B to
546+ // physically carry the extra mode). Each [1] inner cell is then unwrapped
547+ // to a scalar. This never materializes the uncontracted product, and is
548+ // correct when an inner extent depends on a contracted-outer index.
554549 auto sum_tot_2_tos = [](auto const &tot) {
555550 using tot_t = std::remove_reference_t <decltype (tot)>;
556551 using numeric_type = typename tot_t ::numeric_type;
@@ -566,30 +561,23 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
566561 return result;
567562 };
568563
569- auto const oixs = TensorOpIndices (a, b, c);
564+ // U+2297 CIRCLED TIMES + U+2081 SUBSCRIPT ONE: a reserved phantom-unit
565+ // inner annotator (see is_phantom_unit_label).
566+ const std::string phantom_unit = " ⊗₁" ;
570567
571- struct {
572- std::string C0, C1, C2;
573- } const Cn_annot{
574- std::string (oixs.ix_C_canon () + oixs.contracted ()) + inner.a ,
575- {oixs.ix_C_canon () + oixs.contracted ()},
576- {oixs.ix_C_canon ()}};
568+ auto a_annot = std::string (a) + inner.a ; // e.g. "ijpab;xy"
569+ auto b_annot =
570+ std::string (b) + inner.b + " ," + phantom_unit; // e.g. "jiqba;yx,⊗₁"
571+ auto c_annot = std::string (c) + " ;" + phantom_unit; // e.g. "ipjq;⊗₁"
577572
578- // Step I: A(ijpab;xy ) * B(jiqba;yx) -> C0(ijpqab;xy )
579- auto C0 = einsum (A, B, Cn_annot. C0 );
573+ // C0(c; ⊗₁) = A(a; inner.A ) * B(b; inner.B,⊗₁ )
574+ auto C0 = einsum (A. array ()(a_annot) , B. array ()(b_annot), c_annot );
580575
581- // Step II: C0(ijpqab;xy) -> C1(ijpqab)
582- auto C1 = TA::foreach<typename ArrayC::value_type>(
576+ // unwrap unit-extent inner cells to scalars
577+ ArrayC C = TA::foreach<typename ArrayC::value_type>(
583578 C0, [sum_tot_2_tos](auto &out_tile, auto const &in_tile) {
584579 out_tile = sum_tot_2_tos (in_tile);
585580 });
586-
587- // Step III: C1(ijpqab) -> C2(ijpq)
588- auto C2 = reduce_modes (C1, oixs.contracted ().size ());
589-
590- // Step IV: C2(ijpq) -> C(ipjq)
591- ArrayC C;
592- C (c) = C2 (Cn_annot.C2 );
593581 return C;
594582
595583 } else {
0 commit comments