@@ -45,7 +45,7 @@ class StorageStrategy(ABC):
4545
4646 @staticmethod
4747 @abstractmethod
48- def init (config : dict ) -> Union ["StorageStrategy" , None ]:
48+ def init (config : dict ) -> Optional ["StorageStrategy" ]:
4949 """Initialize strategy from config; return None if not applicable."""
5050
5151 @abstractmethod
@@ -115,10 +115,8 @@ def supports_put(self, value: Any) -> bool:
115115 """Supports contiguous NPU tensors only."""
116116 if not (isinstance (value , torch .Tensor ) and value .device .type == "npu" ):
117117 return False
118- # Todo(dpj): perhaps KVClient can process uncontiguous tensor
119- if not value .is_contiguous ():
120- raise ValueError (f"NPU Tensor is not contiguous: { value } " )
121- return True
118+ # Only contiguous NPU tensors are supported by this adapter.
119+ return value .is_contiguous ()
122120
123121 def put (self , keys : list [str ], values : list [Any ]):
124122 """Store NPU tensors in batches; deletes before overwrite."""
@@ -150,10 +148,7 @@ def get(self, keys: list[str], **kwargs) -> list[Optional[Any]]:
150148
151149 batch_values = self ._create_empty_npu_tensorlist (batch_shapes , batch_dtypes )
152150 self ._ds_client .dev_mget (batch_keys , batch_values )
153- # Todo(dpj): should we check failed keys?
154- # failed_keys = self._ds_client.dev_mget(batch_keys, batch_values)
155- # if failed_keys:
156- # logging.warning(f"YuanrongStorageClient: Querying keys using 'DsTensorClient' failed: {failed_keys}")
151+ # Todo(dpj): consider checking and logging keys that fail during dev_mget
157152 results .extend (batch_values )
158153 return results
159154
@@ -212,7 +207,7 @@ def __init__(self, config: dict):
212207 logger .info ("YuanrongStorageClient: Create KVClient to connect with yuanrong-datasystem backend!" )
213208
214209 @staticmethod
215- def init (config : dict ) -> Union ["StorageStrategy" , None ]:
210+ def init (config : dict ) -> Optional ["StorageStrategy" ]:
216211 """Always enabled for general objects."""
217212 return KVClientAdapter (config )
218213
@@ -374,7 +369,7 @@ def __init__(self, config: dict[str, Any]):
374369 if not self ._strategies :
375370 raise RuntimeError ("No storage strategy available for YuanrongStorageClient" )
376371
377- def put (self , keys : list [str ], values : list [Any ]) -> list [str ]:
372+ def put (self , keys : list [str ], values : list [Any ]) -> Optional [ list [Any ] ]:
378373 """Stores multiple key-value pairs to remote storage.
379374
380375 Automatically routes NPU tensors to high-performance tensor storage,
@@ -385,7 +380,7 @@ def put(self, keys: list[str], values: list[Any]) -> list[str]:
385380 values (List[Any]): List of values to store (tensors, scalars, dicts, etc.).
386381
387382 Returns:
388- List[str]: custom metadata of YuanrongStorageCilent in the same order as input keys.
383+ List[str]: custom metadata of YuanrongStorageClient in the same order as input keys.
389384 """
390385 if not isinstance (keys , list ) or not isinstance (values , list ):
391386 raise ValueError ("keys and values must be lists" )
@@ -492,7 +487,10 @@ def _route_to_strategies(
492487 routed_indexes [strategy ].append (i )
493488 break
494489 else :
495- raise ValueError (f"No strategy supports item: { item } " )
490+ raise ValueError (
491+ f"No strategy supports item of type { type (item ).__name__ } : { item } . "
492+ f"Available strategies: { [type (s ).__name__ for s in self ._strategies ]} "
493+ )
496494 return routed_indexes
497495
498496 @staticmethod
@@ -520,7 +518,10 @@ def _dispatch_tasks(routed_tasks: dict[StorageStrategy, list[int]], task_functio
520518 return [task_function (* active_tasks [0 ])]
521519
522520 # Parallel path: overlap NPU and CPU operations
523- with ThreadPoolExecutor (max_workers = len (active_tasks )) as executor :
521+ # Cap the number of worker threads to avoid resource exhaustion if many
522+ # strategies are added in the future.
523+ max_workers = min (len (active_tasks ), 4 )
524+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
524525 # futures' results are from task_function
525526 futures = [executor .submit (task_function , strategy , indexes ) for strategy , indexes in active_tasks ]
526527 return [f .result () for f in futures ]
0 commit comments