Skip to content

Commit d3720fd

Browse files
committed
Added 'route_to_strategy' to class 'YuanrongStorageClient' & Adjust the order of classes
Signed-off-by: dpj135 <958208521@qq.com>
1 parent 2f3d265 commit d3720fd

1 file changed

Lines changed: 135 additions & 125 deletions

File tree

transfer_queue/storage/clients/yuanrong_client.py

Lines changed: 135 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import struct
1919
from abc import ABC, abstractmethod
2020
from concurrent.futures import ThreadPoolExecutor
21-
from typing import Any, Optional, TypeAlias, Union
21+
from typing import Any, Callable, Optional, TypeAlias, Union
2222

2323
import torch
2424
from torch import Tensor
@@ -40,125 +40,6 @@
4040
YUANRONG_DATASYSTEM_IMPORTED = False
4141

4242

43-
@StorageClientFactory.register("YuanrongStorageClient")
44-
class YuanrongStorageClient(TransferQueueStorageKVClient):
45-
"""
46-
Storage client for YuanRong DataSystem.
47-
48-
Supports storing and fetching both:
49-
- NPU tensors via DsTensorClient (for high performance).
50-
- General objects (CPU tensors, str, bool, list, etc.) via KVClient with pickle serialization.
51-
"""
52-
53-
def __init__(self, config: dict[str, Any]):
54-
if not YUANRONG_DATASYSTEM_IMPORTED:
55-
raise ImportError("YuanRong DataSystem not installed.")
56-
57-
self._strategies: list[StorageStrategy] = []
58-
for strategy_cls in [DsTensorClientAdapter, KVClientAdapter]:
59-
strategy = strategy_cls.init(config)
60-
if strategy is not None:
61-
self._strategies.append(strategy)
62-
63-
if not self._strategies:
64-
raise RuntimeError("No storage strategy available for YuanrongStorageClient")
65-
66-
def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
67-
"""Stores multiple key-value pairs to remote storage.
68-
69-
Automatically routes NPU tensors to high-performance tensor storage,
70-
and other objects to general-purpose KV storage.
71-
72-
Args:
73-
keys (List[str]): List of unique string identifiers.
74-
values (List[Any]): List of values to store (tensors, scalars, dicts, etc.).
75-
"""
76-
if not isinstance(keys, list) or not isinstance(values, list):
77-
raise ValueError("keys and values must be lists")
78-
if len(keys) != len(values):
79-
raise ValueError("Number of keys must match number of values")
80-
custom_metas = []
81-
strategy_batches: dict[StorageStrategy, tuple[list[str], list[Any]]] = {s: ([], []) for s in self._strategies}
82-
83-
for key, value in zip(keys, values, strict=True):
84-
for strategy in self._strategies:
85-
if strategy.supports_put(value):
86-
custom_metas.append(strategy.custom_meta())
87-
strategy_batches[strategy][0].append(key)
88-
strategy_batches[strategy][1].append(value)
89-
break
90-
else:
91-
raise ValueError(f"No strategy supports putting value of type {type(value)}")
92-
# Todo: Parallel put
93-
for strategy, (s_keys, s_vals) in strategy_batches.items():
94-
if s_keys:
95-
strategy.put(s_keys, s_vals)
96-
97-
return custom_metas
98-
99-
def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]:
100-
"""Retrieves multiple values from remote storage with expected metadata.
101-
102-
Requires shape and dtype hints to reconstruct NPU tensors correctly.
103-
104-
Args:
105-
keys (List[str]): Keys to fetch.
106-
shapes (List[List[int]]): Expected tensor shapes (use [] for scalars).
107-
dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data.
108-
custom_meta (List[str], optional): Device type (npu/cpu) for each key
109-
110-
Returns:
111-
List[Any]: Retrieved values in the same order as input keys.
112-
"""
113-
if shapes is None or dtypes is None:
114-
raise ValueError("YuanrongStorageClient needs Expected shapes and dtypes")
115-
if not (len(keys) == len(shapes) == len(dtypes)):
116-
raise ValueError("Lengths of keys, shapes, dtypes must match")
117-
118-
if custom_meta is None:
119-
raise ValueError("custom_meta is required for YuanrongStorageClient.get()")
120-
121-
if len(custom_meta) != len(keys):
122-
raise ValueError("custom_meta length must match keys")
123-
124-
results: list[Optional[Any]] = [None] * len(keys)
125-
126-
# {strategy: ([index], [key])}
127-
strategy_batches: dict[StorageStrategy, tuple[list[int], list[str]]] = {s: ([], []) for s in self._strategies}
128-
129-
for i, (key, meta) in enumerate(zip(keys, custom_meta, strict=True)):
130-
for strategy in self._strategies:
131-
if strategy.supports_get(meta):
132-
strategy_batches[strategy][0].append(i)
133-
strategy_batches[strategy][1].append(key)
134-
break
135-
else:
136-
raise ValueError(f"No strategy supports getting with meta={meta}")
137-
# Todo: Parallel get
138-
for strategy, (indices, s_keys) in strategy_batches.items():
139-
s_shapes = [shapes[i] for i in indices]
140-
s_dtypes = [dtypes[i] for i in indices]
141-
142-
try:
143-
s_results = strategy.get(s_keys, shapes=s_shapes, dtypes=s_dtypes)
144-
except Exception as e:
145-
logger.error(f"Strategy {strategy.custom_meta()} failed to get keys: {s_keys}, error: {e}")
146-
raise
147-
148-
for idx, res in zip(indices, s_results, strict=True):
149-
results[idx] = res
150-
151-
return results
152-
153-
def clear(self, keys: list[str]):
154-
"""Deletes multiple keys from remote storage.
155-
156-
Args:
157-
keys (List[str]): List of keys to remove.
158-
"""
159-
pass
160-
161-
16243
class StorageStrategy(ABC):
16344
@abstractmethod
16445
@staticmethod
@@ -336,12 +217,10 @@ def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]:
336217
def supports_clear(self, custom_meta: str) -> bool:
337218
return isinstance(custom_meta, str) and custom_meta == self.custom_meta()
338219

