Skip to content

[None][feat] Support post-norm and per-aux fc_norm for Eagle3 draft models#14988

Open
Dogacel wants to merge 11 commits into
NVIDIA:mainfrom
Dogacel:eagle3_1-postnorm
Open

[None][feat] Support post-norm and per-aux fc_norm for Eagle3 draft models#14988
Dogacel wants to merge 11 commits into
NVIDIA:mainfrom
Dogacel:eagle3_1-postnorm

Conversation

@Dogacel

@Dogacel Dogacel commented Jun 5, 2026

Copy link
Copy Markdown

Description

Adds support for EAGLE-3.1 architecture & draft checkpoints. Related sources:

  1. Paper describing the architectural changes + rational: https://arxiv.org/abs/2605.09992
  2. vLLM blog post on EAGLE 3.1: https://vllm.ai/blog/2026-05-26-eagle-3-1
  3. Newly supported models: https://huggingface.co/Dogacel/specdrift-gpt-oss-120b-eagle3, https://huggingface.co/Dogacel/specdrift-gpt-oss-20b-eagle3, https://huggingface.co/lightseekorg/kimi-k2.6-eagle3.1-mla (not tested)

Core changes:

  • norm_output (post-norm): return the post-final-norm hidden state as the auxiliary feature fed to the next draft step.
  • fc_norm (per-aux norm): apply a separate RMSNorm to each captured hidden state before the fc projection. Unlike the existing single-norm norm_before_fc (one norm over the full concatenated vector), this normalizes each of the num_capture_layers features independently so they contribute equally regardless of raw scale.

The change is additive and behavior-preserving for existing drafters.

Results

gpt-oss-120b, 2×H100, max_draft_len: 7, greedy. Values are tokens / forward pass (accepted_len = value − 1; 1.00 = no speculation, 8.00 = theoretical max).

prompt baseline (NVIDIA) low / high EAGLE3.1 (ours) low / high
trivia_1line 2.40 / 2.48 5.00 / 3.93
factual_para 1.59 / 1.81 1.98 / 2.53
code_gen 2.29 / 1.99 3.35 / 2.71
math_reason 2.44 / 2.07 4.03 / 3.85
logic_puzzle 2.00 / 2.21 3.36 / 4.34
mean tok/forward 2.13 3.51
mean accepted_len ~1.13 ~2.51

Models:

Testing Strategy

Both models tested using openai/gpt-oss-120b on 2×H100 with --tp_size 2, differing only in the drafter + config.

regular.sh (baseline — stock NVIDIA 3-layer drafter, used to confirm no regression):

trtllm-serve openai/gpt-oss-120b \
    --host 0.0.0.0 --port 8888 \
    --backend pytorch \
    --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 \
    --tp_size 2 \
    --extra_llm_api_options /workdir/extra-llm-api-config.yml

extra-llm-api-config.yml:

enable_attention_dp: false
disable_overlap_scheduler: true
enable_autotuner: false
return_perf_metrics: true
perf_metrics_max_requests: 4096
cuda_graph_config:
  max_batch_size: 1
speculative_config:
  decoding_type: Eagle
  max_draft_len: 7
  speculative_model_dir: nvidia/gpt-oss-120b-Eagle3-v3
kv_cache_config:
  enable_block_reuse: false

new.sh (new model architecture supported using this PR):

trtllm-serve openai/gpt-oss-120b \
    --host 0.0.0.0 --port 8888 \
    --backend pytorch \
    --max_batch_size 32 --max_num_tokens 8192 --max_seq_len 8192 \
    --tp_size 2 \
    --extra_llm_api_options /workdir/new-eagle3-config.yml

