Skip to content

Commit 16e577c

Browse files
[TRTLLM-12669][fix] only route plain-TP greedy MTP-Eagle draft sampling through draft_sampler
The previous fix routed every greedy MTP-Eagle draft step through draft_sampler(), but that call does not forward mapping_lm_head_tp. For the LM-head-TP-in-ADP configuration draft_sampler() then takes its ADP branch with a None mapping and crashes during warmup with "'NoneType' object has no attribute 'tp_group'" (Executor worker returned error), e.g. DeepSeek-R1 nvfp4 latency_adp_lmtp_tp4. Only plain tensor parallelism (tp_size>1 without attention DP) shards the draft logits over the vocab dim and needs draft_sampler()'s all-gather argmax. The LM-head-TP-in-ADP case already yields full-vocab logits per rank (gathered upstream) and the no-TP / Eagle3 cases need nothing, so all of those take the plain d2t-aware argmax (_draft_sampler_greedy), restoring the pre-regression behavior for ADP while keeping the plain-TP hang fix. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
1 parent 85468e1 commit 16e577c

1 file changed

Lines changed: 12 additions & 1 deletion

File tree

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1210,7 +1210,18 @@ def draft_decoder(
12101210
# before argmax (and falls back to a plain argmax when no TP gather is
12111211
# needed). Eagle3 (non-MTP) keeps its d2t-aware argmax.
12121212
if spec_metadata.is_all_greedy_sample:
1213-
if self.is_mtp_eagle:
1213+
# Only plain tensor parallelism (tp_size>1 without attention DP)
1214+
# shards the draft logits over the vocab dim and thus needs
1215+
# draft_sampler()'s all-gather argmax. The LM-head-TP-in-ADP case
1216+
# already produces full-vocab logits per rank (gathered upstream),
1217+
# and the no-TP / Eagle3 cases need nothing, so they take the plain
1218+
# d2t-aware argmax. (Routing ADP/LM-head-TP through draft_sampler
1219+
# without its mapping_lm_head_tp arg hits the None-mapping branch
1220+
# and crashes with 'NoneType has no attribute tp_group'.)
1221+
if (self.is_mtp_eagle and self.model_config is not None
1222+
and hasattr(self.model_config, 'mapping')
1223+
and self.model_config.mapping.tp_size > 1
1224+
and not self.model_config.mapping.enable_attention_dp):
12141225
return self.draft_sampler(logits)
12151226
return self._draft_sampler_greedy(logits, d2t)
12161227
# Non-greedy (advanced) draft sampling has the same TP hazard as the

0 commit comments

Comments
 (0)