Skip to content

Commit 116cc61

Browse files
author
shichangzhang064
committed
feat: bridge TE wait via poll executor for coroutine-friendly paths
1 parent 92f0829 commit 116cc61

7 files changed

Lines changed: 188 additions & 81 deletions

File tree

mooncake-integration/store/store_py.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,7 @@ PYBIND11_MODULE(store, m) {
10451045
size_t async_sender_thread_count = 0,
10461046
size_t async_max_batch_size = 2000,
10471047
size_t async_route_queue_size = 0,
1048+
size_t te_async_poll_worker_num = 32,
10481049
const py::object& engine = py::none()) {
10491050
auto& resource_tracker = ResourceTracker::getInstance();
10501051
self.use_dummy_client_ = false;
@@ -1070,7 +1071,7 @@ PYBIND11_MODULE(store, m) {
10701071
local_transfer_mode, local_memcpy_async_worker_num,
10711072
metrics_port, enable_metrics_http, {},
10721073
async_sender_thread_count, async_max_batch_size,
1073-
async_route_queue_size);
1074+
async_route_queue_size, te_async_poll_worker_num);
10741075

10751076
auto ret = real_client->setup(config);
10761077
self.store_ = real_client;
@@ -1091,6 +1092,7 @@ PYBIND11_MODULE(store, m) {
10911092
py::arg("async_sender_thread_count") = 0,
10921093
py::arg("async_max_batch_size") = 2000,
10931094
py::arg("async_route_queue_size") = 0,
1095+
py::arg("te_async_poll_worker_num") = 32,
10941096
py::arg("engine") = py::none(),
10951097
"Setup the store in P2P architecture.")
10961098
.def(

mooncake-store/include/client_config_builder.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,11 @@ struct P2PClientConfig : RealClientConfigBase {
171171
// When local_transfer_mode == MEMCPY, the following parameter is used:
172172
// 0 means forbid async memcpy (fall back to synchronous).
173173
size_t local_memcpy_async_worker_num = 32;
174+
175+
// Worker threads for offloading TransferEngine batch polling (WaitAllTransferBatches)
176+
// in DataManager: local TE Put path and remote forward TE (co_await) paths. Independent
177+
// of local_transfer_mode. 0 means synchronous TE wait on the caller/coroutine thread.
178+
size_t te_async_poll_worker_num = 32;
174179
};
175180

176181
// ============================================================================
@@ -238,7 +243,8 @@ class ClientConfigBuilder {
238243
bool enable_metrics_http = true,
239244
const std::map<std::string, std::string>& labels = {},
240245
size_t async_sender_thread_count = 0,
241-
size_t async_max_batch_size = 2000, size_t async_route_queue_size = 0) {
246+
size_t async_max_batch_size = 2000, size_t async_route_queue_size = 0,
247+
size_t te_async_poll_worker_num = 32) {
242248
P2PClientConfig config;
243249
fill_real_client_config_base(
244250
config, local_hostname, metadata_connstring, protocol, rdma_devices,
@@ -251,6 +257,7 @@ class ClientConfigBuilder {
251257
config.route_cache_ttl_ms = route_cache_ttl_ms;
252258
config.local_transfer_mode =
253259
parse_p2p_local_transfer_mode(local_transfer_mode);
260+
config.te_async_poll_worker_num = te_async_poll_worker_num;
254261
if (config.local_transfer_mode == LocalTransferMode::MEMCPY) {
255262
config.local_memcpy_async_worker_num =
256263
local_memcpy_async_worker_num;

mooncake-store/include/data_manager.h

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include <thread>
1414
#include <unordered_map>
1515
#include <vector>
16+
#include <async_simple/coro/Lazy.h>
1617
#include <ylt/util/tl/expected.hpp>
1718
#include "async_memcpy_executor.h"
1819
#include "client_buffer.hpp"
@@ -51,6 +52,11 @@ struct LocalTransferConfig {
5152
// When mode == MEMCPY, the following parameters are used:
5253
// 0 means forbid async memcpy (fall back to synchronous).
5354
size_t local_memcpy_async_worker_num = 32;
55+
56+
// Dedicated worker threads to offload TransferEngine batch polling
57+
// (WaitAllTransferBatches) for local TE Put and remote forward TE paths.
58+
// Independent of `mode`. 0 keeps synchronous TE wait on the caller thread.
59+
size_t te_async_poll_worker_num = 32;
5460
};
5561

5662
/**
@@ -211,15 +217,17 @@ class DataManager {
211217
const UUID& pin_token);
212218

213219
/**
214-
* @brief TE transfer without tier DRAM staging (PrepareDRAM*).
220+
* @brief TE transfer without tier DRAM staging (PrepareDRAM*), with TE
221+
* completion polling offloaded when `te_async_poll_worker_num > 0`.
215222
*
216223
* Caller guarantees `local_transfer_base` covers a contiguous layout of
217-
* `total_size` bytes that is valid for TransferEngine (typically registered
218-
* DRAM). Used by forward RDMA paths where buffers are already TE-ready.
224+
* `total_size` bytes valid for TransferEngine (typically registered DRAM).
225+
* Used by forward RDMA paths where buffers are already TE-ready.
219226
*
220227
* @param opcode WRITE: local -> peer_buffers; READ: peer_buffers -> local
221228
*/
222-
tl::expected<void, ErrorCode> TransferWithTeNoTierStaging(
229+
async_simple::coro::Lazy<tl::expected<void, ErrorCode>>
230+
TransferWithTeNoTierStagingAsync(
223231
void* local_transfer_base, size_t total_size,
224232
const std::vector<RemoteBufferDesc>& peer_buffers,
225233
Transport::TransferRequest::OpCode opcode);
@@ -350,6 +358,19 @@ class DataManager {
350358
const std::vector<RemoteBufferDesc>& remote_buffers,
351359
Transport::TransferRequest::OpCode opcode);
352360

361+
tl::expected<std::vector<std::tuple<Transport::BatchID, size_t, std::string>>,
362+
ErrorCode>
363+
SubmitTeNoTierStagingBatches(
364+
void* local_transfer_base, size_t total_size,
365+
const std::vector<RemoteBufferDesc>& peer_buffers,
366+
Transport::TransferRequest::OpCode opcode);
367+
368+
/** Synchronous TE wait on caller thread (used when no poll executor). */
369+
tl::expected<void, ErrorCode> TransferWithTeNoTierStaging(
370+
void* local_transfer_base, size_t total_size,
371+
const std::vector<RemoteBufferDesc>& peer_buffers,
372+
Transport::TransferRequest::OpCode opcode);
373+
353374
tl::expected<
354375
std::vector<std::tuple<Transport::BatchID, size_t, std::string>>,
355376
ErrorCode>
@@ -517,6 +538,7 @@ class DataManager {
517538

518539
LocalTransferConfig local_transfer_config_;
519540
std::unique_ptr<AsyncMemcpyExecutor> async_memcpy_executor_;
541+
std::unique_ptr<AsyncMemcpyExecutor> te_poll_executor_;
520542
std::chrono::milliseconds lease_duration_;
521543
std::chrono::milliseconds lease_scan_interval_;
522544
std::atomic<bool> lease_scanner_stop_requested_{false};

mooncake-store/src/data_manager.cpp

Lines changed: 108 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
#include "utils/scoped_vlog_timer.h"
1919
#include "utils.h"
2020

21+
#include <async_simple/coro/FutureAwaiter.h>
22+
2123
namespace mooncake {
2224

2325
namespace {
@@ -148,6 +150,10 @@ DataManager::DataManager(std::unique_ptr<TieredBackend> tiered_backend,
148150
async_memcpy_executor_ = std::make_unique<AsyncMemcpyExecutor>(
149151
local_transfer_config_.local_memcpy_async_worker_num);
150152
}
153+
if (local_transfer_config_.te_async_poll_worker_num > 0) {
154+
te_poll_executor_ = std::make_unique<AsyncMemcpyExecutor>(
155+
local_transfer_config_.te_async_poll_worker_num);
156+
}
151157

152158
lease_duration_ = std::chrono::milliseconds(GetEnvOr<uint32_t>(
153159
"P2P_RPC_LEASE_DURATION_MS", kDefaultLeaseDurationMs));
@@ -162,8 +168,10 @@ DataManager::DataManager(std::unique_ptr<TieredBackend> tiered_backend,
162168
? "TE"
163169
: "MEMCPY")
164170
<< ", te_endpoint=" << local_transfer_config_.te_endpoint
165-
<< ", async_memcpy_workers="
171+
<< ", memcpy_async_workers="
166172
<< local_transfer_config_.local_memcpy_async_worker_num
173+
<< ", te_async_poll_workers="
174+
<< local_transfer_config_.te_async_poll_worker_num
167175
<< ", lease_duration_ms=" << lease_duration_.count()
168176
<< ", lease_scan_interval_ms=" << lease_scan_interval_.count();
169177
}
@@ -176,6 +184,9 @@ void DataManager::Stop() {
176184
if (async_memcpy_executor_) {
177185
async_memcpy_executor_->Shutdown();
178186
}
187+
if (te_poll_executor_) {
188+
te_poll_executor_->Shutdown();
189+
}
179190
if (tiered_backend_) {
180191
tiered_backend_->Stop();
181192
}
@@ -485,16 +496,6 @@ tl::expected<std::unique_ptr<TaskHandle<void>>, ErrorCode> DataManager::Put(
485496
return tl::unexpected(ErrorCode::INTERNAL_ERROR);
486497
}
487498

488-
// TODO: The returned CallableTaskHandle's WaitAsync() falls back to a
489-
// synchronous Wait() on the coroutine's current thread, because the
490-
// WaitAllTransferBatches() is a loop with no async completion notification.
491-
// Possible optimizations:
492-
// (1) run a polling coroutine on yalantinglibs coro_io's io_context (via
493-
// co_await coro_io::sleep_for(100us) + getTransferStatus), no new thread;
494-
// (2) introduce a lightweight timer service to bridge cv-poll to
495-
// async_simple::Promise;
496-
// (3) introduce a completion callback from transfer_engine itself.
497-
// Once any of these lands, switch the return type to FutureHandle.
498499
tl::expected<std::unique_ptr<TaskHandle<void>>, ErrorCode>
499500
DataManager::PutViaTe(std::string_view key, std::vector<Slice>& slices) {
500501
// using Te, treat local memory as remote memory
@@ -536,46 +537,51 @@ DataManager::PutViaTe(std::string_view key, std::vector<Slice>& slices) {
536537
return tl::unexpected(submit_result.error());
537538
}
538539

539-
return CallableTaskHandle<void>::Create(
540-
[this, ctx = std::move(*submit_result), alloc_handle, kctx,
541-
pending_write_token]() mutable -> tl::expected<void, ErrorCode> {
542-
ScopedVLogTimer timer(1, "DataManager::PutViaTe");
543-
timer.LogRequest("key=", kctx.key);
540+
auto te_phase = [this, ctx = std::move(*submit_result), alloc_handle, kctx,
541+
pending_write_token]() mutable -> tl::expected<void, ErrorCode> {
542+
ScopedVLogTimer timer(1, "DataManager::PutViaTe");
543+
timer.LogRequest("key=", kctx.key);
544+
545+
auto wait_result = WaitAllTransferBatches(ctx.transfer_batches);
546+
if (!wait_result) {
547+
LOG(ERROR) << "WaitAllTransferBatches failed"
548+
<< ", key=" << kctx.key
549+
<< ", error_code=" << toString(wait_result.error());
550+
(void)WriteRevokeInternal(kctx, pending_write_token);
551+
return tl::unexpected(wait_result.error());
552+
}
544553

545-
auto wait_result = WaitAllTransferBatches(ctx.transfer_batches);
546-
if (!wait_result) {
547-
LOG(ERROR) << "WaitAllTransferBatches failed"
554+
if (ctx.handle->loc.data.type != MemoryType::DRAM && ctx.temp_buffer) {
555+
auto& loc_data = ctx.handle->loc.data;
556+
auto copy_result = CopyFromDRAMBuffer(
557+
ctx.temp_buffer.get(),
558+
reinterpret_cast<void*>(loc_data.buffer->data()),
559+
loc_data.type, loc_data.buffer->size(), ctx.handle->backend);
560+
if (!copy_result) {
561+
LOG(ERROR) << "CopyFromDRAMBuffer failed"
548562
<< ", key=" << kctx.key
549-
<< ", error_code=" << toString(wait_result.error());
563+
<< ", error_code=" << toString(copy_result.error());
550564
(void)WriteRevokeInternal(kctx, pending_write_token);
551-
return tl::unexpected(wait_result.error());
565+
return tl::unexpected(copy_result.error());
552566
}
567+
}
553568

554-
if (ctx.handle->loc.data.type != MemoryType::DRAM &&
555-
ctx.temp_buffer) {
556-
auto& loc_data = ctx.handle->loc.data;
557-
auto copy_result = CopyFromDRAMBuffer(
558-
ctx.temp_buffer.get(),
559-
reinterpret_cast<void*>(loc_data.buffer->data()),
560-
loc_data.type, loc_data.buffer->size(),
561-
ctx.handle->backend);
562-
if (!copy_result) {
563-
LOG(ERROR)
564-
<< "CopyFromDRAMBuffer failed"
565-
<< ", key=" << kctx.key
566-
<< ", error_code=" << toString(copy_result.error());
567-
(void)WriteRevokeInternal(kctx, pending_write_token);
568-
return tl::unexpected(copy_result.error());
569-
}
570-
}
569+
auto commit_result = WriteCommitInternal(kctx, pending_write_token);
570+
if (!commit_result) {
571+
return tl::unexpected(commit_result.error());
572+
}
573+
timer.LogResponse("error_code=", ErrorCode::OK);
574+
return {};
575+
};
571576

572-
auto commit_result = WriteCommitInternal(kctx, pending_write_token);
573-
if (!commit_result) {
574-
return tl::unexpected(commit_result.error());
575-
}
576-
timer.LogResponse("error_code=", ErrorCode::OK);
577-
return {};
578-
});
577+
if (te_poll_executor_) {
578+
auto future = te_poll_executor_
579+
->SubmitSingleTask<tl::expected<void, ErrorCode>>(
580+
std::move(te_phase));
581+
return FutureHandle<void>::Create(std::shared_ptr<void>{},
582+
std::move(future));
583+
}
584+
return CallableTaskHandle<void>::Create(std::move(te_phase));
579585
}
580586

581587
tl::expected<std::unique_ptr<TaskHandle<void>>, ErrorCode>
@@ -756,19 +762,27 @@ tl::expected<ReadTaskHandle, ErrorCode> DataManager::BuildDataCopierViaTe(
756762

757763
ReadTaskHandle res;
758764
res.data_size = static_cast<int64_t>(source_size);
759-
res.task_handle = CallableTaskHandle<void>::Create(
760-
[this, ctx = std::move(submit_result.value()),
761-
h = handle]() mutable -> tl::expected<void, ErrorCode> {
762-
ScopedVLogTimer timer(1, "DataManager::BuildDataCopierViaTe");
763-
auto wait_result = WaitAllTransferBatches(ctx.transfer_batches);
764-
if (!wait_result) {
765-
LOG(ERROR) << "Failed to wait TE read transfer, error_code="
766-
<< wait_result.error();
767-
return tl::unexpected(wait_result.error());
768-
}
769-
timer.LogResponse("error_code=", ErrorCode::OK);
770-
return {};
771-
});
765+
auto te_wait = [this, ctx = std::move(submit_result.value()),
766+
h = handle]() mutable -> tl::expected<void, ErrorCode> {
767+
ScopedVLogTimer timer(1, "DataManager::BuildDataCopierViaTe");
768+
auto wait_result = WaitAllTransferBatches(ctx.transfer_batches);
769+
if (!wait_result) {
770+
LOG(ERROR) << "Failed to wait TE read transfer, error_code="
771+
<< wait_result.error();
772+
return tl::unexpected(wait_result.error());
773+
}
774+
timer.LogResponse("error_code=", ErrorCode::OK);
775+
return {};
776+
};
777+
if (te_poll_executor_) {
778+
auto future = te_poll_executor_
779+
->SubmitSingleTask<tl::expected<void, ErrorCode>>(
780+
std::move(te_wait));
781+
res.task_handle = FutureHandle<void>::Create(std::shared_ptr<void>{},
782+
std::move(future));
783+
} else {
784+
res.task_handle = CallableTaskHandle<void>::Create(std::move(te_wait));
785+
}
772786
return res;
773787
}
774788

@@ -1375,7 +1389,9 @@ DataManager::SubmitTeTransferBatches(
13751389
return submitted_batches;
13761390
}
13771391

1378-
tl::expected<void, ErrorCode> DataManager::TransferWithTeNoTierStaging(
1392+
tl::expected<std::vector<std::tuple<Transport::BatchID, size_t, std::string>>,
1393+
ErrorCode>
1394+
DataManager::SubmitTeNoTierStagingBatches(
13791395
void* local_transfer_base, size_t total_size,
13801396
const std::vector<RemoteBufferDesc>& peer_buffers,
13811397
Transport::TransferRequest::OpCode opcode) {
@@ -1394,8 +1410,16 @@ tl::expected<void, ErrorCode> DataManager::TransferWithTeNoTierStaging(
13941410
LOG(ERROR) << "TransferEngine not initialized";
13951411
return tl::make_unexpected(ErrorCode::INTERNAL_ERROR);
13961412
}
1397-
auto batches = SubmitTeTransferBatches(local_transfer_base, total_size,
1398-
peer_buffers, opcode);
1413+
return SubmitTeTransferBatches(local_transfer_base, total_size, peer_buffers,
1414+
opcode);
1415+
}
1416+
1417+
tl::expected<void, ErrorCode> DataManager::TransferWithTeNoTierStaging(
1418+
void* local_transfer_base, size_t total_size,
1419+
const std::vector<RemoteBufferDesc>& peer_buffers,
1420+
Transport::TransferRequest::OpCode opcode) {
1421+
auto batches = SubmitTeNoTierStagingBatches(
1422+
local_transfer_base, total_size, peer_buffers, opcode);
13991423
if (!batches) {
14001424
return tl::unexpected(batches.error());
14011425
}
@@ -1409,6 +1433,29 @@ tl::expected<void, ErrorCode> DataManager::TransferWithTeNoTierStaging(
14091433
return {};
14101434
}
14111435

1436+
async_simple::coro::Lazy<tl::expected<void, ErrorCode>>
1437+
DataManager::TransferWithTeNoTierStagingAsync(
1438+
void* local_transfer_base, size_t total_size,
1439+
const std::vector<RemoteBufferDesc>& peer_buffers,
1440+
Transport::TransferRequest::OpCode opcode) {
1441+
if (!te_poll_executor_) {
1442+
co_return TransferWithTeNoTierStaging(local_transfer_base, total_size,
1443+
peer_buffers, opcode);
1444+
}
1445+
auto batches = SubmitTeNoTierStagingBatches(
1446+
local_transfer_base, total_size, peer_buffers, opcode);
1447+
if (!batches) {
1448+
co_return tl::make_unexpected(batches.error());
1449+
}
1450+
auto batch_vec = std::move(batches.value());
1451+
auto fut =
1452+
te_poll_executor_->SubmitSingleTask<tl::expected<void, ErrorCode>>(
1453+
[this, batch_vec = std::move(batch_vec)]() mutable {
1454+
return WaitAllTransferBatches(batch_vec);
1455+
});
1456+
co_return co_await std::move(fut);
1457+
}
1458+
14121459
tl::expected<void, ErrorCode> DataManager::ValidateRemoteBuffers(
14131460
const std::vector<RemoteBufferDesc>& buffers) {
14141461
if (buffers.empty()) {

0 commit comments

Comments
 (0)