Skip to content

Commit f31431a

Browse files
committed
Optimize ASU task manager with lock-free slots
1 parent 518c842 commit f31431a

2 files changed

Lines changed: 625 additions & 22 deletions

File tree

ucm/transport/kv/asu/common/task_manager_base.h

Lines changed: 200 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,20 +24,43 @@
2424
#pragma once
2525

2626
#include <atomic>
27+
#include <cstddef>
28+
#include <cstdint>
29+
#include <limits>
2730
#include <memory>
28-
#include <mutex>
2931
#include <string>
30-
#include <unordered_map>
3132
#include <utility>
33+
#include <vector>
34+
#include <functional>
35+
3236
#include "asu_transport/types.h"
3337

3438
namespace UC::ASU {
3539

3640
template <typename Context, typename State>
3741
class TaskManagerBase {
3842
public:
39-
TaskManagerBase(State initialState, std::string taskName)
40-
: initialState_(initialState), taskName_(std::move(taskName))
43+
static constexpr std::size_t kMinSlotCount = 1024;
44+
static constexpr std::size_t kDefaultSlotCount = 8192;
45+
46+
static std::size_t RecommendSlotCount(std::size_t maxInflightTasks)
47+
{
48+
// Keep load factor <= 0.5 for open addressing.
49+
// For example: 4096 inflight tasks -> 8192 slots.
50+
const auto required = std::max<std::size_t>(
51+
kMinSlotCount,
52+
maxInflightTasks * 2);
53+
return NormalizeSlotCount(required);
54+
}
55+
56+
explicit TaskManagerBase(
57+
State initialState,
58+
std::string taskName,
59+
std::size_t slotCount = kDefaultSlotCount)
60+
: initialState_(initialState),
61+
taskName_(std::move(taskName)),
62+
slots_(NormalizeSlotCount(slotCount)),
63+
slotMask_(slots_.size() - 1) // Used for efficient slot index calculation: bitwise modulo
4164
{
4265
}
4366

@@ -51,27 +74,89 @@ class TaskManagerBase {
5174
auto sharedCtx = std::shared_ptr<Context>(std::move(ctx));
5275
sharedCtx->state.store(initialState_, std::memory_order_release);
5376

54-
std::lock_guard<std::mutex> lock(mutex_);
77+
TaskId newTaskId = kInvalidTaskId;
5578
do {
56-
taskId = nextTaskId_.fetch_add(1, std::memory_order_relaxed);
57-
} while (taskId == kInvalidTaskId || tasks_.find(taskId) != tasks_.end());
79+
newTaskId = nextTaskId_.fetch_add(1, std::memory_order_relaxed);
80+
} while (newTaskId == kInvalidTaskId); // kInvalidTaskId is 0, so task id starts from 1 to avoid allocating invalid IDs
81+
82+
sharedCtx->taskId = newTaskId;
83+
84+
const auto start = Hash(newTaskId) & slotMask_;
85+
const auto capacity = slots_.size();
86+
87+
for (std::size_t probe = 0; probe < capacity; ++probe) {
88+
auto& slot = slots_[(start + probe) & slotMask_];
89+
90+
// CAS: Try to transition EMPTY → WRITING
91+
std::uint8_t expected = SlotState::EMPTY;
92+
if (!slot.state.compare_exchange_strong(
93+
expected,
94+
SlotState::WRITING,
95+
std::memory_order_acq_rel,
96+
std::memory_order_acquire)) {
97+
continue;
98+
}
99+
100+
AtomicStoreCtx(slot, sharedCtx, std::memory_order_release);
101+
slot.taskId.store(newTaskId, std::memory_order_release);
102+
slot.state.store(SlotState::READY, std::memory_order_release);
103+
104+
taskId = newTaskId;
105+
return Status::OK();
106+
}
107+
108+
taskId = kInvalidTaskId;
58109

59-
sharedCtx->taskId = taskId;
60-
tasks_.emplace(taskId, std::move(sharedCtx));
61-
return Status::OK();
110+
// Consider adding RESOURCE_EXHAUSTED / NO_SPACE error codes to StatusCode
111+
return Status::Error(
112+
StatusCode::INVALID_ARGUMENT,
113+
taskName_ + " task table is full");
62114
}
63115

64116
std::shared_ptr<Context> Get(TaskId taskId)
65117
{
66-
std::lock_guard<std::mutex> lock(mutex_);
67-
auto iter = tasks_.find(taskId);
68-
if (iter == tasks_.end()) { return nullptr; }
69-
return iter->second;
118+
if (taskId == kInvalidTaskId) {
119+
return nullptr;
120+
}
121+
122+
const auto start = Hash(taskId) & slotMask_;
123+
const auto capacity = slots_.size();
124+
125+
for (std::size_t probe = 0; probe < capacity; ++probe) {
126+
auto& slot = slots_[(start + probe) & slotMask_];
127+
128+
const auto state1 = slot.state.load(std::memory_order_acquire);
129+
if (state1 != SlotState::READY) {
130+
continue;
131+
}
132+
133+
const auto id1 = slot.taskId.load(std::memory_order_acquire);
134+
if (id1 != taskId) {
135+
continue;
136+
}
137+
138+
auto ptr = AtomicLoadCtx(slot, std::memory_order_acquire);
139+
if (!ptr) {
140+
continue;
141+
}
142+
143+
// Double-check to avoid returning a ctx from a reused slot.
144+
const auto id2 = slot.taskId.load(std::memory_order_acquire);
145+
const auto state2 = slot.state.load(std::memory_order_acquire);
146+
147+
if (state2 == SlotState::READY &&
148+
id2 == taskId &&
149+
ptr->taskId == taskId) {
150+
return ptr;
151+
}
152+
}
153+
154+
return nullptr;
70155
}
71156

72157
std::vector<std::shared_ptr<Context>> GetAll()
73158
{
74-
std::lock_guard<std::mutex> lock(mutex_);
159+
std::lock_guard<std::mutex> lock(mu_);
75160
std::vector<std::shared_ptr<Context>> tasks;
76161
tasks.reserve(tasks_.size());
77162
for (const auto& item : tasks_) { tasks.emplace_back(item.second); }
@@ -80,21 +165,114 @@ class TaskManagerBase {
80165

81166
Status Remove(TaskId taskId)
82167
{
83-
std::lock_guard<std::mutex> lock(mutex_);
84-
auto erased = tasks_.erase(taskId);
85-
if (erased == 0) {
86-
return Status::Error(StatusCode::TASK_NOT_FOUND, taskName_ + " task not found");
168+
if (taskId == kInvalidTaskId) {
169+
return Status::Error(
170+
StatusCode::TASK_NOT_FOUND,
171+
taskName_ + " task not found");
172+
}
173+
174+
const auto start = Hash(taskId) & slotMask_;
175+
const auto capacity = slots_.size();
176+
177+
for (std::size_t probe = 0; probe < capacity; ++probe) {
178+
auto& slot = slots_[(start + probe) & slotMask_];
179+
180+
const auto state = slot.state.load(std::memory_order_acquire);
181+
if (state != SlotState::READY) {
182+
continue; // Only process slots in READY state
183+
}
184+
185+
const auto id = slot.taskId.load(std::memory_order_acquire);
186+
if (id != taskId) {
187+
continue;
188+
}
189+
190+
std::uint8_t expected = SlotState::READY;
191+
if (!slot.state.compare_exchange_strong(
192+
expected,
193+
SlotState::REMOVING,
194+
std::memory_order_acq_rel,
195+
std::memory_order_acquire)) {
196+
continue; // CAS failed, continue probing
197+
}
198+
199+
AtomicStoreCtx(slot, std::shared_ptr<Context>{}, std::memory_order_release);
200+
slot.taskId.store(kInvalidTaskId, std::memory_order_release);
201+
slot.state.store(SlotState::EMPTY, std::memory_order_release);
202+
203+
return Status::OK();
204+
}
205+
206+
return Status::Error(
207+
StatusCode::TASK_NOT_FOUND,
208+
taskName_ + " task not found");
209+
}
210+
211+
private:
212+
struct SlotState {
213+
static constexpr std::uint8_t EMPTY = 0;
214+
static constexpr std::uint8_t WRITING = 1;
215+
static constexpr std::uint8_t READY = 2;
216+
static constexpr std::uint8_t REMOVING = 3;
217+
};
218+
// Ensure each Slot is aligned to 64 bytes to avoid False Sharing
219+
struct alignas(64) Slot {
220+
std::atomic<std::uint8_t> state{SlotState::EMPTY};
221+
std::atomic<TaskId> taskId{kInvalidTaskId};
222+
223+
// Use atomic_load/atomic_store free functions for shared_ptr.
224+
// This avoids requiring C++20 std::atomic<std::shared_ptr<T>>.
225+
std::shared_ptr<Context> ctx;
226+
};
227+
228+
private:
229+
static std::size_t NormalizeSlotCount(std::size_t n)
230+
{
231+
n = std::max<std::size_t>(n, kMinSlotCount);
232+
233+
std::size_t power = 1;
234+
while (power < n) {
235+
if (power > (std::numeric_limits<std::size_t>::max() >> 1)) {
236+
return power;
237+
}
238+
power <<= 1;
87239
}
88-
return Status::OK();
240+
241+
return power;
242+
}
243+
244+
static std::size_t Hash(TaskId taskId)
245+
{
246+
return std::hash<TaskId>{}(taskId);
247+
}
248+
249+
// Atomically load shared_ptr<Context> from Slot, ensuring thread safety
250+
static std::shared_ptr<Context> AtomicLoadCtx(
251+
const Slot& slot,
252+
std::memory_order order)
253+
{
254+
return std::atomic_load_explicit(&slot.ctx, order);
255+
}
256+
257+
// Atomically store task context sharedCtx into slot
258+
static void AtomicStoreCtx(
259+
Slot& slot,
260+
std::shared_ptr<Context> ptr,
261+
std::memory_order order)
262+
{
263+
std::atomic_store_explicit(&slot.ctx, std::move(ptr), order);
89264
}
90265

91266
private:
92267
State initialState_;
93268
std::string taskName_;
94269
std::atomic<TaskId> nextTaskId_{1};
95270
// TODO: consider using a lock-free structure !
96-
std::mutex mutex_;
271+
std::mutex mu_;
97272
std::unordered_map<TaskId, std::shared_ptr<Context>> tasks_;
273+
274+
std::vector<Slot> slots_;
275+
std::size_t slotMask_{0};
98276
};
99277

100-
} // namespace UC::ASU
278+
} // namespace UC::ASU

0 commit comments

Comments
 (0)