Skip to content

Commit d327b82

Browse files
authored
Merge pull request #550 from ValeevGroup/evaleev/fix/arena-tot-fence-and-view-hadamard-outer-contract
arena ToT: einsum sub-World fence guard + (outer-Contraction, inner-Hadamard) view-cell case
2 parents 60bb88d + 31800a9 commit d327b82

2 files changed

Lines changed: 90 additions & 9 deletions

File tree

src/TiledArray/einsum/tiledarray.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,38 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
653653
// dead World (e.g. while unwinding an exception thrown mid-contraction).
654654
std::vector<std::shared_ptr<World>> worlds;
655655

656+
// RAII fencer: on normal exit and (critically) on exception unwind,
657+
// fence every live sub-World before it is destroyed. ~DistArray ->
658+
// lazy_deleter calls world.gop.lazy_sync(...) which enqueues a
659+
// lazy_sync_children task onto the sub-World's taskq; without a fence
660+
// those tasks survive into the global ThreadPool past the sub-World's
661+
// ~World, then trip ~WorldObject's `World::exists(&world)` assertion
662+
// when some later fence (e.g. an enclosing scope's fence run during
663+
// unwind) picks them up. Declared *after* `worlds` so it destructs
664+
// *before* `worlds` (LIFO); destructs *after* AB/C so it sees the
665+
// tasks they scheduled via lazy_deleter.
666+
//
667+
// One fence per sub-World is sufficient: lazy_deleter's fast path
668+
// skips lazy_sync when invoked from inside fence_impl's do_cleanup
669+
// (gated by `world.gop.is_in_do_cleanup()`), so the deferred-cleanup
670+
// path performs direct deletes rather than scheduling cross-rank
671+
// tasks. Tasks scheduled by *non*-deferred ~DistArray's (e.g. AB
672+
// during exception unwind) are drained by this fence's drain loop;
673+
// all participating ranks of a sub-World reach this RAII guard in
674+
// lockstep at function exit, so their lazy_sync handshakes match up.
675+
struct FenceSubWorldsOnExit {
676+
std::vector<std::shared_ptr<World>> &worlds_;
677+
~FenceSubWorldsOnExit() {
678+
for (auto &w : worlds_) {
679+
if (!w) continue;
680+
try {
681+
w->gop.fence();
682+
} catch (...) {
683+
}
684+
}
685+
}
686+
} fence_subworlds_on_exit{worlds};
687+
656688
std::tuple<ArrayTerm<ArrayA>, ArrayTerm<ArrayB>> AB{{A.array(), a},
657689
{B.array(), b}};
658690

src/TiledArray/expressions/cont_engine.h

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -547,20 +547,69 @@ class ContEngine : public BinaryEngine<Derived> {
547547
TiledArray::is_tensor_view_v<result_tile_element_type>) {
548548
// ToT x ToT with non-owning view inner cells (e.g. ArenaTensor). A
549549
// view cell cannot host a value-returning inner op, so the
550-
// owning-cell inner-op builder cannot be used. Two nested products
551-
// are supported here:
552-
// - the elementwise pure Hadamard, where the inner element op is
553-
// unused anyway -- MultEngine::make_tile_op passes none and the
554-
// outer Mult tile op recurses through Tensor<view>::mult -- so
555-
// element_*_op_ is left null;
556-
// - the inner contraction (incl. inner outer-product), routed
557-
// through the arena fast path: it writes results in place into
558-
// pre-shaped view cells, so only element_nonreturn_op_ is needed.
550+
// owning-cell inner-op builder cannot be used. The supported nested
551+
// products are:
552+
// - the elementwise pure Hadamard (outer Hadamard, inner Hadamard),
553+
// where the inner element op is unused anyway -- MultEngine::
554+
// make_tile_op passes none and the outer Mult tile op recurses
555+
// through Tensor<view>::mult -- so element_*_op_ is left null;
556+
// - inner Hadamard under outer Contraction, routed through the
557+
// arena fast path with a left_range plan and a per-cell
558+
// `r += l * rr` (optionally scaled) op: result cells are
559+
// pre-shaped from non-empty left cells, then accumulated in
560+
// place over the K-panel;
561+
// - inner Contraction (incl. inner outer-product) under either
562+
// outer regime, routed through the arena fast path: it writes
563+
// results in place into pre-shaped view cells, so only
564+
// element_nonreturn_op_ is needed.
559565
// Every other nested product is deferred.
560566
const auto inner_prod = this->inner_product_type();
561567
if (inner_prod == TensorProduct::Hadamard &&
562568
this->product_type() == TensorProduct::Hadamard) {
563569
// pure Hadamard: element_*_op_ left null
570+
} else if (inner_prod == TensorProduct::Hadamard &&
571+
this->product_type() == TensorProduct::Contraction) {
572+
// outer Contraction + inner Hadamard on view inner tiles.
573+
// Mirror the owning-tile path (init_inner_tile_op_owning_): the
574+
// SUMMA shapes each result cell from a non-empty left inner cell
575+
// (left_range plan), and the per-cell op accumulates `r += l * rr`
576+
// -- or `r += (l * rr) * factor_` when scaled -- via
577+
// fused_hadamard_inplace into the pre-shaped view cell. No
578+
// value-returning per-cell op is needed, so this works for view
579+
// cells; non-identity inner result permutation is rejected here
580+
// (the owning fallback that materializes a permuted return cell
581+
// cannot run for views).
582+
constexpr bool arena_eligible_h_view =
583+
TiledArray::detail::is_contraction_arena_tot_v<
584+
result_tile_type, left_tile_type, right_tile_type>;
585+
if constexpr (!arena_eligible_h_view) {
586+
TA_EXCEPTION(
587+
"nested Hadamard on view inner tiles is supported only for "
588+
"arena-backed tensors-of-tensors");
589+
} else {
590+
this->arena_plan_ = TiledArray::detail::make_contraction_arena_plan<
591+
result_tile_type, left_tile_type, right_tile_type>(
592+
TiledArray::detail::ArenaInnerShapeKind::left_range,
593+
std::nullopt, inner(this->perm_));
594+
if (!bool(this->arena_plan_))
595+
TA_EXCEPTION(
596+
"nested Hadamard on view inner tiles: the arena fast path "
597+
"was inactive (arena disabled, or a non-identity inner "
598+
"result permutation -- not yet supported on view cells)");
599+
if (this->factor_ == scalar_type{1}) {
600+
this->element_nonreturn_op_ =
601+
TiledArray::detail::make_fused_hadamard_lambda<
602+
result_tile_element_type, left_tile_element_type,
603+
right_tile_element_type>();
604+
} else {
605+
this->element_nonreturn_op_ =
606+
TiledArray::detail::make_fused_hadamard_scaled_lambda<
607+
result_tile_element_type, left_tile_element_type,
608+
right_tile_element_type>(this->factor_);
609+
}
610+
}
611+
// element_return_op_ left null: a view cell cannot be
612+
// value-returned (see the init_struct precondition check).
564613
} else if (inner_prod == TensorProduct::Contraction) {
565614
using op_type = TiledArray::detail::ContractReduce<
566615
result_tile_element_type, left_tile_element_type,

0 commit comments

Comments
 (0)