Skip to content

Commit ff5c580

Browse files
Implement cr::AllOf
1 parent 4c452be commit ff5c580

2 files changed

Lines changed: 258 additions & 4 deletions

File tree

src/include/crhandle/taskutils.hpp

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,31 @@ struct CurrentHandleRetriever
2828
};
2929

3030
template <typename F, typename T, size_t... Is>
31-
auto CreateArray(std::index_sequence<Is...>, T && tuple, F && transform)
31+
auto TupleToArray(std::index_sequence<Is...>, T && tuple, F && transform)
3232
{
3333
return std::array{transform(std::in_place_index<Is>, std::move(std::get<Is>(tuple)))...};
3434
}
3535

36+
template <typename... Ts>
37+
std::tuple<Ts...> ToNonOptional(std::tuple<std::optional<Ts>...> && t)
38+
{
39+
return std::apply(
40+
[](auto &&... opt) {
41+
return std::tuple{*std::move(opt)...};
42+
},
43+
std::move(t));
44+
}
45+
46+
template <typename... Ts>
47+
bool AllValuesSet(const std::tuple<std::optional<Ts>...> & t) noexcept
48+
{
49+
return std::apply(
50+
[](const auto &... opt) {
51+
return (... && opt.has_value());
52+
},
53+
t);
54+
}
55+
3656
template <typename T>
3757
struct NonVoid
3858
{
@@ -77,9 +97,9 @@ struct AnyOfFn
7797
continuation.resume();
7898
};
7999

80-
auto tasks = internal::CreateArray(std::make_index_sequence<sizeof...(Rs)>{},
81-
std::forward_as_tuple(std::move(ts)...),
82-
TaskWrapper);
100+
auto tasks = internal::TupleToArray(std::make_index_sequence<sizeof...(Rs)>{},
101+
std::forward_as_tuple(std::move(ts)...),
102+
TaskWrapper);
83103

84104
auto thisHandle =
85105
co_await internal::CurrentHandleRetriever<typename HandleType<E, Rs...>::promise_type>{};
@@ -99,6 +119,52 @@ struct AnyOfFn
99119

100120
inline constexpr AnyOfFn AnyOf;
101121

122+
struct AllOfFn
123+
{
124+
template <Executor E, TaskResult... Rs>
125+
using HandleType = TaskHandle<std::tuple<NonVoid<Rs>...>, E>;
126+
127+
template <Executor E, TaskResult... Rs>
128+
HandleType<E, Rs...> operator()(TaskHandle<Rs, E>... ts) const
129+
{
130+
std::tuple<std::optional<NonVoid<Rs>>...> ret;
131+
stdcr::coroutine_handle<> continuation = nullptr;
132+
133+
auto TaskWrapper = [&]<size_t I, typename R>(std::in_place_index_t<I>,
134+
TaskHandle<R, E> task) -> TaskHandle<void, E> {
135+
if constexpr (std::is_same_v<R, void>) {
136+
co_await std::move(task);
137+
std::get<I>(ret).emplace(NonVoid<void>{});
138+
} else {
139+
R tmp = co_await std::move(task);
140+
std::get<I>(ret).emplace(std::move(tmp));
141+
}
142+
if (continuation && internal::AllValuesSet(ret))
143+
continuation.resume();
144+
};
145+
146+
auto tasks = internal::TupleToArray(std::make_index_sequence<sizeof...(Rs)>{},
147+
std::forward_as_tuple(std::move(ts)...),
148+
TaskWrapper);
149+
150+
auto thisHandle =
151+
co_await internal::CurrentHandleRetriever<typename HandleType<E, Rs...>::promise_type>{};
152+
const auto & thisPromise = thisHandle.promise();
153+
154+
for (auto & h : tasks)
155+
h.Run(thisPromise.Executor(), &thisPromise.CancelationFlag());
156+
157+
if (!internal::AllValuesSet(ret)) {
158+
continuation = thisHandle;
159+
co_await stdcr::suspend_always{};
160+
}
161+
162+
co_return internal::ToNonOptional(std::move(ret));
163+
}
164+
};
165+
166+
inline constexpr AllOfFn AllOf;
167+
102168
} // namespace cr
103169

