Skip to content

Commit 4bb5927

Browse files
committed
unroll implementation of __tuple for up to 4 elements
1 parent 5bdb0c1 commit 4bb5927

8 files changed

Lines changed: 231 additions & 151 deletions

File tree

include/exec/sequence.hpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
#include <stdexec/__detail/__tuple.hpp>
2020
#include <stdexec/__detail/__variant.hpp>
2121

22+
STDEXEC_PRAGMA_PUSH()
23+
STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces")
24+
2225
namespace exec {
2326
namespace _seq {
2427
template <class... Sndrs>
@@ -96,7 +99,7 @@ namespace exec {
9699
// result is that the first sender in `sndrs` is not moved from, but the rest are.
97100
_ops.template emplace_from_at<0>(
98101
stdexec::connect,
99-
stdexec::__tup::get<0>(static_cast<CvrefSndrs&&>(sndrs)),
102+
sndrs.template __get<0>(static_cast<CvrefSndrs&&>(sndrs)),
100103
_rcvr_t<0>{this});
101104
}
102105

@@ -107,7 +110,7 @@ namespace exec {
107110
if constexpr (Idx == sizeof...(Sndrs) + 1) {
108111
stdexec::set_value(static_cast<Rcvr&&>(_rcvr), static_cast<Args&&>(args)...);
109112
} else {
110-
auto& sndr = stdexec::__tup::get<Idx>(_sndrs);
113+
auto& sndr = _sndrs.template __get<Idx>(_sndrs);
111114
auto& op = _ops.template emplace_from_at<Idx>(
112115
stdexec::connect, std::move(sndr), _rcvr_t<Idx>{this});
113116
stdexec::start(op);
@@ -197,7 +200,7 @@ namespace exec {
197200
template <class... Sndrs>
198201
requires(sizeof...(Sndrs) > 1) && stdexec::__domain::__has_common_domain<Sndrs...>
199202
STDEXEC_ATTRIBUTE((host, device)) _sndr<Sndrs...> sequence_t::operator()(Sndrs... sndrs) const {
200-
return _sndr<Sndrs...>{{}, {}, {{static_cast<Sndrs&&>(sndrs)}...}};
203+
return _sndr<Sndrs...>{{}, {}, {static_cast<Sndrs&&>(sndrs)...}};
201204
}
202205
} // namespace _seq
203206

@@ -215,3 +218,5 @@ namespace std {
215218
using type = stdexec::__m_at_c<I, exec::sequence_t, stdexec::__, Sndrs...>;
216219
};
217220
} // namespace std
221+
222+
STDEXEC_PRAGMA_POP()

include/exec/start_now.hpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
#include <atomic>
2929

30+
STDEXEC_PRAGMA_PUSH()
31+
STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces")
32+
3033
namespace exec {
3134
/////////////////////////////////////////////////////////////////////////////
3235
// NOT TO SPEC: __start_now
@@ -174,9 +177,9 @@ namespace exec {
174177
public:
175178
__storage(_Env&& __env, _AsyncScope& __scope, stdexec::__cvref_t<_SenderIds>&&... __sndr)
176179
: __storage_base<_EnvId>(static_cast<_Env&&>(__env), sizeof...(__sndr))
177-
, __op_state_{{stdexec::connect(
180+
, __op_state_{stdexec::connect(
178181
__scope.nest(static_cast<stdexec::__cvref_t<_SenderIds>&&>(__sndr)),
179-
__receiver_t{this})}...} {
182+
__receiver_t{this})...} {
180183
// Start all of the child operations
181184
__op_state_.for_each(stdexec::start, __op_state_);
182185
}
@@ -250,3 +253,5 @@ namespace exec {
250253
using __start_now_::start_now_t;
251254
inline constexpr start_now_t start_now{};
252255
} // namespace exec
256+
257+
STDEXEC_PRAGMA_POP()

include/exec/when_any.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ namespace exec {
206206
[this]<class... _Senders>(_Senders&&... __sndrs) noexcept(__nothrow_construct)
207207
-> __opstate_tuple {
208208
return __opstate_tuple{
209-
{stdexec::connect(static_cast<_Senders&&>(__sndrs), __receiver_t{this})}...};
209+
stdexec::connect(static_cast<_Senders&&>(__sndrs), __receiver_t{this})...};
210210
},
211211
static_cast<_SenderTuple&&>(__senders))} {
212212
}
@@ -253,7 +253,7 @@ namespace exec {
253253
template <__not_decays_to<__t>... _Senders>
254254
explicit(sizeof...(_Senders) == 1)
255255
__t(_Senders&&... __senders) noexcept((__nothrow_decay_copyable<_Senders> && ...))
256-
: __senders_{{static_cast<_Senders&&>(__senders)}...} {
256+
: __senders_{static_cast<_Senders&&>(__senders)...} {
257257
}
258258

259259
template <__decays_to<__t> _Self, receiver _Receiver>

include/nvexec/stream/when_all.cuh

Lines changed: 95 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
STDEXEC_PRAGMA_PUSH()
3535
STDEXEC_PRAGMA_IGNORE_EDG(cuda_compile)
36+
STDEXEC_PRAGMA_IGNORE_GNU("-Wmissing-braces")
3637

3738
namespace nvexec::_strm {
3839

@@ -83,7 +84,7 @@ namespace nvexec::_strm {
8384
template <class... As, class TupleT>
8485
__launch_bounds__(1) __global__ void copy_kernel(TupleT* tpl, As... as) {
8586
static_assert(trivially_copyable<As...>);
86-
*tpl = __decayed_tuple<As...>{{static_cast<As&&>(as)}...};
87+
*tpl = __decayed_tuple<As...>{static_cast<As&&>(as)...};
8788
}
8889

8990
template <class... Env, class... Senders>
@@ -110,6 +111,14 @@ namespace nvexec::_strm {
110111
__minvoke<__mpush_back<__q<completion_signatures>>, non_values, values>,
111112
non_values>;
112113
};
114+
115+
inline constexpr auto _sync_op = []<class OpT>(OpT& op) noexcept {
116+
if constexpr (STDEXEC_IS_BASE_OF(stream_op_state_base, OpT)) {
117+
if (op.stream_provider_.status_ == cudaSuccess) {
118+
op.stream_provider_.status_ = STDEXEC_DBG_ERR(cudaStreamSynchronize(op.get_stream()));
119+
}
120+
}
121+
};
113122
} // namespace _when_all
114123

115124
template <bool WithCompletionScheduler, class Scheduler, class... SenderIds>
@@ -131,7 +140,7 @@ namespace nvexec::_strm {
131140
template <class... Sndrs>
132141
explicit __t(context_state_t context_state, Sndrs&&... __sndrs)
133142
: env_{context_state}
134-
, sndrs_{{static_cast<Sndrs&&>(__sndrs)}...} {
143+
, sndrs_{static_cast<Sndrs&&>(__sndrs)...} {
135144
}
136145

137146
private:
@@ -159,76 +168,31 @@ namespace nvexec::_strm {
159168
struct receiver_t {
160169
using WhenAll = __copy_cvref_t<CvrefReceiverId, stdexec::__t<when_all_sender_t>>;
161170
using Receiver = stdexec::__t<__decay_t<CvrefReceiverId>>;
171+
using SenderId = __m_at_c<Index, SenderIds...>;
172+
using Completions = completion_sigs<env_of_t<Receiver>, CvrefReceiverId>;
162173
using Env = //
163174
make_terminal_stream_env_t<
164175
exec::make_env_t<env_of_t<Receiver>, stdexec::prop<get_stop_token_t, inplace_stop_token>>>;
165176

166177
struct __t : stream_receiver_base {
167178
using receiver_concept = stdexec::receiver_t;
168179
using __id = receiver_t;
169-
using SenderId = nvexec::detail::nth_type<Index, SenderIds...>;
170-
using Completions = completion_sigs<env_of_t<Receiver>, CvrefReceiverId>;
171-
172-
template <class Error>
173-
void _set_error_impl(Error&& err) noexcept {
174-
// Transition to the "error" state and switch on the prior state.
175-
// TODO: What memory orderings are actually needed here?
176-
switch (op_state_->state_.exchange(_when_all::error)) {
177-
case _when_all::started:
178-
// We must request stop. When the previous state is "error" or "stopped", then stop
179-
// has already been requested.
180-
op_state_->stop_source_.request_stop();
181-
[[fallthrough]];
182-
case _when_all::stopped:
183-
// We are the first child to complete with an error, so we must save the error.
184-
// (Any subsequent errors are ignored.)
185-
static_assert(__nothrow_constructible_from<__decay_t<Error>, Error>);
186-
op_state_->errors_.template emplace<__decay_t<Error>>(static_cast<Error&&>(err));
187-
break;
188-
case _when_all::error:; // We're already in the "error" state. Ignore the error.
189-
}
190-
}
191180

192181
template <class... Values>
193-
void set_value(Values&&... vals) && noexcept {
194-
if constexpr (__v<sends_values<Completions>>) {
195-
// We only need to bother recording the completion values
196-
// if we're not already in the "error" or "stopped" state.
197-
if (op_state_->state_ == _when_all::started) {
198-
cudaStream_t stream = __tup::get<Index>(op_state_->child_states_).get_stream();
199-
if constexpr (sizeof...(Values)) {
200-
_when_all::copy_kernel<Values&&...><<<1, 1, 0, stream>>>(
201-
&__tup::get<Index>(*op_state_->values_), static_cast<Values&&>(vals)...);
202-
op_state_->statuses_[Index] = cudaGetLastError();
203-
}
204-
205-
if constexpr (stream_receiver<Receiver>) {
206-
if (op_state_->statuses_[Index] == cudaSuccess) {
207-
op_state_->statuses_[Index] = op_state_->events_[Index].try_record(stream);
208-
}
209-
}
210-
}
211-
}
212-
op_state_->arrive();
182+
STDEXEC_ATTRIBUTE((always_inline)) void set_value(Values&&... vals) && noexcept {
183+
op_state_->template _set_value<Index>(static_cast<Values&&>(vals)...);
213184
}
214185

215186
template <class Error>
216-
requires tag_invocable<set_error_t, Receiver, Error>
217-
void set_error(Error&& err) && noexcept {
218-
_set_error_impl(static_cast<Error&&>(err));
219-
op_state_->arrive();
187+
STDEXEC_ATTRIBUTE((always_inline)) void set_error(Error&& err) && noexcept {
188+
op_state_->_set_error(static_cast<Error&&>(err));
220189
}
221190

222-
void set_stopped() && noexcept {
223-
_when_all::state_t expected = _when_all::started;
224-
// Transition to the "stopped" state if and only if we're in the "started" state. (If
225-
// this fails, it's because we're in the "error" state, which trumps cancellation.)
226-
if (op_state_->state_.compare_exchange_strong(expected, _when_all::stopped)) {
227-
op_state_->stop_source_.request_stop();
228-
}
229-
op_state_->arrive();
191+
STDEXEC_ATTRIBUTE((always_inline)) void set_stopped() && noexcept {
192+
op_state_->_set_stopped();
230193
}
231194

195+
[[nodiscard]]
232196
auto get_env() const noexcept -> Env {
233197
auto env = make_terminal_stream_env(
234198
exec::make_env(
@@ -264,11 +228,11 @@ namespace nvexec::_strm {
264228
using __child_ops_t = __tuple_for<child_op_state_t<SenderIds, Is>...>;
265229
return when_all.sndrs_.apply(
266230
[parent_op]<class... Children>(Children&&... children) -> __child_ops_t {
267-
return __child_ops_t{{_strm::exit_op_state(
231+
return __child_ops_t{_strm::exit_op_state(
268232
static_cast<Children&&>(children),
269233
stdexec::__t<receiver_t<CvrefReceiverId, Is>>{{}, parent_op},
270234
stdexec::get_completion_scheduler<set_value_t>(stdexec::get_env(children))
271-
.context_state_)}...};
235+
.context_state_)...};
272236
},
273237
static_cast<WhenAll&&>(when_all).sndrs_);
274238
}
@@ -277,20 +241,11 @@ namespace nvexec::_strm {
277241
decltype(operation_t::connect_children_({}, __declval<WhenAll>(), Indices{}));
278242

279243
void arrive() noexcept {
280-
if (0 == --count_) {
244+
if (1 == count_.fetch_sub(1)) {
281245
complete();
282246
}
283247
}
284248

285-
template <class OpT>
286-
static void sync(OpT& op) noexcept {
287-
if constexpr (STDEXEC_IS_BASE_OF(stream_op_state_base, OpT)) {
288-
if (op.stream_provider_.status_ == cudaSuccess) {
289-
op.stream_provider_.status_ = STDEXEC_DBG_ERR(cudaStreamSynchronize(op.get_stream()));
290-
}
291-
}
292-
}
293-
294249
void complete() noexcept {
295250
// Stop callback is no longer needed. Destroy it.
296251
on_stop_.reset();
@@ -316,7 +271,8 @@ namespace nvexec::_strm {
316271
}
317272
}
318273
} else {
319-
child_states_.apply([](auto&... ops) { (sync(ops), ...); }, child_states_);
274+
// Synchronize the streams of all the child operations
275+
child_states_.for_each(_when_all::_sync_op, child_states_);
320276
}
321277
}
322278

@@ -327,22 +283,18 @@ namespace nvexec::_strm {
327283
if constexpr (__v<sends_values<Completions>>) {
328284
// All child operations completed successfully:
329285
values_->apply(
330-
[this](auto&... opt_vals) -> void {
286+
[this]<class... Tuples>(Tuples&&... value_tupls) -> void {
331287
__tup::__cat_apply(
332-
[this](auto&... all_vals) -> void {
333-
stdexec::set_value(static_cast<Receiver&&>(rcvr_), std::move(all_vals)...);
288+
__mk_completion_fn(stdexec::set_value, rcvr_),
289+
static_cast<Tuples&&>(value_tupls)...);
334290
},
335-
opt_vals...);
336-
},
337-
*values_);
291+
static_cast<child_values_tuple_t&&>(*values_));
338292
}
339293
break;
340294
case _when_all::error:
341295
errors_.visit(
342-
[this](auto& err) noexcept {
343-
stdexec::set_error(static_cast<Receiver&&>(rcvr_), std::move(err));
344-
},
345-
errors_);
296+
__mk_completion_fn(stdexec::set_error, rcvr_),
297+
static_cast<errors_variant_t&&>(errors_));
346298
break;
347299
case _when_all::stopped:
348300
stdexec::set_stopped(static_cast<Receiver&&>(rcvr_));
@@ -389,24 +341,83 @@ namespace nvexec::_strm {
389341
// the child operations.
390342
stdexec::set_stopped(static_cast<Receiver&&>(rcvr_));
391343
} else {
344+
child_states_.for_each(stdexec::start, child_states_);
392345
if constexpr (sizeof...(SenderIds) == 0) {
393346
complete();
394-
} else {
395-
child_states_.for_each(stdexec::start, child_states_);
396347
}
397348
}
398349
}
399350

351+
template <class Error>
352+
void _set_error_impl(Error&& err) noexcept {
353+
// Transition to the "error" state and switch on the prior state.
354+
// TODO: What memory orderings are actually needed here?
355+
switch (state_.exchange(_when_all::error)) {
356+
case _when_all::started:
357+
// We must request stop. When the previous state is "error" or "stopped", then stop
358+
// has already been requested.
359+
stop_source_.request_stop();
360+
[[fallthrough]];
361+
case _when_all::stopped:
362+
// We are the first child to complete with an error, so we must save the error.
363+
// (Any subsequent errors are ignored.)
364+
static_assert(__nothrow_constructible_from<__decay_t<Error>, Error>);
365+
errors_.template emplace<__decay_t<Error>>(static_cast<Error&&>(err));
366+
break;
367+
case _when_all::error:; // We're already in the "error" state. Ignore the error.
368+
}
369+
}
370+
371+
template <std::size_t Index, class... Args>
372+
void _set_value(Args&&... args) noexcept {
373+
if constexpr (__v<sends_values<Completions>>) {
374+
// We only need to bother recording the completion values
375+
// if we're not already in the "error" or "stopped" state.
376+
if (state_.load() == _when_all::started) {
377+
cudaStream_t stream = child_states_.template __get<Index>(child_states_).get_stream();
378+
if constexpr (sizeof...(Args)) {
379+
_when_all::copy_kernel<Args&&...><<<1, 1, 0, stream>>>(
380+
&(values_->template __get<Index>(*values_)), static_cast<Args&&>(args)...);
381+
statuses_[Index] = cudaGetLastError();
382+
}
383+
384+
if constexpr (stream_receiver<Receiver>) {
385+
if (statuses_[Index] == cudaSuccess) {
386+
statuses_[Index] = events_[Index].try_record(stream);
387+
}
388+
}
389+
}
390+
}
391+
arrive();
392+
}
393+
394+
template <class Error>
395+
void _set_error(Error&& err) noexcept {
396+
_set_error_impl(static_cast<Error&&>(err));
397+
arrive();
398+
}
399+
400+
void _set_stopped() noexcept {
401+
auto expected = _when_all::started;
402+
// Transition to the "stopped" state if and only if we're in the
403+
// "started" state. (If this fails, it's because we're in an
404+
// error state, which trumps cancellation.)
405+
if (state_.compare_exchange_strong(expected, _when_all::stopped)) {
406+
stop_source_.request_stop();
407+
}
408+
arrive();
409+
}
410+
400411
// tuple<tuple<Vs1...>, tuple<Vs2...>, ...>
401412
using child_values_tuple_t = //
402413
__if<
403414
sends_values<Completions>,
404415
__minvoke<
405-
__q<__tuple_for>,
416+
__qq<__tuple_for>,
406417
__value_types_of_t<
407418
stdexec::__t<SenderIds>,
408419
_when_all::env_t<Env>,
409-
__q<__decayed_tuple>,
420+
__qq<__decayed_tuple>,
410421
__msingle_or<void>>...>,
411422
__>;
412423

@@ -447,6 +458,7 @@ namespace nvexec::_strm {
447458
return {};
448459
}
449460

461+
[[nodiscard]]
450462
auto get_env() const noexcept -> const env& {
451463
return env_;
452464
}

include/stdexec/__detail/__env.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ namespace stdexec {
401401
STDEXEC_ATTRIBUTE((always_inline)) constexpr decltype(auto) __get_1st() const noexcept {
402402
constexpr bool __flags[] = {__queryable<_Envs, _Query, _Args...>...};
403403
constexpr std::size_t __idx = __pos_of(__flags, __flags + sizeof...(_Envs));
404-
return __tup::get<__idx>(__tup_);
404+
return __tup_.template __get<__idx>(__tup_);
405405
}
406406

407407
template <class _Query, class... _Args>

0 commit comments

Comments
 (0)