2828import time
2929import traceback
3030from pathlib import Path
31- from typing import Any , Dict , List , Optional , Tuple
31+ from typing import Any , Callable , Dict , List , Optional , Tuple
3232
3333import torch
3434from mpi4py import MPI
@@ -70,6 +70,18 @@ def _try_import(module_path: str, attr: Optional[str] = None, default: Any = Non
7070POISON_HERE_PREFIX = "cuda_context_poisoned_after_success"
7171POISON_UPSTREAM_PREFIX = "cuda_context_poisoned_upstream"
7272WATCHDOG_UPSTREAM_PREFIX = "watchdog_timeout_upstream"
73+ # Terminal (status="failed") marker for a candidate the watchdog killed for
74+ # exceeding its wall-clock budget. NOT suffixed "_upstream": is_completed_for_resume
75+ # treats status="failed" as terminal, so --resume_from SKIPS it (does not re-attempt
76+ # and re-hang) while still surfacing the hang as a result row with a clear reason.
77+ WATCHDOG_TIMEOUT_PREFIX = "watchdog_timeout"
78+ # Terminal (status="failed") placeholder pre-written for the in-flight candidate
79+ # BEFORE it runs. If the process dies mid-candidate in a way nothing else can
80+ # record -- a CUDA device-side assert that aborts the MPI step, OOM-kill,
81+ # SIGSEGV, node loss -- this persisted row makes the candidate terminal so
82+ # --resume_from skips it and advances. Replaced with the real result on normal
83+ # completion. Like WATCHDOG_TIMEOUT_PREFIX, NOT suffixed "_upstream".
84+ INCOMPLETE_PREFIX = "incomplete"
7385BENCH_MOE_POISON_EXIT_CODE = 75
7486
7587
@@ -182,11 +194,29 @@ def allreduce_poison_reason(local_reason: Optional[str]) -> Optional[str]:
182194
183195
184196class CandidateWatchdog :
185- """Hard wall-clock guard around one candidate; SIGKILLs the process on timeout."""
197+ """Hard wall-clock guard around one candidate; SIGKILLs the process on timeout.
198+
199+ On timeout the guard first invokes ``on_timeout`` (used to record the hung
200+ candidate as a terminal ``failed`` result + checkpoint so it is not silently
201+ lost and is skipped on ``--resume_from`` rather than re-attempted), then
202+ SIGKILLs to break the wedged CUDA/NCCL state. A genuine hang cannot be
203+ recovered in-process, so the kill is unavoidable; ``on_timeout`` makes it a
204+ recorded outcome instead of a vanished one.
205+ """
186206
187- def __init__ (self , budget_s : float , label : str ):
207+ def __init__ (
208+ self ,
209+ budget_s : float ,
210+ label : str ,
211+ on_timeout : Optional [Callable [[], None ]] = None ,
212+ rank0_persist_grace_s : float = 8.0 ,
213+ ):
188214 self ._budget_s = float (budget_s )
189215 self ._label = label
216+ self ._on_timeout = on_timeout
217+ # Non-rank-0 ranks wait this long before SIGKILL so rank 0 can persist the
218+ # checkpoint before the first task exit tears down the whole srun step.
219+ self ._rank0_persist_grace_s = float (rank0_persist_grace_s )
190220 self ._cancelled = threading .Event ()
191221 self ._thread : Optional [threading .Thread ] = None
192222
@@ -211,16 +241,35 @@ def __exit__(self, exc_type, exc, tb) -> bool:
211241 def _guard (self ) -> None :
212242 if self ._cancelled .wait (self ._budget_s ):
213243 return
244+ rank = mpi_rank ()
214245 try :
215246 sys .stderr .write (
216247 f"[bench_moe watchdog] candidate '{ self ._label } ' exceeded "
217248 f"{ self ._budget_s :.1f} s budget on pid={ os .getpid ()} "
218- f"rank={ mpi_rank () } ; sending SIGKILL to break suspected "
219- f"NCCL deadlock or CUDA hang.\n "
249+ f"rank={ rank } ; recording it as a failed (timeout) result, then "
250+ f"sending SIGKILL to break suspected NCCL deadlock or CUDA hang.\n "
220251 )
221252 sys .stderr .flush ()
222253 except Exception :
223254 pass
255+ # Record the hung candidate as a terminal failed row + checkpoint. Rank 0
256+ # writes; other ranks no-op (see _emit_checkpoint_report). The main thread
257+ # is blocked in a GIL-releasing CUDA call, so this watchdog thread can run.
258+ if self ._on_timeout is not None :
259+ try :
260+ self ._on_timeout ()
261+ except Exception as exc : # never let bookkeeping block the kill
262+ try :
263+ sys .stderr .write (
264+ f"[bench_moe watchdog] on_timeout callback failed "
265+ f"({ type (exc ).__name__ } : { exc } ); killing anyway.\n "
266+ )
267+ sys .stderr .flush ()
268+ except Exception :
269+ pass
270+ # Let rank 0 flush its checkpoint before the first SIGKILL aborts the step.
271+ if rank != 0 and self ._rank0_persist_grace_s > 0 :
272+ time .sleep (self ._rank0_persist_grace_s )
224273 os .kill (os .getpid (), signal .SIGKILL )
225274
226275
@@ -518,8 +567,52 @@ def _run_benchmark_worker_under_current_mpi(args: argparse.Namespace, launcher:
518567 )
519568 _maybe_print_rank0 (f"[bench_moe] running { case_label } " )
520569
570+ # Pre-write a terminal "failed" placeholder for THIS candidate and
571+ # checkpoint it BEFORE running. If the process then dies mid-candidate in
572+ # a way nothing else can catch -- a CUDA device-side assert that aborts
573+ # the MPI step, OOM-kill, SIGSEGV, node loss -- this persisted row keeps
574+ # the candidate terminal, so --resume_from skips it and advances to the
575+ # next one instead of re-attempting (and re-crashing on) the same
576+ # candidate forever. On normal completion it is replaced with the real
577+ # result below. Only the in-flight candidate gets a placeholder;
578+ # not-yet-run candidates have no row and are still attempted on resume.
579+ placeholder = _make_skipped_run_result (
580+ model = ctx .model ,
581+ workload = workload ,
582+ config = cand ,
583+ world_size = world_size ,
584+ analysis = ctx .analysis ,
585+ reason = (
586+ f"{ INCOMPLETE_PREFIX } : process died before this candidate "
587+ f"finished (crash/abort/OOM/kill) ({ case_label } )"
588+ ),
589+ )
590+ placeholder .status = "failed"
591+ placeholder .status_per_rank = {f"rank{ i } " : "incomplete" for i in range (world_size )}
592+ accumulated_rows .append (_runresult_to_row (placeholder ))
593+ _emit_checkpoint_report (args = args , ctx = ctx , rows = accumulated_rows , world_size = world_size )
594+
595+ # If the watchdog fires (suspected hang), overwrite the placeholder's
596+ # reason with the precise timeout text and re-checkpoint before SIGKILL,
597+ # so the hang is surfaced as a clear result (the row is already terminal).
598+ def _record_watchdog_timeout (
599+ _label : str = case_label ,
600+ _budget_s : float = watchdog_budget_s ,
601+ ) -> None :
602+ if accumulated_rows :
603+ accumulated_rows [- 1 ]["skip_reason" ] = (
604+ f"{ WATCHDOG_TIMEOUT_PREFIX } : exceeded { _budget_s :.0f} s; "
605+ f"suspected NCCL/CUDA hang ({ _label } )"
606+ )
607+ accumulated_rows [- 1 ]["status_per_rank" ] = {
608+ f"rank{ i } " : "timeout" for i in range (world_size )
609+ }
610+ _emit_checkpoint_report (
611+ args = args , ctx = ctx , rows = accumulated_rows , world_size = world_size
612+ )
613+
521614 # Hard wall-clock guard around the actual candidate execution.
522- with CandidateWatchdog (watchdog_budget_s , case_label ):
615+ with CandidateWatchdog (watchdog_budget_s , case_label , on_timeout = _record_watchdog_timeout ):
523616 with torch .device (device ):
524617 r = _run_one_candidate (
525618 model = ctx .model ,
@@ -539,8 +632,10 @@ def _run_benchmark_worker_under_current_mpi(args: argparse.Namespace, launcher:
539632 input_cache = input_cache ,
540633 enable_perfect_router_requested = bool (args .enable_perfect_router ),
541634 )
635+ # Candidate finished normally: replace the pre-written placeholder (the
636+ # last row) with the real result.
542637 row = _runresult_to_row (r )
543- accumulated_rows . append ( row )
638+ accumulated_rows [ - 1 ] = row
544639 if rank == 0 :
545640 print (json .dumps (row , indent = 2 ), flush = True )
546641
0 commit comments