Skip to content

Commit 52585fe

Browse files
committed
clean-up
1 parent cd42181 commit 52585fe

2 files changed

Lines changed: 26 additions & 20 deletions

File tree

include/nvexec/stream/when_all.cuh

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -168,19 +168,22 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
168168

169169
template <class Error>
170170
void _set_error_impl(Error&& err) noexcept {
171+
// Transition to the "error" state and switch on the prior state.
171172
// TODO: What memory orderings are actually needed here?
172-
auto old_state = op_state_->state_.exchange(_when_all::error);
173-
// If the previous state was error or stopped, then we have already requested
174-
// stop on the stop source. Otherwise, request stop.
175-
if (old_state == _when_all::started) {
173+
switch (op_state_->state_.exchange(_when_all::error)) {
174+
case _when_all::started:
175+
// We must request stop. When the previous state is "error" or "stopped", then stop
176+
// has already been requested.
176177
op_state_->stop_source_.request_stop();
177-
}
178-
// If we are the first child to complete with an error, we must save the error.
179-
// (Any subsequent errors are ignores.)
180-
if (old_state != _when_all::error) {
178+
[[fallthrough]];
179+
case _when_all::stopped:
180+
// We are the first child to complete with an error, so we must save the error.
181+
// (Any subsequent errors are ignores.)
182+
static_assert(__nothrow_constructible_from<__decay_t<Error>, Error>);
181183
op_state_->errors_.template emplace<__decay_t<Error>>(static_cast<Error&&>(err));
184+
break;
185+
case _when_all::error:; // We're already in the "error" state. Ignore the error.
182186
}
183-
op_state_->arrive();
184187
}
185188

186189
template <class... Values>
@@ -210,13 +213,13 @@ namespace nvexec::STDEXEC_STREAM_DETAIL_NS {
210213
requires tag_invocable<set_error_t, Receiver, Error>
211214
void set_error(Error&& err) && noexcept {
212215
_set_error_impl(static_cast<Error&&>(err));
216+
op_state_->arrive();
213217
}
214218

215219
void set_stopped() && noexcept {
216220
_when_all::state_t expected = _when_all::started;
217-
// Transition to the "stopped" state if and only if we're in the
218-
// "started" state. (If this fails, it's because we're in an
219-
// error state, which trumps cancellation.)
221+
// Transition to the "stopped" state if and only if we're in the "started" state. (If
222+
// this fails, it's because we're in the "error" state, which trumps cancellation.)
220223
if (op_state_->state_.compare_exchange_strong(expected, _when_all::stopped)) {
221224
op_state_->stop_source_.request_stop();
222225
}

include/stdexec/__detail/__when_all.hpp

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -323,16 +323,17 @@ namespace stdexec {
323323

324324
template <class _State, class _Receiver, class _Error>
325325
static void __set_error(_State& __state, _Receiver&, _Error&& __err) noexcept {
326+
// Transition to the "error" state and switch on the prior state.
326327
// TODO: What memory orderings are actually needed here?
327-
auto __old_state = __state.__state_.exchange(__error);
328-
// If the previous state was __error or __stopped, then we have already requested
329-
// stop on the stop source. Otherwise, request stop.
330-
if (__old_state == __started) {
328+
switch (__state.__state_.exchange(__error)) {
329+
case __started:
330+
// We must request stop. When the previous state is __error or __stopped, then stop has
331+
// already been requested.
331332
__state.__stop_source_.request_stop();
332-
}
333-
// If we are the first child to complete with an error, we must save the error.
334-
// (Any subsequent errors are ignores.)
335-
if (__old_state != __error) {
333+
[[fallthrough]];
334+
case __stopped:
335+
// We are the first child to complete with an error, so we must save the error. (Any
336+
// subsequent errors are ignored.)
336337
if constexpr (__nothrow_decay_copyable<_Error>) {
337338
__state.__errors_.template emplace<__decay_t<_Error>>(static_cast<_Error&&>(__err));
338339
} else {
@@ -342,6 +343,8 @@ namespace stdexec {
342343
__state.__errors_.template emplace<std::exception_ptr>(std::current_exception());
343344
}
344345
}
346+
break;
347+
case __error:; // We're already in the "error" state. Ignore the error.
345348
}
346349
}
347350

0 commit comments

Comments
 (0)