Skip to content

Commit 2e959a0

Browse files
committed
fix: bench-moe timeout handler, kernel error handler, workload computation for non-DP attention path, and MoE logic for non-DP attention path
Signed-off-by: guqiqi <29116997+guqiqi@users.noreply.github.com>
1 parent 07e3c3f commit 2e959a0

6 files changed

Lines changed: 146 additions & 17 deletions

File tree

tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,12 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
429429
if self.use_dp and self.comm is not None:
430430
num_rows = self._dp_padded_num_rows(all_rank_num_tokens)
431431
else:
432-
num_rows = sum(all_rank_num_tokens)
432+
# non-DP: no cross-rank dispatch. The scheduler fills all_rank_num_tokens
433+
# from [x.shape[0]] before calling here, so it must be a single-element list.
434+
assert len(all_rank_num_tokens) == 1, (
435+
f"non-DP path expects a single-element list, got {len(all_rank_num_tokens)}"
436+
)
437+
num_rows = all_rank_num_tokens[0]
433438
return (num_rows + self.moe_max_num_tokens - 1) // self.moe_max_num_tokens
434439

435440
def split_chunk(self, split_token_num: int, split_num_chunks: int) -> List[int]:

tests/microbenchmarks/bench_moe/case_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,14 +561,15 @@ def _resolve_layout_and_plan(
561561
top_k=int(model.top_k),
562562
num_experts=int(model.num_experts),
563563
moe_ep_size=int(moe_ep_size),
564+
enable_dp=bool(_enable_dp),
564565
)
565566
except Exception as exc:
566567
reason = f"routing plan error: {type(exc).__name__}: {exc}"
567568
_maybe_print_rank0(f"[bench_moe] {reason}")
568569
return _short_circuit(result, "skipped", reason)
569570
per_rank = list(routing_plan.per_rank_num_tokens)
570571
else:
571-
per_rank = _per_rank_tokens(workload, world_size)
572+
per_rank = _per_rank_tokens(workload, world_size, enable_dp=bool(_enable_dp))
572573

573574
return int(moe_ep_size), per_rank, routing_plan
574575

