Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 56 additions & 29 deletions mooncake-integration/store/store_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "pyclient.h"
#include "dummy_client.h"
#include "real_client.h"
#include "p2p_rpc_types.h"
#include "rpc_types.h"

#include <cstdlib> // for atexit
#include <map>
Expand Down Expand Up @@ -231,8 +233,8 @@ class MooncakeStorePyWrapper {

pybind11::bytes get(
const std::string& key,
const std::optional<ReadRouteConfig>& config_opt = std::nullopt) {
ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{});
const std::optional<ReadConfigExt>& config_opt = std::nullopt) {
ReadConfigExt config = config_opt.value_or(ReadConfigExt{});
if (!is_client_initialized()) {
LOG(ERROR) << "Client is not initialized";
return pybind11::bytes("\\0", 0);
Expand Down Expand Up @@ -268,8 +270,8 @@ class MooncakeStorePyWrapper {

std::vector<pybind11::bytes> get_batch(
const std::vector<std::string>& keys,
const std::optional<ReadRouteConfig>& config_opt = std::nullopt) {
ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{});
const std::optional<ReadConfigExt>& config_opt = std::nullopt) {
ReadConfigExt config = config_opt.value_or(ReadConfigExt{});
const auto kNullString = pybind11::bytes("\\0", 0);
if (!is_client_initialized()) {
LOG(ERROR) << "Client is not initialized";
Expand Down Expand Up @@ -302,15 +304,15 @@ class MooncakeStorePyWrapper {
pybind11::object get_tensor_with_tp(
const std::string& key, int tp_rank = 0, int tp_size = 1,
int split_dim = 0,
const std::optional<ReadRouteConfig>& config_opt = std::nullopt) {
const std::optional<ReadConfigExt>& config_opt = std::nullopt) {
if (tp_size <= 1) return get_tensor(key, config_opt);
return get_tensor(get_tp_key_name(key, tp_rank), config_opt);
}

pybind11::list batch_get_tensor_with_tp(
const std::vector<std::string>& base_keys, int tp_rank = 0,
int tp_size = 1,
const std::optional<ReadRouteConfig>& config_opt = std::nullopt) {
const std::optional<ReadConfigExt>& config_opt = std::nullopt) {
if (tp_size <= 1) return batch_get_tensor(base_keys, config_opt);

std::vector<std::string> shard_keys;
Expand All @@ -323,8 +325,8 @@ class MooncakeStorePyWrapper {

pybind11::object get_tensor(
const std::string& key,
const std::optional<ReadRouteConfig>& config_opt = std::nullopt) {
ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{});
const std::optional<ReadConfigExt>& config_opt = std::nullopt) {
ReadConfigExt config = config_opt.value_or(ReadConfigExt{});
if (!is_client_initialized() || use_dummy_client_) {
LOG(ERROR) << "Client not initialized or Dummy client not "
"supported for tensors";
Expand All @@ -342,8 +344,8 @@ class MooncakeStorePyWrapper {

pybind11::list batch_get_tensor(
const std::vector<std::string>& keys,
const std::optional<ReadRouteConfig>& config_opt = std::nullopt) {
ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{});
const std::optional<ReadConfigExt>& config_opt = std::nullopt) {
ReadConfigExt config = config_opt.value_or(ReadConfigExt{});
if (!is_client_initialized() || use_dummy_client_) {
LOG(ERROR) << "Client not initialized or Dummy client not "
"supported for tensors";
Expand All @@ -367,8 +369,8 @@ class MooncakeStorePyWrapper {

pybind11::object get_tensor_into(
const std::string& key, uintptr_t buffer_ptr, size_t size,
const std::optional<ReadRouteConfig>& config_opt = std::nullopt) {
ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{});
const std::optional<ReadConfigExt>& config_opt = std::nullopt) {
ReadConfigExt config = config_opt.value_or(ReadConfigExt{});
char* buffer = reinterpret_cast<char*>(buffer_ptr);
if (!is_client_initialized()) {
LOG(ERROR) << "Client is not initialized";
Expand Down Expand Up @@ -396,8 +398,8 @@ class MooncakeStorePyWrapper {
const std::vector<std::string>& keys,
const std::vector<uintptr_t>& buffer_ptrs,
const std::vector<size_t>& sizes,
const std::optional<ReadRouteConfig>& config_opt = std::nullopt) {
ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{});
const std::optional<ReadConfigExt>& config_opt = std::nullopt) {
ReadConfigExt config = config_opt.value_or(ReadConfigExt{});
std::vector<void*> buffers;
buffers.reserve(buffer_ptrs.size());
for (uintptr_t ptr : buffer_ptrs) {
Expand Down Expand Up @@ -460,7 +462,7 @@ class MooncakeStorePyWrapper {
pybind11::object get_tensor_with_tp_into(
const std::string& key, uintptr_t buffer_ptr, size_t size,
int tp_rank = 0, int tp_size = 1, int split_dim = 0,
const std::optional<ReadRouteConfig>& config_opt = std::nullopt) {
const std::optional<ReadConfigExt>& config_opt = std::nullopt) {
if (!is_client_initialized()) {
LOG(ERROR) << "Client is not initialized";
return pybind11::none();
Expand All @@ -487,7 +489,7 @@ class MooncakeStorePyWrapper {
const std::vector<std::string>& base_keys,
const std::vector<uintptr_t>& buffer_ptrs,
const std::vector<size_t>& sizes, int tp_rank = 0, int tp_size = 1,
const std::optional<ReadRouteConfig>& config_opt = std::nullopt) {
const std::optional<ReadConfigExt>& config_opt = std::nullopt) {
if (!is_client_initialized()) {
LOG(ERROR) << "Client is not initialized";
py::list empty_list;
Expand Down Expand Up @@ -877,6 +879,19 @@ PYBIND11_MODULE(store, m) {
.def_readwrite("max_candidates", &ReadRouteConfig::max_candidates)
.def_readwrite("p2p_config", &ReadRouteConfig::p2p_config);

py::enum_<RdmaDirectionMode>(m, "RdmaDirectionMode")
.value("REVERSE", RdmaDirectionMode::REVERSE)
.value("FORWARD", RdmaDirectionMode::FORWARD);

py::class_<ReadConfigExt>(m, "ReadConfigExt")
.def(py::init<>())
.def(py::init<ReadRouteConfig>())
.def_readwrite("route_config", &ReadConfigExt::route_config)
.def_readwrite("rdma_direction_mode",
&ReadConfigExt::rdma_direction_mode);

py::implicitly_convertible<ReadRouteConfig, ReadConfigExt>();

py::class_<WriteRouteRequestConfig>(m, "WriteRouteRequestConfig")
.def(py::init<>()) // Default constructor
.def_readwrite("max_candidates",
Expand All @@ -894,6 +909,13 @@ PYBIND11_MODULE(store, m) {
return oss.str();
});

py::class_<WriteConfigExt>(m, "WriteConfigExt")
.def(py::init<>())
.def(py::init<WriteRouteRequestConfig>())
.def_readwrite("route_config", &WriteConfigExt::route_config)
.def_readwrite("rdma_direction_mode",
&WriteConfigExt::rdma_direction_mode);

py::enum_<ReplicaStatus>(m, "ReplicaStatus")
.value("UNDEFINED", ReplicaStatus::UNDEFINED)
.value("INITIALIZED", ReplicaStatus::INITIALIZED)
Expand Down Expand Up @@ -1152,8 +1174,8 @@ PYBIND11_MODULE(store, m) {
.def(
"get_buffer",
[](MooncakeStorePyWrapper& self, const std::string& key,
const std::optional<ReadRouteConfig>& config_opt) {
ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{});
const std::optional<ReadConfigExt>& config_opt) {
ReadConfigExt config = config_opt.value_or(ReadConfigExt{});
py::gil_scoped_release release;
return self.store_->get_buffer(key, config);
},
Expand All @@ -1163,8 +1185,8 @@ PYBIND11_MODULE(store, m) {
"batch_get_buffer",
[](MooncakeStorePyWrapper& self,
const std::vector<std::string>& keys,
const std::optional<ReadRouteConfig>& config_opt) {
ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{});
const std::optional<ReadConfigExt>& config_opt) {
ReadConfigExt config = config_opt.value_or(ReadConfigExt{});
py::gil_scoped_release release;
if (self.use_dummy_client_) {
LOG(ERROR) << "batch_get_buffer is not supported for dummy "
Expand Down Expand Up @@ -1234,7 +1256,10 @@ PYBIND11_MODULE(store, m) {
" tp_size: The total tensor parallel size (default 1).\n"
" split_dim: The dimension to split the tensor along "
"(default 0).\n"
" config: ReadRouteConfig.")
" config: ReadConfigExt (optional; omit for defaults). "
"ReadRouteConfig is accepted and treated as ReadConfigExt with "
"default RdmaDirectionMode.REVERSE. Set config.rdma_direction_mode "
"to FORWARD for forward RDMA read.")
.def("batch_get_tensor_with_tp",
&MooncakeStorePyWrapper::batch_get_tensor_with_tp,
py::arg("base_keys"), py::arg("tp_rank") = 0,
Expand Down Expand Up @@ -1296,7 +1321,10 @@ PYBIND11_MODULE(store, m) {
" tp_size: The total tensor parallel size (default 1).\n"
" split_dim: The dimension to split the tensor along"
"(default 0).\n"
" config: ReadRouteConfig.")
" config: ReadConfigExt (optional; omit for defaults). "
"ReadRouteConfig is accepted and treated as ReadConfigExt with "
"default RdmaDirectionMode.REVERSE. Set config.rdma_direction_mode "
"to FORWARD for forward RDMA read.")
.def(
"batch_get_tensor_with_tp_into",
&MooncakeStorePyWrapper::batch_get_tensor_with_tp_into,
Expand Down Expand Up @@ -1331,8 +1359,8 @@ PYBIND11_MODULE(store, m) {
"get_into",
[](MooncakeStorePyWrapper& self, const std::string& key,
uintptr_t buffer_ptr, size_t size,
const std::optional<ReadRouteConfig>& config_opt) {
ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{});
const std::optional<ReadConfigExt>& config_opt) {
ReadConfigExt config = config_opt.value_or(ReadConfigExt{});
// Get data directly into user-provided buffer
void* buffer = reinterpret_cast<void*>(buffer_ptr);
py::gil_scoped_release release;
Expand All @@ -1352,8 +1380,8 @@ PYBIND11_MODULE(store, m) {
const std::vector<std::string>& keys,
const std::vector<uintptr_t>& buffer_ptrs,
const std::vector<size_t>& sizes,
const std::optional<ReadRouteConfig>& config_opt) {
ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{});
const std::optional<ReadConfigExt>& config_opt) {
ReadConfigExt config = config_opt.value_or(ReadConfigExt{});
std::vector<void*> buffers;
buffers.reserve(buffer_ptrs.size());
for (uintptr_t ptr : buffer_ptrs) {
Expand Down Expand Up @@ -1543,9 +1571,8 @@ PYBIND11_MODULE(store, m) {
const std::vector<std::vector<uintptr_t>>& all_buffer_ptrs,
const std::vector<std::vector<size_t>>& all_sizes,
bool aggregate_same_segment_task = false,
const std::optional<ReadRouteConfig>& config_opt =
std::nullopt) {
ReadRouteConfig config = config_opt.value_or(ReadRouteConfig{});
const std::optional<ReadConfigExt>& config_opt = std::nullopt) {
ReadConfigExt config = config_opt.value_or(ReadConfigExt{});
py::gil_scoped_release release;
if (self.use_dummy_client_) {
LOG(ERROR)
Expand Down
12 changes: 6 additions & 6 deletions mooncake-store/include/centralized_client_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,11 @@ class CentralizedClientService

tl::expected<std::unique_ptr<QueryResult>, ErrorCode> Query(
const std::string& object_key,
const ReadRouteConfig& config = {}) override;
const ReadConfigExt& config = {}) override;

std::vector<tl::expected<std::unique_ptr<QueryResult>, ErrorCode>>
BatchQuery(const std::vector<std::string>& object_keys,
const ReadRouteConfig& config = {}) override;
const ReadConfigExt& config = {}) override;

tl::expected<bool, ErrorCode> IsExist(const std::string& key) override;

Expand All @@ -76,24 +76,24 @@ class CentralizedClientService
tl::expected<int64_t, ErrorCode> Get(
const std::string& key, const std::vector<void*>& buffers,
const std::vector<size_t>& sizes,
const ReadRouteConfig& config = {}) override;
const ReadConfigExt& config = {}) override;

std::vector<tl::expected<int64_t, ErrorCode>> BatchGet(
const std::vector<std::string>& keys,
const std::vector<std::vector<void*>>& all_buffers,
const std::vector<std::vector<size_t>>& all_sizes,
const ReadRouteConfig& config = {},
const ReadConfigExt& config = {},
bool aggregate_same_segment_task = false) override;

tl::expected<std::shared_ptr<BufferHandle>, ErrorCode> Get(
const std::string& key,
std::shared_ptr<ClientBufferAllocator> allocator,
const ReadRouteConfig& config = {}) override;
const ReadConfigExt& config = {}) override;

std::vector<tl::expected<std::shared_ptr<BufferHandle>, ErrorCode>>
BatchGet(const std::vector<std::string>& keys,
std::shared_ptr<ClientBufferAllocator> allocator,
const ReadRouteConfig& config = {}) override;
const ReadConfigExt& config = {}) override;

tl::expected<void, ErrorCode> Put(const ObjectKey& key,
std::vector<Slice>& slices,
Expand Down
14 changes: 14 additions & 0 deletions mooncake-store/include/client_rpc_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,20 @@ class ClientRpcService {
tl::expected<UUID, ErrorCode> WriteRemoteData(
const RemoteWriteRequest& request);

tl::expected<PreWriteResponse, ErrorCode> PreWrite(
const PreWriteRequest& request);

tl::expected<void, ErrorCode> WriteCommit(
const WriteCommitRequest& request);

tl::expected<void, ErrorCode> WriteRevoke(
const WriteRevokeRequest& request);

tl::expected<PinKeyResponse, ErrorCode> PinKey(
const PinKeyRequest& request);

tl::expected<void, ErrorCode> UnPinKey(const UnPinKeyRequest& request);

private:
DataManager& data_manager_; // Reference: owned by Client, same lifetime
P2PClientMetric* metrics_; // Optional: owned by P2PClientService
Expand Down
53 changes: 53 additions & 0 deletions mooncake-store/include/client_rpc_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <string>
#include <string_view>
#include <vector>
#include <optional>
#include <cstdint>
#include "types.h"
#include "ylt/struct_json/json_reader.h"
Expand Down Expand Up @@ -84,4 +85,56 @@ struct BatchRemoteWriteRequest {

YLT_REFL(BatchRemoteWriteRequest, keys, src_buffers_list, target_tier_ids);

struct PreWriteRequest {
std::string_view key;
uint64_t size_bytes = 0;
std::optional<UUID> target_tier_id;
};

YLT_REFL(PreWriteRequest, key, size_bytes, target_tier_id);

struct PreWriteResponse {
RemoteBufferDesc remote_buffer;
UUID pending_write_token;
};

YLT_REFL(PreWriteResponse, remote_buffer, pending_write_token);

struct WriteCommitRequest {
std::string_view key;
UUID pending_write_token;
};

YLT_REFL(WriteCommitRequest, key, pending_write_token);

/** Drops a pending PreWrite allocation without committing (e.g. after TE
* failure). */
struct WriteRevokeRequest {
std::string_view key;
UUID pending_write_token;
};

YLT_REFL(WriteRevokeRequest, key, pending_write_token);

struct PinKeyRequest {
std::string_view key;
std::optional<UUID> target_tier_id;
};

YLT_REFL(PinKeyRequest, key, target_tier_id);

struct PinKeyResponse {
RemoteBufferDesc remote_buffer;
UUID pin_token;
};

YLT_REFL(PinKeyResponse, remote_buffer, pin_token);

struct UnPinKeyRequest {
std::string_view key;
UUID pin_token;
};

YLT_REFL(UnPinKeyRequest, key, pin_token);

} // namespace mooncake
Loading
Loading