Skip to content

Commit acd7686

Browse files
committed
Added custom_meta to clear for all TransferQueueKVStorageClient
Signed-off-by: dpj135 <958208521@qq.com>
1 parent d2ebb88 commit acd7686

5 files changed

Lines changed: 37 additions & 14 deletions

File tree

transfer_queue/storage/clients/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
4444
raise NotImplementedError("Subclasses must implement put")
4545

4646
@abstractmethod
47-
def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]:
47+
def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta: Optional[list[Any]] = None) -> list[Any]:
4848
"""
4949
Retrieve values from the storage backend by key.
5050
Args:
@@ -65,6 +65,6 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li
6565
raise NotImplementedError("Subclasses must implement get")
6666

6767
@abstractmethod
68-
def clear(self, keys: list[str]) -> None:
68+
def clear(self, keys: list[str], custom_meta: Optional[list[Any]] = None) -> None:
6969
"""Clear key-value pairs in the storage backend."""
7070
raise NotImplementedError("Subclasses must implement clear")

transfer_queue/storage/clients/mooncake_client.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li
139139
keys (List[str]): Keys to fetch.
140140
shapes (List[List[int]]): Expected tensor shapes (use [] for scalars).
141141
dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data.
142-
custom_meta (List[str], optional): Device type (npu/cpu) for each key
142+
custom_meta (List[Any], optional): ...
143143
144144
Returns:
145145
List[Any]: Retrieved values in the same order as input keys.
@@ -216,11 +216,12 @@ def _batch_get_bytes(self, keys: list[str]) -> list[bytes]:
216216
results.extend(batch_results)
217217
return results
218218

219-
def clear(self, keys: list[str]):
219+
def clear(self, keys: list[str], custom_meta=None):
220220
"""Deletes multiple keys from MooncakeStore.
221221
222222
Args:
223223
keys (List[str]): List of keys to remove.
224+
custom_meta (List[Any], optional): ...
224225
"""
225226
for key in keys:
226227
ret = self._store.remove(key)

transfer_queue/storage/clients/ray_storage_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,11 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li
106106
raise RuntimeError(f"Failed to retrieve value for key '{keys}': {e}") from e
107107
return values
108108

109-
def clear(self, keys: list[str]):
109+
def clear(self, keys: list[str], custom_meta=None):
110110
"""
111111
Delete entries from storage by keys.
112112
Args:
113113
keys (list): List of keys to delete
114+
custom_meta (List[Any], optional): ...
114115
"""
115116
ray.get(self.storage_actor.clear_obj_ref.remote(keys))

transfer_queue/storage/clients/yuanrong_client.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -219,8 +219,8 @@ def supports_clear(self, custom_meta: str) -> bool:
219219

220220
def clear(self, keys: list[str]):
221221
for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT):
222-
batch = keys[i : i + self.GET_CLEAR_KEYS_LIMIT]
223-
self._ds_client.delete(batch)
222+
batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT]
223+
self._ds_client.delete(batch_keys)
224224

225225
@staticmethod
226226
def calc_packed_size(items: list[memoryview]) -> int:
@@ -342,7 +342,7 @@ def __init__(self, config: dict[str, Any]):
342342
if not self._strategies:
343343
raise RuntimeError("No storage strategy available for YuanrongStorageClient")
344344

345-
def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
345+
def put(self, keys: list[str], values: list[Any]) -> list[str]:
346346
"""Stores multiple key-value pairs to remote storage.
347347
348348
Automatically routes NPU tensors to high-performance tensor storage,
@@ -353,15 +353,17 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
353353
values (List[Any]): List of values to store (tensors, scalars, dicts, etc.).
354354
355355
Returns:
356-
List[Any]: custom metadata of YuanrongStorageCilent in the same order as input keys.
356+
List[str]: custom metadata of YuanrongStorageCilent in the same order as input keys.
357357
"""
358358
if not isinstance(keys, list) or not isinstance(values, list):
359359
raise ValueError("keys and values must be lists")
360360
if len(keys) != len(values):
361361
raise ValueError("Number of keys must match number of values")
362362

363363
routed_indexes = self._route_to_strategies(values, lambda strategy_, item_: strategy_.supports_put(item_))
364-
custom_metas = [None] * len(keys)
364+
custom_metas: list[str] = [""] * len(keys)
365+
366+
# Todo(dpj): Parallel put
365367
for strategy, indexes in routed_indexes.items():
366368
if not indexes:
367369
continue
@@ -382,7 +384,7 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li
382384
keys (List[str]): Keys to fetch.
383385
shapes (List[List[int]]): Expected tensor shapes (use [] for scalars).
384386
dtypes (List[Optional[torch.dtype]]): Expected dtypes; use None for non-tensor data.
385-
custom_meta (List[str], optional): Device type (npu/cpu) for each key
387+
custom_meta (List[str]): StorageStrategy type for each key
386388
387389
Returns:
388390
List[Any]: Retrieved values in the same order as input keys.
@@ -414,13 +416,29 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li
414416

415417
return results
416418

417-
def clear(self, keys: list[str]):
419+
def clear(self, keys: list[str], custom_meta=None):
418420
"""Deletes multiple keys from remote storage.
419421
420422
Args:
421423
keys (List[str]): List of keys to remove.
424+
custom_meta (List[str]): StorageStrategy type for each key
422425
"""
423-
pass
426+
if not isinstance(keys, list):
427+
raise ValueError("keys must be a list")
428+
if not isinstance(custom_meta, list):
429+
raise ValueError("custom_meta must be a list if provided")
430+
if len(custom_meta) != len(keys):
431+
raise ValueError("custom_meta length must match keys")
432+
433+
routed_indexes = self._route_to_strategies(
434+
custom_meta, lambda strategy_, item_: strategy_.supports_clear(item_)
435+
)
436+
# Todo(dpj): Parallel clear
437+
for strategy, indexes in routed_indexes.items():
438+
if not indexes:
439+
continue
440+
strategy_keys = [keys[i] for i in indexes]
441+
strategy.clear(strategy_keys)
424442

425443
def _route_to_strategies(
426444
self,

transfer_queue/storage/managers/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,8 @@ async def put_data(self, data: TensorDict, metadata: BatchMeta) -> None:
553553
keys = self._generate_keys(data.keys(), metadata.global_indexes)
554554
values = self._generate_values(data)
555555
loop = asyncio.get_event_loop()
556+
557+
# put <keys, values> to storage backends
556558
custom_meta = await loop.run_in_executor(None, self.storage_client.put, keys, values)
557559

558560
per_field_dtypes: dict[int, dict[str, Any]] = {}
@@ -628,4 +630,5 @@ async def clear_data(self, metadata: BatchMeta) -> None:
628630
logger.warning("Attempted to clear data, but metadata contains no fields.")
629631
return
630632
keys = self._generate_keys(metadata.field_names, metadata.global_indexes)
631-
self.storage_client.clear(keys=keys)
633+
_, _, custom_meta = self._get_shape_type_custom_meta_list(metadata)
634+
self.storage_client.clear(keys=keys, custom_meta=custom_meta)

0 commit comments

Comments
 (0)