Skip to content

Commit e751c5e

Browse files
committed
The execution policy passed to bulk* is taken into consideration in static thread pool's customization.
1 parent f102fe9 commit e751c5e

2 files changed

Lines changed: 125 additions & 27 deletions

File tree

include/exec/static_thread_pool.hpp

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -175,14 +175,14 @@ namespace exec {
175175
// TODO: code to reconstitute a static_thread_pool_ schedule sender
176176
};
177177

178-
template <class SenderId, std::integral Shape, class Fun>
178+
template <class SenderId, bool parallelize, std::integral Shape, class Fun>
179179
struct bulk_sender {
180180
using Sender = stdexec::__t<SenderId>;
181181
struct __t;
182182
};
183183

184-
template <sender Sender, std::integral Shape, class Fun>
185-
using bulk_sender_t = __t<bulk_sender<__id<__decay_t<Sender>>, Shape, Fun>>;
184+
template <sender Sender, bool parallelize, std::integral Shape, class Fun>
185+
using bulk_sender_t = __t<bulk_sender<__id<__decay_t<Sender>>, parallelize, Shape, Fun>>;
186186

187187
#if STDEXEC_MSVC()
188188
// MSVCBUG https://developercommunity.visualstudio.com/t/Alias-template-with-pack-expansion-in-no/10437850
@@ -209,37 +209,45 @@ namespace exec {
209209
// there's no need to advertise completion with `exception_ptr`
210210
>;
211211

212-
template <class CvrefSender, class Receiver, class Shape, class Fun, bool MayThrow>
212+
template <class CvrefSender, class Receiver, bool parallelize, class Shape, class Fun, bool MayThrow>
213213
struct bulk_shared_state;
214214

215-
template <class CvrefSenderId, class ReceiverId, class Shape, class Fun, bool MayThrow>
215+
template <
216+
class CvrefSenderId,
217+
class ReceiverId,
218+
bool parallelize,
219+
class Shape,
220+
class Fun,
221+
bool MayThrow>
216222
struct bulk_receiver {
217223
using CvrefSender = __cvref_t<CvrefSenderId>;
218224
using Receiver = stdexec::__t<ReceiverId>;
219225
struct __t;
220226
};
221227

222-
template <class CvrefSender, class Receiver, class Shape, class Fun, bool MayThrow>
223-
using bulk_receiver_t =
224-
__t<bulk_receiver<__cvref_id<CvrefSender>, __id<Receiver>, Shape, Fun, MayThrow>>;
228+
template <class CvrefSender, class Receiver, bool parallelize, class Shape, class Fun, bool MayThrow>
229+
using bulk_receiver_t = __t<
230+
bulk_receiver<__cvref_id<CvrefSender>, __id<Receiver>, parallelize, Shape, Fun, MayThrow>>;
225231

226-
template <class CvrefSenderId, class ReceiverId, std::integral Shape, class Fun>
232+
template <class CvrefSenderId, class ReceiverId, bool parallelize, std::integral Shape, class Fun>
227233
struct bulk_op_state {
228234
using CvrefSender = stdexec::__cvref_t<CvrefSenderId>;
229235
using Receiver = stdexec::__t<ReceiverId>;
230236
struct __t;
231237
};
232238

233-
template <class Sender, class Receiver, std::integral Shape, class Fun>
234-
using bulk_op_state_t =
235-
__t<bulk_op_state<__id<__decay_t<Sender>>, __id<__decay_t<Receiver>>, Shape, Fun>>;
239+
template <class Sender, class Receiver, bool parallelize, std::integral Shape, class Fun>
240+
using bulk_op_state_t = __t<
241+
bulk_op_state<__id<__decay_t<Sender>>, __id<__decay_t<Receiver>>, parallelize, Shape, Fun>>;
236242

237243
struct transform_bulk {
238244
template <class Data, class Sender>
239245
auto operator()(bulk_chunked_t, Data&& data, Sender&& sndr) {
240246
auto [pol, shape, fun] = static_cast<Data&&>(data);
241-
// TODO: handle non-par execution policies
242-
return bulk_sender_t<Sender, decltype(shape), decltype(fun)>{
247+
using policy_t = std::remove_cvref_t<decltype(pol.__get())>;
248+
constexpr bool parallelize = std::same_as<policy_t, parallel_policy>
249+
|| std::same_as<policy_t, parallel_unsequenced_policy>;
250+
return bulk_sender_t<Sender, parallelize, decltype(shape), decltype(fun)>{
243251
pool_, static_cast<Sender&&>(sndr), shape, std::move(fun)};
244252
}
245253

@@ -1076,8 +1084,8 @@ namespace exec {
10761084

10771085
//////////////////////////////////////////////////////////////////////////////////////////////////
10781086
// What follows is the implementation for parallel bulk execution on static_thread_pool_.
1079-
template <class SenderId, std::integral Shape, class Fun>
1080-
struct static_thread_pool_::bulk_sender<SenderId, Shape, Fun>::__t {
1087+
template <class SenderId, bool parallelize, std::integral Shape, class Fun>
1088+
struct static_thread_pool_::bulk_sender<SenderId, parallelize, Shape, Fun>::__t {
10811089
using __id = bulk_sender;
10821090
using sender_concept = sender_t;
10831091

@@ -1108,7 +1116,8 @@ namespace exec {
11081116

11091117
template <class Self, class Receiver>
11101118
using bulk_op_state_t = //
1111-
stdexec::__t<bulk_op_state<__cvref_id<Self, Sender>, stdexec::__id<Receiver>, Shape, Fun>>;
1119+
stdexec::__t<
1120+
bulk_op_state<__cvref_id<Self, Sender>, stdexec::__id<Receiver>, parallelize, Shape, Fun>>;
11121121

11131122
template <__decays_to<__t> Self, receiver Receiver>
11141123
requires receiver_of<Receiver, __completions_t<Self, env_of_t<Receiver>>>
@@ -1139,7 +1148,7 @@ namespace exec {
11391148
};
11401149

11411150
//! The customized operation state for `stdexec::bulk` operations
1142-
template <class CvrefSender, class Receiver, class Shape, class Fun, bool MayThrow>
1151+
template <class CvrefSender, class Receiver, bool parallelize, class Shape, class Fun, bool MayThrow>
11431152
struct static_thread_pool_::bulk_shared_state {
11441153
//! The actual `bulk_task` holds a pointer to the shared state
11451154
//! and its `__execute` function reads from that shared state.
@@ -1223,8 +1232,12 @@ namespace exec {
12231232
//! That is, we don't need an agent for each of the shape values.
12241233
[[nodiscard]]
12251234
auto num_agents_required() const -> std::uint32_t {
1226-
return static_cast<std::uint32_t>(
1227-
std::min(shape_, static_cast<Shape>(pool_.available_parallelism())));
1235+
if constexpr (parallelize) {
1236+
return static_cast<std::uint32_t>(
1237+
std::min(shape_, static_cast<Shape>(pool_.available_parallelism())));
1238+
} else {
1239+
return static_cast<std::uint32_t>(1);
1240+
}
12281241
}
12291242

12301243
template <class F>
@@ -1253,12 +1266,20 @@ namespace exec {
12531266
};
12541267

12551268
//! A customized receiver to allow parallel execution of `stdexec::bulk` operations:
1256-
template <class CvrefSenderId, class ReceiverId, class Shape, class Fun, bool MayThrow>
1257-
struct static_thread_pool_::bulk_receiver<CvrefSenderId, ReceiverId, Shape, Fun, MayThrow>::__t {
1269+
template <
1270+
class CvrefSenderId,
1271+
class ReceiverId,
1272+
bool parallelize,
1273+
class Shape,
1274+
class Fun,
1275+
bool MayThrow>
1276+
struct static_thread_pool_::
1277+
bulk_receiver<CvrefSenderId, ReceiverId, parallelize, Shape, Fun, MayThrow>::__t {
12581278
using __id = bulk_receiver;
12591279
using receiver_concept = receiver_t;
12601280

1261-
using shared_state = bulk_shared_state<CvrefSender, Receiver, Shape, Fun, MayThrow>;
1281+
using shared_state =
1282+
bulk_shared_state<CvrefSender, Receiver, parallelize, Shape, Fun, MayThrow>;
12621283

12631284
shared_state& shared_state_;
12641285

@@ -1308,8 +1329,9 @@ namespace exec {
13081329
}
13091330
};
13101331

1311-
template <class CvrefSenderId, class ReceiverId, std::integral Shape, class Fun>
1312-
struct static_thread_pool_::bulk_op_state<CvrefSenderId, ReceiverId, Shape, Fun>::__t {
1332+
template <class CvrefSenderId, class ReceiverId, bool parallelize, std::integral Shape, class Fun>
1333+
struct static_thread_pool_::bulk_op_state<CvrefSenderId, ReceiverId, parallelize, Shape, Fun>::
1334+
__t {
13131335
using __id = bulk_op_state;
13141336

13151337
static constexpr bool may_throw = //
@@ -1319,8 +1341,9 @@ namespace exec {
13191341
__mbind_front_q<bulk_non_throwing, Fun, Shape>,
13201342
__q<__mand>>>;
13211343

1322-
using bulk_rcvr = bulk_receiver_t<CvrefSender, Receiver, Shape, Fun, may_throw>;
1323-
using shared_state = bulk_shared_state<CvrefSender, Receiver, Shape, Fun, may_throw>;
1344+
using bulk_rcvr = bulk_receiver_t<CvrefSender, Receiver, parallelize, Shape, Fun, may_throw>;
1345+
using shared_state =
1346+
bulk_shared_state<CvrefSender, Receiver, parallelize, Shape, Fun, may_throw>;
13241347
using inner_op_state = connect_result_t<CvrefSender, bulk_rcvr>;
13251348

13261349
shared_state shared_state_;

test/stdexec/algos/adaptors/test_bulk.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,6 +1023,81 @@ namespace {
10231023
ex::sync_wait(std::move(s));
10241024
}
10251025

1026+
template <class Sched, class Policy>
1027+
int number_of_threads_in_bulk(Sched sch, const Policy& policy, int n) {
1028+
std::vector<std::thread::id> tids(n);
1029+
auto fun = [&tids](std::size_t idx) {
1030+
tids[idx] = std::this_thread::get_id();
1031+
std::this_thread::sleep_for(std::chrono::milliseconds{10});
1032+
};
1033+
1034+
auto snd = ex::just() //
1035+
| ex::continues_on(sch) //
1036+
| ex::bulk(policy, tids.size(), fun);
1037+
stdexec::sync_wait(std::move(snd));
1038+
1039+
std::sort(tids.begin(), tids.end());
1040+
return static_cast<int>(std::unique(tids.begin(), tids.end()) - tids.begin());
1041+
}
1042+
1043+
TEST_CASE(
1044+
"static thread pool execute bulk work in accordance with the execution policy",
1045+
"[adaptors][bulk]") {
1046+
exec::static_thread_pool pool{4};
1047+
ex::scheduler auto sch = pool.get_scheduler();
1048+
1049+
SECTION("seq execution policy") {
1050+
REQUIRE(number_of_threads_in_bulk(sch, ex::seq, 42) == 1);
1051+
}
1052+
SECTION("unseq execution policy") {
1053+
REQUIRE(number_of_threads_in_bulk(sch, ex::unseq, 42) == 1);
1054+
}
1055+
SECTION("par execution policy") {
1056+
REQUIRE(number_of_threads_in_bulk(sch, ex::par, 42) > 1);
1057+
}
1058+
SECTION("par_unseq execution policy") {
1059+
REQUIRE(number_of_threads_in_bulk(sch, ex::par_unseq, 42) > 1);
1060+
}
1061+
}
1062+
1063+
template <class Sched, class Policy>
1064+
int number_of_threads_in_bulk_chunked(Sched sch, const Policy& policy, int n) {
1065+
std::vector<std::thread::id> tids(n);
1066+
auto fun = [&tids](std::size_t b, std::size_t e) {
1067+
while (b < e)
1068+
tids[b++] = std::this_thread::get_id();
1069+
std::this_thread::sleep_for(std::chrono::milliseconds{10});
1070+
};
1071+
1072+
auto snd = ex::just() //
1073+
| ex::continues_on(sch) //
1074+
| ex::bulk_chunked(policy, tids.size(), fun);
1075+
stdexec::sync_wait(std::move(snd));
1076+
1077+
std::sort(tids.begin(), tids.end());
1078+
return static_cast<int>(std::unique(tids.begin(), tids.end()) - tids.begin());
1079+
}
1080+
1081+
TEST_CASE(
1082+
"static thread pool execute bulk_chunked work in accordance with the execution policy",
1083+
"[adaptors][bulk]") {
1084+
exec::static_thread_pool pool{4};
1085+
ex::scheduler auto sch = pool.get_scheduler();
1086+
1087+
SECTION("seq execution policy") {
1088+
REQUIRE(number_of_threads_in_bulk_chunked(sch, ex::seq, 42) == 1);
1089+
}
1090+
SECTION("unseq execution policy") {
1091+
REQUIRE(number_of_threads_in_bulk_chunked(sch, ex::unseq, 42) == 1);
1092+
}
1093+
SECTION("par execution policy") {
1094+
REQUIRE(number_of_threads_in_bulk_chunked(sch, ex::par, 42) > 1);
1095+
}
1096+
SECTION("par_unseq execution policy") {
1097+
REQUIRE(number_of_threads_in_bulk_chunked(sch, ex::par_unseq, 42) > 1);
1098+
}
1099+
}
1100+
10261101
struct my_domain {
10271102
template <ex::sender_expr_for<ex::bulk_chunked_t> Sender, class... Env>
10281103
static auto transform_sender(Sender, const Env&...) {

0 commit comments

Comments
 (0)