Skip to content

Commit 2aa3955

Browse files
committed
Renamed adapter classes & rename 'custom_meta()' to 'strategy_tag()' & adjusted annotation related to 'custom_meta()'
Signed-off-by: dpj135 <958208521@qq.com>
1 parent 37cf9d9 commit 2aa3955

1 file changed

Lines changed: 43 additions & 43 deletions

File tree

transfer_queue/storage/clients/yuanrong_client.py

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def init(config: dict) -> Optional["StorageStrategy"]:
4949
"""Initialize strategy from config; return None if not applicable."""
5050

5151
@abstractmethod
52-
def custom_meta(self) -> Any:
53-
"""Return metadata identifying this strategy (e.g., string name)."""
52+
def strategy_tag(self) -> Any:
53+
"""Return metadata identifying this strategy (e.g., string name, byte tag)."""
5454

5555
@abstractmethod
5656
def supports_put(self, value: Any) -> bool:
@@ -77,7 +77,7 @@ def clear(self, keys: list[str]):
7777
"""Delete keys from storage."""
7878

7979

80-
class DsTensorClientAdapter(StorageStrategy):
80+
class NPUTensorKVClientAdapter(StorageStrategy):
8181
"""Adapter for YuanRong's high-performance NPU tensor storage."""
8282

8383
KEYS_LIMIT: int = 10_000
@@ -105,11 +105,11 @@ def init(config: dict) -> Union["StorageStrategy", None]:
105105
if not (enable and torch_npu_imported and torch.npu.is_available()):
106106
return None
107107

108-
return DsTensorClientAdapter(config)
108+
return NPUTensorKVClientAdapter(config)
109109

110-
def custom_meta(self) -> Any:
111-
"""Metadata tag for NPU tensor storage."""
112-
return "DsTensorClient"
110+
def strategy_tag(self) -> bytes:
111+
"""Strategy tag for NPU tensor storage. Using a single byte is for better performance."""
112+
return b"\x01"
113113

114114
def supports_put(self, value: Any) -> bool:
115115
"""Supports contiguous NPU tensors only."""
@@ -131,8 +131,8 @@ def put(self, keys: list[str], values: list[Any]):
131131
self._ds_client.dev_mset(batch_keys, batch_values)
132132

133133
def supports_get(self, custom_meta: str) -> bool:
134-
"""Matches 'DsTensorClient' custom metadata."""
135-
return isinstance(custom_meta, str) and custom_meta == self.custom_meta()
134+
"""Matches 'DsTensorClient' Strategy tag."""
135+
return isinstance(custom_meta, bytes) and custom_meta == self.strategy_tag()
136136

137137
def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]:
138138
"""Fetch NPU tensors using pre-allocated empty buffers."""
@@ -153,8 +153,8 @@ def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]:
153153
return results
154154

155155
def supports_clear(self, custom_meta: str) -> bool:
156-
"""Matches 'DsTensorClient' metadata."""
157-
return isinstance(custom_meta, str) and custom_meta == self.custom_meta()
156+
"""Matches 'DsTensorClient' strategy tag."""
157+
return isinstance(custom_meta, bytes) and custom_meta == self.strategy_tag()
158158

159159
def clear(self, keys: list[str]):
160160
"""Delete NPU tensor keys in batches."""
@@ -180,7 +180,7 @@ def _create_empty_npu_tensorlist(self, shapes, dtypes):
180180
return tensors
181181

182182

