Skip to content

Commit c7ae38c

Browse files
authored
fix race condition in nvexec's when_all implementation (#1489)
in nvexec's `when_all` * fix race condition in nvexec's when_all implementation * test for error after cuda kernel launch in when_all * don't ignore an child's error if another child has already stopped early in stdexec's `when_all` * issue only one stop request if one child stops early and then another child errors.
1 parent 3390523 commit c7ae38c

4 files changed

Lines changed: 138 additions & 41 deletions

File tree

include/nvexec/detail/event.cuh

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
/*
2+
* Copyright (c) 2025 NVIDIA Corporation
3+
*
4+
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
5+
* (the "License"); you may not use this file except in compliance with
6+
* the License. You may obtain a copy of the License at
7+
*
8+
* https://llvm.org/LICENSE.txt
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
// clang-format Language: Cpp
18+
19+
#pragma once
20+
21+
#include "config.cuh"
22+
#include "cuda_fwd.cuh"
23+
#include "throw_on_cuda_error.cuh"
24+
25+
#include <utility>
26+
27+
namespace nvexec::detail {
28+
struct cuda_event {
29+
cuda_event() {
30+
if (auto status =
31+
STDEXEC_DBG_ERR(::cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
32+
status != cudaSuccess) {
33+
throw cuda_error(status, "cudaEventCreate");
34+
}
35+
}
36+
37+
cuda_event(cuda_event&& other) noexcept
38+
: event_(std::exchange(other.event_, nullptr)) {
39+
}
40+
41+
~cuda_event() {
42+
if (event_ != nullptr) {
43+
STDEXEC_DBG_ERR(::cudaEventDestroy(event_));
44+
}
45+
}
46+
47+
auto operator=(cuda_event&& other) noexcept -> cuda_event& {
48+
event_ = std::exchange(other.event_, nullptr);
49+
return *this;
50+
}
51+
52+
auto try_record(cudaStream_t stream) noexcept -> cudaError_t {
53+
return STDEXEC_DBG_ERR(::cudaEventRecord(event_, stream));
54+
}
55+
56+
auto try_wait(cudaStream_t stream) noexcept -> cudaError_t {
57+
return STDEXEC_DBG_ERR(::cudaStreamWaitEvent(stream, event_, 0));
58+
}
59+
60+
private:
61+
cudaEvent_t event_{};
62+
};
63+
} // namespace nvexec::detail

include/nvexec/detail/throw_on_cuda_error.cuh

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,18 +18,37 @@
1818
#include "config.cuh"
1919

2020
#include <cstdio>
21+
#include <stdexcept>
2122

2223
#include <cuda_runtime_api.h>
2324

24-
namespace nvexec {
25-
namespace detail {
26-
inline cudaError_t debug_cuda_error(
27-
cudaError_t error,
28-
[[maybe_unused]] char const * file_name,
29-
[[maybe_unused]] int line) {
30-
// Clear the global CUDA error state which may have been set by the last
31-
// call. Otherwise, errors may "leak" to unrelated calls.
32-
cudaGetLastError();
25+
namespace nvexec::detail {
26+
class cuda_error : public ::std::runtime_error {
27+
private:
28+
struct __msg_storage {
29+
char __buffer[256]; // NOLINT
30+
};
31+
32+
static auto
33+
__format_cuda_error(const int __status, const char* __msg, char* __msg_buffer) noexcept
34+
-> char* {
35+
::snprintf(__msg_buffer, 256, "cudaError %d: %s", __status, __msg);
36+
return __msg_buffer;
37+
}
38+
39+
public:
40+
cuda_error(const int __status, const char* __msg, __msg_storage __msg_buffer = {0}) noexcept
41+
: ::std::runtime_error(__format_cuda_error(__status, __msg, __msg_buffer.__buffer)) {
42+
}
43+
};
44+
45+
inline auto debug_cuda_error(
46+
cudaError_t error,
47+
[[maybe_unused]] char const * file_name,
48+
[[maybe_unused]] int line) -> cudaError_t {
49+
// Clear the global CUDA error state which may have been set by the last
50+
// call. Otherwise, errors may "leak" to unrelated calls.
51+
cudaGetLastError();
3352

3453
#if defined(STDEXEC_STDERR)
3554
if (error != cudaSuccess) {
@@ -43,8 +62,7 @@ namespace nvexec {
4362
#endif
4463

4564
return error;
46-
}
47-
} // namespace detail
65+
}
66+
} // namespace nvexec::detail
4867

4968
#define STDEXEC_DBG_ERR(E) ::nvexec::detail::debug_cuda_error(E, __FILE__, __LINE__) /**/
50-
} // namespace nvexec

