Skip to content

Feat: Speculatice Decoding export with quantization support#913

Merged
h-guo18 merged 2 commits intomainfrom
haoguo/eagle-export
Mar 4, 2026
Merged

Feat: Speculatice Decoding export with quantization support#913
h-guo18 merged 2 commits intomainfrom
haoguo/eagle-export

Conversation

@h-guo18
Copy link
Copy Markdown
Contributor

@h-guo18 h-guo18 commented Feb 21, 2026

What does this PR do?

Type of change: ?

Overview:

Main changes:

  • Refactored speculative decoding export logics into class EagleExporter to improve cohesion;

  • Separated speculative decoding export entrance with quantization export (export_hf_checkpoint()) due to their fundamental differences:

    • Quantization export base model's state_dict and config, while speculative decoding only export drafter's.
    • Most of the model-specific logics of quantization export (e.g. diffusers, vlms) are not needed for speculative decoding export.
    • Quantization export produce different format than speculative decoding checkpoint. (The former produce tokenizer config, generation config, e.t.c, while the later does not need. )

Usage

To export an regular bf16 eagle checkpoint without quantization, the commands are the same:

python scripts/export_hf_checkpoint.py --model_path <x> --export_path <x>

To run PTQ on online-trained eagle checkpoint and export it:

python hf_ptq.py --pyt_ckpt_path <x> --qformat fp8 --export_path <x>

The above two commands will produce drafter ckpt for deployment, in the same foramt.

Testing

Tested setting:

  • Base model: llama3.1-8b
  • Algorithms: eagle
  • Export path tested:
    • (Unquantized online ckpt) python scripts/export_hf_checkpoint.py --model_path <x> --export_path <x>
    • (PTQ) export python hf_ptq.py --pyt_ckpt_path <x> --qformat fp8 --export_path <x>
  • Tested deployment on vllm. Got normal AR.

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

  • New Features

    • Added dedicated export flow for speculative-decoding-optimized models, including multi-architecture support and deployment-ready templates.
    • Automatic detection and selection of the appropriate speculative-decoding exporter, including parallel-draft export support.
  • Refactor

    • Consolidated and simplified export logic into a unified, modular speculative-decoding export path.
  • Tests

    • Tightened export validation tests to ensure required weight entries are present.

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Feb 21, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 21, 2026

📝 Walkthrough

Walkthrough

Adds a class-based speculative decoding export pipeline (SpeculativeDecodingExporter, EagleExporter, EagleMedusaExporter), template configs, model hooks to obtain an exporter, a new public API export_speculative_decoding, and updates example scripts and tests to use the speculative export path.

Changes

Cohort / File(s) Summary
Speculative export core
modelopt/torch/export/plugins/hf_spec_export.py, modelopt/torch/export/plugins/hf_spec_configs.py
Introduces SpeculativeDecodingExporter, EagleExporter, EagleMedusaExporter, adds has_spec_opt and has_quant_opt, migrates to layered key schema, safetensors export, and adds llama_eagle_template_config and kimik2_eagle_template_config.
Unified HF export API
modelopt/torch/export/unified_export_hf.py
Adds export_speculative_decoding to public API and all, replaces inline speculative-special-case handling with explicit exporter delegation.
Model integration
modelopt/torch/speculative/plugins/transformers.py
Adds get_exporter() and _draft_model_config to HFEagleModel; selects EagleExporter or EagleMedusaExporter based on draft parallelism.
Examples / scripts
examples/speculative_decoding/scripts/export_hf_checkpoint.py, examples/llm_ptq/hf_ptq.py
Example scripts updated to import and call export_speculative_decoding; hf_ptq.py short-circuits export to speculative path when has_spec_opt(full_model) is true.
Tests
tests/examples/speculative_decoding/test_eagle.py
Adjusted assertion to expect .weight-suffixed keys in exported state_dict entries.

Sequence Diagram

sequenceDiagram
    participant CLI as User/Script
    participant HFExport as unified_export_hf.export_quantized
    participant Check as has_spec_opt()
    participant Model as HFEagleModel
    participant Exporter as EagleExporter/EagleMedusaExporter
    participant Storage as Filesystem

    CLI->>HFExport: request export(model, export_dir, dtype)
    HFExport->>Check: has_spec_opt(model)?
    alt spec-optimized
        Check-->>HFExport: true
        HFExport->>Model: model.get_exporter()
        Model-->>HFExport: Exporter instance
        HFExport->>Exporter: exporter.export(export_dir, dtype)
        Exporter->>Exporter: _extract_state_dict(full_state_dict)
        Exporter->>Storage: write model.safetensors
        Exporter->>Storage: write config.json (and hf_quant_config.json?)
        Exporter-->>HFExport: done
    else not spec-optimized
        Check-->>HFExport: false
        HFExport->>HFExport: continue standard HF export flow
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes


