Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

#pragma once

#include <utility>
namespace olp {
namespace client {

Expand Down Expand Up @@ -58,9 +59,10 @@ inline void CancellationContext::CancelOperation() {
return;
}

impl_->sub_operation_cancel_token_.Cancel();
impl_->sub_operation_cancel_token_ = CancellationToken();
auto token = CancellationToken();
std::swap(token, impl_->sub_operation_cancel_token_);
impl_->is_cancelled_ = true;
token.Cancel();
}

inline bool CancellationContext::IsCancelled() const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,17 @@ class CORE_API OlpClientSettingsFactory final {
* operations.
*
* Defaulted to `olp::thread::ThreadPoolTaskScheduler` with one worker
* thread spawned by default.
* thread spawned by default. The default scheduler can optionally create a
* dedicated cancellation lane backed by an extra worker thread.
*
* @param[in] thread_count The number of regular worker threads.
* @param[in] enable_cancellation_lane When true, enables the dedicated
* cancellation lane on the default `ThreadPoolTaskScheduler`.
*
* @return The `TaskScheduler` instance.
*/
static std::unique_ptr<thread::TaskScheduler> CreateDefaultTaskScheduler(
size_t thread_count = 1u);
size_t thread_count = 1u, bool enable_cancellation_lane = false);

/**
* @brief Creates the `Network` instance used for all the non-local requests.
Expand Down
73 changes: 66 additions & 7 deletions olp-cpp-sdk-core/include/olp/core/client/TaskContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <olp/core/client/CancellationContext.h>
#include <olp/core/client/CancellationToken.h>
#include <olp/core/client/Condition.h>
#include <olp/core/thread/TaskScheduler.h>

namespace olp {
namespace client {
Expand Down Expand Up @@ -56,10 +57,11 @@ class CORE_API TaskContext {
template <typename Exec, typename Callback>
static TaskContext Create(
Exec execute_func, Callback callback,
client::CancellationContext context = client::CancellationContext()) {
client::CancellationContext context = client::CancellationContext(),
std::shared_ptr<thread::TaskScheduler> task_scheduler = nullptr) {
TaskContext task;
task.SetExecutors(std::move(execute_func), std::move(callback),
std::move(context));
std::move(context), std::move(task_scheduler));
return task;
}

Expand Down Expand Up @@ -126,9 +128,32 @@ class CORE_API TaskContext {
* @param context The `CancellationContext` instance.
*/
void SetExecutors(Exec execute_func, Callback callback,
client::CancellationContext context) {
impl_ = std::make_shared<TaskContextImpl<ExecResult>>(
std::move(execute_func), std::move(callback), std::move(context));
client::CancellationContext context,
std::shared_ptr<thread::TaskScheduler> task_scheduler) {
auto impl = std::make_shared<TaskContextImpl<ExecResult>>(
std::move(execute_func), std::move(callback), context);

if (task_scheduler) {
std::weak_ptr<TaskContextImpl<ExecResult>> weak_impl = impl;
auto cancellation_scheduler = task_scheduler;
context.ExecuteOrCancelled(
[weak_impl, cancellation_scheduler]() -> CancellationToken {
return CancellationToken([weak_impl, cancellation_scheduler]() {
auto impl = weak_impl.lock();
if (impl && cancellation_scheduler) {
cancellation_scheduler->ScheduleCancellationTask([weak_impl]() {
auto impl = weak_impl.lock();
if (impl) {
impl->PreExecuteCancel();
}
});
return;
}
});
},
[]() {});
}
impl_ = std::move(impl);
}

/**
Expand Down Expand Up @@ -249,6 +274,40 @@ class CORE_API TaskContext {
state_.store(State::COMPLETED);
}

void PreExecuteCancel() {
State expected_state = State::PENDING;

if (!state_.compare_exchange_strong(expected_state, State::IN_PROGRESS)) {
return;
}

// Moving the user callback and function guarantee that they are
// executed exactly once
ExecuteFunc function = nullptr;
UserCallback callback = nullptr;

{
std::lock_guard<std::mutex> lock(mutex_);
function = std::move(execute_func_);
callback = std::move(callback_);
}

Response user_response =
client::ApiError(client::ErrorCode::Cancelled, "Cancelled");

if (callback) {
callback(std::move(user_response));
}

// Resources need to be released before the notification, else lambas
// would have captured resources like network or `TaskScheduler`.
function = nullptr;
callback = nullptr;

condition_.Notify();
state_.store(State::COMPLETED);
}

/**
* @brief Cancels the operation and waits for the notification.
*
Expand Down Expand Up @@ -330,8 +389,8 @@ struct CORE_API TaskContextHash {
*/
size_t operator()(const TaskContext& task_context) const {
return std::hash<std::shared_ptr<TaskContext::Impl>>()(task_context.impl_);
}
};
} // namespace client
}; // namespace olp

} // namespace client
} // namespace olp
27 changes: 27 additions & 0 deletions olp-cpp-sdk-core/include/olp/core/thread/TaskScheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,20 @@ class CORE_API TaskScheduler {
EnqueueTask(std::move(func), priority);
}

