[None][feat] Support post-norm and per-aux fc_norm for Eagle3 draft models#14988
[None][feat] Support post-norm and per-aux fc_norm for Eagle3 draft models#14988Dogacel wants to merge 11 commits into
Conversation
…odels Enable SGLang-style Eagle3 draft checkpoints in the PyTorch backend: - norm_output (post-norm): return the post-final-norm hidden state as the auxiliary feature fed to the next draft step, in addition to the existing eagle_config return_hidden_post_norm flag. - fc_norm: per-aux-layer RMSNorm applied to each captured hidden state before the fc projection. Unlike the existing single-norm norm_before_fc, this normalizes each of the num_capture_layers features independently. Combined with the configurable num_capture_layers (eagle3_layers_to_capture), this allows running drafters with 5 aux capture layers such as dogacel/specdrift-gpt-oss-120b-eagle3. Signed-off-by: Doğaç Eldenk <dogacel@gmail.com>
📝 WalkthroughWalkthroughEagle3 draft model implementation adds optional per-capture-layer normalization through ChangesEagle3 per-capture-layer normalization
🎯 2 (Simple) | ⏱️ ~12 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Warning Review ran into problems🔥 ProblemsStopped waiting for pipeline failures after 30000ms. One of your pipelines takes longer than our 30000ms fetch window to run, so review may not consider pipeline-failure results for inline comments if any failures occurred after the fetch window. Increase the timeout if you want to wait longer or run a Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/models/modeling_speculative.py`:
- Around line 606-613: The zip used when applying per-chunk normalization over
self.model.fc_norm and chunks can silently truncate if lengths diverge; update
the comprehension in the fc_norm branch (where hidden_states is split with
hidden_states.chunk and norms applied via for norm, chunk in
zip(self.model.fc_norm, chunks)) to call zip(self.model.fc_norm, chunks,
strict=True) so mismatched lengths raise an error, ensuring strict pairing
between norm layers and chunks.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b4434610-3b52-4e63-83bc-0523f95a676f
📒 Files selected for processing (1)
tensorrt_llm/_torch/models/modeling_speculative.py
|
Will check. For now I think this change is fine to land as code changes are straightforward. |
|
We should run a round of SPEED-bench with both the new drafters and https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-v3. https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-v3 has expected ALs in different categories. We can compare to those numbers to figure out if there are bugs (I don't expect any issues though). |
I've re-run the benchmark on MT-Bench, scoring 2.8 AL matching other implementations & expectation. |
|
2.8 AL on https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-v3, right? What does the new drafter get? |
Oh sorry for not clarifying. I've only tested my drafter (specdrift). I run the benchmark to validate the implementation is correct and we get the expected acceptance length. I think testing NVIDIA's is out of scope for this PR. I've only run it as a dry run to validate things are not entirely broken. I previously tested that in vLLM and results were similar to our model in standard benchmarks. The difference is visible in OOD cases. |
|
Hi @mikeiovine , let's focus on merging this PR only. As the codes are all open sourced, please feel free to run any tests that you are interested in at your side, as different machines may vary the performance too. Thank you. |
|
/bot run --disable-fail-fast |
|
PR_Github #53410 [ run ] triggered by Bot. Commit: |
|
PR_Github #53410 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #53688 [ run ] triggered by Bot. Commit: |
|
PR_Github #53688 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #53906 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #53941 [ run ] triggered by Bot. Commit: |
|
PR_Github #53906 [ run ] completed with state |
|
PR_Github #53941 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #54064 [ run ] triggered by Bot. Commit: |
|
PR_Github #54064 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
/bot run --disable-fail-fast |
|
PR_Github #54383 [ run ] triggered by Bot. Commit: |
|
PR_Github #54384 [ run ] triggered by Bot. Commit: |
|
PR_Github #54383 [ run ] completed with state |
|
PR_Github #54384 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #54624 [ run ] triggered by Bot. Commit: |
|
PR_Github #54624 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #55045 [ run ] triggered by Bot. Commit: |
|
PR_Github #55045 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #55254 [ run ] triggered by Bot. Commit: |
|
PR_Github #55254 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #55326 [ run ] triggered by Bot. Commit: |
Description
Adds support for EAGLE-3.1 architecture & draft checkpoints. Related sources:
Core changes:
norm_output(post-norm): return the post-final-norm hidden state as the auxiliary feature fed to the next draft step.fc_norm(per-aux norm): apply a separateRMSNormto each captured hidden state before thefcprojection. Unlike the existing single-normnorm_before_fc(one norm over the full concatenated vector), this normalizes each of thenum_capture_layersfeatures independently so they contribute equally regardless of raw scale.The change is additive and behavior-preserving for existing drafters.
Results
gpt-oss-120b, 2×H100,
max_draft_len: 7, greedy. Values are tokens / forward pass (accepted_len = value − 1;1.00= no speculation,8.00= theoretical max).Models:
Testing Strategy
Both models tested using
openai/gpt-oss-120bon 2×H100 with--tp_size 2, differing only in the drafter + config.regular.sh(baseline — stock NVIDIA 3-layer drafter, used to confirm no regression):trtllm-serve openai/gpt-oss-120b \ --host 0.0.0.0 --port 8888 \ --backend pytorch \ --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 \ --tp_size 2 \ --extra_llm_api_options /workdir/extra-llm-api-config.ymlextra-llm-api-config.yml:new.sh(new model architecture supported using this PR):trtllm-serve openai/gpt-oss-120b \ --host 0.0.0.0 --port 8888 \ --backend pytorch \ --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 \ --tp_size 2 \ --extra_llm_api_options /workdir/new-eagle3-config.ymlnew-eagle3-config.yml(the 5 capture layers come from the drafter'seagle_aux_hidden_state_layer_ids; wiring them here makesnum_capture_layers == 5, which drives thefcinput dim, thefc_normcount, and the target-model capture points):Speculative decoding acceptance length testing script used (Mostly AI generated):
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
Summary by CodeRabbit
New Features
Improvements