Important

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

❌ Failed checks (1 error)

Check name Status Explanation Resolution
Security Anti-Patterns ❌ Error Public API function uses assert for input validation, which can be disabled with Python -O/-OO flags, violating security best practices. Replace assert statement with explicit if/raise pattern to ensure deterministic validation in all deployment scenarios.
✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title mentions 'Speculatice Decoding' (misspelled) and references quantization support, which aligns with the main refactoring of speculative decoding export logic and new quantization support. However, the title is somewhat generic and doesn't capture the core change: refactoring export into EagleExporter class and separating speculative decoding from quantization export paths.
Docstring Coverage ✅ Passed Docstring coverage is 86.96% which is sufficient. The required threshold is 80.00%.
✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch haoguo/eagle-export

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Comment @coderabbitai help to get the list of available commands and usage tips.

@h-guo18 h-guo18 changed the title Feat: quantized eagle export Feat: Eagle export with quantization support Feb 21, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented Feb 21, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 72.10%. Comparing base (a4fde49) to head (5371477).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #913      +/-   ##
==========================================
- Coverage   72.12%   72.10%   -0.03%     
==========================================
  Files         209      209              
  Lines       23628    23628              
==========================================
- Hits        17042    17036       -6     
- Misses       6586     6592       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@h-guo18 h-guo18 changed the title Feat: Eagle export with quantization support Feat: Speculatice Decoding export with quantization support Feb 21, 2026
@h-guo18 h-guo18 force-pushed the haoguo/eagle-export branch from d9926e9 to 1b73de3 Compare February 21, 2026 18:36
@h-guo18 h-guo18 marked this pull request as ready for review February 21, 2026 18:37
@h-guo18 h-guo18 requested review from a team as code owners February 21, 2026 18:37
@h-guo18 h-guo18 marked this pull request as draft February 21, 2026 18:37
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🧹 Nitpick comments (1)
modelopt/torch/export/plugins/hf_spec_export.py (1)

185-214: Validation bypass is documented as temporary.

The _check_valid_sd = lambda *args, **kwargs: None on line 194 effectively disables state dict validation for parallel draft exports. The NOTE: tmp: comment indicates this is intentional but temporary.

Consider tracking this with a TODO or issue reference to ensure validation is properly implemented for parallel draft exports before the feature is considered stable.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 185 - 214, The
code currently disables state-dict validation by setting self._check_valid_sd =
lambda *args, **kwargs: None in the EagleMedusaExporter __init__, which is
marked only as a temporary NOTE; replace this silent bypass with a tracked TODO
and a visible reminder: restore validation by implementing proper checks for
parallel_draft_step in extract_state_dict and call the original
EagleExporter._check_valid_sd (or raise/log a clear warning/error) until full
validation is implemented; specifically update the EagleMedusaExporter class to
remove the no-op lambda, add a TODO/issue-ID comment referencing the missing
validation work, and ensure any call sites (e.g., extract_state_dict) invoke the
proper _check_valid_sd behavior so state-dict validation is not permanently
skipped.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 571-574: The early return after calling has_spec_opt(full_model)
and export_speculative_decoding(full_model, export_dir=export_path) skips the
subsequent tokenizer save and timing/export message; update the
speculative-decoding branch so it either (a) calls the same tokenizer save
routine (e.g., tokenizer.save_pretrained or the existing tokenizer save logic)
and prints the export/timing confirmation before returning, or (b) moves the
return to after those steps, and if skipping is intentional add a concise
comment explaining why; reference has_spec_opt, export_speculative_decoding,
full_model and export_path so the change is applied to the correct branch.

