Skip to content

Commit d1cb9e7

Browse files
Preemptive cancellation of TaskContext jobs
- Enable pre-execution cancellation in TaskContext by registering a cancellation token in constructor that invokes PreExecuteCancel() - Clear cancellation token in ~TaskContextImpl to prevent dangling refs - Fix CancellationContext::CancelOperation() to swap out the token before calling Cancel(), avoiding re-entrancy issues - Update BlockingCancel test to cover pre-execution cancellation Relates-To: HERESDK-12253 Signed-off-by: Mykhailo Diachenko <ext-mykhailo.z.diachenko@here.com>
1 parent b400ee0 commit d1cb9e7

File tree

6 files changed

+92
-30
lines changed

6 files changed

+92
-30
lines changed

olp-cpp-sdk-core/include/olp/core/client/CancellationContext.inl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#pragma once
2121

22+
#include <utility>
2223
namespace olp {
2324
namespace client {
2425

@@ -58,9 +59,10 @@ inline void CancellationContext::CancelOperation() {
5859
return;
5960
}
6061

61-
impl_->sub_operation_cancel_token_.Cancel();
62-
impl_->sub_operation_cancel_token_ = CancellationToken();
62+
auto token = CancellationToken();
63+
std::swap(token, impl_->sub_operation_cancel_token_);
6364
impl_->is_cancelled_ = true;
65+
token.Cancel();
6466
}
6567

6668
inline bool CancellationContext::IsCancelled() const {

olp-cpp-sdk-core/include/olp/core/client/TaskContext.h

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,19 @@ class CORE_API TaskContext {
193193
: execute_func_(std::move(execute_func)),
194194
callback_(std::move(callback)),
195195
context_(std::move(context)),
196-
state_{State::PENDING} {}
196+
state_{State::PENDING} {
197+
context_.ExecuteOrCancelled(
198+
[&]() -> CancellationToken {
199+
return CancellationToken([&] { PreExecuteCancel(); });
200+
},
201+
[]() {});
202+
}
197203

198-
~TaskContextImpl() override{};
204+
~TaskContextImpl() override {
205+
// Overwrite the token to prevent dangling reference to this.
206+
context_.ExecuteOrCancelled(
207+
[]() -> CancellationToken { return CancellationToken(); }, []() {});
208+
};
199209

200210
/**
201211
* @brief Checks for the cancellation, executes the task, and calls
@@ -249,6 +259,40 @@ class CORE_API TaskContext {
249259
state_.store(State::COMPLETED);
250260
}
251261

262+
void PreExecuteCancel() {
263+
State expected_state = State::PENDING;
264+
265+
if (!state_.compare_exchange_strong(expected_state, State::IN_PROGRESS)) {
266+
return;
267+
}
268+
269+
// Moving the user callback and function guarantee that they are
270+
// executed exactly once
271+
ExecuteFunc function = nullptr;
272+
UserCallback callback = nullptr;
273+
274+
{
275+
std::lock_guard<std::mutex> lock(mutex_);
276+
function = std::move(execute_func_);
277+
callback = std::move(callback_);
278+
}
279+
280+
Response user_response =
281+
client::ApiError(client::ErrorCode::Cancelled, "Cancelled");
282+
283+
if (callback) {
284+
callback(std::move(user_response));
285+
}
286+
287+
// Resources need to be released before the notification, else lambas
288+
// would have captured resources like network or `TaskScheduler`.
289+
function = nullptr;
290+
callback = nullptr;
291+
292+
condition_.Notify();
293+
state_.store(State::COMPLETED);
294+
}
295+
252296
/**
253297
* @brief Cancels the operation and waits for the notification.
254298
*

olp-cpp-sdk-core/tests/client/TaskContextTest.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,29 @@ TEST(TaskContextTest, BlockingCancel) {
136136

137137
Callback callback = [&](Response r) { response = std::move(r); };
138138

139-
TaskContext context = TaskContext::Create(func, callback);
139+
{
140+
SCOPED_TRACE("Pre-exec cancellation");
141+
TaskContext context = TaskContext::Create(func, callback);
142+
EXPECT_TRUE(context.BlockingCancel(std::chrono::seconds(0)));
143+
EXPECT_FALSE(response.IsSuccessful());
144+
EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled);
145+
}
140146

141-
EXPECT_FALSE(context.BlockingCancel(std::chrono::seconds(0)));
147+
{
148+
SCOPED_TRACE("Pre-exec cancellation race");
149+
response = Response{};
150+
TaskContext context = TaskContext::Create(func, callback);
142151

143-
std::thread cancel_thread([&]() { EXPECT_TRUE(context.BlockingCancel()); });
152+
std::thread cancel_thread([&]() { EXPECT_TRUE(context.BlockingCancel()); });
144153

145-
std::thread execute_thread([&]() { context.Execute(); });
154+
std::thread execute_thread([&]() { context.Execute(); });
146155

147-
execute_thread.join();
148-
cancel_thread.join();
156+
execute_thread.join();
157+
cancel_thread.join();
149158

150-
EXPECT_FALSE(response.IsSuccessful());
151-
EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled);
159+
EXPECT_FALSE(response.IsSuccessful());
160+
EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled);
161+
}
152162
}
153163

154164
TEST(TaskContextTest, BlockingCancelIsWaiting) {

olp-cpp-sdk-dataservice-read/tests/VersionedLayerClientImplTest.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
* License-Filename: LICENSE
1818
*/
1919

20+
#include <future>
2021
#include <thread>
2122

2223
#include <gtest/gtest.h>
@@ -1138,15 +1139,19 @@ TEST(VersionedLayerClientTest, PrefetchPartitionsCancel) {
11381139
settings);
11391140
{
11401141
SCOPED_TRACE("Cancel request");
1141-
std::promise<void> block_promise;
1142-
auto block_future = block_promise.get_future();
1143-
settings.task_scheduler->ScheduleTask(
1144-
[&block_future]() { block_future.get(); });
1142+
std::promise<void> block_task;
1143+
std::promise<void> block_main;
1144+
auto block_task_future = block_task.get_future();
1145+
settings.task_scheduler->ScheduleTask([&block_task_future, &block_main]() {
1146+
block_main.set_value();
1147+
block_task_future.get();
1148+
});
11451149
auto cancellable = client.PrefetchPartitions(request, nullptr);
11461150

11471151
// cancel the request and unblock queue
11481152
cancellable.GetCancellationToken().Cancel();
1149-
block_promise.set_value();
1153+
block_main.get_future().wait();
1154+
block_task.set_value();
11501155
auto future = cancellable.GetFuture();
11511156

11521157
ASSERT_EQ(future.wait_for(kTimeout), std::future_status::ready);

olp-cpp-sdk-dataservice-read/tests/VolatileLayerClientImplTest.cpp

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -357,12 +357,10 @@ TEST(VolatileLayerClientImplTest, GetDataCancelOnClientDestroy) {
357357
read::DataResponse data_response;
358358
{
359359
// Client owns the task scheduler
360-
auto caller_thread_id = std::this_thread::get_id();
361360
read::VolatileLayerClientImpl client(kHrn, kLayerId, std::move(settings));
362361
client.GetData(read::DataRequest().WithPartitionId(kPartitionId),
363362
[&](read::DataResponse response) {
364363
data_response = std::move(response);
365-
EXPECT_NE(caller_thread_id, std::this_thread::get_id());
366364
});
367365
}
368366

@@ -1013,8 +1011,7 @@ TEST(VolatileLayerClientImplTest, PrefetchTilesCancelOnClientDestroy) {
10131011

10141012
read::PrefetchTilesResponse response;
10151013
{
1016-
// Client owns the task scheduler
1017-
auto caller_thread_id = std::this_thread::get_id();
1014+
// Client owns the task schedule
10181015
read::VolatileLayerClientImpl client(kHrn, kLayerId, std::move(settings));
10191016
std::vector<olp::geo::TileKey> tile_keys = {
10201017
olp::geo::TileKey::FromHereTile(kTileId)};
@@ -1023,11 +1020,10 @@ TEST(VolatileLayerClientImplTest, PrefetchTilesCancelOnClientDestroy) {
10231020
.WithMinLevel(11)
10241021
.WithMaxLevel(12);
10251022

1026-
client.PrefetchTiles(
1027-
request, [&](read::PrefetchTilesResponse prefetch_response) {
1028-
response = std::move(prefetch_response);
1029-
EXPECT_NE(caller_thread_id, std::this_thread::get_id());
1030-
});
1023+
client.PrefetchTiles(request,
1024+
[&](read::PrefetchTilesResponse prefetch_response) {
1025+
response = std::move(prefetch_response);
1026+
});
10311027
}
10321028

10331029
// Callback must be called during client destructor.

tests/integration/olp-cpp-sdk-dataservice-read/VersionedLayerClientPrefetchPartitionsTest.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,15 +301,20 @@ TEST_F(VersionedLayerClientPrefetchPartitionsTest, PrefetchPartitionsCancel) {
301301
const auto request =
302302
read::PrefetchPartitionsRequest().WithPartitionIds(partitions);
303303

304-
std::promise<void> block_promise;
305-
auto block_future = block_promise.get_future();
304+
std::promise<void> block_task_promise;
305+
std::promise<void> block_main_promise;
306+
auto block_future = block_task_promise.get_future();
306307
settings_.task_scheduler->ScheduleTask(
307-
[&block_future]() { block_future.get(); });
308+
[&block_future, &block_main_promise]() {
309+
block_main_promise.set_value();
310+
block_future.get();
311+
});
308312
auto cancellable = client.PrefetchPartitions(request);
309313

310314
// cancel the request and unblock queue
311315
cancellable.GetCancellationToken().Cancel();
312-
block_promise.set_value();
316+
block_main_promise.get_future().get();
317+
block_task_promise.set_value();
313318
auto future = cancellable.GetFuture();
314319

315320
ASSERT_EQ(future.wait_for(kTimeout), std::future_status::ready);

0 commit comments

Comments
 (0)