diff --git a/olp-cpp-sdk-core/include/olp/core/client/CancellationContext.inl b/olp-cpp-sdk-core/include/olp/core/client/CancellationContext.inl index 2d58fc903..eeb39bc98 100644 --- a/olp-cpp-sdk-core/include/olp/core/client/CancellationContext.inl +++ b/olp-cpp-sdk-core/include/olp/core/client/CancellationContext.inl @@ -19,6 +19,7 @@ #pragma once +#include namespace olp { namespace client { @@ -58,9 +59,10 @@ inline void CancellationContext::CancelOperation() { return; } - impl_->sub_operation_cancel_token_.Cancel(); - impl_->sub_operation_cancel_token_ = CancellationToken(); + auto token = CancellationToken(); + std::swap(token, impl_->sub_operation_cancel_token_); impl_->is_cancelled_ = true; + token.Cancel(); } inline bool CancellationContext::IsCancelled() const { diff --git a/olp-cpp-sdk-core/include/olp/core/client/OlpClientSettingsFactory.h b/olp-cpp-sdk-core/include/olp/core/client/OlpClientSettingsFactory.h index c2f98677f..dbc08d340 100644 --- a/olp-cpp-sdk-core/include/olp/core/client/OlpClientSettingsFactory.h +++ b/olp-cpp-sdk-core/include/olp/core/client/OlpClientSettingsFactory.h @@ -47,12 +47,17 @@ class CORE_API OlpClientSettingsFactory final { * operations. * * Defaulted to `olp::thread::ThreadPoolTaskScheduler` with one worker - * thread spawned by default. + * thread spawned by default. The default scheduler can optionally create a + * dedicated cancellation lane backed by an extra worker thread. + * + * @param[in] thread_count The number of regular worker threads. + * @param[in] enable_cancellation_lane When true, enables the dedicated + * cancellation lane on the default `ThreadPoolTaskScheduler`. * * @return The `TaskScheduler` instance. */ static std::unique_ptr CreateDefaultTaskScheduler( - size_t thread_count = 1u); + size_t thread_count = 1u, bool enable_cancellation_lane = false); /** * @brief Creates the `Network` instance used for all the non-local requests. diff --git a/olp-cpp-sdk-core/include/olp/core/client/TaskContext.h b/olp-cpp-sdk-core/include/olp/core/client/TaskContext.h index e8a1f32fc..41ada2b98 100644 --- a/olp-cpp-sdk-core/include/olp/core/client/TaskContext.h +++ b/olp-cpp-sdk-core/include/olp/core/client/TaskContext.h @@ -28,6 +28,7 @@ #include #include #include +#include namespace olp { namespace client { @@ -56,10 +57,11 @@ class CORE_API TaskContext { template static TaskContext Create( Exec execute_func, Callback callback, - client::CancellationContext context = client::CancellationContext()) { + client::CancellationContext context = client::CancellationContext(), + std::shared_ptr task_scheduler = nullptr) { TaskContext task; task.SetExecutors(std::move(execute_func), std::move(callback), - std::move(context)); + std::move(context), std::move(task_scheduler)); return task; } @@ -126,9 +128,32 @@ class CORE_API TaskContext { * @param context The `CancellationContext` instance. */ void SetExecutors(Exec execute_func, Callback callback, - client::CancellationContext context) { - impl_ = std::make_shared>( - std::move(execute_func), std::move(callback), std::move(context)); + client::CancellationContext context, + std::shared_ptr task_scheduler) { + auto impl = std::make_shared>( + std::move(execute_func), std::move(callback), context); + + if (task_scheduler) { + std::weak_ptr> weak_impl = impl; + auto cancellation_scheduler = task_scheduler; + context.ExecuteOrCancelled( + [weak_impl, cancellation_scheduler]() -> CancellationToken { + return CancellationToken([weak_impl, cancellation_scheduler]() { + auto impl = weak_impl.lock(); + if (impl && cancellation_scheduler) { + cancellation_scheduler->ScheduleCancellationTask([weak_impl]() { + auto impl = weak_impl.lock(); + if (impl) { + impl->PreExecuteCancel(); + } + }); + return; + } + }); + }, + []() {}); + } + impl_ = std::move(impl); } /** @@ -249,6 +274,40 @@ class CORE_API TaskContext { state_.store(State::COMPLETED); } + void PreExecuteCancel() { + State expected_state = State::PENDING; + + if (!state_.compare_exchange_strong(expected_state, State::IN_PROGRESS)) { + return; + } + + // Moving the user callback and function guarantee that they are + // executed exactly once + ExecuteFunc function = nullptr; + UserCallback callback = nullptr; + + { + std::lock_guard lock(mutex_); + function = std::move(execute_func_); + callback = std::move(callback_); + } + + Response user_response = + client::ApiError(client::ErrorCode::Cancelled, "Cancelled"); + + if (callback) { + callback(std::move(user_response)); + } + + // Resources need to be released before the notification, else lambas + // would have captured resources like network or `TaskScheduler`. + function = nullptr; + callback = nullptr; + + condition_.Notify(); + state_.store(State::COMPLETED); + } + /** * @brief Cancels the operation and waits for the notification. * @@ -330,8 +389,8 @@ struct CORE_API TaskContextHash { */ size_t operator()(const TaskContext& task_context) const { return std::hash>()(task_context.impl_); - } -}; + } // namespace client +}; // namespace olp } // namespace client } // namespace olp diff --git a/olp-cpp-sdk-core/include/olp/core/thread/TaskScheduler.h b/olp-cpp-sdk-core/include/olp/core/thread/TaskScheduler.h index 8190cadc7..fefa27c23 100644 --- a/olp-cpp-sdk-core/include/olp/core/thread/TaskScheduler.h +++ b/olp-cpp-sdk-core/include/olp/core/thread/TaskScheduler.h @@ -71,6 +71,20 @@ class CORE_API TaskScheduler { EnqueueTask(std::move(func), priority); } + /** + * @brief Schedules cancellation work. + * + * By default, cancellation work falls back to the regular task queue. Custom + * schedulers can override `EnqueueCancellationTask` to dispatch cancellation + * work differently. + * + * @param[in] func The callable target that should be added to the scheduling + * pipeline for cancellation work. + */ + void ScheduleCancellationTask(CallFuncType&& func) { + EnqueueCancellationTask(std::move(func)); + } + /** * @brief Schedules the asynchronous cancellable task. * @@ -136,6 +150,19 @@ class CORE_API TaskScheduler { OLP_SDK_CORE_UNUSED(priority); EnqueueTask(std::forward(func)); } + + /** + * @brief The enqueue cancellation task interface that is implemented by the + * subclass when cancellation work should use a dedicated dispatch path. + * + * By default, cancellation work falls back to the regular task queue. + * + * @param[in] func The rvalue reference of the cancellation task that should + * be enqueued. + */ + virtual void EnqueueCancellationTask(CallFuncType&& func) { + EnqueueTask(std::forward(func)); + } }; /** diff --git a/olp-cpp-sdk-core/include/olp/core/thread/ThreadPoolTaskScheduler.h b/olp-cpp-sdk-core/include/olp/core/thread/ThreadPoolTaskScheduler.h index 078b2b951..69ade79da 100644 --- a/olp-cpp-sdk-core/include/olp/core/thread/ThreadPoolTaskScheduler.h +++ b/olp-cpp-sdk-core/include/olp/core/thread/ThreadPoolTaskScheduler.h @@ -31,6 +31,9 @@ namespace thread { * @brief An implementation of the `TaskScheduler` instance that uses a thread * pool. * + * The scheduler can optionally expose a dedicated cancellation lane backed by a + * separate worker thread. When disabled, cancellation work falls back to the + * regular task queue. */ class CORE_API ThreadPoolTaskScheduler final : public TaskScheduler { public: @@ -38,8 +41,11 @@ class CORE_API ThreadPoolTaskScheduler final : public TaskScheduler { * @brief Creates the `ThreadPoolTaskScheduler` object with one thread. * * @param thread_count The number of threads initialized in the thread pool. + * @param enable_cancellation_lane When true, creates a dedicated worker + * thread for cancellation work scheduled via `ScheduleCancellationTask`. */ - explicit ThreadPoolTaskScheduler(size_t thread_count = 1u); + explicit ThreadPoolTaskScheduler(size_t thread_count = 1u, + bool enable_cancellation_lane = false); /** * @brief Closes the `SyncQueue` instance and joins threads. @@ -81,13 +87,22 @@ class CORE_API ThreadPoolTaskScheduler final : public TaskScheduler { void EnqueueTask(TaskScheduler::CallFuncType&& func, uint32_t priority) override; + void EnqueueCancellationTask(TaskScheduler::CallFuncType&& func) override; + private: class QueueImpl; + class CancellationQueueImpl; /// Thread pool created in constructor. std::vector thread_pool_; /// SyncQueue used to manage tasks. std::unique_ptr queue_; + /// Dedicated cancellation worker thread. + std::thread cancellation_thread_; + /// SyncQueue used to manage cancellation tasks when enabled. + std::unique_ptr cancellation_queue_; + /// Indicates whether the dedicated cancellation lane is enabled. + bool cancellation_lane_enabled_; }; } // namespace thread diff --git a/olp-cpp-sdk-core/src/client/OlpClientSettingsFactory.cpp b/olp-cpp-sdk-core/src/client/OlpClientSettingsFactory.cpp index 5bc1a2ab8..6f414a088 100644 --- a/olp-cpp-sdk-core/src/client/OlpClientSettingsFactory.cpp +++ b/olp-cpp-sdk-core/src/client/OlpClientSettingsFactory.cpp @@ -37,8 +37,10 @@ namespace olp { namespace client { std::unique_ptr -OlpClientSettingsFactory::CreateDefaultTaskScheduler(size_t thread_count) { - return std::make_unique(thread_count); +OlpClientSettingsFactory::CreateDefaultTaskScheduler( + size_t thread_count, bool enable_cancellation_lane) { + return std::make_unique( + thread_count, enable_cancellation_lane); } std::shared_ptr diff --git a/olp-cpp-sdk-core/src/thread/ThreadPoolTaskScheduler.cpp b/olp-cpp-sdk-core/src/thread/ThreadPoolTaskScheduler.cpp index 0cc9d4c1e..5c3b674e7 100644 --- a/olp-cpp-sdk-core/src/thread/ThreadPoolTaskScheduler.cpp +++ b/olp-cpp-sdk-core/src/thread/ThreadPoolTaskScheduler.cpp @@ -42,6 +42,7 @@ namespace thread { namespace { constexpr auto kLogTag = "ThreadPoolTaskScheduler"; +constexpr auto kCancellationExecutorName = "OLPSDKCANCEL"; struct PrioritizedTask { TaskScheduler::CallFuncType function; @@ -61,6 +62,34 @@ void SetExecutorName(size_t idx) { OLP_SDK_LOG_INFO_F(kLogTag, "Starting thread '%s'", thread_name.c_str()); } +void SetCancellationExecutorName() { + olp::utils::Thread::SetCurrentThreadName(kCancellationExecutorName); + OLP_SDK_LOG_INFO_F(kLogTag, "Starting thread '%s'", + kCancellationExecutorName); +} + +TaskScheduler::CallFuncType WrapWithLogContext( + TaskScheduler::CallFuncType&& func) { + auto log_context = logging::GetContext(); + +#if __cplusplus >= 201402L + // At least C++14, use generalized lambda capture + return [log_context = std::move(log_context), func = std::move(func)]() { + olp::logging::ScopedLogContext scoped_context(log_context); + func(); + }; +#else + // C++11 does not support generalized lambda capture :( + return std::bind( + [](std::shared_ptr& log_context, + TaskScheduler::CallFuncType& func) { + olp::logging::ScopedLogContext scoped_context(log_context); + func(); + }, + std::move(log_context), std::move(func)); +#endif +} + } // namespace class ThreadPoolTaskScheduler::QueueImpl { @@ -77,10 +106,39 @@ class ThreadPoolTaskScheduler::QueueImpl { SyncQueue sync_queue_; }; -ThreadPoolTaskScheduler::ThreadPoolTaskScheduler(size_t thread_count) - : queue_{std::make_unique()} { +class ThreadPoolTaskScheduler::CancellationQueueImpl { + public: + using ElementType = TaskScheduler::CallFuncType; + + bool Pull(ElementType& element) { return sync_queue_.Pull(element); } + void Push(ElementType&& element) { sync_queue_.Push(std::move(element)); } + void Close() { sync_queue_.Close(); } + + private: + SyncQueueFifo sync_queue_; +}; + +ThreadPoolTaskScheduler::ThreadPoolTaskScheduler(size_t thread_count, + bool enable_cancellation_lane) + : queue_{std::make_unique()}, + cancellation_lane_enabled_{enable_cancellation_lane} { thread_pool_.reserve(thread_count); + if (cancellation_lane_enabled_) { + cancellation_queue_ = std::make_unique(); + cancellation_thread_ = std::thread([this]() { + SetCancellationExecutorName(); + + for (;;) { + TaskScheduler::CallFuncType task; + if (!cancellation_queue_->Pull(task)) { + return; + } + task(); + } + }); + } + for (size_t idx = 0; idx < thread_count; ++idx) { std::thread executor([this, idx]() { // Set thread name for easy profiling and debugging @@ -100,40 +158,38 @@ ThreadPoolTaskScheduler::ThreadPoolTaskScheduler(size_t thread_count) } ThreadPoolTaskScheduler::~ThreadPoolTaskScheduler() { + if (cancellation_queue_) { + cancellation_queue_->Close(); + } queue_->Close(); + if (cancellation_thread_.joinable()) { + cancellation_thread_.join(); + } for (auto& thread : thread_pool_) { thread.join(); } thread_pool_.clear(); } +void ThreadPoolTaskScheduler::EnqueueCancellationTask( + TaskScheduler::CallFuncType&& func) { + auto task = WrapWithLogContext(std::move(func)); + + if (!cancellation_lane_enabled_ || !cancellation_queue_) { + queue_->Push({std::move(task), thread::NORMAL}); + return; + } + + cancellation_queue_->Push(std::move(task)); +} + void ThreadPoolTaskScheduler::EnqueueTask(TaskScheduler::CallFuncType&& func) { EnqueueTask(std::move(func), thread::NORMAL); } void ThreadPoolTaskScheduler::EnqueueTask(TaskScheduler::CallFuncType&& func, uint32_t priority) { - auto logContext = logging::GetContext(); - -#if __cplusplus >= 201402L - // At least C++14, use generalized lambda capture - auto funcWithCapturedLogContext = [logContext = std::move(logContext), - func = std::move(func)]() { - olp::logging::ScopedLogContext scopedContext(logContext); - func(); - }; -#else - // C++11 does not support generalized lambda capture :( - auto funcWithCapturedLogContext = std::bind( - [](std::shared_ptr& logContext, - TaskScheduler::CallFuncType& func) { - olp::logging::ScopedLogContext scopedContext(logContext); - func(); - }, - std::move(logContext), std::move(func)); -#endif - - queue_->Push({std::move(funcWithCapturedLogContext), priority}); + queue_->Push({WrapWithLogContext(std::move(func)), priority}); } } // namespace thread diff --git a/olp-cpp-sdk-core/tests/client/TaskContextTest.cpp b/olp-cpp-sdk-core/tests/client/TaskContextTest.cpp index d2037e09b..2788e9c5d 100644 --- a/olp-cpp-sdk-core/tests/client/TaskContextTest.cpp +++ b/olp-cpp-sdk-core/tests/client/TaskContextTest.cpp @@ -22,6 +22,7 @@ #include #include #include +#include namespace { @@ -32,6 +33,7 @@ using client::CancellationToken; using client::Condition; using client::ErrorCode; using client::TaskContext; +using ThreadPoolTaskScheduler = olp::thread::ThreadPoolTaskScheduler; using ResponseType = std::string; using Response = client::ApiResponse; @@ -47,21 +49,43 @@ class TaskContextTestable : public TaskContext { template static TaskContextTestable Create( Exec execute_func, Callback callback, - CancellationContext context = CancellationContext()) { + CancellationContext context = CancellationContext(), + std::shared_ptr task_scheduler = nullptr) { TaskContextTestable task; task.SetExecutors(std::move(execute_func), std::move(callback), - std::move(context)); + std::move(context), std::move(task_scheduler)); return task; } template ::type> - void SetExecutors(Exec execute_func, Callback callback, - CancellationContext context) { + void SetExecutors( + Exec execute_func, Callback callback, CancellationContext context, + std::shared_ptr task_scheduler) { auto impl = std::make_shared>( - std::move(execute_func), std::move(callback), std::move(context)); - notify = [=]() { impl->condition_.Notify(); }; + std::move(execute_func), std::move(callback), context); + std::weak_ptr> weak_impl = impl; + auto cancellation_scheduler = task_scheduler; + context.ExecuteOrCancelled( + [weak_impl, cancellation_scheduler]() -> CancellationToken { + return CancellationToken([weak_impl, cancellation_scheduler]() { + auto impl = weak_impl.lock(); + if (impl && cancellation_scheduler) { + cancellation_scheduler->ScheduleCancellationTask( + [weak_impl, cancellation_scheduler]() { + OLP_SDK_CORE_UNUSED(cancellation_scheduler); + if (auto impl = weak_impl.lock()) { + impl->PreExecuteCancel(); + } + }); + return; + } + impl->PreExecuteCancel(); + }); + }, + []() {}); + notify = [impl]() { impl->condition_.Notify(); }; impl_ = impl; } }; @@ -127,28 +151,166 @@ TEST(TaskContextTest, ExecuteSimple) { } TEST(TaskContextTest, BlockingCancel) { - ExecuteFunc func = [&](CancellationContext c) -> Response { - EXPECT_TRUE(c.IsCancelled()); - return std::string("Success"); - }; - Response response; Callback callback = [&](Response r) { response = std::move(r); }; - TaskContext context = TaskContext::Create(func, callback); + { + SCOPED_TRACE("Pre-exec cancellation"); + bool executed = false; + ExecuteFunc func = [&](CancellationContext) -> Response { + executed = true; + return std::string("Success"); + }; + + TaskContext context = TaskContext::Create(func, callback); + EXPECT_FALSE(context.BlockingCancel(std::chrono::seconds(0))); + context.Execute(); + EXPECT_FALSE(executed); + EXPECT_FALSE(response.IsSuccessful()); + EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled); + } - EXPECT_FALSE(context.BlockingCancel(std::chrono::seconds(0))); + { + SCOPED_TRACE("Cancel during execution"); + Condition continue_execution; + Condition execution_started; + int execution_count = 0; + response = Response{}; + ExecuteFunc func = [&](CancellationContext c) -> Response { + ++execution_count; + execution_started.Notify(); + // EXPECT_TRUE(continue_execution.Wait(kWaitTime)); + const auto deadline = std::chrono::steady_clock::now() + kWaitTime; + while (!c.IsCancelled() && std::chrono::steady_clock::now() < deadline) { + std::this_thread::yield(); + } + EXPECT_TRUE(c.IsCancelled()); + return std::string("Success"); + }; + TaskContext context = TaskContext::Create(func, callback); - std::thread cancel_thread([&]() { EXPECT_TRUE(context.BlockingCancel()); }); + std::thread execute_thread([&]() { context.Execute(); }); + EXPECT_TRUE(execution_started.Wait()); - std::thread execute_thread([&]() { context.Execute(); }); + std::thread cancel_thread([&]() { EXPECT_TRUE(context.BlockingCancel()); }); - execute_thread.join(); - cancel_thread.join(); + // continue_execution.Notify(); + execute_thread.join(); + cancel_thread.join(); + + EXPECT_EQ(execution_count, 1); + EXPECT_FALSE(response.IsSuccessful()); + EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled); + } +} + +TEST(TaskContextTest, PreExecuteCancelUsesCancellationLaneWhenEnabled) { + auto task_scheduler = std::make_shared(1u, true); + + bool executed = false; + Response response; + std::promise callback_promise; + auto callback_future = callback_promise.get_future(); + std::thread::id callback_thread_id; + std::promise executor_thread_id_promise; + std::future executor_thread_id_future = + executor_thread_id_promise.get_future(); + task_scheduler->ScheduleTask([&]() { + executor_thread_id_promise.set_value(std::this_thread::get_id()); + }); + + ASSERT_EQ(executor_thread_id_future.wait_for(std::chrono::seconds(1)), + std::future_status::ready); + + TaskContext context = TaskContext::Create( + [&](CancellationContext) -> Response { + executed = true; + return std::string("Success"); + }, + [&](Response r) { + callback_thread_id = std::this_thread::get_id(); + response = std::move(r); + callback_promise.set_value(); + }, + CancellationContext(), task_scheduler); + + EXPECT_TRUE(context.BlockingCancel(kWaitTime)); + EXPECT_EQ(callback_future.wait_for(std::chrono::milliseconds(0)), + std::future_status::ready); + EXPECT_FALSE(executed); EXPECT_FALSE(response.IsSuccessful()); EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled); + EXPECT_NE(callback_thread_id, std::this_thread::get_id()); + EXPECT_NE(callback_thread_id, executor_thread_id_future.get()); +} + +TEST(TaskContextTest, PreExecuteCancelUsesRegularSchedulerWhenLaneDisabled) { + auto task_scheduler = std::make_shared(1u, false); + bool executed = false; + Response response; + std::promise callback_promise; + auto callback_future = callback_promise.get_future(); + std::thread::id callback_thread_id; + + std::promise executor_thread_id_promise; + std::future executor_thread_id_future = + executor_thread_id_promise.get_future(); + task_scheduler->ScheduleTask([&]() { + executor_thread_id_promise.set_value(std::this_thread::get_id()); + }); + + ASSERT_EQ(executor_thread_id_future.wait_for(std::chrono::seconds(1)), + std::future_status::ready); + + TaskContext context = TaskContext::Create( + [&](CancellationContext) -> Response { + executed = true; + return std::string("Success"); + }, + [&](Response r) { + callback_thread_id = std::this_thread::get_id(); + response = std::move(r); + callback_promise.set_value(); + }, + CancellationContext(), task_scheduler); + + EXPECT_TRUE(context.BlockingCancel(kWaitTime)); + EXPECT_EQ(callback_future.wait_for(std::chrono::milliseconds(0)), + std::future_status::ready); + EXPECT_FALSE(executed); + EXPECT_FALSE(response.IsSuccessful()); + EXPECT_EQ(response.GetError().GetErrorCode(), ErrorCode::Cancelled); + EXPECT_NE(callback_thread_id, std::this_thread::get_id()); + EXPECT_EQ(callback_thread_id, executor_thread_id_future.get()); +} + +TEST(TaskContextTest, + CancellationLanePreExecuteCancelRunsWhileRegularWorkerBlocked) { + auto scheduler = std::make_shared(1u, true); + std::promise response; + std::promise scheduler_blocked_promise; + + scheduler->ScheduleTask([&]() { + scheduler_blocked_promise.set_value(); + auto future = response.get_future(); + ASSERT_EQ(future.wait_for(std::chrono::seconds(1)), + std::future_status::ready); + + EXPECT_EQ(future.get().GetError().GetErrorCode(), ErrorCode::Cancelled); + }); + + ASSERT_EQ( + scheduler_blocked_promise.get_future().wait_for(std::chrono::seconds(1)), + std::future_status::ready); + + TaskContext context = TaskContext::Create( + [&](CancellationContext) -> Response { return std::string("Success"); }, + [&](Response r) { response.set_value(std::move(r)); }, + CancellationContext(), scheduler); + + ASSERT_TRUE(context.BlockingCancel(kWaitTime)); } TEST(TaskContextTest, BlockingCancelIsWaiting) { diff --git a/olp-cpp-sdk-core/tests/thread/ThreadPoolTaskSchedulerTest.cpp b/olp-cpp-sdk-core/tests/thread/ThreadPoolTaskSchedulerTest.cpp index 91b59c646..69b84b290 100644 --- a/olp-cpp-sdk-core/tests/thread/ThreadPoolTaskSchedulerTest.cpp +++ b/olp-cpp-sdk-core/tests/thread/ThreadPoolTaskSchedulerTest.cpp @@ -18,6 +18,8 @@ */ #include +#include +#include #include #include @@ -25,6 +27,8 @@ #include #include +#include +#include #include #include "mocks/TaskSchedulerMock.h" @@ -302,3 +306,87 @@ TEST(ThreadPoolTaskSchedulerTest, ExecuteOrSchedule) { EXPECT_EQ(counter, 1); } } + +TEST(ThreadPoolTaskSchedulerTest, + CancellationTaskFallsBackToDefaultSchedulerQueue) { + auto mock_scheduler = std::make_shared(); + SyncTaskType scheduled_task; + bool executed = false; + + EXPECT_CALL(*mock_scheduler, EnqueueTask(testing::_)) + .WillOnce(testing::Invoke( + [&](SyncTaskType&& task) { scheduled_task = std::move(task); })); + + mock_scheduler->ScheduleCancellationTask([&]() { executed = true; }); + + ASSERT_TRUE(static_cast(scheduled_task)); + scheduled_task(); + EXPECT_TRUE(executed); +} + +TEST(ThreadPoolTaskSchedulerTest, + CancellationTaskFallsBackToRegularQueueWhenLaneDisabled) { + auto thread_pool = std::make_shared(1u, false); + + std::promise completed_promise; + auto completed_future = completed_promise.get_future(); + + thread_pool->ScheduleCancellationTask( + [&]() { completed_promise.set_value(); }); + + EXPECT_EQ(completed_future.wait_for(std::chrono::milliseconds(kMaxWaitMs)), + std::future_status::ready); +} + +TEST(ThreadPoolTaskSchedulerTest, + CancellationLaneRunsWhileRegularWorkerBlocked) { + olp::client::Condition executor_blocked; + olp::client::Condition cancellation_done; + olp::client::Condition executor_done; + + auto thread_pool = std::make_shared(1u, true); + TaskScheduler& scheduler = *thread_pool; + + scheduler.ScheduleTask([&]() { + executor_blocked.Notify(); + ASSERT_TRUE(cancellation_done.Wait(std::chrono::seconds(1))); + executor_done.Notify(); + }); + + ASSERT_TRUE(executor_blocked.Wait(std::chrono::seconds(1))); + scheduler.ScheduleCancellationTask([&]() { cancellation_done.Notify(); }); + ASSERT_TRUE(executor_done.Wait(std::chrono::seconds(1))); +} + +TEST(ThreadPoolTaskSchedulerTest, + DefaultSchedulerFactoryUsesDedicatedCancellationLaneWhenEnabled) { + auto scheduler = + olp::client::OlpClientSettingsFactory::CreateDefaultTaskScheduler(1u, + true); + + std::promise cancellation_thread_id_promise; + std::promise executor_thread_id_promise; + + scheduler->ScheduleTask([&]() { + executor_thread_id_promise.set_value(std::this_thread::get_id()); + }); + + scheduler->ScheduleCancellationTask([&]() { + cancellation_thread_id_promise.set_value(std::this_thread::get_id()); + }); + + auto cancellation_thread_id_future = + cancellation_thread_id_promise.get_future(); + auto executor_thread_id_future = executor_thread_id_promise.get_future(); + ASSERT_EQ(cancellation_thread_id_future.wait_for(std::chrono::seconds(1)), + std::future_status::ready); + ASSERT_EQ(executor_thread_id_future.wait_for(std::chrono::seconds(1)), + std::future_status::ready); + + auto cancellation_thread_id = cancellation_thread_id_future.get(); + auto executor_thread_id = executor_thread_id_future.get(); + + EXPECT_NE(cancellation_thread_id, executor_thread_id); + EXPECT_NE(cancellation_thread_id, std::this_thread::get_id()); + EXPECT_NE(executor_thread_id, std::this_thread::get_id()); +} diff --git a/olp-cpp-sdk-dataservice-read/src/TaskSink.cpp b/olp-cpp-sdk-dataservice-read/src/TaskSink.cpp index 91d46200c..951f7f2b7 100644 --- a/olp-cpp-sdk-dataservice-read/src/TaskSink.cpp +++ b/olp-cpp-sdk-dataservice-read/src/TaskSink.cpp @@ -61,7 +61,7 @@ client::CancellationToken TaskSink::AddTask( return client::ApiError(); }, [=](client::ApiResponse) { func(context); }, - context); + context, task_scheduler_); AddTaskImpl(task, priority); return task.CancelToken(); } diff --git a/olp-cpp-sdk-dataservice-read/src/TaskSink.h b/olp-cpp-sdk-dataservice-read/src/TaskSink.h index f0ee7b578..ec530a46b 100644 --- a/olp-cpp-sdk-dataservice-read/src/TaskSink.h +++ b/olp-cpp-sdk-dataservice-read/src/TaskSink.h @@ -47,8 +47,8 @@ class TaskSink { template client::CancellationToken AddTask(Function task, Callback callback, uint32_t priority, Args&&... args) { - auto context = client::TaskContext::Create( - std::move(task), std::move(callback), std::forward(args)...); + auto context = CreateTaskContext(std::move(task), std::move(callback), + std::forward(args)...); AddTaskImpl(context, priority); return context.CancelToken(); } @@ -58,8 +58,8 @@ class TaskSink { Callback callback, uint32_t priority, Args&&... args) { - auto context = client::TaskContext::Create( - std::move(task), std::move(callback), std::forward(args)...); + auto context = CreateTaskContext(std::move(task), std::move(callback), + std::forward(args)...); if (!AddTaskImpl(context, priority)) { return olp::porting::none; } @@ -67,6 +67,20 @@ class TaskSink { } protected: + template + client::TaskContext CreateTaskContext(Function task, Callback callback) { + return client::TaskContext::Create(std::move(task), std::move(callback), + client::CancellationContext(), + task_scheduler_); + } + + template + client::TaskContext CreateTaskContext(Function task, Callback callback, + client::CancellationContext context) { + return client::TaskContext::Create(std::move(task), std::move(callback), + std::move(context), task_scheduler_); + } + bool AddTaskImpl(client::TaskContext task, uint32_t priority); bool ScheduleTask(client::TaskContext task, uint32_t priority); diff --git a/olp-cpp-sdk-dataservice-write/src/Common.h b/olp-cpp-sdk-dataservice-write/src/Common.h index 94b78c4a4..1c8b565cf 100644 --- a/olp-cpp-sdk-dataservice-write/src/Common.h +++ b/olp-cpp-sdk-dataservice-write/src/Common.h @@ -36,16 +36,20 @@ namespace write { * requests. * @param task Function that will be executed. * @param callback Function that will consume task output. - * @param args Additional agrs to pass to TaskContext. + * @param cancellation_context Cancellation context used to cancel the + * operation. * @return CancellationToken used to cancel the operation. */ template inline client::CancellationToken AddTask( const std::shared_ptr& task_scheduler, const std::shared_ptr& pending_requests, - Function task, Callback callback, Args&&... args) { - auto context = client::TaskContext::Create( - std::move(task), std::move(callback), std::forward(args)...); + Function task, Callback callback, + const client::CancellationContext& cancellation_context = + client::CancellationContext()) { + auto context = + client::TaskContext::Create(std::move(task), std::move(callback), + cancellation_context, task_scheduler); pending_requests->Insert(context); auto scheduler_func = [=] { diff --git a/olp-cpp-sdk-dataservice-write/src/StreamLayerClientImpl.cpp b/olp-cpp-sdk-dataservice-write/src/StreamLayerClientImpl.cpp index f88c938fc..e4341f36c 100644 --- a/olp-cpp-sdk-dataservice-write/src/StreamLayerClientImpl.cpp +++ b/olp-cpp-sdk-dataservice-write/src/StreamLayerClientImpl.cpp @@ -251,7 +251,8 @@ olp::client::CancellationToken StreamLayerClientImpl::Flush( if (!exec_started->load()) { callback(StreamLayerClient::FlushResponse{}); } - }); + }, + client::CancellationContext(), task_scheduler_); auto pending_requests = pending_requests_; pending_requests->Insert(task_context); @@ -285,7 +286,7 @@ client::CancellationToken StreamLayerClientImpl::PublishData( using std::placeholders::_1; client::TaskContext task_context = olp::client::TaskContext::Create( std::bind(&StreamLayerClientImpl::PublishDataTask, this, request, _1), - callback); + callback, client::CancellationContext(), task_scheduler_); auto pending_requests = pending_requests_; pending_requests->Insert(task_context); @@ -474,7 +475,7 @@ client::CancellationToken StreamLayerClientImpl::PublishSdii( auto context = olp::client::TaskContext::Create( std::bind(&StreamLayerClientImpl::PublishSdiiTask, this, std::move(request), _1), - callback); + callback, client::CancellationContext(), task_scheduler_); auto pending_requests = pending_requests_; pending_requests->Insert(context);