Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion tensorrt_llm/_torch/modules/fused_moe/configurable_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,12 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
if self.use_dp and self.comm is not None:
num_rows = self._dp_padded_num_rows(all_rank_num_tokens)
else:
num_rows = sum(all_rank_num_tokens)
# non-DP: no cross-rank dispatch. The scheduler fills all_rank_num_tokens
# from [x.shape[0]] before calling here, so it must be a single-element list.
assert len(all_rank_num_tokens) == 1, (
f"non-DP path expects a single-element list, got {len(all_rank_num_tokens)}"
)
num_rows = all_rank_num_tokens[0]
return (num_rows + self.moe_max_num_tokens - 1) // self.moe_max_num_tokens
Comment thread
guqiqi marked this conversation as resolved.

def split_chunk(self, split_token_num: int, split_num_chunks: int) -> List[int]:
Expand Down
25 changes: 25 additions & 0 deletions tensorrt_llm/_torch/modules/fused_moe/moe_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,31 @@ def forward(
else:
all_rank_num_tokens_padded = all_rank_num_tokens

# ========== 0-token rank deadlock fix ==========
# When some ranks have 0 tokens in single-chunk forward with collective comm,
# those ranks hang in CUDA kernels (e.g. NVFP4 quantize_input with 0-row tensor)
# before reaching moe.comm.dispatch(), causing NCCL AllGather deadlock on
# non-zero ranks. Fix: activate DP padding uniformly across all ranks so every
# rank uses sizes=None (uniform allgather) and pads x/router_logits to max_tokens.
# Mirrors the empty-chunk substitution in _forward_multiple_chunks (line ~597-620).
# Existing truncation at line ~202 discards dummy-token outputs automatically.
if (
moe.comm is not None
and moe.use_dp
and all_rank_max_num_tokens > 0
and not use_dp_padding
and any(t == 0 for t in all_rank_num_tokens_padded)
):
use_dp_padding = True
all_rank_num_tokens_padded = [all_rank_max_num_tokens] * len(all_rank_num_tokens)
local_n = x.shape[0]
if local_n < all_rank_max_num_tokens:
pad = all_rank_max_num_tokens - local_n
x = torch.cat([x, x.new_zeros((pad, x.shape[1]))], dim=0)
router_logits = torch.cat(
[router_logits, router_logits.new_zeros((pad, router_logits.shape[1]))], dim=0
)

# ========== Step 2: Determine communication method ==========
num_chunks = moe.calculate_num_chunks(all_rank_num_tokens_padded)

Expand Down
46 changes: 29 additions & 17 deletions tests/microbenchmarks/bench_moe/case_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ def _select_routing_inputs(
routing_plan: RoutingPlan,
rank: int,
moe_ep_size: int,
enable_attention_dp: bool,
base_router_logits: torch.Tensor,
device: torch.device,
act_dtype: torch.dtype,
Expand Down Expand Up @@ -411,6 +412,21 @@ def _select_routing_inputs(
except Exception as exc:
return None, _RoutingSkip(f"native logits projection error: {type(exc).__name__}: {exc}")

# In attention-DP + MoE-TP layouts (DTP / CUSTOM-DP), _project_router_logits
# returns logits shaped [agg_tokens, E] covering all DP shards aggregated
# onto ep_axis_rank. The MoE internally allgathers each rank's local
# router_logits before routing, so each rank must supply only its local
# slice [offset_r : offset_r + n_r] of the full projected tensor.
world_size_inferred = len(routing_plan.per_rank_num_tokens)
if enable_attention_dp and int(moe_ep_size) < world_size_inferred:
offset = sum(
routing_plan.per_rank_num_tokens[s]
for s in range(world_size_inferred)
if s % int(moe_ep_size) == ep_axis_rank and s < rank
)
local_n = routing_plan.per_rank_num_tokens[rank]
new_logits = new_logits[offset : offset + local_n]

if projection_status != "exact" and rc_spec.projection_policy == "reject":
return None, _RoutingSkip(
skip_reason=(
Expand Down Expand Up @@ -536,21 +552,6 @@ def _resolve_layout_and_plan(
except ValueError as exc:
return _short_circuit(result, "skipped", str(exc))

# Routing-control's dispatch_matrix axis is ``moe_ep_size`` while
# ``per_rank_num_tokens`` follows the world (DP source) axis. When the two
# disagree (DTP/TTP/CUSTOM with ``moe_ep_size != world_size``) the plan
# either crashes inside ``_build_routing_plan`` or silently drops the
# tokens of world ranks beyond ``moe_ep_size``. Skip cleanly.
if rc_active and int(moe_ep_size) != int(world_size):
return _short_circuit(
result,
"skipped",
f"routing-control requires moe_ep_size == world_size "
f"(got moe_ep_size={moe_ep_size}, world_size={world_size}); "
"the dispatch_matrix axis would not align with the per-rank token "
"distribution. Use parallel_mode in {DEP, TEP} or drop routing-control.",
)

routing_plan: Optional[RoutingPlan] = None
if rc_active:
try:
Expand All @@ -561,14 +562,15 @@ def _resolve_layout_and_plan(
top_k=int(model.top_k),
num_experts=int(model.num_experts),
moe_ep_size=int(moe_ep_size),
enable_dp=bool(_enable_dp),
)
except Exception as exc:
reason = f"routing plan error: {type(exc).__name__}: {exc}"
_maybe_print_rank0(f"[bench_moe] {reason}")
return _short_circuit(result, "skipped", reason)
per_rank = list(routing_plan.per_rank_num_tokens)
else:
per_rank = _per_rank_tokens(workload, world_size)
per_rank = _per_rank_tokens(workload, world_size, enable_dp=bool(_enable_dp))

return int(moe_ep_size), per_rank, routing_plan

Expand Down Expand Up @@ -663,6 +665,11 @@ def _run_one_candidate(
result.moe_tp_size = int(mapping.moe_tp_size)
result.enable_attention_dp = bool(mapping.enable_attention_dp)

# TEP/TTP (no attention DP): no cross-rank dispatch; the scheduler fills
# all_rank_num_tokens from x.shape[0]. Pass None to follow that path.
if not mapping.enable_attention_dp:
all_rank_num_tokens = None

AutoTuner.get().setup_distributed_state(mapping)
AutoTuner.get().clear_cache()

Expand All @@ -683,7 +690,11 @@ def _run_one_candidate(
mapping=mapping,
moe_backend=config.backend,
use_cuda_graph=bool(config.cuda_graph),
max_num_tokens=max(int(local_num_tokens), 1),
# Symmetric-memory comm backends (e.g. NVLINK_ONE_SIDED) size their
# workspace from max_num_tokens and require every rank to allocate the
# same size, so use the global per-rank maximum rather than this rank's
# local token count (which differs under uneven attention-DP shards).
max_num_tokens=max(int(max(per_rank)) if per_rank else 0, 1),
use_low_precision_moe_combine=bool(config.use_low_precision_moe_combine),
enable_perfect_router=enable_perfect_router,
dtype=act_dtype,
Expand Down Expand Up @@ -735,6 +746,7 @@ def _run_one_candidate(
routing_plan=routing_plan,
rank=rank,
moe_ep_size=int(moe_ep_size),
enable_attention_dp=bool(result.enable_attention_dp),
base_router_logits=router_logits,
device=device,
act_dtype=act_dtype,
Expand Down
15 changes: 12 additions & 3 deletions tests/microbenchmarks/bench_moe/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tensorrt_llm.models.modeling_utils import QuantAlgo

from .backend import MoeBackendType
from .mapping import _resolve_mapping_layout
from .routing import _per_rank_tokens
from .search import (
_coerce_str_tuple,
Expand Down Expand Up @@ -91,7 +92,14 @@ def _build_worker_header(ctx: _BenchmarkContext, launcher: str, world_size: int)
"world_size": world_size,
"analysis": list(ctx.analysis) or ["summary"],
"workloads": [
w.to_dict(per_rank_num_tokens=_per_rank_tokens(w, world_size)) for w in ctx.workloads
w.to_dict(
per_rank_num_tokens=_per_rank_tokens(
w,
world_size,
enable_dp=bool(_resolve_mapping_layout(ctx.base_config, world_size)[2]),
)
)
for w in ctx.workloads
],
"base_config": ctx.base_config.to_dict(),
}
Expand Down Expand Up @@ -214,8 +222,9 @@ def parse_args() -> argparse.Namespace:
nargs="+",
required=False,
help=(
"Global token counts to sweep. Each value is balanced across ranks "
"with any remainder on rank 0. Example: --balanced_total_num_tokens 64 256 1024."
"Global token counts to sweep. Each value is balanced across ranks, "
"spreading any remainder one token per leading rank (e.g. world_size=4, "
"tokens=2 -> [1, 1, 0, 0]). Example: --balanced_total_num_tokens 64 256 1024."
),
)

Expand Down
5 changes: 5 additions & 0 deletions tests/microbenchmarks/bench_moe/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,17 @@ def _resolve_mapping_layout(config: ConfigSpec, world_size: int) -> Tuple[int, i
def _build_mapping_from_config(config: ConfigSpec, world_size: int) -> Mapping:
"""Build ``Mapping`` from a ``ConfigSpec`` + world size; sets ``rank=mpi_rank()``."""
moe_ep, moe_tp, enable_dp = _resolve_mapping_layout(config, world_size)
# gpus_per_node must match actual visible GPUs per node so that
# mapping.local_rank (= rank % gpus_per_node) gives the correct device index.
# The Mapping default (8) is wrong for multi-node runs with fewer GPUs per node.
gpus_per_node = torch.cuda.device_count()
mapping = Mapping(
world_size=world_size,
tp_size=world_size,
moe_ep_size=moe_ep,
moe_tp_size=moe_tp,
enable_attention_dp=enable_dp,
gpus_per_node=gpus_per_node,
)
mapping.rank = mpi_rank()
return mapping
Expand Down
4 changes: 3 additions & 1 deletion tests/microbenchmarks/bench_moe/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from tensorrt_llm._utils import mpi_allgather

from .mapping import _resolve_mapping_layout
from .routing import _per_rank_tokens
from .specs import ConfigSpec, ModelSpec, RunResult, WorkloadSpec
from .utils import _compute_stats
Expand Down Expand Up @@ -407,7 +408,8 @@ def _make_skipped_run_result(
r = RunResult(model=model, workload=workload, config=config)
r.status = "skipped"
r.skip_reason = reason
r.per_rank_num_tokens = _per_rank_tokens(workload, world_size)
_, _, _enable_dp = _resolve_mapping_layout(config, world_size)
r.per_rank_num_tokens = _per_rank_tokens(workload, world_size, enable_dp=bool(_enable_dp))
r.status_per_rank = {f"rank{i}": "skipped" for i in range(world_size)}
Comment thread
guqiqi marked this conversation as resolved.
r.instrumentation = {
"level": ",".join(sorted(analysis)) if analysis else "summary",
Expand Down
89 changes: 70 additions & 19 deletions tests/microbenchmarks/bench_moe/routing/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,43 +115,88 @@ def _build_per_rank_num_tokens(
spec: RoutingControlSpec,
num_tokens: int,
world_size: int,
enable_dp: bool,
) -> List[int]:
"""Resolve ``per_rank_num_tokens`` for a workload.

Explicit ``spec.per_rank_num_tokens`` wins; otherwise tokens are split
evenly across ranks with any remainder on rank 0.
Explicit ``spec.per_rank_num_tokens`` wins; otherwise the token count per
rank depends on the attention-DP setting:

* ``enable_dp=True`` (DEP / DTP): tokens are DP-sharded across ranks, so
each rank holds ``num_tokens / world_size``.
* ``enable_dp=False`` (TEP / TTP): attention is tensor-parallel, so every
rank sees the complete batch and holds ``num_tokens``.

When an explicit list is provided its sum is validated against the expected
total (``num_tokens`` for DP modes, ``num_tokens * world_size`` for non-DP).
"""
if spec.per_rank_num_tokens is None:
if not enable_dp:
return [int(num_tokens)] * world_size
return _distribute_tokens(int(num_tokens), world_size)
expected_total = int(num_tokens) * (1 if enable_dp else world_size)
return _validate_per_rank_token_list(
spec.per_rank_num_tokens, world_size=world_size, expected_total=int(num_tokens)
spec.per_rank_num_tokens, world_size=world_size, expected_total=expected_total
)


def _per_rank_tokens(workload: WorkloadSpec, world_size: int) -> List[int]:
def _per_rank_tokens(workload: WorkloadSpec, world_size: int, enable_dp: bool) -> List[int]:
"""Materialize the ``per_rank_num_tokens`` list for a workload + world size."""
return _build_per_rank_num_tokens(
workload.routing_control, int(workload.num_tokens), world_size
workload.routing_control, int(workload.num_tokens), world_size, enable_dp
)


def _aggregate_dispatch_source_tokens(
per_rank_num_tokens: List[int],
ep_size: int,
enable_dp: bool,
) -> List[int]:
"""Project world-rank token counts onto EP-source rows.

TRT-LLM Mapping orders MoE ranks with ``moe_ep_rank = tp_rank % moe_ep_size``.
In attention-DP modes each world rank owns a distinct token shard, so TP
shards targeting the same EP row are summed. In non-DP MoE-TP modes those TP
shards carry the same logical tokens, so only the first TP shard contributes
to the logical dispatch plan.
"""
if ep_size <= 0:
return []
if len(per_rank_num_tokens) == ep_size:
return [int(v) for v in per_rank_num_tokens]

source_tokens = [0] * ep_size
if not enable_dp:
for ep_rank in range(ep_size):
if ep_rank < len(per_rank_num_tokens):
source_tokens[ep_rank] = int(per_rank_num_tokens[ep_rank])
else:
for rank, num_tokens in enumerate(per_rank_num_tokens):
source_tokens[rank % ep_size] += int(num_tokens)

return source_tokens


def _build_dispatch_matrix(
comm_pattern: str,
per_rank_num_tokens: List[int],
top_k: int,
ep_size: int,
enable_dp: bool,
seed: int = 0,
) -> List[List[int]]:
"""Build the canonical slot ``dispatch_matrix`` for ``comm_pattern``.

Row sums always equal ``per_rank_num_tokens[src] * top_k``. The matrix is
a pure planning artefact: it does not enforce per-token uniqueness yet.
That constraint is checked at materialisation time.
Row sums equal the EP-source token counts projected from
``per_rank_num_tokens`` times ``top_k``. When world ranks outnumber EP
ranks (DTP / TTP / CUSTOM MoE-TP layouts), multiple world-rank rows are
aggregated onto the same EP-source row.
"""
name, kwargs = _parse_comm_pattern(comm_pattern)
source_tokens = _aggregate_dispatch_source_tokens(per_rank_num_tokens, ep_size, enable_dp)
matrix: List[List[int]] = [[0] * ep_size for _ in range(ep_size)]
for src in range(ep_size):
row_total = int(per_rank_num_tokens[src]) * int(top_k)
row_total = int(source_tokens[src]) * int(top_k)
if row_total == 0:
continue
if name == "file":
Expand Down Expand Up @@ -309,6 +354,7 @@ def _build_routing_plan(
top_k: int,
num_experts: int,
moe_ep_size: int,
enable_dp: bool,
) -> RoutingPlan:
"""Translate a ``RoutingControlSpec`` into a canonical normalised plan."""
if moe_ep_size <= 0 or num_experts % moe_ep_size != 0:
Expand All @@ -318,11 +364,10 @@ def _build_routing_plan(
experts_per_rank = num_experts // moe_ep_size
if top_k > num_experts:
raise ValueError(f"top_k ({top_k}) must be <= num_experts ({num_experts})")
per_rank = _build_per_rank_num_tokens(spec, num_tokens, world_size)
# The dispatch matrix is indexed by EP rank on both axes. The current
# worker only calls routing-control planning when ``moe_ep_size`` equals
# ``world_size`` so that this EP-axis matrix also matches the user-visible
# per-rank token list.
per_rank = _build_per_rank_num_tokens(spec, num_tokens, world_size, enable_dp)
# The dispatch matrix stays on EP axes. When MoE-TP makes multiple world
# ranks share one EP rank, the world-rank token counts are aggregated onto
# the corresponding EP-source row before building the matrix.
if spec.routing_pattern_file:
default_patterns = {("balanced_alltoall", "balanced"), ("random", "random")}
if (spec.comm_pattern, spec.expert_pattern) not in default_patterns:
Expand All @@ -335,26 +380,32 @@ def _build_routing_plan(
)
else:
dispatch_matrix = _build_dispatch_matrix(
spec.comm_pattern, per_rank, top_k, moe_ep_size, seed=spec.seed
spec.comm_pattern,
per_rank,
top_k,
moe_ep_size,
enable_dp=enable_dp,
seed=spec.seed,
)
expert_histogram = _build_expert_histogram(
spec.expert_pattern, dispatch_matrix, experts_per_rank, moe_ep_size, seed=spec.seed
)

# Per-row sums are an invariant; emit a clearer error than the materialiser would.
source_tokens = _aggregate_dispatch_source_tokens(per_rank, moe_ep_size, enable_dp)
for src in range(moe_ep_size):
expected = int(per_rank[src]) * int(top_k) if src < len(per_rank) else 0
expected = int(source_tokens[src]) * int(top_k) if src < len(source_tokens) else 0
actual = sum(dispatch_matrix[src])
if actual != expected:
raise ValueError(
f"dispatch_matrix row {src} sums to {actual}, expected per_rank_num_tokens[{src}] * top_k = {expected}"
f"dispatch_matrix row {src} sums to {actual}, expected aggregate source tokens * top_k = {expected}"
)
# Global expert histogram total must match total slots.
total_slots = sum(int(t) for t in per_rank) * int(top_k)
total_slots = sum(int(t) for t in source_tokens) * int(top_k)
hist_total = sum(sum(row) for row in expert_histogram)
if hist_total != total_slots:
raise ValueError(
f"expert_histogram sum={hist_total} must equal sum(per_rank_num_tokens) * top_k = {total_slots}"
f"expert_histogram sum={hist_total} must equal aggregate source tokens * top_k = {total_slots}"
)

return RoutingPlan(
Expand Down
Loading
Loading