Skip to content

Commit ddcae5e

Browse files
committed
Fixed reviews about yuanrong_client(modified strategy_tag, rename custom_name, adjusted annotations ...)
Signed-off-by: dpj135 <958208521@qq.com>
1 parent c01e64a commit ddcae5e

2 files changed

Lines changed: 39 additions & 32 deletions

File tree

tests/test_yuanrong_storage_client_e2e.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
# Here, each mockClient is implemented with independent storage using a simple dictionary,
3131
# and is only suitable for unit testing.
3232

33+
3334
class MockDsTensorClient:
3435
def __init__(self, host, port, device_id):
3536
self.storage = {}
@@ -169,8 +170,8 @@ def test_cpu_only_flow(self, config):
169170

170171
# Put & Verify Meta
171172
meta = client.put(keys, vals)
172-
# b"\x01" is a tag added by YuanrongStorageClient, indicating that it is processed via General KV path.
173-
assert all(m == b"\x02" for m in meta)
173+
# "2" is a tag added by YuanrongStorageClient, indicating that it is processed via General KV path.
174+
assert all(m == "2" for m in meta)
174175

175176
# Get & Verify Values
176177
ret = client.get(keys, shp, dt, meta)
@@ -189,8 +190,8 @@ def test_npu_only_flow(self, config):
189190
client = self.client_cls(config)
190191

191192
meta = client.put(keys, vals)
192-
# b"\x01" is a tag added by YuanrongStorageClient, indicating that it is processed via NPU path.
193-
assert all(m == b"\x01" for m in meta)
193+
# "1" is a tag added by YuanrongStorageClient, indicating that it is processed via NPU path.
194+
assert all(m == "1" for m in meta)
194195

195196
ret = client.get(keys, shp, dt, meta)
196197
for o, r in zip(vals, ret, strict=True):
@@ -203,7 +204,7 @@ def test_mixed_flow(self, config):
203204
client = self.client_cls(config)
204205

205206
meta = client.put(keys, vals)
206-
assert set(meta) == {b"\x01", b"\x02"}
207+
assert set(meta) == {"1", "2"}
207208

208209
ret = client.get(keys, shp, dt, meta)
209210
for o, r in zip(vals, ret, strict=True):

transfer_queue/storage/clients/yuanrong_client.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,15 @@ def put(self, keys: list[str], values: list[Any]):
6161
"""Store key-value pairs using this strategy."""
6262

6363
@abstractmethod
64-
def supports_get(self, custom_meta: Any) -> bool:
65-
"""Check if this strategy can retrieve data with given metadata."""
64+
def supports_get(self, strategy_tag: Any) -> bool:
65+
"""Check if this strategy can retrieve data with given tag."""
6666

6767
@abstractmethod
6868
def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]:
6969
"""Retrieve values by keys; kwargs may include shapes/dtypes."""
7070

7171
@abstractmethod
72-
def supports_clear(self, custom_meta: Any) -> bool:
72+
def supports_clear(self, strategy_tag: Any) -> bool:
7373
"""Check if this strategy owns data identified by metadata."""
7474

7575
@abstractmethod
@@ -101,15 +101,15 @@ def init(config: dict) -> Optional["StorageStrategy"]:
101101
import torch_npu # noqa: F401
102102
except ImportError:
103103
torch_npu_imported = False
104-
enable = config.get("enable_yr_npu_optimization", True)
104+
enable = config.get("enable_yr_npu_transport", True)
105105
if not (enable and torch_npu_imported and torch.npu.is_available()):
106106
return None
107107

108108
return NPUTensorKVClientAdapter(config)
109109

110-
def strategy_tag(self) -> bytes:
110+
def strategy_tag(self) -> str:
111111
"""Strategy tag for NPU tensor storage. Using a single byte is for better performance."""
112-
return b"\x01"
112+
return "1"
113113

114114
def supports_put(self, value: Any) -> bool:
115115
"""Supports contiguous NPU tensors only."""
@@ -130,15 +130,15 @@ def put(self, keys: list[str], values: list[Any]):
130130
pass
131131
self._ds_client.dev_mset(batch_keys, batch_values)
132132

133-
def supports_get(self, custom_meta: str) -> bool:
133+
def supports_get(self, strategy_tag: str) -> bool:
134134
"""Matches 'DsTensorClient' Strategy tag."""
135-
return isinstance(custom_meta, bytes) and custom_meta == self.strategy_tag()
135+
return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()
136136

137137
def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]:
138138
"""Fetch NPU tensors using pre-allocated empty buffers."""
139139
shapes = kwargs.get("shapes", None)
140140
dtypes = kwargs.get("dtypes", None)
141-
if not shapes or not dtypes:
141+
if shapes is None or dtypes is None:
142142
raise ValueError("YuanrongStorageClient needs Expected shapes and dtypes")
143143
results = []
144144
for i in range(0, len(keys), self.KEYS_LIMIT):
@@ -152,9 +152,9 @@ def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]:
152152
results.extend(batch_values)
153153
return results
154154

155-
def supports_clear(self, custom_meta: str) -> bool:
155+
def supports_clear(self, strategy_tag: str) -> bool:
156156
"""Matches 'DsTensorClient' strategy tag."""
157-
return isinstance(custom_meta, bytes) and custom_meta == self.strategy_tag()
157+
return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()
158158

159159
def clear(self, keys: list[str]):
160160
"""Delete NPU tensor keys in batches."""
@@ -211,9 +211,9 @@ def init(config: dict) -> Optional["StorageStrategy"]:
211211
"""Always enabled for general objects."""
212212
return GeneralKVClientAdapter(config)
213213

214-
def strategy_tag(self) -> bytes:
214+
def strategy_tag(self) -> str:
215215
"""Strategy tag for general KV storage. Using a single byte is for better performance."""
216-
return b"\x02"
216+
return "2"
217217

