3131
3232import torch
3333
34+ from tensorrt_llm ._utils import prefer_pinned
3435from tensorrt_llm .logger import logger as tllm_logger
3536
37+ DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S = 5.0
38+ DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S = 0.1
39+ UNKNOWN_COMPLETION_FLAG = - (2 ** 63 )
40+
3641
3742class CompletionFlagReader (Protocol ):
3843 """Reads one phase's rank-local completion flag row."""
@@ -51,6 +56,10 @@ def mark_failed(self, rank: int) -> bool:
5156 """Mark ``rank`` failed and return whether state changed."""
5257
5358
59+ class CompletionFlagReadTimeout (TimeoutError ):
60+ """Raised when the host watchdog cannot read completion flags in time."""
61+
62+
5463@dataclass (frozen = True )
5564class AlltoAllWatchdogTimeout :
5665 """Details emitted when an AlltoAll phase times out."""
@@ -61,6 +70,7 @@ class AlltoAllWatchdogTimeout:
6170 missing_ranks : tuple [int , ...]
6271 marked_failed_ranks : tuple [int , ...]
6372 elapsed_s : float
73+ poll_timed_out : bool = False
6474
6575
6676@dataclass (frozen = True )
@@ -81,6 +91,7 @@ def __init__(
8191 ep_size : int ,
8292 dispatch_completion_flags_offset : int ,
8393 combine_completion_flags_offset : int ,
94+ device_copy_timeout_s : float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S ,
8495 ) -> None :
8596 if workspace .dim () != 2 :
8697 raise ValueError ("workspace must be a 2D tensor [ep_size, size_per_rank]" )
@@ -97,11 +108,50 @@ def __init__(
97108 "dispatch" : int (dispatch_completion_flags_offset ),
98109 "combine" : int (combine_completion_flags_offset ),
99110 }
111+ self ._device_copy_timeout_s = float (device_copy_timeout_s )
112+ self ._copy_stream : torch .cuda .Stream | None = None
113+ self ._retired_copies : list [tuple [torch .Tensor , torch .cuda .Event ]] = []
114+ if workspace .device .type == "cuda" :
115+ self ._copy_stream = torch .cuda .Stream (device = workspace .device )
116+
117+ def _prune_retired_copies (self ) -> None :
118+ self ._retired_copies = [
119+ (host_flags , event ) for host_flags , event in self ._retired_copies if not event .query ()
120+ ]
121+
122+ def _read_cuda_flags (self , flags : torch .Tensor ) -> tuple [int , ...]:
123+ assert self ._copy_stream is not None
124+ self ._prune_retired_copies ()
125+
126+ host_flags = torch .empty (
127+ (self ._ep_size ,),
128+ dtype = torch .int32 ,
129+ device = "cpu" ,
130+ pin_memory = prefer_pinned (),
131+ )
132+ event = torch .cuda .Event (blocking = False )
133+ with torch .cuda .device (flags .device ), torch .cuda .stream (self ._copy_stream ):
134+ host_flags .copy_ (flags .detach (), non_blocking = True )
135+ event .record (self ._copy_stream )
136+
137+ deadline_s = time .monotonic () + self ._device_copy_timeout_s
138+ while not event .query ():
139+ remaining_s = deadline_s - time .monotonic ()
140+ if remaining_s <= 0 :
141+ self ._retired_copies .append ((host_flags , event ))
142+ raise CompletionFlagReadTimeout (
143+ "timed out copying AlltoAll completion flags to host"
144+ )
145+ time .sleep (min (remaining_s , 0.001 ))
146+
147+ return tuple (int (v ) for v in host_flags .tolist ())
100148
101149 def read_completion_flags (self , phase : str ) -> tuple [int , ...]:
102150 offset = self ._offsets [phase ]
103151 end = offset + self ._ep_size * 4
104152 flags = self ._workspace [self ._ep_rank , offset :end ].view (torch .int32 )
153+ if flags .device .type == "cuda" :
154+ return self ._read_cuda_flags (flags )
105155 if flags .device .type != "cpu" :
106156 flags = flags .detach ().cpu ()
107157 return tuple (int (v ) for v in flags .tolist ())
@@ -123,8 +173,8 @@ def __init__(
123173 ep_size : int ,
124174 ep_rank : int ,
125175 completion_reader : CompletionFlagReader ,
126- timeout_s : float ,
127- poll_interval_s : float = 0.05 ,
176+ timeout_s : float = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S ,
177+ poll_interval_s : float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S ,
128178 health : Optional [EPGroupHealthLike ] = None ,
129179 on_timeout : Optional [Callable [[AlltoAllWatchdogTimeout ], None ]] = None ,
130180 ) -> None :
@@ -160,8 +210,8 @@ def from_workspace(
160210 metainfo_index : Mapping [str , int ],
161211 ep_rank : int ,
162212 ep_size : int ,
163- timeout_s : float ,
164- poll_interval_s : float = 0.05 ,
213+ timeout_s : float = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S ,
214+ poll_interval_s : float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S ,
165215 health : Optional [EPGroupHealthLike ] = None ,
166216 on_timeout : Optional [Callable [[AlltoAllWatchdogTimeout ], None ]] = None ,
167217 ) -> "AlltoAllWatchdog" :
@@ -178,6 +228,7 @@ def from_workspace(
178228 ep_size = ep_size ,
179229 dispatch_completion_flags_offset = dispatch_offset ,
180230 combine_completion_flags_offset = combine_offset ,
231+ device_copy_timeout_s = poll_interval_s ,
181232 )
182233 return cls (
183234 ep_size = ep_size ,
@@ -288,11 +339,18 @@ def _missing_ranks(
288339 if observed_flags [rank ] != watch .expected_flag
289340 )
290341
291- def _handle_timeout (self , watch : _CollectiveWatch , observed_flags : tuple [int , ...]) -> None :
342+ def _handle_timeout (
343+ self ,
344+ watch : _CollectiveWatch ,
345+ observed_flags : tuple [int , ...],
346+ * ,
347+ poll_timed_out : bool = False ,
348+ ) -> None :
292349 elapsed_s = time .monotonic () - watch .start_s
293350 missing_ranks = self ._missing_ranks (watch , observed_flags )
294351 marked_failed : list [int ] = []
295- if self ._health is not None :
352+ has_known_flags = UNKNOWN_COMPLETION_FLAG not in observed_flags
353+ if self ._health is not None and (has_known_flags or not poll_timed_out ):
296354 for rank in missing_ranks :
297355 if rank == self ._ep_rank :
298356 continue
@@ -306,20 +364,37 @@ def _handle_timeout(self, watch: _CollectiveWatch, observed_flags: tuple[int, ..
306364 missing_ranks = missing_ranks ,
307365 marked_failed_ranks = tuple (marked_failed ),
308366 elapsed_s = elapsed_s ,
367+ poll_timed_out = poll_timed_out ,
309368 )
310- tllm_logger .warning (
311- "AlltoAll watchdog timeout on rank %d during %s: expected flag %d, "
312- "missing ranks %s, observed flags %s" ,
313- self ._ep_rank ,
314- watch .phase ,
315- watch .expected_flag ,
316- list (missing_ranks ),
317- list (observed_flags ),
318- )
369+ if poll_timed_out :
370+ tllm_logger .error (
371+ "AlltoAll watchdog could not read completion flags on rank %d "
372+ "during %s before timeout %.3fs; expected flag %d, active "
373+ "ranks %s, observed flags %s, marked ranks %s" ,
374+ self ._ep_rank ,
375+ watch .phase ,
376+ elapsed_s ,
377+ watch .expected_flag ,
378+ list (self ._active_ranks (watch .active_mask )),
379+ list (observed_flags ),
380+ list (marked_failed ),
381+ )
382+ else :
383+ tllm_logger .warning (
384+ "AlltoAll watchdog timeout on rank %d during %s: expected flag %d, "
385+ "missing ranks %s, observed flags %s" ,
386+ self ._ep_rank ,
387+ watch .phase ,
388+ watch .expected_flag ,
389+ list (missing_ranks ),
390+ list (observed_flags ),
391+ )
319392 if self ._on_timeout is not None :
320393 self ._on_timeout (event )
321394
322395 def _run (self ) -> None :
396+ last_observed_flags = tuple (UNKNOWN_COMPLETION_FLAG for _ in range (self ._ep_size ))
397+ poll_timed_out = False
323398 while True :
324399 with self ._cv :
325400 while not self ._queue and not self ._stopping :
@@ -337,6 +412,11 @@ def _run(self) -> None:
337412 f"completion reader returned { len (observed_flags )} flags; "
338413 f"expected ep_size={ self ._ep_size } "
339414 )
415+ last_observed_flags = observed_flags
416+ poll_timed_out = False
417+ except CompletionFlagReadTimeout :
418+ observed_flags = last_observed_flags
419+ poll_timed_out = True
340420 except BaseException as exc : # noqa: BLE001 - keep watchdog failures visible.
341421 with self ._cv :
342422 self ._last_error = exc
@@ -350,16 +430,20 @@ def _run(self) -> None:
350430 if self ._queue and self ._queue [0 ] is watch :
351431 self ._queue .popleft ()
352432 self ._cv .notify_all ()
433+ last_observed_flags = tuple (UNKNOWN_COMPLETION_FLAG for _ in range (self ._ep_size ))
434+ poll_timed_out = False
353435 continue
354436
355437 if time .monotonic () - watch .start_s >= self ._timeout_s :
356- self ._handle_timeout (watch , observed_flags )
438+ self ._handle_timeout (watch , observed_flags , poll_timed_out = poll_timed_out )
357439 with self ._cv :
358440 # The GPU stream is no longer trustworthy once a collective
359441 # times out. Drop queued follow-on phases so they do not
360442 # produce duplicate or misleading reports.
361443 self ._queue .clear ()
362444 self ._cv .notify_all ()
445+ last_observed_flags = tuple (UNKNOWN_COMPLETION_FLAG for _ in range (self ._ep_size ))
446+ poll_timed_out = False
363447 continue
364448
365449 with self ._cv :
0 commit comments