|
| 1 | +/* |
| 2 | + * Copyright (c) 2025 NVIDIA Corporation |
| 3 | + * |
| 4 | + * Licensed under the Apache License Version 2.0 with LLVM Exceptions |
| 5 | + * (the "License"); you may not use this file except in compliance with |
| 6 | + * the License. You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * https://llvm.org/LICENSE.txt |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | +#pragma once |
| 17 | + |
| 18 | +#include "../stdexec/execution.hpp" |
| 19 | +#include "../stdexec/__detail/__receiver_ref.hpp" |
| 20 | + |
| 21 | +#include <exception> |
| 22 | + |
| 23 | +namespace exec { |
| 24 | + struct PREDECESSOR_RESULTS_ARE_NOT_DECAY_COPYABLE { }; |
| 25 | + |
| 26 | + struct fork_t { |
| 27 | + template <class Sndr, class... Closures> |
| 28 | + struct _sndr_t; |
| 29 | + |
| 30 | + struct _dematerialize_fn { |
| 31 | + struct _impl_fn { |
| 32 | + template <class Rcvr, class Tag, class... Args> |
| 33 | + STDEXEC_ATTRIBUTE((always_inline, host, device)) void operator()(Rcvr& rcvr, Tag, const Args&... args) const noexcept { |
| 34 | + Tag{}(static_cast<Rcvr&&>(rcvr), args...); |
| 35 | + } |
| 36 | + }; |
| 37 | + |
| 38 | + template <class Rcvr, class Tuple> |
| 39 | + STDEXEC_ATTRIBUTE((always_inline, host, device)) void operator()(Rcvr& rcvr, const Tuple& tupl) const noexcept { |
| 40 | + tupl.apply(_impl_fn{}, tupl, rcvr); |
| 41 | + } |
| 42 | + }; |
| 43 | + |
| 44 | + struct _mk_when_all_fn { |
| 45 | + template <class CacheSndr, class... Closures> |
| 46 | + STDEXEC_ATTRIBUTE((always_inline, host, device)) auto operator()(CacheSndr sndr, Closures&&... closures) const { |
| 47 | + return stdexec::when_all(static_cast<Closures&&>(closures)(sndr)...); |
| 48 | + } |
| 49 | + }; |
| 50 | + |
| 51 | + template <class Completions> |
| 52 | + using _maybe_eptr_completion_t = stdexec::__if_c< |
| 53 | + stdexec::__nothrow_decay_copyable_results_t<Completions>::value, |
| 54 | + stdexec::__mset_nil, |
| 55 | + stdexec::__tuple_for<stdexec::set_error_t, ::std::exception_ptr>>; |
| 56 | + |
| 57 | + template <class Completions> |
| 58 | + using _variant_t = typename stdexec::__mset_insert< |
| 59 | + stdexec::__for_each_completion_signature<Completions, stdexec::__decayed_tuple, stdexec::__mset>, |
| 60 | + _maybe_eptr_completion_t<Completions>>::template rebind<stdexec::__variant_for>; |
| 61 | + |
| 62 | + template <class Domain> |
| 63 | + struct _env_t { |
| 64 | + STDEXEC_ATTRIBUTE((always_inline, host, device)) static constexpr auto query(stdexec::get_domain_t) noexcept -> Domain { |
| 65 | + return {}; |
| 66 | + } |
| 67 | + }; |
| 68 | + |
| 69 | + template <class Tag, class... Args> |
| 70 | + using _cref_sig_t = Tag(const Args&...); |
| 71 | + |
| 72 | + // Given a set of async results, each of the form `tuple<Tag, Args...>`, compute |
| 73 | + // the corresponding completion signatures, where each signature is of the form |
| 74 | + // `Tag(const Args&...)`. |
| 75 | + template <class... AsyncResults> |
| 76 | + using _cache_sndr_completions_t = |
| 77 | + stdexec::completion_signatures<stdexec::__mapply<stdexec::__q<_cref_sig_t>, AsyncResults>...>; |
| 78 | + |
| 79 | + template <class Variant, class Domain> |
| 80 | + struct _cache_sndr_t { |
| 81 | + using sender_concept = stdexec::sender_t; |
| 82 | + |
| 83 | + template <class Rcvr> |
| 84 | + struct _opstate_t { |
| 85 | + using operation_state_concept = stdexec::operation_state_t; |
| 86 | + |
| 87 | + STDEXEC_ATTRIBUTE((host, device)) void start() noexcept { |
| 88 | + Variant::visit(_dematerialize_fn{}, *_results_, _rcvr_); |
| 89 | + } |
| 90 | + |
| 91 | + Rcvr _rcvr_; |
| 92 | + const Variant* _results_; |
| 93 | + }; |
| 94 | + |
| 95 | + template <class _Self, class... _Env> |
| 96 | + STDEXEC_ATTRIBUTE((host, device)) static auto get_completion_signatures(_Self&&, _Env&&...) noexcept { |
| 97 | + return stdexec::__mapply<stdexec::__qq<_cache_sndr_completions_t>, Variant>{}; |
| 98 | + } |
| 99 | + |
| 100 | + template <class Rcvr> |
| 101 | + STDEXEC_ATTRIBUTE((host, device)) auto connect(Rcvr rcvr) const -> _opstate_t<Rcvr> { |
| 102 | + return _opstate_t<Rcvr>{static_cast<Rcvr&&>(rcvr), _results_}; |
| 103 | + } |
| 104 | + |
| 105 | + STDEXEC_ATTRIBUTE((host, device)) static auto get_env() noexcept -> _env_t<Domain> { |
| 106 | + return {}; |
| 107 | + } |
| 108 | + |
| 109 | + const Variant* _results_; |
| 110 | + }; |
| 111 | + |
| 112 | + template <class Completions, class Closures, class Domain> |
| 113 | + using _when_all_sndr_t = stdexec::__tup:: |
| 114 | + __apply_result_t<_mk_when_all_fn, Closures, _cache_sndr_t<_variant_t<Completions>, Domain>>; |
| 115 | + |
| 116 | + template <class Sndr, class Closures, class Rcvr> |
| 117 | + struct _opstate_t { |
| 118 | + using operation_state_concept = stdexec::operation_state_t; |
| 119 | + using _env_t = stdexec::__call_result_t<stdexec::__env::__fwd_fn, stdexec::env_of_t<Rcvr>>; |
| 120 | + using _child_completions_t = stdexec::completion_signatures_of_t<Sndr, _env_t>; |
| 121 | + using _domain_t = stdexec::__early_domain_of_t<Sndr, stdexec::__none_such>; |
| 122 | + using _when_all_sndr_t = fork_t::_when_all_sndr_t<_child_completions_t, Closures, _domain_t>; |
| 123 | + using _child_opstate_t = |
| 124 | + stdexec::connect_result_t<Sndr, stdexec::__rcvr_ref_t<_opstate_t, _env_t>>; |
| 125 | + using _fork_opstate_t = |
| 126 | + stdexec::connect_result_t<_when_all_sndr_t, stdexec::__rcvr_ref_t<Rcvr>>; |
| 127 | + using _cache_sndr_t = fork_t::_cache_sndr_t<_variant_t<_child_completions_t>, _domain_t>; |
| 128 | + |
| 129 | + STDEXEC_ATTRIBUTE((host, device)) explicit _opstate_t(Sndr&& sndr, Closures&& closures, Rcvr rcvr) noexcept |
| 130 | + : _rcvr_(static_cast<Rcvr&&>(rcvr)) |
| 131 | + , _fork_opstate_( |
| 132 | + stdexec::connect( |
| 133 | + closures.apply( |
| 134 | + _mk_when_all_fn{}, |
| 135 | + static_cast<Closures&&>(closures), |
| 136 | + _cache_sndr_t{&_cache_}), |
| 137 | + stdexec::__ref_rcvr(_rcvr_))) { |
| 138 | + _child_opstate_.__construct_from( |
| 139 | + stdexec::connect, static_cast<Sndr&&>(sndr), stdexec::__ref_rcvr(*this)); |
| 140 | + } |
| 141 | + |
| 142 | + STDEXEC_IMMOVABLE(_opstate_t); |
| 143 | + |
| 144 | + STDEXEC_ATTRIBUTE((host, device)) ~_opstate_t() { |
| 145 | + // If this opstate was never started, we must explicitly destroy the _child_opstate_. |
| 146 | + if (_cache_.is_valueless()) { |
| 147 | + _child_opstate_.__destroy(); |
| 148 | + } |
| 149 | + } |
| 150 | + |
| 151 | + STDEXEC_ATTRIBUTE((host, device)) void start() noexcept { |
| 152 | + stdexec::start(_child_opstate_.__get()); |
| 153 | + } |
| 154 | + |
| 155 | + template <class Tag, class... Args> |
| 156 | + STDEXEC_ATTRIBUTE((host, device)) void _complete(Tag, Args&&... args) noexcept { |
| 157 | + try { |
| 158 | + using _tuple_t = stdexec::__decayed_tuple<Tag, Args...>; |
| 159 | + _cache_.template emplace<_tuple_t>(Tag{}, static_cast<Args&&>(args)...); |
| 160 | + } catch (...) { |
| 161 | + if constexpr (!stdexec::__nothrow_decay_copyable<Args...>) { |
| 162 | + using _tuple_t = stdexec::__tuple_for<stdexec::set_error_t, ::std::exception_ptr>; |
| 163 | + _cache_._results_.template emplace<_tuple_t>( |
| 164 | + stdexec::set_error, ::std::current_exception()); |
| 165 | + } |
| 166 | + } |
| 167 | + _child_opstate_.__destroy(); |
| 168 | + stdexec::start(_fork_opstate_); |
| 169 | + } |
| 170 | + |
| 171 | + template <class... Values> |
| 172 | + STDEXEC_ATTRIBUTE((always_inline, host, device)) void set_value(Values&&... values) noexcept { |
| 173 | + this->_complete(stdexec::set_value, static_cast<Values&&>(values)...); |
| 174 | + } |
| 175 | + |
| 176 | + template <class Error> |
| 177 | + STDEXEC_ATTRIBUTE((always_inline, host, device)) void set_error(Error&& err) noexcept { |
| 178 | + this->_complete(stdexec::set_error, static_cast<Error&&>(err)); |
| 179 | + } |
| 180 | + |
| 181 | + STDEXEC_ATTRIBUTE((always_inline, host, device)) void set_stopped() noexcept { |
| 182 | + this->_complete(stdexec::set_stopped); |
| 183 | + } |
| 184 | + |
| 185 | + STDEXEC_ATTRIBUTE((nodiscard, host, device)) constexpr auto get_env() const noexcept // |
| 186 | + -> stdexec::__fwd_env_t<stdexec::env_of_t<Rcvr>> { |
| 187 | + return stdexec::__env::__fwd_fn{}(stdexec::get_env(_rcvr_)); |
| 188 | + } |
| 189 | + |
| 190 | + Rcvr _rcvr_; |
| 191 | + _variant_t<_child_completions_t> _cache_{}; |
| 192 | + stdexec::__manual_lifetime<_child_opstate_t> _child_opstate_{}; |
| 193 | + _fork_opstate_t _fork_opstate_; |
| 194 | + }; |
| 195 | + |
| 196 | + template <class... Closures> |
| 197 | + struct _closure_t { |
| 198 | + using _closures_t = stdexec::__tuple_for<Closures...>; |
| 199 | + |
| 200 | + template <class Sndr> |
| 201 | + STDEXEC_ATTRIBUTE((host, device)) friend constexpr auto operator|(Sndr sndr, _closure_t self) noexcept // |
| 202 | + -> _sndr_t<Sndr, Closures...> { |
| 203 | + return _sndr_t<Sndr, Closures...>{ |
| 204 | + {}, static_cast<_closures_t&&>(self._closures_), static_cast<Sndr&&>(sndr)}; |
| 205 | + } |
| 206 | + |
| 207 | + _closures_t _closures_; |
| 208 | + }; |
| 209 | + |
| 210 | + template <class Sndr, class... Closures> |
| 211 | + requires stdexec::sender<Sndr> |
| 212 | + STDEXEC_ATTRIBUTE((host, device)) auto operator()(Sndr sndr, Closures... closures) const -> _sndr_t<Sndr, Closures...> { |
| 213 | + return {{}, {static_cast<Closures&&>(closures)...}, static_cast<Sndr&&>(sndr)}; |
| 214 | + } |
| 215 | + |
| 216 | + template <class... Closures> |
| 217 | + requires((!stdexec::sender<Closures>) && ...) |
| 218 | + STDEXEC_ATTRIBUTE((host, device)) auto operator()(Closures... closures) const -> _closure_t<Closures...> { |
| 219 | + return {{static_cast<Closures&&>(closures)...}}; |
| 220 | + } |
| 221 | + }; |
| 222 | + |
| 223 | + template <> |
| 224 | + struct fork_t::_env_t<stdexec::__none_such> { }; |
| 225 | + |
| 226 | + template <class Sndr, class... Closures> |
| 227 | + struct fork_t::_sndr_t { |
| 228 | + using sender_concept = stdexec::sender_t; |
| 229 | + using _closures_t = stdexec::__tuple_for<Closures...>; |
| 230 | + |
| 231 | + template <class Self, class... Env> |
| 232 | + STDEXEC_ATTRIBUTE((host, device)) static auto get_completion_signatures(Self&&, Env&&...) noexcept { |
| 233 | + using namespace stdexec; |
| 234 | + using _domain_t = __early_domain_of_t<Sndr, __none_such>; |
| 235 | + using _child_t = __copy_cvref_t<Self, Sndr>; |
| 236 | + using _child_completions_t = completion_signatures_of_t<_child_t, __fwd_env_t<Env>...>; |
| 237 | + using __decay_copyable_results_t = stdexec::__decay_copyable_results_t<_child_completions_t>; |
| 238 | + |
| 239 | + if constexpr (!__decay_copyable_results_t::value) { |
| 240 | + return _ERROR_< |
| 241 | + _WHAT_<>(PREDECESSOR_RESULTS_ARE_NOT_DECAY_COPYABLE), |
| 242 | + _IN_ALGORITHM_(exec::fork_t)>(); |
| 243 | + } else { |
| 244 | + using _sndr_t = _when_all_sndr_t<_child_completions_t, _closures_t, _domain_t>; |
| 245 | + return completion_signatures_of_t<_sndr_t, __fwd_env_t<Env>...>{}; |
| 246 | + } |
| 247 | + } |
| 248 | + |
| 249 | + template <class Rcvr> |
| 250 | + STDEXEC_ATTRIBUTE((host, device)) auto connect(Rcvr rcvr) && -> _opstate_t<Sndr, _closures_t, Rcvr> { |
| 251 | + return _opstate_t<Sndr, _closures_t, Rcvr>{ |
| 252 | + static_cast<Sndr&&>(sndr_), |
| 253 | + static_cast<_closures_t&&>(_closures_), |
| 254 | + static_cast<Rcvr&&>(rcvr)}; |
| 255 | + } |
| 256 | + |
| 257 | + template <class Rcvr> |
| 258 | + STDEXEC_ATTRIBUTE((host, device)) auto connect(Rcvr rcvr) const & -> _opstate_t<Sndr const &, _closures_t const &, Rcvr> { |
| 259 | + return _opstate_t<Sndr const &, _closures_t const &, Rcvr>{ |
| 260 | + sndr_, _closures_, static_cast<Rcvr&&>(rcvr)}; |
| 261 | + } |
| 262 | + |
| 263 | + STDEXEC_ATTRIBUTE((host, device)) constexpr auto get_env() const noexcept -> stdexec::__fwd_env_t<stdexec::env_of_t<Sndr>> { |
| 264 | + return stdexec::__env::__fwd_fn{}(stdexec::get_env(sndr_)); |
| 265 | + } |
| 266 | + |
| 267 | + STDEXEC_ATTRIBUTE((no_unique_address)) fork_t _tag_; |
| 268 | + stdexec::__tuple_for<Closures...> _closures_; |
| 269 | + Sndr sndr_; |
| 270 | + }; |
| 271 | + |
| 272 | + inline constexpr fork_t fork{}; |
| 273 | +} // namespace exec |
0 commit comments