218218
def supports_put(self, value: Any) -> bool:
219219
"""Accepts any Python object."""
@@ -226,9 +226,9 @@ def put(self, keys: list[str], values: list[Any]):
226226
batch_vals = values[i : i + self.PUT_KEYS_LIMIT]
227227
self.mset_zero_copy(batch_keys, batch_vals)
228228

229-
def supports_get(self, custom_meta: str) -> bool:
229+
def supports_get(self, strategy_tag: str) -> bool:
230230
"""Matches 'KVClient' strategy tag."""
231-
return isinstance(custom_meta, bytes) and custom_meta == self.strategy_tag()
231+
return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()
232232

233233
def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]:
234234
"""Retrieve and deserialize objects in batches."""
@@ -239,9 +239,9 @@ def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]:
239239
results.extend(objects)
240240
return results
241241

242-
def supports_clear(self, custom_meta: str) -> bool:
242+
def supports_clear(self, strategy_tag: str) -> bool:
243243
"""Matches 'KVClient' strategy tag."""
244-
return isinstance(custom_meta, bytes) and custom_meta == self.strategy_tag()
244+
return isinstance(strategy_tag, str) and strategy_tag == self.strategy_tag()
245245

246246
def clear(self, keys: list[str]):
247247
"""Delete keys in batches."""
@@ -357,6 +357,8 @@ def __init__(self, config: dict[str, Any]):
357357
if not YUANRONG_DATASYSTEM_IMPORTED:
358358
raise ImportError("YuanRong DataSystem not installed.")
359359

360+
super().__init__(config)
361+
360362
# Storage strategies are prioritized in ascending order of list element index.
361363
# In other words, the later in the order, the lower the priority.
362364
storage_strategies_priority = [NPUTensorKVClientAdapter, GeneralKVClientAdapter]
@@ -369,7 +371,7 @@ def __init__(self, config: dict[str, Any]):
369371
if not self._strategies:
370372
raise RuntimeError("No storage strategy available for YuanrongStorageClient")
371373

372-
def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
374+
def put(self, keys: list[str], values: list[Any]) -> list[str]:
373375
"""Stores multiple key-value pairs to remote storage.
374376
375377
Automatically routes NPU tensors to high-performance tensor storage,
@@ -396,11 +398,11 @@ def put_task(strategy, indexes):
396398
return strategy.strategy_tag(), indexes
397399

398400
# Dispatch tasks and map metadata back to original positions
399-
custom_meta: list[str] = [""] * len(keys)
400-
for meta_str, indexes in self._dispatch_tasks(routed_indexes, put_task):
401+
strategy_tags: list[str] = [""] * len(keys)
402+
for tag, indexes in self._dispatch_tasks(routed_indexes, put_task):
401403
for i in indexes:
402-
custom_meta[i] = meta_str
403-
return custom_meta
404+
strategy_tags[i] = tag
405+
return strategy_tags
404406

405407
def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]:
406408
"""Retrieves multiple values from remote storage with expected metadata.
@@ -411,7 +413,7 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li
411413
keys (List[str]): Keys to fetch.
412414
shapes (List[List[int]]): Expected tensor shapes (use [] for scalars).
413415
dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data.
414-
custom_meta (List[str]): StorageStrategy type for each key
416+
custom_meta (List[str]): StorageStrategy tag for each key
415417
416418
Returns:
417419
List[Any]: Retrieved values in the same order as input keys.
@@ -422,7 +424,10 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li
422424
if not (len(keys) == len(shapes) == len(dtypes) == len(custom_meta)):
423425
raise ValueError("Lengths of keys, shapes, dtypes, custom_meta must match")
424426

425-
routed_indexes = self._route_to_strategies(custom_meta, lambda strategy_, item_: strategy_.supports_get(item_))
427+
strategy_tags = custom_meta
428+
routed_indexes = self._route_to_strategies(
429+
strategy_tags, lambda strategy_, item_: strategy_.supports_get(item_)
430+
)
426431

427432
# Work unit for 'get': handles slicing of keys, shapes, and dtypes simultaneously.
428433
def get_task(strategy, indexes):
@@ -443,16 +448,17 @@ def clear(self, keys: list[str], custom_meta=None):
443448
444449
Args:
445450
keys (List[str]): List of keys to remove.
446-
custom_meta (List[str]): StorageStrategy type for each key
451+
custom_meta (List[str]): StorageStrategy tag for each key
447452
"""
448453
if not isinstance(keys, list) or not isinstance(custom_meta, list):
449454
raise ValueError("keys and custom_meta must be a list")
450455

451456
if len(custom_meta) != len(keys):
452457
raise ValueError("custom_meta length must match keys")
453458

459+
strategy_tags = custom_meta
454460
routed_indexes = self._route_to_strategies(
455-
custom_meta, lambda strategy_, item_: strategy_.supports_clear(item_)
461+
strategy_tags, lambda strategy_, item_: strategy_.supports_clear(item_)
456462
)
457463

458464
def clear_task(strategy, indexes):
@@ -473,7 +479,7 @@ def _route_to_strategies(
473479
Args:
474480
items: A list used to distinguish which storage strategy the data is routed to.
475481
e.g., route <keys, values> for put based on types of values,
476-
or route <keys, Optional[shapes], Optional[dtypes]> for get/clear based on strategy_tag.
482+
or route <keys, Optional[shapes], Optional[dtypes]> for get/clear based on strategy_tags.
477483
The order must correspond to the original keys.
478484
selector: A function that determines whether a strategy supports an item.
479485
Signature: `(strategy: StorageStrategy, item: Any) -> bool`.

0 commit comments

Comments
 (0)