2525#include < utility>
2626
2727#include " common.cuh"
28+ #include " ../detail/event.cuh"
2829#include " ../detail/throw_on_cuda_error.cuh"
2930
3031STDEXEC_PRAGMA_PUSH ()
@@ -166,9 +167,17 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
166167 using Completions = completion_sigs<env_of_t <Receiver>, CvrefReceiverId>;
167168
168169 template <class Error >
169- void _set_error_impl (Error&& err, _when_all:: state_t expected ) noexcept {
170+ void _set_error_impl (Error&& err) noexcept {
170171 // TODO: What memory orderings are actually needed here?
171- if (op_state_->state_ .compare_exchange_strong (expected, _when_all::error)) {
172+ auto old_state = op_state_->__state_ .exchange (_when_all::error);
173+ // If the previous state was __error or __stopped, then we have already requested
174+ // stop on the stop source. Otherwise, request stop.
175+ if (old_state == _when_all::started) {
176+ op_state_->__stop_source_ .request_stop ();
177+ }
178+ // If we are the first child to complete with an error, we must save the error.
179+ // (Any subsequent errors are ignores.)
180+ if (old_state != _when_all::error) {
172181 op_state_->stop_source_ .request_stop ();
173182 // We won the race, free to write the error into the operation
174183 // state without worry.
@@ -187,12 +196,12 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
187196 if constexpr (sizeof ...(Values)) {
188197 _when_all::copy_kernel<Values&&...><<<1 , 1 , 0 , stream>>> (
189198 &__tup::get<Index>(*op_state_->values_ ), static_cast <Values&&>(vals)...);
199+ op_state_->statuses_ [Index] = cudaGetLastError ();
190200 }
191201
192202 if constexpr (stream_receiver<Receiver>) {
193- if (op_state_->status_ == cudaSuccess) {
194- op_state_->status_ =
195- STDEXEC_DBG_ERR (cudaEventRecord (op_state_->events_ [Index], stream));
203+ if (op_state_->statuses_ [Index] == cudaSuccess) {
204+ op_state_->statuses_ [Index] = op_state_->events_ [Index].try_record (stream);
196205 }
197206 }
198207 }
@@ -203,7 +212,7 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
203212 template <class Error >
204213 requires tag_invocable<set_error_t , Receiver, Error>
205214 void set_error (Error&& err) && noexcept {
206- _set_error_impl (static_cast <Error&&>(err), _when_all::started );
215+ _set_error_impl (static_cast <Error&&>(err));
207216 }
208217
209218 void set_stopped () && noexcept {
@@ -238,8 +247,6 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
238247 using Env = env_of_t <Receiver>;
239248 using Completions = completion_sigs<Env, CvrefReceiverId>;
240249
241- cudaError_t status_{cudaSuccess};
242-
243250 template <class SenderId , std::size_t Index>
244251 using child_op_state_t = exit_operation_state_t <
245252 __copy_cvref_t <WhenAll, stdexec::__t <SenderId>>,
@@ -285,6 +292,14 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
285292 // Stop callback is no longer needed. Destroy it.
286293 on_stop_.reset ();
287294
295+ // See if any child operations completed with an error status:
296+ for (auto status: statuses_) {
297+ if (status != cudaSuccess) {
298+ status_ = status;
299+ break ;
300+ }
301+ }
302+
288303 // Synchronize streams
289304 if (status_ == cudaSuccess) {
290305 if constexpr (stream_receiver<Receiver>) {
@@ -294,7 +309,7 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
294309
295310 for (int i = 0 ; i < sizeof ...(SenderIds); i++) {
296311 if (status_ == cudaSuccess) {
297- status_ = STDEXEC_DBG_ERR ( cudaStreamWaitEvent (stream, events_[i], 0 ) );
312+ status_ = events_[i]. try_wait (stream );
298313 }
299314 }
300315 } else {
@@ -354,19 +369,10 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
354369 , child_states_{
355370 operation_t::connect_children_ (this , static_cast <WhenAll&&>(when_all), Indices{})} {
356371 status_ = STDEXEC_DBG_ERR (cudaMallocManaged (&values_, sizeof (child_values_tuple_t )));
357- for (std::size_t i = 0 ; i < sizeof ...(SenderIds); ++i) {
358- if (status_ == cudaSuccess) {
359- status_ = STDEXEC_DBG_ERR (cudaEventCreate (&events_[i], cudaEventDisableTiming));
360- }
361- }
362372 }
363373
364374 ~operation_t () {
365375 STDEXEC_DBG_ERR (cudaFree (values_));
366-
367- for (int i = 0 ; i < sizeof ...(SenderIds); i++) {
368- STDEXEC_DBG_ERR (cudaEventDestroy (events_[i]));
369- }
370376 }
371377
372378 STDEXEC_IMMOVABLE (operation_t );
@@ -388,7 +394,7 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
388394 }
389395 }
390396
391- // tuple<optional< tuple<Vs1...>>, optional< tuple<Vs2...> >, ...>
397+ // tuple<tuple<Vs1...>, tuple<Vs2...>, ...>
392398 using child_values_tuple_t = //
393399 __if<
394400 sends_values<Completions>,
@@ -408,9 +414,11 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
408414 __uniqued_variant_for>;
409415
410416 Receiver rcvr_;
417+ cudaError_t status_{cudaSuccess};
411418 std::atomic<std::size_t > count_{sizeof ...(SenderIds)};
412419 std::array<stream_provider_t , sizeof ...(SenderIds)> stream_providers_;
413- std::array<cudaEvent_t, sizeof ...(SenderIds)> events_;
420+ std::array<detail::cuda_event, sizeof ...(SenderIds)> events_{};
421+ std::array<cudaError_t, sizeof ...(SenderIds)> statuses_{}; // all initialized to cudaSuccess
414422 child_op_states_tuple_t child_states_;
415423 // Could be non-atomic here and atomic_ref everywhere except __completion_fn
416424 std::atomic<_when_all::state_t > state_{_when_all::started};
0 commit comments