diff --git a/mooncake-integration/store/store_py.cpp b/mooncake-integration/store/store_py.cpp index c5494c807d..3d289e99d3 100644 --- a/mooncake-integration/store/store_py.cpp +++ b/mooncake-integration/store/store_py.cpp @@ -5,6 +5,8 @@ #include "pyclient.h" #include "dummy_client.h" #include "real_client.h" +#include "p2p_rpc_types.h" +#include "rpc_types.h" #include // for atexit #include @@ -231,8 +233,8 @@ class MooncakeStorePyWrapper { pybind11::bytes get( const std::string& key, - const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); if (!is_client_initialized()) { LOG(ERROR) << "Client is not initialized"; return pybind11::bytes("\\0", 0); @@ -268,8 +270,8 @@ class MooncakeStorePyWrapper { std::vector get_batch( const std::vector& keys, - const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); const auto kNullString = pybind11::bytes("\\0", 0); if (!is_client_initialized()) { LOG(ERROR) << "Client is not initialized"; @@ -302,7 +304,7 @@ class MooncakeStorePyWrapper { pybind11::object get_tensor_with_tp( const std::string& key, int tp_rank = 0, int tp_size = 1, int split_dim = 0, - const std::optional& config_opt = std::nullopt) { + const std::optional& config_opt = std::nullopt) { if (tp_size <= 1) return get_tensor(key, config_opt); return get_tensor(get_tp_key_name(key, tp_rank), config_opt); } @@ -310,7 +312,7 @@ class MooncakeStorePyWrapper { pybind11::list batch_get_tensor_with_tp( const std::vector& base_keys, int tp_rank = 0, int tp_size = 1, - const std::optional& config_opt = std::nullopt) { + const std::optional& config_opt = std::nullopt) { if (tp_size <= 1) return batch_get_tensor(base_keys, config_opt); std::vector shard_keys; @@ -323,8 +325,8 @@ class MooncakeStorePyWrapper { pybind11::object get_tensor( const std::string& key, - const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); if (!is_client_initialized() || use_dummy_client_) { LOG(ERROR) << "Client not initialized or Dummy client not " "supported for tensors"; @@ -342,8 +344,8 @@ class MooncakeStorePyWrapper { pybind11::list batch_get_tensor( const std::vector& keys, - const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); if (!is_client_initialized() || use_dummy_client_) { LOG(ERROR) << "Client not initialized or Dummy client not " "supported for tensors"; @@ -367,8 +369,8 @@ class MooncakeStorePyWrapper { pybind11::object get_tensor_into( const std::string& key, uintptr_t buffer_ptr, size_t size, - const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); char* buffer = reinterpret_cast(buffer_ptr); if (!is_client_initialized()) { LOG(ERROR) << "Client is not initialized"; @@ -396,8 +398,8 @@ class MooncakeStorePyWrapper { const std::vector& keys, const std::vector& buffer_ptrs, const std::vector& sizes, - const std::optional& config_opt = std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); std::vector buffers; buffers.reserve(buffer_ptrs.size()); for (uintptr_t ptr : buffer_ptrs) { @@ -460,7 +462,7 @@ class MooncakeStorePyWrapper { pybind11::object get_tensor_with_tp_into( const std::string& key, uintptr_t buffer_ptr, size_t size, int tp_rank = 0, int tp_size = 1, int split_dim = 0, - const std::optional& config_opt = std::nullopt) { + const std::optional& config_opt = std::nullopt) { if (!is_client_initialized()) { LOG(ERROR) << "Client is not initialized"; return pybind11::none(); @@ -487,7 +489,7 @@ class MooncakeStorePyWrapper { const std::vector& base_keys, const std::vector& buffer_ptrs, const std::vector& sizes, int tp_rank = 0, int tp_size = 1, - const std::optional& config_opt = std::nullopt) { + const std::optional& config_opt = std::nullopt) { if (!is_client_initialized()) { LOG(ERROR) << "Client is not initialized"; py::list empty_list; @@ -877,6 +879,19 @@ PYBIND11_MODULE(store, m) { .def_readwrite("max_candidates", &ReadRouteConfig::max_candidates) .def_readwrite("p2p_config", &ReadRouteConfig::p2p_config); + py::enum_(m, "RdmaDirectionMode") + .value("REVERSE", RdmaDirectionMode::REVERSE) + .value("FORWARD", RdmaDirectionMode::FORWARD); + + py::class_(m, "ReadConfigExt") + .def(py::init<>()) + .def(py::init()) + .def_readwrite("route_config", &ReadConfigExt::route_config) + .def_readwrite("rdma_direction_mode", + &ReadConfigExt::rdma_direction_mode); + + py::implicitly_convertible(); + py::class_(m, "WriteRouteRequestConfig") .def(py::init<>()) // Default constructor .def_readwrite("max_candidates", @@ -894,6 +909,13 @@ PYBIND11_MODULE(store, m) { return oss.str(); }); + py::class_(m, "WriteConfigExt") + .def(py::init<>()) + .def(py::init()) + .def_readwrite("route_config", &WriteConfigExt::route_config) + .def_readwrite("rdma_direction_mode", + &WriteConfigExt::rdma_direction_mode); + py::enum_(m, "ReplicaStatus") .value("UNDEFINED", ReplicaStatus::UNDEFINED) .value("INITIALIZED", ReplicaStatus::INITIALIZED) @@ -1152,8 +1174,8 @@ PYBIND11_MODULE(store, m) { .def( "get_buffer", [](MooncakeStorePyWrapper& self, const std::string& key, - const std::optional& config_opt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); py::gil_scoped_release release; return self.store_->get_buffer(key, config); }, @@ -1163,8 +1185,8 @@ PYBIND11_MODULE(store, m) { "batch_get_buffer", [](MooncakeStorePyWrapper& self, const std::vector& keys, - const std::optional& config_opt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); py::gil_scoped_release release; if (self.use_dummy_client_) { LOG(ERROR) << "batch_get_buffer is not supported for dummy " @@ -1234,7 +1256,10 @@ PYBIND11_MODULE(store, m) { " tp_size: The total tensor parallel size (default 1).\n" " split_dim: The dimension to split the tensor along " "(default 0).\n" - " config: ReadRouteConfig.") + " config: ReadConfigExt (optional; omit for defaults). " + "ReadRouteConfig is accepted and treated as ReadConfigExt with " + "default RdmaDirectionMode.REVERSE. Set config.rdma_direction_mode " + "to FORWARD for forward RDMA read.") .def("batch_get_tensor_with_tp", &MooncakeStorePyWrapper::batch_get_tensor_with_tp, py::arg("base_keys"), py::arg("tp_rank") = 0, @@ -1296,7 +1321,10 @@ PYBIND11_MODULE(store, m) { " tp_size: The total tensor parallel size (default 1).\n" " split_dim: The dimension to split the tensor along" "(default 0).\n" - " config: ReadRouteConfig.") + " config: ReadConfigExt (optional; omit for defaults). " + "ReadRouteConfig is accepted and treated as ReadConfigExt with " + "default RdmaDirectionMode.REVERSE. Set config.rdma_direction_mode " + "to FORWARD for forward RDMA read.") .def( "batch_get_tensor_with_tp_into", &MooncakeStorePyWrapper::batch_get_tensor_with_tp_into, @@ -1331,8 +1359,8 @@ PYBIND11_MODULE(store, m) { "get_into", [](MooncakeStorePyWrapper& self, const std::string& key, uintptr_t buffer_ptr, size_t size, - const std::optional& config_opt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); // Get data directly into user-provided buffer void* buffer = reinterpret_cast(buffer_ptr); py::gil_scoped_release release; @@ -1352,8 +1380,8 @@ PYBIND11_MODULE(store, m) { const std::vector& keys, const std::vector& buffer_ptrs, const std::vector& sizes, - const std::optional& config_opt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); std::vector buffers; buffers.reserve(buffer_ptrs.size()); for (uintptr_t ptr : buffer_ptrs) { @@ -1543,9 +1571,8 @@ PYBIND11_MODULE(store, m) { const std::vector>& all_buffer_ptrs, const std::vector>& all_sizes, bool aggregate_same_segment_task = false, - const std::optional& config_opt = - std::nullopt) { - ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{}); + const std::optional& config_opt = std::nullopt) { + ReadConfigExt config = config_opt.value_or(ReadConfigExt{}); py::gil_scoped_release release; if (self.use_dummy_client_) { LOG(ERROR) diff --git a/mooncake-store/include/centralized_client_service.h b/mooncake-store/include/centralized_client_service.h index 6df407acc0..a840959732 100644 --- a/mooncake-store/include/centralized_client_service.h +++ b/mooncake-store/include/centralized_client_service.h @@ -54,11 +54,11 @@ class CentralizedClientService tl::expected, ErrorCode> Query( const std::string& object_key, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector, ErrorCode>> BatchQuery(const std::vector& object_keys, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; tl::expected IsExist(const std::string& key) override; @@ -76,24 +76,24 @@ class CentralizedClientService tl::expected Get( const std::string& key, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector> BatchGet( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - const ReadRouteConfig& config = {}, + const ReadConfigExt& config = {}, bool aggregate_same_segment_task = false) override; tl::expected, ErrorCode> Get( const std::string& key, std::shared_ptr allocator, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector, ErrorCode>> BatchGet(const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; tl::expected Put(const ObjectKey& key, std::vector& slices, diff --git a/mooncake-store/include/client_rpc_service.h b/mooncake-store/include/client_rpc_service.h index 288df6ee23..9b0bf4e476 100644 --- a/mooncake-store/include/client_rpc_service.h +++ b/mooncake-store/include/client_rpc_service.h @@ -55,6 +55,20 @@ class ClientRpcService { tl::expected WriteRemoteData( const RemoteWriteRequest& request); + tl::expected PreWrite( + const PreWriteRequest& request); + + tl::expected WriteCommit( + const WriteCommitRequest& request); + + tl::expected WriteRevoke( + const WriteRevokeRequest& request); + + tl::expected PinKey( + const PinKeyRequest& request); + + tl::expected UnPinKey(const UnPinKeyRequest& request); + private: DataManager& data_manager_; // Reference: owned by Client, same lifetime P2PClientMetric* metrics_; // Optional: owned by P2PClientService diff --git a/mooncake-store/include/client_rpc_types.h b/mooncake-store/include/client_rpc_types.h index a9f2263caf..c67b02564e 100644 --- a/mooncake-store/include/client_rpc_types.h +++ b/mooncake-store/include/client_rpc_types.h @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "types.h" #include "ylt/struct_json/json_reader.h" @@ -84,4 +85,56 @@ struct BatchRemoteWriteRequest { YLT_REFL(BatchRemoteWriteRequest, keys, src_buffers_list, target_tier_ids); +struct PreWriteRequest { + std::string_view key; + uint64_t size_bytes = 0; + std::optional target_tier_id; +}; + +YLT_REFL(PreWriteRequest, key, size_bytes, target_tier_id); + +struct PreWriteResponse { + RemoteBufferDesc remote_buffer; + UUID pending_write_token; +}; + +YLT_REFL(PreWriteResponse, remote_buffer, pending_write_token); + +struct WriteCommitRequest { + std::string_view key; + UUID pending_write_token; +}; + +YLT_REFL(WriteCommitRequest, key, pending_write_token); + +/** Drops a pending PreWrite allocation without committing (e.g. after TE + * failure). */ +struct WriteRevokeRequest { + std::string_view key; + UUID pending_write_token; +}; + +YLT_REFL(WriteRevokeRequest, key, pending_write_token); + +struct PinKeyRequest { + std::string_view key; + std::optional target_tier_id; +}; + +YLT_REFL(PinKeyRequest, key, target_tier_id); + +struct PinKeyResponse { + RemoteBufferDesc remote_buffer; + UUID pin_token; +}; + +YLT_REFL(PinKeyResponse, remote_buffer, pin_token); + +struct UnPinKeyRequest { + std::string_view key; + UUID pin_token; +}; + +YLT_REFL(UnPinKeyRequest, key, pin_token); + } // namespace mooncake diff --git a/mooncake-store/include/client_service.h b/mooncake-store/include/client_service.h index 955cdc39cb..8728ea0b09 100644 --- a/mooncake-store/include/client_service.h +++ b/mooncake-store/include/client_service.h @@ -18,6 +18,7 @@ #include "transfer_engine.h" #include "types.h" #include "p2p_rpc_types.h" +#include "rpc_types.h" #include "replica.h" #include "master_client.h" #include @@ -27,7 +28,8 @@ namespace mooncake { -using WriteConfig = std::variant; +using WriteConfig = + std::variant; /** * @brief Result of a query operation containing replica information @@ -117,7 +119,7 @@ class ClientService { * indicating failure */ virtual tl::expected, ErrorCode> Query( - const std::string& object_key, const ReadRouteConfig& config = {}) = 0; + const std::string& object_key, const ReadConfigExt& config = {}) = 0; /** * @brief Batch query object metadata without transferring data @@ -126,7 +128,7 @@ class ClientService { */ virtual std::vector, ErrorCode>> BatchQuery(const std::vector& object_keys, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; /** * @brief Gets data with memory allocation @@ -139,12 +141,12 @@ class ClientService { virtual tl::expected, ErrorCode> Get( const std::string& key, std::shared_ptr allocator, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; virtual std::vector, ErrorCode>> BatchGet(const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; /** * @brief Gets data into user-provided buffers without memory allocation @@ -156,8 +158,7 @@ class ClientService { */ virtual tl::expected Get( const std::string& key, const std::vector& buffers, - const std::vector& sizes, - const ReadRouteConfig& config = {}) = 0; + const std::vector& sizes, const ReadConfigExt& config = {}) = 0; /** * @brief Batch get data into user-provided buffers @@ -175,7 +176,7 @@ class ClientService { const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - const ReadRouteConfig& config = {}, + const ReadConfigExt& config = {}, bool aggregate_same_segment_task = false) = 0; /** diff --git a/mooncake-store/include/data_manager.h b/mooncake-store/include/data_manager.h index a34f67dbab..d83ba4672b 100644 --- a/mooncake-store/include/data_manager.h +++ b/mooncake-store/include/data_manager.h @@ -1,11 +1,18 @@ #pragma once +#include +#include +#include +#include +#include +#include +#include #include -#include #include #include -#include -#include +#include +#include +#include #include #include "async_memcpy_executor.h" #include "client_buffer.hpp" @@ -58,6 +65,8 @@ class DataManager { friend class DataManagerTest; public: + using TimePoint = std::chrono::time_point; + /** * @brief Constructor * @param tiered_backend Unique pointer to TieredBackend instance (takes @@ -70,14 +79,9 @@ class DataManager { size_t lock_shard_count = 1024, const LocalTransferConfig& local_transfer_config = {}); - void Stop() { - if (async_memcpy_executor_) { - async_memcpy_executor_->Shutdown(); - } - if (tiered_backend_) { - tiered_backend_->Stop(); - } - } + ~DataManager(); + + void Stop(); /** * @brief Cleanup: delegates to TieredBackend::Destroy(). @@ -186,6 +190,40 @@ class DataManager { std::string_view key, const std::vector& src_buffers, std::optional tier_id = std::nullopt); + tl::expected PreWrite( + std::string_view key, size_t size_bytes, + std::optional tier_id = std::nullopt); + + tl::expected WriteCommit(std::string_view key, + const UUID& pending_write_token); + + /** + * @brief Remove pending write record for key + token (no tier Commit). + * Used when forward TE fails after PreWrite on the peer. + */ + tl::expected WriteRevoke(std::string_view key, + const UUID& pending_write_token); + + tl::expected PinKey( + std::string_view key, std::optional tier_id = std::nullopt); + + tl::expected UnPinKey(std::string_view key, + const UUID& pin_token); + + /** + * @brief TE transfer without tier DRAM staging (PrepareDRAM*). + * + * Caller guarantees `local_transfer_base` covers a contiguous layout of + * `total_size` bytes that is valid for TransferEngine (typically registered + * DRAM). Used by forward RDMA paths where buffers are already TE-ready. + * + * @param opcode WRITE: local -> peer_buffers; READ: peer_buffers -> local + */ + tl::expected TransferWithTeNoTierStaging( + void* local_transfer_base, size_t total_size, + const std::vector& peer_buffers, + Transport::TransferRequest::OpCode opcode); + // ================================================================ // Utilities // ================================================================ @@ -212,6 +250,41 @@ class DataManager { std::optional tier_id = std::nullopt) const; private: + void ClearLeaseRecords(); + + // Forward declarations for nested shard structs used by internal helpers. + struct PendingWriteShard; + struct PinnedKeyShard; + + struct KeyCtx { + std::string_view key; + std::string key_string; + size_t hash = 0; + size_t pending_write_shard_idx = 0; + size_t pinned_key_shard_idx = 0; + }; + + KeyCtx BuildKeyCtx(std::string_view key) const; + PendingWriteShard& GetPendingWriteShard(const KeyCtx& ctx); + PinnedKeyShard& GetPinnedKeyShard(const KeyCtx& ctx); + + tl::expected PreWriteInternal( + const KeyCtx& ctx, size_t size_bytes, std::optional tier_id, + bool enforce_dram_allocation); + tl::expected WriteCommitInternal( + const KeyCtx& ctx, const UUID& pending_write_token); + tl::expected WriteRevokeInternal( + const KeyCtx& ctx, const UUID& pending_write_token); + tl::expected PinKeyInternal( + const KeyCtx& ctx, std::optional tier_id); + tl::expected UnPinKeyInternal(const KeyCtx& ctx, + const UUID& pin_token); + + tl::expected LookupPendingWriteHandleInternal( + const KeyCtx& ctx, const UUID& pending_write_token); + tl::expected LookupPinnedKeyHandleInternal( + const KeyCtx& ctx, const UUID& pin_token); + std::shared_mutex& GetKeyLock(std::string_view key) { size_t hash = std::hash{}(key); return lock_shards_[hash % lock_shard_count_]; @@ -277,6 +350,13 @@ class DataManager { const std::vector& remote_buffers, Transport::TransferRequest::OpCode opcode); + tl::expected< + std::vector>, + ErrorCode> + SubmitTeTransferBatches(void* transfer_ptr, size_t total_data_size, + const std::vector& remote_buffers, + Transport::TransferRequest::OpCode opcode); + /** * @brief Helper to wait for a transfer batch to complete * @param batch_id Batch ID to poll @@ -360,6 +440,65 @@ class DataManager { // Wait for all tasks to reach a terminal state, then free the batch. void CancelBatchTETask(Transport::BatchID batch_id, size_t num_tasks); + using OrderedDeadlineList = std::list>; + using OrderedDeadlineListIt = OrderedDeadlineList::iterator; + + struct PendingWriteRecord { + UUID pending_write_token{0, 0}; + TimePoint deadline{}; + AllocationHandle handle; + OrderedDeadlineListIt list_it; + }; + + struct PinnedKeyRecord { + UUID pin_token{0, 0}; + TimePoint deadline{}; + AllocationHandle handle; + uint32_t ref_count = 1; + OrderedDeadlineListIt list_it; + }; + + struct PendingWriteShard { + mutable std::shared_mutex mutex; + std::unordered_map by_key; + OrderedDeadlineList ordered_list; + }; + + struct PinnedKeyShard { + mutable std::shared_mutex mutex; + std::unordered_map by_key; + OrderedDeadlineList ordered_list; + }; + + size_t HashKey(std::string_view key) const; + PendingWriteShard& GetPendingWriteShard(std::string_view key); + PinnedKeyShard& GetPinnedKeyShard(std::string_view key); + const std::chrono::milliseconds& lease_duration() const { + return lease_duration_; + } + bool IsExpired(TimePoint deadline) const; + RemoteBufferDesc BuildRemoteBufferDesc( + const AllocationHandle& handle) const; + void LeaseScannerMain(); + void ShutdownLeaseScanner(); + size_t ScanExpiredPendingWrites(PendingWriteShard& shard, TimePoint now); + size_t ScanExpiredPinnedKeys(PinnedKeyShard& shard, TimePoint now); + bool ErasePendingWriteLocked(PendingWriteShard& shard, + const std::string& key); + bool ErasePinnedKeyLocked(PinnedKeyShard& shard, const std::string& key); + bool RemoveExpiredPendingWriteLocked(PendingWriteShard& shard, + const std::string& key, TimePoint now); + bool RemoveExpiredPinnedKeyLocked(PinnedKeyShard& shard, + const std::string& key, TimePoint now); + void TouchOrderedDeadlineNode(OrderedDeadlineList& ordered_list, + OrderedDeadlineListIt it, + const std::string& key, TimePoint deadline); + + tl::expected LookupPendingWriteHandle( + std::string_view key, const UUID& pending_write_token); + tl::expected LookupPinnedKeyHandle( + std::string_view key, const UUID& pin_token); + private: std::unique_ptr tiered_backend_; // Owned by DataManager std::shared_ptr transfer_engine_; // Shared with Client @@ -369,6 +508,8 @@ class DataManager { // (default: 1024) size_t lock_shard_count_; std::vector lock_shards_; + std::vector pending_write_shards_; + std::vector pinned_key_shards_; // Callback for rectifying stale read routes std::function)> @@ -376,6 +517,12 @@ class DataManager { LocalTransferConfig local_transfer_config_; std::unique_ptr async_memcpy_executor_; + std::chrono::milliseconds lease_duration_; + std::chrono::milliseconds lease_scan_interval_; + std::atomic lease_scanner_stop_requested_{false}; + std::condition_variable lease_scanner_cv_; + std::mutex lease_scanner_mutex_; + std::thread lease_scanner_thread_; }; } // namespace mooncake diff --git a/mooncake-store/include/dummy_client.h b/mooncake-store/include/dummy_client.h index c7efd5330d..eadf884054 100644 --- a/mooncake-store/include/dummy_client.h +++ b/mooncake-store/include/dummy_client.h @@ -75,19 +75,19 @@ class DummyClient : public PyClient { int unregister_buffer(void* buffer) override; int64_t get_into(const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector batch_get_into( const std::vector& keys, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector batch_get_into_multi_buffers( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, bool aggregate_same_segment_task, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; int put_from(const std::string& key, void* buffer, size_t size, const WriteConfig& config) override; @@ -109,14 +109,14 @@ class DummyClient : public PyClient { const WriteConfig& config) override; std::shared_ptr get_buffer( - const std::string& key, const ReadRouteConfig& config = {}) override; + const std::string& key, const ReadConfigExt& config = {}) override; std::tuple get_buffer_info( - const std::string& key, const ReadRouteConfig& config = {}) override; + const std::string& key, const ReadConfigExt& config = {}) override; std::vector> batch_get_buffer( const std::vector& keys, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; int put_parts(const std::string& key, std::vector> values, diff --git a/mooncake-store/include/p2p_client_service.h b/mooncake-store/include/p2p_client_service.h index 5571c1a251..38fba02261 100644 --- a/mooncake-store/include/p2p_client_service.h +++ b/mooncake-store/include/p2p_client_service.h @@ -91,7 +91,7 @@ class P2PClientService final : public ClientService { */ tl::expected, ErrorCode> Query( const std::string& object_key, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; /** * @brief Batch query object metadata without transferring data @@ -100,7 +100,7 @@ class P2PClientService final : public ClientService { */ std::vector, ErrorCode>> BatchQuery(const std::vector& object_keys, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; tl::expected IsExist(const std::string& key) override; @@ -114,23 +114,23 @@ class P2PClientService final : public ClientService { tl::expected, ErrorCode> Get( const std::string& key, std::shared_ptr allocator, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector, ErrorCode>> BatchGet(const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; tl::expected Get( const std::string& key, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; std::vector> BatchGet( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - const ReadRouteConfig& config = {}, + const ReadConfigExt& config = {}, bool aggregate_same_segment_task = false) override; /** @@ -245,7 +245,8 @@ class P2PClientService final : public ClientService { std::vector> InnerBatchPut( const std::vector& keys, std::vector>& batched_slices, - const WriteRouteRequestConfig& route_config); + const WriteRouteRequestConfig& route_config, + RdmaDirectionMode rdma_direction_mode); std::vector> InnerBatchPutDegraded( const std::vector& keys, @@ -254,12 +255,14 @@ class P2PClientService final : public ClientService { std::vector> InnerBatchPutNormal( const std::vector& keys, std::vector>& batched_slices, - const WriteRouteRequestConfig& route_config); + const WriteRouteRequestConfig& route_config, + RdmaDirectionMode rdma_direction_mode); std::vector>, ErrorCode>> CreatePutHandlesFromRoute(const std::vector& keys, std::vector>& batched_slices, const WriteRouteRequestConfig& route_config, + RdmaDirectionMode rdma_direction_mode, BatchGetWriteRouteResponse& batch_resp); tl::expected>, ErrorCode> @@ -295,19 +298,30 @@ class P2PClientService final : public ClientService { }; struct RemoteWriteOp : WriteOp { + P2PClientService* owner_service = nullptr; PeerClient* peer_ptr; std::shared_ptr write_req; P2PProxyDescriptor proxy; RouteCache* route_cache; std::string endpoint; - - RemoteWriteOp(PeerClient* p, std::shared_ptr wr, - P2PProxyDescriptor px, RouteCache* rc, std::string ep) - : peer_ptr(p), + DataManager* forward_dm = nullptr; + std::vector* forward_slices = nullptr; + RdmaDirectionMode rdma_direction_mode = RdmaDirectionMode::REVERSE; + + RemoteWriteOp(P2PClientService* owner, PeerClient* p, + std::shared_ptr wr, + P2PProxyDescriptor px, RouteCache* rc, std::string ep, + DataManager* fwd_dm, std::vector* fwd_slices, + RdmaDirectionMode rdma_mode) + : owner_service(owner), + peer_ptr(p), write_req(std::move(wr)), proxy(std::move(px)), route_cache(rc), - endpoint(std::move(ep)) {} + endpoint(std::move(ep)), + forward_dm(fwd_dm), + forward_slices(fwd_slices), + rdma_direction_mode(rdma_mode) {} std::string_view route() const override { return endpoint; } std::unique_ptr> Dispatch() override; @@ -316,6 +330,7 @@ class P2PClientService final : public ClientService { tl::expected>, ErrorCode> BuildWriteOps(std::string_view key, std::vector& slices, const WriteRouteRequestConfig& config, + RdmaDirectionMode rdma_direction_mode, std::vector candidates); async_simple::coro::Lazy RunWriteWithRetry( @@ -370,10 +385,10 @@ class P2PClientService final : public ClientService { const std::vector& replicas); tl::expected BuildRouteIter( - std::string_view key, const ReadRouteConfig& config); + std::string_view key, const ReadConfigExt& config); tl::expected BuildRouteIter( - std::string_view key, const ReadRouteConfig& config, + std::string_view key, const ReadConfigExt& config, std::vector pre_fetched); private: @@ -385,30 +400,30 @@ class P2PClientService final : public ClientService { std::vector> BatchCreateGetHandles( const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config); + const ReadConfigExt& config); std::vector> BatchCreateGetHandles( const std::vector& keys, std::vector>& all_slices, - const ReadRouteConfig& config); + const ReadConfigExt& config); template std::vector> BatchCreateGetHandlesImpl(const std::vector& keys, - const ReadRouteConfig& config, + const ReadConfigExt& config, LocalGetFn&& local_get, RemoteGetFn&& remote_get); std::vector, ErrorCode>> BatchFetchReadRoutes(const std::vector& keys, - const ReadRouteConfig& config); + const ReadConfigExt& config); tl::expected CreateRemoteGetHandle( std::string_view key, std::shared_ptr allocator, - const ReadRouteConfig& config, std::vector pre_fetched); + const ReadConfigExt& config, std::vector pre_fetched); tl::expected CreateRemoteGetHandle( std::string_view key, std::vector& slices, - const ReadRouteConfig& config, std::vector pre_fetched); + const ReadConfigExt& config, std::vector pre_fetched); /** * @brief Launch async reads driven by a RouteIterator. @@ -417,16 +432,44 @@ class P2PClientService final : public ClientService { * and chains subsequent candidates on failure (no stack recursion). */ tl::expected InnerGetViaRoute( - std::string_view key, std::vector& slices, RouteIterator iter); + std::string_view key, std::vector& slices, RouteIterator iter, + RdmaDirectionMode rdma_direction_mode); async_simple::coro::Lazy RunReadWithRetry( RouteIterator iter, std::shared_ptr req, std::shared_ptr>> - promise); + promise, + RdmaDirectionMode rdma_direction_mode); + + // true = promise set, caller co_return; false = try next route (incl. hard + // UnPin failure after TE success on this replica). + async_simple::coro::Lazy RunForwardReadOnRoute( + const ResolvedRoute& route, std::shared_ptr req, + std::shared_ptr>> + promise, + RouteIterator& iter, ErrorCode& final_result); + + async_simple::coro::Lazy RunForwardRemotePut( + std::shared_ptr>> + promise, + PeerClient* peer, DataManager* dm, + std::shared_ptr write_req, + std::vector* slices); + + // RemoteWriteOp forward path: promise + RunForwardRemotePut coroutine. + std::unique_ptr> StartForwardRemotePut( + PeerClient* peer, DataManager* forward_dm, + std::vector* forward_slices, + std::shared_ptr write_req); + + // RemoteWriteOp reverse path: AsyncWriteRemoteData + route cache upsert. + std::unique_ptr> RunReverseRemotePut( + PeerClient* peer, std::shared_ptr write_req, + const P2PProxyDescriptor& proxy, RouteCache* route_cache); async_simple::coro::Lazy> AsyncResolveRoutesFromMaster(std::string_view key, - const ReadRouteConfig& config); + const ReadConfigExt& config); /** * @brief Get or create a PeerClient for the given endpoint. diff --git a/mooncake-store/include/p2p_rpc_types.h b/mooncake-store/include/p2p_rpc_types.h index be03e2c741..c947fadc62 100644 --- a/mooncake-store/include/p2p_rpc_types.h +++ b/mooncake-store/include/p2p_rpc_types.h @@ -2,6 +2,7 @@ #include #include +#include #include #include "replica.h" @@ -43,6 +44,24 @@ inline std::ostream& operator<<(std::ostream& os, return os; } +struct WriteConfigExt { + WriteRouteRequestConfig route_config{}; + RdmaDirectionMode rdma_direction_mode = RdmaDirectionMode::REVERSE; + + WriteConfigExt() = default; + /** Promotes legacy write-route-only config for variant overload resolution. + */ + WriteConfigExt(WriteRouteRequestConfig r) : route_config(std::move(r)) {} +}; +YLT_REFL(WriteConfigExt, route_config, rdma_direction_mode); + +inline std::ostream& operator<<(std::ostream& os, + const WriteConfigExt& config) { + os << "WriteConfigExt: { route_config: [" << config.route_config + << "], rdma_direction_mode: " << config.rdma_direction_mode << " }"; + return os; +} + /** * @brief Request structure for getting write route. */ diff --git a/mooncake-store/include/peer_client.h b/mooncake-store/include/peer_client.h index 068591c146..58be645af1 100644 --- a/mooncake-store/include/peer_client.h +++ b/mooncake-store/include/peer_client.h @@ -27,11 +27,35 @@ class PeerClient { async_simple::coro::Lazy> AsyncWriteRemoteData(const RemoteWriteRequest& request); + async_simple::coro::Lazy> + AsyncPreWrite(const PreWriteRequest& request); + + async_simple::coro::Lazy> AsyncWriteCommit( + const WriteCommitRequest& request); + + async_simple::coro::Lazy> AsyncWriteRevoke( + const WriteRevokeRequest& request); + + async_simple::coro::Lazy> + AsyncPinKey(const PinKeyRequest& request); + + async_simple::coro::Lazy> AsyncUnPinKey( + const UnPinKeyRequest& request); + // --- Sync single-key interfaces --- tl::expected ReadRemoteData( const RemoteReadRequest& request); tl::expected WriteRemoteData( const RemoteWriteRequest& request); + tl::expected PreWrite( + const PreWriteRequest& request); + tl::expected WriteCommit( + const WriteCommitRequest& request); + tl::expected WriteRevoke( + const WriteRevokeRequest& request); + tl::expected PinKey( + const PinKeyRequest& request); + tl::expected UnPinKey(const UnPinKeyRequest& request); private: std::shared_ptr> diff --git a/mooncake-store/include/pyclient.h b/mooncake-store/include/pyclient.h index 98de68138c..abeea65cb0 100644 --- a/mooncake-store/include/pyclient.h +++ b/mooncake-store/include/pyclient.h @@ -40,19 +40,17 @@ class PyClient { virtual int unregister_buffer(void* buffer) = 0; virtual int64_t get_into(const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; virtual std::vector batch_get_into( const std::vector& keys, const std::vector& buffers, - const std::vector& sizes, - const ReadRouteConfig& config = {}) = 0; + const std::vector& sizes, const ReadConfigExt& config = {}) = 0; virtual std::vector batch_get_into_multi_buffers( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - bool aggregate_same_segment_task, - const ReadRouteConfig& config = {}) = 0; + bool aggregate_same_segment_task, const ReadConfigExt& config = {}) = 0; virtual int put_from(const std::string& key, void* buffer, size_t size, const WriteConfig& config) = 0; @@ -73,14 +71,14 @@ class PyClient { const WriteConfig& config) = 0; virtual std::shared_ptr get_buffer( - const std::string& key, const ReadRouteConfig& config = {}) = 0; + const std::string& key, const ReadConfigExt& config = {}) = 0; virtual std::tuple get_buffer_info( - const std::string& key, const ReadRouteConfig& config = {}) = 0; + const std::string& key, const ReadConfigExt& config = {}) = 0; virtual std::vector> batch_get_buffer( const std::vector& keys, - const ReadRouteConfig& config = {}) = 0; + const ReadConfigExt& config = {}) = 0; virtual int put_parts(const std::string& key, std::vector> values, diff --git a/mooncake-store/include/real_client.h b/mooncake-store/include/real_client.h index 58e34a0283..02108c9bfb 100644 --- a/mooncake-store/include/real_client.h +++ b/mooncake-store/include/real_client.h @@ -98,7 +98,7 @@ class RealClient : public PyClient { * register_buffer() for zero-copy operations */ int64_t get_into(const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; /** * @brief Get object data directly into pre-allocated buffers for multiple @@ -114,7 +114,7 @@ class RealClient : public PyClient { std::vector batch_get_into( const std::vector& keys, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; /** * @brief Get object data directly into pre-allocated buffers for multiple @@ -133,7 +133,7 @@ class RealClient : public PyClient { const std::vector>& all_buffers, const std::vector>& all_sizes, bool aggregate_same_segment_task, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; /** * @brief Put object data directly from a pre-allocated buffer @@ -222,7 +222,7 @@ class RealClient : public PyClient { * nullptr if error */ std::shared_ptr get_buffer( - const std::string& key, const ReadRouteConfig& config = {}) override; + const std::string& key, const ReadConfigExt& config = {}) override; /** * @brief Get buffer information (address and size) for a key @@ -230,7 +230,7 @@ class RealClient : public PyClient { * @return Tuple containing buffer address and size, or (0, 0) if error */ std::tuple get_buffer_info( - const std::string& key, const ReadRouteConfig& config = {}) override; + const std::string& key, const ReadConfigExt& config = {}) override; /** * @brief Get buffers containing the data for multiple keys (batch version) @@ -240,7 +240,7 @@ class RealClient : public PyClient { */ std::vector> batch_get_buffer( const std::vector& keys, - const ReadRouteConfig& config = {}) override; + const ReadConfigExt& config = {}) override; int remove(const std::string& key) override; @@ -276,7 +276,7 @@ class RealClient : public PyClient { // Dummy client helper functions that return tl::expected tl::expected, ErrorCode> get_buffer_info_dummy_helper(const std::string& key, - const ReadRouteConfig& config, + const ReadConfigExt& config, const UUID& client_id); tl::expected put_dummy_helper(const std::string& key, @@ -296,7 +296,7 @@ class RealClient : public PyClient { std::vector> batch_get_into_dummy_helper( const std::vector& keys, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config, const UUID& client_id); + const ReadConfigExt& config, const UUID& client_id); std::vector> batch_put_from_dummy_helper( const std::vector& keys, @@ -342,18 +342,18 @@ class RealClient : public PyClient { tl::expected get_into_internal( const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config = {}); + const ReadConfigExt& config = {}); std::vector> batch_get_into_internal( const std::vector& keys, const std::vector& buffers, - const std::vector& sizes, const ReadRouteConfig& config = {}); + const std::vector& sizes, const ReadConfigExt& config = {}); std::vector> batch_get_into_multi_buffers_internal( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - bool aggregate_same_segment_task, const ReadRouteConfig& config = {}); + bool aggregate_same_segment_task, const ReadConfigExt& config = {}); tl::expected put_from_internal(const std::string& key, void* buffer, size_t size, @@ -403,11 +403,10 @@ class RealClient : public PyClient { const std::string& key, std::shared_ptr client_buffer_allocator = nullptr, - const ReadRouteConfig& config = {}); + const ReadConfigExt& config = {}); std::vector> batch_get_buffer_internal( - const std::vector& keys, - const ReadRouteConfig& config = {}); + const std::vector& keys, const ReadConfigExt& config = {}); std::map> batch_get_replica_desc(const std::vector& keys); diff --git a/mooncake-store/include/rpc_types.h b/mooncake-store/include/rpc_types.h index 09067217b7..7ad47d6aff 100644 --- a/mooncake-store/include/rpc_types.h +++ b/mooncake-store/include/rpc_types.h @@ -1,5 +1,7 @@ #pragma once +#include + #include "types.h" #include "replica.h" #include "heartbeat_type.h" @@ -36,6 +38,19 @@ YLT_REFL(GetReplicaListRequestConfig, max_candidates, p2p_config); typedef GetReplicaListRequestConfig ReadRouteConfig; typedef P2PGetReplicaListConfigExtra P2PReadRouteConfigExtra; +/** + * @brief P2P read API wrapper: master routing + optional RDMA direction. + */ +struct ReadConfigExt { + ReadRouteConfig route_config{}; + RdmaDirectionMode rdma_direction_mode = RdmaDirectionMode::REVERSE; + + ReadConfigExt() = default; + /** @brief Backward-compatible promotion from read-route-only config. */ + ReadConfigExt(ReadRouteConfig r) : route_config(std::move(r)) {} +}; +YLT_REFL(ReadConfigExt, route_config, rdma_direction_mode); + /** * @brief Extra info for centralized read route response (Internal use) */ diff --git a/mooncake-store/include/types.h b/mooncake-store/include/types.h index e812954c70..0fa1a69d12 100644 --- a/mooncake-store/include/types.h +++ b/mooncake-store/include/types.h @@ -123,6 +123,8 @@ enum class ErrorCode : int32_t { // Parameter errors (Range: -600 to -699) INVALID_PARAMS = -600, ///< Invalid parameters. ILLEGAL_CLIENT = -601, ///< Illegal client to do the operation. + NON_CONTIGUOUS_BUFFER_NOT_SUPPORTED = + -602, ///< Non-contiguous buffer not supported in forward RDMA mode. // Engine operation errors (Range: -700 to -710) INVALID_WRITE = -700, ///< Invalid write operation. @@ -467,6 +469,31 @@ inline std::ostream& operator<<( return os; } +// Who initiates TE/RDMA for the data plane: REVERSE matches the historical +// target-initiated path and is the conventional default when unset optional or +// client-level config omits an explicit override. +enum class RdmaDirectionMode : uint8_t { + REVERSE = 0, + FORWARD = 1, +}; + +// Logging only: prints REVERSE / FORWARD / UNKNOWN for invalid numeric values. +inline std::ostream& operator<<(std::ostream& os, + const RdmaDirectionMode& mode) noexcept { + switch (mode) { + case RdmaDirectionMode::REVERSE: + os << "REVERSE"; + break; + case RdmaDirectionMode::FORWARD: + os << "FORWARD"; + break; + default: + os << "UNKNOWN"; + break; + } + return os; +} + } // namespace mooncake namespace std { diff --git a/mooncake-store/src/centralized_client_service.cpp b/mooncake-store/src/centralized_client_service.cpp index 24d7752dbc..80557bf9e4 100644 --- a/mooncake-store/src/centralized_client_service.cpp +++ b/mooncake-store/src/centralized_client_service.cpp @@ -234,7 +234,7 @@ void CentralizedClientService::InitTransferSubmitter() { tl::expected, ErrorCode> CentralizedClientService::Query(const std::string& object_key, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { auto guard = AcquireInflightGuard(); if (!guard.is_valid()) { LOG(ERROR) << "client is shutting down"; @@ -242,7 +242,8 @@ CentralizedClientService::Query(const std::string& object_key, } std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now(); - auto result = master_client_.GetReplicaList(object_key, config); + auto result = + master_client_.GetReplicaList(object_key, config.route_config); if (!result) { LOG(ERROR) << "Failed to get replica list: " << result.error(); return tl::unexpected(result.error()); @@ -262,8 +263,7 @@ CentralizedClientService::Query(const std::string& object_key, std::vector, ErrorCode>> CentralizedClientService::BatchQuery( - const std::vector& object_keys, - const ReadRouteConfig& config) { + const std::vector& object_keys, const ReadConfigExt& config) { auto guard = AcquireInflightGuard(); if (!guard.is_valid()) { LOG(ERROR) << "client is shutting down"; @@ -279,7 +279,8 @@ CentralizedClientService::BatchQuery( std::chrono::steady_clock::now(); std::vector key_views(object_keys.begin(), object_keys.end()); - auto response = master_client_.BatchGetReplicaList(key_views, config); + auto response = + master_client_.BatchGetReplicaList(key_views, config.route_config); // Check if we got the expected number of responses if (response.size() != object_keys.size()) { @@ -370,7 +371,7 @@ CentralizedClientService::BatchReplicaClear( tl::expected, ErrorCode> CentralizedClientService::Get(const std::string& key, std::shared_ptr allocator, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { if (!allocator) { LOG(ERROR) << "Client buffer allocator is not provided"; return tl::unexpected(ErrorCode::INVALID_PARAMS); @@ -425,7 +426,7 @@ std::vector, ErrorCode>> CentralizedClientService::BatchGet( const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { std::vector, ErrorCode>> results( keys.size(), tl::unexpected(ErrorCode::INTERNAL_ERROR)); @@ -527,7 +528,7 @@ CentralizedClientService::BatchGet( tl::expected CentralizedClientService::Get( const std::string& key, const std::vector& buffers, - const std::vector& sizes, const ReadRouteConfig& config) { + const std::vector& sizes, const ReadConfigExt& config) { // Step 1: Query metadata from master auto query_result = Query(key, config); if (!query_result) { @@ -572,7 +573,7 @@ CentralizedClientService::BatchGet( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - const ReadRouteConfig& config, bool aggregate_same_segment_task) { + const ReadConfigExt& config, bool aggregate_same_segment_task) { if (keys.size() != all_buffers.size() || keys.size() != all_sizes.size()) { LOG(ERROR) << "Input vector sizes mismatch"; return std::vector>( diff --git a/mooncake-store/src/client_rpc_service.cpp b/mooncake-store/src/client_rpc_service.cpp index d3a62fd5f0..1de25501a6 100644 --- a/mooncake-store/src/client_rpc_service.cpp +++ b/mooncake-store/src/client_rpc_service.cpp @@ -1,4 +1,7 @@ #include "client_rpc_service.h" + +#include + #include #include #include "utils/scoped_vlog_timer.h" @@ -7,6 +10,10 @@ namespace mooncake { namespace { +bool IsZeroUuid(const UUID& uuid) { + return uuid.first == 0 && uuid.second == 0; +} + size_t CalculateBufferSize(const std::vector& buffers) { size_t total = 0; for (const auto& buf : buffers) total += buf.size; @@ -51,6 +58,46 @@ bool IsValidRequest(const RemoteWriteRequest& request) { return true; } +bool IsValidRequest(const PreWriteRequest& request) { + if (request.key.empty() || request.size_bytes == 0) { + LOG(ERROR) << "PreWriteRequest: invalid key or size"; + return false; + } + return true; +} + +bool IsValidRequest(const WriteCommitRequest& request) { + if (request.key.empty() || IsZeroUuid(request.pending_write_token)) { + LOG(ERROR) << "WriteCommitRequest: invalid key or token"; + return false; + } + return true; +} + +bool IsValidRequest(const WriteRevokeRequest& request) { + if (request.key.empty() || IsZeroUuid(request.pending_write_token)) { + LOG(ERROR) << "WriteRevokeRequest: invalid key or token"; + return false; + } + return true; +} + +bool IsValidRequest(const PinKeyRequest& request) { + if (request.key.empty()) { + LOG(ERROR) << "PinKeyRequest: empty key"; + return false; + } + return true; +} + +bool IsValidRequest(const UnPinKeyRequest& request) { + if (request.key.empty() || IsZeroUuid(request.pin_token)) { + LOG(ERROR) << "UnPinKeyRequest: invalid key or token"; + return false; + } + return true; +} + } // anonymous namespace ClientRpcService::ClientRpcService(DataManager& data_manager, @@ -157,10 +204,128 @@ tl::expected ClientRpcService::WriteRemoteData( return result; } +tl::expected ClientRpcService::PreWrite( + const PreWriteRequest& request) { + ScopedVLogTimer timer(1, "ClientRpcService::PreWrite"); + timer.LogRequest("key=", request.key, "size_bytes=", request.size_bytes); + + if (!IsValidRequest(request)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + auto result = data_manager_.PreWrite(request.key, request.size_bytes, + request.target_tier_id); + if (!result) { + LOG(ERROR) << "PreWrite failed for key: " << request.key + << ", error: " << toString(result.error()); + timer.LogResponse("error_code=", result.error()); + return tl::make_unexpected(result.error()); + } + + timer.LogResponse("error_code=", ErrorCode::OK); + return std::move(*result); +} + +tl::expected ClientRpcService::WriteCommit( + const WriteCommitRequest& request) { + ScopedVLogTimer timer(1, "ClientRpcService::WriteCommit"); + timer.LogRequest("key=", request.key); + + if (!IsValidRequest(request)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + auto result = + data_manager_.WriteCommit(request.key, request.pending_write_token); + if (!result) { + LOG(ERROR) << "WriteCommit failed for key: " << request.key + << ", error: " << toString(result.error()); + timer.LogResponse("error_code=", result.error()); + return result; + } + + timer.LogResponse("error_code=", ErrorCode::OK); + return {}; +} + +tl::expected ClientRpcService::WriteRevoke( + const WriteRevokeRequest& request) { + ScopedVLogTimer timer(1, "ClientRpcService::WriteRevoke"); + timer.LogRequest("key=", request.key); + + if (!IsValidRequest(request)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + auto result = + data_manager_.WriteRevoke(request.key, request.pending_write_token); + if (!result) { + LOG(ERROR) << "WriteRevoke failed for key: " << request.key + << ", error: " << toString(result.error()); + timer.LogResponse("error_code=", result.error()); + return result; + } + + timer.LogResponse("error_code=", ErrorCode::OK); + return {}; +} + +tl::expected ClientRpcService::PinKey( + const PinKeyRequest& request) { + ScopedVLogTimer timer(1, "ClientRpcService::PinKey"); + timer.LogRequest("key=", request.key); + + if (!IsValidRequest(request)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + auto result = data_manager_.PinKey(request.key, request.target_tier_id); + if (!result) { + LOG(ERROR) << "PinKey failed for key: " << request.key + << ", error: " << toString(result.error()); + timer.LogResponse("error_code=", result.error()); + return tl::make_unexpected(result.error()); + } + + timer.LogResponse("error_code=", ErrorCode::OK); + return std::move(*result); +} + +tl::expected ClientRpcService::UnPinKey( + const UnPinKeyRequest& request) { + ScopedVLogTimer timer(1, "ClientRpcService::UnPinKey"); + timer.LogRequest("key=", request.key); + + if (!IsValidRequest(request)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + auto result = data_manager_.UnPinKey(request.key, request.pin_token); + if (!result) { + LOG(ERROR) << "UnPinKey failed for key: " << request.key + << ", error: " << toString(result.error()); + timer.LogResponse("error_code=", result.error()); + return result; + } + + timer.LogResponse("error_code=", ErrorCode::OK); + return {}; +} + void RegisterClientRpcService(coro_rpc::coro_rpc_server& server, ClientRpcService& service) { server.register_handler<&ClientRpcService::ReadRemoteData>(&service); server.register_handler<&ClientRpcService::WriteRemoteData>(&service); + server.register_handler<&ClientRpcService::PreWrite>(&service); + server.register_handler<&ClientRpcService::WriteCommit>(&service); + server.register_handler<&ClientRpcService::WriteRevoke>(&service); + server.register_handler<&ClientRpcService::PinKey>(&service); + server.register_handler<&ClientRpcService::UnPinKey>(&service); } } // namespace mooncake diff --git a/mooncake-store/src/data_manager.cpp b/mooncake-store/src/data_manager.cpp index 87fb89bd54..ee42b1c0ce 100644 --- a/mooncake-store/src/data_manager.cpp +++ b/mooncake-store/src/data_manager.cpp @@ -22,6 +22,9 @@ namespace mooncake { namespace { +constexpr uint32_t kDefaultLeaseDurationMs = 5000; +constexpr uint32_t kDefaultLeaseScanIntervalMs = 1000; + struct LocalCopyPlan { AllocationHandle source_handle; const char* source_ptr = nullptr; @@ -112,6 +115,10 @@ ErrorCode ExecuteLocalCopyPlan(const LocalCopyPlan& plan) { return ErrorCode::OK; } +bool IsZeroUuid(const UUID& uuid) { + return uuid.first == 0 && uuid.second == 0; +} + } // namespace // ================================================================ @@ -126,6 +133,8 @@ DataManager::DataManager(std::unique_ptr tiered_backend, transfer_engine_(transfer_engine), lock_shard_count_(lock_shard_count > 0 ? lock_shard_count : 1024), lock_shards_(lock_shard_count_), + pending_write_shards_(lock_shard_count_), + pinned_key_shards_(lock_shard_count_), local_transfer_config_(local_transfer_config) { if (!tiered_backend_) { LOG(FATAL) << "TieredBackend cannot be null"; @@ -140,6 +149,13 @@ DataManager::DataManager(std::unique_ptr tiered_backend, local_transfer_config_.local_memcpy_async_worker_num); } + lease_duration_ = std::chrono::milliseconds(GetEnvOr( + "P2P_RPC_LEASE_DURATION_MS", kDefaultLeaseDurationMs)); + lease_scan_interval_ = std::chrono::milliseconds( + std::max(1, GetEnvOr("P2P_RPC_LEASE_SCAN_INTERVAL", + kDefaultLeaseScanIntervalMs))); + lease_scanner_thread_ = std::thread(&DataManager::LeaseScannerMain, this); + LOG(INFO) << "DataManager initialized with " << lock_shard_count_ << " lock shards, local_transfer_mode=" << (local_transfer_config_.mode == LocalTransferMode::TE @@ -147,7 +163,302 @@ DataManager::DataManager(std::unique_ptr tiered_backend, : "MEMCPY") << ", te_endpoint=" << local_transfer_config_.te_endpoint << ", async_memcpy_workers=" - << local_transfer_config_.local_memcpy_async_worker_num; + << local_transfer_config_.local_memcpy_async_worker_num + << ", lease_duration_ms=" << lease_duration_.count() + << ", lease_scan_interval_ms=" << lease_scan_interval_.count(); +} + +DataManager::~DataManager() { Stop(); } + +void DataManager::Stop() { + ShutdownLeaseScanner(); + ClearLeaseRecords(); + if (async_memcpy_executor_) { + async_memcpy_executor_->Shutdown(); + } + if (tiered_backend_) { + tiered_backend_->Stop(); + } +} + +void DataManager::ClearLeaseRecords() { + for (auto& shard : pending_write_shards_) { + std::unique_lock lock(shard.mutex); + shard.by_key.clear(); + shard.ordered_list.clear(); + } + for (auto& shard : pinned_key_shards_) { + std::unique_lock lock(shard.mutex); + shard.by_key.clear(); + shard.ordered_list.clear(); + } +} + +size_t DataManager::HashKey(std::string_view key) const { + return std::hash{}(key); +} + +DataManager::KeyCtx DataManager::BuildKeyCtx(std::string_view key) const { + KeyCtx ctx; + ctx.key = key; + ctx.key_string = std::string(key); + ctx.hash = HashKey(key); + ctx.pending_write_shard_idx = + pending_write_shards_.empty() + ? 0 + : (ctx.hash % pending_write_shards_.size()); + ctx.pinned_key_shard_idx = + pinned_key_shards_.empty() ? 0 : (ctx.hash % pinned_key_shards_.size()); + return ctx; +} + +DataManager::PendingWriteShard& DataManager::GetPendingWriteShard( + const KeyCtx& ctx) { + return pending_write_shards_[ctx.pending_write_shard_idx]; +} + +DataManager::PinnedKeyShard& DataManager::GetPinnedKeyShard(const KeyCtx& ctx) { + return pinned_key_shards_[ctx.pinned_key_shard_idx]; +} + +DataManager::PendingWriteShard& DataManager::GetPendingWriteShard( + std::string_view key) { + return pending_write_shards_[HashKey(key) % pending_write_shards_.size()]; +} + +DataManager::PinnedKeyShard& DataManager::GetPinnedKeyShard( + std::string_view key) { + return pinned_key_shards_[HashKey(key) % pinned_key_shards_.size()]; +} + +bool DataManager::IsExpired(TimePoint deadline) const { + return deadline <= std::chrono::steady_clock::now(); +} + +RemoteBufferDesc DataManager::BuildRemoteBufferDesc( + const AllocationHandle& handle) const { + const auto& loc_data = handle->loc.data; + RemoteBufferDesc remote_buffer; + remote_buffer.segment_endpoint = local_transfer_config_.te_endpoint; + remote_buffer.addr = 0; + remote_buffer.size = 0; + if (loc_data.buffer) { + remote_buffer.addr = + reinterpret_cast(loc_data.buffer->data()); + remote_buffer.size = loc_data.buffer->size(); + } + return remote_buffer; +} + +void DataManager::TouchOrderedDeadlineNode(OrderedDeadlineList& ordered_list, + OrderedDeadlineListIt it, + const std::string& key, + TimePoint deadline) { + it->first = key; + it->second = deadline; + ordered_list.splice(ordered_list.end(), ordered_list, it); +} + +bool DataManager::ErasePendingWriteLocked(PendingWriteShard& shard, + const std::string& key) { + auto it = shard.by_key.find(key); + if (it == shard.by_key.end()) { + return false; + } + shard.ordered_list.erase(it->second.list_it); + shard.by_key.erase(it); + return true; +} + +bool DataManager::ErasePinnedKeyLocked(PinnedKeyShard& shard, + const std::string& key) { + auto it = shard.by_key.find(key); + if (it == shard.by_key.end()) { + return false; + } + shard.ordered_list.erase(it->second.list_it); + shard.by_key.erase(it); + return true; +} + +bool DataManager::RemoveExpiredPendingWriteLocked(PendingWriteShard& shard, + const std::string& key, + TimePoint now) { + auto it = shard.by_key.find(key); + if (it == shard.by_key.end() || it->second.deadline > now) { + return false; + } + shard.ordered_list.erase(it->second.list_it); + shard.by_key.erase(it); + return true; +} + +bool DataManager::RemoveExpiredPinnedKeyLocked(PinnedKeyShard& shard, + const std::string& key, + TimePoint now) { + auto it = shard.by_key.find(key); + if (it == shard.by_key.end() || it->second.deadline > now) { + return false; + } + shard.ordered_list.erase(it->second.list_it); + shard.by_key.erase(it); + return true; +} + +size_t DataManager::ScanExpiredPendingWrites(PendingWriteShard& shard, + TimePoint now) { + size_t removed = 0; + while (!shard.ordered_list.empty()) { + auto list_it = shard.ordered_list.begin(); + if (list_it->second > now) { + break; + } + auto record_it = shard.by_key.find(list_it->first); + if (record_it == shard.by_key.end() || + record_it->second.list_it != list_it) { + shard.ordered_list.erase(list_it); + continue; + } + shard.by_key.erase(record_it); + shard.ordered_list.erase(list_it); + ++removed; + } + return removed; +} + +size_t DataManager::ScanExpiredPinnedKeys(PinnedKeyShard& shard, + TimePoint now) { + size_t removed = 0; + while (!shard.ordered_list.empty()) { + auto list_it = shard.ordered_list.begin(); + if (list_it->second > now) { + break; + } + auto record_it = shard.by_key.find(list_it->first); + if (record_it == shard.by_key.end() || + record_it->second.list_it != list_it) { + shard.ordered_list.erase(list_it); + continue; + } + shard.by_key.erase(record_it); + shard.ordered_list.erase(list_it); + ++removed; + } + return removed; +} + +tl::expected DataManager::LookupPendingWriteHandle( + std::string_view key, const UUID& pending_write_token) { + return LookupPendingWriteHandleInternal(BuildKeyCtx(key), + pending_write_token); +} + +tl::expected +DataManager::LookupPendingWriteHandleInternal(const KeyCtx& ctx, + const UUID& pending_write_token) { + const auto now = std::chrono::steady_clock::now(); + auto& shard = GetPendingWriteShard(ctx); + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(ctx.key_string); + if (it == shard.by_key.end()) { + return tl::unexpected(ErrorCode::OBJECT_NOT_FOUND); + } + if (it->second.deadline <= now) { + return tl::unexpected(ErrorCode::LEASE_EXPIRED); + } + if (it->second.pending_write_token != pending_write_token) { + return tl::unexpected(ErrorCode::INVALID_WRITE); + } + return it->second.handle; +} + +tl::expected DataManager::LookupPinnedKeyHandle( + std::string_view key, const UUID& pin_token) { + return LookupPinnedKeyHandleInternal(BuildKeyCtx(key), pin_token); +} + +tl::expected +DataManager::LookupPinnedKeyHandleInternal(const KeyCtx& ctx, + const UUID& pin_token) { + const auto now = std::chrono::steady_clock::now(); + auto& shard = GetPinnedKeyShard(ctx); + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(ctx.key_string); + if (it == shard.by_key.end()) { + return tl::unexpected(ErrorCode::OBJECT_NOT_FOUND); + } + if (it->second.deadline <= now) { + return tl::unexpected(ErrorCode::LEASE_EXPIRED); + } + if (it->second.pin_token != pin_token) { + return tl::unexpected(ErrorCode::INVALID_READ); + } + return it->second.handle; +} + +tl::expected DataManager::WriteRevoke( + std::string_view key, const UUID& pending_write_token) { + return WriteRevokeInternal(BuildKeyCtx(key), pending_write_token); +} + +tl::expected DataManager::WriteRevokeInternal( + const KeyCtx& ctx, const UUID& pending_write_token) { + ScopedVLogTimer timer(1, "DataManager::WriteRevoke"); + timer.LogRequest("key=", ctx.key); + + if (ctx.key.empty() || IsZeroUuid(pending_write_token)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + std::unique_lock key_lock(GetKeyLock(ctx.key)); + auto& shard = GetPendingWriteShard(ctx); + std::unique_lock shard_lock(shard.mutex); + + auto record_it = shard.by_key.find(ctx.key_string); + if (record_it == shard.by_key.end()) { + timer.LogResponse("error_code=", ErrorCode::OK, "idempotent=", true); + return {}; + } + if (record_it->second.pending_write_token != pending_write_token) { + timer.LogResponse("error_code=", ErrorCode::INVALID_WRITE); + return tl::make_unexpected(ErrorCode::INVALID_WRITE); + } + shard.ordered_list.erase(record_it->second.list_it); + shard.by_key.erase(record_it); + timer.LogResponse("error_code=", ErrorCode::OK, "record_erased=", true); + return {}; +} + +void DataManager::ShutdownLeaseScanner() { + lease_scanner_stop_requested_.store(true); + lease_scanner_cv_.notify_all(); + if (lease_scanner_thread_.joinable()) { + lease_scanner_thread_.join(); + } +} + +void DataManager::LeaseScannerMain() { + std::unique_lock wait_lock(lease_scanner_mutex_); + while (!lease_scanner_stop_requested_.load()) { + lease_scanner_cv_.wait_for(wait_lock, lease_scan_interval_, [this]() { + return lease_scanner_stop_requested_.load(); + }); + if (lease_scanner_stop_requested_.load()) { + break; + } + const auto now = std::chrono::steady_clock::now(); + wait_lock.unlock(); + for (auto& shard : pending_write_shards_) { + std::unique_lock shard_lock(shard.mutex); + ScanExpiredPendingWrites(shard, now); + } + for (auto& shard : pinned_key_shards_) { + std::unique_lock shard_lock(shard.mutex); + ScanExpiredPinnedKeys(shard, now); + } + wait_lock.lock(); + } } // ================================================================ @@ -187,6 +498,7 @@ tl::expected>, ErrorCode> DataManager::Put( tl::expected>, ErrorCode> DataManager::PutViaTe(std::string_view key, std::vector& slices) { // using Te, treat local memory as remote memory + const KeyCtx kctx = BuildKeyCtx(key); size_t total_size = 0; for (const auto& s : slices) total_size += s.size; auto src_buffers = SlicesToRemoteBufferDescs(slices); @@ -197,28 +509,22 @@ DataManager::PutViaTe(std::string_view key, std::vector& slices) { return tl::unexpected(validate_result.error()); } - AllocationHandle alloc_handle; - { - std::unique_lock lock(GetKeyLock(key)); - - if (tiered_backend_->Exist(key)) { - LOG(WARNING) << "key already exists: " << key; - return tl::make_unexpected(ErrorCode::OBJECT_ALREADY_EXISTS); - } + // Local Put: allocation follows tier backend policy (not restricted to + // DRAM). + auto prewrite_result = + PreWriteInternal(kctx, total_size, std::nullopt, false); + if (!prewrite_result) { + return tl::unexpected(prewrite_result.error()); + } + const UUID pending_write_token = prewrite_result->pending_write_token; - auto alloc_result = tiered_backend_->Allocate(total_size); - if (!alloc_result) { - auto err = alloc_result.error(); - if (err == ErrorCode::REPLICA_NUM_EXCEEDED || - err == ErrorCode::OBJECT_ALREADY_EXISTS || - err == ErrorCode::REPLICA_ALREADY_EXISTS) { - LOG(WARNING) << "object already exists for key: " << key - << ", error code: " << err; - } - return tl::unexpected(err); - } - alloc_handle = alloc_result.value(); + auto handle_result = + LookupPendingWriteHandleInternal(kctx, pending_write_token); + if (!handle_result) { + (void)WriteRevokeInternal(kctx, pending_write_token); + return tl::unexpected(handle_result.error()); } + AllocationHandle alloc_handle = handle_result.value(); auto submit_result = SubmitTeTransferInternal( alloc_handle, src_buffers, Transport::TransferRequest::READ); @@ -226,20 +532,22 @@ DataManager::PutViaTe(std::string_view key, std::vector& slices) { LOG(ERROR) << "SubmitTeTransferInternal failed" << ", key=" << key << ", error_code=" << toString(submit_result.error()); + (void)WriteRevokeInternal(kctx, pending_write_token); return tl::unexpected(submit_result.error()); } return CallableTaskHandle::Create( - [this, ctx = std::move(*submit_result), alloc_handle, - key]() mutable -> tl::expected { + [this, ctx = std::move(*submit_result), alloc_handle, kctx, + pending_write_token]() mutable -> tl::expected { ScopedVLogTimer timer(1, "DataManager::PutViaTe"); - timer.LogRequest("key=", key); + timer.LogRequest("key=", kctx.key); auto wait_result = WaitAllTransferBatches(ctx.transfer_batches); if (!wait_result) { LOG(ERROR) << "WaitAllTransferBatches failed" - << ", key=" << key + << ", key=" << kctx.key << ", error_code=" << toString(wait_result.error()); + (void)WriteRevokeInternal(kctx, pending_write_token); return tl::unexpected(wait_result.error()); } @@ -254,16 +562,16 @@ DataManager::PutViaTe(std::string_view key, std::vector& slices) { if (!copy_result) { LOG(ERROR) << "CopyFromDRAMBuffer failed" - << ", key=" << key + << ", key=" << kctx.key << ", error_code=" << toString(copy_result.error()); + (void)WriteRevokeInternal(kctx, pending_write_token); return tl::unexpected(copy_result.error()); } } - std::unique_lock lock(GetKeyLock(key)); - auto commit_result = tiered_backend_->Commit(key, alloc_handle); + auto commit_result = WriteCommitInternal(kctx, pending_write_token); if (!commit_result) { - LOG(WARNING) << "commit race for key: " << key; + return tl::unexpected(commit_result.error()); } timer.LogResponse("error_code=", ErrorCode::OK); return {}; @@ -276,54 +584,48 @@ DataManager::PutViaMemcpy(std::string_view key, std::vector& slices) { LOG(ERROR) << "PutLocal in memcpy mode only supports a single slice"; return tl::unexpected(ErrorCode::NOT_IMPLEMENTED); } + const KeyCtx kctx = BuildKeyCtx(key); Slice slice = slices[0]; - AllocationHandle alloc_handle; - { - std::unique_lock lock(GetKeyLock(key)); - - if (tiered_backend_->Exist(key)) { - LOG(WARNING) << "Key already exists: " << key; - return tl::make_unexpected(ErrorCode::OBJECT_ALREADY_EXISTS); - } + // Same allocation policy as PutViaTe. + auto prewrite_result = + PreWriteInternal(kctx, slice.size, std::nullopt, false); + if (!prewrite_result) { + return tl::unexpected(prewrite_result.error()); + } + const UUID pending_write_token = prewrite_result->pending_write_token; - auto handle = tiered_backend_->Allocate(slice.size); - if (!handle.has_value()) { - LOG(ERROR) << "Failed to allocate space for key: " << key - << ", error: " << handle.error(); - return tl::make_unexpected(handle.error()); - } - alloc_handle = handle.value(); + auto handle_result = + LookupPendingWriteHandleInternal(kctx, pending_write_token); + if (!handle_result) { + (void)WriteRevokeInternal(kctx, pending_write_token); + return tl::unexpected(handle_result.error()); } + AllocationHandle alloc_handle = handle_result.value(); - auto write_fn = [this, key, slice, - alloc_handle]() -> tl::expected { + auto write_fn = [this, kctx, slice, alloc_handle, + pending_write_token]() -> tl::expected { DataSource source; source.buffer = std::make_unique(slice.ptr, slice.size); source.type = MemoryType::DRAM; auto write_result = tiered_backend_->Write(source, alloc_handle); if (!write_result.has_value()) { - LOG(ERROR) << "Failed to write data for key: " << key + LOG(ERROR) << "Failed to write data for key: " << kctx.key << ", error: " << write_result.error(); + (void)WriteRevokeInternal(kctx, pending_write_token); return tl::make_unexpected(write_result.error()); } return {}; }; - auto commit_fn = [this, key, - alloc_handle]() -> tl::expected { - std::unique_lock lock(GetKeyLock(key)); - auto commit_result = tiered_backend_->Commit(key, alloc_handle); - if (!commit_result.has_value()) { - auto err = commit_result.error(); - if (err != ErrorCode::REPLICA_NUM_EXCEEDED && - err != ErrorCode::OBJECT_ALREADY_EXISTS && - err != ErrorCode::REPLICA_ALREADY_EXISTS) { - LOG(ERROR) << "Failed to commit data for key: " << key - << ", error: " << commit_result.error(); - return tl::make_unexpected(commit_result.error()); - } + auto commit_fn = [this, kctx, + pending_write_token]() -> tl::expected { + auto commit_result = WriteCommitInternal(kctx, pending_write_token); + if (!commit_result) { + LOG(ERROR) << "Failed to commit data for key: " << kctx.key + << ", error: " << commit_result.error(); + return tl::make_unexpected(commit_result.error()); } return {}; }; @@ -514,8 +816,6 @@ tl::expected DataManager::BuildDataCopierViaMemcpy( // Remote data transfer — called by RPC service layer // ================================================================ -// Attention!!! -// This method runs without key lock. tl::expected DataManager::ReadRemoteData( std::string_view key, const std::vector& dest_buffers) { ScopedVLogTimer timer(1, "DataManager::ReadRemoteData"); @@ -529,15 +829,22 @@ tl::expected DataManager::ReadRemoteData( return tl::make_unexpected(validate_result.error()); } + // Reverse RDMA read stays on the direct object-handle path. Only forward + // RDMA read uses the 3-phase PinKey -> TE Read -> UnPinKey flow. auto handle_result = tiered_backend_->Get(key); - if (!handle_result.has_value()) { - LOG(ERROR) << "ReadRemoteData: Failed to get data for key: " << key - << ", error: " << toString(handle_result.error()); + if (!handle_result) { timer.LogResponse("error_code=", handle_result.error()); return tl::make_unexpected(handle_result.error()); } - return TransferDataToRemote(handle_result.value(), dest_buffers); + auto transfer_result = + TransferDataToRemote(handle_result.value(), dest_buffers); + if (!transfer_result) { + timer.LogResponse("error_code=", transfer_result.error()); + return tl::make_unexpected(transfer_result.error()); + } + timer.LogResponse("error_code=", ErrorCode::OK); + return {}; } tl::expected DataManager::TransferDataToRemote( @@ -561,6 +868,7 @@ tl::expected DataManager::WriteRemoteData( std::optional tier_id) { ScopedVLogTimer timer(1, "DataManager::WriteRemoteData"); timer.LogRequest("key=", key, "buffer_count=", src_buffers.size()); + const KeyCtx kctx = BuildKeyCtx(key); auto validate_result = ValidateRemoteBuffers(src_buffers); if (!validate_result) { @@ -573,47 +881,37 @@ tl::expected DataManager::WriteRemoteData( size_t total_size = 0; for (const auto& buf : src_buffers) total_size += buf.size; - AllocationHandle handle; - { - std::unique_lock lock(GetKeyLock(key)); - - if (tiered_backend_->Exist(key)) { - LOG(WARNING) << "Key already exists: " << key; - timer.LogResponse("error_code=", ErrorCode::OBJECT_ALREADY_EXISTS); - return tl::make_unexpected(ErrorCode::OBJECT_ALREADY_EXISTS); - } + // Reverse RDMA path: still one RPC, but internally use the 3-phase write + // model (PreWrite -> transfer -> WriteCommit). Target tier may be non-DRAM. + auto prewrite_result = PreWriteInternal(kctx, total_size, tier_id, false); + if (!prewrite_result) { + timer.LogResponse("error_code=", prewrite_result.error()); + return tl::make_unexpected(prewrite_result.error()); + } + const UUID pending_write_token = prewrite_result->pending_write_token; - auto handle_result = tiered_backend_->Allocate(total_size, tier_id); - if (!handle_result.has_value()) { - LOG(ERROR) << "WriteRemoteData: Failed to allocate space for key: " - << key; - timer.LogResponse("error_code=", handle_result.error()); - return tl::make_unexpected(handle_result.error()); - } - handle = handle_result.value(); + auto handle_result = + LookupPendingWriteHandleInternal(kctx, pending_write_token); + if (!handle_result) { + (void)WriteRevokeInternal(kctx, pending_write_token); + timer.LogResponse("error_code=", handle_result.error()); + return tl::make_unexpected(handle_result.error()); } + AllocationHandle handle = handle_result.value(); + UUID result_tier_id = handle->loc.tier->GetTierId(); - // Transfer phase — no lock held, RDMA runs concurrently. + // Transfer phase — no long key lock held. auto transfer_result = TransferDataFromRemote(handle, src_buffers); - if (!transfer_result.has_value()) { - LOG(ERROR) << "WriteRemoteData: Transfer failed for key: " << key - << ", error: " << toString(transfer_result.error()); + if (!transfer_result) { + (void)WriteRevokeInternal(kctx, pending_write_token); timer.LogResponse("error_code=", transfer_result.error()); return tl::make_unexpected(transfer_result.error()); } - // Commit phase: re-acquire lock, commit the handle. - UUID result_tier_id; - { - std::unique_lock lock(GetKeyLock(key)); - auto commit_result = tiered_backend_->Commit(key, handle); - if (!commit_result.has_value()) { - LOG(ERROR) << "WriteRemoteData: Failed to commit data for key: " - << key; - timer.LogResponse("error_code=", commit_result.error()); - return tl::make_unexpected(commit_result.error()); - } - result_tier_id = handle->loc.tier->GetTierId(); + auto commit_result = WriteCommitInternal(kctx, pending_write_token); + if (!commit_result) { + timer.LogResponse("error_code=", commit_result.error()); + return tl::make_unexpected(commit_result.error()); } timer.LogResponse("error_code=", ErrorCode::OK, @@ -621,6 +919,263 @@ tl::expected DataManager::WriteRemoteData( return result_tier_id; } +tl::expected DataManager::PreWrite( + std::string_view key, size_t size_bytes, std::optional tier_id) { + // RPC PreWrite: forward path is wired for DRAM only for now; other tiers + // TODO (staging / TE registration). + return PreWriteInternal(BuildKeyCtx(key), size_bytes, tier_id, true); +} + +tl::expected DataManager::PreWriteInternal( + const KeyCtx& ctx, size_t size_bytes, std::optional tier_id, + bool enforce_dram_allocation) { + ScopedVLogTimer timer(1, "DataManager::PreWrite"); + timer.LogRequest("key=", ctx.key, "size_bytes=", size_bytes); + + if (ctx.key.empty() || size_bytes == 0) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + const auto now = std::chrono::steady_clock::now(); + const auto deadline = now + lease_duration_; + + std::unique_lock key_lock(GetKeyLock(ctx.key)); + auto& shard = GetPendingWriteShard(ctx); + std::unique_lock shard_lock(shard.mutex); + + RemoveExpiredPendingWriteLocked(shard, ctx.key_string, now); + if (shard.by_key.find(ctx.key_string) != shard.by_key.end()) { + timer.LogResponse("error_code=", ErrorCode::OBJECT_HAS_LEASE); + return tl::make_unexpected(ErrorCode::OBJECT_HAS_LEASE); + } + if (tiered_backend_->Exist(ctx.key)) { + timer.LogResponse("error_code=", ErrorCode::OBJECT_ALREADY_EXISTS); + return tl::make_unexpected(ErrorCode::OBJECT_ALREADY_EXISTS); + } + + auto handle_result = tiered_backend_->Allocate(size_bytes, tier_id); + if (!handle_result) { + timer.LogResponse("error_code=", handle_result.error()); + return tl::make_unexpected(handle_result.error()); + } + + auto handle = std::move(handle_result.value()); + // When enforce_dram_allocation is true (RPC PreWrite): only DRAM is wired + // for forward TE today; non-DRAM tiers TODO. Local Put / WriteRemoteData + // use false and skip this check. + if (enforce_dram_allocation && handle->loc.data.type != MemoryType::DRAM) { + timer.LogResponse("error_code=", + ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + return tl::make_unexpected(ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + } + auto list_it = shard.ordered_list.emplace(shard.ordered_list.end(), + ctx.key_string, deadline); + + const UUID pending_write_token = generate_uuid(); + PendingWriteRecord record; + record.pending_write_token = pending_write_token; + record.deadline = deadline; + record.handle = handle; + record.list_it = list_it; + shard.by_key.insert_or_assign(ctx.key_string, std::move(record)); + + PreWriteResponse result; + result.remote_buffer = BuildRemoteBufferDesc(handle); + result.pending_write_token = pending_write_token; + timer.LogResponse("error_code=", ErrorCode::OK); + return result; +} + +tl::expected DataManager::WriteCommit( + std::string_view key, const UUID& pending_write_token) { + return WriteCommitInternal(BuildKeyCtx(key), pending_write_token); +} + +tl::expected DataManager::WriteCommitInternal( + const KeyCtx& ctx, const UUID& pending_write_token) { + ScopedVLogTimer timer(1, "DataManager::WriteCommit"); + timer.LogRequest("key=", ctx.key); + + if (ctx.key.empty() || IsZeroUuid(pending_write_token)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + const auto now = std::chrono::steady_clock::now(); + + std::unique_lock key_lock(GetKeyLock(ctx.key)); + auto& shard = GetPendingWriteShard(ctx); + std::unique_lock shard_lock(shard.mutex); + + auto record_it = shard.by_key.find(ctx.key_string); + if (record_it == shard.by_key.end()) { + timer.LogResponse("error_code=", ErrorCode::OK, "idempotent=", true); + return {}; + } + if (record_it->second.deadline <= now) { + shard.ordered_list.erase(record_it->second.list_it); + shard.by_key.erase(record_it); + timer.LogResponse("error_code=", ErrorCode::LEASE_EXPIRED); + return tl::make_unexpected(ErrorCode::LEASE_EXPIRED); + } + if (record_it->second.pending_write_token != pending_write_token) { + timer.LogResponse("error_code=", ErrorCode::INVALID_WRITE); + return tl::make_unexpected(ErrorCode::INVALID_WRITE); + } + + auto handle = record_it->second.handle; + // Once we attempt the commit, always erase the pending record regardless of + // the commit result. + // + // Rationale: most commit failures are not transient and cannot be resolved + // by retrying WriteCommit alone (e.g. tier commit failure or master + // callback failure). The caller should restart a full write flow instead of + // reusing a potentially stale prewrite context. + // + // Note: after erasing the record, subsequent WriteCommit calls become + // idempotent no-ops (return OK) due to record absence. + auto commit_result = tiered_backend_->Commit(ctx.key, handle); + shard.ordered_list.erase(record_it->second.list_it); + shard.by_key.erase(record_it); + + if (!commit_result) { + timer.LogResponse("error_code=", commit_result.error(), + "record_erased=", true); + return tl::make_unexpected(commit_result.error()); + } + + timer.LogResponse("error_code=", ErrorCode::OK, "record_erased=", true); + return {}; +} + +tl::expected DataManager::PinKey( + std::string_view key, std::optional tier_id) { + return PinKeyInternal(BuildKeyCtx(key), tier_id); +} + +tl::expected DataManager::PinKeyInternal( + const KeyCtx& ctx, std::optional tier_id) { + ScopedVLogTimer timer(1, "DataManager::PinKey"); + timer.LogRequest("key=", ctx.key); + + if (ctx.key.empty()) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + const auto now = std::chrono::steady_clock::now(); + const auto deadline = now + lease_duration_; + + std::shared_lock key_lock(GetKeyLock(ctx.key)); + auto& shard = GetPinnedKeyShard(ctx); + std::unique_lock shard_lock(shard.mutex); + + RemoveExpiredPinnedKeyLocked(shard, ctx.key_string, now); + auto record_it = shard.by_key.find(ctx.key_string); + if (record_it != shard.by_key.end()) { + // PinKey forward path: DRAM-only for now; non-DRAM replica handling + // TODO. + if (record_it->second.handle->loc.data.type != MemoryType::DRAM) { + timer.LogResponse("error_code=", + ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + return tl::make_unexpected(ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + } + record_it->second.ref_count++; + record_it->second.deadline = deadline; + TouchOrderedDeadlineNode(shard.ordered_list, record_it->second.list_it, + ctx.key_string, deadline); + + PinKeyResponse result; + result.remote_buffer = BuildRemoteBufferDesc(record_it->second.handle); + result.pin_token = record_it->second.pin_token; + timer.LogResponse("error_code=", ErrorCode::OK, + "ref_count=", record_it->second.ref_count); + return result; + } + + auto handle_result = tiered_backend_->Get(ctx.key, tier_id); + if (!handle_result) { + timer.LogResponse("error_code=", handle_result.error()); + return tl::make_unexpected(handle_result.error()); + } + + auto handle = std::move(handle_result.value()); + // PinKey forward path: DRAM-only for now; non-DRAM replica handling TODO. + if (handle->loc.data.type != MemoryType::DRAM) { + timer.LogResponse("error_code=", + ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + return tl::make_unexpected(ErrorCode::UNAVAILABLE_IN_CURRENT_MODE); + } + auto list_it = shard.ordered_list.emplace(shard.ordered_list.end(), + ctx.key_string, deadline); + + const UUID pin_token_value = generate_uuid(); + PinnedKeyRecord record; + record.pin_token = pin_token_value; + record.deadline = deadline; + record.handle = handle; + record.ref_count = 1; + record.list_it = list_it; + shard.by_key.insert_or_assign(ctx.key_string, std::move(record)); + + PinKeyResponse result; + result.remote_buffer = BuildRemoteBufferDesc(handle); + result.pin_token = pin_token_value; + timer.LogResponse("error_code=", ErrorCode::OK); + return result; +} + +tl::expected DataManager::UnPinKey(std::string_view key, + const UUID& pin_token) { + return UnPinKeyInternal(BuildKeyCtx(key), pin_token); +} + +tl::expected DataManager::UnPinKeyInternal( + const KeyCtx& ctx, const UUID& pin_token) { + ScopedVLogTimer timer(1, "DataManager::UnPinKey"); + timer.LogRequest("key=", ctx.key); + + if (ctx.key.empty() || IsZeroUuid(pin_token)) { + timer.LogResponse("error_code=", ErrorCode::INVALID_PARAMS); + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + + const auto now = std::chrono::steady_clock::now(); + + std::shared_lock key_lock(GetKeyLock(ctx.key)); + auto& shard = GetPinnedKeyShard(ctx); + std::unique_lock shard_lock(shard.mutex); + + auto record_it = shard.by_key.find(ctx.key_string); + if (record_it == shard.by_key.end()) { + timer.LogResponse("error_code=", ErrorCode::OK, "idempotent=", true); + return {}; + } + if (record_it->second.deadline <= now) { + shard.ordered_list.erase(record_it->second.list_it); + shard.by_key.erase(record_it); + timer.LogResponse("error_code=", ErrorCode::LEASE_EXPIRED); + return tl::make_unexpected(ErrorCode::LEASE_EXPIRED); + } + if (record_it->second.pin_token != pin_token) { + timer.LogResponse("error_code=", ErrorCode::INVALID_READ); + return tl::make_unexpected(ErrorCode::INVALID_READ); + } + + if (record_it->second.ref_count > 1) { + record_it->second.ref_count--; + timer.LogResponse("error_code=", ErrorCode::OK, + "ref_count=", record_it->second.ref_count); + return {}; + } + + shard.ordered_list.erase(record_it->second.list_it); + shard.by_key.erase(record_it); + timer.LogResponse("error_code=", ErrorCode::OK, "ref_count=", 0); + return {}; +} + tl::expected DataManager::TransferDataFromRemote( AllocationHandle handle, const std::vector& src_buffers) { auto submit_result = SubmitTeTransferInternal( @@ -721,6 +1276,30 @@ DataManager::SubmitTeTransferInternal( std::move(buffer_result.value()); } + auto batches_result = SubmitTeTransferBatches(transfer_ptr, total_data_size, + remote_buffers, opcode); + if (!batches_result) { + return tl::unexpected(batches_result.error()); + } + + TeSubmitResult result; + result.transfer_batches = std::move(batches_result.value()); + result.temp_buffer = std::move(temp_buffer_owner); + result.handle = handle; + return result; +} + +tl::expected>, + ErrorCode> +DataManager::SubmitTeTransferBatches( + void* transfer_ptr, size_t total_data_size, + const std::vector& remote_buffers, + Transport::TransferRequest::OpCode opcode) { + if (!transfer_ptr || total_data_size == 0) { + LOG(ERROR) << "SubmitTeTransferBatches: invalid local transfer pointer"; + return tl::unexpected(ErrorCode::INVALID_PARAMS); + } + std::unordered_map> segment_buffers; for (size_t i = 0; i < remote_buffers.size(); ++i) { segment_buffers[remote_buffers[i].segment_endpoint].push_back(i); @@ -793,11 +1372,41 @@ DataManager::SubmitTeTransferInternal( return tl::unexpected(ErrorCode::TRANSFER_FAIL); } - TeSubmitResult result; - result.transfer_batches = std::move(submitted_batches); - result.temp_buffer = std::move(temp_buffer_owner); - result.handle = handle; - return result; + return submitted_batches; +} + +tl::expected DataManager::TransferWithTeNoTierStaging( + void* local_transfer_base, size_t total_size, + const std::vector& peer_buffers, + Transport::TransferRequest::OpCode opcode) { + auto validate_result = ValidateRemoteBuffers(peer_buffers); + if (!validate_result) { + return tl::make_unexpected(validate_result.error()); + } + size_t total_remote_size = 0; + for (const auto& buf : peer_buffers) total_remote_size += buf.size; + if (total_remote_size != total_size) { + LOG(ERROR) << "TransferWithTeNoTierStaging: peer buffer size mismatch (" + << total_remote_size << " vs " << total_size << ")"; + return tl::make_unexpected(ErrorCode::INVALID_PARAMS); + } + if (!transfer_engine_->getMetadata()) { + LOG(ERROR) << "TransferEngine not initialized"; + return tl::make_unexpected(ErrorCode::INTERNAL_ERROR); + } + auto batches = SubmitTeTransferBatches(local_transfer_base, total_size, + peer_buffers, opcode); + if (!batches) { + return tl::unexpected(batches.error()); + } + auto wait_result = WaitAllTransferBatches(batches.value()); + if (!wait_result) { + LOG(ERROR) << "TransferWithTeNoTierStaging: WaitAllTransferBatches " + "failed: " + << toString(wait_result.error()); + return wait_result; + } + return {}; } tl::expected DataManager::ValidateRemoteBuffers( @@ -1134,6 +1743,17 @@ tl::expected DataManager::Delete(std::string_view key, ScopedVLogTimer timer(1, "DataManager::Delete"); timer.LogRequest("key=", key); + // NOTE (weak delete semantics): + // TieredBackend::Delete only removes the metadata entry (or a replica + // entry) from the in-memory index. It does NOT directly free underlying + // memory. The actual buffer lifetime is still governed by + // AllocationHandle's RAII reference counting. + // + // We still guard Delete against in-flight 3-phase contexts: + // - PendingWriteRecord holds a strong handle reference until WriteCommit or + // lease cleanup. + // - PinnedKeyRecord holds a strong handle reference until UnPinKey reaches + // ref_count==0 or lease cleanup. std::unique_lock lock(GetKeyLock(key)); auto result = tiered_backend_->Delete(key, tier_id); diff --git a/mooncake-store/src/dummy_client.cpp b/mooncake-store/src/dummy_client.cpp index 746e4b0c6c..f560c75de6 100644 --- a/mooncake-store/src/dummy_client.cpp +++ b/mooncake-store/src/dummy_client.cpp @@ -542,13 +542,13 @@ int64_t DummyClient::getSize(const std::string& key) { } std::shared_ptr DummyClient::get_buffer( - const std::string& key, const ReadRouteConfig& config) { + const std::string& key, const ReadConfigExt& config) { // Dummy client does not use BufferHandle, so we return nullptr return nullptr; } std::tuple DummyClient::get_buffer_info( - const std::string& key, const ReadRouteConfig& config) { + const std::string& key, const ReadConfigExt& config) { auto result = invoke_rpc<&RealClient::get_buffer_info_dummy_helper, std::tuple>(key, config, client_id_); @@ -560,13 +560,13 @@ std::tuple DummyClient::get_buffer_info( } std::vector> DummyClient::batch_get_buffer( - const std::vector& keys, const ReadRouteConfig& config) { + const std::vector& keys, const ReadConfigExt& config) { // TODO: implement this function return std::vector>(); } int64_t DummyClient::get_into(const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { // TODO: implement this function return -1; } @@ -604,7 +604,7 @@ int DummyClient::put_from(const std::string& key, void* buffer, size_t size, std::vector DummyClient::batch_get_into( const std::vector& keys, const std::vector& buffer_ptrs, - const std::vector& sizes, const ReadRouteConfig& config) { + const std::vector& sizes, const ReadConfigExt& config) { std::vector buffers; for (auto ptr : buffer_ptrs) { buffers.push_back(reinterpret_cast(ptr)); @@ -644,7 +644,7 @@ std::vector DummyClient::batch_get_into_multi_buffers( const std::vector& keys, const std::vector>& all_buffer_ptrs, const std::vector>& all_sizes, - bool aggregate_same_segment_task, const ReadRouteConfig& config) { + bool aggregate_same_segment_task, const ReadConfigExt& config) { // TODO: implement this function std::vector vec(keys.size(), -1); return vec; diff --git a/mooncake-store/src/p2p_client_service.cpp b/mooncake-store/src/p2p_client_service.cpp index 46dfe3f808..412d71a96f 100644 --- a/mooncake-store/src/p2p_client_service.cpp +++ b/mooncake-store/src/p2p_client_service.cpp @@ -17,6 +17,73 @@ namespace mooncake { +namespace { + +// UnPin after forward read (or cleanup after TE failure): retry only for +// "other" errors. INVALID_READ = token mismatch (treat as released for flow). +// RPC_FAIL = transport/timeout-like (no repeat; owner may TTL-clean). +// LEASE_EXPIRED = server already expired the pin record. Same max-attempt count +// is used for WriteRevoke after forward write TE failure. +constexpr int kForwardReadUnpinMaxAttempts = 3; + +bool UnPinErrorTreatAsEffectiveOk(ErrorCode e) { + switch (e) { + case ErrorCode::INVALID_READ: + case ErrorCode::RPC_FAIL: + case ErrorCode::LEASE_EXPIRED: + return true; + default: + return false; + } +} + +bool WriteRevokeErrorTreatAsEffectiveOk(ErrorCode e) { + switch (e) { + case ErrorCode::INVALID_WRITE: + case ErrorCode::RPC_FAIL: + return true; + default: + return false; + } +} + +bool SlicesAreContiguous(const std::vector& slices) { + if (slices.empty()) { + return false; + } + for (size_t i = 1; i < slices.size(); ++i) { + const char* prev_end = + static_cast(slices[i - 1].ptr) + slices[i - 1].size; + if (prev_end != static_cast(slices[i].ptr)) { + return false; + } + } + return true; +} + +size_t TotalSliceBytes(const std::vector& slices) { + size_t t = 0; + for (const auto& s : slices) { + t += s.size; + } + return t; +} + +bool RemoteDestBuffersContiguous(const std::vector& bufs) { + if (bufs.empty()) { + return false; + } + for (size_t i = 1; i < bufs.size(); ++i) { + uintptr_t prev_end = bufs[i - 1].addr + bufs[i - 1].size; + if (prev_end != bufs[i].addr) { + return false; + } + } + return true; +} + +} // namespace + // ============================================================================ // Construction / Destruction // ============================================================================ @@ -248,8 +315,8 @@ ErrorCode P2PClientService::InitStorage(const P2PClientConfig& config) { config.local_memcpy_async_worker_num; } - data_manager_ = DataManager(std::move(tiered_backend), transfer_engine_, - config.lock_shard_count, local_transfer_config); + data_manager_.emplace(std::move(tiered_backend), transfer_engine_, + config.lock_shard_count, local_transfer_config); // Set rectify callback on DataManager to remove stale replicas from master data_manager_->SetRectifyCallback([this](std::string_view key, std::optional tier_id) { @@ -558,7 +625,16 @@ std::vector> P2PClientService::BatchPut( } auto guard = AcquireInflightGuard(); - const auto* route_config = std::get_if(&config); + const auto* plain_route = std::get_if(&config); + const auto* ext_route = std::get_if(&config); + const WriteRouteRequestConfig* route_cfg_ptr = nullptr; + RdmaDirectionMode rdma_write_mode = RdmaDirectionMode::REVERSE; + if (plain_route) { + route_cfg_ptr = plain_route; + } else if (ext_route) { + route_cfg_ptr = &ext_route->route_config; + rdma_write_mode = ext_route->rdma_direction_mode; + } if (!guard.is_valid()) { LOG(ERROR) << "client is shutting down"; std::fill(results.begin(), results.end(), @@ -567,13 +643,14 @@ std::vector> P2PClientService::BatchPut( LOG(ERROR) << "BatchPut input size mismatch"; std::fill(results.begin(), results.end(), tl::unexpected(ErrorCode::INVALID_PARAMS)); - } else if (!route_config) { - LOG(ERROR) << "P2PClientService currently only supports " - "WriteRouteRequestConfig"; + } else if (!route_cfg_ptr) { + LOG(ERROR) << "P2PClientService expects WriteRouteRequestConfig or " + "WriteConfigExt"; std::fill(results.begin(), results.end(), tl::unexpected(ErrorCode::INVALID_PARAMS)); } else { - results = InnerBatchPut(keys, batched_slices, *route_config); + results = InnerBatchPut(keys, batched_slices, *route_cfg_ptr, + rdma_write_mode); } size_t success_count = 0; @@ -612,11 +689,13 @@ std::vector> P2PClientService::BatchPut( std::vector> P2PClientService::InnerBatchPut( const std::vector& keys, std::vector>& batched_slices, - const WriteRouteRequestConfig& route_config) { + const WriteRouteRequestConfig& route_config, + RdmaDirectionMode rdma_direction_mode) { if (ha_manager_ && ha_manager_->IsDegraded()) { return InnerBatchPutDegraded(keys, batched_slices); } - return InnerBatchPutNormal(keys, batched_slices, route_config); + return InnerBatchPutNormal(keys, batched_slices, route_config, + rdma_direction_mode); } std::vector> @@ -691,7 +770,8 @@ std::vector> P2PClientService::InnerBatchPutNormal( const std::vector& keys, std::vector>& batched_slices, - const WriteRouteRequestConfig& route_config) { + const WriteRouteRequestConfig& route_config, + RdmaDirectionMode rdma_direction_mode) { // Phase 1: fetch write routes from master. auto batch_routes = BatchFetchWriteRoutes(keys, batched_slices, route_config); @@ -704,8 +784,9 @@ P2PClientService::InnerBatchPutNormal( // Phase 2: // 2.1: async dispatch first-candidate writes for each key // 2.2: wrap each key in a retry chain based on rotute - auto handles = CreatePutHandlesFromRoute(keys, batched_slices, route_config, - batch_routes.value()); + auto handles = + CreatePutHandlesFromRoute(keys, batched_slices, route_config, + rdma_direction_mode, batch_routes.value()); // Phase 3: wait every retry chain and collect results. return CollectResults(handles, keys); @@ -740,6 +821,7 @@ P2PClientService::CreatePutHandlesFromRoute( const std::vector& keys, std::vector>& batched_slices, const WriteRouteRequestConfig& route_config, + RdmaDirectionMode rdma_direction_mode, BatchGetWriteRouteResponse& batch_resp) { struct WriteTask { std::unique_ptr> first_task; @@ -757,6 +839,7 @@ P2PClientService::CreatePutHandlesFromRoute( continue; } auto ops = BuildWriteOps(keys[i], batched_slices[i], route_config, + rdma_direction_mode, std::move(batch_resp.responses[i].candidates)); if (!ops) { LOG(ERROR) << "fail to build write ops" @@ -812,6 +895,7 @@ P2PClientService::CreatePutHandlesFromRoute( auto P2PClientService::BuildWriteOps(std::string_view key, std::vector& slices, const WriteRouteRequestConfig& config, + RdmaDirectionMode rdma_direction_mode, std::vector candidates) -> tl::expected>, ErrorCode> { if (candidates.empty()) { @@ -848,9 +932,15 @@ auto P2PClientService::BuildWriteOps(std::string_view key, } else { std::string endpoint = proxy.ip_address + ":" + std::to_string(proxy.rpc_port); + DataManager* fwd_dm = + (rdma_direction_mode == RdmaDirectionMode::FORWARD && + data_manager_.has_value()) + ? &*data_manager_ + : nullptr; write_ops.push_back(std::make_unique( - &GetOrCreatePeerClient(endpoint), write_req, std::move(proxy), - route_cache_ ? &*route_cache_ : nullptr, endpoint)); + this, &GetOrCreatePeerClient(endpoint), write_req, + std::move(proxy), route_cache_ ? &*route_cache_ : nullptr, + endpoint, fwd_dm, &slices, rdma_direction_mode)); } } @@ -894,6 +984,44 @@ std::unique_ptr> P2PClientService::LocalWriteOp::Dispatch() { } std::unique_ptr> P2PClientService::RemoteWriteOp::Dispatch() { + if (!owner_service) { + LOG(ERROR) << "Remote write requires P2PClientService"; + return CallableTaskHandle::Create( + []() -> tl::expected { + return tl::unexpected(ErrorCode::INTERNAL_ERROR); + }); + } + if (rdma_direction_mode == RdmaDirectionMode::FORWARD) { + return owner_service->StartForwardRemotePut(peer_ptr, forward_dm, + forward_slices, write_req); + } + return owner_service->RunReverseRemotePut(peer_ptr, write_req, proxy, + route_cache); +} + +std::unique_ptr> P2PClientService::StartForwardRemotePut( + PeerClient* peer, DataManager* forward_dm, + std::vector* forward_slices, + std::shared_ptr write_req) { + if (!forward_dm || !forward_slices) { + LOG(ERROR) << "Forward RDMA write requires local DataManager"; + return CallableTaskHandle::Create( + []() -> tl::expected { + return tl::unexpected(ErrorCode::INTERNAL_ERROR); + }); + } + auto promise = std::make_shared< + async_simple::Promise>>(); + auto future = promise->getFuture(); + RunForwardRemotePut(std::move(promise), peer, forward_dm, write_req, + forward_slices) + .start([](auto&&) {}); + return FutureHandle::Create(write_req, std::move(future)); +} + +std::unique_ptr> P2PClientService::RunReverseRemotePut( + PeerClient* peer, std::shared_ptr write_req, + const P2PProxyDescriptor& proxy, RouteCache* route_cache) { auto promise = std::make_shared< async_simple::Promise>>(); auto future = promise->getFuture(); @@ -902,7 +1030,7 @@ std::unique_ptr> P2PClientService::RemoteWriteOp::Dispatch() { auto cached_proxy = proxy; auto* cache = route_cache; - peer_ptr->AsyncWriteRemoteData(*write_req) + peer->AsyncWriteRemoteData(*write_req) .start([promise, req, cached_proxy, cache](async_simple::Try>&& remote_res) mutable { @@ -937,6 +1065,91 @@ std::unique_ptr> P2PClientService::RemoteWriteOp::Dispatch() { return FutureHandle::Create(req, std::move(future)); } +async_simple::coro::Lazy P2PClientService::RunForwardRemotePut( + std::shared_ptr>> + promise, + PeerClient* peer, DataManager* dm, + std::shared_ptr write_req, std::vector* slices) { + if (!peer || !dm || !write_req || !slices) { + tl::expected err = + tl::make_unexpected(ErrorCode::INTERNAL_ERROR); + promise->setValue(std::move(err)); + co_return; + } + if (!SlicesAreContiguous(*slices)) { + LOG(ERROR) + << "Forward RDMA write requires contiguous slice buffers, key=" + << write_req->key; + tl::expected err = + tl::make_unexpected(ErrorCode::NON_CONTIGUOUS_BUFFER_NOT_SUPPORTED); + promise->setValue(std::move(err)); + co_return; + } + PreWriteRequest pre_req; + pre_req.key = write_req->key; + pre_req.size_bytes = TotalSliceBytes(*slices); + pre_req.target_tier_id = write_req->target_tier_id; + + auto pre = co_await peer->AsyncPreWrite(pre_req); + if (!pre) { + if (!IsAlreadyExistsError(pre.error())) { + LOG(ERROR) << "AsyncPreWrite failed, key=" << write_req->key + << ", error=" << pre.error(); + } + tl::expected err = tl::make_unexpected(pre.error()); + promise->setValue(std::move(err)); + co_return; + } + + std::vector dest{pre.value().remote_buffer}; + void* base = slices->front().ptr; + auto te = + dm->TransferWithTeNoTierStaging(base, TotalSliceBytes(*slices), dest, + Transport::TransferRequest::WRITE); + if (!te) { + LOG(ERROR) << "Forward TE write failed, key=" << write_req->key + << ", error=" << te.error(); + WriteRevokeRequest revoke_req; + revoke_req.key = write_req->key; + revoke_req.pending_write_token = pre.value().pending_write_token; + tl::expected revoke_res; + for (int attempt = 0; attempt < kForwardReadUnpinMaxAttempts; + ++attempt) { + revoke_res = co_await peer->AsyncWriteRevoke(revoke_req); + if (revoke_res) { + break; + } + if (WriteRevokeErrorTreatAsEffectiveOk(revoke_res.error())) { + revoke_res = tl::expected{}; + break; + } + if (attempt + 1 < kForwardReadUnpinMaxAttempts) { + LOG(WARNING) << "AsyncWriteRevoke retry after TE failure, key=" + << write_req->key << ", attempt=" << (attempt + 1) + << ", error=" << revoke_res.error(); + } + } + if (!revoke_res) { + LOG(ERROR) << "AsyncWriteRevoke failed after TE failure, key=" + << write_req->key << ", error=" << revoke_res.error(); + } + tl::expected err = tl::make_unexpected(te.error()); + promise->setValue(std::move(err)); + co_return; + } + + WriteCommitRequest commit; + commit.key = write_req->key; + commit.pending_write_token = pre.value().pending_write_token; + auto cm = co_await peer->AsyncWriteCommit(commit); + if (!cm) { + tl::expected err = tl::make_unexpected(cm.error()); + promise->setValue(std::move(err)); + co_return; + } + promise->setValue(tl::expected{}); +} + async_simple::coro::Lazy P2PClientService::RunWriteWithRetry( std::shared_ptr>> promise, @@ -996,20 +1209,20 @@ async_simple::coro::Lazy P2PClientService::RunWriteWithRetry( tl::expected, ErrorCode> P2PClientService::Get( const std::string& key, std::shared_ptr allocator, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { return std::move(BatchGet({key}, allocator, config)[0]); } tl::expected P2PClientService::Get( const std::string& key, const std::vector& buffers, - const std::vector& sizes, const ReadRouteConfig& config) { + const std::vector& sizes, const ReadConfigExt& config) { return std::move(BatchGet({key}, {buffers}, {sizes}, config)[0]); } std::vector, ErrorCode>> P2PClientService::BatchGet(const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { if (!allocator) { LOG(ERROR) << "Client buffer allocator is not provided"; return std::vector< @@ -1030,7 +1243,7 @@ std::vector> P2PClientService::BatchGet( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - const ReadRouteConfig& config, bool /*aggregate_same_segment_task*/) { + const ReadConfigExt& config, bool /*aggregate_same_segment_task*/) { if (keys.size() != all_buffers.size() || keys.size() != all_sizes.size()) { LOG(ERROR) << "Input vector sizes mismatch"; return std::vector>( @@ -1140,7 +1353,7 @@ std::vector> P2PClientService::BatchCreateGetHandles( const std::vector& keys, std::shared_ptr allocator, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { auto local_get = [&](std::string_view key, size_t) -> tl::expected { if (!data_manager_.has_value()) { @@ -1159,8 +1372,7 @@ P2PClientService::BatchCreateGetHandles( std::vector> P2PClientService::BatchCreateGetHandles( const std::vector& keys, - std::vector>& all_slices, - const ReadRouteConfig& config) { + std::vector>& all_slices, const ReadConfigExt& config) { auto local_get = [&](std::string_view key, size_t i) -> tl::expected { if (!data_manager_.has_value()) { @@ -1180,7 +1392,7 @@ P2PClientService::BatchCreateGetHandles( template std::vector> P2PClientService::BatchCreateGetHandlesImpl( - const std::vector& keys, const ReadRouteConfig& config, + const std::vector& keys, const ReadConfigExt& config, LocalGetFn&& local_get, RemoteGetFn&& remote_get) { std::vector> handles; handles.reserve(keys.size()); @@ -1240,7 +1452,7 @@ P2PClientService::BatchCreateGetHandlesImpl( std::vector< tl::expected, ErrorCode>> P2PClientService::BatchFetchReadRoutes( - const std::vector& keys, const ReadRouteConfig& config) { + const std::vector& keys, const ReadConfigExt& config) { std::vector, ErrorCode>> result( keys.size(), std::vector{}); @@ -1263,7 +1475,8 @@ P2PClientService::BatchFetchReadRoutes( // Single batch RPC to master std::vector> responses; - responses = master_client_.BatchGetReplicaList(miss_keys, config); + responses = + master_client_.BatchGetReplicaList(miss_keys, config.route_config); for (size_t k = 0; k < responses.size(); ++k) { if (!responses[k]) { if (responses[k].error() != ErrorCode::OBJECT_NOT_FOUND) { @@ -1345,7 +1558,7 @@ std::vector P2PClientService::ReplicasToRoutes( tl::expected P2PClientService::CreateRemoteGetHandle( std::string_view key, std::shared_ptr allocator, - const ReadRouteConfig& config, std::vector pre_fetched) { + const ReadConfigExt& config, std::vector pre_fetched) { auto iter = BuildRouteIter(key, config, std::move(pre_fetched)); if (!iter) { LOG(ERROR) << "Failed to build route iterator, key=" << key @@ -1362,7 +1575,8 @@ tl::expected P2PClientService::CreateRemoteGetHandle( auto read_buf = std::make_shared(std::move(*alloc_result)); std::vector slices = {{read_buf->ptr(), object_size}}; - auto result = InnerGetViaRoute(key, slices, std::move(*iter)); + auto result = InnerGetViaRoute(key, slices, std::move(*iter), + config.rdma_direction_mode); if (!result) { LOG(ERROR) << "Failed to get via route, key=" << key << ", error=" << result.error(); @@ -1377,7 +1591,7 @@ tl::expected P2PClientService::CreateRemoteGetHandle( tl::expected P2PClientService::CreateRemoteGetHandle( std::string_view key, std::vector& slices, - const ReadRouteConfig& config, std::vector pre_fetched) { + const ReadConfigExt& config, std::vector pre_fetched) { auto iter = BuildRouteIter(key, config, std::move(pre_fetched)); if (!iter) { if (iter.error() != ErrorCode::OBJECT_NOT_FOUND) { @@ -1386,7 +1600,8 @@ tl::expected P2PClientService::CreateRemoteGetHandle( } return tl::unexpected(iter.error()); } - auto result = InnerGetViaRoute(key, slices, std::move(*iter)); + auto result = InnerGetViaRoute(key, slices, std::move(*iter), + config.rdma_direction_mode); if (!result) { LOG(ERROR) << "Failed to get via route, key=" << key << ", error=" << result.error(); @@ -1396,7 +1611,8 @@ tl::expected P2PClientService::CreateRemoteGetHandle( } tl::expected P2PClientService::InnerGetViaRoute( - std::string_view key, std::vector& slices, RouteIterator iter) { + std::string_view key, std::vector& slices, RouteIterator iter, + RdmaDirectionMode rdma_direction_mode) { auto req = std::make_shared(); req->key = key; for (const auto& s : slices) { @@ -1412,7 +1628,8 @@ tl::expected P2PClientService::InnerGetViaRoute( auto future = promise->getFuture(); const uint64_t object_size = iter.object_size(); - RunReadWithRetry(std::move(iter), req, promise).start([](auto&&) {}); + RunReadWithRetry(std::move(iter), req, promise, rdma_direction_mode) + .start([](auto&&) {}); ReadTaskHandle res; res.data_size = object_size; @@ -1421,15 +1638,127 @@ tl::expected P2PClientService::InnerGetViaRoute( return res; } +async_simple::coro::Lazy P2PClientService::RunForwardReadOnRoute( + const ResolvedRoute& route, std::shared_ptr req, + std::shared_ptr>> + promise, + RouteIterator& iter, ErrorCode& final_result) { + if (!data_manager_.has_value()) { + LOG(ERROR) << "Forward RDMA read requires DataManager"; + tl::expected err1 = + tl::make_unexpected(ErrorCode::INTERNAL_ERROR); + promise->setValue(std::move(err1)); + co_return true; + } + if (!RemoteDestBuffersContiguous(req->dest_buffers)) { + LOG(ERROR) << "Forward RDMA read requires contiguous dest buffers, key=" + << req->key; + tl::expected err2 = + tl::make_unexpected(ErrorCode::NON_CONTIGUOUS_BUFFER_NOT_SUPPORTED); + promise->setValue(std::move(err2)); + co_return true; + } + PinKeyRequest pin_req; + pin_req.key = req->key; + auto pin = co_await route.peer->AsyncPinKey(pin_req); + if (!pin) { + if (pin.error() != ErrorCode::OBJECT_NOT_FOUND) { + LOG(ERROR) << "AsyncPinKey failed, key=" << req->key + << ", error=" << pin.error(); + } else { + final_result = pin.error(); + } + iter.Evict(route); + co_return false; + } + void* base = reinterpret_cast(req->dest_buffers[0].addr); + size_t total = 0; + for (const auto& d : req->dest_buffers) { + total += d.size; + } + auto tr = data_manager_->TransferWithTeNoTierStaging( + base, total, {pin.value().remote_buffer}, + Transport::TransferRequest::READ); + if (!tr) { + LOG(ERROR) << "Forward TE read failed, key=" << req->key + << ", error=" << tr.error(); + UnPinKeyRequest cleanup; + cleanup.key = req->key; + cleanup.pin_token = pin.value().pin_token; + tl::expected cleanup_unpin; + for (int attempt = 0; attempt < kForwardReadUnpinMaxAttempts; + ++attempt) { + cleanup_unpin = co_await route.peer->AsyncUnPinKey(cleanup); + if (cleanup_unpin) { + break; + } + if (UnPinErrorTreatAsEffectiveOk(cleanup_unpin.error())) { + cleanup_unpin = tl::expected{}; + break; + } + if (attempt + 1 < kForwardReadUnpinMaxAttempts) { + LOG(WARNING) + << "AsyncUnPinKey retry after TE failure, key=" << req->key + << ", attempt=" << (attempt + 1) + << ", error=" << cleanup_unpin.error(); + } + } + if (!cleanup_unpin) { + LOG(ERROR) << "AsyncUnPinKey failed after TE read failure, key=" + << req->key << ", error=" << cleanup_unpin.error(); + } + iter.Evict(route); + co_return false; + } + UnPinKeyRequest unpin_req; + unpin_req.key = req->key; + unpin_req.pin_token = pin.value().pin_token; + tl::expected unpin_res; + for (int attempt = 0; attempt < kForwardReadUnpinMaxAttempts; ++attempt) { + unpin_res = co_await route.peer->AsyncUnPinKey(unpin_req); + if (unpin_res) { + break; + } + if (UnPinErrorTreatAsEffectiveOk(unpin_res.error())) { + unpin_res = tl::expected{}; + break; + } + if (attempt + 1 < kForwardReadUnpinMaxAttempts) { + LOG(WARNING) << "AsyncUnPinKey retry after forward read, key=" + << req->key << ", attempt=" << (attempt + 1) + << ", error=" << unpin_res.error(); + } + } + if (!unpin_res) { + LOG(ERROR) << "AsyncUnPinKey failed after forward read, key=" + << req->key << ", error=" << unpin_res.error(); + final_result = unpin_res.error(); + iter.Evict(route); + co_return false; + } + tl::expected ok; + promise->setValue(std::move(ok)); + co_return true; +} + // Coroutine iterates route candidates and retries on failure. async_simple::coro::Lazy P2PClientService::RunReadWithRetry( RouteIterator iter, std::shared_ptr req, std::shared_ptr>> - promise) { + promise, + RdmaDirectionMode rdma_direction_mode) { ErrorCode final_result = ErrorCode::OBJECT_NOT_FOUND; try { while (auto route = co_await iter.AsyncNext()) { try { + if (rdma_direction_mode == RdmaDirectionMode::FORWARD) { + if (co_await RunForwardReadOnRoute(*route, req, promise, + iter, final_result)) { + co_return; + } + continue; + } + auto result = co_await route->peer->AsyncReadRemoteData(*req); if (result.has_value()) { tl::expected ok; @@ -1539,13 +1868,13 @@ void P2PClientService::RouteIterator::Evict(const ResolvedRoute& route) { tl::expected P2PClientService::BuildRouteIter(std::string_view key, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { return BuildRouteIter(key, config, LoadCachedRoutes(key)); } tl::expected P2PClientService::BuildRouteIter(std::string_view key, - const ReadRouteConfig& config, + const ReadConfigExt& config, std::vector pre_fetched) { auto routes = std::move(pre_fetched); uint64_t object_size = routes.empty() ? 0 : routes.front().object_size; @@ -1565,9 +1894,9 @@ P2PClientService::BuildRouteIter(std::string_view key, async_simple::coro::Lazy> P2PClientService::AsyncResolveRoutesFromMaster(std::string_view key, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { auto replica_result = - co_await master_client_.AsyncGetReplicaList(key, config); + co_await master_client_.AsyncGetReplicaList(key, config.route_config); if (!replica_result) { if (replica_result.error() != ErrorCode::OBJECT_NOT_FOUND) { LOG(ERROR) << "Failed to query replica list, key=" << key @@ -1649,7 +1978,7 @@ std::vector> P2PClientService::BatchIsExist( // ============================================================================ tl::expected, ErrorCode> P2PClientService::Query( - const std::string& object_key, const ReadRouteConfig& config) { + const std::string& object_key, const ReadConfigExt& config) { auto guard = AcquireInflightGuard(); if (!guard.is_valid()) { LOG(ERROR) << "client is shutting down"; @@ -1664,7 +1993,8 @@ tl::expected, ErrorCode> P2PClientService::Query( } // Query master for replica list - auto result = master_client_.GetReplicaList(object_key, config); + auto result = + master_client_.GetReplicaList(object_key, config.route_config); if (!result) { LOG(WARNING) << "fail to get replica list" << ", key=" << object_key << ", error=" << result.error(); @@ -1676,7 +2006,7 @@ tl::expected, ErrorCode> P2PClientService::Query( std::vector, ErrorCode>> P2PClientService::BatchQuery(const std::vector& object_keys, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { auto guard = AcquireInflightGuard(); if (!guard.is_valid()) { LOG(ERROR) << "client is shutting down"; @@ -1690,7 +2020,8 @@ P2PClientService::BatchQuery(const std::vector& object_keys, } std::vector key_views(object_keys.begin(), object_keys.end()); - auto responses = master_client_.BatchGetReplicaList(key_views, config); + auto responses = + master_client_.BatchGetReplicaList(key_views, config.route_config); std::vector, ErrorCode>> results; results.reserve(responses.size()); for (size_t i = 0; i < responses.size(); ++i) { diff --git a/mooncake-store/src/peer_client.cpp b/mooncake-store/src/peer_client.cpp index ad44b78b0c..ab558cccf9 100644 --- a/mooncake-store/src/peer_client.cpp +++ b/mooncake-store/src/peer_client.cpp @@ -78,6 +78,128 @@ PeerClient::AsyncWriteRemoteData(const RemoteWriteRequest& request) { co_return result->result(); } +async_simple::coro::Lazy> +PeerClient::AsyncPreWrite(const PreWriteRequest& request) { + if (!client_pool_) { + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto ret = co_await client_pool_->send_request( + [&](coro_io::client_reuse_hint, coro_rpc::coro_rpc_client& client) { + return client.send_request<&ClientRpcService::PreWrite>(request); + }); + if (!ret.has_value()) { + LOG(ERROR) << "AsyncPreWrite: client not available"; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto result = co_await std::move(ret.value()); + if (!result) { + LOG(ERROR) << "AsyncPreWrite: RPC call failed: " << result.error().msg; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + co_return result->result(); +} + +async_simple::coro::Lazy> +PeerClient::AsyncWriteCommit(const WriteCommitRequest& request) { + if (!client_pool_) { + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto ret = co_await client_pool_->send_request( + [&](coro_io::client_reuse_hint, coro_rpc::coro_rpc_client& client) { + return client.send_request<&ClientRpcService::WriteCommit>(request); + }); + if (!ret.has_value()) { + LOG(ERROR) << "AsyncWriteCommit: client not available"; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto result = co_await std::move(ret.value()); + if (!result) { + LOG(ERROR) << "AsyncWriteCommit: RPC call failed: " + << result.error().msg; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + co_return result->result(); +} + +async_simple::coro::Lazy> +PeerClient::AsyncWriteRevoke(const WriteRevokeRequest& request) { + if (!client_pool_) { + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto ret = co_await client_pool_->send_request( + [&](coro_io::client_reuse_hint, coro_rpc::coro_rpc_client& client) { + return client.send_request<&ClientRpcService::WriteRevoke>(request); + }); + if (!ret.has_value()) { + LOG(ERROR) << "AsyncWriteRevoke: client not available"; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto result = co_await std::move(ret.value()); + if (!result) { + LOG(ERROR) << "AsyncWriteRevoke: RPC call failed: " + << result.error().msg; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + co_return result->result(); +} + +async_simple::coro::Lazy> +PeerClient::AsyncPinKey(const PinKeyRequest& request) { + if (!client_pool_) { + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto ret = co_await client_pool_->send_request( + [&](coro_io::client_reuse_hint, coro_rpc::coro_rpc_client& client) { + return client.send_request<&ClientRpcService::PinKey>(request); + }); + if (!ret.has_value()) { + LOG(ERROR) << "AsyncPinKey: client not available"; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto result = co_await std::move(ret.value()); + if (!result) { + LOG(ERROR) << "AsyncPinKey: RPC call failed: " << result.error().msg; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + co_return result->result(); +} + +async_simple::coro::Lazy> +PeerClient::AsyncUnPinKey(const UnPinKeyRequest& request) { + if (!client_pool_) { + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto ret = co_await client_pool_->send_request( + [&](coro_io::client_reuse_hint, coro_rpc::coro_rpc_client& client) { + return client.send_request<&ClientRpcService::UnPinKey>(request); + }); + if (!ret.has_value()) { + LOG(ERROR) << "AsyncUnPinKey: client not available"; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + auto result = co_await std::move(ret.value()); + if (!result) { + LOG(ERROR) << "AsyncUnPinKey: RPC call failed: " << result.error().msg; + co_return tl::make_unexpected(ErrorCode::RPC_FAIL); + } + + co_return result->result(); +} + tl::expected PeerClient::ReadRemoteData( const RemoteReadRequest& request) { return async_simple::coro::syncAwait(AsyncReadRemoteData(request)); @@ -88,4 +210,29 @@ tl::expected PeerClient::WriteRemoteData( return async_simple::coro::syncAwait(AsyncWriteRemoteData(request)); } +tl::expected PeerClient::PreWrite( + const PreWriteRequest& request) { + return async_simple::coro::syncAwait(AsyncPreWrite(request)); +} + +tl::expected PeerClient::WriteCommit( + const WriteCommitRequest& request) { + return async_simple::coro::syncAwait(AsyncWriteCommit(request)); +} + +tl::expected PeerClient::WriteRevoke( + const WriteRevokeRequest& request) { + return async_simple::coro::syncAwait(AsyncWriteRevoke(request)); +} + +tl::expected PeerClient::PinKey( + const PinKeyRequest& request) { + return async_simple::coro::syncAwait(AsyncPinKey(request)); +} + +tl::expected PeerClient::UnPinKey( + const UnPinKeyRequest& request) { + return async_simple::coro::syncAwait(AsyncUnPinKey(request)); +} + } // namespace mooncake diff --git a/mooncake-store/src/real_client.cpp b/mooncake-store/src/real_client.cpp index 0fe9a34ab0..42f6881441 100644 --- a/mooncake-store/src/real_client.cpp +++ b/mooncake-store/src/real_client.cpp @@ -824,7 +824,7 @@ tl::expected RealClient::unregister_shm_buffer_internal( std::shared_ptr RealClient::get_buffer_internal( const std::string& key, std::shared_ptr client_buffer_allocator, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { if (!client_service_) { LOG(ERROR) << "Client is not initialized"; return nullptr; @@ -848,12 +848,12 @@ std::shared_ptr RealClient::get_buffer_internal( // Implementation of get_buffer method std::shared_ptr RealClient::get_buffer( - const std::string& key, const ReadRouteConfig& config) { + const std::string& key, const ReadConfigExt& config) { return get_buffer_internal(key, client_buffer_allocator_, config); } std::tuple RealClient::get_buffer_info( - const std::string& key, const ReadRouteConfig& config) { + const std::string& key, const ReadConfigExt& config) { auto buffer_handle = get_buffer_internal(key, client_buffer_allocator_, config); if (!buffer_handle) { @@ -867,7 +867,7 @@ std::tuple RealClient::get_buffer_info( tl::expected, ErrorCode> RealClient::get_buffer_info_dummy_helper(const std::string& key, - const ReadRouteConfig& config, + const ReadConfigExt& config, const UUID& client_id) { std::shared_lock lock(dummy_client_mutex_); auto it = shm_contexts_.find(client_id); @@ -905,7 +905,7 @@ RealClient::get_buffer_info_dummy_helper(const std::string& key, // Implementation of batch_get_buffer_internal method std::vector> RealClient::batch_get_buffer_internal(const std::vector& keys, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { std::vector> final_results(keys.size(), nullptr); @@ -938,7 +938,7 @@ RealClient::batch_get_buffer_internal(const std::vector& keys, // Implementation of batch_get_buffer method std::vector> RealClient::batch_get_buffer( - const std::vector& keys, const ReadRouteConfig& config) { + const std::vector& keys, const ReadConfigExt& config) { return batch_get_buffer_internal(keys, config); } @@ -978,7 +978,7 @@ int RealClient::unregister_buffer(void* buffer) { tl::expected RealClient::get_into_internal( const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { // NOTE: The buffer address must be previously registered with // register_buffer() for zero-copy RDMA operations to work correctly if (!client_service_) { @@ -990,7 +990,7 @@ tl::expected RealClient::get_into_internal( } int64_t RealClient::get_into(const std::string& key, void* buffer, size_t size, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { return to_py_ret(get_into_internal(key, buffer, size, config)); } @@ -1173,7 +1173,7 @@ int RealClient::put_from(const std::string& key, void* buffer, size_t size, std::vector RealClient::batch_get_into( const std::vector& keys, const std::vector& buffers, - const std::vector& sizes, const ReadRouteConfig& config) { + const std::vector& sizes, const ReadConfigExt& config) { auto internal_results = batch_get_into_internal(keys, buffers, sizes, config); std::vector results; @@ -1190,7 +1190,7 @@ std::vector> RealClient::batch_get_into_dummy_helper( const std::vector& keys, const std::vector& dummy_buffers, - const std::vector& sizes, const ReadRouteConfig& config, + const std::vector& sizes, const ReadConfigExt& config, const UUID& client_id) { std::shared_lock lock(dummy_client_mutex_); auto it = shm_contexts_.find(client_id); @@ -1245,7 +1245,7 @@ std::vector> RealClient::batch_get_into_internal(const std::vector& keys, const std::vector& buffers, const std::vector& sizes, - const ReadRouteConfig& config) { + const ReadConfigExt& config) { // Validate preconditions if (!client_service_) { LOG(ERROR) << "Client is not initialized"; @@ -1405,7 +1405,7 @@ std::vector RealClient::batch_get_into_multi_buffers( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - bool prefer_alloc_in_same_node, const ReadRouteConfig& config) { + bool prefer_alloc_in_same_node, const ReadConfigExt& config) { auto start = std::chrono::steady_clock::now(); auto internal_results = batch_get_into_multi_buffers_internal( keys, all_buffers, all_sizes, prefer_alloc_in_same_node, config); @@ -1427,7 +1427,7 @@ RealClient::batch_get_into_multi_buffers_internal( const std::vector& keys, const std::vector>& all_buffers, const std::vector>& all_sizes, - bool prefer_alloc_in_same_node, const ReadRouteConfig& config) { + bool prefer_alloc_in_same_node, const ReadConfigExt& config) { // Validate preconditions if (!client_service_) { LOG(ERROR) << "Client is not initialized"; diff --git a/mooncake-store/src/types.cpp b/mooncake-store/src/types.cpp index 2506c18738..1f8d912d02 100644 --- a/mooncake-store/src/types.cpp +++ b/mooncake-store/src/types.cpp @@ -21,6 +21,8 @@ const std::string& toString(ErrorCode errorCode) noexcept { {ErrorCode::WRITE_FAIL, "WRITE_FAIL"}, {ErrorCode::INVALID_PARAMS, "INVALID_PARAMS"}, {ErrorCode::ILLEGAL_CLIENT, "ILLEGAL_CLIENT"}, + {ErrorCode::NON_CONTIGUOUS_BUFFER_NOT_SUPPORTED, + "NON_CONTIGUOUS_BUFFER_NOT_SUPPORTED"}, {ErrorCode::INVALID_WRITE, "INVALID_WRITE"}, {ErrorCode::INVALID_READ, "INVALID_READ"}, {ErrorCode::INVALID_REPLICA, "INVALID_REPLICA"}, diff --git a/mooncake-store/tests/client_rpc_service_test.cpp b/mooncake-store/tests/client_rpc_service_test.cpp index 63130adb42..4745484b04 100644 --- a/mooncake-store/tests/client_rpc_service_test.cpp +++ b/mooncake-store/tests/client_rpc_service_test.cpp @@ -203,6 +203,187 @@ TEST_F(ClientRpcServiceTest, ReadRemoteDataKeyNotFound) { auto result = rpc_service_->ReadRemoteData(request); ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::OBJECT_NOT_FOUND); +} + +// ============================================================================ +// PinKey / UnPinKey (forward read control plane) +// ============================================================================ + +TEST_F(ClientRpcServiceTest, PinKeyEmptyKey) { + PinKeyRequest req; + req.key = ""; + req.target_tier_id = std::nullopt; + + auto result = rpc_service_->PinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(ClientRpcServiceTest, PinKeyObjectNotFound) { + PinKeyRequest req; + req.key = "rpc_svc_pin_missing_key"; + req.target_tier_id = std::nullopt; + + auto result = rpc_service_->PinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::OBJECT_NOT_FOUND); +} + +TEST_F(ClientRpcServiceTest, PinKeyAfterPutThenUnPin) { + const std::string key = "rpc_svc_pin_after_put"; + const std::string blob = "payload"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()) << "Put failed"; + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto pin_res = rpc_service_->PinKey(pin_req); + ASSERT_TRUE(pin_res.has_value()) + << "PinKey failed: " << static_cast(pin_res.error()); + EXPECT_GT(pin_res->remote_buffer.size, 0u); + EXPECT_NE(pin_res->pin_token.first, 0u); + EXPECT_NE(pin_res->pin_token.second, 0u); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount. + UnPinKeyRequest unpin; + unpin.key = key; + unpin.pin_token = pin_res->pin_token; + auto unpin_res = rpc_service_->UnPinKey(unpin); + ASSERT_TRUE(unpin_res.has_value()) + << "UnPinKey failed: " << static_cast(unpin_res.error()); +} + +TEST_F(ClientRpcServiceTest, PinKeyTwiceSameTokenThenUnpinTwice) { + const std::string key = "rpc_svc_pin_twice_ref"; + const std::string blob = "ref"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto first = rpc_service_->PinKey(pin_req); + ASSERT_TRUE(first.has_value()) + << "first PinKey failed: " << static_cast(first.error()); + auto second = rpc_service_->PinKey(pin_req); + ASSERT_TRUE(second.has_value()) + << "second PinKey failed: " << static_cast(second.error()); + + EXPECT_EQ(first->pin_token, second->pin_token); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount. + UnPinKeyRequest unpin; + unpin.key = key; + unpin.pin_token = first->pin_token; + auto u1 = rpc_service_->UnPinKey(unpin); + ASSERT_TRUE(u1.has_value()) + << "first UnPinKey failed: " << static_cast(u1.error()); + + auto u2 = rpc_service_->UnPinKey(unpin); + ASSERT_TRUE(u2.has_value()) + << "second UnPinKey failed: " << static_cast(u2.error()); +} + +TEST_F(ClientRpcServiceTest, PinKeyAfterUnpinNewToken) { + const std::string key = "rpc_svc_pin_new_token_after_unpin"; + const std::string blob = "tok"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto pin1 = rpc_service_->PinKey(pin_req); + ASSERT_TRUE(pin1.has_value()) + << "first PinKey failed: " << static_cast(pin1.error()); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount. + UnPinKeyRequest unpin1; + unpin1.key = key; + unpin1.pin_token = pin1->pin_token; + auto un1 = rpc_service_->UnPinKey(unpin1); + ASSERT_TRUE(un1.has_value()) + << "first UnPinKey failed: " << static_cast(un1.error()); + + auto pin2 = rpc_service_->PinKey(pin_req); + ASSERT_TRUE(pin2.has_value()) << "second PinKey after unpin failed: " + << static_cast(pin2.error()); + EXPECT_NE(pin1->pin_token, pin2->pin_token); + + UnPinKeyRequest unpin2; + unpin2.key = key; + unpin2.pin_token = pin2->pin_token; + auto un2 = rpc_service_->UnPinKey(unpin2); + ASSERT_TRUE(un2.has_value()) + << "second UnPinKey failed: " << static_cast(un2.error()); +} + +TEST_F(ClientRpcServiceTest, UnPinKeyEmptyKey) { + UnPinKeyRequest req; + req.key = ""; + req.pin_token = {1, 2}; + + auto result = rpc_service_->UnPinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(ClientRpcServiceTest, UnPinKeyZeroToken) { + UnPinKeyRequest req; + req.key = "rpc_svc_unpin_zero"; + req.pin_token = {0, 0}; + + auto result = rpc_service_->UnPinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(ClientRpcServiceTest, UnPinKeyWrongTokenAfterPin) { + const std::string key = "rpc_svc_unpin_wrong_token"; + const std::string blob = "x"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + auto pin_res = rpc_service_->PinKey(pin_req); + ASSERT_TRUE(pin_res.has_value()); + + UnPinKeyRequest bad; + bad.key = key; + bad.pin_token = pin_res->pin_token; + bad.pin_token.first += 1; + auto bad_res = rpc_service_->UnPinKey(bad); + ASSERT_FALSE(bad_res.has_value()); + EXPECT_EQ(bad_res.error(), ErrorCode::INVALID_READ); + + UnPinKeyRequest ok; + ok.key = key; + ok.pin_token = pin_res->pin_token; + auto ok_res = rpc_service_->UnPinKey(ok); + ASSERT_TRUE(ok_res.has_value()) << "UnPinKey with correct token failed: " + << static_cast(ok_res.error()); } // ============================================================================ @@ -289,4 +470,96 @@ TEST_F(ClientRpcServiceTest, WriteRemoteDataWithTierId) { EXPECT_EQ(result.error(), ErrorCode::INTERNAL_ERROR); } +// ============================================================================ +// WriteRevoke +// ============================================================================ + +TEST_F(ClientRpcServiceTest, WriteRevokeInvalidKey) { + WriteRevokeRequest request; + request.key = ""; + request.pending_write_token = {1, 0}; + + auto result = rpc_service_->WriteRevoke(request); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(ClientRpcServiceTest, WriteRevokeInvalidZeroToken) { + const std::string key = "rpc_svc_revoke_zero_token"; + WriteRevokeRequest request; + request.key = key; + request.pending_write_token = {0, 0}; + + auto result = rpc_service_->WriteRevoke(request); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(ClientRpcServiceTest, WriteRevokeIdempotentNoPendingRecord) { + const std::string key = "rpc_svc_revoke_no_pending"; + WriteRevokeRequest request; + request.key = key; + request.pending_write_token = {100, 200}; + + auto result = rpc_service_->WriteRevoke(request); + ASSERT_TRUE(result.has_value()) + << "WriteRevoke on missing pending should be OK (idempotent)"; +} + +TEST_F(ClientRpcServiceTest, WriteRevokeAfterPreWrite) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "rpc_svc_revoke_after_prewrite"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 256; + pre.target_tier_id = tier_id; + + auto pre_res = rpc_service_->PreWrite(pre); + ASSERT_TRUE(pre_res.has_value()) + << "PreWrite failed: " << static_cast(pre_res.error()); + + WriteRevokeRequest revoke; + revoke.key = key; + revoke.pending_write_token = pre_res->pending_write_token; + auto rev_res = rpc_service_->WriteRevoke(revoke); + ASSERT_TRUE(rev_res.has_value()) + << "WriteRevoke failed: " << static_cast(rev_res.error()); + + auto again = rpc_service_->WriteRevoke(revoke); + ASSERT_TRUE(again.has_value()) + << "Second WriteRevoke on same key/token should be idempotent OK"; +} + +TEST_F(ClientRpcServiceTest, WriteRevokeTokenMismatch) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "rpc_svc_revoke_token_mismatch"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 64; + pre.target_tier_id = tier_id; + + auto pre_res = rpc_service_->PreWrite(pre); + ASSERT_TRUE(pre_res.has_value()); + + UUID wrong_token = pre_res->pending_write_token; + wrong_token.first += 1; + + WriteRevokeRequest bad; + bad.key = key; + bad.pending_write_token = wrong_token; + auto bad_res = rpc_service_->WriteRevoke(bad); + ASSERT_FALSE(bad_res.has_value()); + EXPECT_EQ(bad_res.error(), ErrorCode::INVALID_WRITE); + + WriteRevokeRequest good; + good.key = key; + good.pending_write_token = pre_res->pending_write_token; + auto good_res = rpc_service_->WriteRevoke(good); + ASSERT_TRUE(good_res.has_value()); +} + } // namespace mooncake diff --git a/mooncake-store/tests/data_manager_test.cpp b/mooncake-store/tests/data_manager_test.cpp index 18a189a921..1849cdf7fd 100644 --- a/mooncake-store/tests/data_manager_test.cpp +++ b/mooncake-store/tests/data_manager_test.cpp @@ -284,6 +284,257 @@ TEST_F(DataManagerTest, DeleteWithTierId) { ASSERT_FALSE(data_manager_->Exist(key)); } +// Test PreWrite: concurrent PreWrite should be rejected by a pending lease. +TEST_F(DataManagerTest, PreWriteRejectsConcurrentLease) { + const std::string key = "prewrite_lifecycle_key"; + auto prewrite_result = data_manager_->PreWrite(key, 256, GetTierId()); + ASSERT_TRUE(prewrite_result.has_value()) + << "PreWrite failed: " << toString(prewrite_result.error()); + + auto second_prewrite = data_manager_->PreWrite(key, 256, GetTierId()); + ASSERT_FALSE(second_prewrite.has_value()); + EXPECT_EQ(second_prewrite.error(), ErrorCode::OBJECT_HAS_LEASE); + + auto& shard = data_manager_->GetPendingWriteShard(key); + { + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + EXPECT_EQ(it->second.pending_write_token, + prewrite_result->pending_write_token); + } +} + +// Test WriteCommit: successful commit should erase the pending write record. +TEST_F(DataManagerTest, WriteCommitErasesPendingWriteRecord) { + const std::string key = "write_commit_erases_record_key"; + auto prewrite_result = data_manager_->PreWrite(key, 256, GetTierId()); + ASSERT_TRUE(prewrite_result.has_value()) + << "PreWrite failed: " << toString(prewrite_result.error()); + + auto commit_result = + data_manager_->WriteCommit(key, prewrite_result->pending_write_token); + ASSERT_TRUE(commit_result.has_value()) + << "WriteCommit failed: " << toString(commit_result.error()); + EXPECT_TRUE(data_manager_->Exist(key)); + + auto& shard = data_manager_->GetPendingWriteShard(key); + { + std::shared_lock shard_lock(shard.mutex); + EXPECT_EQ(shard.by_key.count(key), 0U); + } +} + +// Test WriteCommit: token mismatch should fail without erasing the record. +TEST_F(DataManagerTest, WriteCommitTokenMismatchKeepsPendingWriteRecord) { + const std::string key = "write_commit_token_mismatch_key"; + auto prewrite_result = data_manager_->PreWrite(key, 256, GetTierId()); + ASSERT_TRUE(prewrite_result.has_value()) + << "PreWrite failed: " << toString(prewrite_result.error()); + + UUID wrong_token = prewrite_result->pending_write_token; + wrong_token.first += 1; + auto wrong_commit = data_manager_->WriteCommit(key, wrong_token); + ASSERT_FALSE(wrong_commit.has_value()); + EXPECT_EQ(wrong_commit.error(), ErrorCode::INVALID_WRITE); + + auto& shard = data_manager_->GetPendingWriteShard(key); + { + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + EXPECT_EQ(it->second.pending_write_token, + prewrite_result->pending_write_token); + } +} + +// WriteRevoke removes pending allocation without committing the object. +TEST_F(DataManagerTest, WriteRevokeErasesPendingWriteRecord) { + const std::string key = "write_revoke_erases_record_key"; + auto prewrite_result = data_manager_->PreWrite(key, 256, GetTierId()); + ASSERT_TRUE(prewrite_result.has_value()) + << "PreWrite failed: " << toString(prewrite_result.error()); + + auto revoke_result = + data_manager_->WriteRevoke(key, prewrite_result->pending_write_token); + ASSERT_TRUE(revoke_result.has_value()) + << "WriteRevoke failed: " << toString(revoke_result.error()); + + auto& shard = data_manager_->GetPendingWriteShard(key); + { + std::shared_lock shard_lock(shard.mutex); + EXPECT_EQ(shard.by_key.count(key), 0U); + } +} + +// WriteRevoke: wrong token should not erase the pending record. +TEST_F(DataManagerTest, WriteRevokeTokenMismatchKeepsPendingWriteRecord) { + const std::string key = "write_revoke_token_mismatch_key"; + auto prewrite_result = data_manager_->PreWrite(key, 256, GetTierId()); + ASSERT_TRUE(prewrite_result.has_value()) + << "PreWrite failed: " << toString(prewrite_result.error()); + + UUID wrong_token = prewrite_result->pending_write_token; + wrong_token.first += 1; + auto wrong_revoke = data_manager_->WriteRevoke(key, wrong_token); + ASSERT_FALSE(wrong_revoke.has_value()); + EXPECT_EQ(wrong_revoke.error(), ErrorCode::INVALID_WRITE); + + auto& shard = data_manager_->GetPendingWriteShard(key); + { + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + EXPECT_EQ(it->second.pending_write_token, + prewrite_result->pending_write_token); + } +} + +// Test Pin/Unpin: ref_count increments on PinKey and reaches zero on final +// UnPinKey. +TEST_F(DataManagerTest, PinKeyTracksRefCountUntilFinalUnpin) { + const std::string key = "pin_ref_count_key"; + const std::string test_data = "Pin key ref count payload"; + + auto buffer = StringToBuffer(test_data); + ASSERT_TRUE(DoPut(key, buffer.get(), test_data.size()).has_value()); + + auto first_pin = data_manager_->PinKey(key, GetTierId()); + ASSERT_TRUE(first_pin.has_value()) + << "First PinKey failed: " << toString(first_pin.error()); + + auto second_pin = data_manager_->PinKey(key, GetTierId()); + ASSERT_TRUE(second_pin.has_value()) + << "Second PinKey failed: " << toString(second_pin.error()); + EXPECT_EQ(first_pin->pin_token, second_pin->pin_token); + + auto& shard = data_manager_->GetPinnedKeyShard(key); + { + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + EXPECT_EQ(it->second.ref_count, 2U); + } + + auto first_unpin = data_manager_->UnPinKey(key, first_pin->pin_token); + ASSERT_TRUE(first_unpin.has_value()) + << "First UnPinKey failed: " << toString(first_unpin.error()); + { + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + EXPECT_EQ(it->second.ref_count, 1U); + } + + auto second_unpin = data_manager_->UnPinKey(key, second_pin->pin_token); + ASSERT_TRUE(second_unpin.has_value()) + << "Second UnPinKey failed: " << toString(second_unpin.error()); + { + std::shared_lock shard_lock(shard.mutex); + EXPECT_EQ(shard.by_key.count(key), 0U); + } +} + +// Test UnPinKey: token mismatch should fail without erasing or decrementing the +// record. +TEST_F(DataManagerTest, UnPinKeyTokenMismatchKeepsPinnedRecord) { + const std::string key = "unpin_token_mismatch_key"; + const std::string test_data = "Unpin token mismatch payload"; + + auto buffer = StringToBuffer(test_data); + ASSERT_TRUE(DoPut(key, buffer.get(), test_data.size()).has_value()); + + auto pin_result = data_manager_->PinKey(key, GetTierId()); + ASSERT_TRUE(pin_result.has_value()) + << "PinKey failed: " << toString(pin_result.error()); + + UUID wrong_token = pin_result->pin_token; + wrong_token.first += 1; + auto wrong_unpin = data_manager_->UnPinKey(key, wrong_token); + ASSERT_FALSE(wrong_unpin.has_value()); + EXPECT_EQ(wrong_unpin.error(), ErrorCode::INVALID_READ); + + auto& shard = data_manager_->GetPinnedKeyShard(key); + { + std::shared_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + EXPECT_EQ(it->second.ref_count, 1U); + EXPECT_EQ(it->second.pin_token, pin_result->pin_token); + } +} + +// Test WriteCommit: lease expiry causes commit failure and allows subsequent +// PreWrite. +TEST_F(DataManagerTest, + WriteCommitFailsAfterLeaseExpiryAndAllowsRetryPreWrite) { + const std::string key = "expired_prewrite_key"; + auto prewrite_result = data_manager_->PreWrite(key, 512, GetTierId()); + ASSERT_TRUE(prewrite_result.has_value()) + << "PreWrite failed: " << toString(prewrite_result.error()); + + auto& shard = data_manager_->GetPendingWriteShard(key); + { + std::unique_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + const auto expired_deadline = + std::chrono::steady_clock::now() - std::chrono::milliseconds(1); + it->second.deadline = expired_deadline; + it->second.list_it->second = expired_deadline; + } + + auto commit_result = + data_manager_->WriteCommit(key, prewrite_result->pending_write_token); + ASSERT_FALSE(commit_result.has_value()); + EXPECT_EQ(commit_result.error(), ErrorCode::LEASE_EXPIRED); + + { + std::shared_lock shard_lock(shard.mutex); + EXPECT_EQ(shard.by_key.count(key), 0U); + } + + auto retry_prewrite = data_manager_->PreWrite(key, 512, GetTierId()); + ASSERT_TRUE(retry_prewrite.has_value()) + << "PreWrite should succeed after expired record is cleaned: " + << toString(retry_prewrite.error()); +} + +// Test UnPinKey: lease expiry causes unpin failure and cleans up the pin +// record. +TEST_F(DataManagerTest, + ExpiredPinnedLeaseBlocksUnpinButDeleteCanProceedAfterCleanup) { + const std::string key = "expired_pin_key"; + const std::string test_data = "Expired pin payload"; + + auto buffer = StringToBuffer(test_data); + ASSERT_TRUE(DoPut(key, buffer.get(), test_data.size()).has_value()); + + auto pin_result = data_manager_->PinKey(key, GetTierId()); + ASSERT_TRUE(pin_result.has_value()) + << "PinKey failed: " << toString(pin_result.error()); + + auto& shard = data_manager_->GetPinnedKeyShard(key); + { + std::unique_lock shard_lock(shard.mutex); + auto it = shard.by_key.find(key); + ASSERT_NE(it, shard.by_key.end()); + const auto expired_deadline = + std::chrono::steady_clock::now() - std::chrono::milliseconds(1); + it->second.deadline = expired_deadline; + it->second.list_it->second = expired_deadline; + } + + auto unpin_result = data_manager_->UnPinKey(key, pin_result->pin_token); + ASSERT_FALSE(unpin_result.has_value()); + EXPECT_EQ(unpin_result.error(), ErrorCode::LEASE_EXPIRED); + + { + std::shared_lock shard_lock(shard.mutex); + EXPECT_EQ(shard.by_key.count(key), 0U); + } +} + // Test concurrent Put operations TEST_F(DataManagerTest, ConcurrentPut) { const int num_keys = 10; diff --git a/mooncake-store/tests/p2p_client_integration_test.cpp b/mooncake-store/tests/p2p_client_integration_test.cpp index 32d711331f..f6d52e0573 100644 --- a/mooncake-store/tests/p2p_client_integration_test.cpp +++ b/mooncake-store/tests/p2p_client_integration_test.cpp @@ -622,5 +622,49 @@ TEST_F(P2PClientIntegrationTest, LocalGetBufferHandleWithTeTransferMode) { EXPECT_TRUE(unreg_src.has_value()); } +TEST_F(P2PClientIntegrationTest, ForwardRemotePutAndGet) { + const std::vector transfer_modes = {"te", "memcpy"}; + for (const auto& mode : transfer_modes) { + SCOPED_TRACE("local_transfer_mode=" + mode); + + std::string host = "localhost:" + std::to_string(getFreeTcpPort()); + auto remote_writer = CreateP2PClient(host, /*rpc_port=*/0, mode); + ASSERT_NE(remote_writer, nullptr); + + const std::string key = "p2p_fwd_put_get_" + mode + "_" + host; + const std::string payload = "forward_payload_" + mode + "_data"; + + WriteRouteRequestConfig route; + route.allow_local = false; + route.prefer_local = false; + route.max_candidates = WriteRouteRequestConfig::RETURN_ALL_CANDIDATES; + + WriteConfigExt wcfg(route); + wcfg.rdma_direction_mode = RdmaDirectionMode::FORWARD; + + std::vector slices; + slices.emplace_back( + Slice{const_cast(payload.data()), payload.size()}); + auto put_res = remote_writer->Put(key, slices, wcfg); + ASSERT_TRUE(put_res.has_value()) + << "Forward Put failed mode=" << mode + << " err=" << static_cast(put_res.error()); + + ReadConfigExt rcfg; + rcfg.route_config.max_candidates = + GetReplicaListRequestConfig::RETURN_ALL_CANDIDATES; + rcfg.rdma_direction_mode = RdmaDirectionMode::FORWARD; + + std::vector buf(payload.size(), 0); + auto get_res = + remote_writer->Get(key, {(void*)buf.data()}, {buf.size()}, rcfg); + ASSERT_TRUE(get_res.has_value()) + << "Forward Get failed mode=" << mode + << " err=" << static_cast(get_res.error()); + EXPECT_EQ(static_cast(get_res.value()), payload.size()); + EXPECT_EQ(std::string(buf.data(), buf.size()), payload); + } +} + } // namespace testing } // namespace mooncake diff --git a/mooncake-store/tests/peer_client_test.cpp b/mooncake-store/tests/peer_client_test.cpp index e931b5f829..49b281b271 100644 --- a/mooncake-store/tests/peer_client_test.cpp +++ b/mooncake-store/tests/peer_client_test.cpp @@ -272,6 +272,201 @@ TEST_F(PeerClientTest, AsyncReadRemoteDataWithExistingKey) { EXPECT_EQ(result.error(), ErrorCode::INTERNAL_ERROR); } +// ============================================================================ +// PinKey / UnPinKey (async) — forward read control plane over PeerClient +// ============================================================================ + +TEST_F(PeerClientTest, AsyncPinKeyEmptyKey) { + PinKeyRequest req; + req.key = ""; + req.target_tier_id = std::nullopt; + + auto result = async_simple::coro::syncAwait(peer_client_->AsyncPinKey(req)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncPinKeyKeyNotFound) { + PinKeyRequest req; + req.key = "peer_async_pin_missing_key"; + req.target_tier_id = std::nullopt; + + auto result = async_simple::coro::syncAwait(peer_client_->AsyncPinKey(req)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::OBJECT_NOT_FOUND); +} + +TEST_F(PeerClientTest, AsyncPinKeyAfterPut) { + const std::string key = "peer_async_pin_after_put"; + const std::string blob = "payload"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()) << "Put failed"; + put.value()->Wait(); + + PinKeyRequest req; + req.key = key; + req.target_tier_id = std::nullopt; + + auto pin_res = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(req)); + ASSERT_TRUE(pin_res.has_value()) + << "AsyncPinKey failed: " << static_cast(pin_res.error()); + EXPECT_GT(pin_res->remote_buffer.size, 0u); + EXPECT_NE(pin_res->pin_token.first, 0u); + EXPECT_NE(pin_res->pin_token.second, 0u); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount via RPC. + UnPinKeyRequest unpin; + unpin.key = key; + unpin.pin_token = pin_res->pin_token; + auto unpin_res = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin)); + ASSERT_TRUE(unpin_res.has_value()) + << "AsyncUnPinKey failed: " << static_cast(unpin_res.error()); +} + +TEST_F(PeerClientTest, AsyncPinKeyTwiceSameTokenThenUnpinTwice) { + const std::string key = "peer_async_pin_twice_ref"; + const std::string blob = "ref"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto first = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(pin_req)); + ASSERT_TRUE(first.has_value()) + << "first AsyncPinKey failed: " << static_cast(first.error()); + auto second = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(pin_req)); + ASSERT_TRUE(second.has_value()) + << "second AsyncPinKey failed: " << static_cast(second.error()); + + EXPECT_EQ(first->pin_token, second->pin_token); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount via RPC. + UnPinKeyRequest unpin; + unpin.key = key; + unpin.pin_token = first->pin_token; + auto u1 = async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin)); + ASSERT_TRUE(u1.has_value()) + << "first AsyncUnPinKey failed: " << static_cast(u1.error()); + + auto u2 = async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin)); + ASSERT_TRUE(u2.has_value()) + << "second AsyncUnPinKey failed: " << static_cast(u2.error()); +} + +TEST_F(PeerClientTest, AsyncPinKeyAfterUnpinNewToken) { + const std::string key = "peer_async_pin_new_token_after_unpin"; + const std::string blob = "tok"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto pin1 = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(pin_req)); + ASSERT_TRUE(pin1.has_value()) + << "first AsyncPinKey failed: " << static_cast(pin1.error()); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount via RPC. + UnPinKeyRequest unpin1; + unpin1.key = key; + unpin1.pin_token = pin1->pin_token; + auto un1 = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin1)); + ASSERT_TRUE(un1.has_value()) + << "first AsyncUnPinKey failed: " << static_cast(un1.error()); + + auto pin2 = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(pin_req)); + ASSERT_TRUE(pin2.has_value()) << "second AsyncPinKey after unpin failed: " + << static_cast(pin2.error()); + EXPECT_NE(pin1->pin_token, pin2->pin_token); + + UnPinKeyRequest unpin2; + unpin2.key = key; + unpin2.pin_token = pin2->pin_token; + auto un2 = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(unpin2)); + ASSERT_TRUE(un2.has_value()) + << "second AsyncUnPinKey failed: " << static_cast(un2.error()); +} + +TEST_F(PeerClientTest, AsyncUnPinKeyEmptyKey) { + UnPinKeyRequest req; + req.key = ""; + req.pin_token = {1, 2}; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(req)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncUnPinKeyZeroToken) { + const std::string key = "peer_async_unpin_zero_token"; + UnPinKeyRequest req; + req.key = key; + req.pin_token = {0, 0}; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(req)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncUnPinKeyWrongTokenAfterPin) { + const std::string key = "peer_async_unpin_wrong_token"; + const std::string blob = "x"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + auto pin_res = + async_simple::coro::syncAwait(peer_client_->AsyncPinKey(pin_req)); + ASSERT_TRUE(pin_res.has_value()); + + UnPinKeyRequest bad; + bad.key = key; + bad.pin_token = pin_res->pin_token; + bad.pin_token.first += 1; + auto bad_res = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(bad)); + ASSERT_FALSE(bad_res.has_value()); + EXPECT_EQ(bad_res.error(), ErrorCode::INVALID_READ); + + UnPinKeyRequest ok; + ok.key = key; + ok.pin_token = pin_res->pin_token; + auto ok_res = + async_simple::coro::syncAwait(peer_client_->AsyncUnPinKey(ok)); + ASSERT_TRUE(ok_res.has_value()) + << "AsyncUnPinKey with correct token failed: " + << static_cast(ok_res.error()); +} + // ============================================================================ // AsyncWriteRemoteData Tests // ============================================================================ @@ -354,6 +549,215 @@ TEST_F(PeerClientTest, AsyncWriteRemoteDataWithTierId) { EXPECT_EQ(result.error(), ErrorCode::INTERNAL_ERROR); } +// ============================================================================ +// PreWrite / WriteCommit / WriteRevoke (async) +// ============================================================================ + +TEST_F(PeerClientTest, AsyncPreWriteEmptyKey) { + PreWriteRequest pre; + pre.key = ""; + pre.size_bytes = 64; + // target_tier_id optional; invalid key fails before tier selection. + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncPreWriteZeroSize) { + const std::string key = "peer_async_pre_zero_size"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 0; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncPreWriteValidRequest) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "peer_async_pre_valid"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 256; + pre.target_tier_id = tier_id; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); + ASSERT_TRUE(result.has_value()) + << "AsyncPreWrite failed: " << static_cast(result.error()); + EXPECT_GT(result->remote_buffer.size, 0u); + EXPECT_NE(result->pending_write_token.first, 0u); + EXPECT_NE(result->pending_write_token.second, 0u); + + WriteRevokeRequest revoke; + revoke.key = key; + revoke.pending_write_token = result->pending_write_token; + auto rev = + async_simple::coro::syncAwait(peer_client_->AsyncWriteRevoke(revoke)); + ASSERT_TRUE(rev.has_value()) + << "Cleanup AsyncWriteRevoke failed: " << static_cast(rev.error()); +} + +TEST_F(PeerClientTest, AsyncPreWriteWhenObjectAlreadyExists) { + const std::string key = "peer_async_pre_key_exists"; + const std::string blob = "existing"; + auto buffer = StringToBuffer(blob); + std::vector put_slices{{buffer.get(), blob.size()}}; + auto put_result = data_manager_->Put(key, put_slices); + ASSERT_TRUE(put_result.has_value()) << "Put failed"; + put_result.value()->Wait(); + + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 128; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::OBJECT_ALREADY_EXISTS); +} + +TEST_F(PeerClientTest, AsyncWriteCommitAfterPreWrite) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "peer_async_commit_after_pre"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 256; + pre.target_tier_id = tier_id; + + auto pre_res = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); + ASSERT_TRUE(pre_res.has_value()) + << "AsyncPreWrite failed: " << static_cast(pre_res.error()); + + // TransferEngine is not initialized in this unit test, so no real + // data-plane write occurs. Assume the access side has already filled the + // buffer via TE; this case only checks WriteCommit RPC / metadata outcome. + WriteCommitRequest commit; + commit.key = key; + commit.pending_write_token = pre_res->pending_write_token; + auto commit_res = + async_simple::coro::syncAwait(peer_client_->AsyncWriteCommit(commit)); + ASSERT_TRUE(commit_res.has_value()) + << "AsyncWriteCommit failed: " << static_cast(commit_res.error()); +} + +TEST_F(PeerClientTest, AsyncWriteCommitEmptyKey) { + WriteCommitRequest commit; + commit.key = ""; + commit.pending_write_token = {1, 2}; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncWriteCommit(commit)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncWriteCommitZeroToken) { + const std::string key = "peer_async_commit_zero_token"; + WriteCommitRequest commit; + commit.key = key; + commit.pending_write_token = {0, 0}; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncWriteCommit(commit)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncWriteCommitTokenMismatchAfterPreWrite) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "peer_async_commit_token_bad"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 128; + pre.target_tier_id = tier_id; + + auto pre_res = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); + ASSERT_TRUE(pre_res.has_value()); + + WriteCommitRequest commit; + commit.key = key; + UUID wrong_token = pre_res->pending_write_token; + wrong_token.first += 1; + commit.pending_write_token = wrong_token; + + auto bad = + async_simple::coro::syncAwait(peer_client_->AsyncWriteCommit(commit)); + ASSERT_FALSE(bad.has_value()); + EXPECT_EQ(bad.error(), ErrorCode::INVALID_WRITE); + + WriteRevokeRequest revoke; + revoke.key = key; + revoke.pending_write_token = pre_res->pending_write_token; + auto rev = + async_simple::coro::syncAwait(peer_client_->AsyncWriteRevoke(revoke)); + ASSERT_TRUE(rev.has_value()); +} + +TEST_F(PeerClientTest, AsyncWriteRevokeEmptyKey) { + WriteRevokeRequest request; + request.key = ""; + request.pending_write_token = {1, 2}; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncWriteRevoke(request)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncWriteRevokeZeroToken) { + const std::string key = "peer_async_revoke_zero_token"; + WriteRevokeRequest request; + request.key = key; + request.pending_write_token = {0, 0}; + + auto result = + async_simple::coro::syncAwait(peer_client_->AsyncWriteRevoke(request)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, AsyncPreWriteThenWriteRevokeIdempotent) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "peer_async_revoke_after_prewrite"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 256; + pre.target_tier_id = tier_id; + + auto pre_res = + async_simple::coro::syncAwait(peer_client_->AsyncPreWrite(pre)); + ASSERT_TRUE(pre_res.has_value()) + << "AsyncPreWrite failed: " << static_cast(pre_res.error()); + + WriteRevokeRequest revoke; + revoke.key = key; + revoke.pending_write_token = pre_res->pending_write_token; + auto rev_res = + async_simple::coro::syncAwait(peer_client_->AsyncWriteRevoke(revoke)); + ASSERT_TRUE(rev_res.has_value()) + << "AsyncWriteRevoke failed: " << static_cast(rev_res.error()); + + auto idem = + async_simple::coro::syncAwait(peer_client_->AsyncWriteRevoke(revoke)); + ASSERT_TRUE(idem.has_value()) + << "Second AsyncWriteRevoke should be idempotent OK"; +} + // ============================================================================ // Sync ReadRemoteData Tests (wrappers around async) // ============================================================================ @@ -412,6 +816,133 @@ TEST_F(PeerClientTest, SyncReadRemoteDataWithExistingKey) { EXPECT_EQ(result.error(), ErrorCode::INTERNAL_ERROR); } +// ============================================================================ +// Sync PinKey / UnPinKey (wrappers around async) +// ============================================================================ + +TEST_F(PeerClientTest, SyncPinKeyEmptyKey) { + PinKeyRequest req; + req.key = ""; + req.target_tier_id = std::nullopt; + + auto result = peer_client_->PinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, SyncPinKeyAfterPut) { + const std::string key = "peer_sync_pin_after_put"; + const std::string blob = "sync-payload"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest req; + req.key = key; + req.target_tier_id = std::nullopt; + + auto pin_res = peer_client_->PinKey(req); + ASSERT_TRUE(pin_res.has_value()) + << "PinKey failed: " << static_cast(pin_res.error()); + EXPECT_GT(pin_res->remote_buffer.size, 0u); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount via RPC. + UnPinKeyRequest unpin; + unpin.key = key; + unpin.pin_token = pin_res->pin_token; + auto unpin_res = peer_client_->UnPinKey(unpin); + ASSERT_TRUE(unpin_res.has_value()) + << "UnPinKey failed: " << static_cast(unpin_res.error()); +} + +TEST_F(PeerClientTest, SyncPinKeyTwiceSameTokenThenUnpinTwice) { + const std::string key = "peer_sync_pin_twice_ref"; + const std::string blob = "ref"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto first = peer_client_->PinKey(pin_req); + ASSERT_TRUE(first.has_value()) + << "first PinKey failed: " << static_cast(first.error()); + auto second = peer_client_->PinKey(pin_req); + ASSERT_TRUE(second.has_value()) + << "second PinKey failed: " << static_cast(second.error()); + + EXPECT_EQ(first->pin_token, second->pin_token); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount via RPC. + UnPinKeyRequest unpin; + unpin.key = key; + unpin.pin_token = first->pin_token; + auto u1 = peer_client_->UnPinKey(unpin); + ASSERT_TRUE(u1.has_value()) + << "first UnPinKey failed: " << static_cast(u1.error()); + + auto u2 = peer_client_->UnPinKey(unpin); + ASSERT_TRUE(u2.has_value()) + << "second UnPinKey failed: " << static_cast(u2.error()); +} + +TEST_F(PeerClientTest, SyncPinKeyAfterUnpinNewToken) { + const std::string key = "peer_sync_pin_new_token_after_unpin"; + const std::string blob = "tok"; + auto buf = StringToBuffer(blob); + std::vector slices{{buf.get(), blob.size()}}; + auto put = data_manager_->Put(key, slices); + ASSERT_TRUE(put.has_value()); + put.value()->Wait(); + + PinKeyRequest pin_req; + pin_req.key = key; + pin_req.target_tier_id = std::nullopt; + + auto pin1 = peer_client_->PinKey(pin_req); + ASSERT_TRUE(pin1.has_value()) + << "first PinKey failed: " << static_cast(pin1.error()); + + // TransferEngine is not initialized in this unit test, so no TE read + // occurs; UnPinKey only drives DataManager pin refcount via RPC. + UnPinKeyRequest unpin1; + unpin1.key = key; + unpin1.pin_token = pin1->pin_token; + auto un1 = peer_client_->UnPinKey(unpin1); + ASSERT_TRUE(un1.has_value()) + << "first UnPinKey failed: " << static_cast(un1.error()); + + auto pin2 = peer_client_->PinKey(pin_req); + ASSERT_TRUE(pin2.has_value()) << "second PinKey after unpin failed: " + << static_cast(pin2.error()); + EXPECT_NE(pin1->pin_token, pin2->pin_token); + + UnPinKeyRequest unpin2; + unpin2.key = key; + unpin2.pin_token = pin2->pin_token; + auto un2 = peer_client_->UnPinKey(unpin2); + ASSERT_TRUE(un2.has_value()) + << "second UnPinKey failed: " << static_cast(un2.error()); +} + +TEST_F(PeerClientTest, SyncUnPinKeyZeroToken) { + UnPinKeyRequest req; + req.key = "peer_sync_unpin_zero"; + req.pin_token = {0, 0}; + + auto result = peer_client_->UnPinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + // ============================================================================ // Sync WriteRemoteData Tests (wrappers around async) // ============================================================================ @@ -449,6 +980,87 @@ TEST_F(PeerClientTest, SyncWriteRemoteDataValidRequest) { EXPECT_EQ(result.error(), ErrorCode::INTERNAL_ERROR); } +// ============================================================================ +// Sync PreWrite / WriteCommit / WriteRevoke +// ============================================================================ + +TEST_F(PeerClientTest, SyncPreWriteEmptyKey) { + PreWriteRequest pre; + pre.key = ""; + pre.size_bytes = 32; + + auto result = peer_client_->PreWrite(pre); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, SyncWriteCommitEmptyKey) { + WriteCommitRequest commit; + commit.key = ""; + commit.pending_write_token = {5, 6}; + + auto result = peer_client_->WriteCommit(commit); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, SyncWriteCommitAfterPreWrite) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "peer_sync_commit_after_pre"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 128; + pre.target_tier_id = tier_id; + + auto pre_res = peer_client_->PreWrite(pre); + ASSERT_TRUE(pre_res.has_value()) + << "PreWrite failed: " << static_cast(pre_res.error()); + + // TransferEngine is not initialized in this unit test, so no real + // data-plane write occurs. Assume the access side has already filled the + // buffer via TE; this case only checks WriteCommit RPC / metadata outcome. + WriteCommitRequest commit; + commit.key = key; + commit.pending_write_token = pre_res->pending_write_token; + auto commit_res = peer_client_->WriteCommit(commit); + ASSERT_TRUE(commit_res.has_value()) + << "WriteCommit failed: " << static_cast(commit_res.error()); +} + +TEST_F(PeerClientTest, SyncWriteRevokeEmptyKey) { + WriteRevokeRequest request; + request.key = ""; + request.pending_write_token = {3, 4}; + + auto result = peer_client_->WriteRevoke(request); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::INVALID_PARAMS); +} + +TEST_F(PeerClientTest, SyncWriteRevokeAfterPreWrite) { + auto tier_id = GetTierId(); + ASSERT_TRUE(tier_id.has_value()) << "No tier available"; + + const std::string key = "peer_sync_revoke_after_prewrite"; + PreWriteRequest pre; + pre.key = key; + pre.size_bytes = 128; + pre.target_tier_id = tier_id; + + auto pre_res = peer_client_->PreWrite(pre); + ASSERT_TRUE(pre_res.has_value()) + << "PreWrite failed: " << static_cast(pre_res.error()); + + WriteRevokeRequest revoke; + revoke.key = key; + revoke.pending_write_token = pre_res->pending_write_token; + auto rev_res = peer_client_->WriteRevoke(revoke); + ASSERT_TRUE(rev_res.has_value()) + << "WriteRevoke failed: " << static_cast(rev_res.error()); +} + // ============================================================================ // Not Connected Error Tests // ============================================================================ @@ -509,4 +1121,79 @@ TEST_F(PeerClientTest, SyncWriteWithoutConnect) { EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); } +TEST_F(PeerClientTest, AsyncPinKeyWithoutConnect) { + PeerClient unconnected_client; + + PinKeyRequest req; + req.key = "k"; + req.target_tier_id = std::nullopt; + + auto result = + async_simple::coro::syncAwait(unconnected_client.AsyncPinKey(req)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); +} + +TEST_F(PeerClientTest, AsyncUnPinKeyWithoutConnect) { + PeerClient unconnected_client; + + UnPinKeyRequest req; + req.key = "k"; + req.pin_token = {1, 1}; + + auto result = + async_simple::coro::syncAwait(unconnected_client.AsyncUnPinKey(req)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); +} + +TEST_F(PeerClientTest, SyncPinKeyWithoutConnect) { + PeerClient unconnected_client; + + PinKeyRequest req; + req.key = "k"; + req.target_tier_id = std::nullopt; + + auto result = unconnected_client.PinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); +} + +TEST_F(PeerClientTest, SyncUnPinKeyWithoutConnect) { + PeerClient unconnected_client; + + UnPinKeyRequest req; + req.key = "k"; + req.pin_token = {1, 1}; + + auto result = unconnected_client.UnPinKey(req); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); +} + +TEST_F(PeerClientTest, AsyncWriteRevokeWithoutConnect) { + PeerClient unconnected_client; + + WriteRevokeRequest request; + request.key = "test_key"; + request.pending_write_token = {1, 1}; + + auto result = async_simple::coro::syncAwait( + unconnected_client.AsyncWriteRevoke(request)); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); +} + +TEST_F(PeerClientTest, SyncWriteRevokeWithoutConnect) { + PeerClient unconnected_client; + + WriteRevokeRequest request; + request.key = "test_key"; + request.pending_write_token = {2, 2}; + + auto result = unconnected_client.WriteRevoke(request); + ASSERT_FALSE(result.has_value()); + EXPECT_EQ(result.error(), ErrorCode::RPC_FAIL); +} + } // namespace mooncake