Feat: Speculatice Decoding export with quantization support#913
Feat: Speculatice Decoding export with quantization support#913
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughAdds a class-based speculative decoding export pipeline (SpeculativeDecodingExporter, EagleExporter, EagleMedusaExporter), template configs, model hooks to obtain an exporter, a new public API export_speculative_decoding, and updates example scripts and tests to use the speculative export path. Changes
Sequence DiagramsequenceDiagram
participant CLI as User/Script
participant HFExport as unified_export_hf.export_quantized
participant Check as has_spec_opt()
participant Model as HFEagleModel
participant Exporter as EagleExporter/EagleMedusaExporter
participant Storage as Filesystem
CLI->>HFExport: request export(model, export_dir, dtype)
HFExport->>Check: has_spec_opt(model)?
alt spec-optimized
Check-->>HFExport: true
HFExport->>Model: model.get_exporter()
Model-->>HFExport: Exporter instance
HFExport->>Exporter: exporter.export(export_dir, dtype)
Exporter->>Exporter: _extract_state_dict(full_state_dict)
Exporter->>Storage: write model.safetensors
Exporter->>Storage: write config.json (and hf_quant_config.json?)
Exporter-->>HFExport: done
else not spec-optimized
Check-->>HFExport: false
HFExport->>HFExport: continue standard HF export flow
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error)
✅ Passed checks (3 passed)
✨ Finishing Touches
🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Comment |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #913 +/- ##
==========================================
- Coverage 72.12% 72.10% -0.03%
==========================================
Files 209 209
Lines 23628 23628
==========================================
- Hits 17042 17036 -6
- Misses 6586 6592 +6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
d9926e9 to
1b73de3
Compare
There was a problem hiding this comment.
Actionable comments posted: 4
🧹 Nitpick comments (1)
modelopt/torch/export/plugins/hf_spec_export.py (1)
185-214: Validation bypass is documented as temporary.The
_check_valid_sd = lambda *args, **kwargs: Noneon line 194 effectively disables state dict validation for parallel draft exports. TheNOTE: tmp:comment indicates this is intentional but temporary.Consider tracking this with a TODO or issue reference to ensure validation is properly implemented for parallel draft exports before the feature is considered stable.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 185 - 214, The code currently disables state-dict validation by setting self._check_valid_sd = lambda *args, **kwargs: None in the EagleMedusaExporter __init__, which is marked only as a temporary NOTE; replace this silent bypass with a tracked TODO and a visible reminder: restore validation by implementing proper checks for parallel_draft_step in extract_state_dict and call the original EagleExporter._check_valid_sd (or raise/log a clear warning/error) until full validation is implemented; specifically update the EagleMedusaExporter class to remove the no-op lambda, add a TODO/issue-ID comment referencing the missing validation work, and ensure any call sites (e.g., extract_state_dict) invoke the proper _check_valid_sd behavior so state-dict validation is not permanently skipped.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/llm_ptq/hf_ptq.py`:
- Around line 571-574: The early return after calling has_spec_opt(full_model)
and export_speculative_decoding(full_model, export_dir=export_path) skips the
subsequent tokenizer save and timing/export message; update the
speculative-decoding branch so it either (a) calls the same tokenizer save
routine (e.g., tokenizer.save_pretrained or the existing tokenizer save logic)
and prints the export/timing confirmation before returning, or (b) moves the
return to after those steps, and if skipping is intentional add a concise
comment explaining why; reference has_spec_opt, export_speculative_decoding,
full_model and export_path so the change is applied to the correct branch.
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 180-182: Fix the typo in the docstring of export_quant_config:
change "hf_quant_coinfig.json" to "hf_quant_config.json" in the docstring for
the function export_quant_config which returns copy(self.hf_quant_config).
- Around line 144-178: In export_config, using copy(template_config) creates
only a shallow copy so nested dicts (e.g., eagle config data) are mutated on
assignment; replace the shallow copy with a deep copy (use copy.deepcopy) when
copying the selected template (referencing template_config,
llama_eagle_template_config, kimik2_eagle_template_config in the export_config
method) so modifications to nested keys do not alter the original imported
templates across multiple calls; ensure the copy module's deepcopy is
imported/used accordingly.
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 994-996: The comment above the state-dict export is incorrect:
change the misleading "Export config.json" comment that precedes the lines using
exporter.extract_state_dict(), drafter_sd, and save_file(...,
"model.safetensors") to accurately describe exporting the model state dict
(e.g., "Export model state dict to model.safetensors"), leaving the actual
config.json export block (using save_file for config.json) unchanged.
---
Nitpick comments:
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 185-214: The code currently disables state-dict validation by
setting self._check_valid_sd = lambda *args, **kwargs: None in the
EagleMedusaExporter __init__, which is marked only as a temporary NOTE; replace
this silent bypass with a tracked TODO and a visible reminder: restore
validation by implementing proper checks for parallel_draft_step in
extract_state_dict and call the original EagleExporter._check_valid_sd (or
raise/log a clear warning/error) until full validation is implemented;
specifically update the EagleMedusaExporter class to remove the no-op lambda,
add a TODO/issue-ID comment referencing the missing validation work, and ensure
any call sites (e.g., extract_state_dict) invoke the proper _check_valid_sd
behavior so state-dict validation is not permanently skipped.
|
Does TRTLLM need to be patched to support quantized eagle3? |
1 similar comment
|
Does TRTLLM need to be patched to support quantized eagle3? |
It is supported now after this PR. |
There was a problem hiding this comment.
Pull request overview
This PR refactors the speculative decoding export logic into a proper class hierarchy (SpeculativeDecodingExporter, EagleExporter, EagleMedusaExporter) and separates the speculative decoding export path from the standard quantization export path (export_hf_checkpoint). It also adds quantization support for speculative decoding checkpoints via a new export_speculative_decoding public API.
Changes:
- Introduced
EagleExporter/EagleMedusaExporterclasses inhf_spec_export.pyand extracted template configs intohf_spec_configs.py - Added
export_speculative_decoding()as a new public API and removed the oldspec_opt_only/export_spec_ckpt_*functions fromexport_hf_checkpoint's control flow - Updated example scripts and tests to use the new API and updated key name format (
layers.0.*prefix instead ofmidlayer.*)
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
modelopt/torch/export/plugins/hf_spec_export.py |
Core refactoring: exporter class hierarchy, key format change, validation logic |
modelopt/torch/export/plugins/hf_spec_configs.py |
New file: extracted Eagle template configs for llama and kimi-k2 |
modelopt/torch/export/unified_export_hf.py |
Added export_speculative_decoding public API, removed old spec-decoding early-exit |
modelopt/torch/speculative/plugins/transformers.py |
Added get_exporter() and _draft_model_config to HFEagleModel |
examples/speculative_decoding/scripts/export_hf_checkpoint.py |
Updated to use export_speculative_decoding |
examples/llm_ptq/hf_ptq.py |
Added early-exit for spec-decoding models in export_quantized |
tests/examples/speculative_decoding/test_eagle.py |
Updated required key assertions to match new key format |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Edwardf0t1
left a comment
There was a problem hiding this comment.
@h-guo18 Please check Copilot's reviews, I think many of them make sense.
24d88c3 to
533f400
Compare
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
modelopt/torch/export/plugins/hf_spec_export.py (1)
144-146: Use explicit prefix matching foreagle_modulekey extraction.Line 144 currently matches any key containing
"eagle_module". Using a strict prefix avoids accidental rewrites of unrelated keys.🔧 Suggested change
- if "eagle_module" in key: - export_key = key.replace("eagle_module.", "") + prefix = "eagle_module." + if key.startswith(prefix): + export_key = key[len(prefix):] export_sd[export_key] = full_state_dict[key].clone()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 144 - 146, The code currently checks if "eagle_module" is anywhere in key which can inadvertently match unrelated keys; change the condition to require the "eagle_module." prefix (e.g., use key.startswith("eagle_module.")) and derive export_key by stripping only that prefix (e.g., slice off len("eagle_module.") from key) before assigning export_sd[export_key] = full_state_dict[key].clone(); update the block that references key, export_key, full_state_dict, and export_sd accordingly.tests/examples/speculative_decoding/test_eagle.py (1)
135-149: Add coverage forparallel_draft_step > 1export keys.This test now validates
.weightkeys for single-layer Eagle. Please add a Medusa/parallel-draft export assertion path as well, since the new exporter remapsparallel_draft_heads.*keys.As per coding guidelines,
tests/**/*.py: Add tests for any new features or examples using pytest, ensuring coverage check passes.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/examples/speculative_decoding/test_eagle.py` around lines 135 - 149, Extend test_export_hf_checkpoint to also validate Medusa/parallel-draft exports by checking for remapped parallel-draft keys: after loading state_dict (as done now), keep the existing loop over LLAMA_EAGLE_SINGLE_LAYER["required"] but also assert that either (a) there exists at least one key in state_dict that starts with "parallel_draft_heads." (to detect a parallel-draft export) and that those keys end with ".weight", or (b) if you have a constant listing Medusa-required names, iterate that constant and assert f"{required}.weight" in state_dict; update the test_export_hf_checkpoint function to perform this additional branch so parallel_draft_step > 1 exports are covered.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 227-229: The change replaces EagleMedusaExporter._check_valid_sd
with a no-op lambda, which disables schema validation for Medusa exports; revert
this by removing the lambda override and either call the superclass
implementation (e.g., super()._check_valid_sd(*args, **kwargs>) or implement a
proper validation routine that enforces the expected schema and raises
informative errors; ensure the replacement in class EagleMedusaExporter accepts
the same signature and logs/raises validation failures instead of silencing them
so malformed checkpoints are caught early.
- Around line 117-138: The current schema checks using assert (in the block
referencing export_sd, expected_keys_single_layer, and self.num_hidden_layers)
must be replaced with explicit exceptions so they aren't skipped under python
-O; change each assert that verifies presence of f"{key}.weight" and the
unexpected-key check to raise a clear exception (e.g., ValueError or
RuntimeError) with the same message, including the relevant context (the missing
key or unexpected key). Update the checks inside the first loop (assert
f"{key}.weight" in export_sd), the nested loop over range(1,
self.num_hidden_layers) (assert f"{key}.weight".replace(... ) in export_sd), and
the final unexpected-key validation (the assert around re.sub(... ) in
allowed_keys_single_layer) to use explicit raise statements carrying the
identical error text.
---
Nitpick comments:
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 144-146: The code currently checks if "eagle_module" is anywhere
in key which can inadvertently match unrelated keys; change the condition to
require the "eagle_module." prefix (e.g., use key.startswith("eagle_module."))
and derive export_key by stripping only that prefix (e.g., slice off
len("eagle_module.") from key) before assigning export_sd[export_key] =
full_state_dict[key].clone(); update the block that references key, export_key,
full_state_dict, and export_sd accordingly.
In `@tests/examples/speculative_decoding/test_eagle.py`:
- Around line 135-149: Extend test_export_hf_checkpoint to also validate
Medusa/parallel-draft exports by checking for remapped parallel-draft keys:
after loading state_dict (as done now), keep the existing loop over
LLAMA_EAGLE_SINGLE_LAYER["required"] but also assert that either (a) there
exists at least one key in state_dict that starts with "parallel_draft_heads."
(to detect a parallel-draft export) and that those keys end with ".weight", or
(b) if you have a constant listing Medusa-required names, iterate that constant
and assert f"{required}.weight" in state_dict; update the
test_export_hf_checkpoint function to perform this additional branch so
parallel_draft_step > 1 exports are covered.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
examples/llm_ptq/hf_ptq.pyexamples/speculative_decoding/scripts/export_hf_checkpoint.pymodelopt/torch/export/plugins/hf_spec_configs.pymodelopt/torch/export/plugins/hf_spec_export.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/speculative/plugins/transformers.pytests/examples/speculative_decoding/test_eagle.py
🚧 Files skipped from review as they are similar to previous changes (3)
- examples/llm_ptq/hf_ptq.py
- modelopt/torch/export/plugins/hf_spec_configs.py
- examples/speculative_decoding/scripts/export_hf_checkpoint.py
| assert f"{key}.weight" in export_sd, f"Missing required key: {key}.weight" | ||
| for i in range(1, self.num_hidden_layers): | ||
| for key in expected_keys_single_layer["required"] - { | ||
| "midlayer.hidden_norm.weight", | ||
| "midlayer.input_layernorm.weight", | ||
| "norm.weight", | ||
| "fc.weight", | ||
| "layers.0.hidden_norm", | ||
| "layers.0.input_layernorm", | ||
| "norm", | ||
| "fc", | ||
| }: | ||
| assert key.replace("midlayer", f"midlayer.{i}") in state_dict, ( | ||
| f"Missing required key: {key}" | ||
| assert f"{key}.weight".replace("layers.0", f"layers.{i}") in export_sd, ( | ||
| f"Missing required key: {key}.weight" | ||
| ) | ||
|
|
||
| # Check that export sd has no unexpected keys | ||
| allowed_keys_single_layer = ( | ||
| expected_keys_single_layer["required"] | expected_keys_single_layer["optional"] | ||
| ) | ||
| if num_hidden_layers == 1: | ||
| for key in state_dict: | ||
| assert key in allowed_keys_single_layer, f"Unexpected key: {key}" | ||
| else: | ||
| for key in state_dict: | ||
| assert re.sub(r"midlayers\.\d+\.", "", key) in { | ||
| k.replace("midlayer.", "") for k in allowed_keys_single_layer | ||
| }, f"Unexpected key: {key}" | ||
|
|
||
|
|
||
| def spec_opt_only(model: nn.Module): | ||
| """Check if the model have only speculative decoding optimization.""" | ||
| opt_modes = getattr(model, "_modelopt_state", None) | ||
| return ( | ||
| isinstance(opt_modes, (list, tuple)) and len(opt_modes) == 1 and opt_modes[0][0] == "eagle" | ||
| ) | ||
|
|
||
|
|
||
| def export_spec_ckpt_state_dict(model: nn.Module): | ||
| """Only return the state dict of the draft model in official format and ignore the base model.""" | ||
| # check the model has only speculative decoding | ||
| assert spec_opt_only(model), "Not purely eagle model." | ||
|
|
||
| # Rename layers to midlayer | ||
| if model.eagle_config.num_hidden_layers == 1: | ||
| model.eagle_module.midlayer = model.eagle_module._modules.pop("layers")[0] | ||
| else: | ||
| model.eagle_module.midlayer = model.eagle_module._modules.pop("layers") | ||
| export_sd = copy(model.eagle_module.state_dict()) | ||
|
|
||
| # Use base model's lm head if draft model doesn't have one | ||
| if "lm_head.weight" not in export_sd: | ||
| export_sd["lm_head.weight"] = model.state_dict()["lm_head.weight"] | ||
|
|
||
| # Rename parallel draft weights | ||
| if model.eagle_config.parallel_draft_step > 1: | ||
| for i in range(model.eagle_config.parallel_draft_step - 1): | ||
| for j in range(model.eagle_config.parallel_draft_heads_num_layers): | ||
| # Check that export sd has no unexpected keys | ||
| # Note that quantized eagle are allowed to have scales | ||
| allowed_keys_single_layer = ( | ||
| expected_keys_single_layer["required"] | expected_keys_single_layer["optional"] | ||
| ) | ||
| for key in export_sd: | ||
| assert ( | ||
| re.sub(r"layers\.\d+\.", "layers.0.", key.rsplit(".", 1)[0]) | ||
| in allowed_keys_single_layer | ||
| ), f"Unexpected key: {key}" |
There was a problem hiding this comment.
🧩 Analysis chain
🌐 Web query:
Python assert statement disabled PYTHONOPTIMIZE -O optimization
💡 Result:
- In CPython,
assertis debug-only: it’s compiled roughly asif __debug__: ... raise AssertionError .... When optimization is enabled, the compiler emits no bytecode forassertstatements. [1] - Running Python with
-O(“optimize”) will removeassertstatements and any code conditional on__debug__, and sets__debug__ = False(also changes.pycnaming to include anopt-1tag). [2] - Running with
-OOdoes everything-Odoes and also discards docstrings (producingopt-2bytecode). [2][3] - Setting
PYTHONOPTIMIZEis the environment-variable equivalent of-O: if it’s a non-empty string, it’s like-O; if it’s an integern, it’s like passing-On times (so2corresponds to-OO). [4]
Practical checks:
- At runtime, check whether asserts are disabled via
__debug__(False under-O/-OO) orsys.flags.optimize(0/1/2). [2]
Implication:
- Don’t use
assertfor required validation/business logic/security checks, because it may be compiled out under-O/PYTHONOPTIMIZE. [1][2]
Sources:
[1] Python docs – assert statement semantics and compilation under -O (reference: simple statements) (docs.python.org)
[2] Python docs – -O / -OO behavior (command line and environment) (docs.python.org)
[3] Python docs – -OO removes __doc__ strings (tutorial: modules) (docs.python.org)
[4] Python docs – PYTHONOPTIMIZE equivalence to -O / repeated -O (command line and environment) (docs.python.org)
Citations:
- 1: https://docs.python.org/fr/3.11/reference/simple_stmts.html?utm_source=openai
- 2: https://docs.python.org/pl/3.6/using/cmdline.html?utm_source=openai
- 3: https://docs.python.org/bn-in/3.13/tutorial/modules.html?utm_source=openai
- 4: https://docs.python.org/uk/3.10/using/cmdline.html?utm_source=openai
🏁 Script executed:
find . -type f -name "hf_spec_export.py" | head -5Repository: NVIDIA/Model-Optimizer
Length of output: 115
🏁 Script executed:
sed -n '110,145p' ./modelopt/torch/export/plugins/hf_spec_export.py | cat -nRepository: NVIDIA/Model-Optimizer
Length of output: 1950
Replace assert-based schema checks with explicit exceptions.
Export schema validation at lines 117, 125, and 135 uses assert statements. In optimized Python runs (python -O or PYTHONOPTIMIZE set), these asserts are compiled out entirely, allowing malformed checkpoints to bypass validation. Runtime validation for checkpoint correctness must use explicit exceptions.
🔧 Suggested change
- assert f"{key}.weight" in export_sd, f"Missing required key: {key}.weight"
+ if f"{key}.weight" not in export_sd:
+ raise KeyError(f"Missing required key: {key}.weight")
...
- assert f"{key}.weight".replace("layers.0", f"layers.{i}") in export_sd, (
- f"Missing required key: {key}.weight"
- )
+ if f"{key}.weight".replace("layers.0", f"layers.{i}") not in export_sd:
+ raise KeyError(f"Missing required key: {key}.weight")
...
- assert (
- re.sub(r"layers\.\d+\.", "layers.0.", key.rsplit(".", 1)[0])
- in allowed_keys_single_layer
- ), f"Unexpected key: {key}"
+ if re.sub(r"layers\.\d+\.", "layers.0.", key.rsplit(".", 1)[0]) not in allowed_keys_single_layer:
+ raise KeyError(f"Unexpected key: {key}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 117 - 138, The
current schema checks using assert (in the block referencing export_sd,
expected_keys_single_layer, and self.num_hidden_layers) must be replaced with
explicit exceptions so they aren't skipped under python -O; change each assert
that verifies presence of f"{key}.weight" and the unexpected-key check to raise
a clear exception (e.g., ValueError or RuntimeError) with the same message,
including the relevant context (the missing key or unexpected key). Update the
checks inside the first loop (assert f"{key}.weight" in export_sd), the nested
loop over range(1, self.num_hidden_layers) (assert f"{key}.weight".replace(... )
in export_sd), and the final unexpected-key validation (the assert around
re.sub(... ) in allowed_keys_single_layer) to use explicit raise statements
carrying the identical error text.
| # NOTE: tmp: bypassing format check for parallel draft | ||
| self._check_valid_sd = lambda *args, **kwargs: None | ||
|
|
There was a problem hiding this comment.
Avoid disabling validation entirely in EagleMedusaExporter.
Line 228 replaces _check_valid_sd with a no-op lambda. That removes schema guarantees for Medusa exports and can hide malformed checkpoints until deployment.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 227 - 229, The
change replaces EagleMedusaExporter._check_valid_sd with a no-op lambda, which
disables schema validation for Medusa exports; revert this by removing the
lambda override and either call the superclass implementation (e.g.,
super()._check_valid_sd(*args, **kwargs>) or implement a proper validation
routine that enforces the expected schema and raises informative errors; ensure
the replacement in class EagleMedusaExporter accepts the same signature and
logs/raises validation failures instead of silencing them so malformed
checkpoints are caught early.
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
533f400 to
5371477
Compare
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (2)
modelopt/torch/export/plugins/hf_spec_export.py (2)
227-229:⚠️ Potential issue | 🟠 MajorDo not disable Medusa export validation entirely.
Overriding
_check_valid_sdwith a no-op can hide broken exports until deployment.🔧 Proposed direction
class EagleMedusaExporter(EagleExporter): @@ def __init__(self, model: nn.Module): """Initialize the EagleMedusaExporter.""" super().__init__(model) self.parallel_draft_step = model.eagle_config.parallel_draft_step self.parallel_draft_heads_num_layers = model.eagle_config.parallel_draft_heads_num_layers - # NOTE: tmp: bypassing format check for parallel draft - self._check_valid_sd = lambda *args, **kwargs: None + # Keep validation enabled; add/override schema checks for parallel draft keys as needed.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 227 - 229, The no-op override of _check_valid_sd hides Medusa export validation failures; remove the lambda override and either call the original validation or implement a guarded wrapper that runs the real checks and surfaces/logs errors (e.g., call super()._check_valid_sd(...) or the original function stored beforehand), optionally catching exceptions to add context but not swallowing them, so exports still fail or report detailed errors when validation fails in hf_spec_export._check_valid_sd.
109-138:⚠️ Potential issue | 🟠 MajorAvoid
assertfor checkpoint schema enforcement.Required schema validation should not rely on
assert; optimized Python runs can skip these checks entirely.🔧 Proposed fix
def _check_valid_sd(self, export_sd: dict): """Check the export state dict is valid, otherwise raise Exception.""" @@ for key in expected_keys_single_layer["required"]: - assert f"{key}.weight" in export_sd, f"Missing required key: {key}.weight" + if f"{key}.weight" not in export_sd: + raise KeyError(f"Missing required key: {key}.weight") @@ - assert f"{key}.weight".replace("layers.0", f"layers.{i}") in export_sd, ( - f"Missing required key: {key}.weight" - ) + if f"{key}.weight".replace("layers.0", f"layers.{i}") not in export_sd: + raise KeyError(f"Missing required key: {key}.weight") @@ - assert ( - re.sub(r"layers\.\d+\.", "layers.0.", key.rsplit(".", 1)[0]) - in allowed_keys_single_layer - ), f"Unexpected key: {key}" + if re.sub(r"layers\.\d+\.", "layers.0.", key.rsplit(".", 1)[0]) not in allowed_keys_single_layer: + raise KeyError(f"Unexpected key: {key}")#!/bin/bash # Verify assert-removal behavior in optimized Python. python - <<'PY' import dis src = 'def validate(x):\n assert x, "missing"\n return True\n' for opt in (0, 1): ns = {} exec(compile(src, "<src>", "exec", optimize=opt), ns) print(f"optimize={opt}") dis.dis(ns["validate"]) PY🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/export/plugins/hf_spec_export.py` around lines 109 - 138, The _check_valid_sd method uses assert statements for schema validation which can be stripped in optimized Python; replace each assert with explicit conditional checks that raise a clear exception (e.g., ValueError or KeyError) including the problematic key and context. Specifically, in _check_valid_sd, change the required-key checks (both single-layer and per-layer loop) to if key not in export_sd: raise ValueError(f"Missing required key: {key}.weight") using the same key construction logic, and change the unexpected-key check at the end to: if re.sub(... ) not in allowed_keys_single_layer: raise ValueError(f"Unexpected key: {key}"). Keep the existing variables expected_keys_single_layer, allowed_keys_single_layer, export_sd, and self.num_hidden_layers so the logic location remains identical.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/export/unified_export_hf.py`:
- Around line 1089-1096: The export_speculative_decoding function currently uses
an assert for public API validation which can be skipped under -O flags; replace
the assert has_spec_opt(model) with an explicit check that raises a suitable
exception (e.g., ValueError or TypeError) when the model is not optimized for
speculative decoding, include a clear message like "Model is not optimized for
speculative decoding.", and keep this check at the start of
export_speculative_decoding to ensure deterministic runtime validation.
---
Duplicate comments:
In `@modelopt/torch/export/plugins/hf_spec_export.py`:
- Around line 227-229: The no-op override of _check_valid_sd hides Medusa export
validation failures; remove the lambda override and either call the original
validation or implement a guarded wrapper that runs the real checks and
surfaces/logs errors (e.g., call super()._check_valid_sd(...) or the original
function stored beforehand), optionally catching exceptions to add context but
not swallowing them, so exports still fail or report detailed errors when
validation fails in hf_spec_export._check_valid_sd.
- Around line 109-138: The _check_valid_sd method uses assert statements for
schema validation which can be stripped in optimized Python; replace each assert
with explicit conditional checks that raise a clear exception (e.g., ValueError
or KeyError) including the problematic key and context. Specifically, in
_check_valid_sd, change the required-key checks (both single-layer and per-layer
loop) to if key not in export_sd: raise ValueError(f"Missing required key:
{key}.weight") using the same key construction logic, and change the
unexpected-key check at the end to: if re.sub(... ) not in
allowed_keys_single_layer: raise ValueError(f"Unexpected key: {key}"). Keep the
existing variables expected_keys_single_layer, allowed_keys_single_layer,
export_sd, and self.num_hidden_layers so the logic location remains identical.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
examples/llm_ptq/hf_ptq.pyexamples/speculative_decoding/scripts/export_hf_checkpoint.pymodelopt/torch/export/plugins/hf_spec_configs.pymodelopt/torch/export/plugins/hf_spec_export.pymodelopt/torch/export/unified_export_hf.pymodelopt/torch/speculative/plugins/transformers.pytests/examples/speculative_decoding/test_eagle.py
🚧 Files skipped from review as they are similar to previous changes (3)
- tests/examples/speculative_decoding/test_eagle.py
- examples/llm_ptq/hf_ptq.py
- examples/speculative_decoding/scripts/export_hf_checkpoint.py
| def export_speculative_decoding( | ||
| model: torch.nn.Module, | ||
| dtype: torch.dtype | None = None, | ||
| export_dir: Path | str = tempfile.gettempdir(), | ||
| ) -> None: | ||
| """Export speculative decoding HuggingFace model checkpoint.""" | ||
| assert has_spec_opt(model), "Model is not optimized for speculative decoding." | ||
|
|
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# First, verify the concern about assert being compiled out under optimization
python3 - <<'PY'
import dis
src = 'def guard(x):\n assert x, "bad"\n return 1\n'
for opt in (0, 1):
ns = {}
exec(compile(src, "<src>", "exec", optimize=opt), ns)
print(f"optimize={opt}")
dis.dis(ns["guard"])
print()
PYRepository: NVIDIA/Model-Optimizer
Length of output: 649
🏁 Script executed:
#!/bin/bash
# Verify the actual code in the file
head -1100 modelopt/torch/export/unified_export_hf.py | tail -20Repository: NVIDIA/Model-Optimizer
Length of output: 717
🏁 Script executed:
#!/bin/bash
# Check if this function is exported in __all__
rg -A 5 -B 5 "export_speculative_decoding" modelopt/torch/export/unified_export_hf.py | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 727
Replace assert with explicit exception in public API validation.
The guard statement can be compiled out when Python is invoked with -O or -OO flags, making validation non-deterministic. Input validation in public APIs must use explicit if/raise patterns.
🔧 Proposed fix
def export_speculative_decoding(
model: torch.nn.Module,
dtype: torch.dtype | None = None,
export_dir: Path | str = tempfile.gettempdir(),
) -> None:
"""Export speculative decoding HuggingFace model checkpoint."""
- assert has_spec_opt(model), "Model is not optimized for speculative decoding."
+ if not has_spec_opt(model):
+ raise ValueError("Model is not optimized for speculative decoding.")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/export/unified_export_hf.py` around lines 1089 - 1096, The
export_speculative_decoding function currently uses an assert for public API
validation which can be skipped under -O flags; replace the assert
has_spec_opt(model) with an explicit check that raises a suitable exception
(e.g., ValueError or TypeError) when the model is not optimized for speculative
decoding, include a clear message like "Model is not optimized for speculative
decoding.", and keep this check at the start of export_speculative_decoding to
ensure deterministic runtime validation.
What does this PR do?
Type of change: ?
Overview:
Main changes:
Refactored speculative decoding export logics into
class EagleExporterto improve cohesion;Separated speculative decoding export entrance with quantization export (
export_hf_checkpoint()) due to their fundamental differences:Usage
To export an regular bf16 eagle checkpoint without quantization, the commands are the same:
To run PTQ on online-trained eagle checkpoint and export it:
The above two commands will produce drafter ckpt for deployment, in the same foramt.
Testing
Tested setting:
python scripts/export_hf_checkpoint.py --model_path <x> --export_path <x>python hf_ptq.py --pyt_ckpt_path <x> --qformat fp8 --export_path <x>Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Refactor
Tests