@@ -147,23 +147,56 @@ def _per_rank_tokens(workload: WorkloadSpec, world_size: int, enable_dp: bool) -
147147 )
148148
149149
150+ def _aggregate_dispatch_source_tokens (
151+ per_rank_num_tokens : List [int ],
152+ ep_size : int ,
153+ enable_dp : bool ,
154+ ) -> List [int ]:
155+ """Project world-rank token counts onto EP-source rows.
156+
157+ TRT-LLM Mapping orders MoE ranks with ``moe_ep_rank = tp_rank % moe_ep_size``.
158+ In attention-DP modes each world rank owns a distinct token shard, so TP
159+ shards targeting the same EP row are summed. In non-DP MoE-TP modes those TP
160+ shards carry the same logical tokens, so only the first TP shard contributes
161+ to the logical dispatch plan.
162+ """
163+ if ep_size <= 0 :
164+ return []
165+ if len (per_rank_num_tokens ) == ep_size :
166+ return [int (v ) for v in per_rank_num_tokens ]
167+
168+ source_tokens = [0 ] * ep_size
169+ if not enable_dp :
170+ for ep_rank in range (ep_size ):
171+ if ep_rank < len (per_rank_num_tokens ):
172+ source_tokens [ep_rank ] = int (per_rank_num_tokens [ep_rank ])
173+ else :
174+ for rank , num_tokens in enumerate (per_rank_num_tokens ):
175+ source_tokens [rank % ep_size ] += int (num_tokens )
176+
177+ return source_tokens
178+
179+
150180def _build_dispatch_matrix (
151181 comm_pattern : str ,
152182 per_rank_num_tokens : List [int ],
153183 top_k : int ,
154184 ep_size : int ,
185+ enable_dp : bool ,
155186 seed : int = 0 ,
156187) -> List [List [int ]]:
157188 """Build the canonical slot ``dispatch_matrix`` for ``comm_pattern``.
158189
159- Row sums always equal ``per_rank_num_tokens[src] * top_k``. The matrix is
160- a pure planning artefact: it does not enforce per-token uniqueness yet.
161- That constraint is checked at materialisation time.
190+ Row sums equal the EP-source token counts projected from
191+ ``per_rank_num_tokens`` times ``top_k``. When world ranks outnumber EP
192+ ranks (DTP / TTP / CUSTOM MoE-TP layouts), multiple world-rank rows are
193+ aggregated onto the same EP-source row.
162194 """
163195 name , kwargs = _parse_comm_pattern (comm_pattern )
196+ source_tokens = _aggregate_dispatch_source_tokens (per_rank_num_tokens , ep_size , enable_dp )
164197 matrix : List [List [int ]] = [[0 ] * ep_size for _ in range (ep_size )]
165198 for src in range (ep_size ):
166- row_total = int (per_rank_num_tokens [src ]) * int (top_k )
199+ row_total = int (source_tokens [src ]) * int (top_k )
167200 if row_total == 0 :
168201 continue
169202 if name == "file" :
@@ -332,10 +365,9 @@ def _build_routing_plan(
332365 if top_k > num_experts :
333366 raise ValueError (f"top_k ({ top_k } ) must be <= num_experts ({ num_experts } )" )
334367 per_rank = _build_per_rank_num_tokens (spec , num_tokens , world_size , enable_dp )
335- # The dispatch matrix is indexed by EP rank on both axes. The current
336- # worker only calls routing-control planning when ``moe_ep_size`` equals
337- # ``world_size`` so that this EP-axis matrix also matches the user-visible
338- # per-rank token list.
368+ # The dispatch matrix stays on EP axes. When MoE-TP makes multiple world
369+ # ranks share one EP rank, the world-rank token counts are aggregated onto
370+ # the corresponding EP-source row before building the matrix.
339371 if spec .routing_pattern_file :
340372 default_patterns = {("balanced_alltoall" , "balanced" ), ("random" , "random" )}
341373 if (spec .comm_pattern , spec .expert_pattern ) not in default_patterns :
@@ -348,26 +380,32 @@ def _build_routing_plan(
348380 )
349381 else :
350382 dispatch_matrix = _build_dispatch_matrix (
351- spec .comm_pattern , per_rank , top_k , moe_ep_size , seed = spec .seed
383+ spec .comm_pattern ,
384+ per_rank ,
385+ top_k ,
386+ moe_ep_size ,
387+ enable_dp = enable_dp ,
388+ seed = spec .seed ,
352389 )
353390 expert_histogram = _build_expert_histogram (
354391 spec .expert_pattern , dispatch_matrix , experts_per_rank , moe_ep_size , seed = spec .seed
355392 )
356393
357394 # Per-row sums are an invariant; emit a clearer error than the materialiser would.
395+ source_tokens = _aggregate_dispatch_source_tokens (per_rank , moe_ep_size , enable_dp )
358396 for src in range (moe_ep_size ):
359- expected = int (per_rank [src ]) * int (top_k ) if src < len (per_rank ) else 0
397+ expected = int (source_tokens [src ]) * int (top_k ) if src < len (source_tokens ) else 0
360398 actual = sum (dispatch_matrix [src ])
361399 if actual != expected :
362400 raise ValueError (
363- f"dispatch_matrix row { src } sums to { actual } , expected per_rank_num_tokens[ { src } ] * top_k = { expected } "
401+ f"dispatch_matrix row { src } sums to { actual } , expected aggregate source tokens * top_k = { expected } "
364402 )
365403 # Global expert histogram total must match total slots.
366- total_slots = sum (int (t ) for t in per_rank ) * int (top_k )
404+ total_slots = sum (int (t ) for t in source_tokens ) * int (top_k )
367405 hist_total = sum (sum (row ) for row in expert_histogram )
368406 if hist_total != total_slots :
369407 raise ValueError (
370- f"expert_histogram sum={ hist_total } must equal sum(per_rank_num_tokens) * top_k = { total_slots } "
408+ f"expert_histogram sum={ hist_total } must equal aggregate source tokens * top_k = { total_slots } "
371409 )
372410
373411 return RoutingPlan (
0 commit comments