new-eagle3-config.yml (the 5 capture layers come from the drafter's eagle_aux_hidden_state_layer_ids; wiring them here makes num_capture_layers == 5, which drives the fc input dim, the fc_norm count, and the target-model capture points):

enable_attention_dp: false
disable_overlap_scheduler: true
enable_autotuner: false
return_perf_metrics: true
perf_metrics_max_requests: 4096
cuda_graph_config:
  max_batch_size: 1
speculative_config:
  decoding_type: Eagle
  max_draft_len: 7
  speculative_model_dir: dogacel/specdrift-gpt-oss-120b-eagle3
  eagle3_layers_to_capture: [1, 9, 17, 25, 33]
kv_cache_config:
  enable_block_reuse: false

Speculative decoding acceptance length testing script used (Mostly AI generated):

#!/usr/bin/env python3
"""Sweep prompt type x reasoning_effort and report Eagle3 acceptance (tokens/forward)."""
import json
import sys
import time

import requests

URL = "http://localhost:8888"
MODEL = "openai/gpt-oss-120b"
MAX_TOKENS = 512

PROMPTS = {
    "trivia_1line":  "What is the capital of France?",
    "factual_para":  "Write 150 words about the history of the bicycle.",
    "code_gen":      "Write a Python function `binary_search(arr, target)` with a "
                     "docstring, type hints, and handle the empty-list case.",
    "math_reason":   "Compute the determinant of [[2,1,3],[0,4,1],[5,2,1]] showing "
                     "every step of cofactor expansion.",
    "logic_puzzle":  "A farmer has 17 sheep. All but 9 run away. How many remain? "
                     "Reason carefully and double-check before answering.",
}
EFFORTS = ["low", "high"]


def drain():
    try:
        requests.get(f"{URL}/perf_metrics", timeout=5)
    except requests.RequestException:
        pass


def probe(prompt, effort):
    drain()
    body = {
        "model": MODEL, "temperature": 0, "max_tokens": MAX_TOKENS,
        "reasoning_effort": effort,
        "messages": [{"role": "user", "content": prompt}],
    }
    t0 = time.perf_counter()
    r = requests.post(f"{URL}/v1/chat/completions", json=body, timeout=300)
    dt = time.perf_counter() - t0
    r.raise_for_status()
    d = r.json()
    msg = d["choices"][0]["message"]
    ctok = d["usage"]["completion_tokens"]
    reasoning_tok = len((msg.get("reasoning_content") or "").split())
    time.sleep(0.3)
    pm = requests.get(f"{URL}/perf_metrics", timeout=5).json()
    steps = 0
    for e in pm:
        sm = (e.get("time_breakdown_metrics") or {}).get("step_metrics")
        if sm:
            steps += len(sm)
        else:
            p = e.get("perf_metrics") or {}
            if p.get("last_iter") is not None:
                steps += p["last_iter"] - p["first_iter"]
    tpf = ctok / steps if steps else float("nan")
    return tpf, ctok, steps, reasoning_tok, ctok / dt if dt else 0


def main():
    try:
        if requests.get(f"{URL}/health", timeout=3).status_code != 200:
            raise RuntimeError
    except Exception:
        print(f"[!] no healthy server at {URL}")
        sys.exit(1)

    print(f"{'prompt':<14}{'effort':<7}{'tok/fwd':>8}{'acc_len':>8}"
          f"{'~reason_wd':>11}{'ctok':>7}{'steps':>7}{'tok/s':>8}")
    print("-" * 70)
    best = (None, -1)
    rows = []
    for name, prompt in PROMPTS.items():
        for effort in EFFORTS:
            try:
                tpf, ctok, steps, rtok, tps = probe(prompt, effort)
            except requests.RequestException as e:
                print(f"{name:<14}{effort:<7}  ERROR {e}")
                continue
            rows.append((name, effort, tpf))
            print(f"{name:<14}{effort:<7}{tpf:>8.2f}{tpf-1:>8.2f}"
                  f"{rtok:>11}{ctok:>7}{steps:>7}{tps:>8.1f}")
            if tpf > best[1]:
                best = (f"{name}/{effort}", tpf)

    print("-" * 70)
    if rows:
        avg = sum(r[2] for r in rows) / len(rows)
        print(f"mean tok/forward across configs: {avg:.2f}")
        print(f"best: {best[0]} @ {best[1]:.2f} tok/forward "
              f"(accepted_len {best[1]-1:.2f})")


if __name__ == "__main__":
    main()

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • If PR introduces API changes, an appropriate PR label is added - either api-compatible or api-breaking. For api-breaking, include BREAKING in the PR title.

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

Summary by CodeRabbit

  • New Features

    • Added optional per-capture-layer normalization support for Eagle3 draft models, configurable through settings.
  • Improvements

    • Enhanced hidden state processing during forward pass with conditional normalization application based on model configuration.

…odels

Enable SGLang-style Eagle3 draft checkpoints in the PyTorch backend:

- norm_output (post-norm): return the post-final-norm hidden state as the
  auxiliary feature fed to the next draft step, in addition to the existing
  eagle_config return_hidden_post_norm flag.
- fc_norm: per-aux-layer RMSNorm applied to each captured hidden state before
  the fc projection. Unlike the existing single-norm norm_before_fc, this
  normalizes each of the num_capture_layers features independently.

Combined with the configurable num_capture_layers (eagle3_layers_to_capture),
this allows running drafters with 5 aux capture layers such as
dogacel/specdrift-gpt-oss-120b-eagle3.

Signed-off-by: Doğaç Eldenk <dogacel@gmail.com>
@Dogacel Dogacel marked this pull request as ready for review June 5, 2026 04:55
@Dogacel Dogacel requested a review from a team as a code owner June 5, 2026 04:55
@Dogacel Dogacel requested a review from yechank-nvidia June 5, 2026 04:55
@coderabbitai

coderabbitai Bot commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Review Change Stack

📝 Walkthrough

Walkthrough

Eagle3 draft model implementation adds optional per-capture-layer normalization through fc_norm. Configuration determines when normalization applies; initialization creates RMSNorm layers per capture layer when enabled; the forward pass chunks hidden states and applies corresponding normalization before projection.

Changes

Eagle3 per-capture-layer normalization

Layer / File(s) Summary
fc_norm configuration and module initialization
tensorrt_llm/_torch/models/modeling_speculative.py
Eagle3DraftModel computes _return_hidden_post_norm from eagle_config or config.norm_output, and conditionally creates fc_norm as a ModuleList of RMSNorm instances—one per spec_config.num_capture_layers—when config.fc_norm is enabled.
fc_norm application in forward path
tensorrt_llm/_torch/models/modeling_speculative.py
Eagle3ForCausalLM.apply_eagle3_fc() checks for model.fc_norm presence; when set and hidden-state projection is required, it chunks hidden_states along the last dimension per fc_norm length, applies the corresponding RMSNorm to each chunk, and concatenates normalized chunks before projection.

🎯 2 (Simple) | ⏱️ ~12 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 50.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title is fully related to the main change in the changeset, clearly summarizing the two core features (post-norm and per-aux fc_norm) added to support Eagle3 draft models.
Description check ✅ Passed The description is comprehensive and well-structured, covering the motivation (with references), core changes, performance results, testing strategy with detailed examples, and addressing the PR checklist items.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Warning

Review ran into problems

🔥 Problems

Stopped waiting for pipeline failures after 30000ms. One of your pipelines takes longer than our 30000ms fetch window to run, so review may not consider pipeline-failure results for inline comments if any failures occurred after the fetch window. Increase the timeout if you want to wait longer or run a @coderabbit review after the pipeline has finished.


Comment @coderabbitai help to get the list of available commands and usage tips.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tensorrt_llm/_torch/models/modeling_speculative.py`:
- Around line 606-613: The zip used when applying per-chunk normalization over
self.model.fc_norm and chunks can silently truncate if lengths diverge; update
the comprehension in the fc_norm branch (where hidden_states is split with
hidden_states.chunk and norms applied via for norm, chunk in
zip(self.model.fc_norm, chunks)) to call zip(self.model.fc_norm, chunks,
strict=True) so mismatched lengths raise an error, ensuring strict pairing
between norm layers and chunks.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b4434610-3b52-4e63-83bc-0523f95a676f

📥 Commits

Reviewing files that changed from the base of the PR and between 81e86a5 and 0ac8947.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/models/modeling_speculative.py

Comment thread tensorrt_llm/_torch/models/modeling_speculative.py
@benchislett

Copy link
Copy Markdown
  1. @mikeiovine Is there a bug in the support for https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-v3? The reported mean acceptance length seems extremely low
  2. @Dogacel Could you please run evals on an actual dataset, MTBench and/or SPEED-Bench, so that we can compare the results against the model cards?

@mikeiovine

Copy link
Copy Markdown
Collaborator

@mikeiovine Is there a bug in the support for https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-v3? The reported mean acceptance length seems extremely low

Will check. For now I think this change is fine to land as code changes are straightforward.

@mikeiovine

Copy link
Copy Markdown
Collaborator

We should run a round of SPEED-bench with both the new drafters and https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-v3.

https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-v3 has expected ALs in different categories. We can compare to those numbers to figure out if there are bugs (I don't expect any issues though).

@Dogacel

Dogacel commented Jun 5, 2026

Copy link
Copy Markdown
Author
  1. @mikeiovine Is there a bug in the support for https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-v3? The reported mean acceptance length seems extremely low

    1. @Dogacel Could you please run evals on an actual dataset, MTBench and/or SPEED-Bench, so that we can compare the results against the model cards?

I've re-run the benchmark on MT-Bench, scoring 2.8 AL matching other implementations & expectation.

@mikeiovine

Copy link
Copy Markdown
Collaborator

2.8 AL on https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-v3, right? What does the new drafter get?

@Dogacel

Dogacel commented Jun 5, 2026

Copy link
Copy Markdown
Author

2.8 AL on https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-v3, right? What does the new drafter get?

Oh sorry for not clarifying. I've only tested my drafter (specdrift). I run the benchmark to validate the implementation is correct and we get the expected acceptance length.

I think testing NVIDIA's is out of scope for this PR. I've only run it as a dry run to validate things are not entirely broken. I previously tested that in vLLM and results were similar to our model in standard benchmarks. The difference is visible in OOD cases.

@hongyanz

hongyanz commented Jun 5, 2026

Copy link
Copy Markdown

Hi @mikeiovine , let's focus on merging this PR only. As the codes are all open sourced, please feel free to run any tests that you are interested in at your side, as different machines may vary the performance too. Thank you.

@laikhtewari laikhtewari enabled auto-merge (squash) June 10, 2026 23:03
@mikeiovine

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53410 [ run ] triggered by Bot. Commit: d7b4ac1 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53410 [ run ] completed with state SUCCESS. Commit: d7b4ac1
/LLM/main/L0_MergeRequest_PR pipeline #42583 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@mikeiovine

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53688 [ run ] triggered by Bot. Commit: 3f390e6 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53688 [ run ] completed with state SUCCESS. Commit: 3f390e6
/LLM/main/L0_MergeRequest_PR pipeline #42825 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@mikeiovine

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53906 [ run ] triggered by Bot. Commit: 6d774b9 Link to invocation

@mikeiovine

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53941 [ run ] triggered by Bot. Commit: af33ab9 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53906 [ run ] completed with state ABORTED. Commit: 6d774b9

Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #53941 [ run ] completed with state FAILURE. Commit: af33ab9
/LLM/main/L0_MergeRequest_PR pipeline #43033 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@mikeiovine

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54064 [ run ] triggered by Bot. Commit: d61c269 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54064 [ run ] completed with state SUCCESS. Commit: d61c269
/LLM/main/L0_MergeRequest_PR pipeline #43148 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@mikeiovine

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@mikeiovine

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54383 [ run ] triggered by Bot. Commit: 7fa16ca Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54384 [ run ] triggered by Bot. Commit: 7fa16ca Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54383 [ run ] completed with state ABORTED. Commit: 7fa16ca

Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54384 [ run ] completed with state FAILURE. Commit: 7fa16ca
/LLM/main/L0_MergeRequest_PR pipeline #43454 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@mikeiovine

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54624 [ run ] triggered by Bot. Commit: 0f9ef99 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #54624 [ run ] completed with state FAILURE. Commit: 0f9ef99
/LLM/main/L0_MergeRequest_PR pipeline #43673 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@mikeiovine

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55045 [ run ] triggered by Bot. Commit: 03f46b9 Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55045 [ run ] completed with state SUCCESS. Commit: 03f46b9
/LLM/main/L0_MergeRequest_PR pipeline #44036 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@mikeiovine

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55254 [ run ] triggered by Bot. Commit: 36da7de Link to invocation

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55254 [ run ] completed with state SUCCESS. Commit: 36da7de
/LLM/main/L0_MergeRequest_PR pipeline #44213 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@mikeiovine

Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd

Copy link
Copy Markdown
Collaborator

PR_Github #55326 [ run ] triggered by Bot. Commit: 9b0dd64 Link to invocation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants