File tree Expand file tree Collapse file tree
tests/unittest/_torch/modules Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -197,6 +197,7 @@ def __init__(
197197
198198 self ._cv = threading .Condition ()
199199 self ._queue : Deque [_CollectiveWatch ] = deque ()
200+ self ._closed = False
200201 self ._stopping = False
201202 self ._thread : threading .Thread | None = None
202203 self ._last_error : BaseException | None = None
@@ -249,6 +250,8 @@ def last_error(self) -> BaseException | None:
249250 def start (self ) -> None :
250251 """Start the background polling thread. Idempotent."""
251252 with self ._cv :
253+ if self ._closed :
254+ raise RuntimeError ("cannot start a stopped AlltoAllWatchdog" )
252255 if self ._thread is not None and self ._thread .is_alive ():
253256 return
254257 self ._stopping = False
@@ -262,6 +265,7 @@ def start(self) -> None:
262265 def stop (self , timeout_s : float | None = None ) -> None :
263266 """Stop the polling thread and wait for it to exit."""
264267 with self ._cv :
268+ self ._closed = True
265269 self ._stopping = True
266270 self ._queue .clear ()
267271 self ._cv .notify_all ()
@@ -291,7 +295,7 @@ def watch(
291295
292296 self .start ()
293297 with self ._cv :
294- if self ._stopping :
298+ if self ._closed :
295299 raise RuntimeError ("cannot queue a stopped AlltoAllWatchdog" )
296300 self ._queue .append (
297301 _CollectiveWatch (
Original file line number Diff line number Diff line change @@ -108,6 +108,24 @@ def test_watchdog_defaults_match_design_doc() -> None:
108108 assert watchdog ._poll_interval_s == DEFAULT_ALLTOALL_WATCHDOG_POLL_INTERVAL_S
109109
110110
111+ def test_watchdog_stop_is_terminal () -> None :
112+ reader = FakeCompletionFlagReader (ep_size = 1 )
113+ watchdog = AlltoAllWatchdog (
114+ ep_size = 1 ,
115+ ep_rank = 0 ,
116+ completion_reader = reader ,
117+ timeout_s = 0.2 ,
118+ poll_interval_s = 0.005 ,
119+ )
120+ watchdog .start ()
121+ watchdog .stop (timeout_s = 1.0 )
122+
123+ with pytest .raises (RuntimeError , match = "stopped AlltoAllWatchdog" ):
124+ watchdog .start ()
125+ with pytest .raises (RuntimeError , match = "stopped AlltoAllWatchdog" ):
126+ watchdog .watch (phase = "dispatch" , expected_flag = 1 )
127+
128+
111129def test_wide_ep_ft_options_create_shared_health_when_enabled (monkeypatch ) -> None :
112130 monkeypatch .setenv ("TRTLLM_ENABLE_WIDE_EP_FT" , "1" )
113131 model_config = SimpleNamespace (
You can’t perform that action at this time.
0 commit comments