2020
2121#include " ../stdexec/execution.hpp"
2222
23- #include < concepts>
24- #include < utility>
25-
2623#include " stream_context.cuh"
2724
2825STDEXEC_PRAGMA_PUSH ()
2926STDEXEC_PRAGMA_IGNORE_EDG(cuda_compile)
3027
3128namespace nvexec {
3229 namespace _strm {
33- template <sender Sender, std::integral Shape, class Fun >
34- using multi_gpu_bulk_sender_th =
35- stdexec::__t <multi_gpu_bulk_sender_t <stdexec::__id<__decay_t <Sender>>, Shape, Fun>>;
36-
37- struct multi_gpu_stream_scheduler {
30+ struct multi_gpu_stream_scheduler : private stream_scheduler_env {
3831 using __t = multi_gpu_stream_scheduler;
3932 using __id = multi_gpu_stream_scheduler;
40- friend stream_context;
4133
42- template <sender Sender>
43- using schedule_from_sender_th =
44- stdexec::__t <schedule_from_sender_t <stream_scheduler, stdexec::__id<__decay_t <Sender>>>>;
34+ multi_gpu_stream_scheduler (int num_devices, context_state_t context_state)
35+ : num_devices_(num_devices)
36+ , context_state_(context_state) {
37+ }
4538
46- template < class RId >
47- struct operation_state_t : stream_op_state_base {
48- using R = stdexec:: __t <RId>;
39+ auto operator ==( const multi_gpu_stream_scheduler& other) const noexcept -> bool {
40+ return context_state_. hub_ == other. context_state_ . hub_ ;
41+ }
4942
50- R rec_;
51- cudaStream_t stream_{nullptr };
52- cudaError_t status_{cudaSuccess};
43+ [[nodiscard]]
44+ STDEXEC_ATTRIBUTE ((host, device)) auto schedule () const noexcept {
45+ return sender_t {num_devices_, context_state_};
46+ }
47+
48+ using stream_scheduler_env::query;
5349
54- template <__decays_to<R> Receiver>
55- operation_state_t (Receiver&& rec)
56- : rec_(static_cast <Receiver&&>(rec)) {
50+ private:
51+ template <class ReceiverId >
52+ struct operation_state_t : stream_op_state_base {
53+ using Receiver = stdexec::__t <ReceiverId>;
54+
55+ explicit operation_state_t (Receiver rcvr)
56+ : rcvr_(static_cast <Receiver&&>(rcvr)) {
5757 status_ = STDEXEC_DBG_ERR (cudaStreamCreate (&stream_));
5858 }
5959
6060 ~operation_state_t () {
6161 STDEXEC_DBG_ERR (cudaStreamDestroy (stream_));
6262 }
6363
64+ [[nodiscard]]
6465 auto get_stream () -> cudaStream_t {
6566 return stream_;
6667 }
6768
6869 void start () & noexcept {
69- if constexpr (stream_receiver<R >) {
70+ if constexpr (stream_receiver<Receiver >) {
7071 if (status_ == cudaSuccess) {
71- stdexec::set_value (static_cast <R &&>(rec_ ));
72+ stdexec::set_value (static_cast <Receiver &&>(rcvr_ ));
7273 } else {
73- stdexec::set_error (static_cast <R &&>(rec_ ), std::move (status_));
74+ stdexec::set_error (static_cast <Receiver &&>(rcvr_ ), std::move (status_));
7475 }
7576 } else {
7677 if (status_ == cudaSuccess) {
77- continuation_kernel<<<1 , 1 , 0 , stream_>>> (std::move (rec_ ), stdexec::set_value);
78+ continuation_kernel<<<1 , 1 , 0 , stream_>>> (std::move (rcvr_ ), stdexec::set_value);
7879 } else {
7980 continuation_kernel<<<1 , 1 , 0 , stream_>>> (
80- std::move (rec_ ), stdexec::set_error, std::move (status_));
81+ std::move (rcvr_ ), stdexec::set_error, std::move (status_));
8182 }
8283 }
8384 }
84- };
8585
86- struct sender_t : stream_sender_base {
86+ private:
87+ friend stream_context;
8788
88- struct env {
89- int num_devices_;
90- context_state_t context_state_;
89+ Receiver rcvr_;
90+ cudaStream_t stream_{};
91+ cudaError_t status_{cudaSuccess};
92+ };
9193
92- template <class CPO >
93- auto query (get_completion_scheduler_t <CPO>) const noexcept -> multi_gpu_stream_scheduler {
94- return multi_gpu_stream_scheduler{num_devices_, context_state_};
95- }
96- };
94+ struct sender_t : stream_sender_base {
95+ using __t = sender_t ;
96+ using __id = sender_t ;
9797
9898 using completion_signatures =
99- completion_signatures<set_value_t (), set_error_t (cudaError_t)>;
99+ stdexec:: completion_signatures<set_value_t (), set_error_t (cudaError_t)>;
100100
101- template <class R >
102- auto connect (R rec) const & noexcept (__nothrow_move_constructible<R>) //
103- -> operation_state_t<stdexec::__id<__decay_t<R>>> {
104- return operation_state_t <stdexec::__id<__decay_t <R>>>(static_cast <R&&>(rec));
101+ STDEXEC_ATTRIBUTE ((host, device)) explicit sender_t (int num_devices, context_state_t context_state) noexcept
102+ : env_{.num_devices_ = num_devices, .context_state_ = context_state} {
105103 }
106104
105+ template <class Receiver >
107106 [[nodiscard]]
108- auto get_env () const noexcept -> const env& {
109- return env_;
107+ auto connect (Receiver rcvr) const & noexcept (__nothrow_move_constructible<Receiver>) //
108+ -> operation_state_t<stdexec::__id<Receiver>> {
109+ return operation_state_t <stdexec::__id<Receiver>>(static_cast <Receiver&&>(rcvr));
110110 }
111111
112- sender_t (int num_devices, context_state_t context_state) noexcept
113- : env_{.num_devices_ = num_devices, .context_state_ = context_state} {
112+ [[nodiscard]]
113+ auto get_env () const noexcept -> decltype(auto ) {
114+ return (env_);
114115 }
115116
116- env env_;
117- };
118-
119- template <sender S>
120- STDEXEC_MEMFN_DECL (schedule_from_sender_th<S> schedule_from)(
121- this const multi_gpu_stream_scheduler& sch,
122- S&& sndr) //
123- noexcept {
124- return schedule_from_sender_th<S>(sch.context_state_ , static_cast <S&&>(sndr));
125- }
126-
127- template <sender S, std::integral Shape, class Fn >
128- STDEXEC_MEMFN_DECL (multi_gpu_bulk_sender_th<S, Shape, Fn> bulk)(
129- this const multi_gpu_stream_scheduler& sch, //
130- S&& sndr, //
131- Shape shape, //
132- Fn fun) //
133- noexcept {
134- return multi_gpu_bulk_sender_th<S, Shape, Fn>{
135- {}, sch.num_devices_ , static_cast <S&&>(sndr), shape, static_cast <Fn&&>(fun)};
136- }
137-
138- template <sender S, class Fn >
139- STDEXEC_MEMFN_DECL (then_sender_th<S, Fn> then)(
140- this const multi_gpu_stream_scheduler& sch,
141- S&& sndr,
142- Fn fun) //
143- noexcept {
144- return then_sender_th<S, Fn>{{}, static_cast <S&&>(sndr), static_cast <Fn&&>(fun)};
145- }
146-
147- template <__one_of<let_value_t , let_stopped_t , let_error_t > Let, sender S, class Fn >
148- friend auto tag_invoke (Let, const multi_gpu_stream_scheduler& sch, S&& sndr, Fn fun) noexcept
149- -> let_xxx_th<Let, S, Fn> {
150- return let_xxx_th<Let, S, Fn>{{}, static_cast <S&&>(sndr), static_cast <Fn&&>(fun)};
151- }
152-
153- template <sender S, class Fn >
154- STDEXEC_MEMFN_DECL (upon_error_sender_th<S, Fn> upon_error)(
155- this const multi_gpu_stream_scheduler& sch,
156- S&& sndr,
157- Fn fun) noexcept {
158- return upon_error_sender_th<S, Fn>{{}, static_cast <S&&>(sndr), static_cast <Fn&&>(fun)};
159- }
160-
161- template <sender S, class Fn >
162- STDEXEC_MEMFN_DECL (upon_stopped_sender_th<S, Fn> upon_stopped)(
163- this const multi_gpu_stream_scheduler& sch,
164- S&& sndr,
165- Fn fun) noexcept {
166- return upon_stopped_sender_th<S, Fn>{{}, static_cast <S&&>(sndr), static_cast <Fn&&>(fun)};
167- }
168-
169- template <stream_completing_sender... Senders>
170- STDEXEC_MEMFN_DECL (auto transfer_when_all)(
171- this const multi_gpu_stream_scheduler& sch, //
172- Senders&&... sndrs) noexcept {
173- return transfer_when_all_sender_th<multi_gpu_stream_scheduler, Senders...>(
174- sch.context_state_ , static_cast <Senders&&>(sndrs)...);
175- }
176-
177- template <stream_completing_sender... Senders>
178- STDEXEC_MEMFN_DECL (auto transfer_when_all_with_variant)(
179- this const multi_gpu_stream_scheduler& sch, //
180- Senders&&... sndrs) noexcept {
181- return transfer_when_all_sender_th<
182- multi_gpu_stream_scheduler,
183- __result_of<into_variant, Senders>...>(
184- sch.context_state_ , into_variant (static_cast <Senders&&>(sndrs))...);
185- }
186-
187- template <sender S, scheduler Sch>
188- STDEXEC_MEMFN_DECL (auto continues_on)(
189- this const multi_gpu_stream_scheduler& sch, //
190- S&& sndr, //
191- Sch&& scheduler) noexcept {
192- return schedule_from (
193- static_cast <Sch&&>(scheduler),
194- continues_on_sender_th<S>(sch.context_state_ , static_cast <S&&>(sndr)));
195- }
196-
197- template <sender S>
198- STDEXEC_MEMFN_DECL (
199- split_sender_th<S> split)(this const multi_gpu_stream_scheduler& sch, S&& sndr) noexcept {
200- return split_sender_th<S>(static_cast <S&&>(sndr), sch.context_state_ );
201- }
202-
203- template <sender S>
204- STDEXEC_MEMFN_DECL (ensure_started_th<S> ensure_started)(
205- this const multi_gpu_stream_scheduler& sch,
206- S&& sndr) //
207- noexcept {
208- return ensure_started_th<S>(static_cast <S&&>(sndr), sch.context_state_ );
209- }
210-
211- [[nodiscard]]
212- auto schedule () const noexcept -> sender_t {
213- return {num_devices_, context_state_};
214- }
215-
216- template <sender S>
217- STDEXEC_MEMFN_DECL (auto sync_wait)(this const multi_gpu_stream_scheduler& self, S&& sndr) {
218- return _sync_wait::sync_wait_t {}(self.context_state_ , static_cast <S&&>(sndr));
219- }
117+ private:
118+ struct env {
119+ using __t = env;
120+ using __id = env;
220121
221- [[nodiscard]]
222- auto query (get_forward_progress_guarantee_t ) const noexcept -> forward_progress_guarantee {
223- return forward_progress_guarantee::weakly_parallel;
224- }
122+ int num_devices_;
123+ context_state_t context_state_;
225124
226- auto operator ==(const multi_gpu_stream_scheduler& other) const noexcept -> bool {
227- return context_state_.hub_ == other.context_state_ .hub_ ;
228- }
125+ template <class CPO >
126+ [[nodiscard]]
127+ auto query (get_completion_scheduler_t <CPO>) const noexcept -> multi_gpu_stream_scheduler {
128+ return multi_gpu_stream_scheduler{num_devices_, context_state_};
129+ }
130+ };
229131
230- multi_gpu_stream_scheduler (int num_devices, context_state_t context_state)
231- : num_devices_(num_devices)
232- , context_state_(context_state) {
233- }
132+ env env_;
133+ };
234134
135+ public:
235136 // private: TODO
236137 int num_devices_{};
237138 context_state_t context_state_;
@@ -241,23 +142,8 @@ namespace nvexec {
241142 using _strm::multi_gpu_stream_scheduler;
242143
243144 struct multi_gpu_stream_context {
244- int num_devices_{};
245-
246- _strm::resource_storage<_strm::pinned_resource> pinned_resource_{};
247- _strm::resource_storage<_strm::managed_resource> managed_resource_{};
248- _strm::stream_pools_t stream_pools_{};
249-
250- int dev_id_{};
251- _strm::queue::task_hub_t hub_;
252-
253- static auto get_device () -> int {
254- int dev_id{};
255- cudaGetDevice (&dev_id);
256- return dev_id;
257- }
258-
259145 multi_gpu_stream_context ()
260- : dev_id_(get_device ())
146+ : dev_id_(_get_device ())
261147 , hub_(dev_id_, pinned_resource_.get()) {
262148 // TODO Manage errors
263149 cudaGetDeviceCount (&num_devices_);
@@ -278,13 +164,30 @@ namespace nvexec {
278164 cudaSetDevice (dev_id_);
279165 }
280166
167+ [[nodiscard]]
281168 auto get_scheduler (stream_priority priority = stream_priority::normal)
282169 -> multi_gpu_stream_scheduler {
283170 return {
284171 num_devices_,
285172 _strm::context_state_t (
286173 pinned_resource_.get (), managed_resource_.get (), &stream_pools_, &hub_, priority)};
287174 }
175+
176+ private:
177+ static auto _get_device () -> int {
178+ int dev_id{};
179+ cudaGetDevice (&dev_id);
180+ return dev_id;
181+ }
182+
183+ int num_devices_{};
184+
185+ _strm::resource_storage<_strm::pinned_resource> pinned_resource_{};
186+ _strm::resource_storage<_strm::managed_resource> managed_resource_{};
187+ _strm::stream_pools_t stream_pools_{};
188+
189+ int dev_id_{};
190+ _strm::queue::task_hub_t hub_;
288191 };
289192} // namespace nvexec
290193
0 commit comments