Skip to content

Commit 4db6344

Browse files
committed
Added multi-threads optimization to 'put/get/clear' of 'YuanrongStorageClient'
Signed-off-by: dpj135 <958208521@qq.com>
1 parent acd7686 commit 4db6344

1 file changed

Lines changed: 61 additions & 29 deletions

File tree

transfer_queue/storage/clients/yuanrong_client.py

Lines changed: 61 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -361,19 +361,21 @@ def put(self, keys: list[str], values: list[Any]) -> list[str]:
361361
raise ValueError("Number of keys must match number of values")
362362

363363
routed_indexes = self._route_to_strategies(values, lambda strategy_, item_: strategy_.supports_put(item_))
364-
custom_metas: list[str] = [""] * len(keys)
365-
366-
# Todo(dpj): Parallel put
367-
for strategy, indexes in routed_indexes.items():
368-
if not indexes:
369-
continue
370-
strategy_keys = [keys[i] for i in indexes]
371-
strategy_values = [values[i] for i in indexes]
372-
strategy.put(strategy_keys, strategy_values)
373-
for i in indexes:
374-
custom_metas[i] = strategy.custom_meta()
375364

376-
return custom_metas
365+
# Define the work unit: Slicing the input list and calling the backend strategy.
366+
# The closure captures local 'keys' and 'values' for zero-overhead parameter passing.
367+
def put_task(strategy, indexes):
368+
strategy.put([keys[i] for i in indexes], [values[i] for i in indexes])
369+
return strategy.custom_meta(), indexes
370+
371+
# Call the orchestrator to run the tasks (Parallel or Sequential).
372+
# We then iterate through the results to map strategy-specific metadata back
373+
# to the original global index order.
374+
custom_meta: list[str] = [""] * len(keys)
375+
for meta_str, indexes in self._dispatch_tasks(routed_indexes, put_task):
376+
for i in indexes:
377+
custom_meta[i] = meta_str
378+
return custom_meta
377379

378380
def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> list[Any]:
379381
"""Retrieves multiple values from remote storage with expected metadata.
@@ -402,18 +404,20 @@ def get(self, keys: list[str], shapes=None, dtypes=None, custom_meta=None) -> li
402404

403405
routed_indexes = self._route_to_strategies(custom_meta, lambda strategy_, item_: strategy_.supports_get(item_))
404406

405-
# Todo(dpj): Parallel get
406-
results = [None] * len(keys)
407-
for strategy, indexes in routed_indexes.items():
408-
if not indexes:
409-
continue
410-
strategy_keys = [keys[i] for i in indexes]
411-
strategy_shapes = [shapes[i] for i in indexes]
412-
strategy_dtypes = [dtypes[i] for i in indexes]
413-
strategy_results = strategy.get(strategy_keys, shapes=strategy_shapes, dtypes=strategy_dtypes)
414-
for j, i in enumerate(indexes):
415-
results[i] = strategy_results[j]
407+
# Work unit for 'get': handles slicing of keys, shapes, and dtypes simultaneously.
408+
def get_task(strategy, indexes):
409+
res = strategy.get(
410+
[keys[i] for i in indexes], shapes=[shapes[i] for i in indexes], dtypes=[dtypes[i] for i in indexes]
411+
)
412+
return res, indexes
416413

414+
# Dispatch the 'get' requests. Multiple backends (e.g. NPU and SSD) will fetch
415+
# in parallel if needed. Results are merged back into the 'results' list
416+
# according to their original positions.
417+
results = [None] * len(keys)
418+
for strategy_res, indexes in self._dispatch_tasks(routed_indexes, get_task):
419+
for value, original_index in zip(strategy_res, indexes, strict=True):
420+
results[original_index] = value
417421
return results
418422

419423
def clear(self, keys: list[str], custom_meta=None):
@@ -433,12 +437,14 @@ def clear(self, keys: list[str], custom_meta=None):
433437
routed_indexes = self._route_to_strategies(
434438
custom_meta, lambda strategy_, item_: strategy_.supports_clear(item_)
435439
)
436-
# Todo(dpj): Parallel clear
437-
for strategy, indexes in routed_indexes.items():
438-
if not indexes:
439-
continue
440-
strategy_keys = [keys[i] for i in indexes]
441-
strategy.clear(strategy_keys)
440+
441+
# Cleanup work unit: Does not return values, just executes deletion.
442+
def clear_task(strategy, indexes):
443+
strategy.clear([keys[i] for i in indexes])
444+
445+
# Parallelize deletion across strategies.
446+
# The 'with' context in _dispatch_tasks will wait for all deletions to finish.
447+
self._dispatch_tasks(routed_indexes, clear_task)
442448

443449
def _route_to_strategies(
444450
self,
@@ -466,3 +472,29 @@ def _route_to_strategies(
466472
else:
467473
raise ValueError(f"No strategy supports item: {item}")
468474
return routed_indexes
475+
476+
def _dispatch_tasks(self, routed_tasks: dict[StorageStrategy, list[int]], task_function: Callable) -> list[Any]:
477+
"""
478+
Orchestrates task execution across multiple strategies.
479+
480+
Logic:
481+
1. If no tasks are present, return immediately.
482+
2. If only one strategy is active, execute synchronously in the main thread (Fast Path)
483+
to avoid the overhead of thread creation and context switching.
484+
3. If multiple strategies are active, execute in parallel using a ThreadPoolExecutor.
485+
"""
486+
active_tasks = [(strategy, indexes) for strategy, indexes in routed_tasks.items() if indexes]
487+
488+
if not active_tasks:
489+
return []
490+
491+
# Fast Path: Execute directly if only one backend is targeted.
492+
# This significantly reduces latency for homogeneous batches (e.g., NPU-only).
493+
if len(active_tasks) == 1:
494+
return [task_function(*active_tasks[0])]
495+
496+
# Parallel Path: Maximize throughput by overlapping NPU and CPU operations.
497+
# The 'with' statement ensures immediate thread teardown and resource release.
498+
with ThreadPoolExecutor(max_workers=len(active_tasks)) as executor:
499+
futures = [executor.submit(task_function, strategy, indexes) for strategy, indexes in active_tasks]
500+
return [f.result() for f in futures]

0 commit comments

Comments
 (0)