Skip to content

Commit 159d2f8

Browse files
committed
[None][fix] Address AlltoAll watchdog review findings
Signed-off-by: Chien-Chun Hung <2679986+chienchunhung@users.noreply.github.com>
1 parent 4608509 commit 159d2f8

8 files changed

Lines changed: 345 additions & 23 deletions

File tree

tensorrt_llm/_torch/alltoall_watchdog.py

Lines changed: 100 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,13 @@
3131

3232
import torch
3333

34+
from tensorrt_llm._utils import prefer_pinned
3435
from tensorrt_llm.logger import logger as tllm_logger
3536

37+
DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S = 5.0
38+
DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S = 0.1
39+
UNKNOWN_COMPLETION_FLAG = -(2**63)
40+
3641

3742
class CompletionFlagReader(Protocol):
3843
"""Reads one phase's rank-local completion flag row."""
@@ -51,6 +56,10 @@ def mark_failed(self, rank: int) -> bool:
5156
"""Mark ``rank`` failed and return whether state changed."""
5257

5358

59+
class CompletionFlagReadTimeout(TimeoutError):
60+
"""Raised when the host watchdog cannot read completion flags in time."""
61+
62+
5463
@dataclass(frozen=True)
5564
class AlltoAllWatchdogTimeout:
5665
"""Details emitted when an AlltoAll phase times out."""
@@ -61,6 +70,7 @@ class AlltoAllWatchdogTimeout:
6170
missing_ranks: tuple[int, ...]
6271
marked_failed_ranks: tuple[int, ...]
6372
elapsed_s: float
73+
poll_timed_out: bool = False
6474

6575

6676
@dataclass(frozen=True)
@@ -81,6 +91,7 @@ def __init__(
8191
ep_size: int,
8292
dispatch_completion_flags_offset: int,
8393
combine_completion_flags_offset: int,
94+
device_copy_timeout_s: float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S,
8495
) -> None:
8596
if workspace.dim() != 2:
8697
raise ValueError("workspace must be a 2D tensor [ep_size, size_per_rank]")
@@ -97,11 +108,50 @@ def __init__(
97108
"dispatch": int(dispatch_completion_flags_offset),
98109
"combine": int(combine_completion_flags_offset),
99110
}
111+
self._device_copy_timeout_s = float(device_copy_timeout_s)
112+
self._copy_stream: torch.cuda.Stream | None = None
113+
self._retired_copies: list[tuple[torch.Tensor, torch.cuda.Event]] = []
114+
if workspace.device.type == "cuda":
115+
self._copy_stream = torch.cuda.Stream(device=workspace.device)
116+
117+
def _prune_retired_copies(self) -> None:
118+
self._retired_copies = [
119+
(host_flags, event) for host_flags, event in self._retired_copies if not event.query()
120+
]
121+
122+
def _read_cuda_flags(self, flags: torch.Tensor) -> tuple[int, ...]:
123+
assert self._copy_stream is not None
124+
self._prune_retired_copies()
125+
126+
host_flags = torch.empty(
127+
(self._ep_size,),
128+
dtype=torch.int32,
129+
device="cpu",
130+
pin_memory=prefer_pinned(),
131+
)
132+
event = torch.cuda.Event(blocking=False)
133+
with torch.cuda.device(flags.device), torch.cuda.stream(self._copy_stream):
134+
host_flags.copy_(flags.detach(), non_blocking=True)
135+
event.record(self._copy_stream)
136+
137+
deadline_s = time.monotonic() + self._device_copy_timeout_s
138+
while not event.query():
139+
remaining_s = deadline_s - time.monotonic()
140+
if remaining_s <= 0:
141+
self._retired_copies.append((host_flags, event))
142+
raise CompletionFlagReadTimeout(
143+
"timed out copying AlltoAll completion flags to host"
144+
)
145+
time.sleep(min(remaining_s, 0.001))
146+
147+
return tuple(int(v) for v in host_flags.tolist())
100148