/**
* @brief Schedules cancellation work.
*
* By default, cancellation work falls back to the regular task queue. Custom
* schedulers can override `EnqueueCancellationTask` to dispatch cancellation
* work differently.
*
* @param[in] func The callable target that should be added to the scheduling
* pipeline for cancellation work.
*/
void ScheduleCancellationTask(CallFuncType&& func) {
EnqueueCancellationTask(std::move(func));
}

/**
* @brief Schedules the asynchronous cancellable task.
*
Expand Down Expand Up @@ -136,6 +150,19 @@ class CORE_API TaskScheduler {
OLP_SDK_CORE_UNUSED(priority);
EnqueueTask(std::forward<CallFuncType>(func));
}

/**
* @brief The enqueue cancellation task interface that is implemented by the
* subclass when cancellation work should use a dedicated dispatch path.
*
* By default, cancellation work falls back to the regular task queue.
*
* @param[in] func The rvalue reference of the cancellation task that should
* be enqueued.
*/
virtual void EnqueueCancellationTask(CallFuncType&& func) {
EnqueueTask(std::forward<CallFuncType>(func));
}
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,21 @@ namespace thread {
* @brief An implementation of the `TaskScheduler` instance that uses a thread
* pool.
*
* The scheduler can optionally expose a dedicated cancellation lane backed by a
* separate worker thread. When disabled, cancellation work falls back to the
* regular task queue.
*/
class CORE_API ThreadPoolTaskScheduler final : public TaskScheduler {
public:
/**
* @brief Creates the `ThreadPoolTaskScheduler` object with one thread.
*
* @param thread_count The number of threads initialized in the thread pool.
* @param enable_cancellation_lane When true, creates a dedicated worker
* thread for cancellation work scheduled via `ScheduleCancellationTask`.
*/
explicit ThreadPoolTaskScheduler(size_t thread_count = 1u);
explicit ThreadPoolTaskScheduler(size_t thread_count = 1u,
bool enable_cancellation_lane = false);

/**
* @brief Closes the `SyncQueue` instance and joins threads.
Expand Down Expand Up @@ -81,13 +87,22 @@ class CORE_API ThreadPoolTaskScheduler final : public TaskScheduler {
void EnqueueTask(TaskScheduler::CallFuncType&& func,
uint32_t priority) override;

void EnqueueCancellationTask(TaskScheduler::CallFuncType&& func) override;

private:
class QueueImpl;
class CancellationQueueImpl;

/// Thread pool created in constructor.
std::vector<std::thread> thread_pool_;
/// SyncQueue used to manage tasks.
std::unique_ptr<QueueImpl> queue_;
/// Dedicated cancellation worker thread.
std::thread cancellation_thread_;
/// SyncQueue used to manage cancellation tasks when enabled.
std::unique_ptr<CancellationQueueImpl> cancellation_queue_;
/// Indicates whether the dedicated cancellation lane is enabled.
bool cancellation_lane_enabled_;
};

} // namespace thread
Expand Down
6 changes: 4 additions & 2 deletions olp-cpp-sdk-core/src/client/OlpClientSettingsFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ namespace olp {
namespace client {

std::unique_ptr<thread::TaskScheduler>
OlpClientSettingsFactory::CreateDefaultTaskScheduler(size_t thread_count) {
return std::make_unique<thread::ThreadPoolTaskScheduler>(thread_count);
OlpClientSettingsFactory::CreateDefaultTaskScheduler(
size_t thread_count, bool enable_cancellation_lane) {
return std::make_unique<thread::ThreadPoolTaskScheduler>(
thread_count, enable_cancellation_lane);
}

std::shared_ptr<http::Network>
Expand Down
102 changes: 79 additions & 23 deletions olp-cpp-sdk-core/src/thread/ThreadPoolTaskScheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ namespace thread {

namespace {
constexpr auto kLogTag = "ThreadPoolTaskScheduler";
constexpr auto kCancellationExecutorName = "OLPSDKCANCEL";

struct PrioritizedTask {
TaskScheduler::CallFuncType function;
Expand All @@ -61,6 +62,34 @@ void SetExecutorName(size_t idx) {
OLP_SDK_LOG_INFO_F(kLogTag, "Starting thread '%s'", thread_name.c_str());
}

void SetCancellationExecutorName() {
olp::utils::Thread::SetCurrentThreadName(kCancellationExecutorName);
OLP_SDK_LOG_INFO_F(kLogTag, "Starting thread '%s'",
kCancellationExecutorName);
}

TaskScheduler::CallFuncType WrapWithLogContext(
TaskScheduler::CallFuncType&& func) {
auto log_context = logging::GetContext();

#if __cplusplus >= 201402L
// At least C++14, use generalized lambda capture
return [log_context = std::move(log_context), func = std::move(func)]() {
olp::logging::ScopedLogContext scoped_context(log_context);
func();
};
#else
// C++11 does not support generalized lambda capture :(
return std::bind(
[](std::shared_ptr<const olp::logging::LogContext>& log_context,
TaskScheduler::CallFuncType& func) {
olp::logging::ScopedLogContext scoped_context(log_context);
func();
},
std::move(log_context), std::move(func));
#endif
}

} // namespace

class ThreadPoolTaskScheduler::QueueImpl {
Expand All @@ -77,10 +106,39 @@ class ThreadPoolTaskScheduler::QueueImpl {
SyncQueue<ElementType, PriorityQueue> sync_queue_;
};

ThreadPoolTaskScheduler::ThreadPoolTaskScheduler(size_t thread_count)
: queue_{std::make_unique<QueueImpl>()} {
class ThreadPoolTaskScheduler::CancellationQueueImpl {
public:
using ElementType = TaskScheduler::CallFuncType;

bool Pull(ElementType& element) { return sync_queue_.Pull(element); }
void Push(ElementType&& element) { sync_queue_.Push(std::move(element)); }
void Close() { sync_queue_.Close(); }

private:
SyncQueueFifo<ElementType> sync_queue_;
};

ThreadPoolTaskScheduler::ThreadPoolTaskScheduler(size_t thread_count,
bool enable_cancellation_lane)
: queue_{std::make_unique<QueueImpl>()},
cancellation_lane_enabled_{enable_cancellation_lane} {
thread_pool_.reserve(thread_count);

if (cancellation_lane_enabled_) {
cancellation_queue_ = std::make_unique<CancellationQueueImpl>();
cancellation_thread_ = std::thread([this]() {
SetCancellationExecutorName();

for (;;) {
TaskScheduler::CallFuncType task;
if (!cancellation_queue_->Pull(task)) {
return;
}
task();
}
});
}

for (size_t idx = 0; idx < thread_count; ++idx) {
std::thread executor([this, idx]() {
// Set thread name for easy profiling and debugging
Expand All @@ -100,40 +158,38 @@ ThreadPoolTaskScheduler::ThreadPoolTaskScheduler(size_t thread_count)
}

ThreadPoolTaskScheduler::~ThreadPoolTaskScheduler() {
if (cancellation_queue_) {
cancellation_queue_->Close();
}
queue_->Close();
if (cancellation_thread_.joinable()) {
cancellation_thread_.join();
}
for (auto& thread : thread_pool_) {
thread.join();
}
thread_pool_.clear();
}

void ThreadPoolTaskScheduler::EnqueueCancellationTask(
TaskScheduler::CallFuncType&& func) {
auto task = WrapWithLogContext(std::move(func));

if (!cancellation_lane_enabled_ || !cancellation_queue_) {
queue_->Push({std::move(task), thread::NORMAL});
return;
}

cancellation_queue_->Push(std::move(task));
}

void ThreadPoolTaskScheduler::EnqueueTask(TaskScheduler::CallFuncType&& func) {
EnqueueTask(std::move(func), thread::NORMAL);
}

void ThreadPoolTaskScheduler::EnqueueTask(TaskScheduler::CallFuncType&& func,
uint32_t priority) {
auto logContext = logging::GetContext();

#if __cplusplus >= 201402L
// At least C++14, use generalized lambda capture
auto funcWithCapturedLogContext = [logContext = std::move(logContext),
func = std::move(func)]() {
olp::logging::ScopedLogContext scopedContext(logContext);
func();
};
#else
// C++11 does not support generalized lambda capture :(
auto funcWithCapturedLogContext = std::bind(
[](std::shared_ptr<const olp::logging::LogContext>& logContext,
TaskScheduler::CallFuncType& func) {
olp::logging::ScopedLogContext scopedContext(logContext);
func();
},
std::move(logContext), std::move(func));
#endif

queue_->Push({std::move(funcWithCapturedLogContext), priority});
queue_->Push({WrapWithLogContext(std::move(func)), priority});
}

} // namespace thread
Expand Down
Loading
Loading