|
18 | 18 | import struct |
19 | 19 | from abc import ABC, abstractmethod |
20 | 20 | from concurrent.futures import ThreadPoolExecutor |
21 | | -from typing import Any, Optional, TypeAlias, Union |
| 21 | +from typing import Any, Callable, Optional, TypeAlias, Union |
22 | 22 |
|
23 | 23 | import torch |
24 | 24 | from torch import Tensor |
|
40 | 40 | YUANRONG_DATASYSTEM_IMPORTED = False |
41 | 41 |
|
42 | 42 |
|
43 | | -@StorageClientFactory.register("YuanrongStorageClient") |
44 | | -class YuanrongStorageClient(TransferQueueStorageKVClient): |
45 | | - """ |
46 | | - Storage client for YuanRong DataSystem. |
47 | | -
|
48 | | - Supports storing and fetching both: |
49 | | - - NPU tensors via DsTensorClient (for high performance). |
50 | | - - General objects (CPU tensors, str, bool, list, etc.) via KVClient with pickle serialization. |
51 | | - """ |
52 | | - |
53 | | - def __init__(self, config: dict[str, Any]): |
54 | | - if not YUANRONG_DATASYSTEM_IMPORTED: |
55 | | - raise ImportError("YuanRong DataSystem not installed.") |
56 | | - |
57 | | - self._strategies: list[StorageStrategy] = [] |
58 | | - for strategy_cls in [DsTensorClientAdapter, KVClientAdapter]: |
59 | | - strategy = strategy_cls.init(config) |
60 | | - if strategy is not None: |
61 | | - self._strategies.append(strategy) |
62 | | - |
63 | | - if not self._strategies: |
64 | | - raise RuntimeError("No storage strategy available for YuanrongStorageClient") |
65 | | - |
66 | | - def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: |
67 | | - """Stores multiple key-value pairs to remote storage. |
68 | | -
|
69 | | - Automatically routes NPU tensors to high-performance tensor storage, |
70 | | - and other objects to general-purpose KV storage. |
71 | | -
|
72 | | - Args: |
73 | | - keys (List[str]): List of unique string identifiers. |
74 | | - values (List[Any]): List of values to store (tensors, scalars, dicts, etc.). |
75 | | - """ |
76 | | - if not isinstance(keys, list) or not isinstance(values, list): |
77 | | - raise ValueError("keys and values must be lists") |
78 | | - if len(keys) != len(values): |
79 | | - raise ValueError("Number of keys must match number of values") |
80 | | - custom_metas = [] |
81 | | - strategy_batches: dict[StorageStrategy, tuple[list[str], list[Any]]] = {s: ([], []) for s in self._strategies} |
82 | | - |
83 | | - for key, value in zip(keys, values, strict=True): |
84 | | - for strategy in self._strategies: |
85 | | - if strategy.supports_put(value): |
86 | | - custom_metas.append(strategy.custom_meta()) |
87 | | - strategy_batches[strategy][0].append(key) |
88 | | - strategy_batches[strategy][1].append(value) |
89 | | - break |
90 | | - else: |
91 | | - raise ValueError(f"No strategy supports putting value of type {type(value)}") |
92 | | - # Todo: Parallel put |
93 | | - for strategy, (s_keys, s_vals) in strategy_batches.items(): |
94 | | - if s_keys: |
95 | | - strategy.put(s_keys, s_vals) |
96 | | - |
97 | | - return custom_metas |
98 | | - |
99 | | - def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: |
100 | | - """Retrieves multiple values from remote storage with expected metadata. |
101 | | -
|
102 | | - Requires shape and dtype hints to reconstruct NPU tensors correctly. |
103 | | -
|
104 | | - Args: |
105 | | - keys (List[str]): Keys to fetch. |
106 | | - shapes (List[List[int]]): Expected tensor shapes (use [] for scalars). |
107 | | - dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data. |
108 | | - custom_meta (List[str], optional): Device type (npu/cpu) for each key |
109 | | -
|
110 | | - Returns: |
111 | | - List[Any]: Retrieved values in the same order as input keys. |
112 | | - """ |
113 | | - if shapes is None or dtypes is None: |
114 | | - raise ValueError("YuanrongStorageClient needs Expected shapes and dtypes") |
115 | | - if not (len(keys) == len(shapes) == len(dtypes)): |
116 | | - raise ValueError("Lengths of keys, shapes, dtypes must match") |
117 | | - |
118 | | - if custom_meta is None: |
119 | | - raise ValueError("custom_meta is required for YuanrongStorageClient.get()") |
120 | | - |
121 | | - if len(custom_meta) != len(keys): |
122 | | - raise ValueError("custom_meta length must match keys") |
123 | | - |
124 | | - results: list[Optional[Any]] = [None] * len(keys) |
125 | | - |
126 | | - # {strategy: ([index], [key])} |
127 | | - strategy_batches: dict[StorageStrategy, tuple[list[int], list[str]]] = {s: ([], []) for s in self._strategies} |
128 | | - |
129 | | - for i, (key, meta) in enumerate(zip(keys, custom_meta, strict=True)): |
130 | | - for strategy in self._strategies: |
131 | | - if strategy.supports_get(meta): |
132 | | - strategy_batches[strategy][0].append(i) |
133 | | - strategy_batches[strategy][1].append(key) |
134 | | - break |
135 | | - else: |
136 | | - raise ValueError(f"No strategy supports getting with meta={meta}") |
137 | | - # Todo: Parallel get |
138 | | - for strategy, (indices, s_keys) in strategy_batches.items(): |
139 | | - s_shapes = [shapes[i] for i in indices] |
140 | | - s_dtypes = [dtypes[i] for i in indices] |
141 | | - |
142 | | - try: |
143 | | - s_results = strategy.get(s_keys, shapes=s_shapes, dtypes=s_dtypes) |
144 | | - except Exception as e: |
145 | | - logger.error(f"Strategy {strategy.custom_meta()} failed to get keys: {s_keys}, error: {e}") |
146 | | - raise |
147 | | - |
148 | | - for idx, res in zip(indices, s_results, strict=True): |
149 | | - results[idx] = res |
150 | | - |
151 | | - return results |
152 | | - |
153 | | - def clear(self, keys: list[str]): |
154 | | - """Deletes multiple keys from remote storage. |
155 | | -
|
156 | | - Args: |
157 | | - keys (List[str]): List of keys to remove. |
158 | | - """ |
159 | | - pass |
160 | | - |
161 | | - |
162 | 43 | class StorageStrategy(ABC): |
163 | 44 | @abstractmethod |
164 | 45 | @staticmethod |
@@ -336,12 +217,10 @@ def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]: |
336 | 217 | def supports_clear(self, custom_meta: str) -> bool: |
337 | 218 | return isinstance(custom_meta, str) and custom_meta == self.custom_meta() |
338 | 219 |
|
339 | | - # Todo(wenlin): Add clear_buffer method |
340 | 220 | def clear(self, keys: list[str]): |
341 | | - pass |
342 | | - # for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT): |
343 | | - # batch = keys[i : i + self.GET_CLEAR_KEYS_LIMIT] |
344 | | - # self._ds_client.delete(batch) |
| 221 | + for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT): |
| 222 | + batch = keys[i : i + self.GET_CLEAR_KEYS_LIMIT] |
| 223 | + self._ds_client.delete(batch) |
345 | 224 |
|
346 | 225 | @staticmethod |
347 | 226 | def calc_packed_size(items: list[memoryview]) -> int: |
@@ -438,3 +317,134 @@ def mget_zero_copy(self, keys: list[str]) -> list[Any]: |
438 | 317 | """ |
439 | 318 | buffers = self._ds_client.get_buffers(keys) |
440 | 319 | return [_decoder.decode(self.unpack_from(buffer)) if buffer is not None else None for buffer in buffers] |
| 320 | + |
| 321 | + |
| 322 | +@StorageClientFactory.register("YuanrongStorageClient") |
| 323 | +class YuanrongStorageClient(TransferQueueStorageKVClient): |
| 324 | + """ |
| 325 | + Storage client for YuanRong DataSystem. |
| 326 | +
|
| 327 | + Supports storing and fetching both: |
| 328 | + - NPU tensors via DsTensorClient (for high performance). |
| 329 | + - General objects (CPU tensors, str, bool, list, etc.) via KVClient with pickle serialization. |
| 330 | + """ |
| 331 | + |
| 332 | + def __init__(self, config: dict[str, Any]): |
| 333 | + if not YUANRONG_DATASYSTEM_IMPORTED: |
| 334 | + raise ImportError("YuanRong DataSystem not installed.") |
| 335 | + |
| 336 | + self._strategies: list[StorageStrategy] = [] |
| 337 | + for strategy_cls in [DsTensorClientAdapter, KVClientAdapter]: |
| 338 | + strategy = strategy_cls.init(config) |
| 339 | + if strategy is not None: |
| 340 | + self._strategies.append(strategy) |
| 341 | + |
| 342 | + if not self._strategies: |
| 343 | + raise RuntimeError("No storage strategy available for YuanrongStorageClient") |
| 344 | + |
| 345 | + def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]: |
| 346 | + """Stores multiple key-value pairs to remote storage. |
| 347 | +
|
| 348 | + Automatically routes NPU tensors to high-performance tensor storage, |
| 349 | + and other objects to general-purpose KV storage. |
| 350 | +
|
| 351 | + Args: |
| 352 | + keys (List[str]): List of unique string identifiers. |
| 353 | + values (List[Any]): List of values to store (tensors, scalars, dicts, etc.). |
| 354 | +
|
| 355 | + Returns: |
| 356 | + List[Any]: custom metadata of YuanrongStorageCilent in the same order as input keys. |
| 357 | + """ |
| 358 | + if not isinstance(keys, list) or not isinstance(values, list): |
| 359 | + raise ValueError("keys and values must be lists") |
| 360 | + if len(keys) != len(values): |
| 361 | + raise ValueError("Number of keys must match number of values") |
| 362 | + |
| 363 | + routed_indexes = self._route_to_strategies(values, lambda strategy_, item_: strategy_.supports_put(item_)) |
| 364 | + custom_metas = [None] * len(keys) |
| 365 | + for strategy, indexes in routed_indexes.items(): |
| 366 | + if not indexes: |
| 367 | + continue |
| 368 | + strategy_keys = [keys[i] for i in indexes] |
| 369 | + strategy_values = [values[i] for i in indexes] |
| 370 | + strategy.put(strategy_keys, strategy_values) |
| 371 | + for i in indexes: |
| 372 | + custom_metas[i] = strategy.custom_meta() |
| 373 | + |
| 374 | + return custom_metas |
| 375 | + |
| 376 | + def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]: |
| 377 | + """Retrieves multiple values from remote storage with expected metadata. |
| 378 | +
|
| 379 | + Requires shape and dtype hints to reconstruct NPU tensors correctly. |
| 380 | +
|
| 381 | + Args: |
| 382 | + keys (List[str]): Keys to fetch. |
| 383 | + shapes (List[List[int]]): Expected tensor shapes (use [] for scalars). |
| 384 | + dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data. |
| 385 | + custom_meta (List[str], optional): Device type (npu/cpu) for each key |
| 386 | +
|
| 387 | + Returns: |
| 388 | + List[Any]: Retrieved values in the same order as input keys. |
| 389 | + """ |
| 390 | + if shapes is None or dtypes is None: |
| 391 | + raise ValueError("YuanrongStorageClient needs Expected shapes and dtypes") |
| 392 | + if not (len(keys) == len(shapes) == len(dtypes)): |
| 393 | + raise ValueError("Lengths of keys, shapes, dtypes must match") |
| 394 | + |
| 395 | + if custom_meta is None: |
| 396 | + raise ValueError("custom_meta is required for YuanrongStorageClient.get()") |
| 397 | + |
| 398 | + if len(custom_meta) != len(keys): |
| 399 | + raise ValueError("custom_meta length must match keys") |
| 400 | + |
| 401 | + routed_indexes = self._route_to_strategies(custom_meta, lambda strategy_, item_: strategy_.supports_get(item_)) |
| 402 | + |
| 403 | + # Todo(dpj): Parallel get |
| 404 | + results = [None] * len(keys) |
| 405 | + for strategy, indexes in routed_indexes.items(): |
| 406 | + if not indexes: |
| 407 | + continue |
| 408 | + strategy_keys = [keys[i] for i in indexes] |
| 409 | + strategy_shapes = [shapes[i] for i in indexes] |
| 410 | + strategy_dtypes = [dtypes[i] for i in indexes] |
| 411 | + strategy_results = strategy.get(strategy_keys, shapes=strategy_shapes, dtypes=strategy_dtypes) |
| 412 | + for j, i in enumerate(indexes): |
| 413 | + results[i] = strategy_results[j] |
| 414 | + |
| 415 | + return results |
| 416 | + |
| 417 | + def clear(self, keys: list[str]): |
| 418 | + """Deletes multiple keys from remote storage. |
| 419 | +
|
| 420 | + Args: |
| 421 | + keys (List[str]): List of keys to remove. |
| 422 | + """ |
| 423 | + pass |
| 424 | + |
| 425 | + def _route_to_strategies( |
| 426 | + self, |
| 427 | + items: list[Any], |
| 428 | + selector: Callable[[StorageStrategy, Any], bool], |
| 429 | + ) -> dict[StorageStrategy, list[int]]: |
| 430 | + """ |
| 431 | + Groups item indices by storage strategy. |
| 432 | +
|
| 433 | + Args: |
| 434 | + items: A list of items (e.g., values or custom_meta strings) to be dispatched. |
| 435 | + The order must correspond to the original keys. |
| 436 | + selector: A function that determines whether a strategy supports an item. |
| 437 | + Signature: (strategy, item) -> bool |
| 438 | +
|
| 439 | + Returns: |
| 440 | + A dictionary mapping each strategy to a list of indices in `items`. |
| 441 | + """ |
| 442 | + routed_indexes: dict[StorageStrategy, list[int]] = {s: [] for s in self._strategies} |
| 443 | + for i, item in enumerate(items): |
| 444 | + for strategy in self._strategies: |
| 445 | + if selector(strategy, item): |
| 446 | + routed_indexes[strategy].append(i) |
| 447 | + break |
| 448 | + else: |
| 449 | + raise ValueError(f"No strategy supports item: {item}") |
| 450 | + return routed_indexes |
0 commit comments