Skip to content

Commit 069a7a7

Browse files
AnyOf: pass correct executor instance and cancelation flag
1 parent 45cfa3d commit 069a7a7

4 files changed

Lines changed: 107 additions & 12 deletions

File tree

src/include/crhandle/taskhandle.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ struct Promise
148148
const bool * parentCanceled = nullptr;
149149
stdcr::coroutine_handle<> parentHandle = nullptr;
150150

151+
const bool & CancelationFlag() const noexcept
152+
{
153+
return parentCanceled ? *parentCanceled : canceled;
154+
}
151155
const E & Executor() const noexcept { return static_cast<const E &>(*this); }
152156
E & Executor() noexcept { return static_cast<E &>(*this); }
153157

@@ -168,9 +172,7 @@ struct Promise
168172
auto await_transform(TaskHandle<R, E> && innerTask) const
169173
{
170174
using InnerAwaiter = std::remove_reference_t<decltype(innerTask.Run())>;
171-
return CancelingAwaiter<InnerAwaiter>{
172-
innerTask.Run(Executor(), parentCanceled ? parentCanceled : &canceled),
173-
*this};
175+
return CancelingAwaiter<InnerAwaiter>{innerTask.Run(Executor(), &CancelationFlag()), *this};
174176
}
175177

176178
auto initial_suspend() noexcept

src/include/crhandle/taskutils.hpp

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,19 @@ namespace cr {
1111

1212
namespace internal {
1313

14+
template <typename P = void>
1415
struct CurrentHandleRetriever
1516
{
16-
stdcr::coroutine_handle<> handle;
17+
stdcr::coroutine_handle<P> handle;
1718

1819
bool await_ready() const noexcept { return false; }
19-
bool await_suspend(stdcr::coroutine_handle<> h) noexcept
20+
template <typename T = P>
21+
bool await_suspend(stdcr::coroutine_handle<T> h) noexcept
2022
{
2123
handle = h;
2224
return false;
2325
}
24-
stdcr::coroutine_handle<> await_resume() const noexcept { return handle; }
26+
stdcr::coroutine_handle<P> await_resume() const noexcept { return handle; }
2527
};
2628

2729
template <typename F, typename T, size_t... Is>
@@ -53,15 +55,16 @@ struct NonVoid<C<Ts...>>
5355

5456
} // namespace internal
5557

56-
5758
template <typename T>
5859
using NonVoid = typename internal::NonVoid<T>::Type;
5960

6061
template <Executor E, TaskResult... Rs>
6162
TaskHandle<NonVoid<std::variant<Rs...>>, E> AnyOf(TaskHandle<Rs, E>... ts)
6263
{
64+
using ThisTaskHandle = TaskHandle<NonVoid<std::variant<Rs...>>, E>;
65+
6366
std::optional<NonVoid<std::variant<Rs...>>> ret;
64-
stdcr::coroutine_handle<> thisHandle = nullptr;
67+
stdcr::coroutine_handle<> continuation = nullptr;
6568

6669
auto TaskWrapper = [&]<size_t I, typename R>(std::in_place_index_t<I> i,
6770
TaskHandle<R, E> task) -> TaskHandle<void, E> {
@@ -75,19 +78,23 @@ TaskHandle<NonVoid<std::variant<Rs...>>, E> AnyOf(TaskHandle<Rs, E>... ts)
7578
R tmp = co_await std::move(task);
7679
ret.emplace(i, std::move(tmp));
7780
}
78-
if (thisHandle)
79-
thisHandle.resume();
81+
if (continuation)
82+
continuation.resume();
8083
};
8184

8285
auto handles = internal::CreateArray(std::make_index_sequence<sizeof...(Rs)>{},
8386
std::forward_as_tuple(std::move(ts)...),
8487
TaskWrapper);
8588

89+
auto thisHandle =
90+
co_await internal::CurrentHandleRetriever<typename ThisTaskHandle::promise_type>{};
91+
const auto & thisPromise = thisHandle.promise();
92+
8693
for (auto & h : handles)
87-
h.Run();
94+
h.Run(thisPromise.Executor(), &thisPromise.CancelationFlag());
8895

8996
if (!ret.has_value()) {
90-
thisHandle = co_await internal::CurrentHandleRetriever{};
97+
continuation = thisHandle;
9198
co_await stdcr::suspend_always{};
9299
}
93100

test/dispatcher.hpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,12 @@ struct ManualDispatcher
5555
return true;
5656
}
5757

