Skip to content

Commit 5e49b85

Browse files
Preemptive cancellation of TaskContext jobs
- Enable pre-execution cancellation in TaskContext by registering a cancellation token in constructor that invokes PreExecuteCancel() - Clear cancellation token in ~TaskContextImpl to prevent dangling refs - Fix CancellationContext::CancelOperation() to swap out the token before calling Cancel(), avoiding re-entrancy issues - Update BlockingCancel test to cover pre-execution cancellation Relates-To: HERESDK-12253 Signed-off-by: Mykhailo Diachenko <ext-mykhailo.z.diachenko@here.com>
1 parent b400ee0 commit 5e49b85

File tree

6 files changed

+132
-39
lines changed

6 files changed

+132
-39
lines changed

olp-cpp-sdk-core/include/olp/core/client/CancellationContext.inl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#pragma once
2121

22+
#include <utility>
2223
namespace olp {
2324
namespace client {
2425

@@ -58,9 +59,10 @@ inline void CancellationContext::CancelOperation() {
5859
return;
5960
}
6061

61-
impl_->sub_operation_cancel_token_.Cancel();
62-
impl_->sub_operation_cancel_token_ = CancellationToken();
62+
auto token = CancellationToken();
63+
std::swap(token, impl_->sub_operation_cancel_token_);
6364
impl_->is_cancelled_ = true;
65+
token.Cancel();
6466
}
6567

6668
inline bool CancellationContext::IsCancelled() const {

olp-cpp-sdk-core/include/olp/core/client/TaskContext.h

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,19 @@ class CORE_API TaskContext {
127127
*/
128128
void SetExecutors(Exec execute_func, Callback callback,
129129
client::CancellationContext context) {
130-
impl_ = std::make_shared<TaskContextImpl<ExecResult>>(
131-
std::move(execute_func), std::move(callback), std::move(context));
130+
auto impl = std::make_shared<TaskContextImpl<ExecResult>>(
131+
std::move(execute_func), std::move(callback), context);
132+
std::weak_ptr<TaskContextImpl<ExecResult>> weak_impl = impl;
133+
context.ExecuteOrCancelled(
134+
[weak_impl]() -> CancellationToken {
135+
return CancellationToken([weak_impl]() {
136+
if (auto impl = weak_impl.lock()) {
137+
impl->PreExecuteCancel();
138+
}
139+
});
140+
},
141+
[]() {});
142+
impl_ = std::move(impl);
132143
}
133144

134145
/**
@@ -195,8 +206,6 @@ class CORE_API TaskContext {
195206
context_(std::move(context)),
196207
state_{State::PENDING} {}
197208

198-
~TaskContextImpl() override{};
199-
200209
/**
201210
* @brief Checks for the cancellation, executes the task, and calls
202211
* the callback with the result or error.
@@ -249,6 +258,40 @@ class CORE_API TaskContext {
249258
state_.store(State::COMPLETED);
250259
}
251260

261+
void PreExecuteCancel() {
262+
State expected_state = State::PENDING;
263+
264+
if (!state_.compare_exchange_strong(expected_state, State::IN_PROGRESS)) {
265+
return;
266+
}
267+
268+
// Moving the user callback and function guarantee that they are
269+
// executed exactly once
270+
ExecuteFunc function = nullptr;
271+
UserCallback callback = nullptr;
272+
273+
{
274+
std::lock_guard<std::mutex> lock(mutex_);
275+
function = std::move(execute_func_);
276+
callback = std::move(callback_);
277+
}
278+
279+
Response user_response =
280+
client::ApiError(client::ErrorCode::Cancelled, "Cancelled");
281+
282+
if (callback) {
283+
callback(std::move(user_response));
284+
}
285+
286+
// Resources need to be released before the notification, else lambas
287+
// would have captured resources like network or `TaskScheduler`.
288+
function = nullptr;
289+
callback = nullptr;
290+
291+
condition_.Notify();
292+
state_.store(State::COMPLETED);
293+
}
294+
252295
/**
253296
* @brief Cancels the operation and waits for the notification.
254297
*

olp-cpp-sdk-core/tests/client/TaskContextTest.cpp

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,18 @@ class TaskContextTestable : public TaskContext {
6060
void SetExecutors(Exec execute_func, Callback callback,
6161
CancellationContext context) {
6262
auto impl = std::make_shared<TaskContextImpl<ExecResult>>(
63-
std::move(execute_func), std::move(callback), std::move(context));
64-
notify = [=]() { impl->condition_.Notify(); };
63+
std::move(execute_func), std::move(callback), context);
64+
std::weak_ptr<TaskContextImpl<ExecResult>> weak_impl = impl;
65+
context.ExecuteOrCancelled(
66+
[weak_impl]() -> CancellationToken {
67+
return CancellationToken([weak_impl]() {
68+
if (auto impl = weak_impl.lock()) {
69+
impl->PreExecuteCancel();
70+
}
71+
});
72+
},
73+
[]() {});
74+
notify = [impl]() { impl->condition_.Notify(); };
6575
impl_ = impl;
6676
}
6777
};
@@ -127,28 +137,58 @@ TEST(TaskContextTest, ExecuteSimple) {
127137
}
128138

129139
TEST(TaskContextTest, BlockingCancel) {
130-
ExecuteFunc func = [&](CancellationContext c) -> Response {
131-
EXPECT_TRUE(c.IsCancelled());
132-
return std::string("Success");
133-
};
134-
135140
Response response;
136141

137142
Callback callback = [&](Response r) { response = std::move(r); };
138143

139-
TaskContext context = TaskContext::Create(func, callback);
144+
{
145+
SCOPED_TRACE("Pre-exec cancellation");
146+
bool executed = false;
147+
ExecuteFunc func = [&](CancellationContext) -> Response {
148+
executed = true;
149+
return std::string("Success");
150+
};
151+
152+
TaskContext context = TaskContext::Create(func, callback);
153+
EXPECT_TRUE(context.BlockingCancel(std::chrono::seconds(0)));
154+
EXPECT_FALSE(executed);
155+
EXPECT_FALSE(response.IsSuccessful());
156+
EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled);
157+
}
140158

141-
EXPECT_FALSE(context.BlockingCancel(std::chrono::seconds(0)));
159+
{
160+
SCOPED_TRACE("Cancel during execution");
161+
Condition continue_execution;
162+
Condition execution_started;
163+
int execution_count = 0;
164+
response = Response{};
165+
ExecuteFunc func = [&](CancellationContext c) -> Response {
166+
++execution_count;
167+
execution_started.Notify();
168+
EXPECT_TRUE(continue_execution.Wait(kWaitTime));
169+
const auto deadline = std::chrono::steady_clock::now() + kWaitTime;
170+
while (!c.IsCancelled() && std::chrono::steady_clock::now() < deadline) {
171+
std::this_thread::yield();
172+
}
173+
EXPECT_TRUE(c.IsCancelled());
174+
return std::string("Success");
175+
};
176+
TaskContext context = TaskContext::Create(func, callback);
142177

143-
std::thread cancel_thread([&]() { EXPECT_TRUE(context.BlockingCancel()); });
178+
std::thread execute_thread([&]() { context.Execute(); });
179+
EXPECT_TRUE(execution_started.Wait());
144180

145-
std::thread execute_thread([&]() { context.Execute(); });
181+
std::thread cancel_thread([&]() { EXPECT_TRUE(context.BlockingCancel()); });
146182

147-
execute_thread.join();
148-
cancel_thread.join();
183+
continue_execution.Notify();
149184

150-
EXPECT_FALSE(response.IsSuccessful());
151-
EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled);
185+
execute_thread.join();
186+
cancel_thread.join();
187+
188+
EXPECT_EQ(execution_count, 1);
189+
EXPECT_FALSE(response.IsSuccessful());
190+
EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled);
191+
}
152192
}
153193

154194
TEST(TaskContextTest, BlockingCancelIsWaiting) {

olp-cpp-sdk-dataservice-read/tests/VersionedLayerClientImplTest.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* License-Filename: LICENSE
1818
*/
1919

20+
#include <future>
2021
#include <thread>
2122

2223
#include <gtest/gtest.h>
@@ -1138,15 +1139,20 @@ TEST(VersionedLayerClientTest, PrefetchPartitionsCancel) {
11381139
settings);
11391140
{
11401141
SCOPED_TRACE("Cancel request");
1141-
std::promise<void> block_promise;
1142-
auto block_future = block_promise.get_future();
1143-
settings.task_scheduler->ScheduleTask(
1144-
[&block_future]() { block_future.get(); });
1142+
std::promise<void> block_task;
1143+
std::promise<void> block_main;
1144+
auto block_task_future = block_task.get_future();
1145+
auto block_main_future = block_main.get_future();
1146+
settings.task_scheduler->ScheduleTask([&block_task_future, &block_main]() {
1147+
block_main.set_value();
1148+
block_task_future.get();
1149+
});
11451150
auto cancellable = client.PrefetchPartitions(request, nullptr);
11461151

11471152
// cancel the request and unblock queue
11481153
cancellable.GetCancellationToken().Cancel();
1149-
block_promise.set_value();
1154+
block_main_future.wait();
1155+
block_task.set_value();
11501156
auto future = cancellable.GetFuture();
11511157

11521158
ASSERT_EQ(future.wait_for(kTimeout), std::future_status::ready);

olp-cpp-sdk-dataservice-read/tests/VolatileLayerClientImplTest.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -357,12 +357,10 @@ TEST(VolatileLayerClientImplTest, GetDataCancelOnClientDestroy) {
357357
read::DataResponse data_response;
358358
{
359359
// Client owns the task scheduler
360-
auto caller_thread_id = std::this_thread::get_id();
361360
read::VolatileLayerClientImpl client(kHrn, kLayerId, std::move(settings));
362361
client.GetData(read::DataRequest().WithPartitionId(kPartitionId),
363362
[&](read::DataResponse response) {
364363
data_response = std::move(response);
365-
EXPECT_NE(caller_thread_id, std::this_thread::get_id());
366364
});
367365
}
368366

@@ -1013,8 +1011,7 @@ TEST(VolatileLayerClientImplTest, PrefetchTilesCancelOnClientDestroy) {
10131011

10141012
read::PrefetchTilesResponse response;
10151013
{
1016-
// Client owns the task scheduler
1017-
auto caller_thread_id = std::this_thread::get_id();
1014+
// Client owns the task schedule
10181015
read::VolatileLayerClientImpl client(kHrn, kLayerId, std::move(settings));
10191016
std::vector<olp::geo::TileKey> tile_keys = {
10201017
olp::geo::TileKey::FromHereTile(kTileId)};
@@ -1023,11 +1020,10 @@ TEST(VolatileLayerClientImplTest, PrefetchTilesCancelOnClientDestroy) {
10231020
.WithMinLevel(11)
10241021
.WithMaxLevel(12);
10251022

1026-
client.PrefetchTiles(
1027-
request, [&](read::PrefetchTilesResponse prefetch_response) {
1028-
response = std::move(prefetch_response);
1029-
EXPECT_NE(caller_thread_id, std::this_thread::get_id());
1030-
});
1023+
client.PrefetchTiles(request,
1024+
[&](read::PrefetchTilesResponse prefetch_response) {
1025+
response = std::move(prefetch_response);
1026+
});
10311027
}
10321028

10331029
// Callback must be called during client destructor.

tests/integration/olp-cpp-sdk-dataservice-read/VersionedLayerClientPrefetchPartitionsTest.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,15 +301,21 @@ TEST_F(VersionedLayerClientPrefetchPartitionsTest, PrefetchPartitionsCancel) {
301301
const auto request =
302302
read::PrefetchPartitionsRequest().WithPartitionIds(partitions);
303303

304-
std::promise<void> block_promise;
305-
auto block_future = block_promise.get_future();
304+
std::promise<void> block_task_promise;
305+
std::promise<void> block_main_promise;
306+
auto block_future = block_task_promise.get_future();
307+
auto block_main_future = block_main_promise.get_future();
306308
settings_.task_scheduler->ScheduleTask(
307-
[&block_future]() { block_future.get(); });
309+
[&block_future, &block_main_promise]() {
310+
block_main_promise.set_value();
311+
block_future.get();
312+
});
308313
auto cancellable = client.PrefetchPartitions(request);
309314

310315
// cancel the request and unblock queue
311316
cancellable.GetCancellationToken().Cancel();
312-
block_promise.set_value();
317+
block_main_future.get();
318+
block_task_promise.set_value();
313319
auto future = cancellable.GetFuture();
314320

315321
ASSERT_EQ(future.wait_for(kTimeout), std::future_status::ready);

0 commit comments

Comments
 (0)