Skip to content

Commit 89e0ad3

Browse files
committed
add DTP/TTP support
Signed-off-by: guqiqi <29116997+guqiqi@users.noreply.github.com>
1 parent 22257b1 commit 89e0ad3

4 files changed

Lines changed: 94 additions & 35 deletions

File tree

tests/microbenchmarks/bench_moe/case_runner.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@ def _select_routing_inputs(
344344
routing_plan: RoutingPlan,
345345
rank: int,
346346
moe_ep_size: int,
347+
enable_attention_dp: bool,
347348
base_router_logits: torch.Tensor,
348349
device: torch.device,
349350
act_dtype: torch.dtype,
@@ -411,6 +412,21 @@ def _select_routing_inputs(
411412
except Exception as exc:
412413
return None, _RoutingSkip(f"native logits projection error: {type(exc).__name__}: {exc}")
413414

415+
# In attention-DP + MoE-TP layouts (DTP / CUSTOM-DP), _project_router_logits
416+
# returns logits shaped [agg_tokens, E] covering all DP shards aggregated
417+
# onto ep_axis_rank. The MoE internally allgathers each rank's local
418+
# router_logits before routing, so each rank must supply only its local
419+
# slice [offset_r : offset_r + n_r] of the full projected tensor.
420+
world_size_inferred = len(routing_plan.per_rank_num_tokens)
421+
if enable_attention_dp and int(moe_ep_size) < world_size_inferred:
422+
offset = sum(
423+
routing_plan.per_rank_num_tokens[s]
424+
for s in range(world_size_inferred)
425+
if s % int(moe_ep_size) == ep_axis_rank and s < rank
426+
)
427+
local_n = routing_plan.per_rank_num_tokens[rank]
428+
new_logits = new_logits[offset : offset + local_n]
429+
414430
if projection_status != "exact" and rc_spec.projection_policy == "reject":
415431
return None, _RoutingSkip(
416432
skip_reason=(
@@ -536,21 +552,6 @@ def _resolve_layout_and_plan(
536552
except ValueError as exc:
537553
return _short_circuit(result, "skipped", str(exc))
538554

539-
# Routing-control's dispatch_matrix axis is ``moe_ep_size`` while
540-
# ``per_rank_num_tokens`` follows the world (DP source) axis. When the two
541-
# disagree (DTP/TTP/CUSTOM with ``moe_ep_size != world_size``) the plan
542-
# either crashes inside ``_build_routing_plan`` or silently drops the
543-
# tokens of world ranks beyond ``moe_ep_size``. Skip cleanly.
544-
if rc_active and int(moe_ep_size) != int(world_size):
545-
return _short_circuit(
546-
result,
547-
"skipped",
548-
f"routing-control requires moe_ep_size == world_size "
549-
f"(got moe_ep_size={moe_ep_size}, world_size={world_size}); "
550-
"the dispatch_matrix axis would not align with the per-rank token "
551-
"distribution. Use parallel_mode in {DEP, TEP} or drop routing-control.",
552-
)
553-
554555
routing_plan: Optional[RoutingPlan] = None
555556
if rc_active:
556557
try:
@@ -741,6 +742,7 @@ def _run_one_candidate(
741742
routing_plan=routing_plan,
742743
rank=rank,
743744
moe_ep_size=int(moe_ep_size),
745+
enable_attention_dp=bool(result.enable_attention_dp),
744746
base_router_logits=router_logits,
745747
device=device,
746748
act_dtype=act_dtype,

tests/microbenchmarks/bench_moe/routing/builders.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,23 +147,56 @@ def _per_rank_tokens(workload: WorkloadSpec, world_size: int, enable_dp: bool) -
147147
)
148148

149149

150+
def _aggregate_dispatch_source_tokens(
151+
per_rank_num_tokens: List[int],
152+
ep_size: int,
153+
enable_dp: bool,
154+
) -> List[int]:
155+
"""Project world-rank token counts onto EP-source rows.
156+
157+
TRT-LLM Mapping orders MoE ranks with ``moe_ep_rank = tp_rank % moe_ep_size``.
158+
In attention-DP modes each world rank owns a distinct token shard, so TP
159+
shards targeting the same EP row are summed. In non-DP MoE-TP modes those TP
160+
shards carry the same logical tokens, so only the first TP shard contributes
161+
to the logical dispatch plan.
162+
"""
163+
if ep_size <= 0:
164+
return []
165+
if len(per_rank_num_tokens) == ep_size:
166+
return [int(v) for v in per_rank_num_tokens]
167+
168+
source_tokens = [0] * ep_size
169+
if not enable_dp:
170+
for ep_rank in range(ep_size):
171+
if ep_rank < len(per_rank_num_tokens):
172+
source_tokens[ep_rank] = int(per_rank_num_tokens[ep_rank])
173+
else:
174+
for rank, num_tokens in enumerate(per_rank_num_tokens):
175+
source_tokens[rank % ep_size] += int(num_tokens)
176+
177+
return source_tokens
178+
179+
150180
def _build_dispatch_matrix(
151181
comm_pattern: str,
152182
per_rank_num_tokens: List[int],
153183
top_k: int,
154184
ep_size: int,
185+
enable_dp: bool,
155186
seed: int = 0,
156187
) -> List[List[int]]:
157188
"""Build the canonical slot ``dispatch_matrix`` for ``comm_pattern``.
158189
159-
Row sums always equal ``per_rank_num_tokens[src] * top_k``. The matrix is
160-
a pure planning artefact: it does not enforce per-token uniqueness yet.
161-
That constraint is checked at materialisation time.
190+
Row sums equal the EP-source token counts projected from
191+
``per_rank_num_tokens`` times ``top_k``. When world ranks outnumber EP
192+
ranks (DTP / TTP / CUSTOM MoE-TP layouts), multiple world-rank rows are
193+
aggregated onto the same EP-source row.
162194
"""
163195
name, kwargs = _parse_comm_pattern(comm_pattern)
196+
source_tokens = _aggregate_dispatch_source_tokens(per_rank_num_tokens, ep_size, enable_dp)
164197
matrix: List[List[int]] = [[0] * ep_size for _ in range(ep_size)]
165198
for src in range(ep_size):
166-
row_total = int(per_rank_num_tokens[src]) * int(top_k)
199+
row_total = int(source_tokens[src]) * int(top_k)
167200
if row_total == 0:
168201
continue
169202
if name == "file":
@@ -332,10 +365,9 @@ def _build_routing_plan(
332365
if top_k > num_experts:
333366
raise ValueError(f"top_k ({top_k}) must be <= num_experts ({num_experts})")
334367
per_rank = _build_per_rank_num_tokens(spec, num_tokens, world_size, enable_dp)
335-
# The dispatch matrix is indexed by EP rank on both axes. The current
336-
# worker only calls routing-control planning when ``moe_ep_size`` equals
337-
# ``world_size`` so that this EP-axis matrix also matches the user-visible
338-
# per-rank token list.
368+
# The dispatch matrix stays on EP axes. When MoE-TP makes multiple world
369+
# ranks share one EP rank, the world-rank token counts are aggregated onto
370+
# the corresponding EP-source row before building the matrix.
339371
if spec.routing_pattern_file:
340372
default_patterns = {("balanced_alltoall", "balanced"), ("random", "random")}
341373
if (spec.comm_pattern, spec.expert_pattern) not in default_patterns:
@@ -348,26 +380,32 @@ def _build_routing_plan(
348380
)
349381
else:
350382
dispatch_matrix = _build_dispatch_matrix(
351-
spec.comm_pattern, per_rank, top_k, moe_ep_size, seed=spec.seed
383+
spec.comm_pattern,
384+
per_rank,
385+
top_k,
386+
moe_ep_size,
387+
enable_dp=enable_dp,
388+
seed=spec.seed,
352389
)
353390
expert_histogram = _build_expert_histogram(
354391
spec.expert_pattern, dispatch_matrix, experts_per_rank, moe_ep_size, seed=spec.seed
355392
)
356393

357394
# Per-row sums are an invariant; emit a clearer error than the materialiser would.
395+
source_tokens = _aggregate_dispatch_source_tokens(per_rank, moe_ep_size, enable_dp)
358396
for src in range(moe_ep_size):
359-
expected = int(per_rank[src]) * int(top_k) if src < len(per_rank) else 0
397+
expected = int(source_tokens[src]) * int(top_k) if src < len(source_tokens) else 0
360398
actual = sum(dispatch_matrix[src])
361399
if actual != expected:
362400
raise ValueError(
363-
f"dispatch_matrix row {src} sums to {actual}, expected per_rank_num_tokens[{src}] * top_k = {expected}"
401+
f"dispatch_matrix row {src} sums to {actual}, expected aggregate source tokens * top_k = {expected}"
364402
)
365403
# Global expert histogram total must match total slots.
366-
total_slots = sum(int(t) for t in per_rank) * int(top_k)
404+
total_slots = sum(int(t) for t in source_tokens) * int(top_k)
367405
hist_total = sum(sum(row) for row in expert_histogram)
368406
if hist_total != total_slots:
369407
raise ValueError(
370-
f"expert_histogram sum={hist_total} must equal sum(per_rank_num_tokens) * top_k = {total_slots}"
408+
f"expert_histogram sum={hist_total} must equal aggregate source tokens * top_k = {total_slots}"
371409
)
372410

373411
return RoutingPlan(

tests/microbenchmarks/bench_moe/routing/materialize.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,21 @@ def _flatten_plan_slots_for_rank(
5555
experts_per_rank: int,
5656
moe_ep_size: int,
5757
) -> List[int]:
58-
"""Flatten one plan row into expert ids while preserving slot counts."""
59-
local_num_tokens = int(plan.per_rank_num_tokens[src_rank])
58+
"""Flatten one plan row into expert ids while preserving slot counts.
59+
60+
``local_num_tokens`` is derived from the dispatch-matrix row sum rather
61+
than from ``per_rank_num_tokens[src_rank]``. In MoE-TP + attention-DP
62+
layouts (DTP / CUSTOM-DP) the dispatch matrix is EP-axis indexed while
63+
``per_rank_num_tokens`` is world-rank indexed; the row sum is always the
64+
correct EP-axis aggregate (``source_tokens[src_rank] * top_k``).
65+
"""
6066
row = list(plan.dispatch_matrix[src_rank])
61-
if sum(row) != local_num_tokens * top_k:
67+
row_sum = sum(row)
68+
if top_k > 0 and row_sum % top_k != 0:
6269
raise ValueError(
63-
f"dispatch_matrix row sum ({sum(row)}) must equal local_num_tokens*top_k "
64-
f"({local_num_tokens * top_k}) for rank {src_rank}"
70+
f"dispatch_matrix row {src_rank} sum ({row_sum}) is not divisible by top_k ({top_k})"
6571
)
72+
local_num_tokens = row_sum // top_k if top_k > 0 else 0
6673

6774
flat: List[int] = []
6875
for dst in range(moe_ep_size):
@@ -174,7 +181,13 @@ def _materialize_selected_experts_for_rank(
174181
4. Run a small repair pass that swaps duplicated expert ids between
175182
rows until each token has ``top_k`` distinct experts.
176183
"""
177-
local_num_tokens = int(plan.per_rank_num_tokens[src_rank])
184+
# Derive the effective token count from the dispatch-matrix row sum so that
185+
# MoE-TP + attention-DP layouts (DTP / CUSTOM-DP) are handled correctly.
186+
# In those layouts the row sum equals the aggregated source tokens for the
187+
# EP rank, while per_rank_num_tokens[src_rank] would only reflect one DP
188+
# shard's contribution.
189+
row_sum = sum(plan.dispatch_matrix[src_rank])
190+
local_num_tokens = row_sum // max(top_k, 1)
178191
if local_num_tokens == 0:
179192
ids = torch.zeros((0, top_k), dtype=torch.int32, device=device)
180193
scales = torch.zeros((0, top_k), dtype=scale_dtype, device=device)

tests/microbenchmarks/bench_moe/routing/native_logits.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,13 @@ def _project_router_logits_for_plan(
123123
Returns ``(router_logits, status, reason)`` where ``status`` is one of
124124
``"exact"``, ``"projected"``, or ``"rejected"``.
125125
"""
126-
local_num_tokens = int(plan.per_rank_num_tokens[src_rank])
126+
# Derive the effective token count from the dispatch-matrix row sum.
127+
# In MoE-TP + attention-DP layouts (DTP / CUSTOM-DP) the row sum equals
128+
# the aggregated source tokens for the EP rank (which is what the router
129+
# sees after the in-MoE allgather), while per_rank_num_tokens[src_rank]
130+
# would only cover one DP shard.
131+
row_sum = sum(plan.dispatch_matrix[src_rank])
132+
local_num_tokens = row_sum // max(top_k, 1) if row_sum > 0 else 0
127133
if local_num_tokens == 0:
128134
return (
129135
torch.empty((0, num_experts), dtype=dtype, device=device),

0 commit comments

Comments
 (0)