In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 180-182: Fix the typo in the docstring of export_quant_config:
change "hf_quant_coinfig.json" to "hf_quant_config.json" in the docstring for
the function export_quant_config which returns copy(self.hf_quant_config).
- Around line 144-178: In export_config, using copy(template_config) creates
only a shallow copy so nested dicts (e.g., eagle config data) are mutated on
assignment; replace the shallow copy with a deep copy (use copy.deepcopy) when
copying the selected template (referencing template_config,
llama_eagle_template_config, kimik2_eagle_template_config in the export_config
method) so modifications to nested keys do not alter the original imported
templates across multiple calls; ensure the copy module's deepcopy is
imported/used accordingly.

In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 994-996: The comment above the state-dict export is incorrect:
change the misleading "Export config.json" comment that precedes the lines using
exporter.extract_state_dict(), drafter_sd, and save_file(...,
"model.safetensors") to accurately describe exporting the model state dict
(e.g., "Export model state dict to model.safetensors"), leaving the actual
config.json export block (using save_file for config.json) unchanged.

---

Nitpick comments:
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 185-214: The code currently disables state-dict validation by
setting self._check_valid_sd = lambda *args, **kwargs: None in the
EagleMedusaExporter __init__, which is marked only as a temporary NOTE; replace
this silent bypass with a tracked TODO and a visible reminder: restore
validation by implementing proper checks for parallel_draft_step in
extract_state_dict and call the original EagleExporter._check_valid_sd (or
raise/log a clear warning/error) until full validation is implemented;
specifically update the EagleMedusaExporter class to remove the no-op lambda,
add a TODO/issue-ID comment referencing the missing validation work, and ensure
any call sites (e.g., extract_state_dict) invoke the proper _check_valid_sd
behavior so state-dict validation is not permanently skipped.

Comment thread examples/llm_ptq/hf_ptq.py
Comment thread modelopt/torch/export/plugins/hf_spec_export.py Outdated
Comment thread modelopt/torch/export/plugins/hf_spec_export.py Outdated
Comment thread modelopt/torch/export/unified_export_hf.py Outdated
@h-guo18 h-guo18 marked this pull request as ready for review February 21, 2026 19:55
Comment thread modelopt/torch/export/plugins/hf_spec_export.py
@yeyu-nvidia
Copy link
Copy Markdown
Contributor

Does TRTLLM need to be patched to support quantized eagle3?

1 similar comment
@yeyu-nvidia
Copy link
Copy Markdown
Contributor

Does TRTLLM need to be patched to support quantized eagle3?

@h-guo18 h-guo18 requested a review from yeyu-nvidia February 27, 2026 00:58
@h-guo18
Copy link
Copy Markdown
Contributor Author

h-guo18 commented Feb 27, 2026

Does TRTLLM need to be patched to support quantized eagle3?

It is supported now after this PR.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the speculative decoding export logic into a proper class hierarchy (SpeculativeDecodingExporter, EagleExporter, EagleMedusaExporter) and separates the speculative decoding export path from the standard quantization export path (export_hf_checkpoint). It also adds quantization support for speculative decoding checkpoints via a new export_speculative_decoding public API.

Changes:

  • Introduced EagleExporter/EagleMedusaExporter classes in hf_spec_export.py and extracted template configs into hf_spec_configs.py
  • Added export_speculative_decoding() as a new public API and removed the old spec_opt_only/export_spec_ckpt_* functions from export_hf_checkpoint's control flow
  • Updated example scripts and tests to use the new API and updated key name format (layers.0.* prefix instead of midlayer.*)

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
modelopt/torch/export/plugins/hf_spec_export.py Core refactoring: exporter class hierarchy, key format change, validation logic
modelopt/torch/export/plugins/hf_spec_configs.py New file: extracted Eagle template configs for llama and kimi-k2
modelopt/torch/export/unified_export_hf.py Added export_speculative_decoding public API, removed old spec-decoding early-exit
modelopt/torch/speculative/plugins/transformers.py Added get_exporter() and _draft_model_config to HFEagleModel
examples/speculative_decoding/scripts/export_hf_checkpoint.py Updated to use export_speculative_decoding
examples/llm_ptq/hf_ptq.py Added early-exit for spec-decoding models in export_quantized
tests/examples/speculative_decoding/test_eagle.py Updated required key assertions to match new key format

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread modelopt/torch/export/unified_export_hf.py
Comment thread modelopt/torch/export/plugins/hf_spec_export.py
Comment thread modelopt/torch/export/plugins/hf_spec_export.py Outdated
Comment thread modelopt/torch/export/plugins/hf_spec_export.py Outdated
Comment thread modelopt/torch/export/plugins/hf_spec_export.py Outdated
Comment thread modelopt/torch/export/plugins/hf_spec_export.py Outdated
Comment thread modelopt/torch/export/plugins/hf_spec_export.py Outdated
Comment thread modelopt/torch/export/plugins/hf_spec_export.py
Copy link
Copy Markdown
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@h-guo18 Please check Copilot's reviews, I think many of them make sense.

