Skip to content

Commit b02c410

Browse files
committed
implement proposed resolution of P3718
1 parent 47185ef commit b02c410

43 files changed

Lines changed: 506 additions & 462 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

include/nvexec/nvtx.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ namespace nvexec {
121121
return {};
122122
}
123123

124-
auto get_env() const noexcept -> env_of_t<const Sender&> {
125-
return stdexec::get_env(sndr_);
124+
auto get_env() const noexcept -> stream_sender_attrs<Sender> {
125+
return {&sndr_};
126126
}
127127
};
128128
};

include/nvexec/stream/bulk.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ namespace nvexec::_strm {
147147
return {};
148148
}
149149

150-
auto get_env() const noexcept -> env_of_t<const Sender&> {
151-
return stdexec::get_env(sndr_);
150+
auto get_env() const noexcept -> stream_sender_attrs<Sender> {
151+
return {&sndr_};
152152
}
153153
};
154154
};
@@ -383,8 +383,8 @@ namespace nvexec::_strm {
383383
return {};
384384
}
385385

386-
auto get_env() const noexcept -> env_of_t<const Sender&> {
387-
return stdexec::get_env(sndr_);
386+
auto get_env() const noexcept -> stream_sender_attrs<Sender> {
387+
return {&sndr_};
388388
}
389389
};
390390
};

include/nvexec/stream/common.cuh

Lines changed: 48 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,39 @@ namespace nvexec {
6868
inline STDEXEC_ATTRIBUTE((host, device)) auto is_on_gpu() noexcept -> bool {
6969
return get_device_type() == device_type::device;
7070
}
71+
72+
namespace _strm {
73+
// Used by stream_domain to late-customize senders for execution
74+
// on the stream_scheduler.
75+
template <class Tag, class... Env>
76+
struct transform_sender_for;
77+
78+
template <class Tag>
79+
struct apply_sender_for;
80+
} // namespace _strm
7181
} // namespace nvexec
7282

