Skip to content

Commit cde575c

Browse files
authored
[fix,feat] Support MooncakeStore easy init (Ascend#45)
## New Features - Added `MooncakeStore` Configurations: Introduced related configuration options for `MooncakeStore` in `config.py`. - Easy Initialization: Implemented support for `tq.init()` when using the MooncakeStore backend. - E2E CI Coverage: Added end-to-end continuous integration tests specifically for the MooncakeStore backend. ## Bug Fixes - **`KVStorageManager` Check**: Removed an outdated validation check in `KVStorageManager` that previously caused issues during put operations. - **Metadata Update Tracking**: Fixed a metadata update issue in `TransferQueueController`. Now, when a field transforms between a normal tensor and a nested tensor, the system correctly recomputes and updates the `per_sample_shape`, `is_nested`, and `shape` information. - **ZMQ Related**: Set `recv_multipart(copy=False)` by default. ## Known Issues - **Graceful Shutdown Limitations**: We cannot gracefully shut down `mooncake_master` because the distributed `TransferQueueClient` holding `MooncakeDistributedStore()` will raise heartbeat error. As a workaround, we currently launch `mooncake_master` when setting `auto_init=true` but bypass shutting it down. To minimize possible influence, we call `remove_all()` to delete all the keys in `mooncake_master`. - **Uneven BatchMeta Fields**: `TransferQueueController` currently cannot handle non-uniform `BatchMeta` instances where samples do not have equal fields. This prevents key-value-based backends from accurately clearing all keys. In `MooncakeStore`, we are temporarily using `remove_by_regex` to mitigate this issue. - **1D Tensor Handling**: When a user inputs a 1D tensor, previous refactoring populated an empty `torch.Size([])` which could mislead key-value-based backends during zero-copy operations. Since these backends must perform fine-grained splits on the input TensorDict, distinguishing between 1D and 2D input tensors is difficult. We have now added a warning for this type of input and manually populate the shape with `torch.Size([1])`. - ```python3 # in AsyncTransferQueueClient.async_put() for field_name, field_data in data.items(): if isinstance(field_data, torch.Tensor) and field_data.ndim == 1: logger.warning( f"[{self.client_id}]: Data field '{field_name}' is a tensor with only one dimension. " f"You may receive 2D tensors in key-value based backend." ) ``` --- ## Configuration Reference The config structure for `MooncakeStore` looks like this: ```yml backend: # Pluggable storage/transport backend of TransferQueue. Choose from: # SimpleStorage, Yuanrong, MooncakeStore, ... storage_backend: MooncakeStore # For MooncakeStore: MooncakeStore: # Auto init metadata_server auto_init: true # Address of the HTTP metadata server metadata_server: localhost:50050 # Address of master server master_server_address: localhost:50051 # Address of local host local_hostname: localhost # Protocol for transmission. Choose from: tcp, rdma. (default: tcp) protocol: tcp # Memory segment size in bytes for mounting (default: 4GB) global_segment_size: 4294967296 # Local buffer size in bytes (default: 1GB) local_buffer_size: 1073741824 # Network device name. Set to "" to let Mooncake to auto-picks devices device_name: "" ``` CC:@zhaohaidao @dpj135 @mpb159753 --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent 0945d28 commit cde575c

18 files changed

Lines changed: 560 additions & 209 deletions

.github/workflows/python-package.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ jobs:
3333
python -m pip install --upgrade pip
3434
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
3535
pip install -e ".[test,build,yuanrong]"
36+
pip install mooncake-transfer-engine-non-cuda
3637
- name: Lint with flake8
3738
run: |
3839
# stop the build if there are Python syntax errors or undefined names
@@ -43,11 +44,10 @@ jobs:
4344
run: |
4445
python -m build --wheel
4546
pip install dist/*.whl --force-reinstall
46-
- name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=False)
47+
- name: Test with pytest
4748
run: |
4849
pytest tests
49-
- name: Test with pytest (TQ_ZERO_COPY_SERIALIZATION=True)
50-
run: |
51-
ray stop --force
52-
export TQ_ZERO_COPY_SERIALIZATION=True
53-
pytest tests
50+
TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py
51+
pkill -f "mooncake_master"
52+
TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_kv_interface_e2e.py
53+
pkill -f "mooncake_master"

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ test = [
117117
yuanrong = [
118118
"openyuanrong-datasystem"
119119
]
120+
mooncake = [
121+
"mooncake-transfer-engine"
122+
]
120123

121124
# If you need to mimic `package_dir={'': '.'}`:
122125
[tool.setuptools.package-dir]

tests/e2e/test_e2e_lifecycle_consistency.py

Lines changed: 92 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
"""E2E lifecycle consistency tests for TransferQueue."""
17-
16+
import os
1817
import sys
1918
import time
2019
from pathlib import Path
@@ -23,6 +22,7 @@
2322
import pytest
2423
import ray
2524
import torch
25+
from omegaconf import OmegaConf
2626
from tensordict import TensorDict
2727
from tensordict.tensorclass import NonTensorData
2828

@@ -48,6 +48,38 @@
4848
"non_tensor_stack",
4949
]
5050

51+
# Backend configurations for E2E tests
52+
BACKEND_CONFIGS = {
53+
"SimpleStorage": {
54+
"controller": {
55+
"polling_mode": True,
56+
},
57+
"backend": {
58+
"storage_backend": "SimpleStorage",
59+
"SimpleStorage": {
60+
"total_storage_size": 200,
61+
"num_data_storage_units": 2,
62+
},
63+
},
64+
},
65+
"MooncakeStore": {
66+
"controller": {
67+
"polling_mode": True,
68+
},
69+
"backend": {
70+
"storage_backend": "MooncakeStore",
71+
"MooncakeStore": {
72+
"global_segment_size": 134217728, # 128MB
73+
"local_buffer_size": 134217728, # 128MB
74+
"metadata_server": "localhost:50050",
75+
"master_server_address": "localhost:50051",
76+
"protocol": "tcp",
77+
"device_name": "",
78+
},
79+
},
80+
},
81+
}
82+
5183

5284
@pytest.fixture(scope="module")
5385
def ray_cluster():
@@ -59,24 +91,33 @@ def ray_cluster():
5991

6092

6193
@pytest.fixture(scope="module")
62-
def e2e_client(ray_cluster):
63-
"""Create a client using transfer_queue.init() for lifecycle testing."""
64-
from omegaconf import OmegaConf
94+
def backend_name():
95+
"""Get the backend name from environment variable.
96+
97+
Environment variables:
98+
TQ_TEST_BACKEND: Backend name (SimpleStorage or MooncakeStore)
99+
100+
To run tests for a specific backend:
101+
TQ_TEST_BACKEND=SimpleStorage pytest tests/e2e/test_e2e_lifecycle_consistency.py
102+
TQ_TEST_BACKEND=MooncakeStore pytest tests/e2e/test_e2e_lifecycle_consistency.py
103+
"""
104+
return os.environ.get("TQ_TEST_BACKEND", "SimpleStorage")
105+
106+
107+
@pytest.fixture(scope="module")
108+
def e2e_client(ray_cluster, backend_name):
109+
"""Create a client using transfer_queue.init() for lifecycle testing.
65110
111+
Args:
112+
ray_cluster: Ray cluster fixture
113+
backend_name: Backend name from TQ_TEST_BACKEND env var
114+
"""
66115
import transfer_queue
67116

68-
config = {
69-
"controller": {
70-
"polling_mode": True,
71-
},
72-
"backend": {
73-
"storage_backend": "SimpleStorage",
74-
"SimpleStorage": {
75-
"total_storage_size": 200,
76-
"num_data_storage_units": 2,
77-
},
78-
},
79-
}
117+
if backend_name not in BACKEND_CONFIGS:
118+
raise ValueError(f"Unknown backend: {backend_name}. Available: {list(BACKEND_CONFIGS.keys())}")
119+
120+
config = BACKEND_CONFIGS[backend_name]
80121
transfer_queue.init(OmegaConf.create(config))
81122
client = transfer_queue.get_client()
82123
yield client
@@ -244,7 +285,7 @@ def verify_list_equal(retrieved, expected) -> bool:
244285
if isinstance(retrieved, NonTensorStack):
245286
retrieved = retrieved.tolist()
246287
elif isinstance(retrieved, torch.Tensor):
247-
retrieved = retrieved.tolist()
288+
retrieved = retrieved.reshape(-1).tolist() # may get 2D tensor back using key-value based backend
248289
if isinstance(expected, NonTensorStack):
249290
expected = expected.tolist()
250291
elif isinstance(expected, torch.Tensor):
@@ -283,9 +324,21 @@ def _reorder_tensordict(td: TensorDict, order: list[int]) -> TensorDict:
283324
return TensorDict(reordered, batch_size=td.batch_size)
284325

285326

327+
def recover_local_index(global_index_order, new_global_index_order):
328+
value_to_new_index = {}
329+
for idx, val in enumerate(new_global_index_order):
330+
value_to_new_index[val] = idx
331+
332+
local_index_order_to_recover = []
333+
for val in global_index_order:
334+
local_index_order_to_recover.append(value_to_new_index[val])
335+
336+
return local_index_order_to_recover
337+
338+
286339
# Scenario One: Core Read/Write Consistency
287340
def test_core_consistency(e2e_client):
288-
"""Put full complex data then get verify all field types are correctly round-tripped."""
341+
"""Put full complex data then get - verify all field types are correctly round-tripped."""
289342
client = e2e_client
290343
partition_id = "test_core_consistency"
291344
batch_size = 20
@@ -362,6 +415,12 @@ def test_core_consistency(e2e_client):
362415
# Scenario Two: Cross-Shard Update
363416
def test_cross_shard_complex_update(e2e_client):
364417
"""Cross-shard update: put A + put B, update overlapping region, verify all regions."""
418+
419+
# FIXME: Add data update test to MooncakeStore after Upsert function is ready
420+
# https://github.com/kvcache-ai/Mooncake/issues/1645
421+
if os.environ.get("TQ_TEST_BACKEND", "SimpleStorage") == "MooncakeStore":
422+
return
423+
365424
client = e2e_client
366425
partition_id = "test_cross_shard_update"
367426
task_name = "cross_shard_task"
@@ -744,12 +803,19 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client):
744803

745804
indices = list(range(batch_size))
746805
original_data = generate_complex_data(indices)
747-
client.put(data=original_data, partition_id=partition_id)
806+
original_meta = client.put(data=original_data, partition_id=partition_id)
748807

808+
global_index_order = original_meta.global_indexes
749809
try:
750810
# === Phase 1: Retrieve and verify writability ===
751811
meta = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch")
752812
assert meta is not None and meta.size == batch_size
813+
814+
# the global_index_order in retrieved meta is different from the original one.
815+
# we need to reorder first.
816+
local_index_order = recover_local_index(global_index_order, meta.global_indexes)
817+
meta = meta.select_samples(local_index_order)
818+
753819
retrieved = client.get_data(meta)
754820

755821
# 1. tensor_f32: writable
@@ -793,6 +859,12 @@ def test_retrieved_data_writability_and_memory_safety(e2e_client):
793859
# Re-retrieve the same data — modifications above should NOT have affected storage
794860
meta2 = poll_for_meta(client, partition_id, fields, batch_size, task_name, mode="force_fetch")
795861
assert meta2 is not None and meta2.size == batch_size
862+
863+
# the global_index_order in retrieved meta is different from the original one.
864+
# we need to reorder first.
865+
local_index_order = recover_local_index(global_index_order, meta2.global_indexes)
866+
meta2 = meta2.select_samples(local_index_order)
867+
796868
retrieved2 = client.get_data(meta2)
797869

798870
# tensor_f32[0,0] should be the original value, not 99999.0

0 commit comments

Comments
 (0)