Skip to content

Commit 78c673f

Browse files
committed
Optimize ASU task manager with lock-free slots
1 parent 518c842 commit 78c673f

2 files changed

Lines changed: 952 additions & 25 deletions

File tree

ucm/transport/kv/asu/common/task_manager_base.h

Lines changed: 252 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,50 @@
2323
* */
2424
#pragma once
2525

26+
#include <algorithm>
2627
#include <atomic>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <limits>
2731
#include <memory>
28-
#include <mutex>
2932
#include <string>
30-
#include <unordered_map>
33+
#include <type_traits>
3134
#include <utility>
35+
#include <vector>
3236
#include "asu_transport/types.h"
3337

3438
namespace UC::ASU {
3539

3640
template <typename Context, typename State>
3741
class TaskManagerBase {
3842
public:
39-
TaskManagerBase(State initialState, std::string taskName)
40-
: initialState_(initialState), taskName_(std::move(taskName))
43+
static constexpr std::size_t kMinSlotCount = 1024;
44+
static constexpr std::size_t kDefaultMaxInflightTasks = 4096;
45+
46+
static std::size_t RecommendSlotCount(std::size_t maxInflightTasks)
47+
{
48+
// Reserve extra slots for the configured maximum inflight workload.
49+
// For example: 4096 inflight tasks -> 8192 slots.
50+
const auto required = std::max<std::size_t>(kMinSlotCount, maxInflightTasks * 2);
51+
return NormalizeSlotCount(required);
52+
}
53+
54+
explicit TaskManagerBase(State initialState, std::string taskName,
55+
std::size_t maxInflightTasks = kDefaultMaxInflightTasks)
56+
: initialState_(initialState),
57+
taskName_(std::move(taskName)),
58+
slotIndexBits_(ComputeSlotIndexBits(RecommendSlotCount(maxInflightTasks))),
59+
slots_(RecommendSlotCount(maxInflightTasks)),
60+
slotMask_(slots_.size() - 1),
61+
freeListShift_(slotIndexBits_ + 1),
62+
freeListMask_(MakeLowBitsMask(slotIndexBits_ + 1)),
63+
freeListEnd_(slots_.size()),
64+
freeListHead_(PackFreeListHead(0, 0))
4165
{
66+
for (std::size_t i = 0; i + 1 < slots_.size(); ++i) {
67+
slots_[i].freeNext.store(i + 1, std::memory_order_relaxed);
68+
}
69+
slots_[slots_.size() - 1].freeNext.store(freeListEnd_, std::memory_order_relaxed);
4270
}
4371

4472
Status Submit(std::unique_ptr<Context> ctx, TaskId& taskId)
@@ -48,53 +76,252 @@ class TaskManagerBase {
4876
return Status::Error(StatusCode::INVALID_ARGUMENT, taskName_ + " task context is null");
4977
}
5078

79+
auto slotIndex = FreeListPop();
80+
if (slotIndex == freeListEnd_) {
81+
taskId = kInvalidTaskId;
82+
return Status::Error(StatusCode::INTERNAL_ERROR, taskName_ + " task table is full");
83+
}
84+
85+
auto& slot = slots_[slotIndex];
86+
87+
std::uint8_t expected = SlotState::EMPTY;
88+
if (!slot.state.compare_exchange_strong(expected, SlotState::WRITING,
89+
std::memory_order_acq_rel,
90+
std::memory_order_acquire)) {
91+
FreeListPush(slotIndex);
92+
taskId = kInvalidTaskId;
93+
return Status::Error(StatusCode::INTERNAL_ERROR, taskName_ + " task slot not empty");
94+
}
95+
96+
const auto generation = slot.generation.fetch_add(1, std::memory_order_relaxed) + 1;
97+
if (!CanEncodeGeneration(generation)) {
98+
AtomicStoreCtx(slot, std::shared_ptr<Context>{}, std::memory_order_release);
99+
slot.taskId.store(kInvalidTaskId, std::memory_order_release);
100+
slot.state.store(SlotState::EMPTY, std::memory_order_release);
101+
FreeListPush(slotIndex);
102+
taskId = kInvalidTaskId;
103+
return Status::Error(
104+
StatusCode::INTERNAL_ERROR,
105+
taskName_ + " task id generation overflow");
106+
}
107+
108+
const auto newTaskId = MakeTaskId(slotIndex, generation);
109+
ctx->state.store(initialState_, std::memory_order_release);
110+
ctx->taskId = newTaskId;
51111
auto sharedCtx = std::shared_ptr<Context>(std::move(ctx));
52-
sharedCtx->state.store(initialState_, std::memory_order_release);
53112

54-
std::lock_guard<std::mutex> lock(mutex_);
55-
do {
56-
taskId = nextTaskId_.fetch_add(1, std::memory_order_relaxed);
57-
} while (taskId == kInvalidTaskId || tasks_.find(taskId) != tasks_.end());
113+
AtomicStoreCtx(slot, sharedCtx, std::memory_order_release);
114+
slot.taskId.store(newTaskId, std::memory_order_release);
115+
slot.state.store(SlotState::READY, std::memory_order_release);
58116

59-
sharedCtx->taskId = taskId;
60-
tasks_.emplace(taskId, std::move(sharedCtx));
117+
taskId = newTaskId;
61118
return Status::OK();
62119
}
63120

64121
std::shared_ptr<Context> Get(TaskId taskId)
65122
{
66-
std::lock_guard<std::mutex> lock(mutex_);
67-
auto iter = tasks_.find(taskId);
68-
if (iter == tasks_.end()) { return nullptr; }
69-
return iter->second;
123+
if (taskId == kInvalidTaskId) { return nullptr; }
124+
125+
const auto slotIndex = static_cast<std::size_t>(ToTaskIdUInt(taskId) & slotMask_);
126+
auto& slot = slots_[slotIndex];
127+
128+
const auto state1 = slot.state.load(std::memory_order_acquire);
129+
if (state1 != SlotState::READY) { return nullptr; }
130+
131+
const auto id1 = slot.taskId.load(std::memory_order_acquire);
132+
if (id1 != taskId) { return nullptr; }
133+
134+
auto ptr = AtomicLoadCtx(slot, std::memory_order_acquire);
135+
if (!ptr) { return nullptr; }
136+
137+
const auto id2 = slot.taskId.load(std::memory_order_acquire);
138+
const auto state2 = slot.state.load(std::memory_order_acquire);
139+
if (state2 == SlotState::READY && id2 == taskId && ptr->taskId == taskId) { return ptr; }
140+
141+
return nullptr;
70142
}
71143

72144
std::vector<std::shared_ptr<Context>> GetAll()
73145
{
74-
std::lock_guard<std::mutex> lock(mutex_);
75146
std::vector<std::shared_ptr<Context>> tasks;
76-
tasks.reserve(tasks_.size());
77-
for (const auto& item : tasks_) { tasks.emplace_back(item.second); }
147+
for (const auto& slot : slots_) {
148+
if (slot.state.load(std::memory_order_acquire) != SlotState::READY) { continue; }
149+
150+
auto ctx = AtomicLoadCtx(slot, std::memory_order_acquire);
151+
if (!ctx) { continue; }
152+
153+
const auto taskId = slot.taskId.load(std::memory_order_acquire);
154+
const auto state = slot.state.load(std::memory_order_acquire);
155+
if (state == SlotState::READY && taskId == ctx->taskId) {
156+
tasks.emplace_back(std::move(ctx));
157+
}
158+
}
78159
return tasks;
79160
}
80161

81162
Status Remove(TaskId taskId)
82163
{
83-
std::lock_guard<std::mutex> lock(mutex_);
84-
auto erased = tasks_.erase(taskId);
85-
if (erased == 0) {
164+
if (taskId == kInvalidTaskId) {
86165
return Status::Error(StatusCode::TASK_NOT_FOUND, taskName_ + " task not found");
87166
}
167+
168+
const auto slotIndex = static_cast<std::size_t>(ToTaskIdUInt(taskId) & slotMask_);
169+
auto& slot = slots_[slotIndex];
170+
171+
std::uint8_t expected = SlotState::READY;
172+
if (!slot.state.compare_exchange_strong(
173+
expected,
174+
SlotState::REMOVING,
175+
std::memory_order_acq_rel,
176+
std::memory_order_acquire)) {
177+
return Status::Error(StatusCode::TASK_NOT_FOUND, taskName_ + " task not found");
178+
}
179+
180+
if (slot.taskId.load(std::memory_order_acquire) != taskId) {
181+
slot.state.store(SlotState::READY, std::memory_order_release);
182+
return Status::Error(StatusCode::TASK_NOT_FOUND, taskName_ + " task not found");
183+
}
184+
185+
AtomicStoreCtx(slot, std::shared_ptr<Context>{}, std::memory_order_release);
186+
slot.taskId.store(kInvalidTaskId, std::memory_order_release);
187+
slot.state.store(SlotState::EMPTY, std::memory_order_release);
188+
189+
FreeListPush(slotIndex);
88190
return Status::OK();
89191
}
90192

193+
private:
194+
using TaskIdUInt = std::make_unsigned_t<TaskId>;
195+
196+
struct SlotState {
197+
static constexpr std::uint8_t EMPTY = 0;
198+
static constexpr std::uint8_t WRITING = 1;
199+
static constexpr std::uint8_t READY = 2;
200+
static constexpr std::uint8_t REMOVING = 3;
201+
};
202+
203+
struct alignas(64) Slot {
204+
std::atomic<std::uint8_t> state{SlotState::EMPTY};
205+
std::atomic<TaskIdUInt> generation{0};
206+
std::atomic<TaskId> taskId{kInvalidTaskId};
207+
std::shared_ptr<Context> ctx;
208+
std::atomic<std::size_t> freeNext{0};
209+
};
210+
211+
private:
212+
static TaskIdUInt ToTaskIdUInt(TaskId taskId) { return static_cast<TaskIdUInt>(taskId); }
213+
214+
static std::size_t NormalizeSlotCount(std::size_t n)
215+
{
216+
n = std::max<std::size_t>(n, kMinSlotCount);
217+
218+
std::size_t power = 1;
219+
while (power < n) {
220+
if (power > (std::numeric_limits<std::size_t>::max() >> 1)) { return power; }
221+
power <<= 1;
222+
}
223+
224+
return power;
225+
}
226+
227+
static std::size_t ComputeSlotIndexBits(std::size_t slotCount)
228+
{
229+
std::size_t bits = 0;
230+
for (auto s = slotCount; s > 1; s >>= 1) { ++bits; }
231+
return bits;
232+
}
233+
234+
static std::uint64_t MakeLowBitsMask(std::size_t bits)
235+
{
236+
if (bits >= std::numeric_limits<std::uint64_t>::digits) {
237+
return std::numeric_limits<std::uint64_t>::max();
238+
}
239+
return (1ULL << bits) - 1;
240+
}
241+
242+
bool CanEncodeGeneration(TaskIdUInt generation) const
243+
{
244+
constexpr std::size_t kTotalBits =
245+
static_cast<std::size_t>(std::numeric_limits<TaskIdUInt>::digits);
246+
247+
if (slotIndexBits_ >= kTotalBits) { return false; }
248+
249+
if (generation == 0) { return false; }
250+
251+
const auto generationBits = kTotalBits - slotIndexBits_;
252+
if (generationBits >= kTotalBits) { return true; }
253+
254+
const auto maxGeneration = (static_cast<TaskIdUInt>(1) << generationBits) - 1;
255+
return generation <= maxGeneration;
256+
}
257+
258+
TaskId MakeTaskId(std::size_t slotIndex, TaskIdUInt generation) const
259+
{
260+
const auto raw =
261+
(generation << slotIndexBits_) | static_cast<TaskIdUInt>(slotIndex & slotMask_);
262+
return static_cast<TaskId>(raw);
263+
}
264+
265+
static std::shared_ptr<Context> AtomicLoadCtx(const Slot& slot, std::memory_order order)
266+
{
267+
return std::atomic_load_explicit(&slot.ctx, order);
268+
}
269+
270+
static void AtomicStoreCtx(Slot& slot, std::shared_ptr<Context> ptr, std::memory_order order)
271+
{
272+
std::atomic_store_explicit(&slot.ctx, std::move(ptr), order);
273+
}
274+
275+
std::uint64_t PackFreeListHead(std::uint64_t generation, std::size_t index) const
276+
{
277+
return (generation << freeListShift_) | static_cast<std::uint64_t>(index);
278+
}
279+
280+
std::size_t FreeListPop()
281+
{
282+
auto oldHead = freeListHead_.load(std::memory_order_acquire);
283+
while (true) {
284+
const auto index = static_cast<std::size_t>(oldHead & freeListMask_);
285+
if (index == freeListEnd_) { return freeListEnd_; }
286+
287+
const auto nextIndex = slots_[index].freeNext.load(std::memory_order_acquire);
288+
const auto oldGen = oldHead >> freeListShift_;
289+
const auto newHead = PackFreeListHead(oldGen + 1, nextIndex);
290+
if (freeListHead_.compare_exchange_weak(oldHead, newHead, std::memory_order_acq_rel,
291+
std::memory_order_acquire)) {
292+
return index;
293+
}
294+
}
295+
}
296+
297+
void FreeListPush(std::size_t slotIndex)
298+
{
299+
auto oldHead = freeListHead_.load(std::memory_order_acquire);
300+
while (true) {
301+
slots_[slotIndex].freeNext.store(static_cast<std::size_t>(oldHead & freeListMask_),
302+
std::memory_order_release);
303+
304+
const auto oldGen = oldHead >> freeListShift_;
305+
const auto newHead = PackFreeListHead(oldGen + 1, slotIndex);
306+
if (freeListHead_.compare_exchange_weak(oldHead, newHead, std::memory_order_release,
307+
std::memory_order_acquire)) {
308+
return;
309+
}
310+
}
311+
}
312+
91313
private:
92314
State initialState_;
93315
std::string taskName_;
94-
std::atomic<TaskId> nextTaskId_{1};
95-
// TODO: consider using a lock-free structure !
96-
std::mutex mutex_;
97-
std::unordered_map<TaskId, std::shared_ptr<Context>> tasks_;
316+
std::size_t slotIndexBits_{0};
317+
318+
std::vector<Slot> slots_;
319+
std::size_t slotMask_{0};
320+
321+
std::size_t freeListShift_{0};
322+
std::uint64_t freeListMask_{0};
323+
std::size_t freeListEnd_{0};
324+
std::atomic<std::uint64_t> freeListHead_{0};
98325
};
99326

100327
} // namespace UC::ASU

0 commit comments

Comments
 (0)