@@ -269,6 +269,56 @@ def __exit__(self, exc_type, exc, tb):
269269 return self ._group .__exit__ (exc_type , exc , tb )
270270
271271
272+ class _WorkerWarmupState :
273+ """
274+ Shared once-per-physical-device warmup coordination.
275+
276+ The first worker that reaches this state performs the warmup. Other workers
277+ wait on the completion event without holding the pool registry lock.
278+ """
279+
280+ def __init__ (self , warmup_fn : Callable [[torch .device ], None ]):
281+ self ._warmup_fn = warmup_fn
282+ self ._claim_lock = threading .Lock ()
283+ self ._started = False
284+ self ._done = threading .Event ()
285+ self ._error : Optional [BaseException ] = None
286+
287+ def run (self , * , device : torch .device , rwlock : _RWLock ) -> None :
288+ if self ._done .is_set ():
289+ self ._raise_if_failed ()
290+ return
291+
292+ should_run = False
293+ with self ._claim_lock :
294+ if self ._done .is_set ():
295+ pass
296+ elif not self ._started :
297+ self ._started = True
298+ should_run = True
299+
300+ if should_run :
301+ try :
302+ with ctx (rwlock .reader (), _device_ctx (device )):
303+ self ._warmup_fn (device )
304+ except BaseException as exc :
305+ with self ._claim_lock :
306+ self ._error = exc
307+ raise
308+ finally :
309+ self ._done .set ()
310+ else :
311+ self ._done .wait ()
312+
313+ self ._raise_if_failed ()
314+
315+ def _raise_if_failed (self ) -> None :
316+ with self ._claim_lock :
317+ error = self ._error
318+ if error is not None :
319+ raise error
320+
321+
272322# --------------------------- Worker Thread ---------------------------
273323# Each worker is bound to a specific device and runs a single thread. Tasks are
274324# executed under the device’s read lock; GC acquires the writer lock to keep
@@ -292,15 +342,15 @@ def __init__(
292342 name : Optional [str ] = None ,
293343 inference_mode : bool = False ,
294344 cpu_core : Optional [int ] = None ,
295- warmup_fn : Optional [Callable [[ torch . device ], None ] ] = None ,
345+ warmup_state : Optional [_WorkerWarmupState ] = None ,
296346 * ,
297347 key_override : Optional [str ] = None ,
298348 ):
299349 self .device = device
300350 self .rwlock = rwlock
301351 self ._on_task_finished = on_task_finished
302352 self ._on_worker_exit = on_worker_exit
303- self ._warmup_fn = warmup_fn
353+ self ._warmup_state = warmup_state
304354
305355 if key_override is not None :
306356 self .key = key_override
@@ -375,14 +425,11 @@ def _apply_cpu_affinity(self) -> None:
375425 self ._affinity_applied = True
376426
377427 def _run_warmup (self ) -> None :
378- warmup_fn = self ._warmup_fn
379- if warmup_fn is None :
428+ warmup_state = self ._warmup_state
429+ if warmup_state is None :
380430 return
381- try :
382- with ctx (self .rwlock .reader (), _device_ctx (self .device )):
383- warmup_fn (self .device )
384- finally :
385- self ._warmup_fn = None
431+ warmup_state .run (device = self .device , rwlock = self .rwlock )
432+ self ._warmup_state = None
386433
387434 def _run (self ):
388435 """
@@ -636,7 +683,7 @@ def __init__(
636683 {str (k ).lower (): fn for k , fn in warmups .items ()} if warmups else None
637684 )
638685 self ._warmup_lock = threading .Lock ()
639- self ._warmup_ran_keys : Set [str ] = set ()
686+ self ._warmup_states : Dict [str , _WorkerWarmupState ] = {}
640687
641688 workers_cfg = workers or {}
642689 base_workers : Dict [str , int ] = {}
@@ -890,7 +937,11 @@ def _priority(dev_type: str) -> int:
890937
891938 return plan
892939
893- def _resolve_worker_warmup (self , dev : torch .device , key : str ) -> Optional [Callable [[torch .device ], None ]]:
940+ def _resolve_worker_warmup (
941+ self ,
942+ dev : torch .device ,
943+ key : str ,
944+ ) -> Optional [_WorkerWarmupState ]:
894945 mapping = self ._worker_warmups
895946 if not mapping :
896947 return None
@@ -904,13 +955,14 @@ def _resolve_worker_warmup(self, dev: torch.device, key: str) -> Optional[Callab
904955 if warmup is None :
905956 return None
906957
907- # Map virtual workers back to their parent key so warmup runs once per physical device .
958+ # Virtual workers share the same physical-device warmup state as their parent .
908959 physical_key = self ._virtual_to_parent .get (key , key )
909960 with self ._warmup_lock :
910- if physical_key in self ._warmup_ran_keys :
911- return None
912- self ._warmup_ran_keys .add (physical_key )
913- return warmup
961+ state = self ._warmup_states .get (physical_key )
962+ if state is None :
963+ state = _WorkerWarmupState (warmup )
964+ self ._warmup_states [physical_key ] = state
965+ return state
914966
915967 def _spawn_worker (
916968 self ,
@@ -922,7 +974,7 @@ def _spawn_worker(
922974 """
923975 Create and start a worker bound to the provided device.
924976 """
925- warmup_fn = self ._resolve_worker_warmup (dev , key )
977+ warmup_state = self ._resolve_worker_warmup (dev , key )
926978 w = _DeviceWorker (
927979 device = dev ,
928980 rwlock = self ._locks [key ],
@@ -931,7 +983,7 @@ def _spawn_worker(
931983 name = name ,
932984 inference_mode = self ._inference_mode ,
933985 cpu_core = cpu_core ,
934- warmup_fn = warmup_fn ,
986+ warmup_state = warmup_state ,
935987 key_override = key ,
936988 )
937989 return w
0 commit comments