include/nvexec/stream/when_all.cuh

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <utility>
2626

2727
#include "common.cuh"
28+
#include "../detail/event.cuh"
2829
#include "../detail/throw_on_cuda_error.cuh"
2930

3031
STDEXEC_PRAGMA_PUSH()
@@ -166,15 +167,23 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
166167
using Completions = completion_sigs<env_of_t<Receiver>, CvrefReceiverId>;
167168

168169
template <class Error>
169-
void _set_error_impl(Error&& err, _when_all::state_t expected) noexcept {
170+
void _set_error_impl(Error&& err) noexcept {
171+
// Transition to the "error" state and switch on the prior state.
170172
// TODO: What memory orderings are actually needed here?
171-
if (op_state_->state_.compare_exchange_strong(expected, _when_all::error)) {
173+
switch (op_state_->state_.exchange(_when_all::error)) {
174+
case _when_all::started:
175+
// We must request stop. When the previous state is "error" or "stopped", then stop
176+
// has already been requested.
172177
op_state_->stop_source_.request_stop();
173-
// We won the race, free to write the error into the operation
174-
// state without worry.
178+
[[fallthrough]];
179+
case _when_all::stopped:
180+
// We are the first child to complete with an error, so we must save the error.
181+
// (Any subsequent errors are ignored.)
182+
static_assert(__nothrow_constructible_from<__decay_t<Error>, Error>);
175183
op_state_->errors_.template emplace<__decay_t<Error>>(static_cast<Error&&>(err));
184+
break;
185+
case _when_all::error:; // We're already in the "error" state. Ignore the error.
176186
}
177-
op_state_->arrive();
178187
}
179188

180189
template <class... Values>
@@ -187,12 +196,12 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
187196
if constexpr (sizeof...(Values)) {
188197
_when_all::copy_kernel<Values&&...><<<1, 1, 0, stream>>>(
189198
&__tup::get<Index>(*op_state_->values_), static_cast<Values&&>(vals)...);
199+
op_state_->statuses_[Index] = cudaGetLastError();
190200
}
191201

192202
if constexpr (stream_receiver<Receiver>) {
193-
if (op_state_->status_ == cudaSuccess) {
194-
op_state_->status_ =
195-
STDEXEC_DBG_ERR(cudaEventRecord(op_state_->events_[Index], stream));
203+
if (op_state_->statuses_[Index] == cudaSuccess) {
204+
op_state_->statuses_[Index] = op_state_->events_[Index].try_record(stream);
196205
}
197206
}
198207
}
@@ -203,14 +212,14 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
203212
template <class Error>
204213
requires tag_invocable<set_error_t, Receiver, Error>
205214
void set_error(Error&& err) && noexcept {
206-
_set_error_impl(static_cast<Error&&>(err), _when_all::started);
215+
_set_error_impl(static_cast<Error&&>(err));
216+
op_state_->arrive();
207217
}
208218

209219
void set_stopped() && noexcept {
210220
_when_all::state_t expected = _when_all::started;
211-
// Transition to the "stopped" state if and only if we're in the
212-
// "started" state. (If this fails, it's because we're in an
213-
// error state, which trumps cancellation.)
221+
// Transition to the "stopped" state if and only if we're in the "started" state. (If
222+
// this fails, it's because we're in the "error" state, which trumps cancellation.)
214223
if (op_state_->state_.compare_exchange_strong(expected, _when_all::stopped)) {
215224
op_state_->stop_source_.request_stop();
216225
}
@@ -238,8 +247,6 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
238247
using Env = env_of_t<Receiver>;
239248
using Completions = completion_sigs<Env, CvrefReceiverId>;
240249

241-
cudaError_t status_{cudaSuccess};
242-
243250
template <class SenderId, std::size_t Index>
244251
using child_op_state_t = exit_operation_state_t<
245252
__copy_cvref_t<WhenAll, stdexec::__t<SenderId>>,
@@ -285,6 +292,14 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
285292
// Stop callback is no longer needed. Destroy it.
286293
on_stop_.reset();
287294

295+
// See if any child operations completed with an error status:
296+
for (auto status: statuses_) {
297+
if (status != cudaSuccess) {
298+
status_ = status;
299+
break;
300+
}
301+
}
302+
288303
// Synchronize streams
289304
if (status_ == cudaSuccess) {
290305
if constexpr (stream_receiver<Receiver>) {
@@ -294,7 +309,7 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
294309

295310
for (int i = 0; i < sizeof...(SenderIds); i++) {
296311
if (status_ == cudaSuccess) {
297-
status_ = STDEXEC_DBG_ERR(cudaStreamWaitEvent(stream, events_[i], 0));
312+
status_ = events_[i].try_wait(stream);
298313
}
299314
}
300315
} else {
@@ -354,19 +369,10 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
354369
, child_states_{
355370
operation_t::connect_children_(this, static_cast<WhenAll&&>(when_all), Indices{})} {
356371
status_ = STDEXEC_DBG_ERR(cudaMallocManaged(&values_, sizeof(child_values_tuple_t)));
357-
for (std::size_t i = 0; i < sizeof...(SenderIds); ++i) {
358-
if (status_ == cudaSuccess) {
359-
status_ = STDEXEC_DBG_ERR(cudaEventCreate(&events_[i], cudaEventDisableTiming));
360-
}
361-
}
362372
}
363373

364374
~operation_t() {
365375
STDEXEC_DBG_ERR(cudaFree(values_));
366-
367-
for (int i = 0; i < sizeof...(SenderIds); i++) {
368-
STDEXEC_DBG_ERR(cudaEventDestroy(events_[i]));
369-
}
370376
}
371377

372378
STDEXEC_IMMOVABLE(operation_t);
@@ -388,7 +394,7 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
388394
}
389395
}
390396

391-
// tuple<optional<tuple<Vs1...>>, optional<tuple<Vs2...>>, ...>
397+
// tuple<tuple<Vs1...>, tuple<Vs2...>, ...>
392398
using child_values_tuple_t = //
393399
__if<
394400
sends_values<Completions>,
@@ -408,9 +414,11 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
408414
__uniqued_variant_for>;
409415

410416
Receiver rcvr_;
417+
cudaError_t status_{cudaSuccess};
411418
std::atomic<std::size_t> count_{sizeof...(SenderIds)};
412419
std::array<stream_provider_t, sizeof...(SenderIds)> stream_providers_;
413-
std::array<cudaEvent_t, sizeof...(SenderIds)> events_;
420+
std::array<detail::cuda_event, sizeof...(SenderIds)> events_{};
421+
std::array<cudaError_t, sizeof...(SenderIds)> statuses_{}; // all initialized to cudaSuccess
414422
child_op_states_tuple_t child_states_;
415423
// Could be non-atomic here and atomic_ref everywhere except __completion_fn
416424
std::atomic<_when_all::state_t> state_{_when_all::started};

include/stdexec/__detail/__when_all.hpp

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,11 +323,17 @@ namespace stdexec {
323323

324324
template <class _State, class _Receiver, class _Error>
325325
static void __set_error(_State& __state, _Receiver&, _Error&& __err) noexcept {
326+
// Transition to the "error" state and switch on the prior state.
326327
// TODO: What memory orderings are actually needed here?
327-
if (__error != __state.__state_.exchange(__error)) {
328+
switch (__state.__state_.exchange(__error)) {
329+
case __started:
330+
// We must request stop. When the previous state is __error or __stopped, then stop has
331+
// already been requested.
328332
__state.__stop_source_.request_stop();
329-
// We won the race, free to write the error into the operation
330-
// state without worry.
333+
[[fallthrough]];
334+
case __stopped:
335+
// We are the first child to complete with an error, so we must save the error. (Any
336+
// subsequent errors are ignored.)
331337
if constexpr (__nothrow_decay_copyable<_Error>) {
332338
__state.__errors_.template emplace<__decay_t<_Error>>(static_cast<_Error&&>(__err));
333339
} else {
@@ -337,6 +343,8 @@ namespace stdexec {
337343
__state.__errors_.template emplace<std::exception_ptr>(std::current_exception());
338344
}
339345
}
346+
break;
347+
case __error:; // We're already in the "error" state. Ignore the error.
340348
}
341349
}
342350

0 commit comments

Comments
 (0)