@@ -29,6 +29,11 @@ struct ExtraBuffers : protected Matrix<T, D> {
2929 pika::execution::thread_priority::high)));
3030 }
3131
32+ auto read_sender (SizeType index) {
33+ index %= nbuffers_;
34+ return Matrix<T, D>::read_sender (LocalTileIndex{index, 0 });
35+ }
36+
3237 auto readwrite_sender (SizeType index) {
3338 index %= nbuffers_;
3439 return Matrix<T, D>::readwrite_sender (LocalTileIndex{index, 0 });
@@ -38,17 +43,32 @@ struct ExtraBuffers : protected Matrix<T, D> {
3843 [[nodiscard]] auto reduce (TileSender tile) {
3944 namespace ex = pika::execution::experimental;
4045
41- std::vector<pika::future<matrix::Tile<T, D>>> buffers;
42- for (const auto & ij : common::iterate_range2d (this ->distribution ().localNrTiles ()))
43- buffers.emplace_back (Matrix<T, D>::operator ()(ij));
44- auto all_buffers = ex::when_all_vector (std::move (buffers));
45-
46- return ex::when_all (std::move (tile), std::move (all_buffers)) |
47- ex::then ([](const matrix::Tile<T, D>& tile, const std::vector<matrix::Tile<T, D>>& buffers) {
48- tile::internal::set0 (tile);
49- for (auto & buffer : buffers)
50- dlaf::tile::internal::add (T (1 ), buffer, tile);
51- });
46+ std::vector<ex::any_sender<pika::shared_future<matrix::Tile<const T, D>>>> buffers;
47+ for (SizeType index = 0 ; index < nbuffers_; ++index)
48+ buffers.emplace_back (read_sender (index));
49+
50+ return ex::when_all (std::move (tile), ex::when_all_vector (std::move (buffers))) |
51+ dlaf::internal::transform (dlaf::internal::Policy<DefaultBackend_v<D>>(),
52+ [](const matrix::Tile<T, D>& tile,
53+ const std::vector<pika::shared_future<matrix::Tile<const T, D>>>&
54+ buffers,
55+ auto &&... ts) {
56+ for (const auto & buffer : buffers) {
57+ if constexpr (D == Device::CPU) {
58+ static_assert (sizeof ...(ts) == 0 ,
59+ " Parameter pack should be empty for MC." );
60+ dlaf::tile::internal::add (T (1 ), buffer.get (), tile);
61+ }
62+ #ifdef DLAF_WITH_GPU
63+ else if constexpr (D == Device::GPU) {
64+ dlaf::tile::internal::add (T (1 ), buffer.get (), tile, ts...);
65+ }
66+ #endif
67+ else {
68+ DLAF_STATIC_UNIMPLEMENTED (T);
69+ }
70+ }
71+ });
5272 }
5373
5474protected:
0 commit comments