Skip to content

Commit 4ddb47d

Browse files
committed
fix: make AlltoAll watchdog stop terminal
1 parent 26639ec commit 4ddb47d

2 files changed

Lines changed: 23 additions & 1 deletion

File tree

tensorrt_llm/_torch/alltoall_watchdog.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,7 @@ def __init__(
197197

198198
self._cv = threading.Condition()
199199
self._queue: Deque[_CollectiveWatch] = deque()
200+
self._closed = False
200201
self._stopping = False
201202
self._thread: threading.Thread | None = None
202203
self._last_error: BaseException | None = None
@@ -249,6 +250,8 @@ def last_error(self) -> BaseException | None:
249250
def start(self) -> None:
250251
"""Start the background polling thread. Idempotent."""
251252
with self._cv:
253+
if self._closed:
254+
raise RuntimeError("cannot start a stopped AlltoAllWatchdog")
252255
if self._thread is not None and self._thread.is_alive():
253256
return
254257
self._stopping = False
@@ -262,6 +265,7 @@ def start(self) -> None:
262265
def stop(self, timeout_s: float | None = None) -> None:
263266
"""Stop the polling thread and wait for it to exit."""
264267
with self._cv:
268+
self._closed = True
265269
self._stopping = True
266270
self._queue.clear()
267271
self._cv.notify_all()
@@ -291,7 +295,7 @@ def watch(
291295

292296
self.start()
293297
with self._cv:
294-
if self._stopping:
298+
if self._closed:
295299
raise RuntimeError("cannot queue a stopped AlltoAllWatchdog")
296300
self._queue.append(
297301
_CollectiveWatch(

tests/unittest/_torch/modules/test_alltoall_watchdog.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,24 @@ def test_watchdog_defaults_match_design_doc() -> None:
108108
assert watchdog._poll_interval_s == DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S
109109

110110

111+
def test_watchdog_stop_is_terminal() -> None:
112+
reader = FakeCompletionFlagReader(ep_size=1)
113+
watchdog = AlltoAllWatchdog(
114+
ep_size=1,
115+
ep_rank=0,
116+
completion_reader=reader,
117+
timeout_s=0.2,
118+
poll_interval_s=0.005,
119+
)
120+
watchdog.start()
121+
watchdog.stop(timeout_s=1.0)
122+
123+
with pytest.raises(RuntimeError, match="stopped AlltoAllWatchdog"):
124+
watchdog.start()
125+
with pytest.raises(RuntimeError, match="stopped AlltoAllWatchdog"):
126+
watchdog.watch(phase="dispatch", expected_flag=1)
127+
128+
111129
def test_wide_ep_ft_options_create_shared_health_when_enabled(monkeypatch) -> None:
112130
monkeypatch.setenv("TRTLLM_ENABLE_WIDE_EP_FT", "1")
113131
model_config = SimpleNamespace(

0 commit comments

Comments
 (0)