From b7f75993f2ad20ce8faf01150a556adb5b0e63d9 Mon Sep 17 00:00:00 2001 From: shichangzhang064 Date: Thu, 7 May 2026 20:25:42 +0800 Subject: [PATCH 01/14] [Store] Implement 3-phase read/write contexts in DataManager --- mooncake-store/include/data_manager.h | 149 +++- mooncake-store/src/data_manager.cpp | 748 ++++++++++++++++++--- mooncake-store/tests/data_manager_test.cpp | 203 ++++++ 3 files changed, 990 insertions(+), 110 deletions(-) diff --git a/mooncake-store/include/data_manager.h b/mooncake-store/include/data_manager.h index a34f67dbab..c06fe6b531 100644 --- a/mooncake-store/include/data_manager.h +++ b/mooncake-store/include/data_manager.h @@ -1,11 +1,18 @@ #pragma once +#include +#include +#include +#include +#include +#include +#include #include -#include #include #include -#include -#include +#include +#include +#include #include #include "async_memcpy_executor.h" #include "client_buffer.hpp" @@ -58,6 +65,20 @@ class DataManager { friend class DataManagerTest; public: + using TimePoint = std::chrono::time_point; + + struct PreWriteResult { + RemoteBufferDesc remote_buffer; + uint64_t deadline_ms = 0; + UUID pending_write_token{0, 0}; + }; + + struct PinKeyResult { + RemoteBufferDesc remote_buffer; + uint64_t deadline_ms = 0; + UUID pin_token{0, 0}; + }; + /** * @brief Constructor * @param tiered_backend Unique pointer to TieredBackend instance (takes @@ -70,14 +91,9 @@ class DataManager { size_t lock_shard_count = 1024, const LocalTransferConfig& local_transfer_config = {}); - void Stop() { - if (async_memcpy_executor_) { - async_memcpy_executor_->Shutdown(); - } - if (tiered_backend_) { - tiered_backend_->Stop(); - } - } + ~DataManager(); + + void Stop(); /** * @brief Cleanup: delegates to TieredBackend::Destroy(). @@ -186,6 +202,19 @@ class DataManager { std::string_view key, const std::vector& src_buffers, std::optional tier_id = std::nullopt); + tl::expected PreWrite( + std::string_view key, size_t size_bytes, + std::optional tier_id = std::nullopt); + + tl::expected WriteCommit( + std::string_view key, const UUID& pending_write_token); + + tl::expected PinKey( + std::string_view key, std::optional tier_id = std::nullopt); + + tl::expected UnPinKey(std::string_view key, + const UUID& pin_token); + // ================================================================ // Utilities // ================================================================ @@ -212,6 +241,34 @@ class DataManager { std::optional tier_id = std::nullopt) const; private: + struct KeyCtx { + std::string_view key; + std::string key_string; + size_t hash = 0; + size_t pending_write_shard_idx = 0; + size_t pinned_key_shard_idx = 0; + }; + + KeyCtx BuildKeyCtx(std::string_view key) const; + PendingWriteShard& GetPendingWriteShard(const KeyCtx& ctx); + PinnedKeyShard& GetPinnedKeyShard(const KeyCtx& ctx); + + tl::expected PreWriteInternal( + const KeyCtx& ctx, size_t size_bytes, std::optional tier_id); + tl::expected WriteCommitInternal(const KeyCtx& ctx, + const UUID& pending_write_token); + tl::expected PinKeyInternal( + const KeyCtx& ctx, std::optional tier_id); + tl::expected UnPinKeyInternal(const KeyCtx& ctx, + const UUID& pin_token); + + tl::expected LookupPendingWriteHandleInternal( + const KeyCtx& ctx, const UUID& pending_write_token); + tl::expected LookupPinnedKeyHandleInternal( + const KeyCtx& ctx, const UUID& pin_token); + void AbortPendingWriteInternal(const KeyCtx& ctx, + const UUID& pending_write_token); + std::shared_mutex& GetKeyLock(std::string_view key) { size_t hash = std::hash{}(key); return lock_shards_[hash % lock_shard_count_]; @@ -360,6 +417,68 @@ class DataManager { // Wait for all tasks to reach a terminal state, then free the batch. void CancelBatchTETask(Transport::BatchID batch_id, size_t num_tasks); + using OrderedDeadlineList = std::list>; + using OrderedDeadlineListIt = OrderedDeadlineList::iterator; + + struct PendingWriteRecord { + UUID pending_write_token{0, 0}; + TimePoint deadline{}; + AllocationHandle handle; + OrderedDeadlineListIt list_it; + }; + + struct PinnedKeyRecord { + UUID pin_token{0, 0}; + TimePoint deadline{}; + AllocationHandle handle; + uint32_t ref_count = 1; + OrderedDeadlineListIt list_it; + }; + + struct PendingWriteShard { + mutable std::shared_mutex mutex; + std::unordered_map by_key; + OrderedDeadlineList ordered_list; + }; + + struct PinnedKeyShard { + mutable std::shared_mutex mutex; + std::unordered_map by_key; + OrderedDeadlineList ordered_list; + }; + + size_t HashKey(std::string_view key) const; + PendingWriteShard& GetPendingWriteShard(std::string_view key); + PinnedKeyShard& GetPinnedKeyShard(std::string_view key); + const std::chrono::milliseconds& lease_duration() const { + return lease_duration_; + } + uint64_t TimePointToDeadlineMs(TimePoint deadline) const; + TimePoint DeadlineMsToTimePoint(uint64_t deadline_ms) const; + bool IsExpired(TimePoint deadline) const; + RemoteBufferDesc BuildRemoteBufferDesc(const AllocationHandle& handle) const; + void LeaseScannerMain(); + void ShutdownLeaseScanner(); + size_t ScanExpiredPendingWrites(PendingWriteShard& shard, TimePoint now); + size_t ScanExpiredPinnedKeys(PinnedKeyShard& shard, TimePoint now); + bool ErasePendingWriteLocked(PendingWriteShard& shard, + const std::string& key); + bool ErasePinnedKeyLocked(PinnedKeyShard& shard, const std::string& key); + bool RemoveExpiredPendingWriteLocked(PendingWriteShard& shard, + const std::string& key, + TimePoint now); + bool RemoveExpiredPinnedKeyLocked(PinnedKeyShard& shard, + const std::string& key, TimePoint now); + void TouchOrderedDeadlineNode(OrderedDeadlineList& ordered_list, + OrderedDeadlineListIt it, + const std::string& key, TimePoint deadline); + + tl::expected LookupPendingWriteHandle( + std::string_view key, const UUID& pending_write_token); + tl::expected LookupPinnedKeyHandle( + std::string_view key, const UUID& pin_token); + void AbortPendingWrite(std::string_view key, const UUID& pending_write_token); + private: std::unique_ptr tiered_backend_; // Owned by DataManager std::shared_ptr transfer_engine_; // Shared with Client @@ -369,6 +488,8 @@ class DataManager { // (default: 1024) size_t lock_shard_count_; std::vector lock_shards_; + std::vector pending_write_shards_; + std::vector pinned_key_shards_; // Callback for rectifying stale read routes std::function)> @@ -376,6 +497,12 @@ class DataManager { LocalTransferConfig local_transfer_config_; std::unique_ptr async_memcpy_executor_; + std::chrono::milliseconds lease_duration_; + std::chrono::milliseconds lease_scan_interval_; + std::atomic lease_scanner_stop_requested_{false}; + std::condition_variable lease_scanner_cv_; + std::mutex lease_scanner_mutex_; + std::thread lease_scanner_thread_; }; } // namespace mooncake diff --git a/mooncake-store/src/data_manager.cpp b/mooncake-store/src/data_manager.cpp index 87fb89bd54..7daf43a4f6 100644 --- a/mooncake-store/src/data_manager.cpp +++ b/mooncake-store/src/data_manager.cpp @@ -22,6 +22,9 @@ namespace mooncake { namespace { +constexpr uint32_t kDefaultLeaseDurationMs = 5000; +constexpr uint32_t kDefaultLeaseScanIntervalMs = 1000; + struct LocalCopyPlan { AllocationHandle source_handle; const char* source_ptr = nullptr; @@ -112,6 +115,10 @@ ErrorCode ExecuteLocalCopyPlan(const LocalCopyPlan& plan) { return ErrorCode::OK; } +bool IsZeroUuid(const UUID& uuid) { + return uuid.first == 0 && uuid.second == 0; +} + } // namespace // ================================================================ @@ -126,6 +133,8 @@ DataManager::DataManager(std::unique_ptr tiered_backend, transfer_engine_(transfer_engine), lock_shard_count_(lock_shard_count > 0 ? lock_shard_count : 1024), lock_shards_(lock_shard_count_), + pending_write_shards_(lock_shard_count_), + pinned_key_shards_(lock_shard_count_), local_transfer_config_(local_transfer_config) { if (!tiered_backend_) { LOG(FATAL) << "TieredBackend cannot be null"; @@ -140,6 +149,14 @@ DataManager::DataManager(std::unique_ptr tiered_backend, local_transfer_config_.local_memcpy_async_worker_num); } + lease_duration_ = std::chrono::milliseconds( + GetEnvOr("P2P_RPC_LEASE_DURATION_MS", + kDefaultLeaseDurationMs)); + lease_scan_interval_ = std::chrono::milliseconds( + std::max(1, GetEnvOr("P2P_RPC_LEASE_SCAN_INTERVAL", + kDefaultLeaseScanIntervalMs))); + lease_scanner_thread_ = std::thread(&DataManager::LeaseScannerMain, this); + LOG(INFO) << "DataManager initialized with " << lock_shard_count_ << " lock shards, local_transfer_mode=" << (local_transfer_config_.mode == LocalTransferMode::TE @@ -147,7 +164,294 @@ DataManager::DataManager(std::unique_ptr tiered_backend, : "MEMCPY") << ", te_endpoint=" << local_transfer_config_.te_endpoint << ", async_memcpy_workers=" - << local_transfer_config_.local_memcpy_async_worker_num; + << local_transfer_config_.local_memcpy_async_worker_num + << ", lease_duration_ms=" << lease_duration_.count() + << ", lease_scan_interval_ms=" << lease_scan_interval_.count(); +} + +DataManager::~DataManager() { Stop(); } + +void DataManager::Stop() { + ShutdownLeaseScanner(); + if (async_memcpy_executor_) { + async_memcpy_executor_->Shutdown(); + } + if (tiered_backend_) { + tiered_backend_->Stop(); + } +} + +size_t DataManager::HashKey(std::string_view key) const { + return std::hash{}(key); +} + +DataManager::KeyCtx DataManager::BuildKeyCtx(std::string_view key) const { + KeyCtx ctx; + ctx.key = key; + ctx.key_string = std::string(key); + ctx.hash = HashKey(key); + ctx.pending_write_shard_idx = + pending_write_shards_.empty() + ? 0 + : (ctx.hash % pending_write_shards_.size()); + ctx.pinned_key_shard_idx = + pinned_key_shards_.empty() ? 0 : (ctx.hash % pinned_key_shards_.size()); + return ctx; +} + +DataManager::PendingWriteShard& DataManager::GetPendingWriteShard( + const KeyCtx& ctx) { + return pending_write_shards_[ctx.pending_write_shard_idx]; +} + +DataManager::PinnedKeyShard& DataManager::GetPinnedKeyShard(const KeyCtx& ctx) { + return pinned_key_shards_[ctx.pinned_key_shard_idx]; +} + +DataManager::PendingWriteShard& DataManager::GetPendingWriteShard( + std::string_view key) { + return pending_write_shards_[HashKey(key) % pending_write_shards_.size()]; +} + +DataManager::PinnedKeyShard& DataManager::GetPinnedKeyShard( + std::string_view key) { + return pinned_key_shards_[HashKey(key) % pinned_key_shards_.size()]; +} + +uint64_t DataManager::TimePointToDeadlineMs(TimePoint deadline) const { + const auto remaining = + std::max(std::chrono::milliseconds::zero(), + std::chrono::duration_cast( + deadline - std::chrono::steady_clock::now())); + const auto system_deadline = + std::chrono::system_clock::now() + remaining; + return static_cast( + std::chrono::duration_cast( + system_deadline.time_since_epoch()) + .count()); +} + +DataManager::TimePoint DataManager::DeadlineMsToTimePoint( + uint64_t deadline_ms) const { + const auto system_deadline = + std::chrono::system_clock::time_point(std::chrono::milliseconds( + static_cast(deadline_ms))); + const auto remaining = + std::chrono::duration_cast( + system_deadline - std::chrono::system_clock::now()); + if (remaining <= std::chrono::milliseconds::zero()) { + return std::chrono::steady_clock::now(); + } + return std::chrono::steady_clock::now() + remaining; +} + +bool DataManager::IsExpired(TimePoint deadline) const { + return deadline <= std::chrono::steady_clock::now(); +} + +RemoteBufferDesc DataManager::BuildRemoteBufferDesc( + const AllocationHandle& handle) const { + const auto& loc_data = handle->loc.data; + RemoteBufferDesc remote_buffer; + remote_buffer.segment_endpoint = local_transfer_config_.te_endpoint; + remote_buffer.addr = + reinterpret_cast(loc_data.buffer ? loc_data.buffer->data() + : nullptr); + remote_buffer.size = loc_data.buffer ? loc_data.buffer->size() : 0; + return remote_buffer; +} + +void DataManager::TouchOrderedDeadlineNode(OrderedDeadlineList& ordered_list, + OrderedDeadlineListIt it, + const std::string& key, + TimePoint deadline) { + it->first = key; + it->second = deadline; + ordered_list.splice(ordered_list.end(), ordered_list, it); +} + +bool DataManager::ErasePendingWriteLocked(PendingWriteShard& shard, + const std::string& key) { + auto it = shard.by_key.find(key); + if (it == shard.by_key.end()) { + return false; + } + shard.ordered_list.erase(it->second.list_it); + shard.by_key.erase(it); + return true; +} + +bool DataManager::ErasePinnedKeyLocked(PinnedKeyShard& shard, + const std::string& key) { + auto it = shard.by_key.find(key); + if (it == shard.by_key.end()) { + return false; + } + shard.ordered_list.erase(it->second.list_it); + shard.by_key.erase(it); + return true; +} + +bool DataManager::RemoveExpiredPendingWriteLocked(PendingWriteShard& shard, + const std::string& key, + TimePoint now) { + auto it = shard.by_key.find(key); + if (it == shard.by_key.end() || it->second.deadline > now) { + return false; + } + shard.ordered_list.erase(it->second.list_it); + shard.by_key.erase(it); + return true; +} + +bool DataManager::RemoveExpiredPinnedKeyLocked(PinnedKeyShard& shard, + const std::string& key, + TimePoint now) { + auto it = shard.by_key.find(key); + if (it == shard.by_key.end() || it->second.deadline > now) { + return false; + } + shard.ordered_list.erase(it->second.list_it); + shard.by_key.erase(it); + return true; +} + +size_t DataManager::ScanExpiredPendingWrites(PendingWriteShard& shard, + TimePoint now) { + size_t removed = 0; + while (!shard.ordered_list.empty()) { + auto list_it = shard.ordered_list.begin(); + if (list_it->second > now) { + break; + } + auto record_it = shard.by_key.find(list_it->first); + if (record_it == shard.by_key.end() || + record_it->second.list_it != list_it) { + shard.ordered_list.erase(list_it); + continue; + } + shard.by_key.erase(record_it); + shard.ordered_list.erase(list_it); + ++removed; + } + return removed; +} + +size_t DataManager::ScanExpiredPinnedKeys(PinnedKeyShard& shard, TimePoint now) { + size_t removed = 0; + while (!shard.ordered_list.empty()) { + auto list_it = shard.ordered_list.begin(); + if (list_it->second > now) { + break; + } + auto record_it = shard.by_key.find(list_it->first); + if (record_it == shard.by_key.end() || + record_it->second.list_it != list_it) { + shard.ordered_list.erase(list_it); + continue; + } + shard.by_key.erase(record_it); + shard.ordered_list.erase(list_it); + ++removed; + } + return removed; +} + +tl::expected DataManager::LookupPendingWriteHandle( + std::string_view key, const UUID& pending_write_token) { + return LookupPendingWriteHandleInternal(BuildKeyCtx(key), + pending_write_token); +} + +tl::expected +DataManager::LookupPendingWriteHandleInternal(const KeyCtx& ctx, + const UUID& pending_write_token) { + const auto now = std::chrono::steady_clock::now(); + auto& shard = GetPendingWriteShard(ctx); + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(ctx.key_string); + if (it == shard.by_key.end()) { + return tl::unexpected(ErrorCode::OBJECT_NOT_FOUND); + } + if (it->second.deadline <= now) { + return tl::unexpected(ErrorCode::LEASE_EXPIRED); + } + if (it->second.pending_write_token != pending_write_token) { + return tl::unexpected(ErrorCode::INVALID_WRITE); + } + return it->second.handle; +} + +tl::expected DataManager::LookupPinnedKeyHandle( + std::string_view key, const UUID& pin_token) { + return LookupPinnedKeyHandleInternal(BuildKeyCtx(key), pin_token); +} + +tl::expected DataManager::LookupPinnedKeyHandleInternal( + const KeyCtx& ctx, const UUID& pin_token) { + const auto now = std::chrono::steady_clock::now(); + auto& shard = GetPinnedKeyShard(ctx); + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(ctx.key_string); + if (it == shard.by_key.end()) { + return tl::unexpected(ErrorCode::OBJECT_NOT_FOUND); + } + if (it->second.deadline <= now) { + return tl::unexpected(ErrorCode::LEASE_EXPIRED); + } + if (it->second.pin_token != pin_token) { + return tl::unexpected(ErrorCode::INVALID_READ); + } + return it->second.handle; +} + +void DataManager::AbortPendingWrite(std::string_view key, + const UUID& pending_write_token) { + AbortPendingWriteInternal(BuildKeyCtx(key), pending_write_token); +} + +void DataManager::AbortPendingWriteInternal(const KeyCtx& ctx, + const UUID& pending_write_token) { + if (ctx.key.empty() || IsZeroUuid(pending_write_token)) return; + std::unique_lock key_lock(GetKeyLock(ctx.key)); + auto& shard = GetPendingWriteShard(ctx); + std::unique_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(ctx.key_string); + if (it == shard.by_key.end()) return; + if (it->second.pending_write_token != pending_write_token) return; + shard.ordered_list.erase(it->second.list_it); + shard.by_key.erase(it); +} + +void DataManager::ShutdownLeaseScanner() { + lease_scanner_stop_requested_.store(true); + lease_scanner_cv_.notify_all(); + if (lease_scanner_thread_.joinable()) { + lease_scanner_thread_.join(); + } +} + +void DataManager::LeaseScannerMain() { + std::unique_lock wait_lock(lease_scanner_mutex_); + while (!lease_scanner_stop_requested_.load()) { + lease_scanner_cv_.wait_for(wait_lock, lease_scan_interval_, [this]() { + return lease_scanner_stop_requested_.load(); + }); + if (lease_scanner_stop_requested_.load()) { + break; + } + const auto now = std::chrono::steady_clock::now(); + wait_lock.unlock(); + for (auto& shard : pending_write_shards_) { + std::unique_lock shard_lock(shard.mutex); + ScanExpiredPendingWrites(shard, now); + } + for (auto& shard : pinned_key_shards_) { + std::unique_lock shard_lock(shard.mutex); + ScanExpiredPinnedKeys(shard, now); + } + wait_lock.lock(); + } } // ================================================================ @@ -187,6 +491,7 @@ tl::expected>, ErrorCode> DataManager::Put( tl::expected>, ErrorCode> DataManager::PutViaTe(std::string_view key, std::vector& slices) { // using Te, treat local memory as remote memory + const KeyCtx kctx = BuildKeyCtx(key); size_t total_size = 0; for (const auto& s : slices) total_size += s.size; auto src_buffers = SlicesToRemoteBufferDescs(slices); @@ -197,28 +502,19 @@ DataManager::PutViaTe(std::string_view key, std::vector& slices) { return tl::unexpected(validate_result.error()); } - AllocationHandle alloc_handle; - { - std::unique_lock lock(GetKeyLock(key)); - - if (tiered_backend_->Exist(key)) { - LOG(WARNING) << "key already exists: " << key; - return tl::make_unexpected(ErrorCode::OBJECT_ALREADY_EXISTS); - } + auto prewrite_result = PreWriteInternal(kctx, total_size, std::nullopt); + if (!prewrite_result) { + return tl::unexpected(prewrite_result.error()); + } + const UUID pending_write_token = prewrite_result->pending_write_token; - auto alloc_result = tiered_backend_->Allocate(total_size); - if (!alloc_result) { - auto err = alloc_result.error(); - if (err == ErrorCode::REPLICA_NUM_EXCEEDED || - err == ErrorCode::OBJECT_ALREADY_EXISTS || - err == ErrorCode::REPLICA_ALREADY_EXISTS) { - LOG(WARNING) << "object already exists for key: " << key - << ", error code: " << err; - } - return tl::unexpected(err); - } - alloc_handle = alloc_result.value(); + auto handle_result = + LookupPendingWriteHandleInternal(kctx, pending_write_token); + if (!handle_result) { + AbortPendingWriteInternal(kctx, pending_write_token); + return tl::unexpected(handle_result.error()); } + AllocationHandle alloc_handle = handle_result.value(); auto submit_result = SubmitTeTransferInternal( alloc_handle, src_buffers, Transport::TransferRequest::READ); @@ -226,20 +522,22 @@ DataManager::PutViaTe(std::string_view key, std::vector& slices) { LOG(ERROR) << "SubmitTeTransferInternal failed" << ", key=" << key << ", error_code=" << toString(submit_result.error()); + AbortPendingWriteInternal(kctx, pending_write_token); return tl::unexpected(submit_result.error()); } return CallableTaskHandle::Create( - [this, ctx = std::move(*submit_result), alloc_handle, - key]() mutable -> tl::expected { + [this, ctx = std::move(*submit_result), alloc_handle, kctx, + pending_write_token]() mutable -> tl::expected { ScopedVLogTimer timer(1, "DataManager::PutViaTe"); - timer.LogRequest("key=", key); + timer.LogRequest("key=", kctx.key); auto wait_result = WaitAllTransferBatches(ctx.transfer_batches); if (!wait_result) { LOG(ERROR) << "WaitAllTransferBatches failed" - << ", key=" << key + << ", key=" << kctx.key << ", error_code=" << toString(wait_result.error()); + AbortPendingWriteInternal(kctx, pending_write_token); return tl::unexpected(wait_result.error()); } @@ -254,16 +552,16 @@ DataManager::PutViaTe(std::string_view key, std::vector& slices) { if (!copy_result) { LOG(ERROR) << "CopyFromDRAMBuffer failed" - << ", key=" << key + << ", key=" << kctx.key << ", error_code=" << toString(copy_result.error()); + AbortPendingWriteInternal(kctx, pending_write_token); return tl::unexpected(copy_result.error()); } } - std::unique_lock lock(GetKeyLock(key)); - auto commit_result = tiered_backend_->Commit(key, alloc_handle); + auto commit_result = WriteCommitInternal(kctx, pending_write_token); if (!commit_result) { - LOG(WARNING) << "commit race for key: " << key; + return tl::unexpected(commit_result.error()); } timer.LogResponse("error_code=", ErrorCode::OK); return {}; @@ -276,56 +574,49 @@ DataManager::PutViaMemcpy(std::string_view key, std::vector& slices) { LOG(ERROR) << "PutLocal in memcpy mode only supports a single slice"; return tl::unexpected(ErrorCode::NOT_IMPLEMENTED); } + const KeyCtx kctx = BuildKeyCtx(key); Slice slice = slices[0]; - AllocationHandle alloc_handle; - { - std::unique_lock lock(GetKeyLock(key)); - - if (tiered_backend_->Exist(key)) { - LOG(WARNING) << "Key already exists: " << key; - return tl::make_unexpected(ErrorCode::OBJECT_ALREADY_EXISTS); - } + auto prewrite_result = PreWriteInternal(kctx, slice.size, std::nullopt); + if (!prewrite_result) { + return tl::unexpected(prewrite_result.error()); + } + const UUID pending_write_token = prewrite_result->pending_write_token; - auto handle = tiered_backend_->Allocate(slice.size); - if (!handle.has_value()) { - LOG(ERROR) << "Failed to allocate space for key: " << key - << ", error: " << handle.error(); - return tl::make_unexpected(handle.error()); - } - alloc_handle = handle.value(); + auto handle_result = + LookupPendingWriteHandleInternal(kctx, pending_write_token); + if (!handle_result) { + AbortPendingWriteInternal(kctx, pending_write_token); + return tl::unexpected(handle_result.error()); } + AllocationHandle alloc_handle = handle_result.value(); - auto write_fn = [this, key, slice, - alloc_handle]() -> tl::expected { + auto write_fn = [this, kctx, slice, alloc_handle, + pending_write_token]() -> tl::expected { DataSource source; source.buffer = std::make_unique(slice.ptr, slice.size); source.type = MemoryType::DRAM; auto write_result = tiered_backend_->Write(source, alloc_handle); if (!write_result.has_value()) { - LOG(ERROR) << "Failed to write data for key: " << key + LOG(ERROR) << "Failed to write data for key: " << kctx.key << ", error: " << write_result.error(); + AbortPendingWriteInternal(kctx, pending_write_token); return tl::make_unexpected(write_result.error()); } return {}; }; - auto commit_fn = [this, key, - alloc_handle]() -> tl::expected { - std::unique_lock lock(GetKeyLock(key)); - auto commit_result = tiered_backend_->Commit(key, alloc_handle); - if (!commit_result.has_value()) { - auto err = commit_result.error(); - if (err != ErrorCode::REPLICA_NUM_EXCEEDED && - err != ErrorCode::OBJECT_ALREADY_EXISTS && - err != ErrorCode::REPLICA_ALREADY_EXISTS) { - LOG(ERROR) << "Failed to commit data for key: " << key + auto commit_fn = [this, kctx, + pending_write_token]() -> tl::expected { + auto commit_result = + WriteCommitInternal(kctx, pending_write_token); + if (!commit_result) { + LOG(ERROR) << "Failed to commit data for key: " << kctx.key << ", error: " << commit_result.error(); return tl::make_unexpected(commit_result.error()); } - } - return {}; + return {}; }; auto write_and_commit = [write_fn = std::move(write_fn), @@ -514,12 +805,11 @@ tl::expected DataManager::BuildDataCopierViaMemcpy( // Remote data transfer — called by RPC service layer // ================================================================ -// Attention!!! -// This method runs without key lock. tl::expected DataManager::ReadRemoteData( std::string_view key, const std::vector& dest_buffers) { ScopedVLogTimer timer(1, "DataManager::ReadRemoteData"); timer.LogRequest("key=", key, "buffer_count=", dest_buffers.size()); + const KeyCtx kctx = BuildKeyCtx(key); auto validate_result = ValidateRemoteBuffers(dest_buffers); if (!validate_result) { @@ -529,15 +819,35 @@ tl::expected DataManager::ReadRemoteData( return tl::make_unexpected(validate_result.error()); } - auto handle_result = tiered_backend_->Get(key); - if (!handle_result.has_value()) { - LOG(ERROR) << "ReadRemoteData: Failed to get data for key: " << key - << ", error: " << toString(handle_result.error()); + // Reverse RDMA path: use the same 3-phase pin/unpin semantics even though + // the control plane is still a single RPC. + auto pin_result = PinKeyInternal(kctx, std::nullopt); + if (!pin_result) { + timer.LogResponse("error_code=", pin_result.error()); + return tl::make_unexpected(pin_result.error()); + } + const UUID pin_token = pin_result->pin_token; + + auto handle_result = LookupPinnedKeyHandleInternal(kctx, pin_token); + if (!handle_result) { + (void)UnPinKeyInternal(kctx, pin_token); timer.LogResponse("error_code=", handle_result.error()); return tl::make_unexpected(handle_result.error()); } - return TransferDataToRemote(handle_result.value(), dest_buffers); + auto transfer_result = TransferDataToRemote(handle_result.value(), dest_buffers); + auto unpin_result = UnPinKeyInternal(kctx, pin_token); + + if (!transfer_result) { + timer.LogResponse("error_code=", transfer_result.error()); + return tl::make_unexpected(transfer_result.error()); + } + if (!unpin_result) { + timer.LogResponse("error_code=", unpin_result.error()); + return tl::make_unexpected(unpin_result.error()); + } + timer.LogResponse("error_code=", ErrorCode::OK); + return {}; } tl::expected DataManager::TransferDataToRemote( @@ -561,6 +871,7 @@ tl::expected DataManager::WriteRemoteData( std::optional tier_id) { ScopedVLogTimer timer(1, "DataManager::WriteRemoteData"); timer.LogRequest("key=", key, "buffer_count=", src_buffers.size()); + const KeyCtx kctx = BuildKeyCtx(key); auto validate_result = ValidateRemoteBuffers(src_buffers); if (!validate_result) { @@ -573,47 +884,37 @@ tl::expected DataManager::WriteRemoteData( size_t total_size = 0; for (const auto& buf : src_buffers) total_size += buf.size; - AllocationHandle handle; - { - std::unique_lock lock(GetKeyLock(key)); - - if (tiered_backend_->Exist(key)) { - LOG(WARNING) << "Key already exists: " << key; - timer.LogResponse("error_code=", ErrorCode::OBJECT_ALREADY_EXISTS); - return tl::make_unexpected(ErrorCode::OBJECT_ALREADY_EXISTS); - } + // Reverse RDMA path: still one RPC, but internally use the 3-phase write + // model (PreWrite -> transfer -> WriteCommit). + auto prewrite_result = PreWriteInternal(kctx, total_size, tier_id); + if (!prewrite_result) { + timer.LogResponse("error_code=", prewrite_result.error()); + return tl::make_unexpected(prewrite_result.error()); + } + const UUID pending_write_token = prewrite_result->pending_write_token; - auto handle_result = tiered_backend_->Allocate(total_size, tier_id); - if (!handle_result.has_value()) { - LOG(ERROR) << "WriteRemoteData: Failed to allocate space for key: " - << key; - timer.LogResponse("error_code=", handle_result.error()); - return tl::make_unexpected(handle_result.error()); - } - handle = handle_result.value(); + auto handle_result = + LookupPendingWriteHandleInternal(kctx, pending_write_token); + if (!handle_result) { + AbortPendingWriteInternal(kctx, pending_write_token); + timer.LogResponse("error_code=", handle_result.error()); + return tl::make_unexpected(handle_result.error()); } + AllocationHandle handle = handle_result.value(); + UUID result_tier_id = handle->loc.tier->GetTierId(); - // Transfer phase — no lock held, RDMA runs concurrently. + // Transfer phase — no long key lock held. auto transfer_result = TransferDataFromRemote(handle, src_buffers); - if (!transfer_result.has_value()) { - LOG(ERROR) << "WriteRemoteData: Transfer failed for key: " << key - << ", error: " << toString(transfer_result.error()); + if (!transfer_result) { + AbortPendingWriteInternal(kctx, pending_write_token); timer.LogResponse("error_code=", transfer_result.error()); return tl::make_unexpected(transfer_result.error()); } - // Commit phase: re-acquire lock, commit the handle. - UUID result_tier_id; - { - std::unique_lock lock(GetKeyLock(key)); - auto commit_result = tiered_backend_->Commit(key, handle); - if (!commit_result.has_value()) { - LOG(ERROR) << "WriteRemoteData: Failed to commit data for key: " - << key; - timer.LogResponse("error_code=", commit_result.error()); - return tl::make_unexpected(commit_result.error()); - } - result_tier_id = handle->loc.tier->GetTierId(); + auto commit_result = WriteCommitInternal(kctx, pending_write_token); + if (!commit_result) { + timer.LogResponse("error_code=", commit_result.error()); + return tl::make_unexpected(commit_result.error()); } timer.LogResponse("error_code=", ErrorCode::OK, @@ -621,6 +922,244 @@ tl::expected DataManager::WriteRemoteData( return result_tier_id; } +tl::expected DataManager::PreWrite( + std::string_view key, size_t size_bytes, std::optional tier_id) { + return PreWriteInternal(BuildKeyCtx(key), size_bytes, tier_id); +} + +tl::expected DataManager::PreWriteInternal( + const KeyCtx& ctx, size_t size_bytes, std::optional tier_id) { + ScopedVLogTimer timer(1, "DataManager::PreWrite"); + timer.LogRequest("key=", ctx.key, "size_bytes=", size_bytes); + + if (ctx.key.empty() || size_bytes == 0) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + const auto now = std::chrono::steady_clock::now(); + const auto deadline = now + lease_duration_; + + std::unique_lock key_lock(GetKeyLock(ctx.key)); + auto& shard = GetPendingWriteShard(ctx); + std::unique_lock shard_lock(shard.mutex); + + RemoveExpiredPendingWriteLocked(shard, ctx.key_string, now); + if (shard.by_key.find(ctx.key_string) != shard.by_key.end()) { + timer.LogResponse("error_code=", ErrorCode::OBJECT_HAS_LEASE); + return tl::make_unexpected(ErrorCode::OBJECT_HAS_LEASE); + } + if (tiered_backend_->Exist(ctx.key)) { + timer.LogResponse("error_code=", ErrorCode::OBJECT_ALREADY_EXISTS); + return tl::make_unexpected(ErrorCode::OBJECT_ALREADY_EXISTS); + } + + auto handle_result = tiered_backend_->Allocate(size_bytes, tier_id); + if (!handle_result) { + timer.LogResponse("error_code=", handle_result.error()); + return tl::make_unexpected(handle_result.error()); + } + + auto handle = std::move(handle_result.value()); + auto list_it = + shard.ordered_list.emplace(shard.ordered_list.end(), ctx.key_string, deadline); + + const UUID pending_write_token = generate_uuid(); + PendingWriteRecord record; + record.pending_write_token = pending_write_token; + record.deadline = deadline; + record.handle = handle; + record.list_it = list_it; + shard.by_key.insert_or_assign(ctx.key_string, std::move(record)); + + PreWriteResult result; + result.remote_buffer = BuildRemoteBufferDesc(handle); + result.deadline_ms = TimePointToDeadlineMs(deadline); + result.pending_write_token = pending_write_token; + timer.LogResponse("error_code=", ErrorCode::OK, + "deadline_ms=", result.deadline_ms); + return result; +} + +tl::expected DataManager::WriteCommit( + std::string_view key, const UUID& pending_write_token) { + return WriteCommitInternal(BuildKeyCtx(key), pending_write_token); +} + +tl::expected DataManager::WriteCommitInternal( + const KeyCtx& ctx, const UUID& pending_write_token) { + ScopedVLogTimer timer(1, "DataManager::WriteCommit"); + timer.LogRequest("key=", ctx.key); + + if (ctx.key.empty() || IsZeroUuid(pending_write_token)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + const auto now = std::chrono::steady_clock::now(); + + std::unique_lock key_lock(GetKeyLock(ctx.key)); + auto& shard = GetPendingWriteShard(ctx); + std::unique_lock shard_lock(shard.mutex); + + auto record_it = shard.by_key.find(ctx.key_string); + if (record_it == shard.by_key.end()) { + timer.LogResponse("error_code=", ErrorCode::OK, "idempotent=", true); + return {}; + } + if (record_it->second.deadline <= now) { + shard.ordered_list.erase(record_it->second.list_it); + shard.by_key.erase(record_it); + timer.LogResponse("error_code=", ErrorCode::LEASE_EXPIRED); + return tl::make_unexpected(ErrorCode::LEASE_EXPIRED); + } + if (record_it->second.pending_write_token != pending_write_token) { + timer.LogResponse("error_code=", ErrorCode::INVALID_WRITE); + return tl::make_unexpected(ErrorCode::INVALID_WRITE); + } + + auto handle = record_it->second.handle; + // Once we attempt the commit, always erase the pending record regardless of + // the commit result. + // + // Rationale: most commit failures are not transient and cannot be resolved + // by retrying WriteCommit alone (e.g. tier commit failure or master + // callback failure). The caller should restart a full write flow instead of + // reusing a potentially stale prewrite context. + // + // Note: after erasing the record, subsequent WriteCommit calls become + // idempotent no-ops (return OK) due to record absence. + auto commit_result = tiered_backend_->Commit(ctx.key, handle); + shard.ordered_list.erase(record_it->second.list_it); + shard.by_key.erase(record_it); + + if (!commit_result) { + timer.LogResponse("error_code=", commit_result.error(), + "record_erased=", true); + return tl::make_unexpected(commit_result.error()); + } + + timer.LogResponse("error_code=", ErrorCode::OK, "record_erased=", true); + return {}; +} + +tl::expected DataManager::PinKey( + std::string_view key, std::optional tier_id) { + return PinKeyInternal(BuildKeyCtx(key), tier_id); +} + +tl::expected DataManager::PinKeyInternal( + const KeyCtx& ctx, std::optional tier_id) { + ScopedVLogTimer timer(1, "DataManager::PinKey"); + timer.LogRequest("key=", ctx.key); + + if (ctx.key.empty()) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + const auto now = std::chrono::steady_clock::now(); + const auto deadline = now + lease_duration_; + + std::shared_lock key_lock(GetKeyLock(ctx.key)); + auto& shard = GetPinnedKeyShard(ctx); + std::unique_lock shard_lock(shard.mutex); + + RemoveExpiredPinnedKeyLocked(shard, ctx.key_string, now); + auto record_it = shard.by_key.find(ctx.key_string); + if (record_it != shard.by_key.end()) { + record_it->second.ref_count++; + record_it->second.deadline = deadline; + TouchOrderedDeadlineNode(shard.ordered_list, record_it->second.list_it, + ctx.key_string, deadline); + + PinKeyResult result; + result.remote_buffer = BuildRemoteBufferDesc(record_it->second.handle); + result.deadline_ms = TimePointToDeadlineMs(deadline); + result.pin_token = record_it->second.pin_token; + timer.LogResponse("error_code=", ErrorCode::OK, + "ref_count=", record_it->second.ref_count); + return result; + } + + auto handle_result = tiered_backend_->Get(ctx.key, tier_id); + if (!handle_result) { + timer.LogResponse("error_code=", handle_result.error()); + return tl::make_unexpected(handle_result.error()); + } + + auto handle = std::move(handle_result.value()); + auto list_it = + shard.ordered_list.emplace(shard.ordered_list.end(), ctx.key_string, deadline); + + const UUID pin_token_value = generate_uuid(); + PinnedKeyRecord record; + record.pin_token = pin_token_value; + record.deadline = deadline; + record.handle = handle; + record.ref_count = 1; + record.list_it = list_it; + shard.by_key.insert_or_assign(ctx.key_string, std::move(record)); + + PinKeyResult result; + result.remote_buffer = BuildRemoteBufferDesc(handle); + result.deadline_ms = TimePointToDeadlineMs(deadline); + result.pin_token = pin_token_value; + timer.LogResponse("error_code=", ErrorCode::OK, + "deadline_ms=", result.deadline_ms); + return result; +} + +tl::expected DataManager::UnPinKey(std::string_view key, + const UUID& pin_token) { + return UnPinKeyInternal(BuildKeyCtx(key), pin_token); +} + +tl::expected DataManager::UnPinKeyInternal( + const KeyCtx& ctx, const UUID& pin_token) { + ScopedVLogTimer timer(1, "DataManager::UnPinKey"); + timer.LogRequest("key=", ctx.key); + + if (ctx.key.empty() || IsZeroUuid(pin_token)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + const auto now = std::chrono::steady_clock::now(); + + std::shared_lock key_lock(GetKeyLock(ctx.key)); + auto& shard = GetPinnedKeyShard(ctx); + std::unique_lock shard_lock(shard.mutex); + + auto record_it = shard.by_key.find(ctx.key_string); + if (record_it == shard.by_key.end()) { + timer.LogResponse("error_code=", ErrorCode::OK, "idempotent=", true); + return {}; + } + if (record_it->second.deadline <= now) { + shard.ordered_list.erase(record_it->second.list_it); + shard.by_key.erase(record_it); + timer.LogResponse("error_code=", ErrorCode::LEASE_EXPIRED); + return tl::make_unexpected(ErrorCode::LEASE_EXPIRED); + } + if (record_it->second.pin_token != pin_token) { + timer.LogResponse("error_code=", ErrorCode::INVALID_READ); + return tl::make_unexpected(ErrorCode::INVALID_READ); + } + + if (record_it->second.ref_count > 1) { + record_it->second.ref_count--; + timer.LogResponse("error_code=", ErrorCode::OK, + "ref_count=", record_it->second.ref_count); + return {}; + } + + shard.ordered_list.erase(record_it->second.list_it); + shard.by_key.erase(record_it); + timer.LogResponse("error_code=", ErrorCode::OK, "ref_count=", 0); + return {}; +} + tl::expected DataManager::TransferDataFromRemote( AllocationHandle handle, const std::vector& src_buffers) { auto submit_result = SubmitTeTransferInternal( @@ -1134,6 +1673,17 @@ tl::expected DataManager::Delete(std::string_view key, ScopedVLogTimer timer(1, "DataManager::Delete"); timer.LogRequest("key=", key); + // NOTE (weak delete semantics): + // TieredBackend::Delete only removes the metadata entry (or a replica entry) + // from the in-memory index. It does NOT directly free underlying memory. + // The actual buffer lifetime is still governed by AllocationHandle's RAII + // reference counting. + // + // We still guard Delete against in-flight 3-phase contexts: + // - PendingWriteRecord holds a strong handle reference until WriteCommit or + // lease cleanup. + // - PinnedKeyRecord holds a strong handle reference until UnPinKey reaches + // ref_count==0 or lease cleanup. std::unique_lock lock(GetKeyLock(key)); auto result = tiered_backend_->Delete(key, tier_id); diff --git a/mooncake-store/tests/data_manager_test.cpp b/mooncake-store/tests/data_manager_test.cpp index 18a189a921..97557610e3 100644 --- a/mooncake-store/tests/data_manager_test.cpp +++ b/mooncake-store/tests/data_manager_test.cpp @@ -284,6 +284,209 @@ TEST_F(DataManagerTest, DeleteWithTierId) { ASSERT_FALSE(data_manager_->Exist(key)); } +// Test PreWrite: concurrent PreWrite should be rejected by a pending lease. +TEST_F(DataManagerTest, PreWriteRejectsConcurrentLease) { + const std::string key = "prewrite_lifecycle_key"; + auto prewrite_result = data_manager_->PreWrite(key, 256, GetTierId()); + ASSERT_TRUE(prewrite_result.has_value()) + << "PreWrite failed: " << toString(prewrite_result.error()); + + auto second_prewrite = data_manager_->PreWrite(key, 256, GetTierId()); + ASSERT_FALSE(second_prewrite.has_value()); + EXPECT_EQ(second_prewrite.error(), ErrorCode::OBJECT_HAS_LEASE); + + auto& shard = data_manager_->GetPendingWriteShard(key); + { + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + EXPECT_EQ(it->second.pending_write_token, + prewrite_result->pending_write_token); + } +} + +// Test WriteCommit: successful commit should erase the pending write record. +TEST_F(DataManagerTest, WriteCommitErasesPendingWriteRecord) { + const std::string key = "write_commit_erases_record_key"; + auto prewrite_result = data_manager_->PreWrite(key, 256, GetTierId()); + ASSERT_TRUE(prewrite_result.has_value()) + << "PreWrite failed: " << toString(prewrite_result.error()); + + auto commit_result = + data_manager_->WriteCommit(key, prewrite_result->pending_write_token); + ASSERT_TRUE(commit_result.has_value()) + << "WriteCommit failed: " << toString(commit_result.error()); + EXPECT_TRUE(data_manager_->Exist(key)); + + auto& shard = data_manager_->GetPendingWriteShard(key); + { + std::shared_lock shard_lock(shard.mutex); + EXPECT_EQ(shard.by_key.count(key), 0U); + } +} + +// Test WriteCommit: token mismatch should fail without erasing the record. +TEST_F(DataManagerTest, WriteCommitTokenMismatchKeepsPendingWriteRecord) { + const std::string key = "write_commit_token_mismatch_key"; + auto prewrite_result = data_manager_->PreWrite(key, 256, GetTierId()); + ASSERT_TRUE(prewrite_result.has_value()) + << "PreWrite failed: " << toString(prewrite_result.error()); + + UUID wrong_token = prewrite_result->pending_write_token; + wrong_token.first += 1; + auto wrong_commit = data_manager_->WriteCommit(key, wrong_token); + ASSERT_FALSE(wrong_commit.has_value()); + EXPECT_EQ(wrong_commit.error(), ErrorCode::INVALID_WRITE); + + auto& shard = data_manager_->GetPendingWriteShard(key); + { + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + EXPECT_EQ(it->second.pending_write_token, + prewrite_result->pending_write_token); + } +} + +// Test Pin/Unpin: ref_count increments on PinKey and reaches zero on final UnPinKey. +TEST_F(DataManagerTest, PinKeyTracksRefCountUntilFinalUnpin) { + const std::string key = "pin_ref_count_key"; + const std::string test_data = "Pin key ref count payload"; + + auto buffer = StringToBuffer(test_data); + ASSERT_TRUE(DoPut(key, buffer.get(), test_data.size()).has_value()); + + auto first_pin = data_manager_->PinKey(key, GetTierId()); + ASSERT_TRUE(first_pin.has_value()) + << "First PinKey failed: " << toString(first_pin.error()); + + auto second_pin = data_manager_->PinKey(key, GetTierId()); + ASSERT_TRUE(second_pin.has_value()) + << "Second PinKey failed: " << toString(second_pin.error()); + EXPECT_EQ(first_pin->pin_token, second_pin->pin_token); + + auto& shard = data_manager_->GetPinnedKeyShard(key); + { + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + EXPECT_EQ(it->second.ref_count, 2U); + } + + auto first_unpin = data_manager_->UnPinKey(key, first_pin->pin_token); + ASSERT_TRUE(first_unpin.has_value()) + << "First UnPinKey failed: " << toString(first_unpin.error()); + { + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + EXPECT_EQ(it->second.ref_count, 1U); + } + + auto second_unpin = data_manager_->UnPinKey(key, second_pin->pin_token); + ASSERT_TRUE(second_unpin.has_value()) + << "Second UnPinKey failed: " << toString(second_unpin.error()); + { + std::shared_lock shard_lock(shard.mutex); + EXPECT_EQ(shard.by_key.count(key), 0U); + } +} + +// Test UnPinKey: token mismatch should fail without erasing or decrementing the record. +TEST_F(DataManagerTest, UnPinKeyTokenMismatchKeepsPinnedRecord) { + const std::string key = "unpin_token_mismatch_key"; + const std::string test_data = "Unpin token mismatch payload"; + + auto buffer = StringToBuffer(test_data); + ASSERT_TRUE(DoPut(key, buffer.get(), test_data.size()).has_value()); + + auto pin_result = data_manager_->PinKey(key, GetTierId()); + ASSERT_TRUE(pin_result.has_value()) + << "PinKey failed: " << toString(pin_result.error()); + + UUID wrong_token = pin_result->pin_token; + wrong_token.first += 1; + auto wrong_unpin = data_manager_->UnPinKey(key, wrong_token); + ASSERT_FALSE(wrong_unpin.has_value()); + EXPECT_EQ(wrong_unpin.error(), ErrorCode::INVALID_READ); + + auto& shard = data_manager_->GetPinnedKeyShard(key); + { + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + EXPECT_EQ(it->second.ref_count, 1U); + EXPECT_EQ(it->second.pin_token, pin_result->pin_token); + } +} + +// Test WriteCommit: lease expiry causes commit failure and allows subsequent PreWrite. +TEST_F(DataManagerTest, WriteCommitFailsAfterLeaseExpiryAndAllowsRetryPreWrite) { + const std::string key = "expired_prewrite_key"; + auto prewrite_result = data_manager_->PreWrite(key, 512, GetTierId()); + ASSERT_TRUE(prewrite_result.has_value()) + << "PreWrite failed: " << toString(prewrite_result.error()); + + auto& shard = data_manager_->GetPendingWriteShard(key); + { + std::unique_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + const auto expired_deadline = + std::chrono::steady_clock::now() - std::chrono::milliseconds(1); + it->second.deadline = expired_deadline; + it->second.list_it->second = expired_deadline; + } + + auto commit_result = + data_manager_->WriteCommit(key, prewrite_result->pending_write_token); + ASSERT_FALSE(commit_result.has_value()); + EXPECT_EQ(commit_result.error(), ErrorCode::LEASE_EXPIRED); + + { + std::shared_lock shard_lock(shard.mutex); + EXPECT_EQ(shard.by_key.count(key), 0U); + } + + auto retry_prewrite = data_manager_->PreWrite(key, 512, GetTierId()); + ASSERT_TRUE(retry_prewrite.has_value()) + << "PreWrite should succeed after expired record is cleaned: " + << toString(retry_prewrite.error()); +} + +// Test UnPinKey: lease expiry causes unpin failure and cleans up the pin record. +TEST_F(DataManagerTest, ExpiredPinnedLeaseBlocksUnpinButDeleteCanProceedAfterCleanup) { + const std::string key = "expired_pin_key"; + const std::string test_data = "Expired pin payload"; + + auto buffer = StringToBuffer(test_data); + ASSERT_TRUE(DoPut(key, buffer.get(), test_data.size()).has_value()); + + auto pin_result = data_manager_->PinKey(key, GetTierId()); + ASSERT_TRUE(pin_result.has_value()) + << "PinKey failed: " << toString(pin_result.error()); + + auto& shard = data_manager_->GetPinnedKeyShard(key); + { + std::unique_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + const auto expired_deadline = + std::chrono::steady_clock::now() - std::chrono::milliseconds(1); + it->second.deadline = expired_deadline; + it->second.list_it->second = expired_deadline; + } + + auto unpin_result = data_manager_->UnPinKey(key, pin_result->pin_token); + ASSERT_FALSE(unpin_result.has_value()); + EXPECT_EQ(unpin_result.error(), ErrorCode::LEASE_EXPIRED); + + { + std::shared_lock shard_lock(shard.mutex); + EXPECT_EQ(shard.by_key.count(key), 0U); + } +} + // Test concurrent Put operations TEST_F(DataManagerTest, ConcurrentPut) { const int num_keys = 10; From 381b82f243cc8f6d5c61666aa4f65884dc62482a Mon Sep 17 00:00:00 2001 From: shichangzhang064 Date: Fri, 8 May 2026 14:37:37 +0800 Subject: [PATCH 02/14] fix: avoid move in data Manager initialization --- mooncake-store/src/data_manager.cpp | 11 +++++++---- mooncake-store/src/p2p_client_service.cpp | 4 ++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/mooncake-store/src/data_manager.cpp b/mooncake-store/src/data_manager.cpp index 7daf43a4f6..e87073bef5 100644 --- a/mooncake-store/src/data_manager.cpp +++ b/mooncake-store/src/data_manager.cpp @@ -254,10 +254,13 @@ RemoteBufferDesc DataManager::BuildRemoteBufferDesc( const auto& loc_data = handle->loc.data; RemoteBufferDesc remote_buffer; remote_buffer.segment_endpoint = local_transfer_config_.te_endpoint; - remote_buffer.addr = - reinterpret_cast(loc_data.buffer ? loc_data.buffer->data() - : nullptr); - remote_buffer.size = loc_data.buffer ? loc_data.buffer->size() : 0; + remote_buffer.addr = 0; + remote_buffer.size = 0; + if (loc_data.buffer) { + remote_buffer.addr = + reinterpret_cast(loc_data.buffer->data()); + remote_buffer.size = loc_data.buffer->size(); + } return remote_buffer; } diff --git a/mooncake-store/src/p2p_client_service.cpp b/mooncake-store/src/p2p_client_service.cpp index 46dfe3f808..32e887c260 100644 --- a/mooncake-store/src/p2p_client_service.cpp +++ b/mooncake-store/src/p2p_client_service.cpp @@ -248,8 +248,8 @@ ErrorCode P2PClientService::InitStorage(const P2PClientConfig& config) { config.local_memcpy_async_worker_num; } - data_manager_ = DataManager(std::move(tiered_backend), transfer_engine_, - config.lock_shard_count, local_transfer_config); + data_manager_.emplace(std::move(tiered_backend), transfer_engine_, + config.lock_shard_count, local_transfer_config); // Set rectify callback on DataManager to remove stale replicas from master data_manager_->SetRectifyCallback([this](std::string_view key, std::optional tier_id) { From 368c3c06aaaa63c4243c9d6cc5f6fc6773accf3a Mon Sep 17 00:00:00 2001 From: shichangzhang064 Date: Fri, 8 May 2026 16:19:07 +0800 Subject: [PATCH 03/14] fix: clear records before shutdown DataManager --- mooncake-store/include/data_manager.h | 2 ++ mooncake-store/src/data_manager.cpp | 14 ++++++++++++++ 2 files changed, 16 insertions(+) diff --git a/mooncake-store/include/data_manager.h b/mooncake-store/include/data_manager.h index c06fe6b531..05e3b3a45c 100644 --- a/mooncake-store/include/data_manager.h +++ b/mooncake-store/include/data_manager.h @@ -241,6 +241,8 @@ class DataManager { std::optional tier_id = std::nullopt) const; private: + void ClearLeaseRecords(); + struct KeyCtx { std::string_view key; std::string key_string; diff --git a/mooncake-store/src/data_manager.cpp b/mooncake-store/src/data_manager.cpp index e87073bef5..2adc105387 100644 --- a/mooncake-store/src/data_manager.cpp +++ b/mooncake-store/src/data_manager.cpp @@ -173,6 +173,7 @@ DataManager::~DataManager() { Stop(); } void DataManager::Stop() { ShutdownLeaseScanner(); + ClearLeaseRecords(); if (async_memcpy_executor_) { async_memcpy_executor_->Shutdown(); } @@ -181,6 +182,19 @@ void DataManager::Stop() { } } +void DataManager::ClearLeaseRecords() { + for (auto& shard : pending_write_shards_) { + std::unique_lock lock(shard.mutex); + shard.by_key.clear(); + shard.ordered_list.clear(); + } + for (auto& shard : pinned_key_shards_) { + std::unique_lock lock(shard.mutex); + shard.by_key.clear(); + shard.ordered_list.clear(); + } +} + size_t DataManager::HashKey(std::string_view key) const { return std::hash{}(key); } From 768e1e23061a7df9f6d8d7a8686d7a24a80da0bf Mon Sep 17 00:00:00 2001 From: shichangzhang064 Date: Fri, 8 May 2026 16:42:01 +0800 Subject: [PATCH 04/14] fix: add shard forward delearation in dataManager.h --- mooncake-store/include/data_manager.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mooncake-store/include/data_manager.h b/mooncake-store/include/data_manager.h index 05e3b3a45c..fdf3d89f28 100644 --- a/mooncake-store/include/data_manager.h +++ b/mooncake-store/include/data_manager.h @@ -243,6 +243,10 @@ class DataManager { private: void ClearLeaseRecords(); + // Forward declarations for nested shard structs used by internal helpers. + struct PendingWriteShard; + struct PinnedKeyShard; + struct KeyCtx { std::string_view key; std::string key_string; From ed2c0620075c96b6cb0bbc85deed405ae9f8f551 Mon Sep 17 00:00:00 2001 From: Shichang-Zhang Date: Sat, 9 May 2026 10:28:37 +0800 Subject: [PATCH 05/14] fix: code format fix --- mooncake-store/include/data_manager.h | 17 +++--- mooncake-store/src/data_manager.cpp | 62 +++++++++++----------- mooncake-store/tests/data_manager_test.cpp | 18 ++++--- 3 files changed, 52 insertions(+), 45 deletions(-) diff --git a/mooncake-store/include/data_manager.h b/mooncake-store/include/data_manager.h index fdf3d89f28..991d01f695 100644 --- a/mooncake-store/include/data_manager.h +++ b/mooncake-store/include/data_manager.h @@ -206,8 +206,8 @@ class DataManager { std::string_view key, size_t size_bytes, std::optional tier_id = std::nullopt); - tl::expected WriteCommit( - std::string_view key, const UUID& pending_write_token); + tl::expected WriteCommit(std::string_view key, + const UUID& pending_write_token); tl::expected PinKey( std::string_view key, std::optional tier_id = std::nullopt); @@ -261,8 +261,8 @@ class DataManager { tl::expected PreWriteInternal( const KeyCtx& ctx, size_t size_bytes, std::optional tier_id); - tl::expected WriteCommitInternal(const KeyCtx& ctx, - const UUID& pending_write_token); + tl::expected WriteCommitInternal( + const KeyCtx& ctx, const UUID& pending_write_token); tl::expected PinKeyInternal( const KeyCtx& ctx, std::optional tier_id); tl::expected UnPinKeyInternal(const KeyCtx& ctx, @@ -462,7 +462,8 @@ class DataManager { uint64_t TimePointToDeadlineMs(TimePoint deadline) const; TimePoint DeadlineMsToTimePoint(uint64_t deadline_ms) const; bool IsExpired(TimePoint deadline) const; - RemoteBufferDesc BuildRemoteBufferDesc(const AllocationHandle& handle) const; + RemoteBufferDesc BuildRemoteBufferDesc( + const AllocationHandle& handle) const; void LeaseScannerMain(); void ShutdownLeaseScanner(); size_t ScanExpiredPendingWrites(PendingWriteShard& shard, TimePoint now); @@ -471,8 +472,7 @@ class DataManager { const std::string& key); bool ErasePinnedKeyLocked(PinnedKeyShard& shard, const std::string& key); bool RemoveExpiredPendingWriteLocked(PendingWriteShard& shard, - const std::string& key, - TimePoint now); + const std::string& key, TimePoint now); bool RemoveExpiredPinnedKeyLocked(PinnedKeyShard& shard, const std::string& key, TimePoint now); void TouchOrderedDeadlineNode(OrderedDeadlineList& ordered_list, @@ -483,7 +483,8 @@ class DataManager { std::string_view key, const UUID& pending_write_token); tl::expected LookupPinnedKeyHandle( std::string_view key, const UUID& pin_token); - void AbortPendingWrite(std::string_view key, const UUID& pending_write_token); + void AbortPendingWrite(std::string_view key, + const UUID& pending_write_token); private: std::unique_ptr tiered_backend_; // Owned by DataManager diff --git a/mooncake-store/src/data_manager.cpp b/mooncake-store/src/data_manager.cpp index 2adc105387..b706ba7893 100644 --- a/mooncake-store/src/data_manager.cpp +++ b/mooncake-store/src/data_manager.cpp @@ -149,9 +149,8 @@ DataManager::DataManager(std::unique_ptr tiered_backend, local_transfer_config_.local_memcpy_async_worker_num); } - lease_duration_ = std::chrono::milliseconds( - GetEnvOr("P2P_RPC_LEASE_DURATION_MS", - kDefaultLeaseDurationMs)); + lease_duration_ = std::chrono::milliseconds(GetEnvOr( + "P2P_RPC_LEASE_DURATION_MS", kDefaultLeaseDurationMs)); lease_scan_interval_ = std::chrono::milliseconds( std::max(1, GetEnvOr("P2P_RPC_LEASE_SCAN_INTERVAL", kDefaultLeaseScanIntervalMs))); @@ -237,8 +236,7 @@ uint64_t DataManager::TimePointToDeadlineMs(TimePoint deadline) const { std::max(std::chrono::milliseconds::zero(), std::chrono::duration_cast( deadline - std::chrono::steady_clock::now())); - const auto system_deadline = - std::chrono::system_clock::now() + remaining; + const auto system_deadline = std::chrono::system_clock::now() + remaining; return static_cast( std::chrono::duration_cast( system_deadline.time_since_epoch()) @@ -247,9 +245,8 @@ uint64_t DataManager::TimePointToDeadlineMs(TimePoint deadline) const { DataManager::TimePoint DataManager::DeadlineMsToTimePoint( uint64_t deadline_ms) const { - const auto system_deadline = - std::chrono::system_clock::time_point(std::chrono::milliseconds( - static_cast(deadline_ms))); + const auto system_deadline = std::chrono::system_clock::time_point( + std::chrono::milliseconds(static_cast(deadline_ms))); const auto remaining = std::chrono::duration_cast( system_deadline - std::chrono::system_clock::now()); @@ -354,7 +351,8 @@ size_t DataManager::ScanExpiredPendingWrites(PendingWriteShard& shard, return removed; } -size_t DataManager::ScanExpiredPinnedKeys(PinnedKeyShard& shard, TimePoint now) { +size_t DataManager::ScanExpiredPinnedKeys(PinnedKeyShard& shard, + TimePoint now) { size_t removed = 0; while (!shard.ordered_list.empty()) { auto list_it = shard.ordered_list.begin(); @@ -404,8 +402,9 @@ tl::expected DataManager::LookupPinnedKeyHandle( return LookupPinnedKeyHandleInternal(BuildKeyCtx(key), pin_token); } -tl::expected DataManager::LookupPinnedKeyHandleInternal( - const KeyCtx& ctx, const UUID& pin_token) { +tl::expected +DataManager::LookupPinnedKeyHandleInternal(const KeyCtx& ctx, + const UUID& pin_token) { const auto now = std::chrono::steady_clock::now(); auto& shard = GetPinnedKeyShard(ctx); std::shared_lock shard_lock(shard.mutex); @@ -423,7 +422,7 @@ tl::expected DataManager::LookupPinnedKeyHandleInte } void DataManager::AbortPendingWrite(std::string_view key, - const UUID& pending_write_token) { + const UUID& pending_write_token) { AbortPendingWriteInternal(BuildKeyCtx(key), pending_write_token); } @@ -626,14 +625,13 @@ DataManager::PutViaMemcpy(std::string_view key, std::vector& slices) { auto commit_fn = [this, kctx, pending_write_token]() -> tl::expected { - auto commit_result = - WriteCommitInternal(kctx, pending_write_token); - if (!commit_result) { - LOG(ERROR) << "Failed to commit data for key: " << kctx.key - << ", error: " << commit_result.error(); - return tl::make_unexpected(commit_result.error()); - } - return {}; + auto commit_result = WriteCommitInternal(kctx, pending_write_token); + if (!commit_result) { + LOG(ERROR) << "Failed to commit data for key: " << kctx.key + << ", error: " << commit_result.error(); + return tl::make_unexpected(commit_result.error()); + } + return {}; }; auto write_and_commit = [write_fn = std::move(write_fn), @@ -852,7 +850,8 @@ tl::expected DataManager::ReadRemoteData( return tl::make_unexpected(handle_result.error()); } - auto transfer_result = TransferDataToRemote(handle_result.value(), dest_buffers); + auto transfer_result = + TransferDataToRemote(handle_result.value(), dest_buffers); auto unpin_result = UnPinKeyInternal(kctx, pin_token); if (!transfer_result) { @@ -944,8 +943,9 @@ tl::expected DataManager::PreWrite( return PreWriteInternal(BuildKeyCtx(key), size_bytes, tier_id); } -tl::expected DataManager::PreWriteInternal( - const KeyCtx& ctx, size_t size_bytes, std::optional tier_id) { +tl::expected +DataManager::PreWriteInternal(const KeyCtx& ctx, size_t size_bytes, + std::optional tier_id) { ScopedVLogTimer timer(1, "DataManager::PreWrite"); timer.LogRequest("key=", ctx.key, "size_bytes=", size_bytes); @@ -978,8 +978,8 @@ tl::expected DataManager::PreWriteIntern } auto handle = std::move(handle_result.value()); - auto list_it = - shard.ordered_list.emplace(shard.ordered_list.end(), ctx.key_string, deadline); + auto list_it = shard.ordered_list.emplace(shard.ordered_list.end(), + ctx.key_string, deadline); const UUID pending_write_token = generate_uuid(); PendingWriteRecord record; @@ -1106,8 +1106,8 @@ tl::expected DataManager::PinKeyInternal( } auto handle = std::move(handle_result.value()); - auto list_it = - shard.ordered_list.emplace(shard.ordered_list.end(), ctx.key_string, deadline); + auto list_it = shard.ordered_list.emplace(shard.ordered_list.end(), + ctx.key_string, deadline); const UUID pin_token_value = generate_uuid(); PinnedKeyRecord record; @@ -1691,10 +1691,10 @@ tl::expected DataManager::Delete(std::string_view key, timer.LogRequest("key=", key); // NOTE (weak delete semantics): - // TieredBackend::Delete only removes the metadata entry (or a replica entry) - // from the in-memory index. It does NOT directly free underlying memory. - // The actual buffer lifetime is still governed by AllocationHandle's RAII - // reference counting. + // TieredBackend::Delete only removes the metadata entry (or a replica + // entry) from the in-memory index. It does NOT directly free underlying + // memory. The actual buffer lifetime is still governed by + // AllocationHandle's RAII reference counting. // // We still guard Delete against in-flight 3-phase contexts: // - PendingWriteRecord holds a strong handle reference until WriteCommit or diff --git a/mooncake-store/tests/data_manager_test.cpp b/mooncake-store/tests/data_manager_test.cpp index 97557610e3..441dfc9868 100644 --- a/mooncake-store/tests/data_manager_test.cpp +++ b/mooncake-store/tests/data_manager_test.cpp @@ -348,7 +348,8 @@ TEST_F(DataManagerTest, WriteCommitTokenMismatchKeepsPendingWriteRecord) { } } -// Test Pin/Unpin: ref_count increments on PinKey and reaches zero on final UnPinKey. +// Test Pin/Unpin: ref_count increments on PinKey and reaches zero on final +// UnPinKey. TEST_F(DataManagerTest, PinKeyTracksRefCountUntilFinalUnpin) { const std::string key = "pin_ref_count_key"; const std::string test_data = "Pin key ref count payload"; @@ -392,7 +393,8 @@ TEST_F(DataManagerTest, PinKeyTracksRefCountUntilFinalUnpin) { } } -// Test UnPinKey: token mismatch should fail without erasing or decrementing the record. +// Test UnPinKey: token mismatch should fail without erasing or decrementing the +// record. TEST_F(DataManagerTest, UnPinKeyTokenMismatchKeepsPinnedRecord) { const std::string key = "unpin_token_mismatch_key"; const std::string test_data = "Unpin token mismatch payload"; @@ -420,8 +422,10 @@ TEST_F(DataManagerTest, UnPinKeyTokenMismatchKeepsPinnedRecord) { } } -// Test WriteCommit: lease expiry causes commit failure and allows subsequent PreWrite. -TEST_F(DataManagerTest, WriteCommitFailsAfterLeaseExpiryAndAllowsRetryPreWrite) { +// Test WriteCommit: lease expiry causes commit failure and allows subsequent +// PreWrite. +TEST_F(DataManagerTest, + WriteCommitFailsAfterLeaseExpiryAndAllowsRetryPreWrite) { const std::string key = "expired_prewrite_key"; auto prewrite_result = data_manager_->PreWrite(key, 512, GetTierId()); ASSERT_TRUE(prewrite_result.has_value()) @@ -454,8 +458,10 @@ TEST_F(DataManagerTest, WriteCommitFailsAfterLeaseExpiryAndAllowsRetryPreWrite) << toString(retry_prewrite.error()); } -// Test UnPinKey: lease expiry causes unpin failure and cleans up the pin record. -TEST_F(DataManagerTest, ExpiredPinnedLeaseBlocksUnpinButDeleteCanProceedAfterCleanup) { +// Test UnPinKey: lease expiry causes unpin failure and cleans up the pin +// record. +TEST_F(DataManagerTest, + ExpiredPinnedLeaseBlocksUnpinButDeleteCanProceedAfterCleanup) { const std::string key = "expired_pin_key"; const std::string test_data = "Expired pin payload"; From a754e2e28a99c913570b74f8c9137eb7adf74f38 Mon Sep 17 00:00:00 2001 From: shichangzhang064 Date: Sat, 9 May 2026 11:36:33 +0800 Subject: [PATCH 06/14] fix: add unpin check in remoteReadData --- mooncake-store/src/data_manager.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/mooncake-store/src/data_manager.cpp b/mooncake-store/src/data_manager.cpp index b706ba7893..dc8ed20aec 100644 --- a/mooncake-store/src/data_manager.cpp +++ b/mooncake-store/src/data_manager.cpp @@ -856,6 +856,11 @@ tl::expected DataManager::ReadRemoteData( if (!transfer_result) { timer.LogResponse("error_code=", transfer_result.error()); + if (!unpin_result) { + LOG(ERROR) << "Also failed to unpin key " << kctx.key + << " after transfer failure: " + << toString(unpin_result.error()); + } return tl::make_unexpected(transfer_result.error()); } if (!unpin_result) { From 4f2a0b41a034fba4e56aa95071e20b8fdde7c9c0 Mon Sep 17 00:00:00 2001 From: shichangzhang064 Date: Sat, 9 May 2026 15:05:07 +0800 Subject: [PATCH 07/14] feat: reverse rdma remote read stay unchanged --- mooncake-store/src/data_manager.cpp | 26 +++----------------------- 1 file changed, 3 insertions(+), 23 deletions(-) diff --git a/mooncake-store/src/data_manager.cpp b/mooncake-store/src/data_manager.cpp index dc8ed20aec..068f37c85b 100644 --- a/mooncake-store/src/data_manager.cpp +++ b/mooncake-store/src/data_manager.cpp @@ -824,7 +824,6 @@ tl::expected DataManager::ReadRemoteData( std::string_view key, const std::vector& dest_buffers) { ScopedVLogTimer timer(1, "DataManager::ReadRemoteData"); timer.LogRequest("key=", key, "buffer_count=", dest_buffers.size()); - const KeyCtx kctx = BuildKeyCtx(key); auto validate_result = ValidateRemoteBuffers(dest_buffers); if (!validate_result) { @@ -834,39 +833,20 @@ tl::expected DataManager::ReadRemoteData( return tl::make_unexpected(validate_result.error()); } - // Reverse RDMA path: use the same 3-phase pin/unpin semantics even though - // the control plane is still a single RPC. - auto pin_result = PinKeyInternal(kctx, std::nullopt); - if (!pin_result) { - timer.LogResponse("error_code=", pin_result.error()); - return tl::make_unexpected(pin_result.error()); - } - const UUID pin_token = pin_result->pin_token; - - auto handle_result = LookupPinnedKeyHandleInternal(kctx, pin_token); + // Reverse RDMA read stays on the direct object-handle path. Only forward + // RDMA read uses the 3-phase PinKey -> TE Read -> UnPinKey flow. + auto handle_result = tiered_backend_->Get(key); if (!handle_result) { - (void)UnPinKeyInternal(kctx, pin_token); timer.LogResponse("error_code=", handle_result.error()); return tl::make_unexpected(handle_result.error()); } auto transfer_result = TransferDataToRemote(handle_result.value(), dest_buffers); - auto unpin_result = UnPinKeyInternal(kctx, pin_token); - if (!transfer_result) { timer.LogResponse("error_code=", transfer_result.error()); - if (!unpin_result) { - LOG(ERROR) << "Also failed to unpin key " << kctx.key - << " after transfer failure: " - << toString(unpin_result.error()); - } return tl::make_unexpected(transfer_result.error()); } - if (!unpin_result) { - timer.LogResponse("error_code=", unpin_result.error()); - return tl::make_unexpected(unpin_result.error()); - } timer.LogResponse("error_code=", ErrorCode::OK); return {}; } From f063821b2cbe81d13008ad966d708028ca4c4115 Mon Sep 17 00:00:00 2001 From: shichangzhang064 Date: Sun, 10 May 2026 15:53:17 +0800 Subject: [PATCH 08/14] feat: add forward rdma rpc structure and types --- mooncake-store/include/client_rpc_service.h | 11 ++ mooncake-store/include/client_rpc_types.h | 44 +++++++ mooncake-store/include/data_manager.h | 25 +--- mooncake-store/include/p2p_rpc_types.h | 21 ++++ mooncake-store/include/peer_client.h | 19 +++ mooncake-store/include/rpc_types.h | 8 ++ mooncake-store/include/types.h | 25 ++++ mooncake-store/src/client_rpc_service.cpp | 129 ++++++++++++++++++++ mooncake-store/src/data_manager.cpp | 87 +++++++------ mooncake-store/src/peer_client.cpp | 118 ++++++++++++++++++ 10 files changed, 423 insertions(+), 64 deletions(-) diff --git a/mooncake-store/include/client_rpc_service.h b/mooncake-store/include/client_rpc_service.h index 288df6ee23..ea60fba240 100644 --- a/mooncake-store/include/client_rpc_service.h +++ b/mooncake-store/include/client_rpc_service.h @@ -55,6 +55,17 @@ class ClientRpcService { tl::expected WriteRemoteData( const RemoteWriteRequest& request); + tl::expected PreWrite( + const PreWriteRequest& request); + + tl::expected WriteCommit( + const WriteCommitRequest& request); + + tl::expected PinKey( + const PinKeyRequest& request); + + tl::expected UnPinKey(const UnPinKeyRequest& request); + private: DataManager& data_manager_; // Reference: owned by Client, same lifetime P2PClientMetric* metrics_; // Optional: owned by P2PClientService diff --git a/mooncake-store/include/client_rpc_types.h b/mooncake-store/include/client_rpc_types.h index a9f2263caf..1940167544 100644 --- a/mooncake-store/include/client_rpc_types.h +++ b/mooncake-store/include/client_rpc_types.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "types.h" #include "ylt/struct_json/json_reader.h" @@ -84,4 +85,47 @@ struct BatchRemoteWriteRequest { YLT_REFL(BatchRemoteWriteRequest, keys, src_buffers_list, target_tier_ids); +struct PreWriteRequest { + std::string_view key; + uint64_t size_bytes = 0; + std::optional target_tier_id; +}; + +YLT_REFL(PreWriteRequest, key, size_bytes, target_tier_id); + +struct PreWriteResponse { + RemoteBufferDesc remote_buffer; + UUID pending_write_token; +}; + +YLT_REFL(PreWriteResponse, remote_buffer, pending_write_token); + +struct WriteCommitRequest { + std::string_view key; + UUID pending_write_token; +}; + +YLT_REFL(WriteCommitRequest, key, pending_write_token); + +struct PinKeyRequest { + std::string_view key; + std::optional target_tier_id; +}; + +YLT_REFL(PinKeyRequest, key, target_tier_id); + +struct PinKeyResponse { + RemoteBufferDesc remote_buffer; + UUID pin_token; +}; + +YLT_REFL(PinKeyResponse, remote_buffer, pin_token); + +struct UnPinKeyRequest { + std::string_view key; + UUID pin_token; +}; + +YLT_REFL(UnPinKeyRequest, key, pin_token); + } // namespace mooncake diff --git a/mooncake-store/include/data_manager.h b/mooncake-store/include/data_manager.h index 991d01f695..3a2ab92d3c 100644 --- a/mooncake-store/include/data_manager.h +++ b/mooncake-store/include/data_manager.h @@ -67,18 +67,6 @@ class DataManager { public: using TimePoint = std::chrono::time_point; - struct PreWriteResult { - RemoteBufferDesc remote_buffer; - uint64_t deadline_ms = 0; - UUID pending_write_token{0, 0}; - }; - - struct PinKeyResult { - RemoteBufferDesc remote_buffer; - uint64_t deadline_ms = 0; - UUID pin_token{0, 0}; - }; - /** * @brief Constructor * @param tiered_backend Unique pointer to TieredBackend instance (takes @@ -202,14 +190,14 @@ class DataManager { std::string_view key, const std::vector& src_buffers, std::optional tier_id = std::nullopt); - tl::expected PreWrite( + tl::expected PreWrite( std::string_view key, size_t size_bytes, std::optional tier_id = std::nullopt); tl::expected WriteCommit(std::string_view key, const UUID& pending_write_token); - tl::expected PinKey( + tl::expected PinKey( std::string_view key, std::optional tier_id = std::nullopt); tl::expected UnPinKey(std::string_view key, @@ -259,11 +247,12 @@ class DataManager { PendingWriteShard& GetPendingWriteShard(const KeyCtx& ctx); PinnedKeyShard& GetPinnedKeyShard(const KeyCtx& ctx); - tl::expected PreWriteInternal( - const KeyCtx& ctx, size_t size_bytes, std::optional tier_id); + tl::expected PreWriteInternal( + const KeyCtx& ctx, size_t size_bytes, std::optional tier_id, + bool enforce_dram_allocation); tl::expected WriteCommitInternal( const KeyCtx& ctx, const UUID& pending_write_token); - tl::expected PinKeyInternal( + tl::expected PinKeyInternal( const KeyCtx& ctx, std::optional tier_id); tl::expected UnPinKeyInternal(const KeyCtx& ctx, const UUID& pin_token); @@ -459,8 +448,6 @@ class DataManager { const std::chrono::milliseconds& lease_duration() const { return lease_duration_; } - uint64_t TimePointToDeadlineMs(TimePoint deadline) const; - TimePoint DeadlineMsToTimePoint(uint64_t deadline_ms) const; bool IsExpired(TimePoint deadline) const; RemoteBufferDesc BuildRemoteBufferDesc( const AllocationHandle& handle) const; diff --git a/mooncake-store/include/p2p_rpc_types.h b/mooncake-store/include/p2p_rpc_types.h index be03e2c741..43739c5fa7 100644 --- a/mooncake-store/include/p2p_rpc_types.h +++ b/mooncake-store/include/p2p_rpc_types.h @@ -2,6 +2,7 @@ #include #include +#include #include #include "replica.h" @@ -43,6 +44,26 @@ inline std::ostream& operator<<(std::ostream& os, return os; } +struct P2PWriteRouteConfig { + WriteRouteRequestConfig route_config; + std::optional rdma_direction_mode; +}; +YLT_REFL(P2PWriteRouteConfig, route_config, rdma_direction_mode); + +inline std::ostream& operator<<(std::ostream& os, + const P2PWriteRouteConfig& config) { + os << "P2PWriteRouteConfig: { route_config: [" << config.route_config + << "], rdma_direction_mode: "; + if (config.rdma_direction_mode.has_value()) { + os << *config.rdma_direction_mode; + } else { + // Unset: product rule is to resolve to REVERSE (see client default). + os << RdmaDirectionMode::REVERSE; + } + os << " }"; + return os; +} + /** * @brief Request structure for getting write route. */ diff --git a/mooncake-store/include/peer_client.h b/mooncake-store/include/peer_client.h index 068591c146..f70c1c6c97 100644 --- a/mooncake-store/include/peer_client.h +++ b/mooncake-store/include/peer_client.h @@ -27,11 +27,30 @@ class PeerClient { async_simple::coro::Lazy> AsyncWriteRemoteData(const RemoteWriteRequest& request); + async_simple::coro::Lazy> + AsyncPreWrite(const PreWriteRequest& request); + + async_simple::coro::Lazy> AsyncWriteCommit( + const WriteCommitRequest& request); + + async_simple::coro::Lazy> + AsyncPinKey(const PinKeyRequest& request); + + async_simple::coro::Lazy> AsyncUnPinKey( + const UnPinKeyRequest& request); + // --- Sync single-key interfaces --- tl::expected ReadRemoteData( const RemoteReadRequest& request); tl::expected WriteRemoteData( const RemoteWriteRequest& request); + tl::expected PreWrite( + const PreWriteRequest& request); + tl::expected WriteCommit( + const WriteCommitRequest& request); + tl::expected PinKey( + const PinKeyRequest& request); + tl::expected UnPinKey(const UnPinKeyRequest& request); private: std::shared_ptr> diff --git a/mooncake-store/include/rpc_types.h b/mooncake-store/include/rpc_types.h index 09067217b7..4491c09754 100644 --- a/mooncake-store/include/rpc_types.h +++ b/mooncake-store/include/rpc_types.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "types.h" #include "replica.h" #include "heartbeat_type.h" @@ -36,6 +38,12 @@ YLT_REFL(GetReplicaListRequestConfig, max_candidates, p2p_config); typedef GetReplicaListRequestConfig ReadRouteConfig; typedef P2PGetReplicaListConfigExtra P2PReadRouteConfigExtra; +struct P2PReadRouteConfig { + ReadRouteConfig route_config; + std::optional rdma_direction_mode; +}; +YLT_REFL(P2PReadRouteConfig, route_config, rdma_direction_mode); + /** * @brief Extra info for centralized read route response (Internal use) */ diff --git a/mooncake-store/include/types.h b/mooncake-store/include/types.h index e812954c70..10f56ac448 100644 --- a/mooncake-store/include/types.h +++ b/mooncake-store/include/types.h @@ -467,6 +467,31 @@ inline std::ostream& operator<<( return os; } +// Who initiates TE/RDMA for the data plane: REVERSE matches the historical +// target-initiated path and is the conventional default when unset optional or +// client-level config omits an explicit override. +enum class RdmaDirectionMode : uint8_t { + REVERSE = 0, + FORWARD = 1, +}; + +// Logging only: prints REVERSE / FORWARD / UNKNOWN for invalid numeric values. +inline std::ostream& operator<<(std::ostream& os, + const RdmaDirectionMode& mode) noexcept { + switch (mode) { + case RdmaDirectionMode::REVERSE: + os << "REVERSE"; + break; + case RdmaDirectionMode::FORWARD: + os << "FORWARD"; + break; + default: + os << "UNKNOWN"; + break; + } + return os; +} + } // namespace mooncake namespace std { diff --git a/mooncake-store/src/client_rpc_service.cpp b/mooncake-store/src/client_rpc_service.cpp index d3a62fd5f0..d186ac39f6 100644 --- a/mooncake-store/src/client_rpc_service.cpp +++ b/mooncake-store/src/client_rpc_service.cpp @@ -1,4 +1,7 @@ #include "client_rpc_service.h" + +#include + #include #include #include "utils/scoped_vlog_timer.h" @@ -51,6 +54,38 @@ bool IsValidRequest(const RemoteWriteRequest& request) { return true; } +bool IsValidRequest(const PreWriteRequest& request) { + if (request.key.empty() || request.size_bytes == 0) { + LOG(ERROR) << "PreWriteRequest: invalid key or size"; + return false; + } + return true; +} + +bool IsValidRequest(const WriteCommitRequest& request) { + if (request.key.empty() || IsZeroUuid(request.pending_write_token)) { + LOG(ERROR) << "WriteCommitRequest: invalid key or token"; + return false; + } + return true; +} + +bool IsValidRequest(const PinKeyRequest& request) { + if (request.key.empty()) { + LOG(ERROR) << "PinKeyRequest: empty key"; + return false; + } + return true; +} + +bool IsValidRequest(const UnPinKeyRequest& request) { + if (request.key.empty() || IsZeroUuid(request.pin_token)) { + LOG(ERROR) << "UnPinKeyRequest: invalid key or token"; + return false; + } + return true; +} + } // anonymous namespace ClientRpcService::ClientRpcService(DataManager& data_manager, @@ -157,10 +192,104 @@ tl::expected ClientRpcService::WriteRemoteData( return result; } +tl::expected ClientRpcService::PreWrite( + const PreWriteRequest& request) { + ScopedVLogTimer timer(1, "ClientRpcService::PreWrite"); + timer.LogRequest("key=", request.key, "size_bytes=", request.size_bytes); + + if (!IsValidRequest(request)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + auto result = data_manager_.PreWrite(request.key, request.size_bytes, + request.target_tier_id); + if (!result) { + LOG(ERROR) << "PreWrite failed for key: " << request.key + << ", error: " << toString(result.error()); + timer.LogResponse("error_code=", result.error()); + return tl::make_unexpected(result.error()); + } + + timer.LogResponse("error_code=", ErrorCode::OK); + return std::move(*result); +} + +tl::expected ClientRpcService::WriteCommit( + const WriteCommitRequest& request) { + ScopedVLogTimer timer(1, "ClientRpcService::WriteCommit"); + timer.LogRequest("key=", request.key); + + if (!IsValidRequest(request)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + auto result = + data_manager_.WriteCommit(request.key, request.pending_write_token); + if (!result) { + LOG(ERROR) << "WriteCommit failed for key: " << request.key + << ", error: " << toString(result.error()); + timer.LogResponse("error_code=", result.error()); + return result; + } + + timer.LogResponse("error_code=", ErrorCode::OK); + return {}; +} + +tl::expected ClientRpcService::PinKey( + const PinKeyRequest& request) { + ScopedVLogTimer timer(1, "ClientRpcService::PinKey"); + timer.LogRequest("key=", request.key); + + if (!IsValidRequest(request)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + auto result = data_manager_.PinKey(request.key, request.target_tier_id); + if (!result) { + LOG(ERROR) << "PinKey failed for key: " << request.key + << ", error: " << toString(result.error()); + timer.LogResponse("error_code=", result.error()); + return tl::make_unexpected(result.error()); + } + + timer.LogResponse("error_code=", ErrorCode::OK); + return std::move(*result); +} + +tl::expected ClientRpcService::UnPinKey( + const UnPinKeyRequest& request) { + ScopedVLogTimer timer(1, "ClientRpcService::UnPinKey"); + timer.LogRequest("key=", request.key); + + if (!IsValidRequest(request)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + auto result = data_manager_.UnPinKey(request.key, request.pin_token); + if (!result) { + LOG(ERROR) << "UnPinKey failed for key: " << request.key + << ", error: " << toString(result.error()); + timer.LogResponse("error_code=", result.error()); + return result; + } + + timer.LogResponse("error_code=", ErrorCode::OK); + return {}; +} + void RegisterClientRpcService(coro_rpc::coro_rpc_server& server, ClientRpcService& service) { server.register_handler<&ClientRpcService::ReadRemoteData>(&service); server.register_handler<&ClientRpcService::WriteRemoteData>(&service); + server.register_handler<&ClientRpcService::PreWrite>(&service); + server.register_handler<&ClientRpcService::WriteCommit>(&service); + server.register_handler<&ClientRpcService::PinKey>(&service); + server.register_handler<&ClientRpcService::UnPinKey>(&service); } } // namespace mooncake diff --git a/mooncake-store/src/data_manager.cpp b/mooncake-store/src/data_manager.cpp index 068f37c85b..a41830d21f 100644 --- a/mooncake-store/src/data_manager.cpp +++ b/mooncake-store/src/data_manager.cpp @@ -231,31 +231,6 @@ DataManager::PinnedKeyShard& DataManager::GetPinnedKeyShard( return pinned_key_shards_[HashKey(key) % pinned_key_shards_.size()]; } -uint64_t DataManager::TimePointToDeadlineMs(TimePoint deadline) const { - const auto remaining = - std::max(std::chrono::milliseconds::zero(), - std::chrono::duration_cast( - deadline - std::chrono::steady_clock::now())); - const auto system_deadline = std::chrono::system_clock::now() + remaining; - return static_cast( - std::chrono::duration_cast( - system_deadline.time_since_epoch()) - .count()); -} - -DataManager::TimePoint DataManager::DeadlineMsToTimePoint( - uint64_t deadline_ms) const { - const auto system_deadline = std::chrono::system_clock::time_point( - std::chrono::milliseconds(static_cast(deadline_ms))); - const auto remaining = - std::chrono::duration_cast( - system_deadline - std::chrono::system_clock::now()); - if (remaining <= std::chrono::milliseconds::zero()) { - return std::chrono::steady_clock::now(); - } - return std::chrono::steady_clock::now() + remaining; -} - bool DataManager::IsExpired(TimePoint deadline) const { return deadline <= std::chrono::steady_clock::now(); } @@ -518,7 +493,9 @@ DataManager::PutViaTe(std::string_view key, std::vector& slices) { return tl::unexpected(validate_result.error()); } - auto prewrite_result = PreWriteInternal(kctx, total_size, std::nullopt); + // Local Put: allocation follows tier backend policy (not restricted to DRAM). + auto prewrite_result = + PreWriteInternal(kctx, total_size, std::nullopt, false); if (!prewrite_result) { return tl::unexpected(prewrite_result.error()); } @@ -593,7 +570,9 @@ DataManager::PutViaMemcpy(std::string_view key, std::vector& slices) { const KeyCtx kctx = BuildKeyCtx(key); Slice slice = slices[0]; - auto prewrite_result = PreWriteInternal(kctx, slice.size, std::nullopt); + // Same allocation policy as PutViaTe. + auto prewrite_result = + PreWriteInternal(kctx, slice.size, std::nullopt, false); if (!prewrite_result) { return tl::unexpected(prewrite_result.error()); } @@ -886,8 +865,9 @@ tl::expected DataManager::WriteRemoteData( for (const auto& buf : src_buffers) total_size += buf.size; // Reverse RDMA path: still one RPC, but internally use the 3-phase write - // model (PreWrite -> transfer -> WriteCommit). - auto prewrite_result = PreWriteInternal(kctx, total_size, tier_id); + // model (PreWrite -> transfer -> WriteCommit). Target tier may be non-DRAM. + auto prewrite_result = + PreWriteInternal(kctx, total_size, tier_id, false); if (!prewrite_result) { timer.LogResponse("error_code=", prewrite_result.error()); return tl::make_unexpected(prewrite_result.error()); @@ -923,14 +903,17 @@ tl::expected DataManager::WriteRemoteData( return result_tier_id; } -tl::expected DataManager::PreWrite( +tl::expected DataManager::PreWrite( std::string_view key, size_t size_bytes, std::optional tier_id) { - return PreWriteInternal(BuildKeyCtx(key), size_bytes, tier_id); + // RPC PreWrite: forward path is wired for DRAM only for now; other tiers + // TODO (staging / TE registration). + return PreWriteInternal(BuildKeyCtx(key), size_bytes, tier_id, true); } -tl::expected +tl::expected DataManager::PreWriteInternal(const KeyCtx& ctx, size_t size_bytes, - std::optional tier_id) { + std::optional tier_id, + bool enforce_dram_allocation) { ScopedVLogTimer timer(1, "DataManager::PreWrite"); timer.LogRequest("key=", ctx.key, "size_bytes=", size_bytes); @@ -963,6 +946,14 @@ DataManager::PreWriteInternal(const KeyCtx& ctx, size_t size_bytes, } auto handle = std::move(handle_result.value()); + // When enforce_dram_allocation is true (RPC PreWrite): only DRAM is wired + // for forward TE today; non-DRAM tiers TODO. Local Put / WriteRemoteData use + // false and skip this check. + if (enforce_dram_allocation && + handle->loc.data.type != MemoryType::DRAM) { + timer.LogResponse("error_code=", ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + return tl::make_unexpected(ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + } auto list_it = shard.ordered_list.emplace(shard.ordered_list.end(), ctx.key_string, deadline); @@ -974,12 +965,10 @@ DataManager::PreWriteInternal(const KeyCtx& ctx, size_t size_bytes, record.list_it = list_it; shard.by_key.insert_or_assign(ctx.key_string, std::move(record)); - PreWriteResult result; + PreWriteResponse result; result.remote_buffer = BuildRemoteBufferDesc(handle); - result.deadline_ms = TimePointToDeadlineMs(deadline); result.pending_write_token = pending_write_token; - timer.LogResponse("error_code=", ErrorCode::OK, - "deadline_ms=", result.deadline_ms); + timer.LogResponse("error_code=", ErrorCode::OK); return result; } @@ -1045,12 +1034,12 @@ tl::expected DataManager::WriteCommitInternal( return {}; } -tl::expected DataManager::PinKey( +tl::expected DataManager::PinKey( std::string_view key, std::optional tier_id) { return PinKeyInternal(BuildKeyCtx(key), tier_id); } -tl::expected DataManager::PinKeyInternal( +tl::expected DataManager::PinKeyInternal( const KeyCtx& ctx, std::optional tier_id) { ScopedVLogTimer timer(1, "DataManager::PinKey"); timer.LogRequest("key=", ctx.key); @@ -1070,14 +1059,19 @@ tl::expected DataManager::PinKeyInternal( RemoveExpiredPinnedKeyLocked(shard, ctx.key_string, now); auto record_it = shard.by_key.find(ctx.key_string); if (record_it != shard.by_key.end()) { + // PinKey forward path: DRAM-only for now; non-DRAM replica handling TODO. + if (record_it->second.handle->loc.data.type != MemoryType::DRAM) { + timer.LogResponse("error_code=", + ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + return tl::make_unexpected(ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + } record_it->second.ref_count++; record_it->second.deadline = deadline; TouchOrderedDeadlineNode(shard.ordered_list, record_it->second.list_it, ctx.key_string, deadline); - PinKeyResult result; + PinKeyResponse result; result.remote_buffer = BuildRemoteBufferDesc(record_it->second.handle); - result.deadline_ms = TimePointToDeadlineMs(deadline); result.pin_token = record_it->second.pin_token; timer.LogResponse("error_code=", ErrorCode::OK, "ref_count=", record_it->second.ref_count); @@ -1091,6 +1085,11 @@ tl::expected DataManager::PinKeyInternal( } auto handle = std::move(handle_result.value()); + // PinKey forward path: DRAM-only for now; non-DRAM replica handling TODO. + if (handle->loc.data.type != MemoryType::DRAM) { + timer.LogResponse("error_code=", ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + return tl::make_unexpected(ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + } auto list_it = shard.ordered_list.emplace(shard.ordered_list.end(), ctx.key_string, deadline); @@ -1103,12 +1102,10 @@ tl::expected DataManager::PinKeyInternal( record.list_it = list_it; shard.by_key.insert_or_assign(ctx.key_string, std::move(record)); - PinKeyResult result; + PinKeyResponse result; result.remote_buffer = BuildRemoteBufferDesc(handle); - result.deadline_ms = TimePointToDeadlineMs(deadline); result.pin_token = pin_token_value; - timer.LogResponse("error_code=", ErrorCode::OK, - "deadline_ms=", result.deadline_ms); + timer.LogResponse("error_code=", ErrorCode::OK); return result; } diff --git a/mooncake-store/src/peer_client.cpp b/mooncake-store/src/peer_client.cpp index ad44b78b0c..feb37bbe77 100644 --- a/mooncake-store/src/peer_client.cpp +++ b/mooncake-store/src/peer_client.cpp @@ -78,6 +78,104 @@ PeerClient::AsyncWriteRemoteData(const RemoteWriteRequest& request) { co_return result->result(); } +async_simple::coro::Lazy> +PeerClient::AsyncPreWrite(const PreWriteRequest& request) { + if (!client_pool_) { + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto ret = co_await client_pool_->send_request( + [&](coro_io::client_reuse_hint, coro_rpc::coro_rpc_client& client) { + return client.send_request<&ClientRpcService::PreWrite>(request); + }); + if (!ret.has_value()) { + LOG(ERROR) << "AsyncPreWrite: client not available"; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto result = co_await std::move(ret.value()); + if (!result) { + LOG(ERROR) << "AsyncPreWrite: RPC call failed: " << result.error().msg; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + co_return result->result(); +} + +async_simple::coro::Lazy> +PeerClient::AsyncWriteCommit(const WriteCommitRequest& request) { + if (!client_pool_) { + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto ret = co_await client_pool_->send_request( + [&](coro_io::client_reuse_hint, coro_rpc::coro_rpc_client& client) { + return client.send_request<&ClientRpcService::WriteCommit>(request); + }); + if (!ret.has_value()) { + LOG(ERROR) << "AsyncWriteCommit: client not available"; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto result = co_await std::move(ret.value()); + if (!result) { + LOG(ERROR) << "AsyncWriteCommit: RPC call failed: " + << result.error().msg; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + co_return result->result(); +} + +async_simple::coro::Lazy> +PeerClient::AsyncPinKey(const PinKeyRequest& request) { + if (!client_pool_) { + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto ret = co_await client_pool_->send_request( + [&](coro_io::client_reuse_hint, coro_rpc::coro_rpc_client& client) { + return client.send_request<&ClientRpcService::PinKey>(request); + }); + if (!ret.has_value()) { + LOG(ERROR) << "AsyncPinKey: client not available"; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto result = co_await std::move(ret.value()); + if (!result) { + LOG(ERROR) << "AsyncPinKey: RPC call failed: " << result.error().msg; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + co_return result->result(); +} + +async_simple::coro::Lazy> +PeerClient::AsyncUnPinKey(const UnPinKeyRequest& request) { + if (!client_pool_) { + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto ret = co_await client_pool_->send_request( + [&](coro_io::client_reuse_hint, coro_rpc::coro_rpc_client& client) { + return client.send_request<&ClientRpcService::UnPinKey>(request); + }); + if (!ret.has_value()) { + LOG(ERROR) << "AsyncUnPinKey: client not available"; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto result = co_await std::move(ret.value()); + if (!result) { + LOG(ERROR) << "AsyncUnPinKey: RPC call failed: " + << result.error().msg; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + co_return result->result(); +} + tl::expected PeerClient::ReadRemoteData( const RemoteReadRequest& request) { return async_simple::coro::syncAwait(AsyncReadRemoteData(request)); @@ -88,4 +186,24 @@ tl::expected PeerClient::WriteRemoteData( return async_simple::coro::syncAwait(AsyncWriteRemoteData(request)); } +tl::expected PeerClient::PreWrite( + const PreWriteRequest& request) { + return async_simple::coro::syncAwait(AsyncPreWrite(request)); +} + +tl::expected PeerClient::WriteCommit( + const WriteCommitRequest& request) { + return async_simple::coro::syncAwait(AsyncWriteCommit(request)); +} + +tl::expected PeerClient::PinKey( + const PinKeyRequest& request) { + return async_simple::coro::syncAwait(AsyncPinKey(request)); +} + +tl::expected PeerClient::UnPinKey( + const UnPinKeyRequest& request) { + return async_simple::coro::syncAwait(AsyncUnPinKey(request)); +} + } // namespace mooncake From 4ddd121d752745aa8cb4a2dff6743fbee1e8c96a Mon Sep 17 00:00:00 2001 From: shichangzhang064 Date: Mon, 11 May 2026 20:45:58 +0800 Subject: [PATCH 09/14] feat: add forward transfer config and p2p client service call --- mooncake-integration/store/store_py.cpp | 84 +++-- .../include/centralized_client_service.h | 12 +- mooncake-store/include/client_service.h | 16 +- mooncake-store/include/data_manager.h | 20 + mooncake-store/include/dummy_client.h | 12 +- mooncake-store/include/p2p_client_service.h | 91 +++-- mooncake-store/include/p2p_rpc_types.h | 25 +- mooncake-store/include/pyclient.h | 12 +- mooncake-store/include/real_client.h | 26 +- mooncake-store/include/rpc_types.h | 15 +- .../src/centralized_client_service.cpp | 16 +- mooncake-store/src/data_manager.cpp | 64 +++- mooncake-store/src/dummy_client.cpp | 12 +- mooncake-store/src/p2p_client_service.cpp | 355 ++++++++++++++++-- mooncake-store/src/real_client.cpp | 26 +- 15 files changed, 609 insertions(+), 177 deletions(-) diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index c5494c807d..39cb86823a 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -5,6 +5,8 @@ #include "pyclient.h" #include "dummy_client.h" #include "real_client.h" +#include "p2p_rpc_types.h" +#include "rpc_types.h" #include // for atexit #include @@ -231,8 +233,8 @@ class MooncakeStorePyWrapper { pybind11::bytes get( const std::string& key, - const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); if (!is_client_initialized()) { LOG(ERROR) << "Client is not initialized"; return pybind11::bytes("\\0", 0); @@ -268,8 +270,8 @@ class MooncakeStorePyWrapper { std::vector get_batch( const std::vector& keys, - const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); const auto kNullString = pybind11::bytes("\\0", 0); if (!is_client_initialized()) { LOG(ERROR) << "Client is not initialized"; @@ -302,7 +304,7 @@ class MooncakeStorePyWrapper { pybind11::object get_tensor_with_tp( const std::string& key, int tp_rank = 0, int tp_size = 1, int split_dim = 0, - const std::optional& config_opt = std::nullopt) { + const std::optional& config_opt = std::nullopt) { if (tp_size <= 1) return get_tensor(key, config_opt); return get_tensor(get_tp_key_name(key, tp_rank), config_opt); } @@ -310,7 +312,7 @@ class MooncakeStorePyWrapper { pybind11::list batch_get_tensor_with_tp( const std::vector& base_keys, int tp_rank = 0, int tp_size = 1, - const std::optional& config_opt = std::nullopt) { + const std::optional& config_opt = std::nullopt) { if (tp_size <= 1) return batch_get_tensor(base_keys, config_opt); std::vector shard_keys; @@ -323,8 +325,8 @@ class MooncakeStorePyWrapper { pybind11::object get_tensor( const std::string& key, - const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); if (!is_client_initialized() || use_dummy_client_) { LOG(ERROR) << "Client not initialized or Dummy client not " "supported for tensors"; @@ -342,8 +344,8 @@ class MooncakeStorePyWrapper { pybind11::list batch_get_tensor( const std::vector& keys, - const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); if (!is_client_initialized() || use_dummy_client_) { LOG(ERROR) << "Client not initialized or Dummy client not " "supported for tensors"; @@ -367,8 +369,8 @@ class MooncakeStorePyWrapper { pybind11::object get_tensor_into( const std::string& key, uintptr_t buffer_ptr, size_t size, - const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); char* buffer = reinterpret_cast(buffer_ptr); if (!is_client_initialized()) { LOG(ERROR) << "Client is not initialized"; @@ -396,8 +398,8 @@ class MooncakeStorePyWrapper { const std::vector& keys, const std::vector& buffer_ptrs, const std::vector& sizes, - const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); std::vector buffers; buffers.reserve(buffer_ptrs.size()); for (uintptr_t ptr : buffer_ptrs) { @@ -460,7 +462,7 @@ class MooncakeStorePyWrapper { pybind11::object get_tensor_with_tp_into( const std::string& key, uintptr_t buffer_ptr, size_t size, int tp_rank = 0, int tp_size = 1, int split_dim = 0, - const std::optional& config_opt = std::nullopt) { + const std::optional& config_opt = std::nullopt) { if (!is_client_initialized()) { LOG(ERROR) << "Client is not initialized"; return pybind11::none(); @@ -487,7 +489,7 @@ class MooncakeStorePyWrapper { const std::vector& base_keys, const std::vector& buffer_ptrs, const std::vector& sizes, int tp_rank = 0, int tp_size = 1, - const std::optional& config_opt = std::nullopt) { + const std::optional& config_opt = std::nullopt) { if (!is_client_initialized()) { LOG(ERROR) << "Client is not initialized"; py::list empty_list; @@ -877,6 +879,19 @@ PYBIND11_MODULE(store, m) { .def_readwrite("max_candidates", &ReadRouteConfig::max_candidates) .def_readwrite("p2p_config", &ReadRouteConfig::p2p_config); + py::enum_(m, "RdmaDirectionMode") + .value("REVERSE", RdmaDirectionMode::REVERSE) + .value("FORWARD", RdmaDirectionMode::FORWARD); + + py::class_(m, "ReadConfigExt") + .def(py::init<>()) + .def(py::init()) + .def_readwrite("route_config", &ReadConfigExt::route_config) + .def_readwrite("rdma_direction_mode", + &ReadConfigExt::rdma_direction_mode); + + py::implicitly_convertible(); + py::class_(m, "WriteRouteRequestConfig") .def(py::init<>()) // Default constructor .def_readwrite("max_candidates", @@ -894,6 +909,13 @@ PYBIND11_MODULE(store, m) { return oss.str(); }); + py::class_(m, "WriteConfigExt") + .def(py::init<>()) + .def(py::init()) + .def_readwrite("route_config", &WriteConfigExt::route_config) + .def_readwrite("rdma_direction_mode", + &WriteConfigExt::rdma_direction_mode); + py::enum_(m, "ReplicaStatus") .value("UNDEFINED", ReplicaStatus::UNDEFINED) .value("INITIALIZED", ReplicaStatus::INITIALIZED) @@ -1152,8 +1174,8 @@ PYBIND11_MODULE(store, m) { .def( "get_buffer", [](MooncakeStorePyWrapper& self, const std::string& key, - const std::optional& config_opt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); py::gil_scoped_release release; return self.store_->get_buffer(key, config); }, @@ -1163,8 +1185,8 @@ PYBIND11_MODULE(store, m) { "batch_get_buffer", [](MooncakeStorePyWrapper& self, const std::vector& keys, - const std::optional& config_opt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); py::gil_scoped_release release; if (self.use_dummy_client_) { LOG(ERROR) << "batch_get_buffer is not supported for dummy " @@ -1234,7 +1256,10 @@ PYBIND11_MODULE(store, m) { " tp_size: The total tensor parallel size (default 1).\n" " split_dim: The dimension to split the tensor along " "(default 0).\n" - " config: ReadRouteConfig.") + " config: ReadConfigExt (optional; omit for defaults). " + "ReadRouteConfig is accepted and treated as ReadConfigExt with " + "default RdmaDirectionMode.REVERSE. Set config.rdma_direction_mode " + "to FORWARD for forward RDMA read.") .def("batch_get_tensor_with_tp", &MooncakeStorePyWrapper::batch_get_tensor_with_tp, py::arg("base_keys"), py::arg("tp_rank") = 0, @@ -1296,7 +1321,10 @@ PYBIND11_MODULE(store, m) { " tp_size: The total tensor parallel size (default 1).\n" " split_dim: The dimension to split the tensor along" "(default 0).\n" - " config: ReadRouteConfig.") + " config: ReadConfigExt (optional; omit for defaults). " + "ReadRouteConfig is accepted and treated as ReadConfigExt with " + "default RdmaDirectionMode.REVERSE. Set config.rdma_direction_mode " + "to FORWARD for forward RDMA read.") .def( "batch_get_tensor_with_tp_into", &MooncakeStorePyWrapper::batch_get_tensor_with_tp_into, @@ -1331,8 +1359,8 @@ PYBIND11_MODULE(store, m) { "get_into", [](MooncakeStorePyWrapper& self, const std::string& key, uintptr_t buffer_ptr, size_t size, - const std::optional& config_opt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); // Get data directly into user-provided buffer void* buffer = reinterpret_cast(buffer_ptr); py::gil_scoped_release release; @@ -1352,8 +1380,8 @@ PYBIND11_MODULE(store, m) { const std::vector& keys, const std::vector& buffer_ptrs, const std::vector& sizes, - const std::optional& config_opt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); std::vector buffers; buffers.reserve(buffer_ptrs.size()); for (uintptr_t ptr : buffer_ptrs) { @@ -1543,9 +1571,9 @@ PYBIND11_MODULE(store, m) { const std::vector>& all_buffer_ptrs, const std::vector>& all_sizes, bool aggregate_same_segment_task = false, - const std::optional& config_opt = + const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); py::gil_scoped_release release; if (self.use_dummy_client_) { LOG(ERROR) diff --git a/mooncake-store/include/centralized_client_service.h b/mooncake-store/include/centralized_client_service.h index 6df407acc0..a840959732 100644 --- a/mooncake-store/include/centralized_client_service.h +++ b/mooncake-store/include/centralized_client_service.h @@ -54,11 +54,11 @@ class CentralizedClientService tl::expected, ErrorCode> Query( const std::string& object_key, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector, ErrorCode>> BatchQuery(const std::vector& object_keys, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; tl::expected IsExist(const std::string& key) override; @@ -76,24 +76,24 @@ class CentralizedClientService tl::expected Get( const std::string& key, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector> BatchGet( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - const ReadRouteConfig& config = {}, + const ReadConfigExt& config = {}, bool aggregate_same_segment_task = false) override; tl::expected, ErrorCode> Get( const std::string& key, std::shared_ptr allocator, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector, ErrorCode>> BatchGet(const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; tl::expected Put(const ObjectKey& key, std::vector& slices, diff --git a/mooncake-store/include/client_service.h b/mooncake-store/include/client_service.h index 955cdc39cb..53567571fe 100644 --- a/mooncake-store/include/client_service.h +++ b/mooncake-store/include/client_service.h @@ -18,6 +18,7 @@ #include "transfer_engine.h" #include "types.h" #include "p2p_rpc_types.h" +#include "rpc_types.h" #include "replica.h" #include "master_client.h" #include @@ -27,7 +28,8 @@ namespace mooncake { -using WriteConfig = std::variant; +using WriteConfig = + std::variant; /** * @brief Result of a query operation containing replica information @@ -117,7 +119,7 @@ class ClientService { * indicating failure */ virtual tl::expected, ErrorCode> Query( - const std::string& object_key, const ReadRouteConfig& config = {}) = 0; + const std::string& object_key, const ReadConfigExt& config = {}) = 0; /** * @brief Batch query object metadata without transferring data @@ -126,7 +128,7 @@ class ClientService { */ virtual std::vector, ErrorCode>> BatchQuery(const std::vector& object_keys, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; /** * @brief Gets data with memory allocation @@ -139,12 +141,12 @@ class ClientService { virtual tl::expected, ErrorCode> Get( const std::string& key, std::shared_ptr allocator, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; virtual std::vector, ErrorCode>> BatchGet(const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; /** * @brief Gets data into user-provided buffers without memory allocation @@ -157,7 +159,7 @@ class ClientService { virtual tl::expected Get( const std::string& key, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; /** * @brief Batch get data into user-provided buffers @@ -175,7 +177,7 @@ class ClientService { const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - const ReadRouteConfig& config = {}, + const ReadConfigExt& config = {}, bool aggregate_same_segment_task = false) = 0; /** diff --git a/mooncake-store/include/data_manager.h b/mooncake-store/include/data_manager.h index 3a2ab92d3c..b37c992a91 100644 --- a/mooncake-store/include/data_manager.h +++ b/mooncake-store/include/data_manager.h @@ -203,6 +203,20 @@ class DataManager { tl::expected UnPinKey(std::string_view key, const UUID& pin_token); + /** + * @brief TE transfer without tier DRAM staging (PrepareDRAM*). + * + * Caller guarantees `local_transfer_base` covers a contiguous layout of + * `total_size` bytes that is valid for TransferEngine (typically registered + * DRAM). Used by forward RDMA paths where buffers are already TE-ready. + * + * @param opcode WRITE: local -> peer_buffers; READ: peer_buffers -> local + */ + tl::expected TransferWithTeNoTierStaging( + void* local_transfer_base, size_t total_size, + const std::vector& peer_buffers, + Transport::TransferRequest::OpCode opcode); + // ================================================================ // Utilities // ================================================================ @@ -329,6 +343,12 @@ class DataManager { const std::vector& remote_buffers, Transport::TransferRequest::OpCode opcode); + tl::expected>, + ErrorCode> + SubmitTeTransferBatches(void* transfer_ptr, size_t total_data_size, + const std::vector& remote_buffers, + Transport::TransferRequest::OpCode opcode); + /** * @brief Helper to wait for a transfer batch to complete * @param batch_id Batch ID to poll diff --git a/mooncake-store/include/dummy_client.h b/mooncake-store/include/dummy_client.h index c7efd5330d..eadf884054 100644 --- a/mooncake-store/include/dummy_client.h +++ b/mooncake-store/include/dummy_client.h @@ -75,19 +75,19 @@ class DummyClient : public PyClient { int unregister_buffer(void* buffer) override; int64_t get_into(const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector batch_get_into( const std::vector& keys, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector batch_get_into_multi_buffers( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, bool aggregate_same_segment_task, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; int put_from(const std::string& key, void* buffer, size_t size, const WriteConfig& config) override; @@ -109,14 +109,14 @@ class DummyClient : public PyClient { const WriteConfig& config) override; std::shared_ptr get_buffer( - const std::string& key, const ReadRouteConfig& config = {}) override; + const std::string& key, const ReadConfigExt& config = {}) override; std::tuple get_buffer_info( - const std::string& key, const ReadRouteConfig& config = {}) override; + const std::string& key, const ReadConfigExt& config = {}) override; std::vector> batch_get_buffer( const std::vector& keys, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; int put_parts(const std::string& key, std::vector> values, diff --git a/mooncake-store/include/p2p_client_service.h b/mooncake-store/include/p2p_client_service.h index 5571c1a251..38fba02261 100644 --- a/mooncake-store/include/p2p_client_service.h +++ b/mooncake-store/include/p2p_client_service.h @@ -91,7 +91,7 @@ class P2PClientService final : public ClientService { */ tl::expected, ErrorCode> Query( const std::string& object_key, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; /** * @brief Batch query object metadata without transferring data @@ -100,7 +100,7 @@ class P2PClientService final : public ClientService { */ std::vector, ErrorCode>> BatchQuery(const std::vector& object_keys, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; tl::expected IsExist(const std::string& key) override; @@ -114,23 +114,23 @@ class P2PClientService final : public ClientService { tl::expected, ErrorCode> Get( const std::string& key, std::shared_ptr allocator, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector, ErrorCode>> BatchGet(const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; tl::expected Get( const std::string& key, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector> BatchGet( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - const ReadRouteConfig& config = {}, + const ReadConfigExt& config = {}, bool aggregate_same_segment_task = false) override; /** @@ -245,7 +245,8 @@ class P2PClientService final : public ClientService { std::vector> InnerBatchPut( const std::vector& keys, std::vector>& batched_slices, - const WriteRouteRequestConfig& route_config); + const WriteRouteRequestConfig& route_config, + RdmaDirectionMode rdma_direction_mode); std::vector> InnerBatchPutDegraded( const std::vector& keys, @@ -254,12 +255,14 @@ class P2PClientService final : public ClientService { std::vector> InnerBatchPutNormal( const std::vector& keys, std::vector>& batched_slices, - const WriteRouteRequestConfig& route_config); + const WriteRouteRequestConfig& route_config, + RdmaDirectionMode rdma_direction_mode); std::vector>, ErrorCode>> CreatePutHandlesFromRoute(const std::vector& keys, std::vector>& batched_slices, const WriteRouteRequestConfig& route_config, + RdmaDirectionMode rdma_direction_mode, BatchGetWriteRouteResponse& batch_resp); tl::expected>, ErrorCode> @@ -295,19 +298,30 @@ class P2PClientService final : public ClientService { }; struct RemoteWriteOp : WriteOp { + P2PClientService* owner_service = nullptr; PeerClient* peer_ptr; std::shared_ptr write_req; P2PProxyDescriptor proxy; RouteCache* route_cache; std::string endpoint; - - RemoteWriteOp(PeerClient* p, std::shared_ptr wr, - P2PProxyDescriptor px, RouteCache* rc, std::string ep) - : peer_ptr(p), + DataManager* forward_dm = nullptr; + std::vector* forward_slices = nullptr; + RdmaDirectionMode rdma_direction_mode = RdmaDirectionMode::REVERSE; + + RemoteWriteOp(P2PClientService* owner, PeerClient* p, + std::shared_ptr wr, + P2PProxyDescriptor px, RouteCache* rc, std::string ep, + DataManager* fwd_dm, std::vector* fwd_slices, + RdmaDirectionMode rdma_mode) + : owner_service(owner), + peer_ptr(p), write_req(std::move(wr)), proxy(std::move(px)), route_cache(rc), - endpoint(std::move(ep)) {} + endpoint(std::move(ep)), + forward_dm(fwd_dm), + forward_slices(fwd_slices), + rdma_direction_mode(rdma_mode) {} std::string_view route() const override { return endpoint; } std::unique_ptr> Dispatch() override; @@ -316,6 +330,7 @@ class P2PClientService final : public ClientService { tl::expected>, ErrorCode> BuildWriteOps(std::string_view key, std::vector& slices, const WriteRouteRequestConfig& config, + RdmaDirectionMode rdma_direction_mode, std::vector candidates); async_simple::coro::Lazy RunWriteWithRetry( @@ -370,10 +385,10 @@ class P2PClientService final : public ClientService { const std::vector& replicas); tl::expected BuildRouteIter( - std::string_view key, const ReadRouteConfig& config); + std::string_view key, const ReadConfigExt& config); tl::expected BuildRouteIter( - std::string_view key, const ReadRouteConfig& config, + std::string_view key, const ReadConfigExt& config, std::vector pre_fetched); private: @@ -385,30 +400,30 @@ class P2PClientService final : public ClientService { std::vector> BatchCreateGetHandles( const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config); + const ReadConfigExt& config); std::vector> BatchCreateGetHandles( const std::vector& keys, std::vector>& all_slices, - const ReadRouteConfig& config); + const ReadConfigExt& config); template std::vector> BatchCreateGetHandlesImpl(const std::vector& keys, - const ReadRouteConfig& config, + const ReadConfigExt& config, LocalGetFn&& local_get, RemoteGetFn&& remote_get); std::vector, ErrorCode>> BatchFetchReadRoutes(const std::vector& keys, - const ReadRouteConfig& config); + const ReadConfigExt& config); tl::expected CreateRemoteGetHandle( std::string_view key, std::shared_ptr allocator, - const ReadRouteConfig& config, std::vector pre_fetched); + const ReadConfigExt& config, std::vector pre_fetched); tl::expected CreateRemoteGetHandle( std::string_view key, std::vector& slices, - const ReadRouteConfig& config, std::vector pre_fetched); + const ReadConfigExt& config, std::vector pre_fetched); /** * @brief Launch async reads driven by a RouteIterator. @@ -417,16 +432,44 @@ class P2PClientService final : public ClientService { * and chains subsequent candidates on failure (no stack recursion). */ tl::expected InnerGetViaRoute( - std::string_view key, std::vector& slices, RouteIterator iter); + std::string_view key, std::vector& slices, RouteIterator iter, + RdmaDirectionMode rdma_direction_mode); async_simple::coro::Lazy RunReadWithRetry( RouteIterator iter, std::shared_ptr req, std::shared_ptr>> - promise); + promise, + RdmaDirectionMode rdma_direction_mode); + + // true = promise set, caller co_return; false = try next route (incl. hard + // UnPin failure after TE success on this replica). + async_simple::coro::Lazy RunForwardReadOnRoute( + const ResolvedRoute& route, std::shared_ptr req, + std::shared_ptr>> + promise, + RouteIterator& iter, ErrorCode& final_result); + + async_simple::coro::Lazy RunForwardRemotePut( + std::shared_ptr>> + promise, + PeerClient* peer, DataManager* dm, + std::shared_ptr write_req, + std::vector* slices); + + // RemoteWriteOp forward path: promise + RunForwardRemotePut coroutine. + std::unique_ptr> StartForwardRemotePut( + PeerClient* peer, DataManager* forward_dm, + std::vector* forward_slices, + std::shared_ptr write_req); + + // RemoteWriteOp reverse path: AsyncWriteRemoteData + route cache upsert. + std::unique_ptr> RunReverseRemotePut( + PeerClient* peer, std::shared_ptr write_req, + const P2PProxyDescriptor& proxy, RouteCache* route_cache); async_simple::coro::Lazy> AsyncResolveRoutesFromMaster(std::string_view key, - const ReadRouteConfig& config); + const ReadConfigExt& config); /** * @brief Get or create a PeerClient for the given endpoint. diff --git a/mooncake-store/include/p2p_rpc_types.h b/mooncake-store/include/p2p_rpc_types.h index 43739c5fa7..a72c7108ed 100644 --- a/mooncake-store/include/p2p_rpc_types.h +++ b/mooncake-store/include/p2p_rpc_types.h @@ -44,23 +44,20 @@ inline std::ostream& operator<<(std::ostream& os, return os; } -struct P2PWriteRouteConfig { - WriteRouteRequestConfig route_config; - std::optional rdma_direction_mode; +struct WriteConfigExt { + WriteRouteRequestConfig route_config{}; + RdmaDirectionMode rdma_direction_mode = RdmaDirectionMode::REVERSE; + + WriteConfigExt() = default; + /** Promotes legacy write-route-only config for variant overload resolution. */ + WriteConfigExt(WriteRouteRequestConfig r) : route_config(std::move(r)) {} }; -YLT_REFL(P2PWriteRouteConfig, route_config, rdma_direction_mode); +YLT_REFL(WriteConfigExt, route_config, rdma_direction_mode); inline std::ostream& operator<<(std::ostream& os, - const P2PWriteRouteConfig& config) { - os << "P2PWriteRouteConfig: { route_config: [" << config.route_config - << "], rdma_direction_mode: "; - if (config.rdma_direction_mode.has_value()) { - os << *config.rdma_direction_mode; - } else { - // Unset: product rule is to resolve to REVERSE (see client default). - os << RdmaDirectionMode::REVERSE; - } - os << " }"; + const WriteConfigExt& config) { + os << "WriteConfigExt: { route_config: [" << config.route_config + << "], rdma_direction_mode: " << config.rdma_direction_mode << " }"; return os; } diff --git a/mooncake-store/include/pyclient.h b/mooncake-store/include/pyclient.h index 98de68138c..c909f74577 100644 --- a/mooncake-store/include/pyclient.h +++ b/mooncake-store/include/pyclient.h @@ -40,19 +40,19 @@ class PyClient { virtual int unregister_buffer(void* buffer) = 0; virtual int64_t get_into(const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; virtual std::vector batch_get_into( const std::vector& keys, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; virtual std::vector batch_get_into_multi_buffers( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, bool aggregate_same_segment_task, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; virtual int put_from(const std::string& key, void* buffer, size_t size, const WriteConfig& config) = 0; @@ -73,14 +73,14 @@ class PyClient { const WriteConfig& config) = 0; virtual std::shared_ptr get_buffer( - const std::string& key, const ReadRouteConfig& config = {}) = 0; + const std::string& key, const ReadConfigExt& config = {}) = 0; virtual std::tuple get_buffer_info( - const std::string& key, const ReadRouteConfig& config = {}) = 0; + const std::string& key, const ReadConfigExt& config = {}) = 0; virtual std::vector> batch_get_buffer( const std::vector& keys, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; virtual int put_parts(const std::string& key, std::vector> values, diff --git a/mooncake-store/include/real_client.h b/mooncake-store/include/real_client.h index 58e34a0283..f49dc468d1 100644 --- a/mooncake-store/include/real_client.h +++ b/mooncake-store/include/real_client.h @@ -98,7 +98,7 @@ class RealClient : public PyClient { * register_buffer() for zero-copy operations */ int64_t get_into(const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; /** * @brief Get object data directly into pre-allocated buffers for multiple @@ -114,7 +114,7 @@ class RealClient : public PyClient { std::vector batch_get_into( const std::vector& keys, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; /** * @brief Get object data directly into pre-allocated buffers for multiple @@ -133,7 +133,7 @@ class RealClient : public PyClient { const std::vector>& all_buffers, const std::vector>& all_sizes, bool aggregate_same_segment_task, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; /** * @brief Put object data directly from a pre-allocated buffer @@ -222,7 +222,7 @@ class RealClient : public PyClient { * nullptr if error */ std::shared_ptr get_buffer( - const std::string& key, const ReadRouteConfig& config = {}) override; + const std::string& key, const ReadConfigExt& config = {}) override; /** * @brief Get buffer information (address and size) for a key @@ -230,7 +230,7 @@ class RealClient : public PyClient { * @return Tuple containing buffer address and size, or (0, 0) if error */ std::tuple get_buffer_info( - const std::string& key, const ReadRouteConfig& config = {}) override; + const std::string& key, const ReadConfigExt& config = {}) override; /** * @brief Get buffers containing the data for multiple keys (batch version) @@ -240,7 +240,7 @@ class RealClient : public PyClient { */ std::vector> batch_get_buffer( const std::vector& keys, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; int remove(const std::string& key) override; @@ -276,7 +276,7 @@ class RealClient : public PyClient { // Dummy client helper functions that return tl::expected tl::expected, ErrorCode> get_buffer_info_dummy_helper(const std::string& key, - const ReadRouteConfig& config, + const ReadConfigExt& config, const UUID& client_id); tl::expected put_dummy_helper(const std::string& key, @@ -296,7 +296,7 @@ class RealClient : public PyClient { std::vector> batch_get_into_dummy_helper( const std::vector& keys, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config, const UUID& client_id); + const ReadConfigExt& config, const UUID& client_id); std::vector> batch_put_from_dummy_helper( const std::vector& keys, @@ -342,18 +342,18 @@ class RealClient : public PyClient { tl::expected get_into_internal( const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config = {}); + const ReadConfigExt& config = {}); std::vector> batch_get_into_internal( const std::vector& keys, const std::vector& buffers, - const std::vector& sizes, const ReadRouteConfig& config = {}); + const std::vector& sizes, const ReadConfigExt& config = {}); std::vector> batch_get_into_multi_buffers_internal( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - bool aggregate_same_segment_task, const ReadRouteConfig& config = {}); + bool aggregate_same_segment_task, const ReadConfigExt& config = {}); tl::expected put_from_internal(const std::string& key, void* buffer, size_t size, @@ -403,11 +403,11 @@ class RealClient : public PyClient { const std::string& key, std::shared_ptr client_buffer_allocator = nullptr, - const ReadRouteConfig& config = {}); + const ReadConfigExt& config = {}); std::vector> batch_get_buffer_internal( const std::vector& keys, - const ReadRouteConfig& config = {}); + const ReadConfigExt& config = {}); std::map> batch_get_replica_desc(const std::vector& keys); diff --git a/mooncake-store/include/rpc_types.h b/mooncake-store/include/rpc_types.h index 4491c09754..7ad47d6aff 100644 --- a/mooncake-store/include/rpc_types.h +++ b/mooncake-store/include/rpc_types.h @@ -38,11 +38,18 @@ YLT_REFL(GetReplicaListRequestConfig, max_candidates, p2p_config); typedef GetReplicaListRequestConfig ReadRouteConfig; typedef P2PGetReplicaListConfigExtra P2PReadRouteConfigExtra; -struct P2PReadRouteConfig { - ReadRouteConfig route_config; - std::optional rdma_direction_mode; +/** + * @brief P2P read API wrapper: master routing + optional RDMA direction. + */ +struct ReadConfigExt { + ReadRouteConfig route_config{}; + RdmaDirectionMode rdma_direction_mode = RdmaDirectionMode::REVERSE; + + ReadConfigExt() = default; + /** @brief Backward-compatible promotion from read-route-only config. */ + ReadConfigExt(ReadRouteConfig r) : route_config(std::move(r)) {} }; -YLT_REFL(P2PReadRouteConfig, route_config, rdma_direction_mode); +YLT_REFL(ReadConfigExt, route_config, rdma_direction_mode); /** * @brief Extra info for centralized read route response (Internal use) diff --git a/mooncake-store/src/centralized_client_service.cpp b/mooncake-store/src/centralized_client_service.cpp index 24d7752dbc..1c550bd8c3 100644 --- a/mooncake-store/src/centralized_client_service.cpp +++ b/mooncake-store/src/centralized_client_service.cpp @@ -234,7 +234,7 @@ void CentralizedClientService::InitTransferSubmitter() { tl::expected, ErrorCode> CentralizedClientService::Query(const std::string& object_key, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { auto guard = AcquireInflightGuard(); if (!guard.is_valid()) { LOG(ERROR) << "client is shutting down"; @@ -242,7 +242,7 @@ CentralizedClientService::Query(const std::string& object_key, } std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now(); - auto result = master_client_.GetReplicaList(object_key, config); + auto result = master_client_.GetReplicaList(object_key, config.route_config); if (!result) { LOG(ERROR) << "Failed to get replica list: " << result.error(); return tl::unexpected(result.error()); @@ -263,7 +263,7 @@ CentralizedClientService::Query(const std::string& object_key, std::vector, ErrorCode>> CentralizedClientService::BatchQuery( const std::vector& object_keys, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { auto guard = AcquireInflightGuard(); if (!guard.is_valid()) { LOG(ERROR) << "client is shutting down"; @@ -279,7 +279,7 @@ CentralizedClientService::BatchQuery( std::chrono::steady_clock::now(); std::vector key_views(object_keys.begin(), object_keys.end()); - auto response = master_client_.BatchGetReplicaList(key_views, config); + auto response = master_client_.BatchGetReplicaList(key_views, config.route_config); // Check if we got the expected number of responses if (response.size() != object_keys.size()) { @@ -370,7 +370,7 @@ CentralizedClientService::BatchReplicaClear( tl::expected, ErrorCode> CentralizedClientService::Get(const std::string& key, std::shared_ptr allocator, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { if (!allocator) { LOG(ERROR) << "Client buffer allocator is not provided"; return tl::unexpected(ErrorCode::INVALID_PARAMS); @@ -425,7 +425,7 @@ std::vector, ErrorCode>> CentralizedClientService::BatchGet( const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { std::vector, ErrorCode>> results( keys.size(), tl::unexpected(ErrorCode::INTERNAL_ERROR)); @@ -527,7 +527,7 @@ CentralizedClientService::BatchGet( tl::expected CentralizedClientService::Get( const std::string& key, const std::vector& buffers, - const std::vector& sizes, const ReadRouteConfig& config) { + const std::vector& sizes, const ReadConfigExt& config) { // Step 1: Query metadata from master auto query_result = Query(key, config); if (!query_result) { @@ -572,7 +572,7 @@ CentralizedClientService::BatchGet( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - const ReadRouteConfig& config, bool aggregate_same_segment_task) { + const ReadConfigExt& config, bool aggregate_same_segment_task) { if (keys.size() != all_buffers.size() || keys.size() != all_sizes.size()) { LOG(ERROR) << "Input vector sizes mismatch"; return std::vector>( diff --git a/mooncake-store/src/data_manager.cpp b/mooncake-store/src/data_manager.cpp index a41830d21f..dc23354552 100644 --- a/mooncake-store/src/data_manager.cpp +++ b/mooncake-store/src/data_manager.cpp @@ -1259,6 +1259,30 @@ DataManager::SubmitTeTransferInternal( std::move(buffer_result.value()); } + auto batches_result = SubmitTeTransferBatches( + transfer_ptr, total_data_size, remote_buffers, opcode); + if (!batches_result) { + return tl::unexpected(batches_result.error()); + } + + TeSubmitResult result; + result.transfer_batches = std::move(batches_result.value()); + result.temp_buffer = std::move(temp_buffer_owner); + result.handle = handle; + return result; +} + +tl::expected>, + ErrorCode> +DataManager::SubmitTeTransferBatches( + void* transfer_ptr, size_t total_data_size, + const std::vector& remote_buffers, + Transport::TransferRequest::OpCode opcode) { + if (!transfer_ptr || total_data_size == 0) { + LOG(ERROR) << "SubmitTeTransferBatches: invalid local transfer pointer"; + return tl::unexpected(ErrorCode::INVALID_PARAMS); + } + std::unordered_map> segment_buffers; for (size_t i = 0; i < remote_buffers.size(); ++i) { segment_buffers[remote_buffers[i].segment_endpoint].push_back(i); @@ -1331,11 +1355,41 @@ DataManager::SubmitTeTransferInternal( return tl::unexpected(ErrorCode::TRANSFER_FAIL); } - TeSubmitResult result; - result.transfer_batches = std::move(submitted_batches); - result.temp_buffer = std::move(temp_buffer_owner); - result.handle = handle; - return result; + return submitted_batches; +} + +tl::expected DataManager::TransferWithTeNoTierStaging( + void* local_transfer_base, size_t total_size, + const std::vector& peer_buffers, + Transport::TransferRequest::OpCode opcode) { + auto validate_result = ValidateRemoteBuffers(peer_buffers); + if (!validate_result) { + return tl::make_unexpected(validate_result.error()); + } + size_t total_remote_size = 0; + for (const auto& buf : peer_buffers) total_remote_size += buf.size; + if (total_remote_size != total_size) { + LOG(ERROR) << "TransferWithTeNoTierStaging: peer buffer size mismatch (" + << total_remote_size << " vs " << total_size << ")"; + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + if (!transfer_engine_->getMetadata()) { + LOG(ERROR) << "TransferEngine not initialized"; + return tl::make_unexpected(ErrorCode::INTERNAL_ERROR); + } + auto batches = SubmitTeTransferBatches(local_transfer_base, total_size, + peer_buffers, opcode); + if (!batches) { + return tl::unexpected(batches.error()); + } + auto wait_result = WaitAllTransferBatches(batches.value()); + if (!wait_result) { + LOG(ERROR) << "TransferWithTeNoTierStaging: WaitAllTransferBatches " + "failed: " + << toString(wait_result.error()); + return wait_result; + } + return {}; } tl::expected DataManager::ValidateRemoteBuffers( diff --git a/mooncake-store/src/dummy_client.cpp b/mooncake-store/src/dummy_client.cpp index 746e4b0c6c..f560c75de6 100644 --- a/mooncake-store/src/dummy_client.cpp +++ b/mooncake-store/src/dummy_client.cpp @@ -542,13 +542,13 @@ int64_t DummyClient::getSize(const std::string& key) { } std::shared_ptr DummyClient::get_buffer( - const std::string& key, const ReadRouteConfig& config) { + const std::string& key, const ReadConfigExt& config) { // Dummy client does not use BufferHandle, so we return nullptr return nullptr; } std::tuple DummyClient::get_buffer_info( - const std::string& key, const ReadRouteConfig& config) { + const std::string& key, const ReadConfigExt& config) { auto result = invoke_rpc<&RealClient::get_buffer_info_dummy_helper, std::tuple>(key, config, client_id_); @@ -560,13 +560,13 @@ std::tuple DummyClient::get_buffer_info( } std::vector> DummyClient::batch_get_buffer( - const std::vector& keys, const ReadRouteConfig& config) { + const std::vector& keys, const ReadConfigExt& config) { // TODO: implement this function return std::vector>(); } int64_t DummyClient::get_into(const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { // TODO: implement this function return -1; } @@ -604,7 +604,7 @@ int DummyClient::put_from(const std::string& key, void* buffer, size_t size, std::vector DummyClient::batch_get_into( const std::vector& keys, const std::vector& buffer_ptrs, - const std::vector& sizes, const ReadRouteConfig& config) { + const std::vector& sizes, const ReadConfigExt& config) { std::vector buffers; for (auto ptr : buffer_ptrs) { buffers.push_back(reinterpret_cast(ptr)); @@ -644,7 +644,7 @@ std::vector DummyClient::batch_get_into_multi_buffers( const std::vector& keys, const std::vector>& all_buffer_ptrs, const std::vector>& all_sizes, - bool aggregate_same_segment_task, const ReadRouteConfig& config) { + bool aggregate_same_segment_task, const ReadConfigExt& config) { // TODO: implement this function std::vector vec(keys.size(), -1); return vec; diff --git a/mooncake-store/src/p2p_client_service.cpp b/mooncake-store/src/p2p_client_service.cpp index 32e887c260..b1b047a183 100644 --- a/mooncake-store/src/p2p_client_service.cpp +++ b/mooncake-store/src/p2p_client_service.cpp @@ -17,6 +17,62 @@ namespace mooncake { +namespace { + +// UnPin after forward read (or cleanup after TE failure): retry only for +// "other" errors. INVALID_READ = token mismatch (treat as released for flow). +// RPC_FAIL = transport/timeout-like (no repeat; owner may TTL-clean). LEASE_EXPIRED +// = server already expired the pin record. +constexpr int kForwardReadUnpinMaxAttempts = 3; + +bool UnPinErrorTreatAsEffectiveOk(ErrorCode e) { + switch (e) { + case ErrorCode::INVALID_READ: + case ErrorCode::RPC_FAIL: + case ErrorCode::LEASE_EXPIRED: + return true; + default: + return false; + } +} + +bool SlicesAreContiguous(const std::vector& slices) { + if (slices.empty()) { + return false; + } + for (size_t i = 1; i < slices.size(); ++i) { + const char* prev_end = + static_cast(slices[i - 1].ptr) + slices[i - 1].size; + if (prev_end != static_cast(slices[i].ptr)) { + return false; + } + } + return true; +} + +size_t TotalSliceBytes(const std::vector& slices) { + size_t t = 0; + for (const auto& s : slices) { + t += s.size; + } + return t; +} + +bool RemoteDestBuffersContiguous(const std::vector& bufs) { + if (bufs.empty()) { + return false; + } + for (size_t i = 1; i < bufs.size(); ++i) { + uintptr_t prev_end = bufs[i - 1].addr + bufs[i - 1].size; + if (prev_end != bufs[i].addr) { + return false; + } + } + return true; +} + +} // namespace + // ============================================================================ // Construction / Destruction // ============================================================================ @@ -558,7 +614,16 @@ std::vector> P2PClientService::BatchPut( } auto guard = AcquireInflightGuard(); - const auto* route_config = std::get_if(&config); + const auto* plain_route = std::get_if(&config); + const auto* ext_route = std::get_if(&config); + const WriteRouteRequestConfig* route_cfg_ptr = nullptr; + RdmaDirectionMode rdma_write_mode = RdmaDirectionMode::REVERSE; + if (plain_route) { + route_cfg_ptr = plain_route; + } else if (ext_route) { + route_cfg_ptr = &ext_route->route_config; + rdma_write_mode = ext_route->rdma_direction_mode; + } if (!guard.is_valid()) { LOG(ERROR) << "client is shutting down"; std::fill(results.begin(), results.end(), @@ -567,13 +632,14 @@ std::vector> P2PClientService::BatchPut( LOG(ERROR) << "BatchPut input size mismatch"; std::fill(results.begin(), results.end(), tl::unexpected(ErrorCode::INVALID_PARAMS)); - } else if (!route_config) { - LOG(ERROR) << "P2PClientService currently only supports " - "WriteRouteRequestConfig"; + } else if (!route_cfg_ptr) { + LOG(ERROR) << "P2PClientService expects WriteRouteRequestConfig or " + "WriteConfigExt"; std::fill(results.begin(), results.end(), tl::unexpected(ErrorCode::INVALID_PARAMS)); } else { - results = InnerBatchPut(keys, batched_slices, *route_config); + results = InnerBatchPut(keys, batched_slices, *route_cfg_ptr, + rdma_write_mode); } size_t success_count = 0; @@ -612,11 +678,13 @@ std::vector> P2PClientService::BatchPut( std::vector> P2PClientService::InnerBatchPut( const std::vector& keys, std::vector>& batched_slices, - const WriteRouteRequestConfig& route_config) { + const WriteRouteRequestConfig& route_config, + RdmaDirectionMode rdma_direction_mode) { if (ha_manager_ && ha_manager_->IsDegraded()) { return InnerBatchPutDegraded(keys, batched_slices); } - return InnerBatchPutNormal(keys, batched_slices, route_config); + return InnerBatchPutNormal(keys, batched_slices, route_config, + rdma_direction_mode); } std::vector> @@ -691,7 +759,8 @@ std::vector> P2PClientService::InnerBatchPutNormal( const std::vector& keys, std::vector>& batched_slices, - const WriteRouteRequestConfig& route_config) { + const WriteRouteRequestConfig& route_config, + RdmaDirectionMode rdma_direction_mode) { // Phase 1: fetch write routes from master. auto batch_routes = BatchFetchWriteRoutes(keys, batched_slices, route_config); @@ -704,8 +773,9 @@ P2PClientService::InnerBatchPutNormal( // Phase 2: // 2.1: async dispatch first-candidate writes for each key // 2.2: wrap each key in a retry chain based on rotute - auto handles = CreatePutHandlesFromRoute(keys, batched_slices, route_config, - batch_routes.value()); + auto handles = + CreatePutHandlesFromRoute(keys, batched_slices, route_config, + rdma_direction_mode, batch_routes.value()); // Phase 3: wait every retry chain and collect results. return CollectResults(handles, keys); @@ -740,6 +810,7 @@ P2PClientService::CreatePutHandlesFromRoute( const std::vector& keys, std::vector>& batched_slices, const WriteRouteRequestConfig& route_config, + RdmaDirectionMode rdma_direction_mode, BatchGetWriteRouteResponse& batch_resp) { struct WriteTask { std::unique_ptr> first_task; @@ -757,6 +828,7 @@ P2PClientService::CreatePutHandlesFromRoute( continue; } auto ops = BuildWriteOps(keys[i], batched_slices[i], route_config, + rdma_direction_mode, std::move(batch_resp.responses[i].candidates)); if (!ops) { LOG(ERROR) << "fail to build write ops" @@ -812,6 +884,7 @@ P2PClientService::CreatePutHandlesFromRoute( auto P2PClientService::BuildWriteOps(std::string_view key, std::vector& slices, const WriteRouteRequestConfig& config, + RdmaDirectionMode rdma_direction_mode, std::vector candidates) -> tl::expected>, ErrorCode> { if (candidates.empty()) { @@ -848,9 +921,15 @@ auto P2PClientService::BuildWriteOps(std::string_view key, } else { std::string endpoint = proxy.ip_address + ":" + std::to_string(proxy.rpc_port); + DataManager* fwd_dm = + (rdma_direction_mode == RdmaDirectionMode::FORWARD && + data_manager_.has_value()) + ? &*data_manager_ + : nullptr; write_ops.push_back(std::make_unique( - &GetOrCreatePeerClient(endpoint), write_req, std::move(proxy), - route_cache_ ? &*route_cache_ : nullptr, endpoint)); + this, &GetOrCreatePeerClient(endpoint), write_req, + std::move(proxy), route_cache_ ? &*route_cache_ : nullptr, + endpoint, fwd_dm, &slices, rdma_direction_mode)); } } @@ -894,6 +973,44 @@ std::unique_ptr> P2PClientService::LocalWriteOp::Dispatch() { } std::unique_ptr> P2PClientService::RemoteWriteOp::Dispatch() { + if (!owner_service) { + LOG(ERROR) << "Remote write requires P2PClientService"; + return CallableTaskHandle::Create( + []() -> tl::expected { + return tl::unexpected(ErrorCode::INTERNAL_ERROR); + }); + } + if (rdma_direction_mode == RdmaDirectionMode::FORWARD) { + return owner_service->StartForwardRemotePut( + peer_ptr, forward_dm, forward_slices, write_req); + } + return owner_service->RunReverseRemotePut(peer_ptr, write_req, proxy, + route_cache); +} + +std::unique_ptr> P2PClientService::StartForwardRemotePut( + PeerClient* peer, DataManager* forward_dm, + std::vector* forward_slices, + std::shared_ptr write_req) { + if (!forward_dm || !forward_slices) { + LOG(ERROR) << "Forward RDMA write requires local DataManager"; + return CallableTaskHandle::Create( + []() -> tl::expected { + return tl::unexpected(ErrorCode::INTERNAL_ERROR); + }); + } + auto promise = std::make_shared< + async_simple::Promise>>(); + auto future = promise->getFuture(); + RunForwardRemotePut(std::move(promise), peer, forward_dm, write_req, + forward_slices) + .start([](auto&&) {}); + return FutureHandle::Create(write_req, std::move(future)); +} + +std::unique_ptr> P2PClientService::RunReverseRemotePut( + PeerClient* peer, std::shared_ptr write_req, + const P2PProxyDescriptor& proxy, RouteCache* route_cache) { auto promise = std::make_shared< async_simple::Promise>>(); auto future = promise->getFuture(); @@ -902,7 +1019,7 @@ std::unique_ptr> P2PClientService::RemoteWriteOp::Dispatch() { auto cached_proxy = proxy; auto* cache = route_cache; - peer_ptr->AsyncWriteRemoteData(*write_req) + peer->AsyncWriteRemoteData(*write_req) .start([promise, req, cached_proxy, cache](async_simple::Try>&& remote_res) mutable { @@ -937,6 +1054,58 @@ std::unique_ptr> P2PClientService::RemoteWriteOp::Dispatch() { return FutureHandle::Create(req, std::move(future)); } +async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( + std::shared_ptr>> + promise, + PeerClient* peer, DataManager* dm, + std::shared_ptr write_req, + std::vector* slices) { + if (!peer || !dm || !write_req || !slices) { + promise->setValue(tl::unexpected(ErrorCode::INTERNAL_ERROR)); + co_return; + } + if (!SlicesAreContiguous(*slices)) { + LOG(ERROR) << "Forward RDMA write requires contiguous slice buffers, key=" + << write_req->key; + promise->setValue(tl::unexpected(ErrorCode::INVALID_PARAMS)); + co_return; + } + PreWriteRequest pre_req; + pre_req.key = write_req->key; + pre_req.size_bytes = TotalSliceBytes(*slices); + pre_req.target_tier_id = write_req->target_tier_id; + + auto pre = co_await peer->AsyncPreWrite(pre_req); + if (!pre) { + if (!IsAlreadyExistsError(pre.error())) { + LOG(ERROR) << "AsyncPreWrite failed, key=" << write_req->key + << ", error=" << pre.error(); + } + promise->setValue(tl::make_unexpected(pre.error())); + co_return; + } + + std::vector dest{pre.value().remote_buffer}; + void* base = slices->front().ptr; + auto te = dm->TransferWithTeNoTierStaging( + base, TotalSliceBytes(*slices), dest, + Transport::TransferRequest::WRITE); + if (!te) { + promise->setValue(tl::make_unexpected(te.error())); + co_return; + } + + WriteCommitRequest commit; + commit.key = write_req->key; + commit.pending_write_token = pre.value().pending_write_token; + auto cm = co_await peer->AsyncWriteCommit(commit); + if (!cm) { + promise->setValue(tl::make_unexpected(cm.error())); + co_return; + } + promise->setValue(tl::expected{}); +} + async_simple::coro::Lazy P2PClientService::RunWriteWithRetry( std::shared_ptr>> promise, @@ -996,20 +1165,20 @@ async_simple::coro::Lazy P2PClientService::RunWriteWithRetry( tl::expected, ErrorCode> P2PClientService::Get( const std::string& key, std::shared_ptr allocator, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { return std::move(BatchGet({key}, allocator, config)[0]); } tl::expected P2PClientService::Get( const std::string& key, const std::vector& buffers, - const std::vector& sizes, const ReadRouteConfig& config) { + const std::vector& sizes, const ReadConfigExt& config) { return std::move(BatchGet({key}, {buffers}, {sizes}, config)[0]); } std::vector, ErrorCode>> P2PClientService::BatchGet(const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { if (!allocator) { LOG(ERROR) << "Client buffer allocator is not provided"; return std::vector< @@ -1030,7 +1199,7 @@ std::vector> P2PClientService::BatchGet( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - const ReadRouteConfig& config, bool /*aggregate_same_segment_task*/) { + const ReadConfigExt& config, bool /*aggregate_same_segment_task*/) { if (keys.size() != all_buffers.size() || keys.size() != all_sizes.size()) { LOG(ERROR) << "Input vector sizes mismatch"; return std::vector>( @@ -1140,7 +1309,7 @@ std::vector> P2PClientService::BatchCreateGetHandles( const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { auto local_get = [&](std::string_view key, size_t) -> tl::expected { if (!data_manager_.has_value()) { @@ -1160,7 +1329,7 @@ std::vector> P2PClientService::BatchCreateGetHandles( const std::vector& keys, std::vector>& all_slices, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { auto local_get = [&](std::string_view key, size_t i) -> tl::expected { if (!data_manager_.has_value()) { @@ -1180,7 +1349,7 @@ P2PClientService::BatchCreateGetHandles( template std::vector> P2PClientService::BatchCreateGetHandlesImpl( - const std::vector& keys, const ReadRouteConfig& config, + const std::vector& keys, const ReadConfigExt& config, LocalGetFn&& local_get, RemoteGetFn&& remote_get) { std::vector> handles; handles.reserve(keys.size()); @@ -1240,7 +1409,7 @@ P2PClientService::BatchCreateGetHandlesImpl( std::vector< tl::expected, ErrorCode>> P2PClientService::BatchFetchReadRoutes( - const std::vector& keys, const ReadRouteConfig& config) { + const std::vector& keys, const ReadConfigExt& config) { std::vector, ErrorCode>> result( keys.size(), std::vector{}); @@ -1263,7 +1432,7 @@ P2PClientService::BatchFetchReadRoutes( // Single batch RPC to master std::vector> responses; - responses = master_client_.BatchGetReplicaList(miss_keys, config); + responses = master_client_.BatchGetReplicaList(miss_keys, config.route_config); for (size_t k = 0; k < responses.size(); ++k) { if (!responses[k]) { if (responses[k].error() != ErrorCode::OBJECT_NOT_FOUND) { @@ -1345,7 +1514,7 @@ std::vector P2PClientService::ReplicasToRoutes( tl::expected P2PClientService::CreateRemoteGetHandle( std::string_view key, std::shared_ptr allocator, - const ReadRouteConfig& config, std::vector pre_fetched) { + const ReadConfigExt& config, std::vector pre_fetched) { auto iter = BuildRouteIter(key, config, std::move(pre_fetched)); if (!iter) { LOG(ERROR) << "Failed to build route iterator, key=" << key @@ -1362,7 +1531,8 @@ tl::expected P2PClientService::CreateRemoteGetHandle( auto read_buf = std::make_shared(std::move(*alloc_result)); std::vector slices = {{read_buf->ptr(), object_size}}; - auto result = InnerGetViaRoute(key, slices, std::move(*iter)); + auto result = + InnerGetViaRoute(key, slices, std::move(*iter), config.rdma_direction_mode); if (!result) { LOG(ERROR) << "Failed to get via route, key=" << key << ", error=" << result.error(); @@ -1377,7 +1547,7 @@ tl::expected P2PClientService::CreateRemoteGetHandle( tl::expected P2PClientService::CreateRemoteGetHandle( std::string_view key, std::vector& slices, - const ReadRouteConfig& config, std::vector pre_fetched) { + const ReadConfigExt& config, std::vector pre_fetched) { auto iter = BuildRouteIter(key, config, std::move(pre_fetched)); if (!iter) { if (iter.error() != ErrorCode::OBJECT_NOT_FOUND) { @@ -1386,7 +1556,8 @@ tl::expected P2PClientService::CreateRemoteGetHandle( } return tl::unexpected(iter.error()); } - auto result = InnerGetViaRoute(key, slices, std::move(*iter)); + auto result = + InnerGetViaRoute(key, slices, std::move(*iter), config.rdma_direction_mode); if (!result) { LOG(ERROR) << "Failed to get via route, key=" << key << ", error=" << result.error(); @@ -1396,7 +1567,8 @@ tl::expected P2PClientService::CreateRemoteGetHandle( } tl::expected P2PClientService::InnerGetViaRoute( - std::string_view key, std::vector& slices, RouteIterator iter) { + std::string_view key, std::vector& slices, RouteIterator iter, + RdmaDirectionMode rdma_direction_mode) { auto req = std::make_shared(); req->key = key; for (const auto& s : slices) { @@ -1412,7 +1584,8 @@ tl::expected P2PClientService::InnerGetViaRoute( auto future = promise->getFuture(); const uint64_t object_size = iter.object_size(); - RunReadWithRetry(std::move(iter), req, promise).start([](auto&&) {}); + RunReadWithRetry(std::move(iter), req, promise, rdma_direction_mode) + .start([](auto&&) {}); ReadTaskHandle res; res.data_size = object_size; @@ -1421,15 +1594,122 @@ tl::expected P2PClientService::InnerGetViaRoute( return res; } +async_simple::coro::Lazy P2PClientService::RunForwardReadOnRoute( + const ResolvedRoute& route, std::shared_ptr req, + std::shared_ptr>> + promise, + RouteIterator& iter, ErrorCode& final_result) { + if (!data_manager_.has_value()) { + LOG(ERROR) << "Forward RDMA read requires DataManager"; + promise->setValue(tl::unexpected(ErrorCode::INTERNAL_ERROR)); + co_return true; + } + if (!RemoteDestBuffersContiguous(req->dest_buffers)) { + LOG(ERROR) << "Forward RDMA read requires contiguous dest buffers, key=" + << req->key; + promise->setValue(tl::unexpected(ErrorCode::INVALID_PARAMS)); + co_return true; + } + PinKeyRequest pin_req; + pin_req.key = req->key; + auto pin = co_await route.peer->AsyncPinKey(pin_req); + if (!pin) { + if (pin.error() != ErrorCode::OBJECT_NOT_FOUND) { + LOG(ERROR) << "AsyncPinKey failed, key=" << req->key + << ", error=" << pin.error(); + } else { + final_result = pin.error(); + } + iter.Evict(route); + co_return false; + } + void* base = reinterpret_cast(req->dest_buffers[0].addr); + size_t total = 0; + for (const auto& d : req->dest_buffers) { + total += d.size; + } + auto tr = data_manager_->TransferWithTeNoTierStaging( + base, total, {pin.value().remote_buffer}, + Transport::TransferRequest::READ); + if (!tr) { + LOG(ERROR) << "Forward TE read failed, key=" << req->key + << ", error=" << tr.error(); + UnPinKeyRequest cleanup; + cleanup.key = req->key; + cleanup.pin_token = pin.value().pin_token; + tl::expected cleanup_unpin; + for (int attempt = 0; attempt < kForwardReadUnpinMaxAttempts; + ++attempt) { + cleanup_unpin = co_await route.peer->AsyncUnPinKey(cleanup); + if (cleanup_unpin) { + break; + } + if (UnPinErrorTreatAsEffectiveOk(cleanup_unpin.error())) { + cleanup_unpin = tl::expected{}; + break; + } + if (attempt + 1 < kForwardReadUnpinMaxAttempts) { + LOG(WARNING) << "AsyncUnPinKey retry after TE failure, key=" + << req->key << ", attempt=" << (attempt + 1) + << ", error=" << cleanup_unpin.error(); + } + } + if (!cleanup_unpin) { + LOG(ERROR) << "AsyncUnPinKey failed after TE read failure, key=" + << req->key << ", error=" << cleanup_unpin.error(); + } + iter.Evict(route); + co_return false; + } + UnPinKeyRequest unpin_req; + unpin_req.key = req->key; + unpin_req.pin_token = pin.value().pin_token; + tl::expected unpin_res; + for (int attempt = 0; attempt < kForwardReadUnpinMaxAttempts; ++attempt) { + unpin_res = co_await route.peer->AsyncUnPinKey(unpin_req); + if (unpin_res) { + break; + } + if (UnPinErrorTreatAsEffectiveOk(unpin_res.error())) { + unpin_res = tl::expected{}; + break; + } + if (attempt + 1 < kForwardReadUnpinMaxAttempts) { + LOG(WARNING) << "AsyncUnPinKey retry after forward read, key=" + << req->key << ", attempt=" << (attempt + 1) + << ", error=" << unpin_res.error(); + } + } + if (!unpin_res) { + LOG(ERROR) << "AsyncUnPinKey failed after forward read, key=" + << req->key << ", error=" << unpin_res.error(); + final_result = unpin_res.error(); + iter.Evict(route); + co_return false; + } + tl::expected ok; + promise->setValue(std::move(ok)); + co_return true; +} + // Coroutine iterates route candidates and retries on failure. async_simple::coro::Lazy P2PClientService::RunReadWithRetry( RouteIterator iter, std::shared_ptr req, std::shared_ptr>> - promise) { + promise, + RdmaDirectionMode rdma_direction_mode) { ErrorCode final_result = ErrorCode::OBJECT_NOT_FOUND; try { while (auto route = co_await iter.AsyncNext()) { try { + if (rdma_direction_mode == RdmaDirectionMode::FORWARD) { + if (co_await RunForwardReadOnRoute(*route, req, promise, iter, + final_result)) { + co_return; + } + continue; + } + auto result = co_await route->peer->AsyncReadRemoteData(*req); if (result.has_value()) { tl::expected ok; @@ -1539,13 +1819,13 @@ void P2PClientService::RouteIterator::Evict(const ResolvedRoute& route) { tl::expected P2PClientService::BuildRouteIter(std::string_view key, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { return BuildRouteIter(key, config, LoadCachedRoutes(key)); } tl::expected P2PClientService::BuildRouteIter(std::string_view key, - const ReadRouteConfig& config, + const ReadConfigExt& config, std::vector pre_fetched) { auto routes = std::move(pre_fetched); uint64_t object_size = routes.empty() ? 0 : routes.front().object_size; @@ -1565,9 +1845,9 @@ P2PClientService::BuildRouteIter(std::string_view key, async_simple::coro::Lazy> P2PClientService::AsyncResolveRoutesFromMaster(std::string_view key, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { auto replica_result = - co_await master_client_.AsyncGetReplicaList(key, config); + co_await master_client_.AsyncGetReplicaList(key, config.route_config); if (!replica_result) { if (replica_result.error() != ErrorCode::OBJECT_NOT_FOUND) { LOG(ERROR) << "Failed to query replica list, key=" << key @@ -1649,7 +1929,7 @@ std::vector> P2PClientService::BatchIsExist( // ============================================================================ tl::expected, ErrorCode> P2PClientService::Query( - const std::string& object_key, const ReadRouteConfig& config) { + const std::string& object_key, const ReadConfigExt& config) { auto guard = AcquireInflightGuard(); if (!guard.is_valid()) { LOG(ERROR) << "client is shutting down"; @@ -1664,7 +1944,7 @@ tl::expected, ErrorCode> P2PClientService::Query( } // Query master for replica list - auto result = master_client_.GetReplicaList(object_key, config); + auto result = master_client_.GetReplicaList(object_key, config.route_config); if (!result) { LOG(WARNING) << "fail to get replica list" << ", key=" << object_key << ", error=" << result.error(); @@ -1676,7 +1956,7 @@ tl::expected, ErrorCode> P2PClientService::Query( std::vector, ErrorCode>> P2PClientService::BatchQuery(const std::vector& object_keys, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { auto guard = AcquireInflightGuard(); if (!guard.is_valid()) { LOG(ERROR) << "client is shutting down"; @@ -1690,7 +1970,8 @@ P2PClientService::BatchQuery(const std::vector& object_keys, } std::vector key_views(object_keys.begin(), object_keys.end()); - auto responses = master_client_.BatchGetReplicaList(key_views, config); + auto responses = + master_client_.BatchGetReplicaList(key_views, config.route_config); std::vector, ErrorCode>> results; results.reserve(responses.size()); for (size_t i = 0; i < responses.size(); ++i) { diff --git a/mooncake-store/src/real_client.cpp b/mooncake-store/src/real_client.cpp index 0fe9a34ab0..42f6881441 100644 --- a/mooncake-store/src/real_client.cpp +++ b/mooncake-store/src/real_client.cpp @@ -824,7 +824,7 @@ tl::expected RealClient::unregister_shm_buffer_internal( std::shared_ptr RealClient::get_buffer_internal( const std::string& key, std::shared_ptr client_buffer_allocator, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { if (!client_service_) { LOG(ERROR) << "Client is not initialized"; return nullptr; @@ -848,12 +848,12 @@ std::shared_ptr RealClient::get_buffer_internal( // Implementation of get_buffer method std::shared_ptr RealClient::get_buffer( - const std::string& key, const ReadRouteConfig& config) { + const std::string& key, const ReadConfigExt& config) { return get_buffer_internal(key, client_buffer_allocator_, config); } std::tuple RealClient::get_buffer_info( - const std::string& key, const ReadRouteConfig& config) { + const std::string& key, const ReadConfigExt& config) { auto buffer_handle = get_buffer_internal(key, client_buffer_allocator_, config); if (!buffer_handle) { @@ -867,7 +867,7 @@ std::tuple RealClient::get_buffer_info( tl::expected, ErrorCode> RealClient::get_buffer_info_dummy_helper(const std::string& key, - const ReadRouteConfig& config, + const ReadConfigExt& config, const UUID& client_id) { std::shared_lock lock(dummy_client_mutex_); auto it = shm_contexts_.find(client_id); @@ -905,7 +905,7 @@ RealClient::get_buffer_info_dummy_helper(const std::string& key, // Implementation of batch_get_buffer_internal method std::vector> RealClient::batch_get_buffer_internal(const std::vector& keys, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { std::vector> final_results(keys.size(), nullptr); @@ -938,7 +938,7 @@ RealClient::batch_get_buffer_internal(const std::vector& keys, // Implementation of batch_get_buffer method std::vector> RealClient::batch_get_buffer( - const std::vector& keys, const ReadRouteConfig& config) { + const std::vector& keys, const ReadConfigExt& config) { return batch_get_buffer_internal(keys, config); } @@ -978,7 +978,7 @@ int RealClient::unregister_buffer(void* buffer) { tl::expected RealClient::get_into_internal( const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { // NOTE: The buffer address must be previously registered with // register_buffer() for zero-copy RDMA operations to work correctly if (!client_service_) { @@ -990,7 +990,7 @@ tl::expected RealClient::get_into_internal( } int64_t RealClient::get_into(const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { return to_py_ret(get_into_internal(key, buffer, size, config)); } @@ -1173,7 +1173,7 @@ int RealClient::put_from(const std::string& key, void* buffer, size_t size, std::vector RealClient::batch_get_into( const std::vector& keys, const std::vector& buffers, - const std::vector& sizes, const ReadRouteConfig& config) { + const std::vector& sizes, const ReadConfigExt& config) { auto internal_results = batch_get_into_internal(keys, buffers, sizes, config); std::vector results; @@ -1190,7 +1190,7 @@ std::vector> RealClient::batch_get_into_dummy_helper( const std::vector& keys, const std::vector& dummy_buffers, - const std::vector& sizes, const ReadRouteConfig& config, + const std::vector& sizes, const ReadConfigExt& config, const UUID& client_id) { std::shared_lock lock(dummy_client_mutex_); auto it = shm_contexts_.find(client_id); @@ -1245,7 +1245,7 @@ std::vector> RealClient::batch_get_into_internal(const std::vector& keys, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { // Validate preconditions if (!client_service_) { LOG(ERROR) << "Client is not initialized"; @@ -1405,7 +1405,7 @@ std::vector RealClient::batch_get_into_multi_buffers( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - bool prefer_alloc_in_same_node, const ReadRouteConfig& config) { + bool prefer_alloc_in_same_node, const ReadConfigExt& config) { auto start = std::chrono::steady_clock::now(); auto internal_results = batch_get_into_multi_buffers_internal( keys, all_buffers, all_sizes, prefer_alloc_in_same_node, config); @@ -1427,7 +1427,7 @@ RealClient::batch_get_into_multi_buffers_internal( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - bool prefer_alloc_in_same_node, const ReadRouteConfig& config) { + bool prefer_alloc_in_same_node, const ReadConfigExt& config) { // Validate preconditions if (!client_service_) { LOG(ERROR) << "Client is not initialized"; From a92f86e48cf10f541908788934ede1975ed773cf Mon Sep 17 00:00:00 2001 From: shichangzhang064 Date: Tue, 12 May 2026 11:32:54 +0800 Subject: [PATCH 10/14] feat: add put revoke for rollback --- mooncake-store/include/client_rpc_service.h | 3 ++ mooncake-store/include/client_rpc_types.h | 8 +++ mooncake-store/include/data_manager.h | 13 +++-- mooncake-store/include/peer_client.h | 5 ++ mooncake-store/src/client_rpc_service.cpp | 32 ++++++++++++ mooncake-store/src/data_manager.cpp | 54 +++++++++++++-------- mooncake-store/src/p2p_client_service.cpp | 36 ++++++++++++++ mooncake-store/src/peer_client.cpp | 30 ++++++++++++ mooncake-store/tests/data_manager_test.cpp | 42 ++++++++++++++++ 9 files changed, 200 insertions(+), 23 deletions(-) diff --git a/mooncake-store/include/client_rpc_service.h b/mooncake-store/include/client_rpc_service.h index ea60fba240..9b0bf4e476 100644 --- a/mooncake-store/include/client_rpc_service.h +++ b/mooncake-store/include/client_rpc_service.h @@ -61,6 +61,9 @@ class ClientRpcService { tl::expected WriteCommit( const WriteCommitRequest& request); + tl::expected WriteRevoke( + const WriteRevokeRequest& request); + tl::expected PinKey( const PinKeyRequest& request); diff --git a/mooncake-store/include/client_rpc_types.h b/mooncake-store/include/client_rpc_types.h index 1940167544..fde4587cda 100644 --- a/mooncake-store/include/client_rpc_types.h +++ b/mooncake-store/include/client_rpc_types.h @@ -107,6 +107,14 @@ struct WriteCommitRequest { YLT_REFL(WriteCommitRequest, key, pending_write_token); +/** Drops a pending PreWrite allocation without committing (e.g. after TE failure). */ +struct WriteRevokeRequest { + std::string_view key; + UUID pending_write_token; +}; + +YLT_REFL(WriteRevokeRequest, key, pending_write_token); + struct PinKeyRequest { std::string_view key; std::optional target_tier_id; diff --git a/mooncake-store/include/data_manager.h b/mooncake-store/include/data_manager.h index b37c992a91..452ea6a3ff 100644 --- a/mooncake-store/include/data_manager.h +++ b/mooncake-store/include/data_manager.h @@ -197,6 +197,13 @@ class DataManager { tl::expected WriteCommit(std::string_view key, const UUID& pending_write_token); + /** + * @brief Remove pending write record for key + token (no tier Commit). + * Used when forward TE fails after PreWrite on the peer. + */ + tl::expected WriteRevoke(std::string_view key, + const UUID& pending_write_token); + tl::expected PinKey( std::string_view key, std::optional tier_id = std::nullopt); @@ -266,6 +273,8 @@ class DataManager { bool enforce_dram_allocation); tl::expected WriteCommitInternal( const KeyCtx& ctx, const UUID& pending_write_token); + tl::expected WriteRevokeInternal( + const KeyCtx& ctx, const UUID& pending_write_token); tl::expected PinKeyInternal( const KeyCtx& ctx, std::optional tier_id); tl::expected UnPinKeyInternal(const KeyCtx& ctx, @@ -275,8 +284,6 @@ class DataManager { const KeyCtx& ctx, const UUID& pending_write_token); tl::expected LookupPinnedKeyHandleInternal( const KeyCtx& ctx, const UUID& pin_token); - void AbortPendingWriteInternal(const KeyCtx& ctx, - const UUID& pending_write_token); std::shared_mutex& GetKeyLock(std::string_view key) { size_t hash = std::hash{}(key); @@ -490,8 +497,6 @@ class DataManager { std::string_view key, const UUID& pending_write_token); tl::expected LookupPinnedKeyHandle( std::string_view key, const UUID& pin_token); - void AbortPendingWrite(std::string_view key, - const UUID& pending_write_token); private: std::unique_ptr tiered_backend_; // Owned by DataManager diff --git a/mooncake-store/include/peer_client.h b/mooncake-store/include/peer_client.h index f70c1c6c97..58be645af1 100644 --- a/mooncake-store/include/peer_client.h +++ b/mooncake-store/include/peer_client.h @@ -33,6 +33,9 @@ class PeerClient { async_simple::coro::Lazy> AsyncWriteCommit( const WriteCommitRequest& request); + async_simple::coro::Lazy> AsyncWriteRevoke( + const WriteRevokeRequest& request); + async_simple::coro::Lazy> AsyncPinKey(const PinKeyRequest& request); @@ -48,6 +51,8 @@ class PeerClient { const PreWriteRequest& request); tl::expected WriteCommit( const WriteCommitRequest& request); + tl::expected WriteRevoke( + const WriteRevokeRequest& request); tl::expected PinKey( const PinKeyRequest& request); tl::expected UnPinKey(const UnPinKeyRequest& request); diff --git a/mooncake-store/src/client_rpc_service.cpp b/mooncake-store/src/client_rpc_service.cpp index d186ac39f6..a736500745 100644 --- a/mooncake-store/src/client_rpc_service.cpp +++ b/mooncake-store/src/client_rpc_service.cpp @@ -70,6 +70,14 @@ bool IsValidRequest(const WriteCommitRequest& request) { return true; } +bool IsValidRequest(const WriteRevokeRequest& request) { + if (request.key.empty() || IsZeroUuid(request.pending_write_token)) { + LOG(ERROR) << "WriteRevokeRequest: invalid key or token"; + return false; + } + return true; +} + bool IsValidRequest(const PinKeyRequest& request) { if (request.key.empty()) { LOG(ERROR) << "PinKeyRequest: empty key"; @@ -238,6 +246,29 @@ tl::expected ClientRpcService::WriteCommit( return {}; } +tl::expected ClientRpcService::WriteRevoke( + const WriteRevokeRequest& request) { + ScopedVLogTimer timer(1, "ClientRpcService::WriteRevoke"); + timer.LogRequest("key=", request.key); + + if (!IsValidRequest(request)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + auto result = + data_manager_.WriteRevoke(request.key, request.pending_write_token); + if (!result) { + LOG(ERROR) << "WriteRevoke failed for key: " << request.key + << ", error: " << toString(result.error()); + timer.LogResponse("error_code=", result.error()); + return result; + } + + timer.LogResponse("error_code=", ErrorCode::OK); + return {}; +} + tl::expected ClientRpcService::PinKey( const PinKeyRequest& request) { ScopedVLogTimer timer(1, "ClientRpcService::PinKey"); @@ -288,6 +319,7 @@ void RegisterClientRpcService(coro_rpc::coro_rpc_server& server, server.register_handler<&ClientRpcService::WriteRemoteData>(&service); server.register_handler<&ClientRpcService::PreWrite>(&service); server.register_handler<&ClientRpcService::WriteCommit>(&service); + server.register_handler<&ClientRpcService::WriteRevoke>(&service); server.register_handler<&ClientRpcService::PinKey>(&service); server.register_handler<&ClientRpcService::UnPinKey>(&service); } diff --git a/mooncake-store/src/data_manager.cpp b/mooncake-store/src/data_manager.cpp index dc23354552..613c158168 100644 --- a/mooncake-store/src/data_manager.cpp +++ b/mooncake-store/src/data_manager.cpp @@ -396,22 +396,38 @@ DataManager::LookupPinnedKeyHandleInternal(const KeyCtx& ctx, return it->second.handle; } -void DataManager::AbortPendingWrite(std::string_view key, - const UUID& pending_write_token) { - AbortPendingWriteInternal(BuildKeyCtx(key), pending_write_token); +tl::expected DataManager::WriteRevoke( + std::string_view key, const UUID& pending_write_token) { + return WriteRevokeInternal(BuildKeyCtx(key), pending_write_token); } -void DataManager::AbortPendingWriteInternal(const KeyCtx& ctx, - const UUID& pending_write_token) { - if (ctx.key.empty() || IsZeroUuid(pending_write_token)) return; +tl::expected DataManager::WriteRevokeInternal( + const KeyCtx& ctx, const UUID& pending_write_token) { + ScopedVLogTimer timer(1, "DataManager::WriteRevoke"); + timer.LogRequest("key=", ctx.key); + + if (ctx.key.empty() || IsZeroUuid(pending_write_token)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + std::unique_lock key_lock(GetKeyLock(ctx.key)); auto& shard = GetPendingWriteShard(ctx); std::unique_lock shard_lock(shard.mutex); - auto it = shard.by_key.find(ctx.key_string); - if (it == shard.by_key.end()) return; - if (it->second.pending_write_token != pending_write_token) return; - shard.ordered_list.erase(it->second.list_it); - shard.by_key.erase(it); + + auto record_it = shard.by_key.find(ctx.key_string); + if (record_it == shard.by_key.end()) { + timer.LogResponse("error_code=", ErrorCode::OK, "idempotent=", true); + return {}; + } + if (record_it->second.pending_write_token != pending_write_token) { + timer.LogResponse("error_code=", ErrorCode::INVALID_WRITE); + return tl::make_unexpected(ErrorCode::INVALID_WRITE); + } + shard.ordered_list.erase(record_it->second.list_it); + shard.by_key.erase(record_it); + timer.LogResponse("error_code=", ErrorCode::OK, "record_erased=", true); + return {}; } void DataManager::ShutdownLeaseScanner() { @@ -504,7 +520,7 @@ DataManager::PutViaTe(std::string_view key, std::vector& slices) { auto handle_result = LookupPendingWriteHandleInternal(kctx, pending_write_token); if (!handle_result) { - AbortPendingWriteInternal(kctx, pending_write_token); + (void)WriteRevokeInternal(kctx, pending_write_token); return tl::unexpected(handle_result.error()); } AllocationHandle alloc_handle = handle_result.value(); @@ -515,7 +531,7 @@ DataManager::PutViaTe(std::string_view key, std::vector& slices) { LOG(ERROR) << "SubmitTeTransferInternal failed" << ", key=" << key << ", error_code=" << toString(submit_result.error()); - AbortPendingWriteInternal(kctx, pending_write_token); + (void)WriteRevokeInternal(kctx, pending_write_token); return tl::unexpected(submit_result.error()); } @@ -530,7 +546,7 @@ DataManager::PutViaTe(std::string_view key, std::vector& slices) { LOG(ERROR) << "WaitAllTransferBatches failed" << ", key=" << kctx.key << ", error_code=" << toString(wait_result.error()); - AbortPendingWriteInternal(kctx, pending_write_token); + (void)WriteRevokeInternal(kctx, pending_write_token); return tl::unexpected(wait_result.error()); } @@ -547,7 +563,7 @@ DataManager::PutViaTe(std::string_view key, std::vector& slices) { << "CopyFromDRAMBuffer failed" << ", key=" << kctx.key << ", error_code=" << toString(copy_result.error()); - AbortPendingWriteInternal(kctx, pending_write_token); + (void)WriteRevokeInternal(kctx, pending_write_token); return tl::unexpected(copy_result.error()); } } @@ -581,7 +597,7 @@ DataManager::PutViaMemcpy(std::string_view key, std::vector& slices) { auto handle_result = LookupPendingWriteHandleInternal(kctx, pending_write_token); if (!handle_result) { - AbortPendingWriteInternal(kctx, pending_write_token); + (void)WriteRevokeInternal(kctx, pending_write_token); return tl::unexpected(handle_result.error()); } AllocationHandle alloc_handle = handle_result.value(); @@ -596,7 +612,7 @@ DataManager::PutViaMemcpy(std::string_view key, std::vector& slices) { if (!write_result.has_value()) { LOG(ERROR) << "Failed to write data for key: " << kctx.key << ", error: " << write_result.error(); - AbortPendingWriteInternal(kctx, pending_write_token); + (void)WriteRevokeInternal(kctx, pending_write_token); return tl::make_unexpected(write_result.error()); } return {}; @@ -877,7 +893,7 @@ tl::expected DataManager::WriteRemoteData( auto handle_result = LookupPendingWriteHandleInternal(kctx, pending_write_token); if (!handle_result) { - AbortPendingWriteInternal(kctx, pending_write_token); + (void)WriteRevokeInternal(kctx, pending_write_token); timer.LogResponse("error_code=", handle_result.error()); return tl::make_unexpected(handle_result.error()); } @@ -887,7 +903,7 @@ tl::expected DataManager::WriteRemoteData( // Transfer phase — no long key lock held. auto transfer_result = TransferDataFromRemote(handle, src_buffers); if (!transfer_result) { - AbortPendingWriteInternal(kctx, pending_write_token); + (void)WriteRevokeInternal(kctx, pending_write_token); timer.LogResponse("error_code=", transfer_result.error()); return tl::make_unexpected(transfer_result.error()); } diff --git a/mooncake-store/src/p2p_client_service.cpp b/mooncake-store/src/p2p_client_service.cpp index b1b047a183..a963b2298b 100644 --- a/mooncake-store/src/p2p_client_service.cpp +++ b/mooncake-store/src/p2p_client_service.cpp @@ -23,6 +23,7 @@ namespace { // "other" errors. INVALID_READ = token mismatch (treat as released for flow). // RPC_FAIL = transport/timeout-like (no repeat; owner may TTL-clean). LEASE_EXPIRED // = server already expired the pin record. +// Same max-attempt count is used for WriteRevoke after forward write TE failure. constexpr int kForwardReadUnpinMaxAttempts = 3; bool UnPinErrorTreatAsEffectiveOk(ErrorCode e) { @@ -36,6 +37,16 @@ bool UnPinErrorTreatAsEffectiveOk(ErrorCode e) { } } +bool WriteRevokeErrorTreatAsEffectiveOk(ErrorCode e) { + switch (e) { + case ErrorCode::INVALID_WRITE: + case ErrorCode::RPC_FAIL: + return true; + default: + return false; + } +} + bool SlicesAreContiguous(const std::vector& slices) { if (slices.empty()) { return false; @@ -1091,6 +1102,31 @@ async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( base, TotalSliceBytes(*slices), dest, Transport::TransferRequest::WRITE); if (!te) { + LOG(ERROR) << "Forward TE write failed, key=" << write_req->key + << ", error=" << te.error(); + WriteRevokeRequest revoke_req; + revoke_req.key = write_req->key; + revoke_req.pending_write_token = pre.value().pending_write_token; + tl::expected revoke_res; + for (int attempt = 0; attempt < kForwardReadUnpinMaxAttempts; ++attempt) { + revoke_res = co_await peer->AsyncWriteRevoke(revoke_req); + if (revoke_res) { + break; + } + if (WriteRevokeErrorTreatAsEffectiveOk(revoke_res.error())) { + revoke_res = tl::expected{}; + break; + } + if (attempt + 1 < kForwardReadUnpinMaxAttempts) { + LOG(WARNING) << "AsyncWriteRevoke retry after TE failure, key=" + << write_req->key << ", attempt=" << (attempt + 1) + << ", error=" << revoke_res.error(); + } + } + if (!revoke_res) { + LOG(ERROR) << "AsyncWriteRevoke failed after TE failure, key=" + << write_req->key << ", error=" << revoke_res.error(); + } promise->setValue(tl::make_unexpected(te.error())); co_return; } diff --git a/mooncake-store/src/peer_client.cpp b/mooncake-store/src/peer_client.cpp index feb37bbe77..516a8d6b3c 100644 --- a/mooncake-store/src/peer_client.cpp +++ b/mooncake-store/src/peer_client.cpp @@ -127,6 +127,31 @@ PeerClient::AsyncWriteCommit(const WriteCommitRequest& request) { co_return result->result(); } +async_simple::coro::Lazy> +PeerClient::AsyncWriteRevoke(const WriteRevokeRequest& request) { + if (!client_pool_) { + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto ret = co_await client_pool_->send_request( + [&](coro_io::client_reuse_hint, coro_rpc::coro_rpc_client& client) { + return client.send_request<&ClientRpcService::WriteRevoke>(request); + }); + if (!ret.has_value()) { + LOG(ERROR) << "AsyncWriteRevoke: client not available"; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto result = co_await std::move(ret.value()); + if (!result) { + LOG(ERROR) << "AsyncWriteRevoke: RPC call failed: " + << result.error().msg; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + co_return result->result(); +} + async_simple::coro::Lazy> PeerClient::AsyncPinKey(const PinKeyRequest& request) { if (!client_pool_) { @@ -196,6 +221,11 @@ tl::expected PeerClient::WriteCommit( return async_simple::coro::syncAwait(AsyncWriteCommit(request)); } +tl::expected PeerClient::WriteRevoke( + const WriteRevokeRequest& request) { + return async_simple::coro::syncAwait(AsyncWriteRevoke(request)); +} + tl::expected PeerClient::PinKey( const PinKeyRequest& request) { return async_simple::coro::syncAwait(AsyncPinKey(request)); diff --git a/mooncake-store/tests/data_manager_test.cpp b/mooncake-store/tests/data_manager_test.cpp index 441dfc9868..1849cdf7fd 100644 --- a/mooncake-store/tests/data_manager_test.cpp +++ b/mooncake-store/tests/data_manager_test.cpp @@ -348,6 +348,48 @@ TEST_F(DataManagerTest, WriteCommitTokenMismatchKeepsPendingWriteRecord) { } } +// WriteRevoke removes pending allocation without committing the object. +TEST_F(DataManagerTest, WriteRevokeErasesPendingWriteRecord) { + const std::string key = "write_revoke_erases_record_key"; + auto prewrite_result = data_manager_->PreWrite(key, 256, GetTierId()); + ASSERT_TRUE(prewrite_result.has_value()) + << "PreWrite failed: " << toString(prewrite_result.error()); + + auto revoke_result = + data_manager_->WriteRevoke(key, prewrite_result->pending_write_token); + ASSERT_TRUE(revoke_result.has_value()) + << "WriteRevoke failed: " << toString(revoke_result.error()); + + auto& shard = data_manager_->GetPendingWriteShard(key); + { + std::shared_lock shard_lock(shard.mutex); + EXPECT_EQ(shard.by_key.count(key), 0U); + } +} + +// WriteRevoke: wrong token should not erase the pending record. +TEST_F(DataManagerTest, WriteRevokeTokenMismatchKeepsPendingWriteRecord) { + const std::string key = "write_revoke_token_mismatch_key"; + auto prewrite_result = data_manager_->PreWrite(key, 256, GetTierId()); + ASSERT_TRUE(prewrite_result.has_value()) + << "PreWrite failed: " << toString(prewrite_result.error()); + + UUID wrong_token = prewrite_result->pending_write_token; + wrong_token.first += 1; + auto wrong_revoke = data_manager_->WriteRevoke(key, wrong_token); + ASSERT_FALSE(wrong_revoke.has_value()); + EXPECT_EQ(wrong_revoke.error(), ErrorCode::INVALID_WRITE); + + auto& shard = data_manager_->GetPendingWriteShard(key); + { + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + EXPECT_EQ(it->second.pending_write_token, + prewrite_result->pending_write_token); + } +} + // Test Pin/Unpin: ref_count increments on PinKey and reaches zero on final // UnPinKey. TEST_F(DataManagerTest, PinKeyTracksRefCountUntilFinalUnpin) { From 5de2a610a944f9cdb86742f8dcfff6bb58acae54 Mon Sep 17 00:00:00 2001 From: shichangzhang064 Date: Tue, 12 May 2026 17:34:50 +0800 Subject: [PATCH 11/14] feat: add peer client and client rpc services tests --- .../tests/client_rpc_service_test.cpp | 275 +++++++ .../tests/p2p_client_integration_test.cpp | 44 ++ mooncake-store/tests/peer_client_test.cpp | 693 ++++++++++++++++++ 3 files changed, 1012 insertions(+) diff --git a/mooncake-store/tests/client_rpc_service_test.cpp b/mooncake-store/tests/client_rpc_service_test.cpp index 63130adb42..4aae44c05f 100644 --- a/mooncake-store/tests/client_rpc_service_test.cpp +++ b/mooncake-store/tests/client_rpc_service_test.cpp @@ -203,6 +203,189 @@ TEST_F(ClientRpcServiceTest, ReadRemoteDataKeyNotFound) { auto result = rpc_service_->ReadRemoteData(request); ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::OBJECT_NOT_FOUND); +} + +// ============================================================================ +// PinKey / UnPinKey (forward read control plane) +// ============================================================================ + +TEST_F(ClientRpcServiceTest, PinKeyEmptyKey) { + PinKeyRequest req; + req.key = ""; + req.target_tier_id = std::nullopt; + + auto result = rpc_service_->PinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(ClientRpcServiceTest, PinKeyObjectNotFound) { + PinKeyRequest req; + req.key = "rpc_svc_pin_missing_key"; + req.target_tier_id = std::nullopt; + + auto result = rpc_service_->PinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::OBJECT_NOT_FOUND); +} + +TEST_F(ClientRpcServiceTest, PinKeyAfterPutThenUnPin) { + const std::string key = "rpc_svc_pin_after_put"; + const std::string blob = "payload"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()) << "Put failed"; + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto pin_res = rpc_service_->PinKey(pin_req); + ASSERT_TRUE(pin_res.has_value()) + << "PinKey failed: " << static_cast(pin_res.error()); + EXPECT_GT(pin_res->remote_buffer.size, 0u); + EXPECT_NE(pin_res->pin_token.first, 0u); + EXPECT_NE(pin_res->pin_token.second, 0u); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount. + UnPinKeyRequest unpin; + unpin.key = key; + unpin.pin_token = pin_res->pin_token; + auto unpin_res = rpc_service_->UnPinKey(unpin); + ASSERT_TRUE(unpin_res.has_value()) + << "UnPinKey failed: " << static_cast(unpin_res.error()); +} + +TEST_F(ClientRpcServiceTest, PinKeyTwiceSameTokenThenUnpinTwice) { + const std::string key = "rpc_svc_pin_twice_ref"; + const std::string blob = "ref"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto first = rpc_service_->PinKey(pin_req); + ASSERT_TRUE(first.has_value()) + << "first PinKey failed: " << static_cast(first.error()); + auto second = rpc_service_->PinKey(pin_req); + ASSERT_TRUE(second.has_value()) + << "second PinKey failed: " << static_cast(second.error()); + + EXPECT_EQ(first->pin_token, second->pin_token); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount. + UnPinKeyRequest unpin; + unpin.key = key; + unpin.pin_token = first->pin_token; + auto u1 = rpc_service_->UnPinKey(unpin); + ASSERT_TRUE(u1.has_value()) + << "first UnPinKey failed: " << static_cast(u1.error()); + + auto u2 = rpc_service_->UnPinKey(unpin); + ASSERT_TRUE(u2.has_value()) + << "second UnPinKey failed: " << static_cast(u2.error()); +} + +TEST_F(ClientRpcServiceTest, PinKeyAfterUnpinNewToken) { + const std::string key = "rpc_svc_pin_new_token_after_unpin"; + const std::string blob = "tok"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto pin1 = rpc_service_->PinKey(pin_req); + ASSERT_TRUE(pin1.has_value()) + << "first PinKey failed: " << static_cast(pin1.error()); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount. + UnPinKeyRequest unpin1; + unpin1.key = key; + unpin1.pin_token = pin1->pin_token; + auto un1 = rpc_service_->UnPinKey(unpin1); + ASSERT_TRUE(un1.has_value()) + << "first UnPinKey failed: " << static_cast(un1.error()); + + auto pin2 = rpc_service_->PinKey(pin_req); + ASSERT_TRUE(pin2.has_value()) + << "second PinKey after unpin failed: " + << static_cast(pin2.error()); + EXPECT_NE(pin1->pin_token, pin2->pin_token); + + UnPinKeyRequest unpin2; + unpin2.key = key; + unpin2.pin_token = pin2->pin_token; + auto un2 = rpc_service_->UnPinKey(unpin2); + ASSERT_TRUE(un2.has_value()) + << "second UnPinKey failed: " << static_cast(un2.error()); +} + +TEST_F(ClientRpcServiceTest, UnPinKeyEmptyKey) { + UnPinKeyRequest req; + req.key = ""; + req.pin_token = {1, 2}; + + auto result = rpc_service_->UnPinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(ClientRpcServiceTest, UnPinKeyZeroToken) { + UnPinKeyRequest req; + req.key = "rpc_svc_unpin_zero"; + req.pin_token = {0, 0}; + + auto result = rpc_service_->UnPinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(ClientRpcServiceTest, UnPinKeyWrongTokenAfterPin) { + const std::string key = "rpc_svc_unpin_wrong_token"; + const std::string blob = "x"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + auto pin_res = rpc_service_->PinKey(pin_req); + ASSERT_TRUE(pin_res.has_value()); + + UnPinKeyRequest bad; + bad.key = key; + bad.pin_token = pin_res->pin_token; + bad.pin_token.first += 1; + auto bad_res = rpc_service_->UnPinKey(bad); + ASSERT_FALSE(bad_res.has_value()); + EXPECT_EQ(bad_res.error(), ErrorCode::INVALID_READ); + + UnPinKeyRequest ok; + ok.key = key; + ok.pin_token = pin_res->pin_token; + auto ok_res = rpc_service_->UnPinKey(ok); + ASSERT_TRUE(ok_res.has_value()) + << "UnPinKey with correct token failed: " + << static_cast(ok_res.error()); } // ============================================================================ @@ -289,4 +472,96 @@ TEST_F(ClientRpcServiceTest, WriteRemoteDataWithTierId) { EXPECT_EQ(result.error(), ErrorCode::INTERNAL_ERROR); } +// ============================================================================ +// WriteRevoke +// ============================================================================ + +TEST_F(ClientRpcServiceTest, WriteRevokeInvalidKey) { + WriteRevokeRequest request; + request.key = ""; + request.pending_write_token = {1, 0}; + + auto result = rpc_service_->WriteRevoke(request); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(ClientRpcServiceTest, WriteRevokeInvalidZeroToken) { + const std::string key = "rpc_svc_revoke_zero_token"; + WriteRevokeRequest request; + request.key = key; + request.pending_write_token = {0, 0}; + + auto result = rpc_service_->WriteRevoke(request); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(ClientRpcServiceTest, WriteRevokeIdempotentNoPendingRecord) { + const std::string key = "rpc_svc_revoke_no_pending"; + WriteRevokeRequest request; + request.key = key; + request.pending_write_token = {100, 200}; + + auto result = rpc_service_->WriteRevoke(request); + ASSERT_TRUE(result.has_value()) + << "WriteRevoke on missing pending should be OK (idempotent)"; +} + +TEST_F(ClientRpcServiceTest, WriteRevokeAfterPreWrite) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "rpc_svc_revoke_after_prewrite"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 256; + pre.target_tier_id = tier_id; + + auto pre_res = rpc_service_->PreWrite(pre); + ASSERT_TRUE(pre_res.has_value()) + << "PreWrite failed: " << static_cast(pre_res.error()); + + WriteRevokeRequest revoke; + revoke.key = key; + revoke.pending_write_token = pre_res->pending_write_token; + auto rev_res = rpc_service_->WriteRevoke(revoke); + ASSERT_TRUE(rev_res.has_value()) + << "WriteRevoke failed: " << static_cast(rev_res.error()); + + auto again = rpc_service_->WriteRevoke(revoke); + ASSERT_TRUE(again.has_value()) + << "Second WriteRevoke on same key/token should be idempotent OK"; +} + +TEST_F(ClientRpcServiceTest, WriteRevokeTokenMismatch) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "rpc_svc_revoke_token_mismatch"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 64; + pre.target_tier_id = tier_id; + + auto pre_res = rpc_service_->PreWrite(pre); + ASSERT_TRUE(pre_res.has_value()); + + UUID wrong_token = pre_res->pending_write_token; + wrong_token.first += 1; + + WriteRevokeRequest bad; + bad.key = key; + bad.pending_write_token = wrong_token; + auto bad_res = rpc_service_->WriteRevoke(bad); + ASSERT_FALSE(bad_res.has_value()); + EXPECT_EQ(bad_res.error(), ErrorCode::INVALID_WRITE); + + WriteRevokeRequest good; + good.key = key; + good.pending_write_token = pre_res->pending_write_token; + auto good_res = rpc_service_->WriteRevoke(good); + ASSERT_TRUE(good_res.has_value()); +} + } // namespace mooncake diff --git a/mooncake-store/tests/p2p_client_integration_test.cpp b/mooncake-store/tests/p2p_client_integration_test.cpp index 32d711331f..0d04b4b295 100644 --- a/mooncake-store/tests/p2p_client_integration_test.cpp +++ b/mooncake-store/tests/p2p_client_integration_test.cpp @@ -622,5 +622,49 @@ TEST_F(P2PClientIntegrationTest, LocalGetBufferHandleWithTeTransferMode) { EXPECT_TRUE(unreg_src.has_value()); } +TEST_F(P2PClientIntegrationTest, ForwardRemotePutAndGet) { + const std::vector transfer_modes = {"te", "memcpy"}; + for (const auto& mode : transfer_modes) { + SCOPED_TRACE("local_transfer_mode=" + mode); + + std::string host = "localhost:" + std::to_string(getFreeTcpPort()); + auto remote_writer = CreateP2PClient(host, /*rpc_port=*/0, mode); + ASSERT_NE(remote_writer, nullptr); + + const std::string key = "p2p_fwd_put_get_" + mode + "_" + host; + const std::string payload = "forward_payload_" + mode + "_data"; + + WriteRouteRequestConfig route; + route.allow_local = false; + route.prefer_local = false; + route.max_candidates = WriteRouteRequestConfig::RETURN_ALL_CANDIDATES; + + WriteConfigExt wcfg(route); + wcfg.rdma_direction_mode = RdmaDirectionMode::FORWARD; + + std::vector slices; + slices.emplace_back( + Slice{const_cast(payload.data()), payload.size()}); + auto put_res = remote_writer->Put(key, slices, wcfg); + ASSERT_TRUE(put_res.has_value()) + << "Forward Put failed mode=" << mode + << " err=" << static_cast(put_res.error()); + + ReadConfigExt rcfg; + rcfg.route_config.max_candidates = + GetReplicaListRequestConfig::RETURN_ALL_CANDIDATES; + rcfg.rdma_direction_mode = RdmaDirectionMode::FORWARD; + + std::vector buf(payload.size(), 0); + auto get_res = remote_writer->Get(key, {(void*)buf.data()}, {buf.size()}, + rcfg); + ASSERT_TRUE(get_res.has_value()) + << "Forward Get failed mode=" << mode + << " err=" << static_cast(get_res.error()); + EXPECT_EQ(static_cast(get_res.value()), payload.size()); + EXPECT_EQ(std::string(buf.data(), buf.size()), payload); + } +} + } // namespace testing } // namespace mooncake diff --git a/mooncake-store/tests/peer_client_test.cpp b/mooncake-store/tests/peer_client_test.cpp index e931b5f829..767b64ef3e 100644 --- a/mooncake-store/tests/peer_client_test.cpp +++ b/mooncake-store/tests/peer_client_test.cpp @@ -272,6 +272,206 @@ TEST_F(PeerClientTest, AsyncReadRemoteDataWithExistingKey) { EXPECT_EQ(result.error(), ErrorCode::INTERNAL_ERROR); } +// ============================================================================ +// PinKey / UnPinKey (async) — forward read control plane over PeerClient +// ============================================================================ + +TEST_F(PeerClientTest, AsyncPinKeyEmptyKey) { + PinKeyRequest req; + req.key = ""; + req.target_tier_id = std::nullopt; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(req)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncPinKeyKeyNotFound) { + PinKeyRequest req; + req.key = "peer_async_pin_missing_key"; + req.target_tier_id = std::nullopt; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(req)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::OBJECT_NOT_FOUND); +} + +TEST_F(PeerClientTest, AsyncPinKeyAfterPut) { + const std::string key = "peer_async_pin_after_put"; + const std::string blob = "payload"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()) << "Put failed"; + put.value()->Wait(); + + PinKeyRequest req; + req.key = key; + req.target_tier_id = std::nullopt; + + auto pin_res = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(req)); + ASSERT_TRUE(pin_res.has_value()) + << "AsyncPinKey failed: " << static_cast(pin_res.error()); + EXPECT_GT(pin_res->remote_buffer.size, 0u); + EXPECT_NE(pin_res->pin_token.first, 0u); + EXPECT_NE(pin_res->pin_token.second, 0u); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount via RPC. + UnPinKeyRequest unpin; + unpin.key = key; + unpin.pin_token = pin_res->pin_token; + auto unpin_res = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin)); + ASSERT_TRUE(unpin_res.has_value()) + << "AsyncUnPinKey failed: " << static_cast(unpin_res.error()); +} + +TEST_F(PeerClientTest, AsyncPinKeyTwiceSameTokenThenUnpinTwice) { + const std::string key = "peer_async_pin_twice_ref"; + const std::string blob = "ref"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto first = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(pin_req)); + ASSERT_TRUE(first.has_value()) + << "first AsyncPinKey failed: " << static_cast(first.error()); + auto second = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(pin_req)); + ASSERT_TRUE(second.has_value()) + << "second AsyncPinKey failed: " << static_cast(second.error()); + + EXPECT_EQ(first->pin_token, second->pin_token); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount via RPC. + UnPinKeyRequest unpin; + unpin.key = key; + unpin.pin_token = first->pin_token; + auto u1 = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin)); + ASSERT_TRUE(u1.has_value()) + << "first AsyncUnPinKey failed: " << static_cast(u1.error()); + + auto u2 = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin)); + ASSERT_TRUE(u2.has_value()) + << "second AsyncUnPinKey failed: " << static_cast(u2.error()); +} + +TEST_F(PeerClientTest, AsyncPinKeyAfterUnpinNewToken) { + const std::string key = "peer_async_pin_new_token_after_unpin"; + const std::string blob = "tok"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto pin1 = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(pin_req)); + ASSERT_TRUE(pin1.has_value()) + << "first AsyncPinKey failed: " << static_cast(pin1.error()); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount via RPC. + UnPinKeyRequest unpin1; + unpin1.key = key; + unpin1.pin_token = pin1->pin_token; + auto un1 = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin1)); + ASSERT_TRUE(un1.has_value()) + << "first AsyncUnPinKey failed: " << static_cast(un1.error()); + + auto pin2 = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(pin_req)); + ASSERT_TRUE(pin2.has_value()) + << "second AsyncPinKey after unpin failed: " + << static_cast(pin2.error()); + EXPECT_NE(pin1->pin_token, pin2->pin_token); + + UnPinKeyRequest unpin2; + unpin2.key = key; + unpin2.pin_token = pin2->pin_token; + auto un2 = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin2)); + ASSERT_TRUE(un2.has_value()) + << "second AsyncUnPinKey failed: " << static_cast(un2.error()); +} + +TEST_F(PeerClientTest, AsyncUnPinKeyEmptyKey) { + UnPinKeyRequest req; + req.key = ""; + req.pin_token = {1, 2}; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(req)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncUnPinKeyZeroToken) { + const std::string key = "peer_async_unpin_zero_token"; + UnPinKeyRequest req; + req.key = key; + req.pin_token = {0, 0}; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(req)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncUnPinKeyWrongTokenAfterPin) { + const std::string key = "peer_async_unpin_wrong_token"; + const std::string blob = "x"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + auto pin_res = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(pin_req)); + ASSERT_TRUE(pin_res.has_value()); + + UnPinKeyRequest bad; + bad.key = key; + bad.pin_token = pin_res->pin_token; + bad.pin_token.first += 1; + auto bad_res = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(bad)); + ASSERT_FALSE(bad_res.has_value()); + EXPECT_EQ(bad_res.error(), ErrorCode::INVALID_READ); + + UnPinKeyRequest ok; + ok.key = key; + ok.pin_token = pin_res->pin_token; + auto ok_res = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(ok)); + ASSERT_TRUE(ok_res.has_value()) + << "AsyncUnPinKey with correct token failed: " + << static_cast(ok_res.error()); +} + // ============================================================================ // AsyncWriteRemoteData Tests // ============================================================================ @@ -354,6 +554,215 @@ TEST_F(PeerClientTest, AsyncWriteRemoteDataWithTierId) { EXPECT_EQ(result.error(), ErrorCode::INTERNAL_ERROR); } +// ============================================================================ +// PreWrite / WriteCommit / WriteRevoke (async) +// ============================================================================ + +TEST_F(PeerClientTest, AsyncPreWriteEmptyKey) { + PreWriteRequest pre; + pre.key = ""; + pre.size_bytes = 64; + // target_tier_id optional; invalid key fails before tier selection. + + auto result = async_simple::coro::syncAwait( + peer_client_->AsyncPreWrite(pre)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncPreWriteZeroSize) { + const std::string key = "peer_async_pre_zero_size"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 0; + + auto result = async_simple::coro::syncAwait( + peer_client_->AsyncPreWrite(pre)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncPreWriteValidRequest) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "peer_async_pre_valid"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 256; + pre.target_tier_id = tier_id; + + auto result = async_simple::coro::syncAwait( + peer_client_->AsyncPreWrite(pre)); + ASSERT_TRUE(result.has_value()) + << "AsyncPreWrite failed: " << static_cast(result.error()); + EXPECT_GT(result->remote_buffer.size, 0u); + EXPECT_NE(result->pending_write_token.first, 0u); + EXPECT_NE(result->pending_write_token.second, 0u); + + WriteRevokeRequest revoke; + revoke.key = key; + revoke.pending_write_token = result->pending_write_token; + auto rev = async_simple::coro::syncAwait( + peer_client_->AsyncWriteRevoke(revoke)); + ASSERT_TRUE(rev.has_value()) + << "Cleanup AsyncWriteRevoke failed: " << static_cast(rev.error()); +} + +TEST_F(PeerClientTest, AsyncPreWriteWhenObjectAlreadyExists) { + const std::string key = "peer_async_pre_key_exists"; + const std::string blob = "existing"; + auto buffer = StringToBuffer(blob); + std::vector put_slices{{buffer.get(), blob.size()}}; + auto put_result = data_manager_->Put(key, put_slices); + ASSERT_TRUE(put_result.has_value()) << "Put failed"; + put_result.value()->Wait(); + + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 128; + + auto result = async_simple::coro::syncAwait( + peer_client_->AsyncPreWrite(pre)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::OBJECT_ALREADY_EXISTS); +} + +TEST_F(PeerClientTest, AsyncWriteCommitAfterPreWrite) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "peer_async_commit_after_pre"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 256; + pre.target_tier_id = tier_id; + + auto pre_res = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); + ASSERT_TRUE(pre_res.has_value()) + << "AsyncPreWrite failed: " << static_cast(pre_res.error()); + + // TransferEngine is not initialized in this unit test, so no real data-plane + // write occurs. Assume the access side has already filled the buffer via TE; + // this case only checks WriteCommit RPC / metadata outcome. + WriteCommitRequest commit; + commit.key = key; + commit.pending_write_token = pre_res->pending_write_token; + auto commit_res = async_simple::coro::syncAwait( + peer_client_->AsyncWriteCommit(commit)); + ASSERT_TRUE(commit_res.has_value()) + << "AsyncWriteCommit failed: " << static_cast(commit_res.error()); +} + +TEST_F(PeerClientTest, AsyncWriteCommitEmptyKey) { + WriteCommitRequest commit; + commit.key = ""; + commit.pending_write_token = {1, 2}; + + auto result = async_simple::coro::syncAwait( + peer_client_->AsyncWriteCommit(commit)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncWriteCommitZeroToken) { + const std::string key = "peer_async_commit_zero_token"; + WriteCommitRequest commit; + commit.key = key; + commit.pending_write_token = {0, 0}; + + auto result = async_simple::coro::syncAwait( + peer_client_->AsyncWriteCommit(commit)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncWriteCommitTokenMismatchAfterPreWrite) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "peer_async_commit_token_bad"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 128; + pre.target_tier_id = tier_id; + + auto pre_res = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); + ASSERT_TRUE(pre_res.has_value()); + + WriteCommitRequest commit; + commit.key = key; + UUID wrong_token = pre_res->pending_write_token; + wrong_token.first += 1; + commit.pending_write_token = wrong_token; + + auto bad = async_simple::coro::syncAwait( + peer_client_->AsyncWriteCommit(commit)); + ASSERT_FALSE(bad.has_value()); + EXPECT_EQ(bad.error(), ErrorCode::INVALID_WRITE); + + WriteRevokeRequest revoke; + revoke.key = key; + revoke.pending_write_token = pre_res->pending_write_token; + auto rev = async_simple::coro::syncAwait( + peer_client_->AsyncWriteRevoke(revoke)); + ASSERT_TRUE(rev.has_value()); +} + +TEST_F(PeerClientTest, AsyncWriteRevokeEmptyKey) { + WriteRevokeRequest request; + request.key = ""; + request.pending_write_token = {1, 2}; + + auto result = async_simple::coro::syncAwait( + peer_client_->AsyncWriteRevoke(request)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncWriteRevokeZeroToken) { + const std::string key = "peer_async_revoke_zero_token"; + WriteRevokeRequest request; + request.key = key; + request.pending_write_token = {0, 0}; + + auto result = async_simple::coro::syncAwait( + peer_client_->AsyncWriteRevoke(request)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncPreWriteThenWriteRevokeIdempotent) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "peer_async_revoke_after_prewrite"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 256; + pre.target_tier_id = tier_id; + + auto pre_res = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); + ASSERT_TRUE(pre_res.has_value()) + << "AsyncPreWrite failed: " << static_cast(pre_res.error()); + + WriteRevokeRequest revoke; + revoke.key = key; + revoke.pending_write_token = pre_res->pending_write_token; + auto rev_res = async_simple::coro::syncAwait( + peer_client_->AsyncWriteRevoke(revoke)); + ASSERT_TRUE(rev_res.has_value()) + << "AsyncWriteRevoke failed: " << static_cast(rev_res.error()); + + auto idem = + async_simple::coro::syncAwait(peer_client_->AsyncWriteRevoke(revoke)); + ASSERT_TRUE(idem.has_value()) + << "Second AsyncWriteRevoke should be idempotent OK"; +} + // ============================================================================ // Sync ReadRemoteData Tests (wrappers around async) // ============================================================================ @@ -412,6 +821,134 @@ TEST_F(PeerClientTest, SyncReadRemoteDataWithExistingKey) { EXPECT_EQ(result.error(), ErrorCode::INTERNAL_ERROR); } +// ============================================================================ +// Sync PinKey / UnPinKey (wrappers around async) +// ============================================================================ + +TEST_F(PeerClientTest, SyncPinKeyEmptyKey) { + PinKeyRequest req; + req.key = ""; + req.target_tier_id = std::nullopt; + + auto result = peer_client_->PinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, SyncPinKeyAfterPut) { + const std::string key = "peer_sync_pin_after_put"; + const std::string blob = "sync-payload"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest req; + req.key = key; + req.target_tier_id = std::nullopt; + + auto pin_res = peer_client_->PinKey(req); + ASSERT_TRUE(pin_res.has_value()) + << "PinKey failed: " << static_cast(pin_res.error()); + EXPECT_GT(pin_res->remote_buffer.size, 0u); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount via RPC. + UnPinKeyRequest unpin; + unpin.key = key; + unpin.pin_token = pin_res->pin_token; + auto unpin_res = peer_client_->UnPinKey(unpin); + ASSERT_TRUE(unpin_res.has_value()) + << "UnPinKey failed: " << static_cast(unpin_res.error()); +} + +TEST_F(PeerClientTest, SyncPinKeyTwiceSameTokenThenUnpinTwice) { + const std::string key = "peer_sync_pin_twice_ref"; + const std::string blob = "ref"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto first = peer_client_->PinKey(pin_req); + ASSERT_TRUE(first.has_value()) + << "first PinKey failed: " << static_cast(first.error()); + auto second = peer_client_->PinKey(pin_req); + ASSERT_TRUE(second.has_value()) + << "second PinKey failed: " << static_cast(second.error()); + + EXPECT_EQ(first->pin_token, second->pin_token); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount via RPC. + UnPinKeyRequest unpin; + unpin.key = key; + unpin.pin_token = first->pin_token; + auto u1 = peer_client_->UnPinKey(unpin); + ASSERT_TRUE(u1.has_value()) + << "first UnPinKey failed: " << static_cast(u1.error()); + + auto u2 = peer_client_->UnPinKey(unpin); + ASSERT_TRUE(u2.has_value()) + << "second UnPinKey failed: " << static_cast(u2.error()); +} + +TEST_F(PeerClientTest, SyncPinKeyAfterUnpinNewToken) { + const std::string key = "peer_sync_pin_new_token_after_unpin"; + const std::string blob = "tok"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto pin1 = peer_client_->PinKey(pin_req); + ASSERT_TRUE(pin1.has_value()) + << "first PinKey failed: " << static_cast(pin1.error()); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount via RPC. + UnPinKeyRequest unpin1; + unpin1.key = key; + unpin1.pin_token = pin1->pin_token; + auto un1 = peer_client_->UnPinKey(unpin1); + ASSERT_TRUE(un1.has_value()) + << "first UnPinKey failed: " << static_cast(un1.error()); + + auto pin2 = peer_client_->PinKey(pin_req); + ASSERT_TRUE(pin2.has_value()) + << "second PinKey after unpin failed: " + << static_cast(pin2.error()); + EXPECT_NE(pin1->pin_token, pin2->pin_token); + + UnPinKeyRequest unpin2; + unpin2.key = key; + unpin2.pin_token = pin2->pin_token; + auto un2 = peer_client_->UnPinKey(unpin2); + ASSERT_TRUE(un2.has_value()) + << "second UnPinKey failed: " << static_cast(un2.error()); +} + +TEST_F(PeerClientTest, SyncUnPinKeyZeroToken) { + UnPinKeyRequest req; + req.key = "peer_sync_unpin_zero"; + req.pin_token = {0, 0}; + + auto result = peer_client_->UnPinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + // ============================================================================ // Sync WriteRemoteData Tests (wrappers around async) // ============================================================================ @@ -449,6 +986,87 @@ TEST_F(PeerClientTest, SyncWriteRemoteDataValidRequest) { EXPECT_EQ(result.error(), ErrorCode::INTERNAL_ERROR); } +// ============================================================================ +// Sync PreWrite / WriteCommit / WriteRevoke +// ============================================================================ + +TEST_F(PeerClientTest, SyncPreWriteEmptyKey) { + PreWriteRequest pre; + pre.key = ""; + pre.size_bytes = 32; + + auto result = peer_client_->PreWrite(pre); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, SyncWriteCommitEmptyKey) { + WriteCommitRequest commit; + commit.key = ""; + commit.pending_write_token = {5, 6}; + + auto result = peer_client_->WriteCommit(commit); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, SyncWriteCommitAfterPreWrite) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "peer_sync_commit_after_pre"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 128; + pre.target_tier_id = tier_id; + + auto pre_res = peer_client_->PreWrite(pre); + ASSERT_TRUE(pre_res.has_value()) + << "PreWrite failed: " << static_cast(pre_res.error()); + + // TransferEngine is not initialized in this unit test, so no real data-plane + // write occurs. Assume the access side has already filled the buffer via TE; + // this case only checks WriteCommit RPC / metadata outcome. + WriteCommitRequest commit; + commit.key = key; + commit.pending_write_token = pre_res->pending_write_token; + auto commit_res = peer_client_->WriteCommit(commit); + ASSERT_TRUE(commit_res.has_value()) + << "WriteCommit failed: " << static_cast(commit_res.error()); +} + +TEST_F(PeerClientTest, SyncWriteRevokeEmptyKey) { + WriteRevokeRequest request; + request.key = ""; + request.pending_write_token = {3, 4}; + + auto result = peer_client_->WriteRevoke(request); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, SyncWriteRevokeAfterPreWrite) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "peer_sync_revoke_after_prewrite"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 128; + pre.target_tier_id = tier_id; + + auto pre_res = peer_client_->PreWrite(pre); + ASSERT_TRUE(pre_res.has_value()) + << "PreWrite failed: " << static_cast(pre_res.error()); + + WriteRevokeRequest revoke; + revoke.key = key; + revoke.pending_write_token = pre_res->pending_write_token; + auto rev_res = peer_client_->WriteRevoke(revoke); + ASSERT_TRUE(rev_res.has_value()) + << "WriteRevoke failed: " << static_cast(rev_res.error()); +} + // ============================================================================ // Not Connected Error Tests // ============================================================================ @@ -509,4 +1127,79 @@ TEST_F(PeerClientTest, SyncWriteWithoutConnect) { EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); } +TEST_F(PeerClientTest, AsyncPinKeyWithoutConnect) { + PeerClient unconnected_client; + + PinKeyRequest req; + req.key = "k"; + req.target_tier_id = std::nullopt; + + auto result = + async_simple::coro::syncAwait(unconnected_client.AsyncPinKey(req)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); +} + +TEST_F(PeerClientTest, AsyncUnPinKeyWithoutConnect) { + PeerClient unconnected_client; + + UnPinKeyRequest req; + req.key = "k"; + req.pin_token = {1, 1}; + + auto result = + async_simple::coro::syncAwait(unconnected_client.AsyncUnPinKey(req)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); +} + +TEST_F(PeerClientTest, SyncPinKeyWithoutConnect) { + PeerClient unconnected_client; + + PinKeyRequest req; + req.key = "k"; + req.target_tier_id = std::nullopt; + + auto result = unconnected_client.PinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); +} + +TEST_F(PeerClientTest, SyncUnPinKeyWithoutConnect) { + PeerClient unconnected_client; + + UnPinKeyRequest req; + req.key = "k"; + req.pin_token = {1, 1}; + + auto result = unconnected_client.UnPinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); +} + +TEST_F(PeerClientTest, AsyncWriteRevokeWithoutConnect) { + PeerClient unconnected_client; + + WriteRevokeRequest request; + request.key = "test_key"; + request.pending_write_token = {1, 1}; + + auto result = async_simple::coro::syncAwait( + unconnected_client.AsyncWriteRevoke(request)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); +} + +TEST_F(PeerClientTest, SyncWriteRevokeWithoutConnect) { + PeerClient unconnected_client; + + WriteRevokeRequest request; + request.key = "test_key"; + request.pending_write_token = {2, 2}; + + auto result = unconnected_client.WriteRevoke(request); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); +} + } // namespace mooncake From 8317337afb5466cb099c53e01c85b4ad8a76e70d Mon Sep 17 00:00:00 2001 From: Shichang-Zhang Date: Tue, 12 May 2026 19:51:03 +0800 Subject: [PATCH 12/14] fix: client rpc service compile error --- mooncake-store/src/client_rpc_service.cpp | 4 ++++ mooncake-store/src/p2p_client_service.cpp | 14 +++++++------- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/mooncake-store/src/client_rpc_service.cpp b/mooncake-store/src/client_rpc_service.cpp index a736500745..1de25501a6 100644 --- a/mooncake-store/src/client_rpc_service.cpp +++ b/mooncake-store/src/client_rpc_service.cpp @@ -10,6 +10,10 @@ namespace mooncake { namespace { +bool IsZeroUuid(const UUID& uuid) { + return uuid.first == 0 && uuid.second == 0; +} + size_t CalculateBufferSize(const std::vector& buffers) { size_t total = 0; for (const auto& buf : buffers) total += buf.size; diff --git a/mooncake-store/src/p2p_client_service.cpp b/mooncake-store/src/p2p_client_service.cpp index a963b2298b..0a7db34cfb 100644 --- a/mooncake-store/src/p2p_client_service.cpp +++ b/mooncake-store/src/p2p_client_service.cpp @@ -1072,13 +1072,13 @@ async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( std::shared_ptr write_req, std::vector* slices) { if (!peer || !dm || !write_req || !slices) { - promise->setValue(tl::unexpected(ErrorCode::INTERNAL_ERROR)); + promise->setValue(tl::expected(tl::unexpect, ErrorCode::INTERNAL_ERROR)); co_return; } if (!SlicesAreContiguous(*slices)) { LOG(ERROR) << "Forward RDMA write requires contiguous slice buffers, key=" << write_req->key; - promise->setValue(tl::unexpected(ErrorCode::INVALID_PARAMS)); + promise->setValue(tl::expected(tl::unexpect, ErrorCode::INVALID_PARAMS)); co_return; } PreWriteRequest pre_req; @@ -1092,7 +1092,7 @@ async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( LOG(ERROR) << "AsyncPreWrite failed, key=" << write_req->key << ", error=" << pre.error(); } - promise->setValue(tl::make_unexpected(pre.error())); + promise->setValue(tl::expected(tl::unexpect, pre.error())); co_return; } @@ -1127,7 +1127,7 @@ async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( LOG(ERROR) << "AsyncWriteRevoke failed after TE failure, key=" << write_req->key << ", error=" << revoke_res.error(); } - promise->setValue(tl::make_unexpected(te.error())); + promise->setValue(tl::expected(tl::unexpect, te.error())); co_return; } @@ -1136,7 +1136,7 @@ async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( commit.pending_write_token = pre.value().pending_write_token; auto cm = co_await peer->AsyncWriteCommit(commit); if (!cm) { - promise->setValue(tl::make_unexpected(cm.error())); + promise->setValue(tl::expected(tl::unexpect, cm.error())); co_return; } promise->setValue(tl::expected{}); @@ -1637,13 +1637,13 @@ async_simple::coro::Lazy P2PClientService::RunForwardReadOnRoute( RouteIterator& iter, ErrorCode& final_result) { if (!data_manager_.has_value()) { LOG(ERROR) << "Forward RDMA read requires DataManager"; - promise->setValue(tl::unexpected(ErrorCode::INTERNAL_ERROR)); + promise->setValue(tl::expected(tl::unexpect, ErrorCode::INTERNAL_ERROR)); co_return true; } if (!RemoteDestBuffersContiguous(req->dest_buffers)) { LOG(ERROR) << "Forward RDMA read requires contiguous dest buffers, key=" << req->key; - promise->setValue(tl::unexpected(ErrorCode::INVALID_PARAMS)); + promise->setValue(tl::expected(tl::unexpect, ErrorCode::INVALID_PARAMS)); co_return true; } PinKeyRequest pin_req; From 0871e70afc07b2c77253e1024c162c72c55c27d8 Mon Sep 17 00:00:00 2001 From: Shichang-Zhang Date: Wed, 13 May 2026 19:44:54 +0800 Subject: [PATCH 13/14] fix: code formate fix --- mooncake-integration/store/store_py.cpp | 3 +- mooncake-store/include/client_rpc_types.h | 3 +- mooncake-store/include/client_service.h | 3 +- mooncake-store/include/data_manager.h | 7 +- mooncake-store/include/p2p_rpc_types.h | 3 +- mooncake-store/include/pyclient.h | 6 +- mooncake-store/include/real_client.h | 3 +- .../src/centralized_client_service.cpp | 9 +- mooncake-store/src/data_manager.cpp | 33 +++---- mooncake-store/src/p2p_client_service.cpp | 84 ++++++++++-------- mooncake-store/src/peer_client.cpp | 3 +- .../tests/client_rpc_service_test.cpp | 10 +-- .../tests/p2p_client_integration_test.cpp | 4 +- mooncake-store/tests/peer_client_test.cpp | 86 +++++++++---------- 14 files changed, 131 insertions(+), 126 deletions(-) diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index 39cb86823a..3d289e99d3 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -1571,8 +1571,7 @@ PYBIND11_MODULE(store, m) { const std::vector>& all_buffer_ptrs, const std::vector>& all_sizes, bool aggregate_same_segment_task = false, - const std::optional& config_opt = - std::nullopt) { + const std::optional& config_opt = std::nullopt) { ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); py::gil_scoped_release release; if (self.use_dummy_client_) { diff --git a/mooncake-store/include/client_rpc_types.h b/mooncake-store/include/client_rpc_types.h index fde4587cda..c67b02564e 100644 --- a/mooncake-store/include/client_rpc_types.h +++ b/mooncake-store/include/client_rpc_types.h @@ -107,7 +107,8 @@ struct WriteCommitRequest { YLT_REFL(WriteCommitRequest, key, pending_write_token); -/** Drops a pending PreWrite allocation without committing (e.g. after TE failure). */ +/** Drops a pending PreWrite allocation without committing (e.g. after TE + * failure). */ struct WriteRevokeRequest { std::string_view key; UUID pending_write_token; diff --git a/mooncake-store/include/client_service.h b/mooncake-store/include/client_service.h index 53567571fe..8728ea0b09 100644 --- a/mooncake-store/include/client_service.h +++ b/mooncake-store/include/client_service.h @@ -158,8 +158,7 @@ class ClientService { */ virtual tl::expected Get( const std::string& key, const std::vector& buffers, - const std::vector& sizes, - const ReadConfigExt& config = {}) = 0; + const std::vector& sizes, const ReadConfigExt& config = {}) = 0; /** * @brief Batch get data into user-provided buffers diff --git a/mooncake-store/include/data_manager.h b/mooncake-store/include/data_manager.h index 452ea6a3ff..d83ba4672b 100644 --- a/mooncake-store/include/data_manager.h +++ b/mooncake-store/include/data_manager.h @@ -202,7 +202,7 @@ class DataManager { * Used when forward TE fails after PreWrite on the peer. */ tl::expected WriteRevoke(std::string_view key, - const UUID& pending_write_token); + const UUID& pending_write_token); tl::expected PinKey( std::string_view key, std::optional tier_id = std::nullopt); @@ -350,8 +350,9 @@ class DataManager { const std::vector& remote_buffers, Transport::TransferRequest::OpCode opcode); - tl::expected>, - ErrorCode> + tl::expected< + std::vector>, + ErrorCode> SubmitTeTransferBatches(void* transfer_ptr, size_t total_data_size, const std::vector& remote_buffers, Transport::TransferRequest::OpCode opcode); diff --git a/mooncake-store/include/p2p_rpc_types.h b/mooncake-store/include/p2p_rpc_types.h index a72c7108ed..c947fadc62 100644 --- a/mooncake-store/include/p2p_rpc_types.h +++ b/mooncake-store/include/p2p_rpc_types.h @@ -49,7 +49,8 @@ struct WriteConfigExt { RdmaDirectionMode rdma_direction_mode = RdmaDirectionMode::REVERSE; WriteConfigExt() = default; - /** Promotes legacy write-route-only config for variant overload resolution. */ + /** Promotes legacy write-route-only config for variant overload resolution. + */ WriteConfigExt(WriteRouteRequestConfig r) : route_config(std::move(r)) {} }; YLT_REFL(WriteConfigExt, route_config, rdma_direction_mode); diff --git a/mooncake-store/include/pyclient.h b/mooncake-store/include/pyclient.h index c909f74577..abeea65cb0 100644 --- a/mooncake-store/include/pyclient.h +++ b/mooncake-store/include/pyclient.h @@ -44,15 +44,13 @@ class PyClient { virtual std::vector batch_get_into( const std::vector& keys, const std::vector& buffers, - const std::vector& sizes, - const ReadConfigExt& config = {}) = 0; + const std::vector& sizes, const ReadConfigExt& config = {}) = 0; virtual std::vector batch_get_into_multi_buffers( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - bool aggregate_same_segment_task, - const ReadConfigExt& config = {}) = 0; + bool aggregate_same_segment_task, const ReadConfigExt& config = {}) = 0; virtual int put_from(const std::string& key, void* buffer, size_t size, const WriteConfig& config) = 0; diff --git a/mooncake-store/include/real_client.h b/mooncake-store/include/real_client.h index f49dc468d1..02108c9bfb 100644 --- a/mooncake-store/include/real_client.h +++ b/mooncake-store/include/real_client.h @@ -406,8 +406,7 @@ class RealClient : public PyClient { const ReadConfigExt& config = {}); std::vector> batch_get_buffer_internal( - const std::vector& keys, - const ReadConfigExt& config = {}); + const std::vector& keys, const ReadConfigExt& config = {}); std::map> batch_get_replica_desc(const std::vector& keys); diff --git a/mooncake-store/src/centralized_client_service.cpp b/mooncake-store/src/centralized_client_service.cpp index 1c550bd8c3..80557bf9e4 100644 --- a/mooncake-store/src/centralized_client_service.cpp +++ b/mooncake-store/src/centralized_client_service.cpp @@ -242,7 +242,8 @@ CentralizedClientService::Query(const std::string& object_key, } std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now(); - auto result = master_client_.GetReplicaList(object_key, config.route_config); + auto result = + master_client_.GetReplicaList(object_key, config.route_config); if (!result) { LOG(ERROR) << "Failed to get replica list: " << result.error(); return tl::unexpected(result.error()); @@ -262,8 +263,7 @@ CentralizedClientService::Query(const std::string& object_key, std::vector, ErrorCode>> CentralizedClientService::BatchQuery( - const std::vector& object_keys, - const ReadConfigExt& config) { + const std::vector& object_keys, const ReadConfigExt& config) { auto guard = AcquireInflightGuard(); if (!guard.is_valid()) { LOG(ERROR) << "client is shutting down"; @@ -279,7 +279,8 @@ CentralizedClientService::BatchQuery( std::chrono::steady_clock::now(); std::vector key_views(object_keys.begin(), object_keys.end()); - auto response = master_client_.BatchGetReplicaList(key_views, config.route_config); + auto response = + master_client_.BatchGetReplicaList(key_views, config.route_config); // Check if we got the expected number of responses if (response.size() != object_keys.size()) { diff --git a/mooncake-store/src/data_manager.cpp b/mooncake-store/src/data_manager.cpp index 613c158168..ee42b1c0ce 100644 --- a/mooncake-store/src/data_manager.cpp +++ b/mooncake-store/src/data_manager.cpp @@ -509,7 +509,8 @@ DataManager::PutViaTe(std::string_view key, std::vector& slices) { return tl::unexpected(validate_result.error()); } - // Local Put: allocation follows tier backend policy (not restricted to DRAM). + // Local Put: allocation follows tier backend policy (not restricted to + // DRAM). auto prewrite_result = PreWriteInternal(kctx, total_size, std::nullopt, false); if (!prewrite_result) { @@ -882,8 +883,7 @@ tl::expected DataManager::WriteRemoteData( // Reverse RDMA path: still one RPC, but internally use the 3-phase write // model (PreWrite -> transfer -> WriteCommit). Target tier may be non-DRAM. - auto prewrite_result = - PreWriteInternal(kctx, total_size, tier_id, false); + auto prewrite_result = PreWriteInternal(kctx, total_size, tier_id, false); if (!prewrite_result) { timer.LogResponse("error_code=", prewrite_result.error()); return tl::make_unexpected(prewrite_result.error()); @@ -926,10 +926,9 @@ tl::expected DataManager::PreWrite( return PreWriteInternal(BuildKeyCtx(key), size_bytes, tier_id, true); } -tl::expected -DataManager::PreWriteInternal(const KeyCtx& ctx, size_t size_bytes, - std::optional tier_id, - bool enforce_dram_allocation) { +tl::expected DataManager::PreWriteInternal( + const KeyCtx& ctx, size_t size_bytes, std::optional tier_id, + bool enforce_dram_allocation) { ScopedVLogTimer timer(1, "DataManager::PreWrite"); timer.LogRequest("key=", ctx.key, "size_bytes=", size_bytes); @@ -963,11 +962,11 @@ DataManager::PreWriteInternal(const KeyCtx& ctx, size_t size_bytes, auto handle = std::move(handle_result.value()); // When enforce_dram_allocation is true (RPC PreWrite): only DRAM is wired - // for forward TE today; non-DRAM tiers TODO. Local Put / WriteRemoteData use - // false and skip this check. - if (enforce_dram_allocation && - handle->loc.data.type != MemoryType::DRAM) { - timer.LogResponse("error_code=", ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + // for forward TE today; non-DRAM tiers TODO. Local Put / WriteRemoteData + // use false and skip this check. + if (enforce_dram_allocation && handle->loc.data.type != MemoryType::DRAM) { + timer.LogResponse("error_code=", + ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); return tl::make_unexpected(ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); } auto list_it = shard.ordered_list.emplace(shard.ordered_list.end(), @@ -1075,7 +1074,8 @@ tl::expected DataManager::PinKeyInternal( RemoveExpiredPinnedKeyLocked(shard, ctx.key_string, now); auto record_it = shard.by_key.find(ctx.key_string); if (record_it != shard.by_key.end()) { - // PinKey forward path: DRAM-only for now; non-DRAM replica handling TODO. + // PinKey forward path: DRAM-only for now; non-DRAM replica handling + // TODO. if (record_it->second.handle->loc.data.type != MemoryType::DRAM) { timer.LogResponse("error_code=", ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); @@ -1103,7 +1103,8 @@ tl::expected DataManager::PinKeyInternal( auto handle = std::move(handle_result.value()); // PinKey forward path: DRAM-only for now; non-DRAM replica handling TODO. if (handle->loc.data.type != MemoryType::DRAM) { - timer.LogResponse("error_code=", ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + timer.LogResponse("error_code=", + ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); return tl::make_unexpected(ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); } auto list_it = shard.ordered_list.emplace(shard.ordered_list.end(), @@ -1275,8 +1276,8 @@ DataManager::SubmitTeTransferInternal( std::move(buffer_result.value()); } - auto batches_result = SubmitTeTransferBatches( - transfer_ptr, total_data_size, remote_buffers, opcode); + auto batches_result = SubmitTeTransferBatches(transfer_ptr, total_data_size, + remote_buffers, opcode); if (!batches_result) { return tl::unexpected(batches_result.error()); } diff --git a/mooncake-store/src/p2p_client_service.cpp b/mooncake-store/src/p2p_client_service.cpp index 0a7db34cfb..4b1516ba11 100644 --- a/mooncake-store/src/p2p_client_service.cpp +++ b/mooncake-store/src/p2p_client_service.cpp @@ -21,9 +21,9 @@ namespace { // UnPin after forward read (or cleanup after TE failure): retry only for // "other" errors. INVALID_READ = token mismatch (treat as released for flow). -// RPC_FAIL = transport/timeout-like (no repeat; owner may TTL-clean). LEASE_EXPIRED -// = server already expired the pin record. -// Same max-attempt count is used for WriteRevoke after forward write TE failure. +// RPC_FAIL = transport/timeout-like (no repeat; owner may TTL-clean). +// LEASE_EXPIRED = server already expired the pin record. Same max-attempt count +// is used for WriteRevoke after forward write TE failure. constexpr int kForwardReadUnpinMaxAttempts = 3; bool UnPinErrorTreatAsEffectiveOk(ErrorCode e) { @@ -992,8 +992,8 @@ std::unique_ptr> P2PClientService::RemoteWriteOp::Dispatch() { }); } if (rdma_direction_mode == RdmaDirectionMode::FORWARD) { - return owner_service->StartForwardRemotePut( - peer_ptr, forward_dm, forward_slices, write_req); + return owner_service->StartForwardRemotePut(peer_ptr, forward_dm, + forward_slices, write_req); } return owner_service->RunReverseRemotePut(peer_ptr, write_req, proxy, route_cache); @@ -1069,16 +1069,20 @@ async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( std::shared_ptr>> promise, PeerClient* peer, DataManager* dm, - std::shared_ptr write_req, - std::vector* slices) { + std::shared_ptr write_req, std::vector* slices) { if (!peer || !dm || !write_req || !slices) { - promise->setValue(tl::expected(tl::unexpect, ErrorCode::INTERNAL_ERROR)); + tl::expected err = + tl::make_unexpected(ErrorCode::INTERNAL_ERROR); + promise->setValue(std::move(err)); co_return; } if (!SlicesAreContiguous(*slices)) { - LOG(ERROR) << "Forward RDMA write requires contiguous slice buffers, key=" - << write_req->key; - promise->setValue(tl::expected(tl::unexpect, ErrorCode::INVALID_PARAMS)); + LOG(ERROR) + << "Forward RDMA write requires contiguous slice buffers, key=" + << write_req->key; + tl::expected err = + tl::make_unexpected(ErrorCode::INVALID_PARAMS); + promise->setValue(std::move(err)); co_return; } PreWriteRequest pre_req; @@ -1092,15 +1096,16 @@ async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( LOG(ERROR) << "AsyncPreWrite failed, key=" << write_req->key << ", error=" << pre.error(); } - promise->setValue(tl::expected(tl::unexpect, pre.error())); + tl::expected err = tl::make_unexpected(pre.error()); + promise->setValue(std::move(err)); co_return; } std::vector dest{pre.value().remote_buffer}; void* base = slices->front().ptr; - auto te = dm->TransferWithTeNoTierStaging( - base, TotalSliceBytes(*slices), dest, - Transport::TransferRequest::WRITE); + auto te = + dm->TransferWithTeNoTierStaging(base, TotalSliceBytes(*slices), dest, + Transport::TransferRequest::WRITE); if (!te) { LOG(ERROR) << "Forward TE write failed, key=" << write_req->key << ", error=" << te.error(); @@ -1108,7 +1113,8 @@ async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( revoke_req.key = write_req->key; revoke_req.pending_write_token = pre.value().pending_write_token; tl::expected revoke_res; - for (int attempt = 0; attempt < kForwardReadUnpinMaxAttempts; ++attempt) { + for (int attempt = 0; attempt < kForwardReadUnpinMaxAttempts; + ++attempt) { revoke_res = co_await peer->AsyncWriteRevoke(revoke_req); if (revoke_res) { break; @@ -1125,9 +1131,10 @@ async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( } if (!revoke_res) { LOG(ERROR) << "AsyncWriteRevoke failed after TE failure, key=" - << write_req->key << ", error=" << revoke_res.error(); + << write_req->key << ", error=" << revoke_res.error(); } - promise->setValue(tl::expected(tl::unexpect, te.error())); + tl::expected err = tl::make_unexpected(te.error()); + promise->setValue(std::move(err)); co_return; } @@ -1136,7 +1143,8 @@ async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( commit.pending_write_token = pre.value().pending_write_token; auto cm = co_await peer->AsyncWriteCommit(commit); if (!cm) { - promise->setValue(tl::expected(tl::unexpect, cm.error())); + tl::expected err = tl::make_unexpected(cm.error()); + promise->setValue(std::move(err)); co_return; } promise->setValue(tl::expected{}); @@ -1364,8 +1372,7 @@ P2PClientService::BatchCreateGetHandles( std::vector> P2PClientService::BatchCreateGetHandles( const std::vector& keys, - std::vector>& all_slices, - const ReadConfigExt& config) { + std::vector>& all_slices, const ReadConfigExt& config) { auto local_get = [&](std::string_view key, size_t i) -> tl::expected { if (!data_manager_.has_value()) { @@ -1468,7 +1475,8 @@ P2PClientService::BatchFetchReadRoutes( // Single batch RPC to master std::vector> responses; - responses = master_client_.BatchGetReplicaList(miss_keys, config.route_config); + responses = + master_client_.BatchGetReplicaList(miss_keys, config.route_config); for (size_t k = 0; k < responses.size(); ++k) { if (!responses[k]) { if (responses[k].error() != ErrorCode::OBJECT_NOT_FOUND) { @@ -1567,8 +1575,8 @@ tl::expected P2PClientService::CreateRemoteGetHandle( auto read_buf = std::make_shared(std::move(*alloc_result)); std::vector slices = {{read_buf->ptr(), object_size}}; - auto result = - InnerGetViaRoute(key, slices, std::move(*iter), config.rdma_direction_mode); + auto result = InnerGetViaRoute(key, slices, std::move(*iter), + config.rdma_direction_mode); if (!result) { LOG(ERROR) << "Failed to get via route, key=" << key << ", error=" << result.error(); @@ -1592,8 +1600,8 @@ tl::expected P2PClientService::CreateRemoteGetHandle( } return tl::unexpected(iter.error()); } - auto result = - InnerGetViaRoute(key, slices, std::move(*iter), config.rdma_direction_mode); + auto result = InnerGetViaRoute(key, slices, std::move(*iter), + config.rdma_direction_mode); if (!result) { LOG(ERROR) << "Failed to get via route, key=" << key << ", error=" << result.error(); @@ -1637,13 +1645,17 @@ async_simple::coro::Lazy P2PClientService::RunForwardReadOnRoute( RouteIterator& iter, ErrorCode& final_result) { if (!data_manager_.has_value()) { LOG(ERROR) << "Forward RDMA read requires DataManager"; - promise->setValue(tl::expected(tl::unexpect, ErrorCode::INTERNAL_ERROR)); + tl::expected err1 = + tl::make_unexpected(ErrorCode::INTERNAL_ERROR); + promise->setValue(std::move(err1)); co_return true; } if (!RemoteDestBuffersContiguous(req->dest_buffers)) { LOG(ERROR) << "Forward RDMA read requires contiguous dest buffers, key=" << req->key; - promise->setValue(tl::expected(tl::unexpect, ErrorCode::INVALID_PARAMS)); + tl::expected err2 = + tl::make_unexpected(ErrorCode::INVALID_PARAMS); + promise->setValue(std::move(err2)); co_return true; } PinKeyRequest pin_req; @@ -1685,14 +1697,15 @@ async_simple::coro::Lazy P2PClientService::RunForwardReadOnRoute( break; } if (attempt + 1 < kForwardReadUnpinMaxAttempts) { - LOG(WARNING) << "AsyncUnPinKey retry after TE failure, key=" - << req->key << ", attempt=" << (attempt + 1) - << ", error=" << cleanup_unpin.error(); + LOG(WARNING) + << "AsyncUnPinKey retry after TE failure, key=" << req->key + << ", attempt=" << (attempt + 1) + << ", error=" << cleanup_unpin.error(); } } if (!cleanup_unpin) { LOG(ERROR) << "AsyncUnPinKey failed after TE read failure, key=" - << req->key << ", error=" << cleanup_unpin.error(); + << req->key << ", error=" << cleanup_unpin.error(); } iter.Evict(route); co_return false; @@ -1739,8 +1752,8 @@ async_simple::coro::Lazy P2PClientService::RunReadWithRetry( while (auto route = co_await iter.AsyncNext()) { try { if (rdma_direction_mode == RdmaDirectionMode::FORWARD) { - if (co_await RunForwardReadOnRoute(*route, req, promise, iter, - final_result)) { + if (co_await RunForwardReadOnRoute(*route, req, promise, + iter, final_result)) { co_return; } continue; @@ -1980,7 +1993,8 @@ tl::expected, ErrorCode> P2PClientService::Query( } // Query master for replica list - auto result = master_client_.GetReplicaList(object_key, config.route_config); + auto result = + master_client_.GetReplicaList(object_key, config.route_config); if (!result) { LOG(WARNING) << "fail to get replica list" << ", key=" << object_key << ", error=" << result.error(); diff --git a/mooncake-store/src/peer_client.cpp b/mooncake-store/src/peer_client.cpp index 516a8d6b3c..ab558cccf9 100644 --- a/mooncake-store/src/peer_client.cpp +++ b/mooncake-store/src/peer_client.cpp @@ -193,8 +193,7 @@ PeerClient::AsyncUnPinKey(const UnPinKeyRequest& request) { auto result = co_await std::move(ret.value()); if (!result) { - LOG(ERROR) << "AsyncUnPinKey: RPC call failed: " - << result.error().msg; + LOG(ERROR) << "AsyncUnPinKey: RPC call failed: " << result.error().msg; co_return tl::make_unexpected(ErrorCode::RPC_FAIL); } diff --git a/mooncake-store/tests/client_rpc_service_test.cpp b/mooncake-store/tests/client_rpc_service_test.cpp index 4aae44c05f..4745484b04 100644 --- a/mooncake-store/tests/client_rpc_service_test.cpp +++ b/mooncake-store/tests/client_rpc_service_test.cpp @@ -323,9 +323,8 @@ TEST_F(ClientRpcServiceTest, PinKeyAfterUnpinNewToken) { << "first UnPinKey failed: " << static_cast(un1.error()); auto pin2 = rpc_service_->PinKey(pin_req); - ASSERT_TRUE(pin2.has_value()) - << "second PinKey after unpin failed: " - << static_cast(pin2.error()); + ASSERT_TRUE(pin2.has_value()) << "second PinKey after unpin failed: " + << static_cast(pin2.error()); EXPECT_NE(pin1->pin_token, pin2->pin_token); UnPinKeyRequest unpin2; @@ -383,9 +382,8 @@ TEST_F(ClientRpcServiceTest, UnPinKeyWrongTokenAfterPin) { ok.key = key; ok.pin_token = pin_res->pin_token; auto ok_res = rpc_service_->UnPinKey(ok); - ASSERT_TRUE(ok_res.has_value()) - << "UnPinKey with correct token failed: " - << static_cast(ok_res.error()); + ASSERT_TRUE(ok_res.has_value()) << "UnPinKey with correct token failed: " + << static_cast(ok_res.error()); } // ============================================================================ diff --git a/mooncake-store/tests/p2p_client_integration_test.cpp b/mooncake-store/tests/p2p_client_integration_test.cpp index 0d04b4b295..f6d52e0573 100644 --- a/mooncake-store/tests/p2p_client_integration_test.cpp +++ b/mooncake-store/tests/p2p_client_integration_test.cpp @@ -656,8 +656,8 @@ TEST_F(P2PClientIntegrationTest, ForwardRemotePutAndGet) { rcfg.rdma_direction_mode = RdmaDirectionMode::FORWARD; std::vector buf(payload.size(), 0); - auto get_res = remote_writer->Get(key, {(void*)buf.data()}, {buf.size()}, - rcfg); + auto get_res = + remote_writer->Get(key, {(void*)buf.data()}, {buf.size()}, rcfg); ASSERT_TRUE(get_res.has_value()) << "Forward Get failed mode=" << mode << " err=" << static_cast(get_res.error()); diff --git a/mooncake-store/tests/peer_client_test.cpp b/mooncake-store/tests/peer_client_test.cpp index 767b64ef3e..49b281b271 100644 --- a/mooncake-store/tests/peer_client_test.cpp +++ b/mooncake-store/tests/peer_client_test.cpp @@ -281,8 +281,7 @@ TEST_F(PeerClientTest, AsyncPinKeyEmptyKey) { req.key = ""; req.target_tier_id = std::nullopt; - auto result = - async_simple::coro::syncAwait(peer_client_->AsyncPinKey(req)); + auto result = async_simple::coro::syncAwait(peer_client_->AsyncPinKey(req)); ASSERT_FALSE(result.has_value()); EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); } @@ -292,8 +291,7 @@ TEST_F(PeerClientTest, AsyncPinKeyKeyNotFound) { req.key = "peer_async_pin_missing_key"; req.target_tier_id = std::nullopt; - auto result = - async_simple::coro::syncAwait(peer_client_->AsyncPinKey(req)); + auto result = async_simple::coro::syncAwait(peer_client_->AsyncPinKey(req)); ASSERT_FALSE(result.has_value()); EXPECT_EQ(result.error(), ErrorCode::OBJECT_NOT_FOUND); } @@ -359,13 +357,11 @@ TEST_F(PeerClientTest, AsyncPinKeyTwiceSameTokenThenUnpinTwice) { UnPinKeyRequest unpin; unpin.key = key; unpin.pin_token = first->pin_token; - auto u1 = - async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin)); + auto u1 = async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin)); ASSERT_TRUE(u1.has_value()) << "first AsyncUnPinKey failed: " << static_cast(u1.error()); - auto u2 = - async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin)); + auto u2 = async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin)); ASSERT_TRUE(u2.has_value()) << "second AsyncUnPinKey failed: " << static_cast(u2.error()); } @@ -400,9 +396,8 @@ TEST_F(PeerClientTest, AsyncPinKeyAfterUnpinNewToken) { auto pin2 = async_simple::coro::syncAwait(peer_client_->AsyncPinKey(pin_req)); - ASSERT_TRUE(pin2.has_value()) - << "second AsyncPinKey after unpin failed: " - << static_cast(pin2.error()); + ASSERT_TRUE(pin2.has_value()) << "second AsyncPinKey after unpin failed: " + << static_cast(pin2.error()); EXPECT_NE(pin1->pin_token, pin2->pin_token); UnPinKeyRequest unpin2; @@ -564,8 +559,8 @@ TEST_F(PeerClientTest, AsyncPreWriteEmptyKey) { pre.size_bytes = 64; // target_tier_id optional; invalid key fails before tier selection. - auto result = async_simple::coro::syncAwait( - peer_client_->AsyncPreWrite(pre)); + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); ASSERT_FALSE(result.has_value()); EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); } @@ -576,8 +571,8 @@ TEST_F(PeerClientTest, AsyncPreWriteZeroSize) { pre.key = key; pre.size_bytes = 0; - auto result = async_simple::coro::syncAwait( - peer_client_->AsyncPreWrite(pre)); + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); ASSERT_FALSE(result.has_value()); EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); } @@ -592,8 +587,8 @@ TEST_F(PeerClientTest, AsyncPreWriteValidRequest) { pre.size_bytes = 256; pre.target_tier_id = tier_id; - auto result = async_simple::coro::syncAwait( - peer_client_->AsyncPreWrite(pre)); + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); ASSERT_TRUE(result.has_value()) << "AsyncPreWrite failed: " << static_cast(result.error()); EXPECT_GT(result->remote_buffer.size, 0u); @@ -603,8 +598,8 @@ TEST_F(PeerClientTest, AsyncPreWriteValidRequest) { WriteRevokeRequest revoke; revoke.key = key; revoke.pending_write_token = result->pending_write_token; - auto rev = async_simple::coro::syncAwait( - peer_client_->AsyncWriteRevoke(revoke)); + auto rev = + async_simple::coro::syncAwait(peer_client_->AsyncWriteRevoke(revoke)); ASSERT_TRUE(rev.has_value()) << "Cleanup AsyncWriteRevoke failed: " << static_cast(rev.error()); } @@ -622,8 +617,8 @@ TEST_F(PeerClientTest, AsyncPreWriteWhenObjectAlreadyExists) { pre.key = key; pre.size_bytes = 128; - auto result = async_simple::coro::syncAwait( - peer_client_->AsyncPreWrite(pre)); + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); ASSERT_FALSE(result.has_value()); EXPECT_EQ(result.error(), ErrorCode::OBJECT_ALREADY_EXISTS); } @@ -643,14 +638,14 @@ TEST_F(PeerClientTest, AsyncWriteCommitAfterPreWrite) { ASSERT_TRUE(pre_res.has_value()) << "AsyncPreWrite failed: " << static_cast(pre_res.error()); - // TransferEngine is not initialized in this unit test, so no real data-plane - // write occurs. Assume the access side has already filled the buffer via TE; - // this case only checks WriteCommit RPC / metadata outcome. + // TransferEngine is not initialized in this unit test, so no real + // data-plane write occurs. Assume the access side has already filled the + // buffer via TE; this case only checks WriteCommit RPC / metadata outcome. WriteCommitRequest commit; commit.key = key; commit.pending_write_token = pre_res->pending_write_token; - auto commit_res = async_simple::coro::syncAwait( - peer_client_->AsyncWriteCommit(commit)); + auto commit_res = + async_simple::coro::syncAwait(peer_client_->AsyncWriteCommit(commit)); ASSERT_TRUE(commit_res.has_value()) << "AsyncWriteCommit failed: " << static_cast(commit_res.error()); } @@ -660,8 +655,8 @@ TEST_F(PeerClientTest, AsyncWriteCommitEmptyKey) { commit.key = ""; commit.pending_write_token = {1, 2}; - auto result = async_simple::coro::syncAwait( - peer_client_->AsyncWriteCommit(commit)); + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncWriteCommit(commit)); ASSERT_FALSE(result.has_value()); EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); } @@ -672,8 +667,8 @@ TEST_F(PeerClientTest, AsyncWriteCommitZeroToken) { commit.key = key; commit.pending_write_token = {0, 0}; - auto result = async_simple::coro::syncAwait( - peer_client_->AsyncWriteCommit(commit)); + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncWriteCommit(commit)); ASSERT_FALSE(result.has_value()); EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); } @@ -698,16 +693,16 @@ TEST_F(PeerClientTest, AsyncWriteCommitTokenMismatchAfterPreWrite) { wrong_token.first += 1; commit.pending_write_token = wrong_token; - auto bad = async_simple::coro::syncAwait( - peer_client_->AsyncWriteCommit(commit)); + auto bad = + async_simple::coro::syncAwait(peer_client_->AsyncWriteCommit(commit)); ASSERT_FALSE(bad.has_value()); EXPECT_EQ(bad.error(), ErrorCode::INVALID_WRITE); WriteRevokeRequest revoke; revoke.key = key; revoke.pending_write_token = pre_res->pending_write_token; - auto rev = async_simple::coro::syncAwait( - peer_client_->AsyncWriteRevoke(revoke)); + auto rev = + async_simple::coro::syncAwait(peer_client_->AsyncWriteRevoke(revoke)); ASSERT_TRUE(rev.has_value()); } @@ -716,8 +711,8 @@ TEST_F(PeerClientTest, AsyncWriteRevokeEmptyKey) { request.key = ""; request.pending_write_token = {1, 2}; - auto result = async_simple::coro::syncAwait( - peer_client_->AsyncWriteRevoke(request)); + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncWriteRevoke(request)); ASSERT_FALSE(result.has_value()); EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); } @@ -728,8 +723,8 @@ TEST_F(PeerClientTest, AsyncWriteRevokeZeroToken) { request.key = key; request.pending_write_token = {0, 0}; - auto result = async_simple::coro::syncAwait( - peer_client_->AsyncWriteRevoke(request)); + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncWriteRevoke(request)); ASSERT_FALSE(result.has_value()); EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); } @@ -752,8 +747,8 @@ TEST_F(PeerClientTest, AsyncPreWriteThenWriteRevokeIdempotent) { WriteRevokeRequest revoke; revoke.key = key; revoke.pending_write_token = pre_res->pending_write_token; - auto rev_res = async_simple::coro::syncAwait( - peer_client_->AsyncWriteRevoke(revoke)); + auto rev_res = + async_simple::coro::syncAwait(peer_client_->AsyncWriteRevoke(revoke)); ASSERT_TRUE(rev_res.has_value()) << "AsyncWriteRevoke failed: " << static_cast(rev_res.error()); @@ -926,9 +921,8 @@ TEST_F(PeerClientTest, SyncPinKeyAfterUnpinNewToken) { << "first UnPinKey failed: " << static_cast(un1.error()); auto pin2 = peer_client_->PinKey(pin_req); - ASSERT_TRUE(pin2.has_value()) - << "second PinKey after unpin failed: " - << static_cast(pin2.error()); + ASSERT_TRUE(pin2.has_value()) << "second PinKey after unpin failed: " + << static_cast(pin2.error()); EXPECT_NE(pin1->pin_token, pin2->pin_token); UnPinKeyRequest unpin2; @@ -1024,9 +1018,9 @@ TEST_F(PeerClientTest, SyncWriteCommitAfterPreWrite) { ASSERT_TRUE(pre_res.has_value()) << "PreWrite failed: " << static_cast(pre_res.error()); - // TransferEngine is not initialized in this unit test, so no real data-plane - // write occurs. Assume the access side has already filled the buffer via TE; - // this case only checks WriteCommit RPC / metadata outcome. + // TransferEngine is not initialized in this unit test, so no real + // data-plane write occurs. Assume the access side has already filled the + // buffer via TE; this case only checks WriteCommit RPC / metadata outcome. WriteCommitRequest commit; commit.key = key; commit.pending_write_token = pre_res->pending_write_token; From 92f08293efcc9ab3bbbe363c30dbc7daf874b66a Mon Sep 17 00:00:00 2001 From: Shichang-Zhang Date: Thu, 14 May 2026 20:43:58 +0800 Subject: [PATCH 14/14] add errorcode NON_CONTIGUOUS_BUFFER_NOT_SUPPORTED for non-contiguous buffers check --- mooncake-store/include/types.h | 2 ++ mooncake-store/src/p2p_client_service.cpp | 4 ++-- mooncake-store/src/types.cpp | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/mooncake-store/include/types.h b/mooncake-store/include/types.h index 10f56ac448..0fa1a69d12 100644 --- a/mooncake-store/include/types.h +++ b/mooncake-store/include/types.h @@ -123,6 +123,8 @@ enum class ErrorCode : int32_t { // Parameter errors (Range: -600 to -699) INVALID_PARAMS = -600, ///< Invalid parameters. ILLEGAL_CLIENT = -601, ///< Illegal client to do the operation. + NON_CONTIGUOUS_BUFFER_NOT_SUPPORTED = + -602, ///< Non-contiguous buffer not supported in forward RDMA mode. // Engine operation errors (Range: -700 to -710) INVALID_WRITE = -700, ///< Invalid write operation. diff --git a/mooncake-store/src/p2p_client_service.cpp b/mooncake-store/src/p2p_client_service.cpp index 4b1516ba11..412d71a96f 100644 --- a/mooncake-store/src/p2p_client_service.cpp +++ b/mooncake-store/src/p2p_client_service.cpp @@ -1081,7 +1081,7 @@ async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( << "Forward RDMA write requires contiguous slice buffers, key=" << write_req->key; tl::expected err = - tl::make_unexpected(ErrorCode::INVALID_PARAMS); + tl::make_unexpected(ErrorCode::NON_CONTIGUOUS_BUFFER_NOT_SUPPORTED); promise->setValue(std::move(err)); co_return; } @@ -1654,7 +1654,7 @@ async_simple::coro::Lazy P2PClientService::RunForwardReadOnRoute( LOG(ERROR) << "Forward RDMA read requires contiguous dest buffers, key=" << req->key; tl::expected err2 = - tl::make_unexpected(ErrorCode::INVALID_PARAMS); + tl::make_unexpected(ErrorCode::NON_CONTIGUOUS_BUFFER_NOT_SUPPORTED); promise->setValue(std::move(err2)); co_return true; } diff --git a/mooncake-store/src/types.cpp b/mooncake-store/src/types.cpp index 2506c18738..1f8d912d02 100644 --- a/mooncake-store/src/types.cpp +++ b/mooncake-store/src/types.cpp @@ -21,6 +21,8 @@ const std::string& toString(ErrorCode errorCode) noexcept { {ErrorCode::WRITE_FAIL, "WRITE_FAIL"}, {ErrorCode::INVALID_PARAMS, "INVALID_PARAMS"}, {ErrorCode::ILLEGAL_CLIENT, "ILLEGAL_CLIENT"}, + {ErrorCode::NON_CONTIGUOUS_BUFFER_NOT_SUPPORTED, + "NON_CONTIGUOUS_BUFFER_NOT_SUPPORTED"}, {ErrorCode::INVALID_WRITE, "INVALID_WRITE"}, {ErrorCode::INVALID_READ, "INVALID_READ"}, {ErrorCode::INVALID_REPLICA, "INVALID_REPLICA"},