Skip to content

Commit 37cf9d9

Browse files
dpj135Copilot
andcommitted
Apply suggestions from code review
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: dpj135 <958208521@qq.com>
1 parent e1661a2 commit 37cf9d9

1 file changed

Lines changed: 15 additions & 14 deletions

File tree

transfer_queue/storage/clients/yuanrong_client.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class StorageStrategy(ABC):
4545

4646
@staticmethod
4747
@abstractmethod
48-
def init(config: dict) -> Union["StorageStrategy", None]:
48+
def init(config: dict) -> Optional["StorageStrategy"]:
4949
"""Initialize strategy from config; return None if not applicable."""
5050

5151
@abstractmethod
@@ -115,10 +115,8 @@ def supports_put(self, value: Any) -> bool:
115115
"""Supports contiguous NPU tensors only."""
116116
if not (isinstance(value, torch.Tensor) and value.device.type == "npu"):
117117
return False
118-
# Todo(dpj): perhaps KVClient can process uncontiguous tensor
119-
if not value.is_contiguous():
120-
raise ValueError(f"NPU Tensor is not contiguous: {value}")
121-
return True
118+
# Only contiguous NPU tensors are supported by this adapter.
119+
return value.is_contiguous()
122120

123121
def put(self, keys: list[str], values: list[Any]):
124122
"""Store NPU tensors in batches; deletes before overwrite."""
@@ -150,10 +148,7 @@ def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]:
150148

151149
batch_values = self._create_empty_npu_tensorlist(batch_shapes, batch_dtypes)
152150
self._ds_client.dev_mget(batch_keys, batch_values)
153-
# Todo(dpj): should we check failed keys?
154-
# failed_keys = self._ds_client.dev_mget(batch_keys, batch_values)
155-
# if failed_keys:
156-
# logging.warning(f"YuanrongStorageClient: Querying keys using 'DsTensorClient' failed: {failed_keys}")
151+
# Todo(dpj): consider checking and logging keys that fail during dev_mget
157152
results.extend(batch_values)
158153
return results
159154

@@ -212,7 +207,7 @@ def __init__(self, config: dict):
212207
logger.info("YuanrongStorageClient: Create KVClient to connect with yuanrong-datasystem backend!")
213208

214209
@staticmethod
215-
def init(config: dict) -> Union["StorageStrategy", None]:
210+
def init(config: dict) -> Optional["StorageStrategy"]:
216211
"""Always enabled for general objects."""
217212
return KVClientAdapter(config)
218213

@@ -374,7 +369,7 @@ def __init__(self, config: dict[str, Any]):
374369
if not self._strategies:
375370
raise RuntimeError("No storage strategy available for YuanrongStorageClient")
376371

377-
def put(self, keys: list[str], values: list[Any]) -> list[str]:
372+
def put(self, keys: list[str], values: list[Any]) -> Optional[list[Any]]:
378373
"""Stores multiple key-value pairs to remote storage.
379374
380375
Automatically routes NPU tensors to high-performance tensor storage,
@@ -385,7 +380,7 @@ def put(self, keys: list[str], values: list[Any]) -> list[str]:
385380
values (List[Any]): List of values to store (tensors, scalars, dicts, etc.).
386381
387382
Returns:
388-
List[str]: custom metadata of YuanrongStorageCilent in the same order as input keys.
383+
List[str]: custom metadata of YuanrongStorageClient in the same order as input keys.
389384
"""
390385
if not isinstance(keys, list) or not isinstance(values, list):
391386
raise ValueError("keys and values must be lists")
@@ -492,7 +487,10 @@ def _route_to_strategies(
492487
routed_indexes[strategy].append(i)
493488
break
494489
else:
495-
raise ValueError(f"No strategy supports item: {item}")
490+
raise ValueError(
491+
f"No strategy supports item of type {type(item).__name__}: {item}. "
492+
f"Available strategies: {[type(s).__name__ for s in self._strategies]}"
493+
)
496494
return routed_indexes
497495

498496
@staticmethod
@@ -520,7 +518,10 @@ def _dispatch_tasks(routed_tasks: dict[StorageStrategy, list[int]], task_functio
520518
return [task_function(*active_tasks[0])]
521519

522520
# Parallel path: overlap NPU and CPU operations
523-
with ThreadPoolExecutor(max_workers=len(active_tasks)) as executor:
521+
# Cap the number of worker threads to avoid resource exhaustion if many
522+
# strategies are added in the future.
523+
max_workers = min(len(active_tasks), 4)
524+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
524525
# futures' results are from task_function
525526
futures = [executor.submit(task_function, strategy, indexes) for strategy, indexes in active_tasks]
526527
return [f.result() for f in futures]

0 commit comments

Comments
 (0)