Skip to content

Commit 4e52fcf

Browse files
committed
fix gpu and do not implicitly set0 the result tile
not yet as continuation
1 parent 8063d7b commit 4e52fcf

2 files changed

Lines changed: 32 additions & 11 deletions

File tree

include/dlaf/eigensolver/reduction_to_band/impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,7 @@ void gemmComputeW2(matrix::Matrix<T, D>& w2, matrix::Panel<Coord::Col, const T,
474474
tile::gemm(dlaf::internal::Policy<B>(thread_priority::high)));
475475
}
476476

477+
ex::start_detached(tile::set0(dlaf::internal::Policy<B>(), w2.readwrite_sender(LocalTileIndex(0, 0))));
477478
ex::start_detached(buffers.reduce(w2.readwrite_sender(LocalTileIndex(0, 0))));
478479
}
479480

include/dlaf/matrix/extra_buffers.h

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@ struct ExtraBuffers : protected Matrix<T, D> {
2929
pika::execution::thread_priority::high)));
3030
}
3131

32+
auto read_sender(SizeType index) {
33+
index %= nbuffers_;
34+
return Matrix<T, D>::read_sender(LocalTileIndex{index, 0});
35+
}
36+
3237
auto readwrite_sender(SizeType index) {
3338
index %= nbuffers_;
3439
return Matrix<T, D>::readwrite_sender(LocalTileIndex{index, 0});
@@ -38,17 +43,32 @@ struct ExtraBuffers : protected Matrix<T, D> {
3843
[[nodiscard]] auto reduce(TileSender tile) {
3944
namespace ex = pika::execution::experimental;
4045

41-
std::vector<pika::future<matrix::Tile<T, D>>> buffers;
42-
for (const auto& ij : common::iterate_range2d(this->distribution().localNrTiles()))
43-
buffers.emplace_back(Matrix<T, D>::operator()(ij));
44-
auto all_buffers = ex::when_all_vector(std::move(buffers));
45-
46-
return ex::when_all(std::move(tile), std::move(all_buffers)) |
47-
ex::then([](const matrix::Tile<T, D>& tile, const std::vector<matrix::Tile<T, D>>& buffers) {
48-
tile::internal::set0(tile);
49-
for (auto& buffer : buffers)
50-
dlaf::tile::internal::add(T(1), buffer, tile);
51-
});
46+
std::vector<ex::any_sender<pika::shared_future<matrix::Tile<const T, D>>>> buffers;
47+
for (SizeType index = 0; index < nbuffers_; ++index)
48+
buffers.emplace_back(read_sender(index));
49+
50+
return ex::when_all(std::move(tile), ex::when_all_vector(std::move(buffers))) |
51+
dlaf::internal::transform(dlaf::internal::Policy<DefaultBackend_v<D>>(),
52+
[](const matrix::Tile<T, D>& tile,
53+
const std::vector<pika::shared_future<matrix::Tile<const T, D>>>&
54+
buffers,
55+
auto&&... ts) {
56+
for (const auto& buffer : buffers) {
57+
if constexpr (D == Device::CPU) {
58+
static_assert(sizeof...(ts) == 0,
59+
"Parameter pack should be empty for MC.");
60+
dlaf::tile::internal::add(T(1), buffer.get(), tile);
61+
}
62+
#ifdef DLAF_WITH_GPU
63+
else if constexpr (D == Device::GPU) {
64+
dlaf::tile::internal::add(T(1), buffer.get(), tile, ts...);
65+
}
66+
#endif
67+
else {
68+
DLAF_STATIC_UNIMPLEMENTED(T);
69+
}
70+
}
71+
});
5272
}
5373

5474
protected:

0 commit comments

Comments
 (0)