Skip to content

Commit 48cce16

Browse files
committed
[None][feat] WideEP FT: add AlltoAll watchdog
1 parent e776204 commit 48cce16

4 files changed

Lines changed: 778 additions & 4 deletions

File tree

Lines changed: 367 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,367 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Host-side watchdog for MoE AlltoAll completion flags.
16+
17+
The NVLinkOneSided kernels signal each collective by writing the current
18+
``flag_val`` into the rank-local completion flag table. A dead peer in the
19+
silent-spin failure mode never writes its slot, so this watchdog polls the same
20+
table from a CPU thread and reports peers whose flags do not reach the expected
21+
generation before a bounded timeout.
22+
"""
23+
24+
from __future__ import annotations
25+
26+
import threading
27+
import time
28+
from collections import deque
29+
from dataclasses import dataclass
30+
from typing import Callable, Deque, Mapping, Optional, Protocol, Sequence
31+
32+
import torch
33+
34+
from tensorrt_llm.logger import logger as tllm_logger
35+
36+
37+
class CompletionFlagReader(Protocol):
38+
"""Reads one phase's rank-local completion flag row."""
39+
40+
def read_completion_flags(self, phase: str) -> Sequence[int]:
41+
"""Return ``ep_size`` flag values for ``phase``."""
42+
43+
44+
class EPGroupHealthLike(Protocol):
45+
"""Subset of EPGroupHealth used by the watchdog."""
46+
47+
def get_mask(self) -> int:
48+
"""Return the active-rank bitmask."""
49+
50+
def mark_failed(self, rank: int) -> bool:
51+
"""Mark ``rank`` failed and return whether state changed."""
52+
53+
54+
@dataclass(frozen=True)
55+
class AlltoAllWatchdogTimeout:
56+
"""Details emitted when an AlltoAll phase times out."""
57+
58+
phase: str
59+
expected_flag: int
60+
observed_flags: tuple[int, ...]
61+
missing_ranks: tuple[int, ...]
62+
marked_failed_ranks: tuple[int, ...]
63+
elapsed_s: float
64+
65+
66+
@dataclass(frozen=True)
67+
class _CollectiveWatch:
68+
phase: str
69+
expected_flag: int
70+
active_mask: int
71+
start_s: float
72+
73+
74+
class _TorchCompletionFlagReader:
75+
"""Completion-flag reader backed by the MoE AlltoAll workspace tensor."""
76+
77+
def __init__(
78+
self,
79+
workspace: torch.Tensor,
80+
ep_rank: int,
81+
ep_size: int,
82+
dispatch_completion_flags_offset: int,
83+
combine_completion_flags_offset: int,
84+
) -> None:
85+
if workspace.dim() != 2:
86+
raise ValueError(
87+
"workspace must be a 2D tensor [ep_size, size_per_rank]")
88+
if not 0 <= ep_rank < ep_size:
89+
raise ValueError(
90+
f"ep_rank must be in [0, {ep_size}), got {ep_rank}")
91+
if workspace.size(0) != ep_size:
92+
raise ValueError(
93+
f"workspace first dimension must equal ep_size={ep_size}, got {workspace.size(0)}"
94+
)
95+
self._workspace = workspace
96+
self._ep_rank = ep_rank
97+
self._ep_size = ep_size
98+
self._offsets = {
99+
"dispatch": int(dispatch_completion_flags_offset),
100+
"combine": int(combine_completion_flags_offset),
101+
}
102+
103+
def read_completion_flags(self, phase: str) -> tuple[int, ...]:
104+
offset = self._offsets[phase]
105+
end = offset + self._ep_size * 4
106+
flags = self._workspace[self._ep_rank, offset:end].view(torch.int32)
107+
if flags.device.type != "cpu":
108+
flags = flags.detach().cpu()
109+
return tuple(int(v) for v in flags.tolist())
110+
111+
112+
class AlltoAllWatchdog:
113+
"""Background host thread that watches AlltoAll completion flags.
114+
115+
The watchdog is intentionally opt-in. Callers queue phases with
116+
:meth:`watch`; the thread polls them in FIFO order so a queued combine cannot
117+
hide a still-spinning dispatch.
118+
"""
119+
120+
VALID_PHASES = frozenset({"dispatch", "combine"})
121+
122+
def __init__(
123+
self,
124+
*,
125+
ep_size: int,
126+
ep_rank: int,
127+
completion_reader: CompletionFlagReader,
128+
timeout_s: float,
129+
poll_interval_s: float = 0.05,
130+
health: Optional[EPGroupHealthLike] = None,
131+
on_timeout: Optional[Callable[[AlltoAllWatchdogTimeout], None]] = None,
132+
) -> None:
133+
if ep_size <= 0:
134+
raise ValueError(f"ep_size must be > 0, got {ep_size}")
135+
if not 0 <= ep_rank < ep_size:
136+
raise ValueError(
137+
f"ep_rank must be in [0, {ep_size}), got {ep_rank}")
138+
if timeout_s <= 0:
139+
raise ValueError(f"timeout_s must be > 0, got {timeout_s}")
140+
if poll_interval_s <= 0:
141+
raise ValueError(
142+
f"poll_interval_s must be > 0, got {poll_interval_s}")
143+
144+
self._ep_size = int(ep_size)
145+
self._ep_rank = int(ep_rank)
146+
self._completion_reader = completion_reader
147+
self._timeout_s = float(timeout_s)
148+
self._poll_interval_s = float(poll_interval_s)
149+
self._health = health
150+
self._on_timeout = on_timeout
151+
152+
self._cv = threading.Condition()
153+
self._queue: Deque[_CollectiveWatch] = deque()
154+
self._stopping = False
155+
self._thread: threading.Thread | None = None
156+
self._last_error: BaseException | None = None
157+
158+
@classmethod
159+
def from_workspace(
160+
cls,
161+
*,
162+
workspace: torch.Tensor,
163+
metainfo: torch.Tensor,
164+
metainfo_index: Mapping[str, int],
165+
ep_rank: int,
166+
ep_size: int,
167+
timeout_s: float,
168+
poll_interval_s: float = 0.05,
169+
health: Optional[EPGroupHealthLike] = None,
170+
on_timeout: Optional[Callable[[AlltoAllWatchdogTimeout], None]] = None,
171+
) -> "AlltoAllWatchdog":
172+
"""Build a watchdog from the MoE AlltoAll workspace and metainfo."""
173+
dispatch_offset = int(metainfo[
174+
metainfo_index["DISPATCH_COMPLETION_FLAGS_OFFSET_INDEX"]].item())
175+
combine_offset = int(metainfo[
176+
metainfo_index["COMBINE_COMPLETION_FLAGS_OFFSET_INDEX"]].item())
177+
reader = _TorchCompletionFlagReader(
178+
workspace=workspace,
179+
ep_rank=ep_rank,
180+
ep_size=ep_size,
181+
dispatch_completion_flags_offset=dispatch_offset,
182+
combine_completion_flags_offset=combine_offset,
183+
)
184+
return cls(
185+
ep_size=ep_size,
186+
ep_rank=ep_rank,
187+
completion_reader=reader,
188+
timeout_s=timeout_s,
189+
poll_interval_s=poll_interval_s,
190+
health=health,
191+
on_timeout=on_timeout,
192+
)
193+
194+
@property
195+
def last_error(self) -> BaseException | None:
196+
"""Return the last polling-thread error, if any."""
197+
with self._cv:
198+
return self._last_error
199+
200+
def start(self) -> None:
201+
"""Start the background polling thread. Idempotent."""
202+
with self._cv:
203+
if self._thread is not None and self._thread.is_alive():
204+
return
205+
self._stopping = False
206+
self._thread = threading.Thread(
207+
target=self._run,
208+
name=f"AlltoAllWatchdog-rank{self._ep_rank}",
209+
daemon=True,
210+
)
211+
self._thread.start()
212+
213+
def stop(self, timeout_s: float | None = None) -> None:
214+
"""Stop the polling thread and wait for it to exit."""
215+
with self._cv:
216+
self._stopping = True
217+
self._queue.clear()
218+
self._cv.notify_all()
219+
thread = self._thread
220+
if thread is not None:
221+
thread.join(timeout=timeout_s)
222+
223+
def watch(
224+
self,
225+
*,
226+
phase: str,
227+
expected_flag: int,
228+
active_mask: int | None = None,
229+
) -> None:
230+
"""Queue a just-launched AlltoAll phase for watchdog polling."""
231+
if phase not in self.VALID_PHASES:
232+
raise ValueError(
233+
f"phase must be one of {sorted(self.VALID_PHASES)}, got {phase!r}"
234+
)
235+
if expected_flag < 0:
236+
raise ValueError(
237+
f"expected_flag must be non-negative, got {expected_flag}")
238+
if active_mask is None:
239+
if self._health is not None:
240+
active_mask = self._health.get_mask()
241+
else:
242+
active_mask = (1 << self._ep_size) - 1
243+
if not (active_mask >> self._ep_rank) & 1:
244+
raise ValueError("active_mask must include the local ep_rank")
245+
246+
self.start()
247+
with self._cv:
248+
if self._stopping:
249+
raise RuntimeError("cannot queue a stopped AlltoAllWatchdog")
250+
self._queue.append(
251+
_CollectiveWatch(
252+
phase=phase,
253+
expected_flag=int(expected_flag),
254+
active_mask=int(active_mask),
255+
start_s=time.monotonic(),
256+
))
257+
self._cv.notify_all()
258+
259+
def wait_until_idle(self, timeout_s: float) -> bool:
260+
"""Wait until all queued phases complete or timeout handling clears them."""
261+
deadline = time.monotonic() + timeout_s
262+
with self._cv:
263+
while self._queue:
264+
remaining = deadline - time.monotonic()
265+
if remaining <= 0:
266+
return False
267+
self._cv.wait(timeout=remaining)
268+
return True
269+
270+
def __enter__(self) -> "AlltoAllWatchdog":
271+
self.start()
272+
return self
273+
274+
def __exit__(self, exc_type, exc, tb) -> None:
275+
self.stop(timeout_s=1.0)
276+
277+
def _active_ranks(self, active_mask: int) -> tuple[int, ...]:
278+
return tuple(rank for rank in range(self._ep_size)
279+
if (active_mask >> rank) & 1)
280+
281+
def _phase_complete(self, watch: _CollectiveWatch,
282+
observed_flags: tuple[int, ...]) -> bool:
283+
return all(observed_flags[rank] == watch.expected_flag
284+
for rank in self._active_ranks(watch.active_mask))
285+
286+
def _missing_ranks(self, watch: _CollectiveWatch,
287+
observed_flags: tuple[int, ...]) -> tuple[int, ...]:
288+
return tuple(rank for rank in self._active_ranks(watch.active_mask)
289+
if observed_flags[rank] != watch.expected_flag)
290+
291+
def _handle_timeout(self, watch: _CollectiveWatch,
292+
observed_flags: tuple[int, ...]) -> None:
293+
elapsed_s = time.monotonic() - watch.start_s
294+
missing_ranks = self._missing_ranks(watch, observed_flags)
295+
marked_failed: list[int] = []
296+
if self._health is not None:
297+
for rank in missing_ranks:
298+
if rank == self._ep_rank:
299+
continue
300+
if self._health.mark_failed(rank):
301+
marked_failed.append(rank)
302+
303+
event = AlltoAllWatchdogTimeout(
304+
phase=watch.phase,
305+
expected_flag=watch.expected_flag,
306+
observed_flags=observed_flags,
307+
missing_ranks=missing_ranks,
308+
marked_failed_ranks=tuple(marked_failed),
309+
elapsed_s=elapsed_s,
310+
)
311+
tllm_logger.warning(
312+
"AlltoAll watchdog timeout on rank %d during %s: expected flag %d, "
313+
"missing ranks %s, observed flags %s",
314+
self._ep_rank,
315+
watch.phase,
316+
watch.expected_flag,
317+
list(missing_ranks),
318+
list(observed_flags),
319+
)
320+
if self._on_timeout is not None:
321+
self._on_timeout(event)
322+
323+
def _run(self) -> None:
324+
while True:
325+
with self._cv:
326+
while not self._queue and not self._stopping:
327+
self._cv.wait()
328+
if self._stopping:
329+
return
330+
watch = self._queue[0]
331+
332+
try:
333+
observed_flags = tuple(
334+
int(v) for v in self._completion_reader.read_completion_flags(watch.phase)
335+
)
336+
if len(observed_flags) != self._ep_size:
337+
raise RuntimeError(
338+
f"completion reader returned {len(observed_flags)} flags; "
339+
f"expected ep_size={self._ep_size}")
340+
except BaseException as exc: # noqa: BLE001 - keep watchdog failures visible.
341+
with self._cv:
342+
self._last_error = exc
343+
self._queue.clear()
344+
self._cv.notify_all()
345+
tllm_logger.error(
346+
"AlltoAll watchdog stopped after polling error: %s", exc)
347+
return
348+
349+
if self._phase_complete(watch, observed_flags):
350+
with self._cv:
351+
if self._queue and self._queue[0] is watch:
352+
self._queue.popleft()
353+
self._cv.notify_all()
354+
continue
355+
356+
if time.monotonic() - watch.start_s >= self._timeout_s:
357+
self._handle_timeout(watch, observed_flags)
358+
with self._cv:
359+
# The GPU stream is no longer trustworthy once a collective
360+
# times out. Drop queued follow-on phases so they do not
361+
# produce duplicate or misleading reports.
362+
self._queue.clear()
363+
self._cv.notify_all()
364+
continue
365+
366+
with self._cv:
367+
self._cv.wait(timeout=self._poll_interval_s)

0 commit comments

Comments
 (0)