@@ -663,6 +664,11 @@ def _run_one_candidate(
663664
result.moe_tp_size = int(mapping.moe_tp_size)
664665
result.enable_attention_dp = bool(mapping.enable_attention_dp)
665666

667+
# TEP/TTP (no attention DP): no cross-rank dispatch; the scheduler fills
668+
# all_rank_num_tokens from x.shape[0]. Pass None to follow that path.
669+
if not mapping.enable_attention_dp:
670+
all_rank_num_tokens = None
671+
666672
AutoTuner.get().setup_distributed_state(mapping)
667673
AutoTuner.get().clear_cache()
668674

tests/microbenchmarks/bench_moe/cli.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from tensorrt_llm.models.modeling_utils import QuantAlgo
3030

3131
from .backend import MoeBackendType
32+
from .mapping import _resolve_mapping_layout
3233
from .routing import _per_rank_tokens
3334
from .search import (
3435
_coerce_str_tuple,
@@ -91,7 +92,14 @@ def _build_worker_header(ctx: _BenchmarkContext, launcher: str, world_size: int)
9192
"world_size": world_size,
9293
"analysis": list(ctx.analysis) or ["summary"],
9394
"workloads": [
94-
w.to_dict(per_rank_num_tokens=_per_rank_tokens(w, world_size)) for w in ctx.workloads
95+
w.to_dict(
96+
per_rank_num_tokens=_per_rank_tokens(
97+
w,
98+
world_size,
99+
enable_dp=bool(_resolve_mapping_layout(ctx.base_config, world_size)[2]),
100+
)
101+
)
102+
for w in ctx.workloads
95103
],
96104
"base_config": ctx.base_config.to_dict(),
97105
}

tests/microbenchmarks/bench_moe/results.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from tensorrt_llm._utils import mpi_allgather
2323

24+
from .mapping import _resolve_mapping_layout
2425
from .routing import _per_rank_tokens
2526
from .specs import ConfigSpec, ModelSpec, RunResult, WorkloadSpec
2627
from .utils import _compute_stats
@@ -407,7 +408,8 @@ def _make_skipped_run_result(
407408
r = RunResult(model=model, workload=workload, config=config)
408409
r.status = "skipped"
409410
r.skip_reason = reason
410-
r.per_rank_num_tokens = _per_rank_tokens(workload, world_size)
411+
_, _, _enable_dp = _resolve_mapping_layout(config, world_size)
412+
r.per_rank_num_tokens = _per_rank_tokens(workload, world_size, enable_dp=bool(_enable_dp))
411413
r.status_per_rank = {f"rank{i}": "skipped" for i in range(world_size)}
412414
r.instrumentation = {
413415
"level": ",".join(sorted(analysis)) if analysis else "summary",

tests/microbenchmarks/bench_moe/routing/builders.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -115,23 +115,35 @@ def _build_per_rank_num_tokens(
115115
spec: RoutingControlSpec,
116116
num_tokens: int,
117117
world_size: int,
118+
enable_dp: bool,
118119
) -> List[int]:
119120
"""Resolve ``per_rank_num_tokens`` for a workload.
120121
121-
Explicit ``spec.per_rank_num_tokens`` wins; otherwise tokens are split
122-
evenly across ranks with any remainder on rank 0.
122+
Explicit ``spec.per_rank_num_tokens`` wins; otherwise the token count per
123+
rank depends on the attention-DP setting:
124+
125+
* ``enable_dp=True`` (DEP / DTP): tokens are DP-sharded across ranks, so
126+
each rank holds ``num_tokens / world_size``.
127+
* ``enable_dp=False`` (TEP / TTP): attention is tensor-parallel, so every
128+
rank sees the complete batch and holds ``num_tokens``.
129+
130+
When an explicit list is provided its sum is validated against the expected
131+
total (``num_tokens`` for DP modes, ``num_tokens * world_size`` for non-DP).
123132
"""
124133
if spec.per_rank_num_tokens is None:
134+
if not enable_dp:
135+
return [int(num_tokens)] * world_size
125136
return _distribute_tokens(int(num_tokens), world_size)
137+
expected_total = int(num_tokens) * (1 if enable_dp else world_size)
126138
return _validate_per_rank_token_list(
127-
spec.per_rank_num_tokens, world_size=world_size, expected_total=int(num_tokens)
139+
spec.per_rank_num_tokens, world_size=world_size, expected_total=expected_total
128140
)
129141

130142

131-
def _per_rank_tokens(workload: WorkloadSpec, world_size: int) -> List[int]:
143+
def _per_rank_tokens(workload: WorkloadSpec, world_size: int, enable_dp: bool) -> List[int]:
132144
"""Materialize the ``per_rank_num_tokens`` list for a workload + world size."""
133145
return _build_per_rank_num_tokens(
134-
workload.routing_control, int(workload.num_tokens), world_size
146+
workload.routing_control, int(workload.num_tokens), world_size, enable_dp
135147
)
136148

137149

@@ -309,6 +321,7 @@ def _build_routing_plan(
309321
top_k: int,
310322
num_experts: int,
311323
moe_ep_size: int,
324+
enable_dp: bool,
312325
) -> RoutingPlan:
313326
"""Translate a ``RoutingControlSpec`` into a canonical normalised plan."""
314327
if moe_ep_size <= 0 or num_experts % moe_ep_size != 0:
@@ -318,7 +331,7 @@ def _build_routing_plan(
318331
experts_per_rank = num_experts // moe_ep_size
319332
if top_k > num_experts:
320333
raise ValueError(f"top_k ({top_k}) must be <= num_experts ({num_experts})")
321-
per_rank = _build_per_rank_num_tokens(spec, num_tokens, world_size)
334+
per_rank = _build_per_rank_num_tokens(spec, num_tokens, world_size, enable_dp)
322335
# The dispatch matrix is indexed by EP rank on both axes. The current
323336
# worker only calls routing-control planning when ``moe_ep_size`` equals
324337
# ``world_size`` so that this EP-axis matrix also matches the user-visible

tests/microbenchmarks/bench_moe/worker.py

Lines changed: 102 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import time
2929
import traceback
3030
from pathlib import Path
31-
from typing import Any, Dict, List, Optional, Tuple
31+
from typing import Any, Callable, Dict, List, Optional, Tuple
3232

3333
import torch
3434
from mpi4py import MPI
@@ -70,6 +70,18 @@ def _try_import(module_path: str, attr: Optional[str] = None, default: Any = Non
7070
POISON_HERE_PREFIX = "cuda_context_poisoned_after_success"
7171
POISON_UPSTREAM_PREFIX = "cuda_context_poisoned_upstream"
7272
WATCHDOG_UPSTREAM_PREFIX = "watchdog_timeout_upstream"
73+
# Terminal (status="failed") marker for a candidate the watchdog killed for
74+
# exceeding its wall-clock budget. NOT suffixed "_upstream": is_completed_for_resume
75+
# treats status="failed" as terminal, so --resume_from SKIPS it (does not re-attempt
76+
# and re-hang) while still surfacing the hang as a result row with a clear reason.
77+
WATCHDOG_TIMEOUT_PREFIX = "watchdog_timeout"
78+
# Terminal (status="failed") placeholder pre-written for the in-flight candidate
79+
# BEFORE it runs. If the process dies mid-candidate in a way nothing else can
80+
# record -- a CUDA device-side assert that aborts the MPI step, OOM-kill,
81+
# SIGSEGV, node loss -- this persisted row makes the candidate terminal so
82+
# --resume_from skips it and advances. Replaced with the real result on normal
83+
# completion. Like WATCHDOG_TIMEOUT_PREFIX, NOT suffixed "_upstream".
84+
INCOMPLETE_PREFIX = "incomplete"
7385
BENCH_MOE_POISON_EXIT_CODE = 75
7486

7587

@@ -182,11 +194,29 @@ def allreduce_poison_reason(local_reason: Optional[str]) -> Optional[str]:
182194

183195

184196
class CandidateWatchdog:
185-
"""Hard wall-clock guard around one candidate; SIGKILLs the process on timeout."""
197+
"""Hard wall-clock guard around one candidate; SIGKILLs the process on timeout.
198+
199+
On timeout the guard first invokes ``on_timeout`` (used to record the hung
200+
candidate as a terminal ``failed`` result + checkpoint so it is not silently
201+
lost and is skipped on ``--resume_from`` rather than re-attempted), then
202+
SIGKILLs to break the wedged CUDA/NCCL state. A genuine hang cannot be
203+
recovered in-process, so the kill is unavoidable; ``on_timeout`` makes it a
204+
recorded outcome instead of a vanished one.
205+
"""
186206

187-
def __init__(self, budget_s: float, label: str):
207+
def __init__(
208+
self,
209+
budget_s: float,
210+
label: str,
211+
on_timeout: Optional[Callable[[], None]] = None,
212+
rank0_persist_grace_s: float = 8.0,
213+
):
188214
self._budget_s = float(budget_s)
189215
self._label = label
216+
self._on_timeout = on_timeout
217+
# Non-rank-0 ranks wait this long before SIGKILL so rank 0 can persist the
218+
# checkpoint before the first task exit tears down the whole srun step.
219+
self._rank0_persist_grace_s = float(rank0_persist_grace_s)
190220
self._cancelled = threading.Event()
191221
self._thread: Optional[threading.Thread] = None
192222

@@ -211,16 +241,35 @@ def __exit__(self, exc_type, exc, tb) -> bool:
211241
def _guard(self) -> None:
212242
if self._cancelled.wait(self._budget_s):
213243
return
244+
rank = mpi_rank()
214245
try:
215246
sys.stderr.write(
216247
f"[bench_moe watchdog] candidate '{self._label}' exceeded "
217248
f"{self._budget_s:.1f}s budget on pid={os.getpid()} "
218-
f"rank={mpi_rank()}; sending SIGKILL to break suspected "
219-
f"NCCL deadlock or CUDA hang.\n"
249+
f"rank={rank}; recording it as a failed (timeout) result, then "
250+
f"sending SIGKILL to break suspected NCCL deadlock or CUDA hang.\n"
220251
)
221252
sys.stderr.flush()
222253
except Exception:
223254
pass
255+
# Record the hung candidate as a terminal failed row + checkpoint. Rank 0
256+
# writes; other ranks no-op (see _emit_checkpoint_report). The main thread
257+
# is blocked in a GIL-releasing CUDA call, so this watchdog thread can run.
258+
if self._on_timeout is not None:
259+
try:
260+
self._on_timeout()
261+
except Exception as exc: # never let bookkeeping block the kill
262+
try:
263+
sys.stderr.write(
264+
f"[bench_moe watchdog] on_timeout callback failed "
265+
f"({type(exc).__name__}: {exc}); killing anyway.\n"
266+
)
267+
sys.stderr.flush()
268+
except Exception:
269+
pass
270+
# Let rank 0 flush its checkpoint before the first SIGKILL aborts the step.
271+
if rank != 0 and self._rank0_persist_grace_s > 0:
272+
time.sleep(self._rank0_persist_grace_s)
224273
os.kill(os.getpid(), signal.SIGKILL)
225274

226275

@@ -518,8 +567,52 @@ def _run_benchmark_worker_under_current_mpi(args: argparse.Namespace, launcher:
518567
)
519568
_maybe_print_rank0(f"[bench_moe] running {case_label}")
520569

570+
# Pre-write a terminal "failed" placeholder for THIS candidate and
571+
# checkpoint it BEFORE running. If the process then dies mid-candidate in
572+
# a way nothing else can catch -- a CUDA device-side assert that aborts
573+
# the MPI step, OOM-kill, SIGSEGV, node loss -- this persisted row keeps
574+
# the candidate terminal, so --resume_from skips it and advances to the
575+
# next one instead of re-attempting (and re-crashing on) the same
576+
# candidate forever. On normal completion it is replaced with the real
577+
# result below. Only the in-flight candidate gets a placeholder;
578+
# not-yet-run candidates have no row and are still attempted on resume.
579+
placeholder = _make_skipped_run_result(
580+
model=ctx.model,
581+
workload=workload,
582+
config=cand,
583+
world_size=world_size,
584+
analysis=ctx.analysis,
585+
reason=(
586+
f"{INCOMPLETE_PREFIX}: process died before this candidate "
587+
f"finished (crash/abort/OOM/kill) ({case_label})"
588+
),
589+
)
590+
placeholder.status = "failed"
591+
placeholder.status_per_rank = {f"rank{i}": "incomplete" for i in range(world_size)}
592+
accumulated_rows.append(_runresult_to_row(placeholder))
593+
_emit_checkpoint_report(args=args, ctx=ctx, rows=accumulated_rows, world_size=world_size)
594+
595+
# If the watchdog fires (suspected hang), overwrite the placeholder's
596+
# reason with the precise timeout text and re-checkpoint before SIGKILL,
597+
# so the hang is surfaced as a clear result (the row is already terminal).
598+
def _record_watchdog_timeout(
599+
_label: str = case_label,
600+
_budget_s: float = watchdog_budget_s,
601+
) -> None:
602+
if accumulated_rows:
603+
accumulated_rows[-1]["skip_reason"] = (
604+
f"{WATCHDOG_TIMEOUT_PREFIX}: exceeded {_budget_s:.0f}s; "
605+
f"suspected NCCL/CUDA hang ({_label})"
606+
)
607+
accumulated_rows[-1]["status_per_rank"] = {
608+
f"rank{i}": "timeout" for i in range(world_size)
609+
}
610+
_emit_checkpoint_report(
611+
args=args, ctx=ctx, rows=accumulated_rows, world_size=world_size
612+
)
613+
521614
# Hard wall-clock guard around the actual candidate execution.
522-
with CandidateWatchdog(watchdog_budget_s, case_label):
615+
with CandidateWatchdog(watchdog_budget_s, case_label, on_timeout=_record_watchdog_timeout):
523616
with torch.device(device):
524617
r = _run_one_candidate(
525618
model=ctx.model,
@@ -539,8 +632,10 @@ def _run_benchmark_worker_under_current_mpi(args: argparse.Namespace, launcher:
539632
input_cache=input_cache,
540633
enable_perfect_router_requested=bool(args.enable_perfect_router),
541634
)
635+
# Candidate finished normally: replace the pre-written placeholder (the
636+
# last row) with the real result.
542637
row = _runresult_to_row(r)
543-
accumulated_rows.append(row)
638+
accumulated_rows[-1] = row
544639
if rank == 0:
545640
print(json.dumps(row, indent=2), flush=True)
546641

0 commit comments

Comments
 (0)