183-
class KVClientAdapter(StorageStrategy):
183+
class GeneralKVClientAdapter(StorageStrategy):
184184
"""
185185
Adapter for general-purpose KV storage with serialization.
186186
The serialization method uses '_decoder' and '_encoder' from 'transfer_queue.utils.serial_utils'.
@@ -209,11 +209,11 @@ def __init__(self, config: dict):
209209
@staticmethod
210210
def init(config: dict) -> Optional["StorageStrategy"]:
211211
"""Always enabled for general objects."""
212-
return KVClientAdapter(config)
212+
return GeneralKVClientAdapter(config)
213213

214-
def custom_meta(self) -> Any:
215-
"""Metadata tag for general KV storage."""
216-
return "KVClient"
214+
def strategy_tag(self) -> bytes:
215+
"""Strategy tag for general KV storage. Using a single byte is for better performance."""
216+
return b"\x02"
217217

218218
def supports_put(self, value: Any) -> bool:
219219
"""Accepts any Python object."""
@@ -227,8 +227,8 @@ def put(self, keys: list[str], values: list[Any]):
227227
self.mset_zero_copy(batch_keys, batch_vals)
228228

229229
def supports_get(self, custom_meta: str) -> bool:
230-
"""Matches 'KVClient' metadata."""
231-
return isinstance(custom_meta, str) and custom_meta == self.custom_meta()
230+
"""Matches 'KVClient' strategy tag."""
231+
return isinstance(custom_meta, bytes) and custom_meta == self.strategy_tag()
232232

233233
def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]:
234234
"""Retrieve and deserialize objects in batches."""
@@ -240,17 +240,17 @@ def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]:
240240
return results
241241

242242
def supports_clear(self, custom_meta: str) -> bool:
243-
"""Matches 'KVClient' metadata."""
244-
return isinstance(custom_meta, str) and custom_meta == self.custom_meta()
243+
"""Matches 'KVClient' strategy tag."""
244+
return isinstance(custom_meta, bytes) and custom_meta == self.strategy_tag()
245245

246246
def clear(self, keys: list[str]):
247247
"""Delete keys in batches."""
248248
for i in range(0, len(keys), self.GET_CLEAR_KEYS_LIMIT):
249249
batch_keys = keys[i : i + self.GET_CLEAR_KEYS_LIMIT]
250250
self._ds_client.delete(batch_keys)
251251

252-
@staticmethod
253-
def calc_packed_size(items: list[memoryview]) -> int:
252+
@classmethod
253+
def calc_packed_size(cls, items: list[memoryview]) -> int:
254254
"""
255255
Calculate the total size (in bytes) required to pack a list of memoryview items
256256
into the structured binary format used by pack_into.
@@ -261,12 +261,10 @@ def calc_packed_size(items: list[memoryview]) -> int:
261261
Returns:
262262
Total buffer size in bytes.
263263
"""
264-
return (
265-
KVClientAdapter.HEADER_SIZE + len(items) * KVClientAdapter.ENTRY_SIZE + sum(item.nbytes for item in items)
266-
)
264+
return cls.HEADER_SIZE + len(items) * cls.ENTRY_SIZE + sum(item.nbytes for item in items)
267265

268-
@staticmethod
269-
def pack_into(target: memoryview, items: list[memoryview]):
266+
@classmethod
267+
def pack_into(cls, target: memoryview, items: list[memoryview]):
270268
"""
271269
Pack multiple contiguous buffers into a single buffer.
272270
┌───────────────┐
@@ -285,22 +283,22 @@ def pack_into(target: memoryview, items: list[memoryview]):
285283
Each item must support the buffer protocol and be readable as raw bytes.
286284
287285
"""
288-
struct.pack_into(KVClientAdapter.HEADER_FMT, target, 0, len(items))
286+
struct.pack_into(cls.HEADER_FMT, target, 0, len(items))
289287

290-
entry_offset = KVClientAdapter.HEADER_SIZE
291-
payload_offset = KVClientAdapter.HEADER_SIZE + len(items) * KVClientAdapter.ENTRY_SIZE
288+
entry_offset = cls.HEADER_SIZE
289+
payload_offset = cls.HEADER_SIZE + len(items) * cls.ENTRY_SIZE
292290

293291
target_tensor = torch.frombuffer(target, dtype=torch.uint8)
294292

295293
for item in items:
296-
struct.pack_into(KVClientAdapter.ENTRY_FMT, target, entry_offset, payload_offset, item.nbytes)
294+
struct.pack_into(cls.ENTRY_FMT, target, entry_offset, payload_offset, item.nbytes)
297295
src_tensor = torch.frombuffer(item, dtype=torch.uint8)
298296
target_tensor[payload_offset : payload_offset + item.nbytes].copy_(src_tensor)
299-
entry_offset += KVClientAdapter.ENTRY_SIZE
297+
entry_offset += cls.ENTRY_SIZE
300298
payload_offset += item.nbytes
301299

302-
@staticmethod
303-
def unpack_from(source: memoryview) -> list[memoryview]:
300+
@classmethod
301+
def unpack_from(cls, source: memoryview) -> list[memoryview]:
304302
"""
305303
Unpack multiple contiguous buffers from a single packed buffer.
306304
Args:
@@ -309,12 +307,10 @@ def unpack_from(source: memoryview) -> list[memoryview]:
309307
list[memoryview]: List of unpacked contiguous buffers.
310308
"""
311309
mv = memoryview(source)
312-
item_count = struct.unpack_from(KVClientAdapter.HEADER_FMT, mv, 0)[0]
310+
item_count = struct.unpack_from(cls.HEADER_FMT, mv, 0)[0]
313311
offsets = []
314312
for i in range(item_count):
315-
offset, length = struct.unpack_from(
316-
KVClientAdapter.ENTRY_FMT, mv, KVClientAdapter.HEADER_SIZE + i * KVClientAdapter.ENTRY_SIZE
317-
)
313+
offset, length = struct.unpack_from(cls.ENTRY_FMT, mv, cls.HEADER_SIZE + i * cls.ENTRY_SIZE)
318314
offsets.append((offset, length))
319315
return [mv[offset : offset + length] for offset, length in offsets]
320316

@@ -351,17 +347,21 @@ class YuanrongStorageClient(TransferQueueStorageKVClient):
351347
"""
352348
Storage client for YuanRong DataSystem.
353349
350+
Use different storage strategies depending on the data type.
354351
Supports storing and fetching both:
355-
- NPU tensors via DsTensorClient (for high performance).
356-
- General objects (CPU tensors, str, bool, list, etc.) via KVClient with serialization.
352+
- NPU tensors via NPUTensorKVClientAdapter (for high performance).
353+
- General objects (CPU tensors, str, bool, list, etc.) via GeneralKVClientAdapter with serialization.
357354
"""
358355

359356
def __init__(self, config: dict[str, Any]):
360357
if not YUANRONG_DATASYSTEM_IMPORTED:
361358
raise ImportError("YuanRong DataSystem not installed.")
362359

360+
# Storage strategies are prioritized in ascending order of list element index.
361+
# In other words, the later in the order, the lower the priority.
362+
storage_strategies_priority = [NPUTensorKVClientAdapter, GeneralKVClientAdapter]
363363
self._strategies: list[StorageStrategy] = []
364-
for strategy_cls in [DsTensorClientAdapter, KVClientAdapter]:
364+
for strategy_cls in storage_strategies_priority:
365365
strategy = strategy_cls.init(config)
366366
if strategy is not None:
367367
self._strategies.append(strategy)
@@ -380,7 +380,7 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
380380
values (List[Any]): List of values to store (tensors, scalars, dicts, etc.).
381381
382382
Returns:
383-
List[str]: custom metadata of YuanrongStorageClient in the same order as input keys.
383+
List[str]: storage strategy tag of YuanrongStorageClient in the same order as input keys.
384384
"""
385385
if not isinstance(keys, list) or not isinstance(values, list):
386386
raise ValueError("keys and values must be lists")
@@ -393,7 +393,7 @@ def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
393393
# The closure captures local 'keys' and 'values' for zero-overhead parameter passing.
394394
def put_task(strategy, indexes):
395395
strategy.put([keys[i] for i in indexes], [values[i] for i in indexes])
396-
return strategy.custom_meta(), indexes
396+
return strategy.strategy_tag(), indexes
397397

398398
# Dispatch tasks and map metadata back to original positions
399399
custom_meta: list[str] = [""] * len(keys)

0 commit comments

Comments
 (0)