3333
3434STDEXEC_PRAGMA_PUSH ()
3535STDEXEC_PRAGMA_IGNORE_EDG(cuda_compile)
36+ STDEXEC_PRAGMA_IGNORE_GNU(" -Wmissing-braces" )
3637
3738namespace nvexec::_strm {
3839
@@ -83,7 +84,7 @@ namespace nvexec::_strm {
8384 template <class ... As, class TupleT >
8485 __launch_bounds__ (1 ) __global__ void copy_kernel (TupleT* tpl, As... as) {
8586 static_assert (trivially_copyable<As...>);
86- *tpl = __decayed_tuple<As...>{{ static_cast <As&&>(as)} ...};
87+ *tpl = __decayed_tuple<As...>{static_cast <As&&>(as)...};
8788 }
8889
8990 template <class ... Env, class ... Senders>
@@ -110,6 +111,14 @@ namespace nvexec::_strm {
110111 __minvoke<__mpush_back<__q<completion_signatures>>, non_values, values>,
111112 non_values>;
112113 };
114+
115+ inline constexpr auto _sync_op = []<class OpT >(OpT& op) noexcept {
116+ if constexpr (STDEXEC_IS_BASE_OF (stream_op_state_base, OpT)) {
117+ if (op.stream_provider_ .status_ == cudaSuccess) {
118+ op.stream_provider_ .status_ = STDEXEC_DBG_ERR (cudaStreamSynchronize (op.get_stream ()));
119+ }
120+ }
121+ };
113122 } // namespace _when_all
114123
115124 template <bool WithCompletionScheduler, class Scheduler , class ... SenderIds>
@@ -131,7 +140,7 @@ namespace nvexec::_strm {
131140 template <class ... Sndrs>
132141 explicit __t (context_state_t context_state, Sndrs&&... __sndrs)
133142 : env_{context_state}
134- , sndrs_{{ static_cast <Sndrs&&>(__sndrs)} ...} {
143+ , sndrs_{static_cast <Sndrs&&>(__sndrs)...} {
135144 }
136145
137146 private:
@@ -159,76 +168,31 @@ namespace nvexec::_strm {
159168 struct receiver_t {
160169 using WhenAll = __copy_cvref_t <CvrefReceiverId, stdexec::__t <when_all_sender_t >>;
161170 using Receiver = stdexec::__t <__decay_t <CvrefReceiverId>>;
171+ using SenderId = __m_at_c<Index, SenderIds...>;
172+ using Completions = completion_sigs<env_of_t <Receiver>, CvrefReceiverId>;
162173 using Env = //
163174 make_terminal_stream_env_t <
164175 exec::make_env_t <env_of_t <Receiver>, stdexec::prop<get_stop_token_t , inplace_stop_token>>>;
165176
166177 struct __t : stream_receiver_base {
167178 using receiver_concept = stdexec::receiver_t ;
168179 using __id = receiver_t ;
169- using SenderId = nvexec::detail::nth_type<Index, SenderIds...>;
170- using Completions = completion_sigs<env_of_t <Receiver>, CvrefReceiverId>;
171-
172- template <class Error >
173- void _set_error_impl (Error&& err) noexcept {
174- // Transition to the "error" state and switch on the prior state.
175- // TODO: What memory orderings are actually needed here?
176- switch (op_state_->state_ .exchange (_when_all::error)) {
177- case _when_all::started:
178- // We must request stop. When the previous state is "error" or "stopped", then stop
179- // has already been requested.
180- op_state_->stop_source_ .request_stop ();
181- [[fallthrough]];
182- case _when_all::stopped:
183- // We are the first child to complete with an error, so we must save the error.
184- // (Any subsequent errors are ignored.)
185- static_assert (__nothrow_constructible_from<__decay_t <Error>, Error>);
186- op_state_->errors_ .template emplace <__decay_t <Error>>(static_cast <Error&&>(err));
187- break ;
188- case _when_all::error:; // We're already in the "error" state. Ignore the error.
189- }
190- }
191180
192181 template <class ... Values>
193- void set_value (Values&&... vals) && noexcept {
194- if constexpr (__v<sends_values<Completions>>) {
195- // We only need to bother recording the completion values
196- // if we're not already in the "error" or "stopped" state.
197- if (op_state_->state_ == _when_all::started) {
198- cudaStream_t stream = __tup::get<Index>(op_state_->child_states_ ).get_stream ();
199- if constexpr (sizeof ...(Values)) {
200- _when_all::copy_kernel<Values&&...><<<1 , 1 , 0 , stream>>> (
201- &__tup::get<Index>(*op_state_->values_ ), static_cast <Values&&>(vals)...);
202- op_state_->statuses_ [Index] = cudaGetLastError ();
203- }
204-
205- if constexpr (stream_receiver<Receiver>) {
206- if (op_state_->statuses_ [Index] == cudaSuccess) {
207- op_state_->statuses_ [Index] = op_state_->events_ [Index].try_record (stream);
208- }
209- }
210- }
211- }
212- op_state_->arrive ();
182+ STDEXEC_ATTRIBUTE ((always_inline)) void set_value (Values&&... vals) && noexcept {
183+ op_state_->template _set_value <Index>(static_cast <Values&&>(vals)...);
213184 }
214185
215186 template <class Error >
216- requires tag_invocable<set_error_t , Receiver, Error>
217- void set_error (Error&& err) && noexcept {
218- _set_error_impl (static_cast <Error&&>(err));
219- op_state_->arrive ();
187+ STDEXEC_ATTRIBUTE ((always_inline)) void set_error (Error&& err) && noexcept {
188+ op_state_->_set_error (static_cast <Error&&>(err));
220189 }
221190
222- void set_stopped () && noexcept {
223- _when_all::state_t expected = _when_all::started;
224- // Transition to the "stopped" state if and only if we're in the "started" state. (If
225- // this fails, it's because we're in the "error" state, which trumps cancellation.)
226- if (op_state_->state_ .compare_exchange_strong (expected, _when_all::stopped)) {
227- op_state_->stop_source_ .request_stop ();
228- }
229- op_state_->arrive ();
191+ STDEXEC_ATTRIBUTE ((always_inline)) void set_stopped () && noexcept {
192+ op_state_->_set_stopped ();
230193 }
231194
195+ [[nodiscard]]
232196 auto get_env () const noexcept -> Env {
233197 auto env = make_terminal_stream_env (
234198 exec::make_env (
@@ -264,11 +228,11 @@ namespace nvexec::_strm {
264228 using __child_ops_t = __tuple_for<child_op_state_t <SenderIds, Is>...>;
265229 return when_all.sndrs_ .apply (
266230 [parent_op]<class ... Children>(Children&&... children) -> __child_ops_t {
267- return __child_ops_t {{ _strm::exit_op_state (
231+ return __child_ops_t {_strm::exit_op_state (
268232 static_cast <Children&&>(children),
269233 stdexec::__t <receiver_t <CvrefReceiverId, Is>>{{}, parent_op},
270234 stdexec::get_completion_scheduler<set_value_t >(stdexec::get_env (children))
271- .context_state_ )} ...};
235+ .context_state_ )...};
272236 },
273237 static_cast <WhenAll&&>(when_all).sndrs_ );
274238 }
@@ -277,20 +241,11 @@ namespace nvexec::_strm {
277241 decltype (operation_t ::connect_children_({}, __declval<WhenAll>(), Indices{}));
278242
279243 void arrive () noexcept {
280- if (0 == -- count_) {
244+ if (1 == count_. fetch_sub ( 1 ) ) {
281245 complete ();
282246 }
283247 }
284248
285- template <class OpT >
286- static void sync (OpT& op) noexcept {
287- if constexpr (STDEXEC_IS_BASE_OF (stream_op_state_base, OpT)) {
288- if (op.stream_provider_ .status_ == cudaSuccess) {
289- op.stream_provider_ .status_ = STDEXEC_DBG_ERR (cudaStreamSynchronize (op.get_stream ()));
290- }
291- }
292- }
293-
294249 void complete () noexcept {
295250 // Stop callback is no longer needed. Destroy it.
296251 on_stop_.reset ();
@@ -316,7 +271,8 @@ namespace nvexec::_strm {
316271 }
317272 }
318273 } else {
319- child_states_.apply ([](auto &... ops) { (sync (ops), ...); }, child_states_);
274+ // Synchronize the streams of all the child operations
275+ child_states_.for_each (_when_all::_sync_op, child_states_);
320276 }
321277 }
322278
@@ -327,22 +283,18 @@ namespace nvexec::_strm {
327283 if constexpr (__v<sends_values<Completions>>) {
328284 // All child operations completed successfully:
329285 values_->apply (
330- [this ]( auto & ... opt_vals ) -> void {
286+ [this ]< class ... Tuples>(Tuples&& ... value_tupls ) -> void {
331287 __tup::__cat_apply (
332- [ this ]( auto &... all_vals) -> void {
333- stdexec::set_value ( static_cast <Receiver &&>(rcvr_), std::move (all_vals )...);
288+ __mk_completion_fn (stdexec::set_value, rcvr_),
289+ static_cast <Tuples &&>(value_tupls )...);
334290 },
335- opt_vals...);
336- },
337- *values_);
291+ static_cast <child_values_tuple_t &&>(*values_));
338292 }
339293 break ;
340294 case _when_all::error:
341295 errors_.visit (
342- [this ](auto & err) noexcept {
343- stdexec::set_error (static_cast <Receiver&&>(rcvr_), std::move (err));
344- },
345- errors_);
296+ __mk_completion_fn (stdexec::set_error, rcvr_),
297+ static_cast <errors_variant_t &&>(errors_));
346298 break ;
347299 case _when_all::stopped:
348300 stdexec::set_stopped (static_cast <Receiver&&>(rcvr_));
@@ -389,24 +341,83 @@ namespace nvexec::_strm {
389341 // the child operations.
390342 stdexec::set_stopped (static_cast <Receiver&&>(rcvr_));
391343 } else {
344+ child_states_.for_each (stdexec::start, child_states_);
392345 if constexpr (sizeof ...(SenderIds) == 0 ) {
393346 complete ();
394- } else {
395- child_states_.for_each (stdexec::start, child_states_);
396347 }
397348 }
398349 }
399350
351+ template <class Error >
352+ void _set_error_impl (Error&& err) noexcept {
353+ // Transition to the "error" state and switch on the prior state.
354+ // TODO: What memory orderings are actually needed here?
355+ switch (state_.exchange (_when_all::error)) {
356+ case _when_all::started:
357+ // We must request stop. When the previous state is "error" or "stopped", then stop
358+ // has already been requested.
359+ stop_source_.request_stop ();
360+ [[fallthrough]];
361+ case _when_all::stopped:
362+ // We are the first child to complete with an error, so we must save the error.
363+ // (Any subsequent errors are ignored.)
364+ static_assert (__nothrow_constructible_from<__decay_t <Error>, Error>);
365+ errors_.template emplace <__decay_t <Error>>(static_cast <Error&&>(err));
366+ break ;
367+ case _when_all::error:; // We're already in the "error" state. Ignore the error.
368+ }
369+ }
370+
371+ template <std::size_t Index, class ... Args>
372+ void _set_value (Args&&... args) noexcept {
373+ if constexpr (__v<sends_values<Completions>>) {
374+ // We only need to bother recording the completion values
375+ // if we're not already in the "error" or "stopped" state.
376+ if (state_.load () == _when_all::started) {
377+ cudaStream_t stream = child_states_.template __get <Index>(child_states_).get_stream ();
378+ if constexpr (sizeof ...(Args)) {
379+ _when_all::copy_kernel<Args&&...><<<1 , 1 , 0 , stream>>> (
380+ &(values_->template __get <Index>(*values_)), static_cast <Args&&>(args)...);
381+ statuses_[Index] = cudaGetLastError ();
382+ }
383+
384+ if constexpr (stream_receiver<Receiver>) {
385+ if (statuses_[Index] == cudaSuccess) {
386+ statuses_[Index] = events_[Index].try_record (stream);
387+ }
388+ }
389+ }
390+ }
391+ arrive ();
392+ }
393+
394+ template <class Error >
395+ void _set_error (Error&& err) noexcept {
396+ _set_error_impl (static_cast <Error&&>(err));
397+ arrive ();
398+ }
399+
400+ void _set_stopped () noexcept {
401+ auto expected = _when_all::started;
402+ // Transition to the "stopped" state if and only if we're in the
403+ // "started" state. (If this fails, it's because we're in an
404+ // error state, which trumps cancellation.)
405+ if (state_.compare_exchange_strong (expected, _when_all::stopped)) {
406+ stop_source_.request_stop ();
407+ }
408+ arrive ();
409+ }
410+
400411 // tuple<tuple<Vs1...>, tuple<Vs2...>, ...>
401412 using child_values_tuple_t = //
402413 __if<
403414 sends_values<Completions>,
404415 __minvoke<
405- __q <__tuple_for>,
416+ __qq <__tuple_for>,
406417 __value_types_of_t <
407418 stdexec::__t <SenderIds>,
408419 _when_all::env_t <Env>,
409- __q <__decayed_tuple>,
420+ __qq <__decayed_tuple>,
410421 __msingle_or<void >>...>,
411422 __>;
412423
@@ -447,6 +458,7 @@ namespace nvexec::_strm {
447458 return {};
448459 }
449460
461+ [[nodiscard]]
450462 auto get_env () const noexcept -> const env& {
451463 return env_;
452464 }
0 commit comments