Skip to content

Commit df2ce08

Browse files
authored
make as_awaitable savvy to senders that are known to complete inline (#1901)
* make `as_awaitable` savvy to senders that are known to complete inline
1 parent 18c0a62 commit df2ce08

7 files changed

Lines changed: 202 additions & 88 deletions

File tree

include/stdexec/__detail/__affine_on.hpp

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,9 @@ namespace STDEXEC
5151
struct affine_on_t
5252
{
5353
template <sender _Sender>
54-
constexpr auto operator()(_Sender&& __sndr) const -> __well_formed_sender auto
54+
constexpr auto operator()(_Sender &&__sndr) const -> __well_formed_sender auto
5555
{
56-
return __make_sexpr<affine_on_t>({}, static_cast<_Sender&&>(__sndr));
56+
return __make_sexpr<affine_on_t>({}, static_cast<_Sender &&>(__sndr));
5757
}
5858

5959
constexpr auto operator()() const noexcept
@@ -62,10 +62,10 @@ namespace STDEXEC
6262
}
6363

6464
template <class _Sender, class _Env>
65-
static constexpr auto transform_sender(set_value_t, _Sender&& __sndr, _Env const & __env)
65+
static constexpr auto transform_sender(set_value_t, _Sender &&__sndr, _Env const &__env)
6666
{
6767
static_assert(sender_expr_for<_Sender, affine_on_t>);
68-
auto& [__tag, __ign, __child] = __sndr;
68+
auto &[__tag, __ign, __child] = __sndr;
6969
using __child_t = decltype(__child);
7070
using __cv_child_t = __copy_cvref_t<_Sender, __child_t>;
7171
using __sched_t = __call_result_or_t<get_scheduler_t, __not_a_scheduler<>, _Env const &>;
@@ -116,14 +116,26 @@ namespace STDEXEC
116116

117117
namespace __affine_on
118118
{
119+
template <class _Attrs>
119120
struct __attrs
120121
{
121-
template <class _Tag>
122-
constexpr auto query(__get_completion_behavior_t<_Tag>) const noexcept
122+
template <class _Tag, class... _Env>
123+
requires __queryable_with<_Attrs, __get_completion_behavior_t<_Tag>, _Env const &...>
124+
constexpr auto query(__get_completion_behavior_t<_Tag>, _Env const &...) const noexcept
123125
{
124-
// FUTURE: when the child sender completes inline *and* the current scheduler also
125-
// completes inline, we can return "inline" here instead of "__asynchronous_affine".
126-
return __completion_behavior::__asynchronous_affine;
126+
using __behavior_t =
127+
__query_result_t<_Attrs, __get_completion_behavior_t<_Tag>, _Env const &...>;
128+
129+
// When the child sender completes inline, we can return "inline" here instead of
130+
// "__asynchronous_affine".
131+
if constexpr (__behavior_t::value == __completion_behavior::__inline_completion)
132+
{
133+
return __completion_behavior::__inline_completion;
134+
}
135+
else
136+
{
137+
return __completion_behavior::__asynchronous_affine;
138+
}
127139
}
128140
};
129141
} // namespace __affine_on
@@ -132,9 +144,9 @@ namespace STDEXEC
132144
struct __sexpr_impl<affine_on_t> : __sexpr_defaults
133145
{
134146
static constexpr auto __get_attrs = //
135-
[](__ignore, __ignore, __ignore) noexcept
147+
[]<class _Child>(__ignore, __ignore, _Child const &) noexcept
136148
{
137-
return __affine_on::__attrs{};
149+
return __affine_on::__attrs<env_of_t<_Child>>{};
138150
};
139151
};
140152
} // namespace STDEXEC

include/stdexec/__detail/__as_awaitable.hpp

Lines changed: 142 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -84,123 +84,214 @@ namespace STDEXEC
8484
using __expected_t =
8585
std::variant<std::monostate, __value_or_void_t<_Value>, std::exception_ptr>;
8686

87-
// Helper to cast a coroutine_handle<void> to coroutine_handle<_Promise>
88-
template <class _Promise>
89-
constexpr auto __coroutine_handle_cast(__std::coroutine_handle<> __hcoro) noexcept
90-
-> __std::coroutine_handle<_Promise>
91-
{
92-
return __std::coroutine_handle<_Promise>::from_address(__hcoro.address());
93-
}
87+
template <class _Tag, class _Sender, class... _Env>
88+
concept __completes_inline_for = __never_sends<_Tag, _Sender, _Env...>
89+
|| STDEXEC::__completes_inline<_Tag, env_of_t<_Sender>, _Env...>;
90+
91+
template <class _Sender, class... _Env>
92+
concept __completes_inline = __completes_inline_for<set_value_t, _Sender, _Env...>
93+
&& __completes_inline_for<set_error_t, _Sender, _Env...>
94+
&& __completes_inline_for<set_stopped_t, _Sender, _Env...>;
9495

9596
template <class _Value>
9697
struct __receiver_base
9798
{
9899
using receiver_concept = receiver_t;
99100

100101
template <class... _Us>
101-
requires __std::constructible_from<__value_or_void_t<_Value>, _Us...>
102102
void set_value(_Us&&... __us) noexcept
103103
{
104104
STDEXEC_TRY
105105
{
106-
__result_->template emplace<1>(static_cast<_Us&&>(__us)...);
107-
__continuation_.resume();
106+
__result_.template emplace<1>(static_cast<_Us&&>(__us)...);
108107
}
109108
STDEXEC_CATCH_ALL
110109
{
111-
STDEXEC::set_error(static_cast<__receiver_base&&>(*this), std::current_exception());
110+
__result_.template emplace<2>(std::current_exception());
112111
}
113112
}
114113

115114
template <class _Error>
116115
void set_error(_Error&& __err) noexcept
117116
{
118117
if constexpr (__decays_to<_Error, std::exception_ptr>)
119-
__result_->template emplace<2>(static_cast<_Error&&>(__err));
118+
__result_.template emplace<2>(static_cast<_Error&&>(__err));
120119
else if constexpr (__decays_to<_Error, std::error_code>)
121-
__result_->template emplace<2>(std::make_exception_ptr(std::system_error(__err)));
120+
__result_.template emplace<2>(std::make_exception_ptr(std::system_error(__err)));
122121
else
123-
__result_->template emplace<2>(std::make_exception_ptr(static_cast<_Error&&>(__err)));
124-
__continuation_.resume();
122+
__result_.template emplace<2>(std::make_exception_ptr(static_cast<_Error&&>(__err)));
125123
}
126124

127-
__expected_t<_Value>* __result_;
128-
__std::coroutine_handle<> __continuation_;
125+
__expected_t<_Value>& __result_;
129126
};
130127

131128
template <class _Promise, class _Value>
132-
struct __receiver : __receiver_base<_Value>
129+
struct __sync_receiver : __receiver_base<_Value>
133130
{
134-
constexpr void set_stopped() noexcept
131+
constexpr explicit __sync_receiver(__expected_t<_Value>& __result,
132+
__std::coroutine_handle<_Promise> __continuation) noexcept
133+
: __receiver_base<_Value>{__result}
134+
, __continuation_{__continuation}
135+
{}
136+
137+
void set_stopped() noexcept
135138
{
136-
auto __continuation = __coroutine_handle_cast<_Promise>(this->__continuation_);
137-
// Do not use type deduction here so that we perform any conversions necessary on
138-
// the stopped continuation:
139-
__std::coroutine_handle<> __on_stopped = __continuation.promise().unhandled_stopped();
140-
__on_stopped.resume();
139+
// no-op: the __result_ variant will remain engaged with the monostate
140+
// alternative, which signals that the operation was stopped.
141141
}
142142

143143
// Forward get_env query to the coroutine promise
144144
constexpr auto get_env() const noexcept -> env_of_t<_Promise&>
145145
{
146-
auto const __continuation = __coroutine_handle_cast<_Promise>(this->__continuation_);
147-
return STDEXEC::get_env(__continuation.promise());
146+
return STDEXEC::get_env(__continuation_.promise());
147+
}
148+
149+
__std::coroutine_handle<_Promise> __continuation_;
150+
};
151+
152+
// The receiver type used to connect to senders that could complete asynchronously.
153+
template <class _Promise, class _Value>
154+
struct __async_receiver : __sync_receiver<_Promise, _Value>
155+
{
156+
constexpr explicit __async_receiver(__expected_t<_Value>& __result,
157+
__std::coroutine_handle<_Promise> __continuation) noexcept
158+
: __sync_receiver<_Promise, _Value>{__result, __continuation}
159+
{}
160+
161+
template <class... _Us>
162+
void set_value(_Us&&... __us) noexcept
163+
{
164+
this->__sync_receiver<_Promise, _Value>::set_value(static_cast<_Us&&>(__us)...);
165+
this->__continuation_.resume();
166+
}
167+
168+
template <class _Error>
169+
void set_error(_Error&& __err) noexcept
170+
{
171+
this->__sync_receiver<_Promise, _Value>::set_error(static_cast<_Error&&>(__err));
172+
this->__continuation_.resume();
173+
}
174+
175+
constexpr void set_stopped() noexcept
176+
{
177+
STDEXEC_TRY
178+
{
179+
// Resuming the stopped continuation unwinds the coroutine stack until we reach
180+
// a promise that can handle the stopped signal. The coroutine referred to by
181+
// __continuation_ will never be resumed.
182+
__std::coroutine_handle<> __on_stopped =
183+
this->__continuation_.promise().unhandled_stopped();
184+
__on_stopped.resume();
185+
}
186+
STDEXEC_CATCH_ALL
187+
{
188+
this->__result_.template emplace<2>(std::current_exception());
189+
this->__continuation_.resume();
190+
}
148191
}
149192
};
150193

151194
template <class _Sender, class _Promise>
152-
using __receiver_t = __receiver<_Promise, __detail::__value_t<_Sender, _Promise>>;
195+
using __sync_receiver_t = __sync_receiver<_Promise, __detail::__value_t<_Sender, _Promise>>;
196+
197+
template <class _Sender, class _Promise>
198+
using __async_receiver_t = __async_receiver<_Promise, __detail::__value_t<_Sender, _Promise>>;
153199

154200
template <class _Value>
155201
struct __sender_awaitable_base
156202
{
157-
[[nodiscard]]
158-
constexpr auto await_ready() const noexcept -> bool
203+
static constexpr auto await_ready() noexcept -> bool
159204
{
160205
return false;
161206
}
162207

163208
constexpr auto await_resume() -> _Value
164209
{
165-
switch (__result_.index())
210+
// If the operation completed with set_stopped (as denoted by the monostate
211+
// alternative being active), we should not be resuming this coroutine at all.
212+
STDEXEC_ASSERT(__result_.index() != 0);
213+
if (__result_.index() == 2)
166214
{
167-
case 0: // receiver contract not satisfied
168-
STDEXEC_ASSERT(false && +"_Should never get here" == nullptr);
169-
break;
170-
case 1: // set_value
171-
if constexpr (!__same_as<_Value, void>)
172-
return static_cast<_Value&&>(std::get<1>(__result_));
173-
else
174-
return;
175-
case 2: // set_error
176-
std::rethrow_exception(std::get<2>(__result_));
215+
// The operation completed with set_error, so we need to rethrow the exception.
216+
std::rethrow_exception(std::move(std::get<2>(__result_)));
177217
}
178-
std::terminate();
218+
// The operation completed with set_value, so we can just return the value, which
219+
// may be void.
220+
return static_cast<std::add_rvalue_reference_t<_Value>>(std::get<1>(__result_));
179221
}
180222

181223
protected:
182-
__expected_t<_Value> __result_;
224+
__expected_t<_Value> __result_{};
183225
};
184226

227+
//////////////////////////////////////////////////////////////////////////////////////
228+
// __sender_awaitable: awaitable type returned by as_awaitable when given a sender
229+
// that does not have an as_awaitable member function
185230
template <class _Promise, class _Sender>
186231
struct __sender_awaitable : __sender_awaitable_base<__detail::__value_t<_Sender, _Promise>>
187232
{
188-
constexpr __sender_awaitable(_Sender&& sndr, __std::coroutine_handle<_Promise> __hcoro)
189-
noexcept(__nothrow_connectable<_Sender, __receiver>)
190-
: __op_state_(connect(static_cast<_Sender&&>(sndr),
191-
__receiver{
192-
{&this->__result_, __hcoro}
193-
}))
233+
constexpr explicit __sender_awaitable(_Sender&& __sndr,
234+
__std::coroutine_handle<_Promise> __hcoro)
235+
noexcept(__nothrow_connectable<_Sender, __receiver_t>)
236+
: __opstate_(STDEXEC::connect(static_cast<_Sender&&>(__sndr),
237+
__receiver_t(this->__result_, __hcoro)))
194238
{}
195239

196240
constexpr void await_suspend(__std::coroutine_handle<_Promise>) noexcept
197241
{
198-
STDEXEC::start(__op_state_);
242+
STDEXEC::start(__opstate_);
243+
}
244+
245+
private:
246+
using __receiver_t = __async_receiver_t<_Sender, _Promise>;
247+
connect_result_t<_Sender, __receiver_t> __opstate_;
248+
};
249+
250+
// When the sender is known to complete inline, we can connect and start the operation
251+
// in await_suspend.
252+
template <class _Promise, class _Sender>
253+
requires __completes_inline<_Sender, env_of_t<_Promise&>>
254+
struct __sender_awaitable<_Promise, _Sender>
255+
: __sender_awaitable_base<__detail::__value_t<_Sender, _Promise>>
256+
{
257+
constexpr explicit __sender_awaitable(_Sender&& sndr, __ignore)
258+
noexcept(__nothrow_move_constructible<_Sender>)
259+
: __sndr_(static_cast<_Sender&&>(sndr))
260+
{}
261+
262+
bool await_suspend(__std::coroutine_handle<_Promise> __hcoro)
263+
{
264+
{
265+
auto __opstate = STDEXEC::connect(static_cast<_Sender&&>(__sndr_),
266+
__receiver_t(this->__result_, __hcoro));
267+
// The following call to start will complete synchronously, writing its result
268+
// into the __result_ variant.
269+
STDEXEC::start(__opstate);
270+
}
271+
272+
if (this->__result_.index() == 0)
273+
{
274+
// The operation completed with set_stopped, so we need to call
275+
// unhandled_stopped() on the promise to propagate the stop signal. That will
276+
// result in the coroutine being torn down, so beware. We then resume the
277+
// returned coroutine handle (which may be a noop_coroutine).
278+
__std::coroutine_handle<> __on_stopped = __hcoro.promise().unhandled_stopped();
279+
__on_stopped.resume();
280+
281+
// By returning true, we indicate that the coroutine should not be resumed
282+
// (because it no longer exists).
283+
return true;
284+
}
285+
286+
// The operation completed with set_value or set_error, so we can just resume the
287+
// current coroutine. await_resume with either return the value or throw as
288+
// appropriate.
289+
return false;
199290
}
200291

201292
private:
202-
using __receiver = __receiver_t<_Sender, _Promise>;
203-
connect_result_t<_Sender, __receiver> __op_state_;
293+
using __receiver_t = __sync_receiver_t<_Sender, _Promise>;
294+
_Sender __sndr_;
204295
};
205296

206297
template <class _Sender, class _Promise>
@@ -211,7 +302,6 @@ namespace STDEXEC
211302
template <class _Sender, class _Promise>
212303
concept __awaitable_adapted_sender = sender_in<_Sender, env_of_t<_Promise&>>
213304
&& __minvocable_q<__detail::__value_t, _Sender, _Promise>
214-
&& sender_to<_Sender, __receiver_t<_Sender, _Promise>>
215305
&& requires(_Promise& __promise) {
216306
{
217307
__promise.unhandled_stopped()

0 commit comments

Comments
 (0)