58+
void ProcessAll()
59+
{
60+
while (ProcessOneTask())
61+
;
62+
}
63+
5864
std::deque<Task> queue;
5965
};
6066

test/test_taskutils.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "crhandle/detachedhandle.hpp"
55
#include "crhandle/taskhandle.hpp"
66
#include "crhandle/taskutils.hpp"
7+
#include "dispatcher.hpp"
78

89
#include <optional>
910

@@ -152,4 +153,83 @@ TEST_F(TaskUtilsFixture, anyof_handles_immediate_task_and_short_circuits)
152153
EXPECT_FALSE(stringResult2.has_value());
153154
}
154155

156+
TEST_F(TaskUtilsFixture, anyof_uses_provided_executor_instance)
157+
{
158+
ManualDispatcher dispatcher;
159+
160+
struct State
161+
{
162+
stdcr::coroutine_handle<> handle = nullptr;
163+
} state1, state2;
164+
size_t index = std::numeric_limits<size_t>::max();
165+
166+
static auto VoidTask = [](State & s) -> cr::TaskHandle<void, ManualDispatcher::Executor> {
167+
co_await Awaitable<State>{s};
168+
};
169+
static auto OuterTask = [](State & state1,
170+
State & state2,
171+
size_t & index) -> cr::TaskHandle<void, ManualDispatcher::Executor> {
172+
auto result = co_await cr::AnyOf(VoidTask(state1), VoidTask(state2));
173+
index = result.index();
174+
};
175+
176+
auto handle = OuterTask(state1, state2, index);
177+
handle.Run(ManualDispatcher::Executor{&dispatcher});
178+
EXPECT_FALSE(state1.handle);
179+
EXPECT_FALSE(state2.handle);
180+
181+
dispatcher.ProcessAll();
182+
EXPECT_TRUE(state1.handle);
183+
EXPECT_TRUE(state2.handle);
184+
EXPECT_EQ(std::numeric_limits<size_t>::max(), index);
185+
186+
state2.handle.resume();
187+
EXPECT_EQ(std::numeric_limits<size_t>::max(), index);
188+
189+
dispatcher.ProcessAll();
190+
EXPECT_EQ(1u, index);
191+
192+
state1.handle.resume();
193+
dispatcher.ProcessAll();
194+
EXPECT_EQ(1u, index);
195+
}
196+
197+
TEST_F(TaskUtilsFixture, anyof_cancels_inner_tasks)
198+
{
199+
ManualDispatcher dispatcher;
200+
201+
struct State
202+
{
203+
stdcr::coroutine_handle<> handle = nullptr;
204+
bool done = false;
205+
} state1, state2;
206+
size_t index = std::numeric_limits<size_t>::max();
207+
208+
static auto VoidTask = [](State & s) -> cr::TaskHandle<void, ManualDispatcher::Executor> {
209+
co_await Awaitable<State>{s};
210+
s.done = true;
211+
};
212+
static auto OuterTask = [](State & state1,
213+
State & state2,
214+
size_t & index) -> cr::TaskHandle<void, ManualDispatcher::Executor> {
215+
auto result = co_await cr::AnyOf(VoidTask(state1), VoidTask(state2));
216+
index = result.index();
217+
};
218+
219+
auto handle = OuterTask(state1, state2, index);
220+
handle.Run(ManualDispatcher::Executor{&dispatcher});
221+
dispatcher.ProcessAll();
222+
EXPECT_TRUE(state1.handle);
223+
EXPECT_TRUE(state2.handle);
224+
EXPECT_EQ(std::numeric_limits<size_t>::max(), index);
225+
226+
handle = {};
227+
state1.handle.resume();
228+
state2.handle.resume();
229+
dispatcher.ProcessAll();
230+
EXPECT_EQ(std::numeric_limits<size_t>::max(), index);
231+
EXPECT_FALSE(state1.done);
232+
EXPECT_FALSE(state2.done);
233+
}
234+
155235
} // namespace

0 commit comments

Comments
 (0)