339-
# Todo(wenlin): Add clear_buffer method
340220
def clear(self, keys: list[str]):
341-
pass
342-
# for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT):
343-
# batch = keys[i : i + self.GET_CLEAR_KEYS_LIMIT]
344-
# self._ds_client.delete(batch)
221+
for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT):
222+
batch = keys[i : i + self.GET_CLEAR_KEYS_LIMIT]
223+
self._ds_client.delete(batch)
345224

346225
@staticmethod
347226
def calc_packed_size(items: list[memoryview]) -> int:
@@ -438,3 +317,134 @@ def mget_zero_copy(self, keys: list[str]) -> list[Any]:
438317
"""
439318
buffers = self._ds_client.get_buffers(keys)
440319
return [_decoder.decode(self.unpack_from(buffer)) if buffer is not None else None for buffer in buffers]
320+
321+
322+
@StorageClientFactory.register("YuanrongStorageClient")
323+
class YuanrongStorageClient(TransferQueueStorageKVClient):
324+
"""
325+
Storage client for YuanRong DataSystem.
326+
327+
Supports storing and fetching both:
328+
- NPU tensors via DsTensorClient (for high performance).
329+
- General objects (CPU tensors, str, bool, list, etc.) via KVClient with pickle serialization.
330+
"""
331+
332+
def __init__(self, config: dict[str, Any]):
333+
if not YUANRONG_DATASYSTEM_IMPORTED:
334+
raise ImportError("YuanRong DataSystem not installed.")
335+
336+
self._strategies: list[StorageStrategy] = []
337+
for strategy_cls in [DsTensorClientAdapter, KVClientAdapter]:
338+
strategy = strategy_cls.init(config)
339+
if strategy is not None:
340+
self._strategies.append(strategy)
341+
342+
if not self._strategies:
343+
raise RuntimeError("No storage strategy available for YuanrongStorageClient")
344+
345+
def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
346+
"""Stores multiple key-value pairs to remote storage.
347+
348+
Automatically routes NPU tensors to high-performance tensor storage,
349+
and other objects to general-purpose KV storage.
350+
351+
Args:
352+
keys (List[str]): List of unique string identifiers.
353+
values (List[Any]): List of values to store (tensors, scalars, dicts, etc.).
354+
355+
Returns:
356+
List[Any]: custom metadata of YuanrongStorageCilent in the same order as input keys.
357+
"""
358+
if not isinstance(keys, list) or not isinstance(values, list):
359+
raise ValueError("keys and values must be lists")
360+
if len(keys) != len(values):
361+
raise ValueError("Number of keys must match number of values")
362+
363+
routed_indexes = self._route_to_strategies(values, lambda strategy_, item_: strategy_.supports_put(item_))
364+
custom_metas = [None] * len(keys)
365+
for strategy, indexes in routed_indexes.items():
366+
if not indexes:
367+
continue
368+
strategy_keys = [keys[i] for i in indexes]
369+
strategy_values = [values[i] for i in indexes]
370+
strategy.put(strategy_keys, strategy_values)
371+
for i in indexes:
372+
custom_metas[i] = strategy.custom_meta()
373+
374+
return custom_metas
375+
376+
def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]:
377+
"""Retrieves multiple values from remote storage with expected metadata.
378+
379+
Requires shape and dtype hints to reconstruct NPU tensors correctly.
380+
381+
Args:
382+
keys (List[str]): Keys to fetch.
383+
shapes (List[List[int]]): Expected tensor shapes (use [] for scalars).
384+
dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data.
385+
custom_meta (List[str], optional): Device type (npu/cpu) for each key
386+
387+
Returns:
388+
List[Any]: Retrieved values in the same order as input keys.
389+
"""
390+
if shapes is None or dtypes is None:
391+
raise ValueError("YuanrongStorageClient needs Expected shapes and dtypes")
392+
if not (len(keys) == len(shapes) == len(dtypes)):
393+
raise ValueError("Lengths of keys, shapes, dtypes must match")
394+
395+
if custom_meta is None:
396+
raise ValueError("custom_meta is required for YuanrongStorageClient.get()")
397+
398+
if len(custom_meta) != len(keys):
399+
raise ValueError("custom_meta length must match keys")
400+
401+
routed_indexes = self._route_to_strategies(custom_meta, lambda strategy_, item_: strategy_.supports_get(item_))
402+
403+
# Todo(dpj): Parallel get
404+
results = [None] * len(keys)
405+
for strategy, indexes in routed_indexes.items():
406+
if not indexes:
407+
continue
408+
strategy_keys = [keys[i] for i in indexes]
409+
strategy_shapes = [shapes[i] for i in indexes]
410+
strategy_dtypes = [dtypes[i] for i in indexes]
411+
strategy_results = strategy.get(strategy_keys, shapes=strategy_shapes, dtypes=strategy_dtypes)
412+
for j, i in enumerate(indexes):
413+
results[i] = strategy_results[j]
414+
415+
return results
416+
417+
def clear(self, keys: list[str]):
418+
"""Deletes multiple keys from remote storage.
419+
420+
Args:
421+
keys (List[str]): List of keys to remove.
422+
"""
423+
pass
424+
425+
def _route_to_strategies(
426+
self,
427+
items: list[Any],
428+
selector: Callable[[StorageStrategy, Any], bool],
429+
) -> dict[StorageStrategy, list[int]]:
430+
"""
431+
Groups item indices by storage strategy.
432+
433+
Args:
434+
items: A list of items (e.g., values or custom_meta strings) to be dispatched.
435+
The order must correspond to the original keys.
436+
selector: A function that determines whether a strategy supports an item.
437+
Signature: (strategy, item) -> bool
438+
439+
Returns:
440+
A dictionary mapping each strategy to a list of indices in `items`.
441+
"""
442+
routed_indexes: dict[StorageStrategy, list[int]] = {s: [] for s in self._strategies}
443+
for i, item in enumerate(items):
444+
for strategy in self._strategies:
445+
if selector(strategy, item):
446+
routed_indexes[strategy].append(i)
447+
break
448+
else:
449+
raise ValueError(f"No strategy supports item: {item}")
450+
return routed_indexes

0 commit comments

Comments
 (0)