104170
#endif

test/test_taskutils.cpp

Lines changed: 188 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,4 +232,192 @@ TEST_F(TaskUtilsFixture, anyof_cancels_inner_tasks)
232232
EXPECT_FALSE(state2.done);
233233
}
234234

235+
TEST_F(TaskUtilsFixture, allof_delivers_all_results)
236+
{
237+
struct State
238+
{
239+
stdcr::coroutine_handle<> handle = nullptr;
240+
} state1, state2;
241+
242+
static auto IntegerTask = [](State & s) -> cr::TaskHandle<int> {
243+
co_await Awaitable<State>{s};
244+
co_return 42;
245+
};
246+
static auto StringTask = [](State & s) -> cr::TaskHandle<std::string> {
247+
co_await Awaitable<State>{s};
248+
co_return "Hello World";
249+
};
250+
static auto ImmediateTask = []() -> cr::TaskHandle<double> {
251+
co_return 3.14;
252+
};
253+
254+
std::optional<std::tuple<int, std::string, double>> result;
255+
256+
auto OuterTask = [&]() -> cr::DetachedHandle {
257+
result.emplace(co_await cr::AllOf(IntegerTask(state1), StringTask(state2), ImmediateTask()));
258+
};
259+
260+
OuterTask();
261+
EXPECT_TRUE(state1.handle);
262+
EXPECT_TRUE(state2.handle);
263+
EXPECT_FALSE(result);
264+
265+
state2.handle.resume();
266+
EXPECT_FALSE(result);
267+
268+
state1.handle.resume();
269+
EXPECT_TRUE(result);
270+
EXPECT_STREQ("Hello World", std::get<std::string>(*result).c_str());
271+
EXPECT_EQ(42, std::get<int>(*result));
272+
EXPECT_EQ(3.14, std::get<double>(*result));
273+
}
274+
275+
TEST_F(TaskUtilsFixture, allof_handles_void_tasks)
276+
{
277+
struct State
278+
{
279+
stdcr::coroutine_handle<> handle = nullptr;
280+
bool done = false;
281+
} state1, state2;
282+
283+
std::optional<std::tuple<std::monostate, std::monostate>> result;
284+
285+
static auto VoidTask = [](State & s) -> cr::TaskHandle<void> {
286+
co_await Awaitable<State>{s};
287+
s.done = true;
288+
};
289+
290+
auto OuterTask = [&]() -> cr::DetachedHandle {
291+
auto ret = co_await cr::AllOf(VoidTask(state1), VoidTask(state2));
292+
result.emplace(std::move(ret));
293+
};
294+
295+
OuterTask();
296+
EXPECT_TRUE(state1.handle);
297+
EXPECT_TRUE(state2.handle);
298+
EXPECT_FALSE(state1.done);
299+
EXPECT_FALSE(state2.done);
300+
EXPECT_FALSE(result);
301+
302+
state1.handle.resume();
303+
EXPECT_TRUE(state1.done);
304+
EXPECT_FALSE(state2.done);
305+
EXPECT_FALSE(result);
306+
307+
state2.handle.resume();
308+
EXPECT_TRUE(state1.done);
309+
EXPECT_TRUE(state2.done);
310+
EXPECT_TRUE(result);
311+
}
312+
313+
TEST_F(TaskUtilsFixture, allof_handles_immediate_tasks)
314+
{
315+
static auto IntegerTask = []() -> cr::TaskHandle<int> {
316+
co_return 42;
317+
};
318+
static auto StringTask = []() -> cr::TaskHandle<std::string> {
319+
co_return "Hello World";
320+
};
321+
322+
std::optional<std::tuple<std::string, int, std::string>> result;
323+
324+
auto OuterTask = [&]() -> cr::DetachedHandle {
325+
auto ret = co_await cr::AllOf(StringTask(), IntegerTask(), StringTask());
326+
result.emplace(std::move(ret));
327+
};
328+
329+
OuterTask();
330+
EXPECT_TRUE(result);
331+
EXPECT_STREQ("Hello World", std::get<0>(*result).c_str());
332+
EXPECT_EQ(42, std::get<1>(*result));
333+
EXPECT_STREQ("Hello World", std::get<2>(*result).c_str());
334+
}
335+
336+
TEST_F(TaskUtilsFixture, allof_uses_provided_executor_instance)
337+
{
338+
ManualDispatcher dispatcher;
339+
340+
struct State
341+
{
342+
stdcr::coroutine_handle<> handle = nullptr;
343+
bool done = false;
344+
} state1, state2;
345+
346+
std::optional<std::tuple<std::monostate, std::monostate>> result;
347+
348+
static auto VoidTask = [](State & s) -> cr::TaskHandle<void, ManualDispatcher::Executor> {
349+
co_await Awaitable<State>{s};
350+
s.done = true;
351+
};
352+
static auto OuterTask = [](State & state1,
353+
State & state2,
354+
auto & result) -> cr::TaskHandle<void, ManualDispatcher::Executor> {
355+
auto ret = co_await cr::AllOf(VoidTask(state1), VoidTask(state2));
356+
result.emplace(std::move(ret));
357+
};
358+
359+
auto handle = OuterTask(state1, state2, result);
360+
handle.Run(ManualDispatcher::Executor{&dispatcher});
361+
EXPECT_FALSE(state1.handle);
362+
EXPECT_FALSE(state2.handle);
363+
364+
dispatcher.ProcessAll();
365+
EXPECT_TRUE(state1.handle);
366+
EXPECT_TRUE(state2.handle);
367+
EXPECT_FALSE(state1.done);
368+
EXPECT_FALSE(state2.done);
369+
EXPECT_FALSE(result);
370+
371+
state2.handle.resume();
372+
dispatcher.ProcessAll();
373+
EXPECT_FALSE(state1.done);
374+
EXPECT_TRUE(state2.done);
375+
EXPECT_FALSE(result);
376+
377+
state1.handle.resume();
378+
dispatcher.ProcessAll();
379+
EXPECT_TRUE(state1.done);
380+
EXPECT_TRUE(state2.done);
381+
EXPECT_TRUE(result);
382+
}
383+
384+
TEST_F(TaskUtilsFixture, allof_cancels_inner_tasks)
385+
{
386+
ManualDispatcher dispatcher;
387+
388+
struct State
389+
{
390+
stdcr::coroutine_handle<> handle = nullptr;
391+
bool done = false;
392+
} state1, state2;
393+
394+
std::optional<std::tuple<std::monostate, std::monostate>> result;
395+
396+
static auto VoidTask = [](State & s) -> cr::TaskHandle<void, ManualDispatcher::Executor> {
397+
co_await Awaitable<State>{s};
398+
s.done = true;
399+
};
400+
static auto OuterTask = [](State & state1,
401+
State & state2,
402+
auto & result) -> cr::TaskHandle<void, ManualDispatcher::Executor> {
403+
auto ret = co_await cr::AllOf(VoidTask(state1), VoidTask(state2));
404+
result.emplace(std::move(ret));
405+
};
406+
407+
auto handle = OuterTask(state1, state2, result);
408+
handle.Run(ManualDispatcher::Executor{&dispatcher});
409+
dispatcher.ProcessAll();
410+
EXPECT_TRUE(state1.handle);
411+
EXPECT_TRUE(state2.handle);
412+
EXPECT_FALSE(result);
413+
414+
handle = {};
415+
state1.handle.resume();
416+
state2.handle.resume();
417+
dispatcher.ProcessAll();
418+
EXPECT_FALSE(result);
419+
EXPECT_FALSE(state1.done);
420+
EXPECT_FALSE(state2.done);
421+
}
422+
235423
} // namespace

0 commit comments

Comments
 (0)