Skip to content

Commit cf0a38f

Browse files
committed
bulk_unchunked also takes an execution policy, and it's not restricted to concurrent execution.
1 parent 6c1f415 commit cf0a38f

4 files changed

Lines changed: 56 additions & 74 deletions

File tree

include/exec/system_context.hpp

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,16 @@ namespace exec {
374374
auto __state = reinterpret_cast<_BulkState*>(this);
375375
if constexpr (_BulkState::__is_unchunked) {
376376
(void) __end; // not used
377-
std::apply(
378-
[&](auto&&... __args) { __state->__fun_(__begin, __args...); },
379-
*reinterpret_cast<std::tuple<_As...>*>(__base_t::__arguments_data_));
377+
// If we are not parallelizing, we need to run all the iterations sequentially.
378+
uint32_t __increments = 1;
379+
if constexpr (!_BulkState::__parallelize) {
380+
__increments = __state->__size_;
381+
}
382+
for (uint32_t __i = __begin; __i < __begin + __increments; __i++) {
383+
std::apply(
384+
[&](auto&&... __args) { __state->__fun_(__i, __args...); },
385+
*reinterpret_cast<std::tuple<_As...>*>(__base_t::__arguments_data_));
386+
}
380387
} else {
381388
// If we are not parallelizing, we need to pass the entire range to the functor.
382389
if constexpr (!_BulkState::__parallelize) {
@@ -464,7 +471,8 @@ namespace exec {
464471
// Schedule the bulk work on the system scheduler.
465472
// This will invoke `execute` on our receiver multiple times, and then a completion signal (e.g., `set_value`).
466473
if constexpr (_BulkState::__is_unchunked) {
467-
__scheduler->schedule_bulk_unchunked(__size, __storage, *__r);
474+
__scheduler
475+
->schedule_bulk_unchunked(_BulkState::__parallelize ? __size : 1, __storage, *__r);
468476
} else {
469477
__scheduler
470478
->schedule_bulk_chunked(_BulkState::__parallelize ? __size : 1, __storage, *__r);
@@ -677,9 +685,17 @@ namespace exec {
677685
template <class _Data, class _Previous>
678686
auto
679687
operator()(stdexec::bulk_unchunked_t, _Data&& __data, _Previous&& __previous) const noexcept {
680-
auto [__unused_pol, __shape, __fn] = static_cast<_Data&&>(__data);
681-
return __parallel_bulk_sender<true, _Previous, decltype(__shape), decltype(__fn), true>{
682-
__sched_, static_cast<_Previous&&>(__previous), __shape, std::move(__fn)};
688+
auto [__pol, __shape, __fn] = static_cast<_Data&&>(__data);
689+
using __policy_t = std::remove_cvref_t<decltype(__pol.__get())>;
690+
constexpr bool __parallelize = std::same_as<__policy_t, stdexec::parallel_policy>
691+
|| std::same_as<__policy_t, stdexec::parallel_unsequenced_policy>;
692+
return __parallel_bulk_sender<
693+
true,
694+
_Previous,
695+
decltype(__shape),
696+
decltype(__fn),
697+
__parallelize
698+
>{__sched_, static_cast<_Previous&&>(__previous), __shape, std::move(__fn)};
683699
}
684700

685701
parallel_scheduler __sched_;

include/stdexec/__detail/__bulk.hpp

Lines changed: 2 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ namespace stdexec {
141141
template <>
142142
struct __bulk_traits<bulk_unchunked_t> {
143143
using __on_not_callable =
144-
__callable_error<"In stdexec::bulk_unchunked(Sender, Shape, Function)..."_mstr>;
144+
__callable_error<"In stdexec::bulk_unchunked(Sender, Policy, Shape, Function)..."_mstr>;
145145

146146
// Curried function, after passing the required indices.
147147
template <class _Fun, class _Shape>
@@ -256,30 +256,7 @@ namespace stdexec {
256256
};
257257

258258
struct bulk_chunked_t : __generic_bulk_t<bulk_chunked_t> { };
259-
260-
struct bulk_unchunked_t {
261-
template <sender _Sender, integral _Shape, copy_constructible _Fun>
262-
STDEXEC_ATTRIBUTE(host, device)
263-
auto operator()(_Sender&& __sndr, _Shape __shape, _Fun __fun) const -> __well_formed_sender
264-
auto {
265-
auto __domain = __get_early_domain(__sndr);
266-
return stdexec::transform_sender(
267-
__domain,
268-
__make_sexpr<bulk_unchunked_t>(
269-
__data{par, __shape, static_cast<_Fun&&>(__fun)}, static_cast<_Sender&&>(__sndr)));
270-
}
271-
272-
template <integral _Shape, copy_constructible _Fun>
273-
STDEXEC_ATTRIBUTE(always_inline)
274-
auto operator()(_Shape __shape, _Fun __fun) const
275-
-> __binder_back<bulk_unchunked_t, _Shape, _Fun> {
276-
return {
277-
{static_cast<_Shape&&>(__shape), static_cast<_Fun&&>(__fun)},
278-
{},
279-
{}
280-
};
281-
}
282-
};
259+
struct bulk_unchunked_t : __generic_bulk_t<bulk_unchunked_t> { };
283260

284261
template <class _AlgoTag>
285262
struct __bulk_impl_base : __sexpr_defaults {
@@ -340,7 +317,6 @@ namespace stdexec {
340317
//! This implements the core default behavior for `bulk_unchunked`:
341318
//! When setting value, it loops over the shape and invokes the function.
342319
//! Note: This is not done in concurrently. That is customized by the scheduler.
343-
//! Calling it on a scheduler that is not concurrent is an error.
344320
static constexpr auto complete =
345321
[]<class _Tag, class _State, class _Receiver, class... _Args>(
346322
__ignore,
@@ -349,17 +325,7 @@ namespace stdexec {
349325
_Tag,
350326
_Args&&... __args) noexcept -> void {
351327
if constexpr (std::same_as<_Tag, set_value_t>) {
352-
// Intercept set_value and dispatch to the bulk operation.
353328
using __shape_t = decltype(__state.__shape_);
354-
constexpr bool __scheduler_available = requires {
355-
get_completion_scheduler<set_value_t>(get_env(__rcvr));
356-
};
357-
if constexpr (__scheduler_available) {
358-
// This default implementation doesn't run a scheduler with concurrent progres guarantees.
359-
constexpr auto __guarantee = get_forward_progress_guarantee(
360-
get_completion_scheduler<set_value_t>(get_env(__rcvr)));
361-
static_assert(__guarantee != forward_progress_guarantee::concurrent);
362-
}
363329
if constexpr (noexcept(__state.__fun_(__shape_t{}, __args...))) {
364330
// The noexcept version that doesn't need try/catch:
365331
for (__shape_t __i{}; __i != __state.__shape_; ++__i) {

test/exec/test_system_context.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ TEST_CASE("simple bulk_unchunked task on system context", "[types][system_schedu
221221
std::thread::id pool_ids[num_tasks];
222222
exec::parallel_scheduler sched = exec::get_parallel_scheduler();
223223

224-
auto bulk_snd = ex::bulk_unchunked(ex::schedule(sched), num_tasks, [&](unsigned long id) {
224+
auto bulk_snd = ex::bulk_unchunked(ex::schedule(sched), ex::par, num_tasks, [&](unsigned long id) {
225225
pool_ids[id] = std::this_thread::get_id();
226226
});
227227

test/stdexec/algos/adaptors/test_bulk.cpp

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ namespace {
7575
}
7676

7777
TEST_CASE("bulk_unchunked returns a sender", "[adaptors][bulk]") {
78-
auto snd = ex::bulk_unchunked(ex::just(19), 8, [](int, int) { });
78+
auto snd = ex::bulk_unchunked(ex::just(19), ex::par, 8, [](int, int) { });
7979
static_assert(ex::sender<decltype(snd)>);
8080
(void) snd;
8181
}
@@ -93,7 +93,7 @@ namespace {
9393
}
9494

9595
TEST_CASE("bulk_unchunked with environment returns a sender", "[adaptors][bulk]") {
96-
auto snd = ex::bulk_unchunked(ex::just(19), 8, [](int, int) { });
96+
auto snd = ex::bulk_unchunked(ex::just(19), ex::par, 8, [](int, int) { });
9797
static_assert(ex::sender_in<decltype(snd), ex::env<>>);
9898
(void) snd;
9999
}
@@ -109,7 +109,7 @@ namespace {
109109
}
110110

111111
TEST_CASE("bulk_unchunked can be piped", "[adaptors][bulk]") {
112-
ex::sender auto snd = ex::just() | ex::bulk_unchunked(42, [](int) { });
112+
ex::sender auto snd = ex::just() | ex::bulk_unchunked(ex::par, 42, [](int) { });
113113
(void) snd;
114114
}
115115

@@ -135,11 +135,11 @@ namespace {
135135

136136
TEST_CASE("bulk_unchunked keeps values_type from input sender", "[adaptors][bulk]") {
137137
constexpr int n = 42;
138-
check_val_types<ex::__mset<pack<>>>(ex::just() | ex::bulk_unchunked(n, [](int) { }));
138+
check_val_types<ex::__mset<pack<>>>(ex::just() | ex::bulk_unchunked(ex::par, n, [](int) { }));
139139
check_val_types<ex::__mset<pack<double>>>(
140-
ex::just(4.2) | ex::bulk_unchunked(n, [](int, double) { }));
140+
ex::just(4.2) | ex::bulk_unchunked(ex::par, n, [](int, double) { }));
141141
check_val_types<ex::__mset<pack<double, std::string>>>(
142-
ex::just(4.2, std::string{}) | ex::bulk_unchunked(n, [](int, double, std::string) { }));
142+
ex::just(4.2, std::string{}) | ex::bulk_unchunked(ex::par, n, [](int, double, std::string) { }));
143143
}
144144

145145
TEST_CASE("bulk keeps error_types from input sender", "[adaptors][bulk]") {
@@ -193,17 +193,17 @@ namespace {
193193
error_scheduler<int> sched3{43};
194194

195195
check_err_types<ex::__mset<>>(
196-
ex::transfer_just(sched1) | ex::bulk_unchunked(n, [](int) noexcept { }));
196+
ex::transfer_just(sched1) | ex::bulk_unchunked(ex::par, n, [](int) noexcept { }));
197197
check_err_types<ex::__mset<std::exception_ptr>>(
198-
ex::transfer_just(sched2) | ex::bulk_unchunked(n, [](int) noexcept { }));
198+
ex::transfer_just(sched2) | ex::bulk_unchunked(ex::par, n, [](int) noexcept { }));
199199
check_err_types<ex::__mset<int>>(
200-
ex::just_error(n) | ex::bulk_unchunked(n, [](int) noexcept { }));
200+
ex::just_error(n) | ex::bulk_unchunked(ex::par, n, [](int) noexcept { }));
201201
check_err_types<ex::__mset<int>>(
202-
ex::transfer_just(sched3) | ex::bulk_unchunked(n, [](int) noexcept { }));
202+
ex::transfer_just(sched3) | ex::bulk_unchunked(ex::par, n, [](int) noexcept { }));
203203
#if !STDEXEC_STD_NO_EXCEPTIONS()
204204
check_err_types<ex::__mset<std::exception_ptr, int>>(
205205
ex::transfer_just(sched3)
206-
| ex::bulk_unchunked(n, [](int) { throw std::logic_error{"err"}; }));
206+
| ex::bulk_unchunked(ex::par, n, [](int) { throw std::logic_error{"err"}; }));
207207
#endif
208208
}
209209

@@ -241,7 +241,7 @@ namespace {
241241
static int counter3[n]{};
242242
std::fill_n(counter3, n, 0);
243243

244-
ex::sender auto snd = ex::just() | ex::bulk_unchunked(n, function<int, n, counter3>);
244+
ex::sender auto snd = ex::just() | ex::bulk_unchunked(ex::par, n, function<int, n, counter3>);
245245
auto op = ex::connect(std::move(snd), expect_void_receiver{});
246246
ex::start(op);
247247

@@ -283,7 +283,7 @@ namespace {
283283
int counter[n]{0};
284284
function_object_t<int> fn{counter};
285285

286-
ex::sender auto snd = ex::just() | ex::bulk_unchunked(n, fn);
286+
ex::sender auto snd = ex::just() | ex::bulk_unchunked(ex::par, n, fn);
287287
auto op = ex::connect(std::move(snd), expect_void_receiver{});
288288
ex::start(op);
289289

@@ -325,7 +325,7 @@ namespace {
325325
constexpr int n = 9;
326326
int counter[n]{0};
327327

328-
ex::sender auto snd = ex::just() | ex::bulk_unchunked(n, [&](int i) { counter[i]++; });
328+
ex::sender auto snd = ex::just() | ex::bulk_unchunked(ex::par, n, [&](int i) { counter[i]++; });
329329
auto op = ex::connect(std::move(snd), expect_void_receiver{});
330330
ex::start(op);
331331

@@ -409,7 +409,7 @@ namespace {
409409
constexpr int magic_number = 42;
410410
int counter[n]{0};
411411

412-
auto snd = ex::just(magic_number) | ex::bulk_unchunked(n, [&](int i, int val) {
412+
auto snd = ex::just(magic_number) | ex::bulk_unchunked(ex::par, n, [&](int i, int val) {
413413
if (val == magic_number) {
414414
counter[i]++;
415415
}
@@ -459,7 +459,7 @@ namespace {
459459
std::iota(vals_expected.begin(), vals_expected.end(), 0);
460460

461461
auto snd = ex::just(std::move(vals))
462-
| ex::bulk_unchunked(n, [&](std::size_t i, std::vector<int>& vals) {
462+
| ex::bulk_unchunked(ex::par, n, [&](std::size_t i, std::vector<int>& vals) {
463463
vals[i] = static_cast<int>(i);
464464
});
465465
auto op = ex::connect(std::move(snd), expect_value_receiver{vals_expected});
@@ -494,7 +494,7 @@ namespace {
494494
constexpr int n = 2;
495495

496496
auto snd = ex::just(magic_number)
497-
| ex::bulk_unchunked(n, [](int, int) { return function_object_t<int>{nullptr}; });
497+
| ex::bulk_unchunked(ex::par, n, [](int, int) { return function_object_t<int>{nullptr}; });
498498

499499
auto op = ex::connect(std::move(snd), expect_value_receiver{magic_number});
500500
ex::start(op);
@@ -522,7 +522,7 @@ namespace {
522522
constexpr int n = 2;
523523

524524
auto snd = ex::just()
525-
| ex::bulk_unchunked(n, [](int) -> int { throw std::logic_error{"err"}; });
525+
| ex::bulk_unchunked(ex::par, n, [](int) -> int { throw std::logic_error{"err"}; });
526526
auto op = ex::connect(std::move(snd), expect_error_receiver{});
527527
ex::start(op);
528528
}
@@ -553,7 +553,7 @@ namespace {
553553
int called{};
554554

555555
auto snd = ex::just_error(std::string{"err"})
556-
| ex::bulk_unchunked(n, [&called](int) { called++; });
556+
| ex::bulk_unchunked(ex::par, n, [&called](int) { called++; });
557557
auto op = ex::connect(std::move(snd), expect_error_receiver{std::string{"err"}});
558558
ex::start(op);
559559
}
@@ -580,7 +580,7 @@ namespace {
580580
constexpr int n = 2;
581581
int called{};
582582

583-
auto snd = ex::just_stopped() | ex::bulk_unchunked(n, [&called](int) { called++; });
583+
auto snd = ex::just_stopped() | ex::bulk_unchunked(ex::par, n, [&called](int) { called++; });
584584
auto op = ex::connect(std::move(snd), expect_stopped_receiver{});
585585
ex::start(op);
586586
}
@@ -819,8 +819,8 @@ namespace {
819819
std::vector<int> counter(n, 42);
820820

821821
auto snd = ex::transfer_just(sch)
822-
| ex::bulk_unchunked(n, [&counter](std::size_t idx) { counter[idx] = 0; })
823-
| ex::bulk_unchunked(n, [&counter](std::size_t idx) { counter[idx]++; });
822+
| ex::bulk_unchunked(ex::par, n, [&counter](std::size_t idx) { counter[idx] = 0; })
823+
| ex::bulk_unchunked(ex::par, n, [&counter](std::size_t idx) { counter[idx]++; });
824824
stdexec::sync_wait(std::move(snd));
825825

826826
const std::size_t actual = static_cast<std::size_t>(
@@ -836,14 +836,14 @@ namespace {
836836
std::vector<int> counter(n, 42);
837837

838838
auto snd = ex::transfer_just(sch, 42)
839-
| ex::bulk_unchunked(
839+
| ex::bulk_unchunked(ex::par,
840840
n,
841841
[&counter](std::size_t idx, int val) {
842842
if (val == 42) {
843843
counter[idx] = 0;
844844
}
845845
})
846-
| ex::bulk_unchunked(n, [&counter](std::size_t idx, int val) {
846+
| ex::bulk_unchunked(ex::par, n, [&counter](std::size_t idx, int val) {
847847
if (val == 42) {
848848
counter[idx]++;
849849
}
@@ -868,9 +868,9 @@ namespace {
868868

869869
auto snd =
870870
ex::transfer_just(sch, std::move(vals))
871-
| ex::bulk_unchunked(
871+
| ex::bulk_unchunked(ex::par,
872872
n, [](std::size_t idx, std::vector<int>& vals) { vals[idx] = static_cast<int>(idx); })
873-
| ex::bulk_unchunked(n, [](std::size_t idx, std::vector<int>& vals) { ++vals[idx]; });
873+
| ex::bulk_unchunked(ex::par, n, [](std::size_t idx, std::vector<int>& vals) { ++vals[idx]; });
874874
auto [vals_actual] = stdexec::sync_wait(std::move(snd)).value();
875875

876876
CHECK(vals_actual == vals_expected);
@@ -881,7 +881,7 @@ namespace {
881881
SECTION("With exception") {
882882
constexpr int n = 9;
883883
auto snd = ex::transfer_just(sch)
884-
| ex::bulk_unchunked(n, [](int) { throw std::runtime_error("bulk_unchunked"); });
884+
| ex::bulk_unchunked(ex::par, n, [](int) { throw std::runtime_error("bulk_unchunked"); });
885885

886886
CHECK_THROWS_AS(stdexec::sync_wait(std::move(snd)), std::runtime_error);
887887
}
@@ -894,9 +894,9 @@ namespace {
894894

895895
stdexec::sender auto snd = stdexec::when_all(
896896
stdexec::schedule(sch)
897-
| stdexec::bulk_unchunked(n, [&](std::size_t id) { counters_1[id]++; }),
897+
| stdexec::bulk_unchunked(ex::par, n, [&](std::size_t id) { counters_1[id]++; }),
898898
stdexec::schedule(sch)
899-
| stdexec::bulk_unchunked(n, [&](std::size_t id) { counters_2[id]++; }));
899+
| stdexec::bulk_unchunked(ex::par, n, [&](std::size_t id) { counters_2[id]++; }));
900900

901901
stdexec::sync_wait(std::move(snd));
902902

@@ -1021,7 +1021,7 @@ namespace {
10211021
"default bulk_unchunked works with non-default constructible types",
10221022
"[adaptors][bulk]") {
10231023
ex::sender auto s = ex::just(non_default_constructible{42})
1024-
| ex::bulk_unchunked(1, [](int, auto&) { });
1024+
| ex::bulk_unchunked(ex::par, 1, [](int, auto&) { });
10251025
ex::sync_wait(std::move(s));
10261026
}
10271027

0 commit comments

Comments
 (0)