@h-guo18 h-guo18 force-pushed the haoguo/eagle-export branch 2 times, most recently from 24d88c3 to 533f400 Compare March 2, 2026 23:48
@h-guo18 h-guo18 requested a review from Edwardf0t1 March 2, 2026 23:51
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
modelopt/torch/export/plugins/hf_spec_export.py (1)

144-146: Use explicit prefix matching for eagle_module key extraction.

Line 144 currently matches any key containing "eagle_module". Using a strict prefix avoids accidental rewrites of unrelated keys.

🔧 Suggested change
-            if "eagle_module" in key:
-                export_key = key.replace("eagle_module.", "")
+            prefix = "eagle_module."
+            if key.startswith(prefix):
+                export_key = key[len(prefix):]
                 export_sd[export_key] = full_state_dict[key].clone()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 144 - 146, The
code currently checks if "eagle_module" is anywhere in key which can
inadvertently match unrelated keys; change the condition to require the
"eagle_module." prefix (e.g., use key.startswith("eagle_module.")) and derive
export_key by stripping only that prefix (e.g., slice off len("eagle_module.")
from key) before assigning export_sd[export_key] = full_state_dict[key].clone();
update the block that references key, export_key, full_state_dict, and export_sd
accordingly.
tests/examples/speculative_decoding/test_eagle.py (1)

135-149: Add coverage for parallel_draft_step > 1 export keys.

This test now validates .weight keys for single-layer Eagle. Please add a Medusa/parallel-draft export assertion path as well, since the new exporter remaps parallel_draft_heads.* keys.