7383
namespace nvexec {
7484
struct stream_context;
7585

76-
struct stream_domain;
86+
// The stream_domain is how the stream scheduler customizes the sender algorithms. All of the
87+
// algorithms use the current scheduler's domain to transform senders before starting them.
88+
struct stream_domain : stdexec::default_domain {
89+
template <stdexec::sender_expr Sender, class Tag = stdexec::tag_of_t<Sender>, class... Env>
90+
requires stdexec::
91+
__callable<stdexec::__sexpr_apply_t, Sender, _strm::transform_sender_for<Tag, Env...>>
92+
static auto transform_sender(Sender&& sndr, const Env&... env) {
93+
return stdexec::__sexpr_apply(
94+
static_cast<Sender&&>(sndr), _strm::transform_sender_for<Tag, Env...>{env...});
95+
}
96+
97+
template <class Tag, stdexec::sender Sender, class... Args>
98+
requires stdexec::__callable<_strm::apply_sender_for<Tag>, Sender, Args...>
99+
static auto apply_sender(Tag, Sender&& sndr, Args&&... args) {
100+
return _strm::apply_sender_for<Tag>{}(
101+
static_cast<Sender&&>(sndr), static_cast<Args&&>(args)...);
102+
}
103+
};
77104

78105
namespace _strm {
79106

@@ -86,14 +113,6 @@ namespace nvexec {
86113
((STDEXEC_IS_TRIVIALLY_COPYABLE(Ts) || std::is_reference_v<Ts>) && ...);
87114
#endif
88115

89-
// Used by stream_domain to late-customize senders for execution
90-
// on the stream_scheduler.
91-
template <class Tag, class... Env>
92-
struct transform_sender_for;
93-
94-
template <class Tag>
95-
struct apply_sender_for;
96-
97116
inline auto get_stream_priority(stream_priority priority) -> std::pair<int, cudaError_t> {
98117
int least{};
99118
int greatest{};
@@ -336,6 +355,26 @@ namespace nvexec {
336355
}
337356
};
338357

358+
template <class Sender>
359+
struct stream_sender_attrs {
360+
using __t = stream_sender_attrs;
361+
using __id = stream_sender_attrs;
362+
363+
STDEXEC_ATTRIBUTE((nodiscard)) constexpr auto query(get_domain_late_t) const noexcept -> stream_domain {
364+
return {};
365+
}
366+
367+
template <__forwarding_query Query>
368+
requires __env::__queryable<env_of_t<Sender>, Query>
369+
STDEXEC_ATTRIBUTE((nodiscard)) constexpr auto query(Query) const //
370+
noexcept(__env::__nothrow_queryable<env_of_t<Sender>, Query>)
371+
-> __env::__query_result_t<env_of_t<Sender>, Query> {
372+
return stdexec::get_env(*child_).query(Query{});
373+
}
374+
375+
const Sender* child_{};
376+
};
377+
339378
template <class BaseEnv>
340379
auto make_stream_env(BaseEnv&& base_env, stream_provider_t* stream_provider) noexcept {
341380
return __env::__join(

include/nvexec/stream/continues_on.cuh

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,10 @@ namespace nvexec::_strm {
116116
};
117117
} // namespace _continues_on
118118

119-
template <class SenderId>
119+
template <class Scheduler, class SenderId>
120120
struct continues_on_sender_t {
121121
using Sender = stdexec::__t<SenderId>;
122+
using LateDomain = __detail::__early_domain_of_t<Sender, __none_such>;
122123

123124
struct __t : stream_sender_base {
124125
using __id = continues_on_sender_t;
@@ -128,6 +129,7 @@ namespace nvexec::_strm {
128129
stdexec::__t<
129130
_continues_on::operation_state_t<__cvref_id<Self, Sender>, stdexec::__id<Receiver>>>;
130131

132+
Scheduler sched_;
131133
context_state_t context_state_;
132134
Sender sndr_;
133135

@@ -155,17 +157,18 @@ namespace nvexec::_strm {
155157
}
156158

157159
template <__decays_to<__t> Self, class... Env>
158-
static auto
159-
get_completion_signatures(Self&&, Env&&...) -> _completion_signatures_t<Self, Env...> {
160+
static auto get_completion_signatures(Self&&, Env&&...) //
161+
-> _completion_signatures_t<Self, Env...> {
160162
return {};
161163
}
162164

163-
auto get_env() const noexcept -> env_of_t<const Sender&> {
164-
return stdexec::get_env(sndr_);
165+
auto get_env() const noexcept -> __sched_attrs<Scheduler, LateDomain> {
166+
return {sched_, {}};
165167
}
166168

167-
__t(context_state_t context_state, Sender sndr)
168-
: context_state_(context_state)
169+
__t(Scheduler sched, context_state_t context_state, Sender sndr)
170+
: sched_(sched)
171+
, context_state_(context_state)
169172
, sndr_{static_cast<Sender&&>(sndr)} {
170173
}
171174
};
@@ -180,18 +183,18 @@ namespace nvexec::_strm {
180183
template <class Sched, class Sender>
181184
requires gpu_stream_scheduler<_current_scheduler_t<Sender>>
182185
auto operator()(__ignore, Sched sched, Sender&& sndr) const {
183-
using _sender_t = __t<continues_on_sender_t<__id<__decay_t<Sender>>>>;
186+
using _sender_t = __t<continues_on_sender_t<Sched, __id<__decay_t<Sender>>>>;
184187
auto stream_sched = get_completion_scheduler<set_value_t>(get_env(sndr));
185188
return schedule_from(
186189
static_cast<Sched&&>(sched),
187-
_sender_t{stream_sched.context_state_, static_cast<Sender&&>(sndr)});
190+
_sender_t{sched, stream_sched.context_state_, static_cast<Sender&&>(sndr)});
188191
}
189192
};
190193

191194
} // namespace nvexec::_strm
192195

193196
namespace stdexec::__detail {
194-
template <class SenderId>
195-
inline constexpr __mconst<nvexec::_strm::continues_on_sender_t<__name_of<__t<SenderId>>>>
196-
__name_of_v<nvexec::_strm::continues_on_sender_t<SenderId>>{};
197+
template <class Scheduler, class SenderId>
198+
inline constexpr __mconst<nvexec::_strm::continues_on_sender_t<Scheduler, __name_of<__t<SenderId>>>>
199+
__name_of_v<nvexec::_strm::continues_on_sender_t<Scheduler, SenderId>>{};
197200
} // namespace stdexec::__detail

include/nvexec/stream/ensure_started.cuh

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,7 @@ namespace nvexec::_strm {
4545
const inplace_stop_source& stop_source,
4646
stream_provider_t* stream_provider) noexcept {
4747
return make_stream_env(
48-
__env::__from{[&](get_stop_token_t) noexcept {
49-
return stop_source.get_token();
50-
}},
48+
__env::__from{[&](get_stop_token_t) noexcept { return stop_source.get_token(); }},
5149
stream_provider);
5250
}
5351

@@ -355,8 +353,8 @@ namespace nvexec::_strm {
355353
return operation_t<Receiver>{static_cast<Receiver&&>(rcvr), std::move(shared_state_)};
356354
}
357355

358-
auto get_env() const noexcept -> env_of_t<const Sender&> {
359-
return stdexec::get_env(sndr_);
356+
auto get_env() const noexcept -> stream_sender_attrs<Sender> {
357+
return {&sndr_};
360358
}
361359

362360
template <class... Tys>
@@ -398,7 +396,7 @@ namespace nvexec::_strm {
398396
using _sender_t = __t<ensure_started_sender_t<__id<__decay_t<Sender>>>>;
399397

400398
template <class Env, stream_completing_sender Sender>
401-
auto operator()(__ignore, Env&& /*env*/, Sender&& sndr) const -> _sender_t<Sender> {
399+
auto operator()(__ignore, Env&&, Sender&& sndr) const -> _sender_t<Sender> {
402400
auto sched = get_completion_scheduler<set_value_t>(get_env(sndr));
403401
return _sender_t<Sender>{sched.context_state_, static_cast<Sender&&>(sndr)};
404402
}

include/nvexec/stream/launch.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,8 @@ namespace nvexec {
151151
return {};
152152
}
153153

154-
auto get_env() const noexcept -> env_of_t<const Sender&> {
155-
return stdexec::get_env(sndr_);
154+
auto get_env() const noexcept -> stream_sender_attrs<Sender> {
155+
return {&sndr_};
156156
}
157157
};
158158
};

0 commit comments

Comments
 (0)