Skip to content

Commit fbdb58e

Browse files
dpj135Copilot
andauthored
[refactor] Refactor yuanrong_client (#18)
## Background: - Currently, the code structure of _yuanrong_client.py_ is complex. The `put/get` operations of `npu_ds_client` and `cpu_ds_client` depend on different **tool functions**, **global constants**, and **a large number of external dependencies**. - In addition, _yuanrong_client.py_ may be optimized in the future, or the interfaces of the data system may be updated or adjusted. As a result, the `yuanrong_client.YuanrongStorageClient` will be modified in a shotgun manner and changes will be divergent. ## Description **Refactor using the Adapter and Strategy patterns.** <img width="882" height="474" alt="image" src="https://github.com/user-attachments/assets/26adb8f9-8b80-45ef-bd5f-85eb291631de" /> - The **interface** `StorageStrategy` is added, which provides a series of abstract methods for encapsulating the `yr.datasystem` interface. - The two storage paths of the original `YuanrongStorageClient` are extracted into two new **adapter classes**: `DsTensorClientAdapter` and `KVClientAdapter`. - `YuanrongStorageClient` is now responsible for dynamically routes data and schedules `DsTensorClientAdapter` and `KVClientAdapter` using the strategy pattern. ## Todo: - [x] Implement **interface** `StorageStartegy`. - [x] Implement **adapter** class `DsTensorClientAdapter` and `KVClientAdapter`. - [x] Add a parameter '**custom_meta**' for All **TransferQueueStorageClient**'s `clear`. - [x] Implement **parallelism** of methods `YuanrongStorageClient::put` , `YuanrongStorageClient::get` and `YuanrongStorageClient::clear`. - [x] Add a **unit test**. --------- Signed-off-by: dpj135 <958208521@qq.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent f048354 commit fbdb58e

7 files changed

Lines changed: 636 additions & 313 deletions

File tree

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
sys.path.append(str(parent_dir))
2626

2727
from transfer_queue.storage.clients.yuanrong_client import ( # noqa: E402
28-
YuanrongStorageClient,
28+
GeneralKVClientAdapter,
2929
)
3030

3131

@@ -37,21 +37,20 @@ def MutableData(self):
3737
return self.data
3838

3939

40-
class TestYuanrongStorageZCopy:
40+
class TestYuanrongKVClientZCopy:
4141
@pytest.fixture
4242
def mock_kv_client(self, mocker):
4343
mock_client = MagicMock()
4444
mock_client.init.return_value = None
4545

4646
mocker.patch("yr.datasystem.KVClient", return_value=mock_client)
4747
mocker.patch("yr.datasystem.DsTensorClient")
48-
mocker.patch("transfer_queue.storage.clients.yuanrong_client.TORCH_NPU_IMPORTED", False)
4948

5049
return mock_client
5150

5251
@pytest.fixture
5352
def storage_client(self, mock_kv_client):
54-
return YuanrongStorageClient({"host": "127.0.0.1", "port": 31501})
53+
return GeneralKVClientAdapter({"host": "127.0.0.1", "port": 31501})
5554

5655
def test_mset_mget_p2p(self, storage_client, mocker):
5756
# Mock serialization/deserialization
@@ -80,13 +79,13 @@ def side_effect_mcreate(keys, sizes):
8079
stored_raw_buffers.append(b.MutableData())
8180
return buffers
8281

83-
storage_client._cpu_ds_client.mcreate.side_effect = side_effect_mcreate
84-
storage_client._cpu_ds_client.get_buffers.return_value = stored_raw_buffers
82+
storage_client._ds_client.mcreate.side_effect = side_effect_mcreate
83+
storage_client._ds_client.get_buffers.return_value = stored_raw_buffers
8584

86-
storage_client.mset_zcopy(
85+
storage_client.mset_zero_copy(
8786
["tensor_key", "string_key"], [torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32), "hello yuanrong"]
8887
)
89-
results = storage_client.mget_zcopy(["tensor_key", "string_key"])
88+
results = storage_client.mget_zero_copy(["tensor_key", "string_key"])
9089

9190
assert torch.allclose(results[0], torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32))
9291
assert results[1] == "hello yuanrong"
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
# Copyright 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
2+
# Copyright 2025 The TransferQueue Team
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import sys
17+
from unittest import mock
18+
19+
import pytest
20+
import torch
21+
22+
try:
23+
import torch_npu # noqa: F401
24+
except ImportError:
25+
pass
26+
27+
28+
# --- Mock Backend Implementation ---
29+
# In real scenarios, multiple DsTensorClients or KVClients share storage.
30+
# Here, each mockClient is implemented with independent storage using a simple dictionary,
31+
# and is only suitable for unit testing.
32+
33+
34+
class MockDsTensorClient:
35+
def __init__(self, host, port, device_id):
36+
self.storage = {}
37+
38+
def init(self):
39+
pass
40+
41+
def dev_mset(self, keys, values):
42+
for k, v in zip(keys, values, strict=True):
43+
assert v.device.type == "npu"
44+
self.storage[k] = v
45+
46+
def dev_mget(self, keys, out_tensors):
47+
for i, k in enumerate(keys):
48+
# Note: If key is missing, tensor remains unchanged (mock limitation)
49+
if k in self.storage:
50+
out_tensors[i].copy_(self.storage[k])
51+
52+
def dev_delete(self, keys):
53+
for k in keys:
54+
self.storage.pop(k, None)
55+
56+
57+
class MockKVClient:
58+
def __init__(self, host, port):
59+
self.storage = {}
60+
61+
def init(self):
62+
pass
63+
64+
def mcreate(self, keys, sizes):
65+
class MockBuffer:
66+
def __init__(self, size):
67+
self._data = bytearray(size)
68+
69+
def MutableData(self):
70+
return memoryview(self._data)
71+
72+
self._current_keys = keys
73+
return [MockBuffer(s) for s in sizes]
74+
75+
def mset_buffer(self, buffers):
76+
for key, buf in zip(self._current_keys, buffers, strict=True):
77+
self.storage[key] = bytes(buf.MutableData())
78+
79+
def get_buffers(self, keys):
80+
return [memoryview(self.storage[k]) if k in self.storage else None for k in keys]
81+
82+
def delete(self, keys):
83+
for k in keys:
84+
self.storage.pop(k, None)
85+
86+
87+
# --- Fixtures ---
88+
89+
90+
@pytest.fixture
91+
def mock_yr_datasystem():
92+
"""Wipe real 'yr' modules and inject mocks."""
93+
94+
# 1. Clean up sys.modules to force a fresh import under mock conditions
95+
# This ensures top-level code in yuanrong_client.py is re-executed
96+
to_delete = [k for k in sys.modules if k.startswith("yr")]
97+
for mod in to_delete:
98+
del sys.modules[mod]
99+
100+
# 2. Setup Mock Objects
101+
ds_mock = mock.MagicMock()
102+
ds_mock.DsTensorClient = MockDsTensorClient
103+
ds_mock.KVClient = MockKVClient
104+
105+
yr_mock = mock.MagicMock(datasystem=ds_mock)
106+
107+
# 3. Apply patches
108+
# - sys.modules: Redirects 'import yr' to our mocks
109+
# - YUANRONG_DATASYSTEM_IMPORTED: Forces the existence check to True so initialize the client successfully
110+
# - datasystem: Direct attribute patch for the module
111+
with (
112+
mock.patch.dict("sys.modules", {"yr": yr_mock, "yr.datasystem": ds_mock}),
113+
mock.patch("transfer_queue.storage.clients.yuanrong_client.YUANRONG_DATASYSTEM_IMPORTED", True, create=True),
114+
mock.patch("transfer_queue.storage.clients.yuanrong_client.datasystem", ds_mock),
115+
):
116+
yield
117+
118+
119+
@pytest.fixture
120+
def config():
121+
return {"host": "127.0.0.1", "port": 12345, "enable_yr_npu_optimization": True}
122+
123+
124+
def assert_tensors_equal(a: torch.Tensor, b: torch.Tensor):
125+
assert a.shape == b.shape and a.dtype == b.dtype
126+
# Move to CPU for cross-device comparison
127+
assert torch.equal(a.cpu(), b.cpu())
128+
129+
130+
# --- Test Suite ---
131+
132+
133+
class TestYuanrongStorageE2E:
134+
@pytest.fixture(autouse=True)
135+
def setup_client(self, mock_yr_datasystem, config):
136+
# Lazy import to ensure mocks are active
137+
from transfer_queue.storage.clients.yuanrong_client import YuanrongStorageClient
138+
139+
self.client_cls = YuanrongStorageClient
140+
self.config = config
141+
142+
def _create_data(self, mode="cpu"):
143+
if mode == "cpu":
144+
keys = ["t", "s", "i"]
145+
vals = [torch.randn(2), "hi", 1]
146+
elif mode == "npu":
147+
if not (hasattr(torch, "npu") and torch.npu.is_available()):
148+
pytest.skip("NPU required")
149+
keys = ["n1", "n2"]
150+
vals = [torch.randn(2).npu(), torch.tensor([1]).npu()]
151+
else: # mixed
152+
if not (hasattr(torch, "npu") and torch.npu.is_available()):
153+
pytest.skip("NPU required")
154+
keys = ["n1", "c1"]
155+
vals = [torch.randn(2).npu(), "cpu"]
156+
157+
shapes = [list(v.shape) if isinstance(v, torch.Tensor) else [] for v in vals]
158+
dtypes = [v.dtype if isinstance(v, torch.Tensor) else None for v in vals]
159+
return keys, vals, shapes, dtypes
160+
161+
def test_mock_can_work(self, config):
162+
mock_class = (MockDsTensorClient, MockKVClient)
163+
client = self.client_cls(config)
164+
for strategy in client._strategies:
165+
assert isinstance(strategy._ds_client, mock_class)
166+
167+
def test_cpu_only_flow(self, config):
168+
client = self.client_cls(config)
169+
keys, vals, shp, dt = self._create_data("cpu")
170+
171+
# Put & Verify Meta
172+
meta = client.put(keys, vals)
173+
# "2" is a tag added by YuanrongStorageClient, indicating that it is processed via General KV path.
174+
assert all(m == "2" for m in meta)
175+
176+
# Get & Verify Values
177+
ret = client.get(keys, shp, dt, meta)
178+
for o, r in zip(vals, ret, strict=True):
179+
if isinstance(o, torch.Tensor):
180+
assert_tensors_equal(o, r)
181+
else:
182+
assert o == r
183+
184+
# Clear & Verify
185+
client.clear(keys, meta)
186+
assert all(v is None for v in client.get(keys, shp, dt, meta))
187+
188+
def test_npu_only_flow(self, config):
189+
keys, vals, shp, dt = self._create_data("npu")
190+
client = self.client_cls(config)
191+
192+
meta = client.put(keys, vals)
193+
# "1" is a tag added by YuanrongStorageClient, indicating that it is processed via NPU path.
194+
assert all(m == "1" for m in meta)
195+
196+
ret = client.get(keys, shp, dt, meta)
197+
for o, r in zip(vals, ret, strict=True):
198+
assert_tensors_equal(o, r)
199+
200+
client.clear(keys, meta)
201+
202+
def test_mixed_flow(self, config):
203+
keys, vals, shp, dt = self._create_data("mixed")
204+
client = self.client_cls(config)
205+
206+
meta = client.put(keys, vals)
207+
assert set(meta) == {"1", "2"}
208+
209+
ret = client.get(keys, shp, dt, meta)
210+
for o, r in zip(vals, ret, strict=True):
211+
if isinstance(o, torch.Tensor):
212+
assert_tensors_equal(o, r)
213+
else:
214+
assert o == r

