Skip to content

Commit 22094a0

Browse files
Preemptive cancellation of TaskContext jobs
- Enable pre-execution cancellation in TaskContext by registering a cancellation token in constructor that invokes PreExecuteCancel() - Clear cancellation token in ~TaskContextImpl to prevent dangling refs - Fix CancellationContext::CancelOperation() to swap out the token before calling Cancel(), avoiding re-entrancy issues - Update BlockingCancel test to cover pre-execution cancellation Relates-To: HERESDK-12253 Signed-off-by: Mykhailo Diachenko <ext-mykhailo.z.diachenko@here.com>
1 parent b400ee0 commit 22094a0

File tree

3 files changed

+68
-12
lines changed

3 files changed

+68
-12
lines changed

olp-cpp-sdk-core/include/olp/core/client/CancellationContext.inl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#pragma once
2121

22+
#include <utility>
2223
namespace olp {
2324
namespace client {
2425

@@ -58,9 +59,10 @@ inline void CancellationContext::CancelOperation() {
5859
return;
5960
}
6061

61-
impl_->sub_operation_cancel_token_.Cancel();
62-
impl_->sub_operation_cancel_token_ = CancellationToken();
62+
auto token = CancellationToken();
63+
std::swap(token, impl_->sub_operation_cancel_token_);
6364
impl_->is_cancelled_ = true;
65+
token.Cancel();
6466
}
6567

6668
inline bool CancellationContext::IsCancelled() const {

olp-cpp-sdk-core/include/olp/core/client/TaskContext.h

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,19 @@ class CORE_API TaskContext {
193193
: execute_func_(std::move(execute_func)),
194194
callback_(std::move(callback)),
195195
context_(std::move(context)),
196-
state_{State::PENDING} {}
196+
state_{State::PENDING} {
197+
context_.ExecuteOrCancelled(
198+
[&]() -> CancellationToken {
199+
return CancellationToken([&] { PreExecuteCancel(); });
200+
},
201+
[]() {});
202+
}
197203

198-
~TaskContextImpl() override{};
204+
~TaskContextImpl() override {
205+
// Overwrite the token to prevent dangling reference to this.
206+
context_.ExecuteOrCancelled(
207+
[]() -> CancellationToken { return CancellationToken(); }, []() {});
208+
};
199209

200210
/**
201211
* @brief Checks for the cancellation, executes the task, and calls
@@ -249,6 +259,40 @@ class CORE_API TaskContext {
249259
state_.store(State::COMPLETED);
250260
}
251261

262+
void PreExecuteCancel() {
263+
State expected_state = State::PENDING;
264+
265+
if (!state_.compare_exchange_strong(expected_state, State::IN_PROGRESS)) {
266+
return;
267+
}
268+
269+
// Moving the user callback and function guarantee that they are
270+
// executed exactly once
271+
ExecuteFunc function = nullptr;
272+
UserCallback callback = nullptr;
273+
274+
{
275+
std::lock_guard<std::mutex> lock(mutex_);
276+
function = std::move(execute_func_);
277+
callback = std::move(callback_);
278+
}
279+
280+
Response user_response =
281+
client::ApiError(client::ErrorCode::Cancelled, "Cancelled");
282+
283+
if (callback) {
284+
callback(std::move(user_response));
285+
}
286+
287+
// Resources need to be released before the notification, else lambas
288+
// would have captured resources like network or `TaskScheduler`.
289+
function = nullptr;
290+
callback = nullptr;
291+
292+
condition_.Notify();
293+
state_.store(State::COMPLETED);
294+
}
295+
252296
/**
253297
* @brief Cancels the operation and waits for the notification.
254298
*

olp-cpp-sdk-core/tests/client/TaskContextTest.cpp

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -136,19 +136,29 @@ TEST(TaskContextTest, BlockingCancel) {
136136

137137
Callback callback = [&](Response r) { response = std::move(r); };
138138

139-
TaskContext context = TaskContext::Create(func, callback);
139+
{
140+
SCOPED_TRACE("Pre-exec cancellation");
141+
TaskContext context = TaskContext::Create(func, callback);
142+
EXPECT_TRUE(context.BlockingCancel(std::chrono::seconds(0)));
143+
EXPECT_FALSE(response.IsSuccessful());
144+
EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled);
145+
}
140146

141-
EXPECT_FALSE(context.BlockingCancel(std::chrono::seconds(0)));
147+
{
148+
SCOPED_TRACE("Pre-exec cancellation race");
149+
response = Response{};
150+
TaskContext context = TaskContext::Create(func, callback);
142151

143-
std::thread cancel_thread([&]() { EXPECT_TRUE(context.BlockingCancel()); });
152+
std::thread cancel_thread([&]() { EXPECT_TRUE(context.BlockingCancel()); });
144153

145-
std::thread execute_thread([&]() { context.Execute(); });
154+
std::thread execute_thread([&]() { context.Execute(); });
146155

147-
execute_thread.join();
148-
cancel_thread.join();
156+
execute_thread.join();
157+
cancel_thread.join();
149158

150-
EXPECT_FALSE(response.IsSuccessful());
151-
EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled);
159+
EXPECT_FALSE(response.IsSuccessful());
160+
EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled);
161+
}
152162
}
153163

154164
TEST(TaskContextTest, BlockingCancelIsWaiting) {

0 commit comments

Comments
 (0)