Skip to content

Commit c9d272a

Browse files
authored
Merge pull request #1560 from lucteo/bulk-unchunked
Bulk unchunked
2 parents 41f673e + 07dc08a commit c9d272a

4 files changed

Lines changed: 94 additions & 75 deletions

File tree

include/exec/system_context.hpp

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ namespace exec {
9898
/// The execution domain of the parallel_scheduler, used for the purposes of customizing
9999
/// sender algorithms such as `bulk_chunked` and `bulk_unchunked`.
100100
struct __parallel_scheduler_domain : stdexec::default_domain {
101-
/// Schedules new bulk chunked work.
101+
template <__bulk_chunked_or_unchunked _Sender>
102+
auto transform_sender(_Sender&& __sndr) const noexcept;
102103
template <__bulk_chunked_or_unchunked _Sender, class _Env>
103104
auto transform_sender(_Sender&& __sndr, const _Env& __env) const noexcept;
104105
};
@@ -374,9 +375,16 @@ namespace exec {
374375
auto __state = reinterpret_cast<_BulkState*>(this);
375376
if constexpr (_BulkState::__is_unchunked) {
376377
(void) __end; // not used
377-
std::apply(
378-
[&](auto&&... __args) { __state->__fun_(__begin, __args...); },
379-
*reinterpret_cast<std::tuple<_As...>*>(__base_t::__arguments_data_));
378+
// If we are not parallelizing, we need to run all the iterations sequentially.
379+
uint32_t __increments = 1;
380+
if constexpr (!_BulkState::__parallelize) {
381+
__increments = __state->__size_;
382+
}
383+
for (uint32_t __i = __begin; __i < __begin + __increments; __i++) {
384+
std::apply(
385+
[&](auto&&... __args) { __state->__fun_(__i, __args...); },
386+
*reinterpret_cast<std::tuple<_As...>*>(__base_t::__arguments_data_));
387+
}
380388
} else {
381389
// If we are not parallelizing, we need to pass the entire range to the functor.
382390
if constexpr (!_BulkState::__parallelize) {
@@ -464,7 +472,8 @@ namespace exec {
464472
// Schedule the bulk work on the system scheduler.
465473
// This will invoke `execute` on our receiver multiple times, and then a completion signal (e.g., `set_value`).
466474
if constexpr (_BulkState::__is_unchunked) {
467-
__scheduler->schedule_bulk_unchunked(__size, __storage, *__r);
475+
__scheduler
476+
->schedule_bulk_unchunked(_BulkState::__parallelize ? __size : 1, __storage, *__r);
468477
} else {
469478
__scheduler
470479
->schedule_bulk_chunked(_BulkState::__parallelize ? __size : 1, __storage, *__r);
@@ -677,9 +686,17 @@ namespace exec {
677686
template <class _Data, class _Previous>
678687
auto
679688
operator()(stdexec::bulk_unchunked_t, _Data&& __data, _Previous&& __previous) const noexcept {
680-
auto [__unused_pol, __shape, __fn] = static_cast<_Data&&>(__data);
681-
return __parallel_bulk_sender<true, _Previous, decltype(__shape), decltype(__fn), true>{
682-
__sched_, static_cast<_Previous&&>(__previous), __shape, std::move(__fn)};
689+
auto [__pol, __shape, __fn] = static_cast<_Data&&>(__data);
690+
using __policy_t = std::remove_cvref_t<decltype(__pol.__get())>;
691+
constexpr bool __parallelize = std::same_as<__policy_t, stdexec::parallel_policy>
692+
|| std::same_as<__policy_t, stdexec::parallel_unsequenced_policy>;
693+
return __parallel_bulk_sender<
694+
true,
695+
_Previous,
696+
decltype(__shape),
697+
decltype(__fn),
698+
__parallelize
699+
>{__sched_, static_cast<_Previous&&>(__previous), __shape, std::move(__fn)};
683700
}
684701

685702
parallel_scheduler __sched_;
@@ -690,6 +707,22 @@ namespace exec {
690707
using sender_concept = stdexec::sender_t;
691708
};
692709

710+
template <__bulk_chunked_or_unchunked _Sender>
711+
auto __parallel_scheduler_domain::transform_sender(_Sender&& __sndr)
712+
const noexcept {
713+
if constexpr (stdexec::__completes_on<_Sender, parallel_scheduler>) {
714+
auto __sched = stdexec::get_completion_scheduler<stdexec::set_value_t>(
715+
stdexec::get_env(__sndr));
716+
return stdexec::__sexpr_apply(
717+
static_cast<_Sender&&>(__sndr), __transform_parallel_bulk_sender{__sched});
718+
} else {
719+
static_assert(
720+
stdexec::__completes_on<_Sender, parallel_scheduler>,
721+
"No parallel_scheduler instance can be found in the sender's "
722+
"environment on which to schedule bulk work.");
723+
return __not_a_sender<stdexec::__name_of<_Sender>>();
724+
}
725+
}
693726
template <__bulk_chunked_or_unchunked _Sender, class _Env>
694727
auto __parallel_scheduler_domain::transform_sender(_Sender&& __sndr, const _Env& __env)
695728
const noexcept {

include/stdexec/__detail/__bulk.hpp

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ namespace stdexec {
141141
template <>
142142
struct __bulk_traits<bulk_unchunked_t> {
143143
using __on_not_callable =
144-
__callable_error<"In stdexec::bulk_unchunked(Sender, Shape, Function)..."_mstr>;
144+
__callable_error<"In stdexec::bulk_unchunked(Sender, Policy, Shape, Function)..."_mstr>;
145145

146146
// Curried function, after passing the required indices.
147147
template <class _Fun, class _Shape>
@@ -256,30 +256,7 @@ namespace stdexec {
256256
};
257257

258258
struct bulk_chunked_t : __generic_bulk_t<bulk_chunked_t> { };
259-
260-
struct bulk_unchunked_t {
261-
template <sender _Sender, integral _Shape, copy_constructible _Fun>
262-
STDEXEC_ATTRIBUTE(host, device)
263-
auto operator()(_Sender&& __sndr, _Shape __shape, _Fun __fun) const -> __well_formed_sender
264-
auto {
265-
auto __domain = __get_early_domain(__sndr);
266-
return stdexec::transform_sender(
267-
__domain,
268-
__make_sexpr<bulk_unchunked_t>(
269-
__data{par, __shape, static_cast<_Fun&&>(__fun)}, static_cast<_Sender&&>(__sndr)));
270-
}
271-
272-
template <integral _Shape, copy_constructible _Fun>
273-
STDEXEC_ATTRIBUTE(always_inline)
274-
auto operator()(_Shape __shape, _Fun __fun) const
275-
-> __binder_back<bulk_unchunked_t, _Shape, _Fun> {
276-
return {
277-
{static_cast<_Shape&&>(__shape), static_cast<_Fun&&>(__fun)},
278-
{},
279-
{}
280-
};
281-
}
282-
};
259+
struct bulk_unchunked_t : __generic_bulk_t<bulk_unchunked_t> { };
283260

284261
template <class _AlgoTag>
285262
struct __bulk_impl_base : __sexpr_defaults {
@@ -340,7 +317,6 @@ namespace stdexec {
340317
//! This implements the core default behavior for `bulk_unchunked`:
341318
//! When setting value, it loops over the shape and invokes the function.
342319
//! Note: This is not done in concurrently. That is customized by the scheduler.
343-
//! Calling it on a scheduler that is not concurrent is an error.
344320
static constexpr auto complete =
345321
[]<class _Tag, class _State, class _Receiver, class... _Args>(
346322
__ignore,
@@ -349,17 +325,7 @@ namespace stdexec {
349325
_Tag,
350326
_Args&&... __args) noexcept -> void {
351327
if constexpr (std::same_as<_Tag, set_value_t>) {
352-
// Intercept set_value and dispatch to the bulk operation.
353328
using __shape_t = decltype(__state.__shape_);
354-
constexpr bool __scheduler_available = requires {
355-
get_completion_scheduler<set_value_t>(get_env(__rcvr));
356-
};
357-
if constexpr (__scheduler_available) {
358-
// This default implementation doesn't run a scheduler with concurrent progres guarantees.
359-
constexpr auto __guarantee = get_forward_progress_guarantee(
360-
get_completion_scheduler<set_value_t>(get_env(__rcvr)));
361-
static_assert(__guarantee != forward_progress_guarantee::concurrent);
362-
}
363329
if constexpr (noexcept(__state.__fun_(__shape_t{}, __args...))) {
364330
// The noexcept version that doesn't need try/catch:
365331
for (__shape_t __i{}; __i != __state.__shape_; ++__i) {

test/exec/test_system_context.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ TEST_CASE("simple bulk_unchunked task on system context", "[types][system_schedu
221221
std::thread::id pool_ids[num_tasks];
222222
exec::parallel_scheduler sched = exec::get_parallel_scheduler();
223223

224-
auto bulk_snd = ex::bulk_unchunked(ex::schedule(sched), num_tasks, [&](unsigned long id) {
224+
auto bulk_snd = ex::bulk_unchunked(ex::schedule(sched), ex::par, num_tasks, [&](unsigned long id) {
225225
pool_ids[id] = std::this_thread::get_id();
226226
});
227227

@@ -233,6 +233,26 @@ TEST_CASE("simple bulk_unchunked task on system context", "[types][system_schedu
233233
}
234234
}
235235

236+
TEST_CASE("bulk_unchunked with seq will run everything on one thread", "[types][system_scheduler]") {
237+
std::thread::id this_id = std::this_thread::get_id();
238+
constexpr size_t num_tasks = 16;
239+
std::thread::id pool_ids[num_tasks];
240+
exec::parallel_scheduler sched = exec::get_parallel_scheduler();
241+
242+
auto bulk_snd = ex::bulk_unchunked(ex::schedule(sched), ex::seq, num_tasks, [&](unsigned long id) {
243+
pool_ids[id] = std::this_thread::get_id();
244+
std::this_thread::sleep_for(std::chrono::milliseconds{1});
245+
});
246+
247+
ex::sync_wait(std::move(bulk_snd));
248+
249+
for (auto pool_id: pool_ids) {
250+
REQUIRE(pool_id != std::thread::id{});
251+
REQUIRE(this_id != pool_id);
252+
REQUIRE(pool_id == pool_ids[0]); // All should be the same
253+
}
254+
}
255+
236256
TEST_CASE("bulk_chunked on parallel_scheduler performs chunking", "[types][system_scheduler]") {
237257
std::atomic<bool> has_chunking = false;
238258

0 commit comments

Comments
 (0)