101149
def read_completion_flags(self, phase: str) -> tuple[int, ...]:
102150
offset = self._offsets[phase]
103151
end = offset + self._ep_size * 4
104152
flags = self._workspace[self._ep_rank, offset:end].view(torch.int32)
153+
if flags.device.type == "cuda":
154+
return self._read_cuda_flags(flags)
105155
if flags.device.type != "cpu":
106156
flags = flags.detach().cpu()
107157
return tuple(int(v) for v in flags.tolist())
@@ -123,8 +173,8 @@ def __init__(
123173
ep_size: int,
124174
ep_rank: int,
125175
completion_reader: CompletionFlagReader,
126-
timeout_s: float,
127-
poll_interval_s: float = 0.05,
176+
timeout_s: float = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S,
177+
poll_interval_s: float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S,
128178
health: Optional[EPGroupHealthLike] = None,
129179
on_timeout: Optional[Callable[[AlltoAllWatchdogTimeout], None]] = None,
130180
) -> None:
@@ -160,8 +210,8 @@ def from_workspace(
160210
metainfo_index: Mapping[str, int],
161211
ep_rank: int,
162212
ep_size: int,
163-
timeout_s: float,
164-
poll_interval_s: float = 0.05,
213+
timeout_s: float = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S,
214+
poll_interval_s: float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S,
165215
health: Optional[EPGroupHealthLike] = None,
166216
on_timeout: Optional[Callable[[AlltoAllWatchdogTimeout], None]] = None,
167217
) -> "AlltoAllWatchdog":
@@ -178,6 +228,7 @@ def from_workspace(
178228
ep_size=ep_size,
179229
dispatch_completion_flags_offset=dispatch_offset,
180230
combine_completion_flags_offset=combine_offset,
231+
device_copy_timeout_s=poll_interval_s,
181232
)
182233
return cls(
183234
ep_size=ep_size,
@@ -288,11 +339,18 @@ def _missing_ranks(
288339
if observed_flags[rank] != watch.expected_flag
289340
)
290341

291-
def _handle_timeout(self, watch: _CollectiveWatch, observed_flags: tuple[int, ...]) -> None:
342+
def _handle_timeout(
343+
self,
344+
watch: _CollectiveWatch,
345+
observed_flags: tuple[int, ...],
346+
*,
347+
poll_timed_out: bool = False,
348+
) -> None:
292349
elapsed_s = time.monotonic() - watch.start_s
293350
missing_ranks = self._missing_ranks(watch, observed_flags)
294351
marked_failed: list[int] = []
295-
if self._health is not None:
352+
has_known_flags = UNKNOWN_COMPLETION_FLAG not in observed_flags
353+
if self._health is not None and (has_known_flags or not poll_timed_out):
296354
for rank in missing_ranks:
297355
if rank == self._ep_rank:
298356
continue
@@ -306,20 +364,37 @@ def _handle_timeout(self, watch: _CollectiveWatch, observed_flags: tuple[int, ..
306364
missing_ranks=missing_ranks,
307365
marked_failed_ranks=tuple(marked_failed),
308366
elapsed_s=elapsed_s,
367+
poll_timed_out=poll_timed_out,
309368
)
310-
tllm_logger.warning(
311-
"AlltoAll watchdog timeout on rank %d during %s: expected flag %d, "
312-
"missing ranks %s, observed flags %s",
313-
self._ep_rank,
314-
watch.phase,
315-
watch.expected_flag,
316-
list(missing_ranks),
317-
list(observed_flags),
318-
)
369+
if poll_timed_out:
370+
tllm_logger.error(
371+
"AlltoAll watchdog could not read completion flags on rank %d "
372+
"during %s before timeout %.3fs; expected flag %d, active "
373+
"ranks %s, observed flags %s, marked ranks %s",
374+
self._ep_rank,
375+
watch.phase,
376+
elapsed_s,
377+
watch.expected_flag,
378+
list(self._active_ranks(watch.active_mask)),
379+
list(observed_flags),
380+
list(marked_failed),
381+
)
382+
else:
383+
tllm_logger.warning(
384+
"AlltoAll watchdog timeout on rank %d during %s: expected flag %d, "
385+
"missing ranks %s, observed flags %s",
386+
self._ep_rank,
387+
watch.phase,
388+
watch.expected_flag,
389+
list(missing_ranks),
390+
list(observed_flags),
391+
)
319392
if self._on_timeout is not None:
320393
self._on_timeout(event)
321394

322395
def _run(self) -> None:
396+
last_observed_flags = tuple(UNKNOWN_COMPLETION_FLAG for _ in range(self._ep_size))
397+
poll_timed_out = False
323398
while True:
324399
with self._cv:
325400
while not self._queue and not self._stopping:
@@ -337,6 +412,11 @@ def _run(self) -> None:
337412
f"completion reader returned {len(observed_flags)} flags; "
338413
f"expected ep_size={self._ep_size}"
339414
)
415+
last_observed_flags = observed_flags
416+
poll_timed_out = False
417+
except CompletionFlagReadTimeout:
418+
observed_flags = last_observed_flags
419+
poll_timed_out = True
340420
except BaseException as exc: # noqa: BLE001 - keep watchdog failures visible.
341421
with self._cv:
342422
self._last_error = exc
@@ -350,16 +430,20 @@ def _run(self) -> None:
350430
if self._queue and self._queue[0] is watch:
351431
self._queue.popleft()
352432
self._cv.notify_all()
433+
last_observed_flags = tuple(UNKNOWN_COMPLETION_FLAG for _ in range(self._ep_size))
434+
poll_timed_out = False
353435
continue
354436

355437
if time.monotonic() - watch.start_s >= self._timeout_s:
356-
self._handle_timeout(watch, observed_flags)
438+
self._handle_timeout(watch, observed_flags, poll_timed_out=poll_timed_out)
357439
with self._cv:
358440
# The GPU stream is no longer trustworthy once a collective
359441
# times out. Drop queued follow-on phases so they do not
360442
# produce duplicate or misleading reports.
361443
self._queue.clear()
362444
self._cv.notify_all()
445+
last_observed_flags = tuple(UNKNOWN_COMPLETION_FLAG for _ in range(self._ep_size))
446+
poll_timed_out = False
363447
continue
364448

365449
with self._cv:

tensorrt_llm/_torch/distributed/moe_alltoall.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,17 @@
88
# ruff: noqa: E501
99

1010
import os
11+
import sys
1112
from dataclasses import dataclass
1213
from typing import Callable, Dict, Optional
1314

1415
import torch
1516

1617
from tensorrt_llm._mnnvl_utils import MnnvlMemory
17-
from tensorrt_llm._torch.alltoall_watchdog import (AlltoAllWatchdog,
18-
AlltoAllWatchdogTimeout)
18+
from tensorrt_llm._torch.alltoall_watchdog import (
19+
DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S,
20+
DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S, AlltoAllWatchdog,
21+
AlltoAllWatchdogTimeout)
1922
from tensorrt_llm.bindings import internal as _tllm_internal
2023
from tensorrt_llm.logger import logger as tllm_logger
2124
from tensorrt_llm.mapping import Mapping
@@ -130,7 +133,8 @@ def __init__(
130133
num_experts: Optional[int] = None,
131134
ep_group_health=None,
132135
alltoall_watchdog_timeout_s: Optional[float] = None,
133-
alltoall_watchdog_poll_interval_s: float = 0.05,
136+
alltoall_watchdog_poll_interval_s:
137+
float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S,
134138
alltoall_watchdog_on_timeout: Optional[Callable[
135139
[AlltoAllWatchdogTimeout], None]] = None,
136140
):
@@ -228,8 +232,12 @@ def __init__(
228232
# Internal state
229233
self._state: _A2AState = _A2AState()
230234
self.ep_group_health = ep_group_health
235+
self._destroyed = False
231236
self._watchdog_flag_generation = 0
232237
self._alltoall_watchdog: AlltoAllWatchdog | None = None
238+
if (alltoall_watchdog_timeout_s is None
239+
and self.ep_group_health is not None):
240+
alltoall_watchdog_timeout_s = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S
233241
if alltoall_watchdog_timeout_s is not None:
234242
self._watchdog_flag_generation = self._read_current_flag_val()
235243
self._alltoall_watchdog = AlltoAllWatchdog.from_workspace(
@@ -244,6 +252,20 @@ def __init__(
244252
on_timeout=alltoall_watchdog_on_timeout,
245253
)
246254

255+
def destroy(self) -> None:
256+
"""Stop background watchdog resources owned by this wrapper."""
257+
if getattr(self, "_destroyed", False):
258+
return
259+
self._destroyed = True
260+
watchdog = getattr(self, "_alltoall_watchdog", None)
261+
if watchdog is not None:
262+
watchdog.stop(timeout_s=1.0)
263+
self._alltoall_watchdog = None
264+
265+
def __del__(self) -> None:
266+
if not sys.is_finalizing():
267+
self.destroy()
268+
247269
def _read_current_flag_val(self) -> int:
248270
flag_val_offset = self.metainfo[
249271
self._METAINFO_INDEX["FLAG_VAL_OFFSET_INDEX"]].item()

tensorrt_llm/_torch/modules/fused_moe/communication/communication_factory.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tensorrt_llm._torch.model_config import ModelConfig
2929
from tensorrt_llm.logger import logger
3030

31+
from ..wide_ep_ft import get_wide_ep_ft_options
3132
from .allgather_reducescatter import AllGatherReduceScatter
3233
from .base import Communication
3334
from .deep_ep import DeepEP
@@ -133,6 +134,9 @@ def create_strategy(
133134

134135
try:
135136
enable_eplb = model_config.moe_load_balancer is not None
137+
ep_group_health, watchdog_timeout_s, watchdog_poll_interval_s = get_wide_ep_ft_options(
138+
model_config
139+
)
136140
strategy = NVLinkOneSided(
137141
mapping,
138142
num_slots,
@@ -143,6 +147,9 @@ def create_strategy(
143147
dtype=act_dtype,
144148
num_experts=num_experts if enable_eplb else None,
145149
use_low_precision_combine=use_low_precision_combine,
150+
ep_group_health=ep_group_health,
151+
alltoall_watchdog_timeout_s=watchdog_timeout_s,
152+
alltoall_watchdog_poll_interval_s=watchdog_poll_interval_s,
146153
)
147154
logger.info("Selected communication strategy: NVLinkOneSided")
148155
return strategy
@@ -285,6 +292,9 @@ def _create_forced_method(
285292
)
286293
elif method in ["NVLINK_ONE_SIDED"]:
287294
enable_eplb = model_config.moe_load_balancer is not None
295+
ep_group_health, watchdog_timeout_s, watchdog_poll_interval_s = get_wide_ep_ft_options(
296+
model_config
297+
)
288298
return NVLinkOneSided(
289299
mapping,
290300
num_slots,
@@ -295,6 +305,9 @@ def _create_forced_method(
295305
dtype=act_dtype,
296306
num_experts=num_experts if enable_eplb else None,
297307
use_low_precision_combine=use_low_precision_combine,
308+
ep_group_health=ep_group_health,
309+
alltoall_watchdog_timeout_s=watchdog_timeout_s,
310+
alltoall_watchdog_poll_interval_s=watchdog_poll_interval_s,
298311
)
299312
elif method == "DEEPEP":
300313
return DeepEP(

tensorrt_llm/_torch/modules/fused_moe/communication/nvlink_one_sided.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@
3030
import torch
3131

3232
from tensorrt_llm._mnnvl_utils import MnnvlMemory
33-
from tensorrt_llm._torch.alltoall_watchdog import AlltoAllWatchdog, AlltoAllWatchdogTimeout
33+
from tensorrt_llm._torch.alltoall_watchdog import (
34+
DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S,
35+
DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S,
36+
AlltoAllWatchdog,
37+
AlltoAllWatchdogTimeout,
38+
)
3439
from tensorrt_llm.bindings import internal as _tllm_internal
3540
from tensorrt_llm.logger import logger as tllm_logger
3641
from tensorrt_llm.mapping import Mapping
@@ -154,7 +159,7 @@ def __init__(
154159
use_low_precision_combine: bool = False,
155160
ep_group_health=None,
156161
alltoall_watchdog_timeout_s: Optional[float] = None,
157-
alltoall_watchdog_poll_interval_s: float = 0.05,
162+
alltoall_watchdog_poll_interval_s: float = DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S,
158163
alltoall_watchdog_on_timeout: Optional[Callable[[AlltoAllWatchdogTimeout], None]] = None,
159164
):
160165
"""
@@ -314,6 +319,8 @@ def __init__(
314319
self.ep_group_health = ep_group_health
315320
self._watchdog_flag_generation = 0
316321
self._alltoall_watchdog: AlltoAllWatchdog | None = None
322+
if alltoall_watchdog_timeout_s is None and self.ep_group_health is not None:
323+
alltoall_watchdog_timeout_s = DEFAULT_ALLTOALL_WATCHDOG_TIMEOUT_S
317324
if alltoall_watchdog_timeout_s is not None:
318325
self._watchdog_flag_generation = self._read_current_flag_val()
319326
self._alltoall_watchdog = AlltoAllWatchdog.from_workspace(

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
Fp4QuantizedTensor)
2121
from .interface import AlltoallMethodType, MoE
2222
from .quantization import UnquantizedFusedMoEMethod
23+
from .wide_ep_ft import get_wide_ep_ft_options
2324

2425
# isort: off
2526
from .quantization import (
@@ -324,6 +325,8 @@ def __init__(
324325
dtype,
325326
self.num_experts if self.layer_load_balancer else None,
326327
)
328+
ep_group_health, watchdog_timeout_s, watchdog_poll_interval_s = (
329+
get_wide_ep_ft_options(model_config))
327330

328331
self.moe_a2a = MoeAlltoAll(
329332
mapping=self.mapping,
@@ -333,6 +336,10 @@ def __init__(
333336
workspace_size_per_rank=workspace_size,
334337
num_experts=self.num_experts
335338
if self.layer_load_balancer else None,
339+
ep_group_health=ep_group_health,
340+
alltoall_watchdog_timeout_s=watchdog_timeout_s,
341+
alltoall_watchdog_poll_interval_s=
342+
watchdog_poll_interval_s,
336343
)
337344
elif self.alltoall_method_type == AlltoallMethodType.DeepEP or self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency:
338345
raise NotImplementedError(

0 commit comments

Comments
 (0)