@@ -361,19 +361,21 @@ def put(self, keys: list[str], values: list[Any]) -> list[str]:
361361 raise ValueError ("Number of keys must match number of values" )
362362
363363 routed_indexes = self ._route_to_strategies (values , lambda strategy_ , item_ : strategy_ .supports_put (item_ ))
364- custom_metas : list [str ] = ["" ] * len (keys )
365-
366- # Todo(dpj): Parallel put
367- for strategy , indexes in routed_indexes .items ():
368- if not indexes :
369- continue
370- strategy_keys = [keys [i ] for i in indexes ]
371- strategy_values = [values [i ] for i in indexes ]
372- strategy .put (strategy_keys , strategy_values )
373- for i in indexes :
374- custom_metas [i ] = strategy .custom_meta ()
375364
376- return custom_metas
365+ # Define the work unit: Slicing the input list and calling the backend strategy.
366+ # The closure captures local 'keys' and 'values' for zero-overhead parameter passing.
367+ def put_task (strategy , indexes ):
368+ strategy .put ([keys [i ] for i in indexes ], [values [i ] for i in indexes ])
369+ return strategy .custom_meta (), indexes
370+
371+ # Call the orchestrator to run the tasks (Parallel or Sequential).
372+ # We then iterate through the results to map strategy-specific metadata back
373+ # to the original global index order.
374+ custom_meta : list [str ] = ["" ] * len (keys )
375+ for meta_str , indexes in self ._dispatch_tasks (routed_indexes , put_task ):
376+ for i in indexes :
377+ custom_meta [i ] = meta_str
378+ return custom_meta
377379
378380 def get (self , keys : list [str ], shapes = None , dtypes = None , custom_meta = None ) -> list [Any ]:
379381 """Retrieves multiple values from remote storage with expected metadata.
@@ -402,18 +404,20 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li
402404
403405 routed_indexes = self ._route_to_strategies (custom_meta , lambda strategy_ , item_ : strategy_ .supports_get (item_ ))
404406
405- # Todo(dpj): Parallel get
406- results = [None ] * len (keys )
407- for strategy , indexes in routed_indexes .items ():
408- if not indexes :
409- continue
410- strategy_keys = [keys [i ] for i in indexes ]
411- strategy_shapes = [shapes [i ] for i in indexes ]
412- strategy_dtypes = [dtypes [i ] for i in indexes ]
413- strategy_results = strategy .get (strategy_keys , shapes = strategy_shapes , dtypes = strategy_dtypes )
414- for j , i in enumerate (indexes ):
415- results [i ] = strategy_results [j ]
407+ # Work unit for 'get': handles slicing of keys, shapes, and dtypes simultaneously.
408+ def get_task (strategy , indexes ):
409+ res = strategy .get (
410+ [keys [i ] for i in indexes ], shapes = [shapes [i ] for i in indexes ], dtypes = [dtypes [i ] for i in indexes ]
411+ )
412+ return res , indexes
416413
414+ # Dispatch the 'get' requests. Multiple backends (e.g. NPU and SSD) will fetch
415+ # in parallel if needed. Results are merged back into the 'results' list
416+ # according to their original positions.
417+ results = [None ] * len (keys )
418+ for strategy_res , indexes in self ._dispatch_tasks (routed_indexes , get_task ):
419+ for value , original_index in zip (strategy_res , indexes , strict = True ):
420+ results [original_index ] = value
417421 return results
418422
419423 def clear (self , keys : list [str ], custom_meta = None ):
@@ -433,12 +437,14 @@ def clear(self, keys: list[str], custom_meta=None):
433437 routed_indexes = self ._route_to_strategies (
434438 custom_meta , lambda strategy_ , item_ : strategy_ .supports_clear (item_ )
435439 )
436- # Todo(dpj): Parallel clear
437- for strategy , indexes in routed_indexes .items ():
438- if not indexes :
439- continue
440- strategy_keys = [keys [i ] for i in indexes ]
441- strategy .clear (strategy_keys )
440+
441+ # Cleanup work unit: Does not return values, just executes deletion.
442+ def clear_task (strategy , indexes ):
443+ strategy .clear ([keys [i ] for i in indexes ])
444+
445+ # Parallelize deletion across strategies.
446+ # The 'with' context in _dispatch_tasks will wait for all deletions to finish.
447+ self ._dispatch_tasks (routed_indexes , clear_task )
442448
443449 def _route_to_strategies (
444450 self ,
@@ -466,3 +472,29 @@ def _route_to_strategies(
466472 else :
467473 raise ValueError (f"No strategy supports item: { item } " )
468474 return routed_indexes
475+
476+ def _dispatch_tasks (self , routed_tasks : dict [StorageStrategy , list [int ]], task_function : Callable ) -> list [Any ]:
477+ """
478+ Orchestrates task execution across multiple strategies.
479+
480+ Logic:
481+ 1. If no tasks are present, return immediately.
482+ 2. If only one strategy is active, execute synchronously in the main thread (Fast Path)
483+ to avoid the overhead of thread creation and context switching.
484+ 3. If multiple strategies are active, execute in parallel using a ThreadPoolExecutor.
485+ """
486+ active_tasks = [(strategy , indexes ) for strategy , indexes in routed_tasks .items () if indexes ]
487+
488+ if not active_tasks :
489+ return []
490+
491+ # Fast Path: Execute directly if only one backend is targeted.
492+ # This significantly reduces latency for homogeneous batches (e.g., NPU-only).
493+ if len (active_tasks ) == 1 :
494+ return [task_function (* active_tasks [0 ])]
495+
496+ # Parallel Path: Maximize throughput by overlapping NPU and CPU operations.
497+ # The 'with' statement ensures immediate thread teardown and resource release.
498+ with ThreadPoolExecutor (max_workers = len (active_tasks )) as executor :
499+ futures = [executor .submit (task_function , strategy , indexes ) for strategy , indexes in active_tasks ]
500+ return [f .result () for f in futures ]
0 commit comments