Skip to content

Commit 17105f2

Browse files
committed
einsum: parallelize hadamard-reduction outer h-tile loop
1 parent 7f76cda commit 17105f2

1 file changed

Lines changed: 28 additions & 4 deletions

File tree

src/TiledArray/einsum/tiledarray.h

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -687,10 +687,24 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
687687

688688
auto pa = A.permutation;
689689
auto pb = B.permutation;
690-
for (Index h : H.tiles) {
690+
691+
// Each H-tile iteration produces an independent output tile, so the
692+
// loop is parallel-safe. Dispatch per-H-tile work to the MADNESS task
693+
// queue; pre-size a per-slot result vector so tasks write their own
694+
// slot without synchronization, and gather before exiting scope so
695+
// captured references stay alive for the task lifetime.
696+
std::vector<Index> local_hs;
697+
{
698+
auto const pc = C.permutation;
699+
for (Index h : H.tiles) {
700+
if (C.array.is_local(apply(pc, h))) local_hs.push_back(h);
701+
}
702+
}
703+
std::vector<std::pair<Index, ResultTensor>> h_results(local_hs.size());
704+
705+
auto per_h_work = [&, pa, pb](Index h, size_t slot) -> bool {
691706
auto const pc = C.permutation;
692707
auto const c = apply(pc, h);
693-
if (!C.array.is_local(c)) continue;
694708
size_t batch = 1;
695709
for (size_t i = 0; i < h.size(); ++i) {
696710
batch *= H.batch[i].at(h[i]);
@@ -752,8 +766,18 @@ auto einsum(expressions::TsrExpr<ArrayA_> A, expressions::TsrExpr<ArrayB_> B,
752766
tile = tile.reshape(shape);
753767
// then permute to target C layout c = (c1 c2 ...)
754768
if (pc) tile = tile.permute(pc);
755-
// and move to C_local_tiles
756-
C_local_tiles.emplace_back(std::move(c), std::move(tile));
769+
h_results[slot] = {c, std::move(tile)};
770+
return true;
771+
};
772+
773+
std::vector<madness::Future<bool>> h_futures;
774+
h_futures.reserve(local_hs.size());
775+
for (size_t slot = 0; slot < local_hs.size(); ++slot) {
776+
h_futures.push_back(world.taskq.add(per_h_work, local_hs[slot], slot));
777+
}
778+
for (auto &fut : h_futures) fut.get();
779+
for (auto &r : h_results) {
780+
C_local_tiles.emplace_back(std::move(r.first), std::move(r.second));
757781
}
758782

759783
build_C_array();

0 commit comments

Comments
 (0)