Skip to content

Commit 70e2266

Browse files
committed
fix race condition in nvexec's when_all implementation
1 parent 3390523 commit 70e2266

4 files changed

Lines changed: 129 additions & 35 deletions

File tree

include/nvexec/detail/event.cuh

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright (c) 2025 NVIDIA Corporation
3+
*
4+
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
5+
* (the "License"); you may not use this file except in compliance with
6+
* the License. You may obtain a copy of the License at
7+
*
8+
* https://llvm.org/LICENSE.txt
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
// clang-format Language: Cpp
18+
19+
#pragma once
20+
21+
#include "config.cuh"
22+
#include "cuda_fwd.cuh"
23+
#include "throw_on_cuda_error.cuh"
24+
25+
#include <utility>
26+
27+
namespace nvexec::detail {
28+
struct cuda_event {
29+
cuda_event() {
30+
if (auto status =
31+
STDEXEC_DBG_ERR(::cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
32+
status != cudaSuccess) {
33+
throw cuda_error(status, "cudaEventCreate");
34+
}
35+
}
36+
37+
cuda_event(cuda_event&& other) noexcept
38+
: event_(std::exchange(other.event_, nullptr)) {
39+
}
40+
41+
~cuda_event() {
42+
if (event_ != nullptr) {
43+
STDEXEC_DBG_ERR(::cudaEventDestroy(event_));
44+
}
45+
}
46+
47+
auto operator=(cuda_event&& other) noexcept -> cuda_event& {
48+
event_ = std::exchange(other.event_, nullptr);
49+
return *this;
50+
}
51+
52+
auto try_record(cudaStream_t stream) noexcept -> cudaError_t {
53+
return STDEXEC_DBG_ERR(::cudaEventRecord(event_, stream));
54+
}
55+
56+
auto try_wait(cudaStream_t stream) noexcept -> cudaError_t {
57+
return STDEXEC_DBG_ERR(::cudaStreamWaitEvent(stream, event_, 0));
58+
}
59+
60+
private:
61+
cudaEvent_t event_{};
62+
};
63+
} // namespace nvexec::detail

include/nvexec/detail/throw_on_cuda_error.cuh

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,37 @@
1818
#include "config.cuh"
1919

2020
#include <cstdio>
21+
#include <stdexcept>
2122

2223
#include <cuda_runtime_api.h>
2324

24-
namespace nvexec {
25-
namespace detail {
26-
inline cudaError_t debug_cuda_error(
27-
cudaError_t error,
28-
[[maybe_unused]] char const * file_name,
29-
[[maybe_unused]] int line) {
30-
// Clear the global CUDA error state which may have been set by the last
31-
// call. Otherwise, errors may "leak" to unrelated calls.
32-
cudaGetLastError();
25+
namespace nvexec::detail {
26+
class cuda_error : public ::std::runtime_error {
27+
private:
28+
struct __msg_storage {
29+
char __buffer[256]; // NOLINT
30+
};
31+
32+
static auto
33+
__format_cuda_error(const int __status, const char* __msg, char* __msg_buffer) noexcept
34+
-> char* {
35+
::snprintf(__msg_buffer, 256, "cudaError %d: %s", __status, __msg);
36+
return __msg_buffer;
37+
}
38+
39+
public:
40+
cuda_error(const int __status, const char* __msg, __msg_storage __msg_buffer = {0}) noexcept
41+
: ::std::runtime_error(__format_cuda_error(__status, __msg, __msg_buffer.__buffer)) {
42+
}
43+
};
44+
45+
inline auto debug_cuda_error(
46+
cudaError_t error,
47+
[[maybe_unused]] char const * file_name,
48+
[[maybe_unused]] int line) -> cudaError_t {
49+
// Clear the global CUDA error state which may have been set by the last
50+
// call. Otherwise, errors may "leak" to unrelated calls.
51+
cudaGetLastError();
3352

3453
#if defined(STDEXEC_STDERR)
3554
if (error != cudaSuccess) {
@@ -43,8 +62,7 @@ namespace nvexec {
4362
#endif
4463

4564
return error;
46-
}
47-
} // namespace detail
65+
}
66+
} // namespace nvexec::detail
4867

4968
#define STDEXEC_DBG_ERR(E) ::nvexec::detail::debug_cuda_error(E, __FILE__, __LINE__) /**/
50-
} // namespace nvexec

include/nvexec/stream/when_all.cuh

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <utility>
2626

2727
#include "common.cuh"
28+
#include "../detail/event.cuh"
2829
#include "../detail/throw_on_cuda_error.cuh"
2930

3031
STDEXEC_PRAGMA_PUSH()
@@ -166,9 +167,17 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
166167
using Completions = completion_sigs<env_of_t<Receiver>, CvrefReceiverId>;
167168

168169
template <class Error>
169-
void _set_error_impl(Error&& err, _when_all::state_t expected) noexcept {
170+
void _set_error_impl(Error&& err) noexcept {
170171
// TODO: What memory orderings are actually needed here?
171-
if (op_state_->state_.compare_exchange_strong(expected, _when_all::error)) {
172+
auto old_state = op_state_->__state_.exchange(_when_all::error);
173+
// If the previous state was __error or __stopped, then we have already requested
174+
// stop on the stop source. Otherwise, request stop.
175+
if (old_state == _when_all::started) {
176+
op_state_->__stop_source_.request_stop();
177+
}
178+
// If we are the first child to complete with an error, we must save the error.
179+
// (Any subsequent errors are ignores.)
180+
if (old_state != _when_all::error) {
172181
op_state_->stop_source_.request_stop();
173182
// We won the race, free to write the error into the operation
174183
// state without worry.
@@ -187,12 +196,12 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
187196
if constexpr (sizeof...(Values)) {
188197
_when_all::copy_kernel<Values&&...><<<1, 1, 0, stream>>>(
189198
&__tup::get<Index>(*op_state_->values_), static_cast<Values&&>(vals)...);
199+
op_state_->statuses_[Index] = cudaGetLastError();
190200
}
191201

192202
if constexpr (stream_receiver<Receiver>) {
193-
if (op_state_->status_ == cudaSuccess) {
194-
op_state_->status_ =
195-
STDEXEC_DBG_ERR(cudaEventRecord(op_state_->events_[Index], stream));
203+
if (op_state_->statuses_[Index] == cudaSuccess) {
204+
op_state_->statuses_[Index] = op_state_->events_[Index].try_record(stream);
196205
}
197206
}
198207
}
@@ -203,7 +212,7 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
203212
template <class Error>
204213
requires tag_invocable<set_error_t, Receiver, Error>
205214
void set_error(Error&& err) && noexcept {
206-
_set_error_impl(static_cast<Error&&>(err), _when_all::started);
215+
_set_error_impl(static_cast<Error&&>(err));
207216
}
208217

209218
void set_stopped() && noexcept {
@@ -238,8 +247,6 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
238247
using Env = env_of_t<Receiver>;
239248
using Completions = completion_sigs<Env, CvrefReceiverId>;
240249

241-
cudaError_t status_{cudaSuccess};
242-
243250
template <class SenderId, std::size_t Index>
244251
using child_op_state_t = exit_operation_state_t<
245252
__copy_cvref_t<WhenAll, stdexec::__t<SenderId>>,
@@ -285,6 +292,14 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
285292
// Stop callback is no longer needed. Destroy it.
286293
on_stop_.reset();
287294

295+
// See if any child operations completed with an error status:
296+
for (auto status: statuses_) {
297+
if (status != cudaSuccess) {
298+
status_ = status;
299+
break;
300+
}
301+
}
302+
288303
// Synchronize streams
289304
if (status_ == cudaSuccess) {
290305
if constexpr (stream_receiver<Receiver>) {
@@ -294,7 +309,7 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
294309

295310
for (int i = 0; i < sizeof...(SenderIds); i++) {
296311
if (status_ == cudaSuccess) {
297-
status_ = STDEXEC_DBG_ERR(cudaStreamWaitEvent(stream, events_[i], 0));
312+
status_ = events_[i].try_wait(stream);
298313
}
299314
}
300315
} else {
@@ -354,19 +369,10 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
354369
, child_states_{
355370
operation_t::connect_children_(this, static_cast<WhenAll&&>(when_all), Indices{})} {
356371
status_ = STDEXEC_DBG_ERR(cudaMallocManaged(&values_, sizeof(child_values_tuple_t)));
357-
for (std::size_t i = 0; i < sizeof...(SenderIds); ++i) {
358-
if (status_ == cudaSuccess) {
359-
status_ = STDEXEC_DBG_ERR(cudaEventCreate(&events_[i], cudaEventDisableTiming));
360-
}
361-
}
362372
}
363373

364374
~operation_t() {
365375
STDEXEC_DBG_ERR(cudaFree(values_));
366-
367-
for (int i = 0; i < sizeof...(SenderIds); i++) {
368-
STDEXEC_DBG_ERR(cudaEventDestroy(events_[i]));
369-
}
370376
}
371377

372378
STDEXEC_IMMOVABLE(operation_t);
@@ -388,7 +394,7 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
388394
}
389395
}
390396

391-
// tuple<optional<tuple<Vs1...>>, optional<tuple<Vs2...>>, ...>
397+
// tuple<tuple<Vs1...>, tuple<Vs2...>, ...>
392398
using child_values_tuple_t = //
393399
__if<
394400
sends_values<Completions>,
@@ -408,9 +414,11 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
408414
__uniqued_variant_for>;
409415

410416
Receiver rcvr_;
417+
cudaError_t status_{cudaSuccess};
411418
std::atomic<std::size_t> count_{sizeof...(SenderIds)};
412419
std::array<stream_provider_t, sizeof...(SenderIds)> stream_providers_;
413-
std::array<cudaEvent_t, sizeof...(SenderIds)> events_;
420+
std::array<detail::cuda_event, sizeof...(SenderIds)> events_{};
421+
std::array<cudaError_t, sizeof...(SenderIds)> statuses_{}; // all initialized to cudaSuccess
414422
child_op_states_tuple_t child_states_;
415423
// Could be non-atomic here and atomic_ref everywhere except __completion_fn
416424
std::atomic<_when_all::state_t> state_{_when_all::started};

include/stdexec/__detail/__when_all.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,10 +324,15 @@ namespace stdexec {
324324
template <class _State, class _Receiver, class _Error>
325325
static void __set_error(_State& __state, _Receiver&, _Error&& __err) noexcept {
326326
// TODO: What memory orderings are actually needed here?
327-
if (__error != __state.__state_.exchange(__error)) {
327+
auto __old_state = __state.__state_.exchange(__error);
328+
// If the previous state was __error or __stopped, then we have already requested
329+
// stop on the stop source. Otherwise, request stop.
330+
if (__old_state == __started) {
328331
__state.__stop_source_.request_stop();
329-
// We won the race, free to write the error into the operation
330-
// state without worry.
332+
}
333+
// If we are the first child to complete with an error, we must save the error.
334+
// (Any subsequent errors are ignores.)
335+
if (__old_state != __error) {
331336
if constexpr (__nothrow_decay_copyable<_Error>) {
332337
__state.__errors_.template emplace<__decay_t<_Error>>(static_cast<_Error&&>(__err));
333338
} else {

0 commit comments

Comments
 (0)