Skip to content

Commit 49d6ec2

Browse files
authored
add exec::fork for broadcasting a sender's results to multiple continuations (#1541)
1 parent b929497 commit 49d6ec2

21 files changed

Lines changed: 622 additions & 146 deletions

include/exec/any_sender_of.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,7 +824,7 @@ namespace exec {
824824
};
825825

826826
template <class _Env>
827-
using __env_t = __env::__join_t<prop<get_stop_token_t, inplace_stop_token>, _Env>;
827+
using __env_t = __join_env_t<prop<get_stop_token_t, inplace_stop_token>, _Env>;
828828

829829
template <class _ReceiverId>
830830
struct __stoppable_receiver {

include/exec/async_scope.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ namespace exec {
693693
};
694694

695695
template <class _Env>
696-
using __spawn_env_t = __env::__join_t<_Env, __spawn_env_>;
696+
using __spawn_env_t = __join_env_t<_Env, __spawn_env_>;
697697

698698
template <class _EnvId>
699699
struct __spawn_op_base {

include/exec/env.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ namespace exec {
4646
stdexec::__nothrow_move_constructible _Base,
4747
stdexec::__nothrow_move_constructible _Env>
4848
auto operator()(_Base&& __base, _Env&& __env) const noexcept
49-
-> stdexec::__env::__join_t<_Env, _Base> {
49+
-> stdexec::__join_env_t<_Env, _Base> {
5050
return stdexec::__env::__join(static_cast<_Env&&>(__env), static_cast<_Base&&>(__base));
5151
}
5252

@@ -158,7 +158,7 @@ namespace exec {
158158
_Sender __sndr_;
159159
_Attrs __attrs_;
160160

161-
auto get_env() const noexcept -> __env::__join_t<const _Attrs&, env_of_t<_Sender>> {
161+
auto get_env() const noexcept -> __join_env_t<const _Attrs&, env_of_t<_Sender>> {
162162
return stdexec::__env::__join(__attrs_, stdexec::get_env(__sndr_));
163163
}
164164

include/exec/finally.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ namespace exec {
314314
};
315315
} // namespace __final
316316

317-
using __final ::finally_t;
318-
inline constexpr __final ::finally_t finally{};
317+
using __final::finally_t;
318+
inline constexpr __final::finally_t finally{};
319319
} // namespace exec
320320

321321
namespace stdexec {

include/exec/fork.hpp

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
/*
2+
* Copyright (c) 2025 NVIDIA Corporation
3+
*
4+
* Licensed under the Apache License Version 2.0 with LLVM Exceptions
5+
* (the "License"); you may not use this file except in compliance with
6+
* the License. You may obtain a copy of the License at
7+
*
8+
* https://llvm.org/LICENSE.txt
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#pragma once
17+
18+
#include "../stdexec/execution.hpp"
19+
#include "../stdexec/__detail/__receiver_ref.hpp"
20+
21+
#include <exception>
22+
23+
namespace exec {
24+
struct PREDECESSOR_RESULTS_ARE_NOT_DECAY_COPYABLE { };
25+
26+
struct fork_t {
27+
template <class Sndr, class... Closures>
28+
struct _sndr_t;
29+
30+
struct _dematerialize_fn {
31+
struct _impl_fn {
32+
template <class Rcvr, class Tag, class... Args>
33+
STDEXEC_ATTRIBUTE((always_inline, host, device)) void operator()(Rcvr& rcvr, Tag, const Args&... args) const noexcept {
34+
Tag{}(static_cast<Rcvr&&>(rcvr), args...);
35+
}
36+
};
37+
38+
template <class Rcvr, class Tuple>
39+
STDEXEC_ATTRIBUTE((always_inline, host, device)) void operator()(Rcvr& rcvr, const Tuple& tupl) const noexcept {
40+
tupl.apply(_impl_fn{}, tupl, rcvr);
41+
}
42+
};
43+
44+
struct _mk_when_all_fn {
45+
template <class CacheSndr, class... Closures>
46+
STDEXEC_ATTRIBUTE((always_inline, host, device)) auto operator()(CacheSndr sndr, Closures&&... closures) const {
47+
return stdexec::when_all(static_cast<Closures&&>(closures)(sndr)...);
48+
}
49+
};
50+
51+
template <class Completions>
52+
using _maybe_eptr_completion_t = stdexec::__if_c<
53+
stdexec::__nothrow_decay_copyable_results_t<Completions>::value,
54+
stdexec::__mset_nil,
55+
stdexec::__tuple_for<stdexec::set_error_t, ::std::exception_ptr>>;
56+
57+
template <class Completions>
58+
using _variant_t = typename stdexec::__mset_insert<
59+
stdexec::__for_each_completion_signature<Completions, stdexec::__decayed_tuple, stdexec::__mset>,
60+
_maybe_eptr_completion_t<Completions>>::template rebind<stdexec::__variant_for>;
61+
62+
template <class Domain>
63+
struct _env_t {
64+
STDEXEC_ATTRIBUTE((always_inline, host, device)) static constexpr auto query(stdexec::get_domain_t) noexcept -> Domain {
65+
return {};
66+
}
67+
};
68+
69+
template <class Tag, class... Args>
70+
using _cref_sig_t = Tag(const Args&...);
71+
72+
// Given a set of async results, each of the form `tuple<Tag, Args...>`, compute
73+
// the corresponding completion signatures, where each signature is of the form
74+
// `Tag(const Args&...)`.
75+
template <class... AsyncResults>
76+
using _cache_sndr_completions_t =
77+
stdexec::completion_signatures<stdexec::__mapply<stdexec::__q<_cref_sig_t>, AsyncResults>...>;
78+
79+
template <class Variant, class Domain>
80+
struct _cache_sndr_t {
81+
using sender_concept = stdexec::sender_t;
82+
83+
template <class Rcvr>
84+
struct _opstate_t {
85+
using operation_state_concept = stdexec::operation_state_t;
86+
87+
STDEXEC_ATTRIBUTE((host, device)) void start() noexcept {
88+
Variant::visit(_dematerialize_fn{}, *_results_, _rcvr_);
89+
}
90+
91+
Rcvr _rcvr_;
92+
const Variant* _results_;
93+
};
94+
95+
template <class _Self, class... _Env>
96+
STDEXEC_ATTRIBUTE((host, device)) static auto get_completion_signatures(_Self&&, _Env&&...) noexcept {
97+
return stdexec::__mapply<stdexec::__qq<_cache_sndr_completions_t>, Variant>{};
98+
}
99+
100+
template <class Rcvr>
101+
STDEXEC_ATTRIBUTE((host, device)) auto connect(Rcvr rcvr) const -> _opstate_t<Rcvr> {
102+
return _opstate_t<Rcvr>{static_cast<Rcvr&&>(rcvr), _results_};
103+
}
104+
105+
STDEXEC_ATTRIBUTE((host, device)) static auto get_env() noexcept -> _env_t<Domain> {
106+
return {};
107+
}
108+
109+
const Variant* _results_;
110+
};
111+
112+
template <class Completions, class Closures, class Domain>
113+
using _when_all_sndr_t = stdexec::__tup::
114+
__apply_result_t<_mk_when_all_fn, Closures, _cache_sndr_t<_variant_t<Completions>, Domain>>;
115+
116+
template <class Sndr, class Closures, class Rcvr>
117+
struct _opstate_t {
118+
using operation_state_concept = stdexec::operation_state_t;
119+
using _env_t = stdexec::__call_result_t<stdexec::__env::__fwd_fn, stdexec::env_of_t<Rcvr>>;
120+
using _child_completions_t = stdexec::completion_signatures_of_t<Sndr, _env_t>;
121+
using _domain_t = stdexec::__early_domain_of_t<Sndr, stdexec::__none_such>;
122+
using _when_all_sndr_t = fork_t::_when_all_sndr_t<_child_completions_t, Closures, _domain_t>;
123+
using _child_opstate_t =
124+
stdexec::connect_result_t<Sndr, stdexec::__rcvr_ref_t<_opstate_t, _env_t>>;
125+
using _fork_opstate_t =
126+
stdexec::connect_result_t<_when_all_sndr_t, stdexec::__rcvr_ref_t<Rcvr>>;
127+
using _cache_sndr_t = fork_t::_cache_sndr_t<_variant_t<_child_completions_t>, _domain_t>;
128+
129+
STDEXEC_ATTRIBUTE((host, device)) explicit _opstate_t(Sndr&& sndr, Closures&& closures, Rcvr rcvr) noexcept
130+
: _rcvr_(static_cast<Rcvr&&>(rcvr))
131+
, _fork_opstate_(
132+
stdexec::connect(
133+
closures.apply(
134+
_mk_when_all_fn{},
135+
static_cast<Closures&&>(closures),
136+
_cache_sndr_t{&_cache_}),
137+
stdexec::__ref_rcvr(_rcvr_))) {
138+
_child_opstate_.__construct_from(
139+
stdexec::connect, static_cast<Sndr&&>(sndr), stdexec::__ref_rcvr(*this));
140+
}
141+
142+
STDEXEC_IMMOVABLE(_opstate_t);
143+
144+
STDEXEC_ATTRIBUTE((host, device)) ~_opstate_t() {
145+
// If this opstate was never started, we must explicitly destroy the _child_opstate_.
146+
if (_cache_.is_valueless()) {
147+
_child_opstate_.__destroy();
148+
}
149+
}
150+
151+
STDEXEC_ATTRIBUTE((host, device)) void start() noexcept {
152+
stdexec::start(_child_opstate_.__get());
153+
}
154+
155+
template <class Tag, class... Args>
156+
STDEXEC_ATTRIBUTE((host, device)) void _complete(Tag, Args&&... args) noexcept {
157+
try {
158+
using _tuple_t = stdexec::__decayed_tuple<Tag, Args...>;
159+
_cache_.template emplace<_tuple_t>(Tag{}, static_cast<Args&&>(args)...);
160+
} catch (...) {
161+
if constexpr (!stdexec::__nothrow_decay_copyable<Args...>) {
162+
using _tuple_t = stdexec::__tuple_for<stdexec::set_error_t, ::std::exception_ptr>;
163+
_cache_._results_.template emplace<_tuple_t>(
164+
stdexec::set_error, ::std::current_exception());
165+
}
166+
}
167+
_child_opstate_.__destroy();
168+
stdexec::start(_fork_opstate_);
169+
}
170+
171+
template <class... Values>
172+
STDEXEC_ATTRIBUTE((always_inline, host, device)) void set_value(Values&&... values) noexcept {
173+
this->_complete(stdexec::set_value, static_cast<Values&&>(values)...);
174+
}
175+
176+
template <class Error>
177+
STDEXEC_ATTRIBUTE((always_inline, host, device)) void set_error(Error&& err) noexcept {
178+
this->_complete(stdexec::set_error, static_cast<Error&&>(err));
179+
}
180+
181+
STDEXEC_ATTRIBUTE((always_inline, host, device)) void set_stopped() noexcept {
182+
this->_complete(stdexec::set_stopped);
183+
}
184+
185+
STDEXEC_ATTRIBUTE((nodiscard, host, device)) constexpr auto get_env() const noexcept //
186+
-> stdexec::__fwd_env_t<stdexec::env_of_t<Rcvr>> {
187+
return stdexec::__env::__fwd_fn{}(stdexec::get_env(_rcvr_));
188+
}
189+
190+
Rcvr _rcvr_;
191+
_variant_t<_child_completions_t> _cache_{};
192+
stdexec::__manual_lifetime<_child_opstate_t> _child_opstate_{};
193+
_fork_opstate_t _fork_opstate_;
194+
};
195+
196+
template <class... Closures>
197+
struct _closure_t {
198+
using _closures_t = stdexec::__tuple_for<Closures...>;
199+
200+
template <class Sndr>
201+
STDEXEC_ATTRIBUTE((host, device)) friend constexpr auto operator|(Sndr sndr, _closure_t self) noexcept //
202+
-> _sndr_t<Sndr, Closures...> {
203+
return _sndr_t<Sndr, Closures...>{
204+
{}, static_cast<_closures_t&&>(self._closures_), static_cast<Sndr&&>(sndr)};
205+
}
206+
207+
_closures_t _closures_;
208+
};
209+
210+
template <class Sndr, class... Closures>
211+
requires stdexec::sender<Sndr>
212+
STDEXEC_ATTRIBUTE((host, device)) auto operator()(Sndr sndr, Closures... closures) const -> _sndr_t<Sndr, Closures...> {
213+
return {{}, {static_cast<Closures&&>(closures)...}, static_cast<Sndr&&>(sndr)};
214+
}
215+
216+
template <class... Closures>
217+
requires((!stdexec::sender<Closures>) && ...)
218+
STDEXEC_ATTRIBUTE((host, device)) auto operator()(Closures... closures) const -> _closure_t<Closures...> {
219+
return {{static_cast<Closures&&>(closures)...}};
220+
}
221+
};
222+
223+
template <>
224+
struct fork_t::_env_t<stdexec::__none_such> { };
225+
226+
template <class Sndr, class... Closures>
227+
struct fork_t::_sndr_t {
228+
using sender_concept = stdexec::sender_t;
229+
using _closures_t = stdexec::__tuple_for<Closures...>;
230+
231+
template <class Self, class... Env>
232+
STDEXEC_ATTRIBUTE((host, device)) static auto get_completion_signatures(Self&&, Env&&...) noexcept {
233+
using namespace stdexec;
234+
using _domain_t = __early_domain_of_t<Sndr, __none_such>;
235+
using _child_t = __copy_cvref_t<Self, Sndr>;
236+
using _child_completions_t = completion_signatures_of_t<_child_t, __fwd_env_t<Env>...>;
237+
using __decay_copyable_results_t = stdexec::__decay_copyable_results_t<_child_completions_t>;
238+
239+
if constexpr (!__decay_copyable_results_t::value) {
240+
return _ERROR_<
241+
_WHAT_<>(PREDECESSOR_RESULTS_ARE_NOT_DECAY_COPYABLE),
242+
_IN_ALGORITHM_(exec::fork_t)>();
243+
} else {
244+
using _sndr_t = _when_all_sndr_t<_child_completions_t, _closures_t, _domain_t>;
245+
return completion_signatures_of_t<_sndr_t, __fwd_env_t<Env>...>{};
246+
}
247+
}
248+
249+
template <class Rcvr>
250+
STDEXEC_ATTRIBUTE((host, device)) auto connect(Rcvr rcvr) && -> _opstate_t<Sndr, _closures_t, Rcvr> {
251+
return _opstate_t<Sndr, _closures_t, Rcvr>{
252+
static_cast<Sndr&&>(sndr_),
253+
static_cast<_closures_t&&>(_closures_),
254+
static_cast<Rcvr&&>(rcvr)};
255+
}
256+
257+
template <class Rcvr>
258+
STDEXEC_ATTRIBUTE((host, device)) auto connect(Rcvr rcvr) const & -> _opstate_t<Sndr const &, _closures_t const &, Rcvr> {
259+
return _opstate_t<Sndr const &, _closures_t const &, Rcvr>{
260+
sndr_, _closures_, static_cast<Rcvr&&>(rcvr)};
261+
}
262+
263+
STDEXEC_ATTRIBUTE((host, device)) constexpr auto get_env() const noexcept -> stdexec::__fwd_env_t<stdexec::env_of_t<Sndr>> {
264+
return stdexec::__env::__fwd_fn{}(stdexec::get_env(sndr_));
265+
}
266+
267+
STDEXEC_ATTRIBUTE((no_unique_address)) fork_t _tag_;
268+
stdexec::__tuple_for<Closures...> _closures_;
269+
Sndr sndr_;
270+
};
271+
272+
inline constexpr fork_t fork{};
273+
} // namespace exec

include/exec/when_any.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ namespace exec {
3737
};
3838

3939
template <class _BaseEnv>
40-
using __env_t = __env::__join_t<prop<get_stop_token_t, inplace_stop_token>, _BaseEnv>;
40+
using __env_t = __join_env_t<prop<get_stop_token_t, inplace_stop_token>, _BaseEnv>;
4141

4242
template <class... _Ts>
4343
using __nothrow_decay_copyable_and_move_constructible_t = __mbool<(

0 commit comments

Comments
 (0)