As per coding guidelines, tests/**/*.py: Add tests for any new features or examples using pytest, ensuring coverage check passes.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/examples/speculative_decoding/test_eagle.py` around lines 135 - 149,
Extend test_export_hf_checkpoint to also validate Medusa/parallel-draft exports
by checking for remapped parallel-draft keys: after loading state_dict (as done
now), keep the existing loop over LLAMA_EAGLE_SINGLE_LAYER["required"] but also
assert that either (a) there exists at least one key in state_dict that starts
with "parallel_draft_heads." (to detect a parallel-draft export) and that those
keys end with ".weight", or (b) if you have a constant listing Medusa-required
names, iterate that constant and assert f"{required}.weight" in state_dict;
update the test_export_hf_checkpoint function to perform this additional branch
so parallel_draft_step > 1 exports are covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 227-229: The change replaces EagleMedusaExporter._check_valid_sd
with a no-op lambda, which disables schema validation for Medusa exports; revert
this by removing the lambda override and either call the superclass
implementation (e.g., super()._check_valid_sd(*args, **kwargs>) or implement a
proper validation routine that enforces the expected schema and raises
informative errors; ensure the replacement in class EagleMedusaExporter accepts
the same signature and logs/raises validation failures instead of silencing them
so malformed checkpoints are caught early.
- Around line 117-138: The current schema checks using assert (in the block
referencing export_sd, expected_keys_single_layer, and self.num_hidden_layers)
must be replaced with explicit exceptions so they aren't skipped under python
-O; change each assert that verifies presence of f"{key}.weight" and the
unexpected-key check to raise a clear exception (e.g., ValueError or
RuntimeError) with the same message, including the relevant context (the missing
key or unexpected key). Update the checks inside the first loop (assert
f"{key}.weight" in export_sd), the nested loop over range(1,
self.num_hidden_layers) (assert f"{key}.weight".replace(... ) in export_sd), and
the final unexpected-key validation (the assert around re.sub(... ) in
allowed_keys_single_layer) to use explicit raise statements carrying the
identical error text.

---

Nitpick comments:
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 144-146: The code currently checks if "eagle_module" is anywhere
in key which can inadvertently match unrelated keys; change the condition to
require the "eagle_module." prefix (e.g., use key.startswith("eagle_module."))
and derive export_key by stripping only that prefix (e.g., slice off
len("eagle_module.") from key) before assigning export_sd[export_key] =
full_state_dict[key].clone(); update the block that references key, export_key,
full_state_dict, and export_sd accordingly.

In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 135-149: Extend test_export_hf_checkpoint to also validate
Medusa/parallel-draft exports by checking for remapped parallel-draft keys:
after loading state_dict (as done now), keep the existing loop over
LLAMA_EAGLE_SINGLE_LAYER["required"] but also assert that either (a) there
exists at least one key in state_dict that starts with "parallel_draft_heads."
(to detect a parallel-draft export) and that those keys end with ".weight", or
(b) if you have a constant listing Medusa-required names, iterate that constant
and assert f"{required}.weight" in state_dict; update the
test_export_hf_checkpoint function to perform this additional branch so
parallel_draft_step > 1 exports are covered.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1b73de3 and 533f400.

📒 Files selected for processing (7)
  • examples/llm_ptq/hf_ptq.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • modelopt/torch/export/plugins/hf_spec_configs.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/speculative/plugins/transformers.py
  • tests/examples/speculative_decoding/test_eagle.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • examples/llm_ptq/hf_ptq.py
  • modelopt/torch/export/plugins/hf_spec_configs.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py

Comment on lines +117 to +138
assert f"{key}.weight" in export_sd, f"Missing required key: {key}.weight"
for i in range(1, self.num_hidden_layers):
for key in expected_keys_single_layer["required"] - {
"midlayer.hidden_norm.weight",
"midlayer.input_layernorm.weight",
"norm.weight",
"fc.weight",
"layers.0.hidden_norm",
"layers.0.input_layernorm",
"norm",
"fc",
}:
assert key.replace("midlayer", f"midlayer.{i}") in state_dict, (
f"Missing required key: {key}"
assert f"{key}.weight".replace("layers.0", f"layers.{i}") in export_sd, (
f"Missing required key: {key}.weight"
)

# Check that export sd has no unexpected keys
allowed_keys_single_layer = (
expected_keys_single_layer["required"] | expected_keys_single_layer["optional"]
)
if num_hidden_layers == 1:
for key in state_dict:
assert key in allowed_keys_single_layer, f"Unexpected key: {key}"
else:
for key in state_dict:
assert re.sub(r"midlayers\.\d+\.", "", key) in {
k.replace("midlayer.", "") for k in allowed_keys_single_layer
}, f"Unexpected key: {key}"


def spec_opt_only(model: nn.Module):
"""Check if the model have only speculative decoding optimization."""
opt_modes = getattr(model, "_modelopt_state", None)
return (
isinstance(opt_modes, (list, tuple)) and len(opt_modes) == 1 and opt_modes[0][0] == "eagle"
)


def export_spec_ckpt_state_dict(model: nn.Module):
"""Only return the state dict of the draft model in official format and ignore the base model."""
# check the model has only speculative decoding
assert spec_opt_only(model), "Not purely eagle model."

# Rename layers to midlayer
if model.eagle_config.num_hidden_layers == 1:
model.eagle_module.midlayer = model.eagle_module._modules.pop("layers")[0]
else:
model.eagle_module.midlayer = model.eagle_module._modules.pop("layers")
export_sd = copy(model.eagle_module.state_dict())

# Use base model's lm head if draft model doesn't have one
if "lm_head.weight" not in export_sd:
export_sd["lm_head.weight"] = model.state_dict()["lm_head.weight"]

# Rename parallel draft weights
if model.eagle_config.parallel_draft_step > 1:
for i in range(model.eagle_config.parallel_draft_step - 1):
for j in range(model.eagle_config.parallel_draft_heads_num_layers):
# Check that export sd has no unexpected keys
# Note that quantized eagle are allowed to have scales
allowed_keys_single_layer = (
expected_keys_single_layer["required"] | expected_keys_single_layer["optional"]
)
for key in export_sd:
assert (
re.sub(r"layers\.\d+\.", "layers.0.", key.rsplit(".", 1)[0])
in allowed_keys_single_layer
), f"Unexpected key: {key}"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🌐 Web query:

Python assert statement disabled PYTHONOPTIMIZE -O optimization

💡 Result:

  • In CPython, assert is debug-only: it’s compiled roughly as if __debug__: ... raise AssertionError .... When optimization is enabled, the compiler emits no bytecode for assert statements. [1]
  • Running Python with -O (“optimize”) will remove assert statements and any code conditional on __debug__, and sets __debug__ = False (also changes .pyc naming to include an opt-1 tag). [2]
  • Running with -OO does everything -O does and also discards docstrings (producing opt-2 bytecode). [2][3]
  • Setting PYTHONOPTIMIZE is the environment-variable equivalent of -O: if it’s a non-empty string, it’s like -O; if it’s an integer n, it’s like passing -O n times (so 2 corresponds to -OO). [4]

Practical checks:

  • At runtime, check whether asserts are disabled via __debug__ (False under -O/-OO) or sys.flags.optimize (0/1/2). [2]

Implication:

  • Don’t use assert for required validation/business logic/security checks, because it may be compiled out under -O/PYTHONOPTIMIZE. [1][2]

Sources:
[1] Python docs – assert statement semantics and compilation under -O (reference: simple statements) (docs.python.org)
[2] Python docs – -O / -OO behavior (command line and environment) (docs.python.org)
[3] Python docs – -OO removes __doc__ strings (tutorial: modules) (docs.python.org)
[4] Python docs – PYTHONOPTIMIZE equivalence to -O / repeated -O (command line and environment) (docs.python.org)

Citations:


🏁 Script executed:

find . -type f -name "hf_spec_export.py" | head -5

Repository: NVIDIA/Model-Optimizer

Length of output: 115


🏁 Script executed:

sed -n '110,145p' ./modelopt/torch/export/plugins/hf_spec_export.py | cat -n

Repository: NVIDIA/Model-Optimizer

Length of output: 1950


Replace assert-based schema checks with explicit exceptions.

Export schema validation at lines 117, 125, and 135 uses assert statements. In optimized Python runs (python -O or PYTHONOPTIMIZE set), these asserts are compiled out entirely, allowing malformed checkpoints to bypass validation. Runtime validation for checkpoint correctness must use explicit exceptions.

🔧 Suggested change
-            assert f"{key}.weight" in export_sd, f"Missing required key: {key}.weight"
+            if f"{key}.weight" not in export_sd:
+                raise KeyError(f"Missing required key: {key}.weight")
...
-                assert f"{key}.weight".replace("layers.0", f"layers.{i}") in export_sd, (
-                    f"Missing required key: {key}.weight"
-                )
+                if f"{key}.weight".replace("layers.0", f"layers.{i}") not in export_sd:
+                    raise KeyError(f"Missing required key: {key}.weight")
...
-            assert (
-                re.sub(r"layers\.\d+\.", "layers.0.", key.rsplit(".", 1)[0])
-                in allowed_keys_single_layer
-            ), f"Unexpected key: {key}"
+            if re.sub(r"layers\.\d+\.", "layers.0.", key.rsplit(".", 1)[0]) not in allowed_keys_single_layer:
+                raise KeyError(f"Unexpected key: {key}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 117 - 138, The
current schema checks using assert (in the block referencing export_sd,
expected_keys_single_layer, and self.num_hidden_layers) must be replaced with
explicit exceptions so they aren't skipped under python -O; change each assert
that verifies presence of f"{key}.weight" and the unexpected-key check to raise
a clear exception (e.g., ValueError or RuntimeError) with the same message,
including the relevant context (the missing key or unexpected key). Update the
checks inside the first loop (assert f"{key}.weight" in export_sd), the nested
loop over range(1, self.num_hidden_layers) (assert f"{key}.weight".replace(... )
in export_sd), and the final unexpected-key validation (the assert around
re.sub(... ) in allowed_keys_single_layer) to use explicit raise statements
carrying the identical error text.

Comment on lines +227 to +229
# NOTE: tmp: bypassing format check for parallel draft
self._check_valid_sd = lambda *args, **kwargs: None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Avoid disabling validation entirely in EagleMedusaExporter.

Line 228 replaces _check_valid_sd with a no-op lambda. That removes schema guarantees for Medusa exports and can hide malformed checkpoints until deployment.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 227 - 229, The
change replaces EagleMedusaExporter._check_valid_sd with a no-op lambda, which
disables schema validation for Medusa exports; revert this by removing the
lambda override and either call the superclass implementation (e.g.,
super()._check_valid_sd(*args, **kwargs>) or implement a proper validation
routine that enforces the expected schema and raises informative errors; ensure
the replacement in class EagleMedusaExporter accepts the same signature and
logs/raises validation failures instead of silencing them so malformed
checkpoints are caught early.

Copy link
Copy Markdown
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

h-guo18 added 2 commits March 3, 2026 23:10
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
@h-guo18 h-guo18 force-pushed the haoguo/eagle-export branch from 533f400 to 5371477 Compare March 3, 2026 23:11
@h-guo18 h-guo18 enabled auto-merge (squash) March 3, 2026 23:12
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (2)
modelopt/torch/export/plugins/hf_spec_export.py (2)

227-229: ⚠️ Potential issue | 🟠 Major

Do not disable Medusa export validation entirely.

Overriding _check_valid_sd with a no-op can hide broken exports until deployment.

🔧 Proposed direction
 class EagleMedusaExporter(EagleExporter):
@@
     def __init__(self, model: nn.Module):
         """Initialize the EagleMedusaExporter."""
         super().__init__(model)
         self.parallel_draft_step = model.eagle_config.parallel_draft_step
         self.parallel_draft_heads_num_layers = model.eagle_config.parallel_draft_heads_num_layers
-        # NOTE: tmp: bypassing format check for parallel draft
-        self._check_valid_sd = lambda *args, **kwargs: None
+        # Keep validation enabled; add/override schema checks for parallel draft keys as needed.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 227 - 229, The
no-op override of _check_valid_sd hides Medusa export validation failures;
remove the lambda override and either call the original validation or implement
a guarded wrapper that runs the real checks and surfaces/logs errors (e.g., call
super()._check_valid_sd(...) or the original function stored beforehand),
optionally catching exceptions to add context but not swallowing them, so
exports still fail or report detailed errors when validation fails in
hf_spec_export._check_valid_sd.

109-138: ⚠️ Potential issue | 🟠 Major

Avoid assert for checkpoint schema enforcement.

Required schema validation should not rely on assert; optimized Python runs can skip these checks entirely.

🔧 Proposed fix
     def _check_valid_sd(self, export_sd: dict):
         """Check the export state dict is valid, otherwise raise Exception."""
@@
         for key in expected_keys_single_layer["required"]:
-            assert f"{key}.weight" in export_sd, f"Missing required key: {key}.weight"
+            if f"{key}.weight" not in export_sd:
+                raise KeyError(f"Missing required key: {key}.weight")
@@
-                assert f"{key}.weight".replace("layers.0", f"layers.{i}") in export_sd, (
-                    f"Missing required key: {key}.weight"
-                )
+                if f"{key}.weight".replace("layers.0", f"layers.{i}") not in export_sd:
+                    raise KeyError(f"Missing required key: {key}.weight")
@@
-            assert (
-                re.sub(r"layers\.\d+\.", "layers.0.", key.rsplit(".", 1)[0])
-                in allowed_keys_single_layer
-            ), f"Unexpected key: {key}"
+            if re.sub(r"layers\.\d+\.", "layers.0.", key.rsplit(".", 1)[0]) not in allowed_keys_single_layer:
+                raise KeyError(f"Unexpected key: {key}")
#!/bin/bash
# Verify assert-removal behavior in optimized Python.
python - <<'PY'
import dis
src = 'def validate(x):\n    assert x, "missing"\n    return True\n'
for opt in (0, 1):
    ns = {}
    exec(compile(src, "<src>", "exec", optimize=opt), ns)
    print(f"optimize={opt}")
    dis.dis(ns["validate"])
PY
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 109 - 138, The
_check_valid_sd method uses assert statements for schema validation which can be
stripped in optimized Python; replace each assert with explicit conditional
checks that raise a clear exception (e.g., ValueError or KeyError) including the
problematic key and context. Specifically, in _check_valid_sd, change the
required-key checks (both single-layer and per-layer loop) to if key not in
export_sd: raise ValueError(f"Missing required key: {key}.weight") using the
same key construction logic, and change the unexpected-key check at the end to:
if re.sub(... ) not in allowed_keys_single_layer: raise ValueError(f"Unexpected
key: {key}"). Keep the existing variables expected_keys_single_layer,
allowed_keys_single_layer, export_sd, and self.num_hidden_layers so the logic
location remains identical.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 1089-1096: The export_speculative_decoding function currently uses
an assert for public API validation which can be skipped under -O flags; replace
the assert has_spec_opt(model) with an explicit check that raises a suitable
exception (e.g., ValueError or TypeError) when the model is not optimized for
speculative decoding, include a clear message like "Model is not optimized for
speculative decoding.", and keep this check at the start of
export_speculative_decoding to ensure deterministic runtime validation.

---

Duplicate comments:
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 227-229: The no-op override of _check_valid_sd hides Medusa export
validation failures; remove the lambda override and either call the original
validation or implement a guarded wrapper that runs the real checks and
surfaces/logs errors (e.g., call super()._check_valid_sd(...) or the original
function stored beforehand), optionally catching exceptions to add context but
not swallowing them, so exports still fail or report detailed errors when
validation fails in hf_spec_export._check_valid_sd.
- Around line 109-138: The _check_valid_sd method uses assert statements for
schema validation which can be stripped in optimized Python; replace each assert
with explicit conditional checks that raise a clear exception (e.g., ValueError
or KeyError) including the problematic key and context. Specifically, in
_check_valid_sd, change the required-key checks (both single-layer and per-layer
loop) to if key not in export_sd: raise ValueError(f"Missing required key:
{key}.weight") using the same key construction logic, and change the
unexpected-key check at the end to: if re.sub(... ) not in
allowed_keys_single_layer: raise ValueError(f"Unexpected key: {key}"). Keep the
existing variables expected_keys_single_layer, allowed_keys_single_layer,
export_sd, and self.num_hidden_layers so the logic location remains identical.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 533f400 and 5371477.

📒 Files selected for processing (7)
  • examples/llm_ptq/hf_ptq.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py
  • modelopt/torch/export/plugins/hf_spec_configs.py
  • modelopt/torch/export/plugins/hf_spec_export.py
  • modelopt/torch/export/unified_export_hf.py
  • modelopt/torch/speculative/plugins/transformers.py
  • tests/examples/speculative_decoding/test_eagle.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • tests/examples/speculative_decoding/test_eagle.py
  • examples/llm_ptq/hf_ptq.py
  • examples/speculative_decoding/scripts/export_hf_checkpoint.py

Comment on lines +1089 to +1096
def export_speculative_decoding(
model: torch.nn.Module,
dtype: torch.dtype | None = None,
export_dir: Path | str = tempfile.gettempdir(),
) -> None:
"""Export speculative decoding HuggingFace model checkpoint."""
assert has_spec_opt(model), "Model is not optimized for speculative decoding."

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# First, verify the concern about assert being compiled out under optimization
python3 - <<'PY'
import dis
src = 'def guard(x):\n    assert x, "bad"\n    return 1\n'
for opt in (0, 1):
    ns = {}
    exec(compile(src, "<src>", "exec", optimize=opt), ns)
    print(f"optimize={opt}")
    dis.dis(ns["guard"])
    print()
PY

Repository: NVIDIA/Model-Optimizer

Length of output: 649


🏁 Script executed:

#!/bin/bash
# Verify the actual code in the file
head -1100 modelopt/torch/export/unified_export_hf.py | tail -20

Repository: NVIDIA/Model-Optimizer

Length of output: 717


🏁 Script executed:

#!/bin/bash
# Check if this function is exported in __all__
rg -A 5 -B 5 "export_speculative_decoding" modelopt/torch/export/unified_export_hf.py | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 727


Replace assert with explicit exception in public API validation.

The guard statement can be compiled out when Python is invoked with -O or -OO flags, making validation non-deterministic. Input validation in public APIs must use explicit if/raise patterns.

🔧 Proposed fix
 def export_speculative_decoding(
     model: torch.nn.Module,
     dtype: torch.dtype | None = None,
     export_dir: Path | str = tempfile.gettempdir(),
 ) -> None:
     """Export speculative decoding HuggingFace model checkpoint."""
-    assert has_spec_opt(model), "Model is not optimized for speculative decoding."
+    if not has_spec_opt(model):
+        raise ValueError("Model is not optimized for speculative decoding.")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/export/unified_export_hf.py` around lines 1089 - 1096, The
export_speculative_decoding function currently uses an assert for public API
validation which can be skipped under -O flags; replace the assert
has_spec_opt(model) with an explicit check that raises a suitable exception
(e.g., ValueError or TypeError) when the model is not optimized for speculative
decoding, include a clear message like "Model is not optimized for speculative
decoding.", and keep this check at the start of export_speculative_decoding to
ensure deterministic runtime validation.

@h-guo18 h-guo18 merged commit a34d613 into main Mar 4, 2026
64 of 68 checks passed
@h-guo18 h-guo18 deleted the haoguo/eagle-export branch March 4, 2026 01:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants