[TRTLLM-12669][refactor] Eagle3 sampling: auto-detect greedy fast-path, mixed-batch rejection sampling, draft honors target params#14745
Conversation
903b453 to
d237690
Compare
|
/bot run |
|
Hi @mikeiovine please help to review this PR, thanks~ |
|
PR_Github #51043 [ run ] triggered by Bot. Commit: |
|
PR_Github #51043 [ run ] completed with state
|
|
Hi @NVIDIA/trt-llm-doc-owners @NVIDIA/trt-llm-llmapi-devs @NVIDIA/trt-llm-qa-function @NVIDIA/trt-llm-torch-models-devs @NVIDIA/trt-llm-torch-runtime-devs @NVIDIA/trt-llm-torch-spec-decoding please help to review this PR, thanks a lot. |
|
/bot run |
|
PR_Github #51297 [ run ] triggered by Bot. Commit: |
|
PR_Github #51297 [ run ] completed with state |
📝 WalkthroughWalkthroughThis PR replaces the ChangesSpeculative Decoding Sampling Configuration Refactoring
🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
mikeiovine
left a comment
There was a problem hiding this comment.
Are any other spec algos always using greedy for draft sampling? That will need to be fixed too in a follow up
775bae6 to
d6fd852
Compare
'MTPForCausalLM' does not store its constructor's 'model' argument as self.model, so getattr(draft_model.model, "d2t", None) raised AttributeError when draft_decoder was called in MTP Eagle mode. Use nested getattr to safely return None when draft_model has no 'model' attribute (MTP Eagle never uses a compressed vocabulary so d2t is always None for that mode anyway). Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…ble top_k_max
Two bugs exposed by the new forced non-greedy CUDA graph capture pass:
1. SpecMetadata.populate_sampling_params_for_one_model: buffer size is
tokens_per_request * max_num_requests, but warmup batches can have
more total tokens when batch_size > max_num_requests. Fix by using
max(static_required, actual_flat_size) for buffer allocation.
2. Eagle3 dynamic tree rejection: verify_dynamic_tree_rejection_from_logits_out
computed top_k_max via boolean tensor indexing + .item(), both
CUDA-graph-incompatible. Fix by:
- Pre-computing top_k_max CPU-side in populate_sampling_params_for_one_model
- Passing top_k_max=0 during stream capture (forces full-sort path,
always correct) and the pre-computed value during eager execution
- Adding top_k_max optional param to verify_dynamic_tree_rejection_from_logits_out
Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…ure and fix PARD num_tokens shape mismatch - eagle3_dynamic_tree.py: _can_use_rejection_sampling now returns False when spec_metadata.is_cuda_graph is True. The rejection ops (compute_draft_probs_for_dynamic_tree_rejection_op) use a full-sort fallback with dynamic allocation that is incompatible with CUDA stream capture, causing cudaErrorStreamCaptureUnsupported. - interface.py: _sample_tokens_for_batch now derives num_tokens from logits.shape[0] instead of computing it from runtime_draft_len. For PARD under CUDA graph capture runtime_draft_len can be the PARD-max while the graph was built for a shorter draft_len, causing a shape mismatch in the torch.compiled sampling_batch_spec_dec_one_model. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…tp_in_adp Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…g in ADP+LM-head-TP In the MTP-Eagle ADP + LM-head-TP path, draft logits are zero-padded to max_num_requests so every TP rank produces an identically-shaped tensor for the LM-head-TP all-gather. The refactored draft sampler applies per-request temperature/top_k/top_p tensors sized to token_count (== batch_size), so the padded logits ([max_num_requests, vocab]) failed to broadcast against the [batch_size, 1] temperature in apply_temperature, crashing torch.compile fake tensor tracing during executor worker init. Drop the padded rows before sampling (logits = logits[:token_count]) instead of trimming the sampled tokens afterwards. This keeps logits, next_draft_tokens and the draft_probs buffer token_count-sized and lets the per-request sampling params broadcast correctly. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…y selection The one-engine CUDA graph key includes is_all_greedy_sample to dispatch between the argmax fast-path and the advanced-sampling graph variant. The flag was only (re)computed inside populate_sampling_params_for_one_model, which runs in _prepare_inputs AFTER maybe_get_cuda_graph has already built the key. The key therefore used the previous iteration's stale flag, and warmup left it False (from the advanced-sampling capture pass). On the first real decode iteration a greedy batch would then replay the advanced-sampling graph while populate skips filling the sampling/draft_probs buffers, reading uninitialized slot-indexed data. For MTP with num_nextn>=2 this hung the executor (Hang detected on rank 0). Fix: - Extract the greediness detection into _scan_one_model_sampling (single source of truth) and add update_is_all_greedy_sample, called before the graph key is built so the key matches the buffers populate fills. populate now reuses the same scan. - Defensively reset spec_metadata.is_all_greedy_sample to True after CUDA graph warmup so the stale capture-only False does not seed the first iteration. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
9f39541 to
764edb7
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #54001 [ run ] triggered by Bot. Commit: |
|
PR_Github #54001 [ run ] completed with state
|
…avoid multi-GPU hang draft_decoder routed the all-greedy fast path to _draft_sampler_greedy, a plain torch.argmax. For MTP-Eagle with a tensor-parallel draft LM head (tp_size>1 without attention DP, or LM-head-TP in ADP) the draft logits are sharded along the vocab dim, so a per-rank argmax selects a different token on each rank. The ranks then desync on the speculative-decoding control flow and the next collective deadlocks, observed as a generation hang on rank 0 (e.g. DeepSeek-V3-Lite tp4 + mtp_nextn>=2 + cuda_graph + torch_compile). Restore the TP-aware path: for the all-greedy case, MTP-Eagle now uses draft_sampler(), which all-gathers the sharded draft logits before argmax (and falls back to a plain argmax when no TP gather is needed). Eagle3 (non-MTP) keeps its d2t-aware argmax. This matches the pre-refactor behavior. Root-caused and verified by local reproduction (DeepSeek-V3-Lite, tp4, mtp_nextn=2, cuda_graph, torch_compile): baseline passes, the refactor hangs, and this fix restores passing. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
…non-greedy) sampling under TP The non-greedy draft sampling path (_draft_sampler_advanced) has the same multi-GPU hazard as the greedy path that was just fixed. With a plain tensor-parallel draft LM head (tp_size>1 without attention DP) each rank only holds a vocab shard of the draft logits, so per-rank random sampling draws a different token on each rank, desyncs the speculative-decoding control flow and deadlocks the next collective (generation hang). Greedy could be repaired with draft_sampler()'s lightweight max+index all-gather, but random sampling needs the full distribution, so all-gather the sharded draft logits into the full vocab before advanced sampling. Every rank then samples from the same distribution with the shared seed. The LM-head-TP-in-ADP path is gathered upstream and is intentionally excluded. Verified by local reproduction (DeepSeek-V3-Lite, tp4, mtp_nextn=2, cuda_graph, torch_compile, non-greedy temperature/top_k/top_p): hangs without this gather, passes with it. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #54084 [ run ] triggered by Bot. Commit: |
|
PR_Github #54084 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #54110 [ run ] triggered by Bot. Commit: |
|
PR_Github #54110 [ run ] completed with state
|
…ng through draft_sampler The previous fix routed every greedy MTP-Eagle draft step through draft_sampler(), but that call does not forward mapping_lm_head_tp. For the LM-head-TP-in-ADP configuration draft_sampler() then takes its ADP branch with a None mapping and crashes during warmup with "'NoneType' object has no attribute 'tp_group'" (Executor worker returned error), e.g. DeepSeek-R1 nvfp4 latency_adp_lmtp_tp4. Only plain tensor parallelism (tp_size>1 without attention DP) shards the draft logits over the vocab dim and needs draft_sampler()'s all-gather argmax. The LM-head-TP-in-ADP case already yields full-vocab logits per rank (gathered upstream) and the no-TP / Eagle3 cases need nothing, so all of those take the plain d2t-aware argmax (_draft_sampler_greedy), restoring the pre-regression behavior for ADP while keeping the plain-TP hang fix. Signed-off-by: ZhaoyangWang <zhaoyangw@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #54140 [ run ] triggered by Bot. Commit: |
|
PR_Github #54140 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #54165 [ run ] triggered by Bot. Commit: |
|
PR_Github #54165 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #54182 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #54250 [ run ] triggered by Bot. Commit: |
|
PR_Github #54182 [ run ] completed with state |
Replace static config flag with auto-detected per-step uses_advanced_sampling based on actual sampling params. Include this in CUDA graph key so we lazily capture two graph variants (argmax fast-path vs advanced sampling kernel) and dispatch by replaying the right one.
Summary by CodeRabbit
New Features
Bug Fixes
Documentation
Description
Refactors Eagle3 one-model speculative decoding sampling. Four logical changes, sequenced as separate commits:
1. Replace
allow_advanced_samplingwith auto-detectedis_all_greedy_sampleRemoves the
allow_advanced_samplingconfig flag fromDecodingBaseConfig. Replaced with a per-stepis_all_greedy_samplederived from the actualtemperature/top_k/top_pof requests in the batch. The flag is included in the CUDA graph cache key, so two graph variants are lazily captured (argmax fast-path vs. advanced-sampling kernel) and dispatched at replay time based on batch composition.2. Eagle3 drafter honors target's sampling params
Previously the Eagle3 draft model always ran greedy regardless of the target's sampling configuration. This change propagates the target's
temperature/top_k/top_pinto the draft loop so that draft samples come from the same distribution as the target's reference distribution. This is a correctness prerequisite for non-greedy rejection sampling (Leviathan formulau * p_draft(x) < p_target(x)only holds when both probabilities come from the same conditioning).3. Slot-indexed
draft_probsto support mixed batchesPreviously
_can_use_rejection_samplingbailed out when the batch contained context requests, falling back to exact-match for the whole batch. Root cause:draft_probswas indexed by batch position, but batch position is unstable across iterations (chunked-prefill, finishing gens, new ctx joins all shift positions). Fix:draft_probsfrom flat[total_draft_tokens, vocab]to slot-indexed[max_num_requests, max_draft_len, vocab], keyed by stablepy_seq_slot._compute_and_store_draft_probs), gather on read (_accept_draft_tokens) using a precomputedbatch_slot_idstensor.num_contexts == 0constraint in_can_use_rejection_sampling— ctx subset goes through_sample_tokens_for_batch, gen subset goes through the rejection kernel.draft_probs_valid = Falsewhenever the draft loop writes no probs, so stale data is never read.Mixed-batch rejection captures ~18% sys-tps on llama70b bs=32 vs. the exact-match fallback.
Test Coverage
Unit tests (B200)
End-to-end correctness — Qwen3-8B (H100 SXM5 80G)
Greedy path verification (no temp / top_p / top_k → both paths take greedy branch):
83ec591,allow_advanced_sampling=False) vs new (d237690, auto-detectedis_all_greedy_sample=True)total_output_tokens: 113,432 vs 113,432, Δ=0mean_acceptance_rate: 0.4918 vs 0.4918, Δ=+0.00%mean_acceptance_length: 2.475 vs 2.475, Δ=0Performance — rejection sampling ON vs OFF (non-greedy)
CUDA graph enabled. mtbench dataset. Sampling params:
temperature=0.7, top_k=50, top_p=0.9.Llama-3.3-70B-Instruct + EAGLE-3 (mean over 3 rounds)
Per-round detail (3 rounds, llama-70b)
Qwen3-235B-A22B + EAGLE-3
Observation
Acceptance rate improves across all tested configurations under non-greedy sampling (+2–4% on llama-70b, +12–15% on qwen-235b). The AR uplift translates to a TPS win on qwen-235b at every batch size, but not on llama-70b, where rejection sampling currently costs 0–9% TPS. Lower batch sizes also show enough run-to-run noise that some signs flip across rounds.
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.