transfer_queue/storage/clients/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,6 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non
6565
raise NotImplementedError("Subclasses must implement get")
6666

6767
@abstractmethod
68-
def clear(self, keys: list[str]) -> None:
68+
def clear(self, keys: list[str], custom_backend_meta=None) -> None:
6969
"""Clear key-value pairs in the storage backend."""
7070
raise NotImplementedError("Subclasses must implement clear")

transfer_queue/storage/clients/mooncake_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non
139139
keys (List[str]): Keys to fetch.
140140
shapes (List[List[int]]): Expected tensor shapes (use [] for scalars).
141141
dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data.
142-
custom_backend_meta (List[str], optional): Device type (npu/cpu) for each key
142+
custom_backend_meta (List[str], optional): ...
143143
144144
Returns:
145145
List[Any]: Retrieved values in the same order as input keys.
@@ -216,11 +216,12 @@ def _batch_get_bytes(self, keys: list[str]) -> list[bytes]:
216216
results.extend(batch_results)
217217
return results
218218

219-
def clear(self, keys: list[str]):
219+
def clear(self, keys: list[str], custom_backend_meta=None):
220220
"""Deletes multiple keys from MooncakeStore.
221221
222222
Args:
223223
keys (List[str]): List of keys to remove.
224+
custom_backend_meta (List[Any], optional): ...
224225
"""
225226
for key in keys:
226227
ret = self._store.remove(key)

transfer_queue/storage/clients/ray_storage_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,11 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_backend_meta=Non
106106
raise RuntimeError(f"Failed to retrieve value for key '{keys}': {e}") from e
107107
return values
108108

109-
def clear(self, keys: list[str]):
109+
def clear(self, keys: list[str], custom_backend_meta=None):
110110
"""
111111
Delete entries from storage by keys.
112112
Args:
113113
keys (list): List of keys to delete
114+
custom_backend_meta (List[Any], optional): ...
114115
"""
115116
ray.get(self.storage_actor.clear_obj_ref.remote(keys))

0 commit comments

Comments
 (0)