Skip to content

Commit 157c320

Browse files
authored
Merge branch 'main' into syssched-R6-work
2 parents 1ca1f3f + dc8f168 commit 157c320

21 files changed

Lines changed: 555 additions & 588 deletions

include/nvexec/multi_gpu_context.cuh

Lines changed: 79 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -20,218 +20,119 @@
2020

2121
#include "../stdexec/execution.hpp"
2222

23-
#include <concepts>
24-
#include <utility>
25-
2623
#include "stream_context.cuh"
2724

2825
STDEXEC_PRAGMA_PUSH()
2926
STDEXEC_PRAGMA_IGNORE_EDG(cuda_compile)
3027

3128
namespace nvexec {
3229
namespace _strm {
33-
template <sender Sender, std::integral Shape, class Fun>
34-
using multi_gpu_bulk_sender_th =
35-
stdexec::__t<multi_gpu_bulk_sender_t<stdexec::__id<__decay_t<Sender>>, Shape, Fun>>;
36-
37-
struct multi_gpu_stream_scheduler {
30+
struct multi_gpu_stream_scheduler : private stream_scheduler_env {
3831
using __t = multi_gpu_stream_scheduler;
3932
using __id = multi_gpu_stream_scheduler;
40-
friend stream_context;
4133

42-
template <sender Sender>
43-
using schedule_from_sender_th =
44-
stdexec::__t<schedule_from_sender_t<stream_scheduler, stdexec::__id<__decay_t<Sender>>>>;
34+
multi_gpu_stream_scheduler(int num_devices, context_state_t context_state)
35+
: num_devices_(num_devices)
36+
, context_state_(context_state) {
37+
}
4538

46-
template <class RId>
47-
struct operation_state_t : stream_op_state_base {
48-
using R = stdexec::__t<RId>;
39+
auto operator==(const multi_gpu_stream_scheduler& other) const noexcept -> bool {
40+
return context_state_.hub_ == other.context_state_.hub_;
41+
}
4942

50-
R rec_;
51-
cudaStream_t stream_{nullptr};
52-
cudaError_t status_{cudaSuccess};
43+
[[nodiscard]]
44+
STDEXEC_ATTRIBUTE((host, device)) auto schedule() const noexcept {
45+
return sender_t{num_devices_, context_state_};
46+
}
47+
48+
using stream_scheduler_env::query;
5349

54-
template <__decays_to<R> Receiver>
55-
operation_state_t(Receiver&& rec)
56-
: rec_(static_cast<Receiver&&>(rec)) {
50+
private:
51+
template <class ReceiverId>
52+
struct operation_state_t : stream_op_state_base {
53+
using Receiver = stdexec::__t<ReceiverId>;
54+
55+
explicit operation_state_t(Receiver rcvr)
56+
: rcvr_(static_cast<Receiver&&>(rcvr)) {
5757
status_ = STDEXEC_DBG_ERR(cudaStreamCreate(&stream_));
5858
}
5959

6060
~operation_state_t() {
6161
STDEXEC_DBG_ERR(cudaStreamDestroy(stream_));
6262
}
6363

64+
[[nodiscard]]
6465
auto get_stream() -> cudaStream_t {
6566
return stream_;
6667
}
6768

6869
void start() & noexcept {
69-
if constexpr (stream_receiver<R>) {
70+
if constexpr (stream_receiver<Receiver>) {
7071
if (status_ == cudaSuccess) {
71-
stdexec::set_value(static_cast<R&&>(rec_));
72+
stdexec::set_value(static_cast<Receiver&&>(rcvr_));
7273
} else {
73-
stdexec::set_error(static_cast<R&&>(rec_), std::move(status_));
74+
stdexec::set_error(static_cast<Receiver&&>(rcvr_), std::move(status_));
7475
}
7576
} else {
7677
if (status_ == cudaSuccess) {
77-
continuation_kernel<<<1, 1, 0, stream_>>>(std::move(rec_), stdexec::set_value);
78+
continuation_kernel<<<1, 1, 0, stream_>>>(std::move(rcvr_), stdexec::set_value);
7879
} else {
7980
continuation_kernel<<<1, 1, 0, stream_>>>(
80-
std::move(rec_), stdexec::set_error, std::move(status_));
81+
std::move(rcvr_), stdexec::set_error, std::move(status_));
8182
}
8283
}
8384
}
84-
};
8585

86-
struct sender_t : stream_sender_base {
86+
private:
87+
friend stream_context;
8788

88-
struct env {
89-
int num_devices_;
90-
context_state_t context_state_;
89+
Receiver rcvr_;
90+
cudaStream_t stream_{};
91+
cudaError_t status_{cudaSuccess};
92+
};
9193

92-
template <class CPO>
93-
auto query(get_completion_scheduler_t<CPO>) const noexcept -> multi_gpu_stream_scheduler {
94-
return multi_gpu_stream_scheduler{num_devices_, context_state_};
95-
}
96-
};
94+
struct sender_t : stream_sender_base {
95+
using __t = sender_t;
96+
using __id = sender_t;
9797

9898
using completion_signatures =
99-
completion_signatures<set_value_t(), set_error_t(cudaError_t)>;
99+
stdexec::completion_signatures<set_value_t(), set_error_t(cudaError_t)>;
100100

101-
template <class R>
102-
auto connect(R rec) const & noexcept(__nothrow_move_constructible<R>) //
103-
-> operation_state_t<stdexec::__id<__decay_t<R>>> {
104-
return operation_state_t<stdexec::__id<__decay_t<R>>>(static_cast<R&&>(rec));
101+
STDEXEC_ATTRIBUTE((host, device)) explicit sender_t(int num_devices, context_state_t context_state) noexcept
102+
: env_{.num_devices_ = num_devices, .context_state_ = context_state} {
105103
}
106104

105+
template <class Receiver>
107106
[[nodiscard]]
108-
auto get_env() const noexcept -> const env& {
109-
return env_;
107+
auto connect(Receiver rcvr) const & noexcept(__nothrow_move_constructible<Receiver>) //
108+
-> operation_state_t<stdexec::__id<Receiver>> {
109+
return operation_state_t<stdexec::__id<Receiver>>(static_cast<Receiver&&>(rcvr));
110110
}
111111

112-
sender_t(int num_devices, context_state_t context_state) noexcept
113-
: env_{.num_devices_ = num_devices, .context_state_ = context_state} {
112+
[[nodiscard]]
113+
auto get_env() const noexcept -> decltype(auto) {
114+
return (env_);
114115
}
115116

116-
env env_;
117-
};
118-
119-
template <sender S>
120-
STDEXEC_MEMFN_DECL(schedule_from_sender_th<S> schedule_from)(
121-
this const multi_gpu_stream_scheduler& sch,
122-
S&& sndr) //
123-
noexcept {
124-
return schedule_from_sender_th<S>(sch.context_state_, static_cast<S&&>(sndr));
125-
}
126-
127-
template <sender S, std::integral Shape, class Fn>
128-
STDEXEC_MEMFN_DECL(multi_gpu_bulk_sender_th<S, Shape, Fn> bulk)(
129-
this const multi_gpu_stream_scheduler& sch, //
130-
S&& sndr, //
131-
Shape shape, //
132-
Fn fun) //
133-
noexcept {
134-
return multi_gpu_bulk_sender_th<S, Shape, Fn>{
135-
{}, sch.num_devices_, static_cast<S&&>(sndr), shape, static_cast<Fn&&>(fun)};
136-
}
137-
138-
template <sender S, class Fn>
139-
STDEXEC_MEMFN_DECL(then_sender_th<S, Fn> then)(
140-
this const multi_gpu_stream_scheduler& sch,
141-
S&& sndr,
142-
Fn fun) //
143-
noexcept {
144-
return then_sender_th<S, Fn>{{}, static_cast<S&&>(sndr), static_cast<Fn&&>(fun)};
145-
}
146-
147-
template <__one_of<let_value_t, let_stopped_t, let_error_t> Let, sender S, class Fn>
148-
friend auto tag_invoke(Let, const multi_gpu_stream_scheduler& sch, S&& sndr, Fn fun) noexcept
149-
-> let_xxx_th<Let, S, Fn> {
150-
return let_xxx_th<Let, S, Fn>{{}, static_cast<S&&>(sndr), static_cast<Fn&&>(fun)};
151-
}
152-
153-
template <sender S, class Fn>
154-
STDEXEC_MEMFN_DECL(upon_error_sender_th<S, Fn> upon_error)(
155-
this const multi_gpu_stream_scheduler& sch,
156-
S&& sndr,
157-
Fn fun) noexcept {
158-
return upon_error_sender_th<S, Fn>{{}, static_cast<S&&>(sndr), static_cast<Fn&&>(fun)};
159-
}
160-
161-
template <sender S, class Fn>
162-
STDEXEC_MEMFN_DECL(upon_stopped_sender_th<S, Fn> upon_stopped)(
163-
this const multi_gpu_stream_scheduler& sch,
164-
S&& sndr,
165-
Fn fun) noexcept {
166-
return upon_stopped_sender_th<S, Fn>{{}, static_cast<S&&>(sndr), static_cast<Fn&&>(fun)};
167-
}
168-
169-
template <stream_completing_sender... Senders>
170-
STDEXEC_MEMFN_DECL(auto transfer_when_all)(
171-
this const multi_gpu_stream_scheduler& sch, //
172-
Senders&&... sndrs) noexcept {
173-
return transfer_when_all_sender_th<multi_gpu_stream_scheduler, Senders...>(
174-
sch.context_state_, static_cast<Senders&&>(sndrs)...);
175-
}
176-
177-
template <stream_completing_sender... Senders>
178-
STDEXEC_MEMFN_DECL(auto transfer_when_all_with_variant)(
179-
this const multi_gpu_stream_scheduler& sch, //
180-
Senders&&... sndrs) noexcept {
181-
return transfer_when_all_sender_th<
182-
multi_gpu_stream_scheduler,
183-
__result_of<into_variant, Senders>...>(
184-
sch.context_state_, into_variant(static_cast<Senders&&>(sndrs))...);
185-
}
186-
187-
template <sender S, scheduler Sch>
188-
STDEXEC_MEMFN_DECL(auto continues_on)(
189-
this const multi_gpu_stream_scheduler& sch, //
190-
S&& sndr, //
191-
Sch&& scheduler) noexcept {
192-
return schedule_from(
193-
static_cast<Sch&&>(scheduler),
194-
continues_on_sender_th<S>(sch.context_state_, static_cast<S&&>(sndr)));
195-
}
196-
197-
template <sender S>
198-
STDEXEC_MEMFN_DECL(
199-
split_sender_th<S> split)(this const multi_gpu_stream_scheduler& sch, S&& sndr) noexcept {
200-
return split_sender_th<S>(static_cast<S&&>(sndr), sch.context_state_);
201-
}
202-
203-
template <sender S>
204-
STDEXEC_MEMFN_DECL(ensure_started_th<S> ensure_started)(
205-
this const multi_gpu_stream_scheduler& sch,
206-
S&& sndr) //
207-
noexcept {
208-
return ensure_started_th<S>(static_cast<S&&>(sndr), sch.context_state_);
209-
}
210-
211-
[[nodiscard]]
212-
auto schedule() const noexcept -> sender_t {
213-
return {num_devices_, context_state_};
214-
}
215-
216-
template <sender S>
217-
STDEXEC_MEMFN_DECL(auto sync_wait)(this const multi_gpu_stream_scheduler& self, S&& sndr) {
218-
return _sync_wait::sync_wait_t{}(self.context_state_, static_cast<S&&>(sndr));
219-
}
117+
private:
118+
struct env {
119+
using __t = env;
120+
using __id = env;
220121

221-
[[nodiscard]]
222-
auto query(get_forward_progress_guarantee_t) const noexcept -> forward_progress_guarantee {
223-
return forward_progress_guarantee::weakly_parallel;
224-
}
122+
int num_devices_;
123+
context_state_t context_state_;
225124

226-
auto operator==(const multi_gpu_stream_scheduler& other) const noexcept -> bool {
227-
return context_state_.hub_ == other.context_state_.hub_;
228-
}
125+
template <class CPO>
126+
[[nodiscard]]
127+
auto query(get_completion_scheduler_t<CPO>) const noexcept -> multi_gpu_stream_scheduler {
128+
return multi_gpu_stream_scheduler{num_devices_, context_state_};
129+
}
130+
};
229131

230-
multi_gpu_stream_scheduler(int num_devices, context_state_t context_state)
231-
: num_devices_(num_devices)
232-
, context_state_(context_state) {
233-
}
132+
env env_;
133+
};
234134

135+
public:
235136
// private: TODO
236137
int num_devices_{};
237138
context_state_t context_state_;
@@ -241,23 +142,8 @@ namespace nvexec {
241142
using _strm::multi_gpu_stream_scheduler;
242143

243144
struct multi_gpu_stream_context {
244-
int num_devices_{};
245-
246-
_strm::resource_storage<_strm::pinned_resource> pinned_resource_{};
247-
_strm::resource_storage<_strm::managed_resource> managed_resource_{};
248-
_strm::stream_pools_t stream_pools_{};
249-
250-
int dev_id_{};
251-
_strm::queue::task_hub_t hub_;
252-
253-
static auto get_device() -> int {
254-
int dev_id{};
255-
cudaGetDevice(&dev_id);
256-
return dev_id;
257-
}
258-
259145
multi_gpu_stream_context()
260-
: dev_id_(get_device())
146+
: dev_id_(_get_device())
261147
, hub_(dev_id_, pinned_resource_.get()) {
262148
// TODO Manage errors
263149
cudaGetDeviceCount(&num_devices_);
@@ -278,13 +164,30 @@ namespace nvexec {
278164
cudaSetDevice(dev_id_);
279165
}
280166

167+
[[nodiscard]]
281168
auto get_scheduler(stream_priority priority = stream_priority::normal)
282169
-> multi_gpu_stream_scheduler {
283170
return {
284171
num_devices_,
285172
_strm::context_state_t(
286173
pinned_resource_.get(), managed_resource_.get(), &stream_pools_, &hub_, priority)};
287174
}
175+
176+
private:
177+
static auto _get_device() -> int {
178+
int dev_id{};
179+
cudaGetDevice(&dev_id);
180+
return dev_id;
181+
}
182+
183+
int num_devices_{};
184+
185+
_strm::resource_storage<_strm::pinned_resource> pinned_resource_{};
186+
_strm::resource_storage<_strm::managed_resource> managed_resource_{};
187+
_strm::stream_pools_t stream_pools_{};
188+
189+
int dev_id_{};
190+
_strm::queue::task_hub_t hub_;
288191
};
289192
} // namespace nvexec
290193

include/nvexec/stream/bulk.cuh

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,27 @@ namespace nvexec::_strm {
382382
}
383383
};
384384
};
385+
386+
template <>
387+
struct transform_sender_for<stdexec::bulk_t> {
388+
template <class Data, stream_completing_sender Sender>
389+
auto operator()(__ignore, Data data, Sender&& sndr) const {
390+
auto [shape, fun] = static_cast<Data&&>(data);
391+
using Shape = decltype(shape);
392+
using Fn = decltype(fun);
393+
auto sched = get_completion_scheduler<set_value_t>(get_env(sndr));
394+
if constexpr (same_as<decltype(sched), stream_scheduler>) {
395+
// Use the bulk sender for a single GPU
396+
using _sender_t = __t<bulk_sender_t<__id<__decay_t<Sender>>, Shape, Fn>>;
397+
return _sender_t{{}, static_cast<Sender&&>(sndr), shape, static_cast<Fn&&>(fun)};
398+
} else {
399+
// Use the bulk sender for a multiple GPUs
400+
using _sender_t = __t<multi_gpu_bulk_sender_t<__id<__decay_t<Sender>>, Shape, Fn>>;
401+
return _sender_t{
402+
{}, sched.num_devices_, static_cast<Sender&&>(sndr), shape, static_cast<Fn&&>(fun)};
403+
}
404+
}
405+
};
385406
} // namespace nvexec::_strm
386407

387408
namespace stdexec::__detail {

0 commit comments

Comments
 (0)