[Feat,Refactor]: Offline Dflash; Spec Mixin; Deprecate parallel draft;#1271
[Feat,Refactor]: Offline Dflash; Spec Mixin; Deprecate parallel draft;#1271
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
📝 WalkthroughWalkthroughConsolidates speculative-decoding logic into dedicated modules, adds HFSpecDecMixin for shared base-model utilities, extracts EAGLE/DFlash draft models into new modeling modules, introduces a standalone HFMedusa plugin, adds DFlash offline/config validators, and updates call sites, imports, and defaults accordingly. Changes
Sequence Diagram(s)sequenceDiagram
participant Client as Caller
participant Mixin as HFSpecDecMixin
participant Base as BaseModel
participant Draft as DraftModule
participant Head as LMHead
Caller->>Mixin: forward(input_ids, attention_mask, ...)
Mixin->>Base: _base_model_forward(input_ids, attention_mask, freeze=...)
Base-->>Mixin: BaseOutputs (hidden_states, logits, aux)
Mixin->>Draft: forward(noise_embedding, target_hidden, position_ids, mask)
Draft-->>Mixin: draft_logits / draft_hidden
Mixin->>Head: compute final logits (optionally from draft/base)
Head-->>Mixin: logits
Mixin-->>Caller: ModelOutput (loss, logits, draft_logits, aux)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 4✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 5
🧹 Nitpick comments (1)
modelopt/torch/speculative/plugins/modeling_dflash.py (1)
118-124: Consider handlingNonefor_attn_implementationmore explicitly.The code assumes
config._attn_implementationis set (per the comment referencingdflash/default_config.py), but if it'sNone,ALL_ATTENTION_FUNCTIONS.get(None, ...)would still work and fall back to SDPA. However, this could be made more explicit for clarity.♻️ Optional: More explicit None handling
def _get_attn_fn(self): """Lazily resolve the HF attention function (default: sdpa).""" if self._attn_fn is not None: return self._attn_fn - impl = self.config._attn_implementation # default set in dflash/default_config.py + impl = self.config._attn_implementation or "sdpa" self._attn_fn = ALL_ATTENTION_FUNCTIONS.get(impl, ALL_ATTENTION_FUNCTIONS["sdpa"]) return self._attn_fn🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/modeling_dflash.py` around lines 118 - 124, The _get_attn_fn method should explicitly handle a None or missing config._attn_implementation instead of relying on dict.get's default; update _get_attn_fn to read impl = self.config._attn_implementation and if impl is None or impl not in ALL_ATTENTION_FUNCTIONS explicitly assign impl = "sdpa" (or the intended default) before setting self._attn_fn via ALL_ATTENTION_FUNCTIONS[impl], so callers of _get_attn_fn and readers of the code see clear, intentional fallback behavior referencing the _get_attn_fn method, config._attn_implementation, and ALL_ATTENTION_FUNCTIONS.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Line 137: The call in hf_medusa.py is passing an incorrectly named keyword
rcache_position which causes the cache position to be ignored; update the
invocation (around the code that constructs the KV cache / calls the method
using rcache_position) to use the correct keyword name cache_position so the
value is honored (search for rcache_position in the hf_medusa.py plugin and
replace it with cache_position in that function/method call).
- Around line 162-171: The loop mutates labels in-place causing misaligned
targets for later heads; preserve the original labels (e.g., store
original_labels = labels before the for loop) and inside the loop compute a
shifted copy like shifted = original_labels[..., 1 + i :].contiguous() (instead
of reassigning labels), then use loss_logits = medusa_logits[i][:, : -(1 +
i)].contiguous(), loss_labels = shifted.view(-1), and compute loss as before
with loss_fct, medusa_decay_coefficient, and medusa_heads_coefficient to avoid
cumulative shifts across iterations.
In `@modelopt/torch/speculative/plugins/hf_spec_mixin.py`:
- Around line 157-165: The code temporarily mutates the global
torch._dynamo.config.suppress_errors but never restores it; wrap the change
around the compile loop by saving the original value of
torch._dynamo.config.suppress_errors, set it to True before iterating
self._compile_targets and compiling each target (using getattr(self, name),
torch.compile(...), setattr(self, name, ...)), and ensure you restore the
original suppress_errors value in a finally block so the global config is
returned to its prior state regardless of compilation success or exceptions.
- Around line 169-175: HFDFlashModel currently inherits
HFSpecDecMixin.get_dummy_inputs which raises NotImplementedError; add an
override in HFDFlashModel (hf_dflash.py) implementing get_dummy_inputs()
following the same shape/keys pattern used by HFEagleModel (so it returns a dict
of dummy tensors/arrays for the export forward pass rather than raising),
ensuring the method signature matches the base and provides the expected keys
consumed by unified_export_hf.py (called at the export flow around
unified_export_hf.py:386) so speculative/offline export does not fail at
runtime.
In `@modelopt/torch/speculative/plugins/modeling_eagle.py`:
- Around line 159-163: The code always assigns self._input_embeds from
self.layers[0].input_layernorm(inputs_embeds) even when the EAGLE-3 pre-hook
that consumes it is not registered; guard the assignment so it only runs when
the pre-hook will consume the value (e.g., check
self.config.use_aux_hidden_state or the same condition used when registering the
pre-hook) to avoid retaining an unnecessary activation tensor. Update the block
that sets self._input_embeds (referencing inputs_embeds,
self.layers[0].input_layernorm, and self._input_embeds) to perform the
assignment only when use_aux_hidden_state (or the hook-registered flag) is true.
---
Nitpick comments:
In `@modelopt/torch/speculative/plugins/modeling_dflash.py`:
- Around line 118-124: The _get_attn_fn method should explicitly handle a None
or missing config._attn_implementation instead of relying on dict.get's default;
update _get_attn_fn to read impl = self.config._attn_implementation and if impl
is None or impl not in ALL_ATTENTION_FUNCTIONS explicitly assign impl = "sdpa"
(or the intended default) before setting self._attn_fn via
ALL_ATTENTION_FUNCTIONS[impl], so callers of _get_attn_fn and readers of the
code see clear, intentional fallback behavior referencing the _get_attn_fn
method, config._attn_implementation, and ALL_ATTENTION_FUNCTIONS.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 0cc18e20-ee50-4dfc-bcc2-bad5a1a6e6e1
📒 Files selected for processing (16)
.pre-commit-config.yamlexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/main.pyexamples/speculative_decoding/scripts/ar_validate.pymodelopt/torch/export/plugins/hf_spec_export.pymodelopt/torch/speculative/config.pymodelopt/torch/speculative/dflash/dflash_model.pymodelopt/torch/speculative/eagle/default_config.pymodelopt/torch/speculative/plugins/__init__.pymodelopt/torch/speculative/plugins/hf_dflash.pymodelopt/torch/speculative/plugins/hf_eagle.pymodelopt/torch/speculative/plugins/hf_medusa.pymodelopt/torch/speculative/plugins/hf_spec_mixin.pymodelopt/torch/speculative/plugins/modeling_dflash.pymodelopt/torch/speculative/plugins/modeling_eagle.pymodelopt/torch/speculative/utils.py
💤 Files with no reviewable changes (1)
- modelopt/torch/speculative/eagle/default_config.py
| inputs_embeds = inputs_embeds.to(hidden_states.dtype).to(hidden_states.device) | ||
| # In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function | ||
| # Also, we normalize input embeddings and hidden states before concatenating them. | ||
| # The default input norm in first layer attn will be disabled. | ||
| self._input_embeds = self.layers[0].input_layernorm(inputs_embeds) |
There was a problem hiding this comment.
Only stash _input_embeds when the EAGLE-3 pre-hook will consume it.
These lines run for every config, but the pre-hook is only registered under use_aux_hidden_state. In the normal path this leaves one extra activation tensor hanging off self for no reason, which increases memory pressure and can keep autograd state alive longer than necessary.
♻️ Proposed fix
inputs_embeds = inputs_embeds.to(hidden_states.dtype).to(hidden_states.device)
# In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function
# Also, we normalize input embeddings and hidden states before concatenating them.
# The default input norm in first layer attn will be disabled.
- self._input_embeds = self.layers[0].input_layernorm(inputs_embeds)
+ if self.config.use_aux_hidden_state:
+ self._input_embeds = self.layers[0].input_layernorm(inputs_embeds)
+ else:
+ self._input_embeds = None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelopt/torch/speculative/plugins/modeling_eagle.py` around lines 159 - 163,
The code always assigns self._input_embeds from
self.layers[0].input_layernorm(inputs_embeds) even when the EAGLE-3 pre-hook
that consumes it is not registered; guard the assignment so it only runs when
the pre-hook will consume the value (e.g., check
self.config.use_aux_hidden_state or the same condition used when registering the
pre-hook) to avoid retaining an unnecessary activation tensor. Update the block
that sets self._input_embeds (referencing inputs_embeds,
self.layers[0].input_layernorm, and self._input_embeds) to perform the
assignment only when use_aux_hidden_state (or the hook-registered flag) is true.
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1271 +/- ##
==========================================
- Coverage 75.58% 72.23% -3.36%
==========================================
Files 459 463 +4
Lines 48613 48661 +48
==========================================
- Hits 36745 35149 -1596
- Misses 11868 13512 +1644
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
6df234f to
f91cf9d
Compare
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
f91cf9d to
f82a9ef
Compare
|
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/__init__.py (1)
16-21:⚠️ Potential issue | 🟡 MinorKeep the plugin list in this docstring in sync with the exports.
The module now re-exports
hf_dflashandhf_medusaat Lines 33-35, so saying the package currently supports onlyhf_eagleis misleading for readers and generated docs. Either list all three HuggingFace entrypoints here or rephrase this bullet as a non-exhaustive example.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/__init__.py` around lines 16 - 21, The docstring currently lists only hf_eagle but the module also re-exports hf_dflash and hf_medusa; update the top-level docstring in modelopt.torch.speculative.plugins (the triple-quoted module docstring) to either enumerate all three plugins (hf_eagle, hf_dflash, hf_medusa) or change the bullet to indicate it is a non-exhaustive/example list; ensure the text matches the actual exports so generated docs are accurate and consistent with the re-exports of hf_eagle, hf_dflash and hf_medusa.
♻️ Duplicate comments (5)
modelopt/torch/speculative/plugins/hf_medusa.py (2)
113-123:⚠️ Potential issue | 🔴 CriticalUse the correct Transformers kwarg here:
cache_position.
rcache_positionis not a recognized forward argument. On models without a**kwargsescape hatch this raises immediately; otherwise the cache position is silently ignored and KV-cache behavior becomes wrong.🐛 Proposed fix
- rcache_position=cache_position, + cache_position=cache_position,In transformers 4.56.0, is `rcache_position` a valid `PreTrainedModel.forward` keyword argument, or should callers use `cache_position`?🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_medusa.py` around lines 113 - 123, The forward call is passing an invalid kwarg name rcache_position to self.model when building outputs; change that argument to the correct Transformers kwarg cache_position so the model receives the cache position (i.e., replace the rcache_position keyword with cache_position in the call that produces outputs) to ensure KV-cache behavior is applied correctly in the self.model(...) invocation.
147-156:⚠️ Potential issue | 🔴 CriticalDon't shift
labelsin-place across heads.Each iteration reuses the already-shifted tensor, so later heads train against cumulatively truncated targets instead of a fresh
(i + 1)-token shift from the original labels.🐛 Proposed fix
+ original_labels = labels # Medusa loss for i in range(self.medusa_num_heads): - labels = labels[..., 1:].contiguous() + shifted_labels = original_labels[..., (i + 1) :].contiguous() loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous() loss_logits = loss_logits.view(-1, loss_logits.shape[-1]) - loss_labels = labels.view(-1) + loss_labels = shifted_labels.view(-1) loss += ( loss_fct(loss_logits, loss_labels) * medusa_decay_coefficient**i * medusa_heads_coefficient )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_medusa.py` around lines 147 - 156, The loop currently mutates labels in-place so each head uses a cumulatively shifted target; instead, capture the original labels before the loop (e.g., original_labels = labels) and inside the loop compute a per-head shifted view like head_labels = original_labels[..., 1 + i :].contiguous(); use head_labels (viewed into loss_labels) when computing loss with medusa_logits[i][:, :-(1 + i)].contiguous(), leaving the original_labels untouched and preserving correct (i+1)-token shifts per head while keeping loss_fct, medusa_decay_coefficient, and medusa_heads_coefficient usage unchanged.modelopt/torch/speculative/plugins/modeling_eagle.py (1)
159-163:⚠️ Potential issue | 🟠 MajorOnly stash
_input_embedswhen the EAGLE-3 hook will consume it.For non-
use_aux_hidden_stateconfigs, this tensor is never used but remains attached toselffor the whole forward pass, which increases memory pressure and can retain autograd state longer than necessary.♻️ Proposed fix
inputs_embeds = inputs_embeds.to(hidden_states.dtype).to(hidden_states.device) # In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function # Also, we normalize input embeddings and hidden states before concatenating them. # The default input norm in first layer attn will be disabled. - self._input_embeds = self.layers[0].input_layernorm(inputs_embeds) + if self.config.use_aux_hidden_state: + self._input_embeds = self.layers[0].input_layernorm(inputs_embeds) + else: + self._input_embeds = None🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/modeling_eagle.py` around lines 159 - 163, The code unconditionally saves self._input_embeds which retains memory/autograd even when not used; change the assignment to only stash the normalized input embeddings when the model is configured to use the EAGLE-3 hook (check the flag use_aux_hidden_state or self.config.use_aux_hidden_state), e.g. compute inputs_embeds = self.layers[0].input_layernorm(inputs_embeds) as now but only set self._input_embeds = ... when the hook flag is true (otherwise avoid storing it or explicitly del/release it), referencing the attribute self._input_embeds and the config flag use_aux_hidden_state and the layer method layers[0].input_layernorm to locate the change.modelopt/torch/speculative/plugins/hf_spec_mixin.py (1)
151-165:⚠️ Potential issue | 🟠 MajorRestore
torch._dynamo.config.suppress_errorsafter compilation.This flips a process-global Dynamo flag and never restores it, so one compile attempt here can silently change later
torch.compile()behavior elsewhere in the same process.♻️ Proposed fix
import torch._dynamo - torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode - - for name, kwargs in self._compile_targets: - try: - setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) - except Exception: # noqa: PERF203 - print(f"Disabling torch.compile for {name} due to compilation error.") + prev_suppress_errors = torch._dynamo.config.suppress_errors + torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode + try: + for name, kwargs in self._compile_targets: + try: + setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) + except Exception: # noqa: PERF203 + print(f"Disabling torch.compile for {name} due to compilation error.") + finally: + torch._dynamo.config.suppress_errors = prev_suppress_errors#!/bin/bash set -euo pipefail sed -n '151,166p' modelopt/torch/speculative/plugins/hf_spec_mixin.py rg -n 'suppress_errors' modelopt/torch/speculative/plugins -C2🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_spec_mixin.py` around lines 151 - 165, The method _activate_torch_compile temporarily flips the process-global torch._dynamo.config.suppress_errors and never restores it; change it to save the original value, set suppress_errors=True for the compile loop, and restore the original value in a finally block (so even if setattr/torch.compile raises in the for-loop or try/except, torch._dynamo.config.suppress_errors is reset). Update the function around the existing import and for name, kwargs in self._compile_targets loop to wrap the change in a try/finally that restores the saved config flag.modelopt/torch/speculative/plugins/hf_dflash.py (1)
159-163:⚠️ Potential issue | 🔴 CriticalImplement
get_dummy_inputs()before advertising export support.
HFDFlashModelnow providesget_exporter(), but it still inheritsHFSpecDecMixin.get_dummy_inputs()frommodelopt/torch/speculative/plugins/hf_spec_mixin.py(Lines 169-171), which raisesNotImplementedError. The first export path that requests dummy inputs will fail at runtime.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/speculative/plugins/hf_dflash.py` around lines 159 - 163, HFDFlashModel advertises export support via get_exporter() but still inherits HFSpecDecMixin.get_dummy_inputs() which raises NotImplementedError; implement a get_dummy_inputs(self, device=None, dtype=None, batch_size=1, seq_len=1) method on HFDFlashModel that returns the same dummy input structure the exporter expects (or delegate to DFlashExporter.get_dummy_inputs(self, ...)) so export callers won't fail—update the HFDFlashModel class to override get_dummy_inputs and produce tensors/inputs matching DFlashExporter input names/shapes.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 145-154: The code assumes every base model has a .layers
attribute; instead use the parts discovered by _find_base_model_parts() or guard
accesses: obtain base_device from self._base_model_lm_head.weight.device as the
safe default, and only override it with
next(self._base_model.layers[-1].parameters()).device when
hasattr(self._base_model, "layers") and self._base_model.layers is non-empty;
likewise only call self._base_model._modules.pop("layers") if "layers" is
actually present in self._base_model._modules. Ensure you still respect
self.dflash_offline, the dtype/device conversion of self.dflash_module, and the
existing base_device.type != "meta" check.
In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Around line 102-103: The parameters medusa_heads_coefficient and
medusa_decay_coefficient are declared as float | None but later used directly in
numeric loss computations; either change their annotations to plain float or
coerce None to a numeric default before the loss loop. Fix by updating the
signature or, if you want to keep None allowed, add normalization at the start
of the function that computes the Medusa loss (e.g., where the loss loop
runs—search for medusa_heads_coefficient/medusa_decay_coefficient usage) and set
medusa_heads_coefficient = 0.2 and medusa_decay_coefficient = 0.8 (or the
desired defaults) if they are None so downstream numeric operations never
receive None.
In `@modelopt/torch/speculative/plugins/modeling_dflash.py`:
- Around line 54-65: The build_target_layer_ids function can produce duplicate
indices when num_draft_layers is close to num_target_layers; after computing the
candidate list (currently returned at the end of build_target_layer_ids)
validate that all IDs are unique (e.g., compare len(set(ids)) to
num_draft_layers) and if duplicates exist raise a ValueError rejecting the
config with a clear message (mentioning num_target_layers and num_draft_layers
and advising to reduce num_draft_layers or increase num_target_layers). Keep
this validation inside build_target_layer_ids so callers receive an explicit
error rather than silently feeding the same hidden state multiple times to the
fusion block.
---
Outside diff comments:
In `@modelopt/torch/speculative/plugins/__init__.py`:
- Around line 16-21: The docstring currently lists only hf_eagle but the module
also re-exports hf_dflash and hf_medusa; update the top-level docstring in
modelopt.torch.speculative.plugins (the triple-quoted module docstring) to
either enumerate all three plugins (hf_eagle, hf_dflash, hf_medusa) or change
the bullet to indicate it is a non-exhaustive/example list; ensure the text
matches the actual exports so generated docs are accurate and consistent with
the re-exports of hf_eagle, hf_dflash and hf_medusa.
---
Duplicate comments:
In `@modelopt/torch/speculative/plugins/hf_dflash.py`:
- Around line 159-163: HFDFlashModel advertises export support via
get_exporter() but still inherits HFSpecDecMixin.get_dummy_inputs() which raises
NotImplementedError; implement a get_dummy_inputs(self, device=None, dtype=None,
batch_size=1, seq_len=1) method on HFDFlashModel that returns the same dummy
input structure the exporter expects (or delegate to
DFlashExporter.get_dummy_inputs(self, ...)) so export callers won't fail—update
the HFDFlashModel class to override get_dummy_inputs and produce tensors/inputs
matching DFlashExporter input names/shapes.
In `@modelopt/torch/speculative/plugins/hf_medusa.py`:
- Around line 113-123: The forward call is passing an invalid kwarg name
rcache_position to self.model when building outputs; change that argument to the
correct Transformers kwarg cache_position so the model receives the cache
position (i.e., replace the rcache_position keyword with cache_position in the
call that produces outputs) to ensure KV-cache behavior is applied correctly in
the self.model(...) invocation.
- Around line 147-156: The loop currently mutates labels in-place so each head
uses a cumulatively shifted target; instead, capture the original labels before
the loop (e.g., original_labels = labels) and inside the loop compute a per-head
shifted view like head_labels = original_labels[..., 1 + i :].contiguous(); use
head_labels (viewed into loss_labels) when computing loss with
medusa_logits[i][:, :-(1 + i)].contiguous(), leaving the original_labels
untouched and preserving correct (i+1)-token shifts per head while keeping
loss_fct, medusa_decay_coefficient, and medusa_heads_coefficient usage
unchanged.
In `@modelopt/torch/speculative/plugins/hf_spec_mixin.py`:
- Around line 151-165: The method _activate_torch_compile temporarily flips the
process-global torch._dynamo.config.suppress_errors and never restores it;
change it to save the original value, set suppress_errors=True for the compile
loop, and restore the original value in a finally block (so even if
setattr/torch.compile raises in the for-loop or try/except,
torch._dynamo.config.suppress_errors is reset). Update the function around the
existing import and for name, kwargs in self._compile_targets loop to wrap the
change in a try/finally that restores the saved config flag.
In `@modelopt/torch/speculative/plugins/modeling_eagle.py`:
- Around line 159-163: The code unconditionally saves self._input_embeds which
retains memory/autograd even when not used; change the assignment to only stash
the normalized input embeddings when the model is configured to use the EAGLE-3
hook (check the flag use_aux_hidden_state or self.config.use_aux_hidden_state),
e.g. compute inputs_embeds = self.layers[0].input_layernorm(inputs_embeds) as
now but only set self._input_embeds = ... when the hook flag is true (otherwise
avoid storing it or explicitly del/release it), referencing the attribute
self._input_embeds and the config flag use_aux_hidden_state and the layer method
layers[0].input_layernorm to locate the change.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 1ffffb46-7fb0-4b7d-8ca7-5e7583948d47
📒 Files selected for processing (16)
.pre-commit-config.yamlexamples/speculative_decoding/eagle_utils.pyexamples/speculative_decoding/main.pyexamples/speculative_decoding/scripts/ar_validate.pymodelopt/torch/export/plugins/hf_spec_export.pymodelopt/torch/speculative/config.pymodelopt/torch/speculative/dflash/dflash_model.pymodelopt/torch/speculative/eagle/default_config.pymodelopt/torch/speculative/plugins/__init__.pymodelopt/torch/speculative/plugins/hf_dflash.pymodelopt/torch/speculative/plugins/hf_eagle.pymodelopt/torch/speculative/plugins/hf_medusa.pymodelopt/torch/speculative/plugins/hf_spec_mixin.pymodelopt/torch/speculative/plugins/modeling_dflash.pymodelopt/torch/speculative/plugins/modeling_eagle.pymodelopt/torch/speculative/utils.py
💤 Files with no reviewable changes (1)
- modelopt/torch/speculative/eagle/default_config.py
✅ Files skipped from review due to trivial changes (3)
- modelopt/torch/speculative/utils.py
- .pre-commit-config.yaml
- examples/speculative_decoding/eagle_utils.py
🚧 Files skipped from review as they are similar to previous changes (5)
- examples/speculative_decoding/scripts/ar_validate.py
- modelopt/torch/export/plugins/hf_spec_export.py
- modelopt/torch/speculative/dflash/dflash_model.py
- modelopt/torch/speculative/config.py
- modelopt/torch/speculative/plugins/hf_eagle.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/unit/torch/export/test_hf_spec_rope_export.py (1)
68-72: Consider parameterizing these two fallback tests to remove duplication.Line 61-72 can be a single
pytest.mark.parametrizetest with different(rope_type, eagle_export_rope_scaling, expected)tuples.♻️ Proposed refactor
+import pytest + +@pytest.mark.parametrize( + "rope_type,eagle_export_rope_scaling,expected_rope_scaling", + [ + ("llama3", DEFAULT_ROPE_SCALING, {"rope_type": "llama3", "rope_theta": 10000}), + ("default", {}, {"rope_type": "default", "rope_theta": 10000}), + ], +) +def test_rope_scaling_falls_through_to_training_config( + rope_type, eagle_export_rope_scaling, expected_rope_scaling +): + """Export override is not applied when conditions are unmet; fallback uses training config.""" + config = _make_exporter( + rope_type=rope_type, + eagle_export_rope_scaling=eagle_export_rope_scaling, + )._export_config() + assert config["rope_scaling"] == expected_rope_scaling - -def test_rope_not_overridden_when_non_default_training_rope(): - """Export override is not applied when training rope_type is not 'default'; - rope_scaling falls through to the training config.""" - config = _make_exporter(rope_type="llama3")._export_config() - assert config["rope_scaling"] == {"rope_type": "llama3", "rope_theta": 10000} - - -def test_rope_not_overridden_when_eagle_export_rope_scaling_is_empty(): - """Export override is not applied when eagle_export_rope_scaling is empty; - rope_scaling falls through to the training config.""" - config = _make_exporter(eagle_export_rope_scaling={})._export_config() - assert config["rope_scaling"] == {"rope_type": "default", "rope_theta": 10000}🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/export/test_hf_spec_rope_export.py` around lines 68 - 72, Combine the two duplicate fallback tests into a single parametrized pytest by replacing the separate test functions (e.g., test_rope_not_overridden_when_eagle_export_rope_scaling_is_empty) with one test decorated with pytest.mark.parametrize over tuples of (rope_type, eagle_export_rope_scaling, expected); inside the test call _make_exporter(...)._export_config() with the parametrized eagle_export_rope_scaling and assert config["rope_scaling"] == expected, keeping the original docstring/context as needed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/unit/torch/export/test_hf_spec_rope_export.py`:
- Around line 68-72: Combine the two duplicate fallback tests into a single
parametrized pytest by replacing the separate test functions (e.g.,
test_rope_not_overridden_when_eagle_export_rope_scaling_is_empty) with one test
decorated with pytest.mark.parametrize over tuples of (rope_type,
eagle_export_rope_scaling, expected); inside the test call
_make_exporter(...)._export_config() with the parametrized
eagle_export_rope_scaling and assert config["rope_scaling"] == expected, keeping
the original docstring/context as needed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 5c5d29d0-b9d3-4a22-8c05-21a28a3ed5e2
📒 Files selected for processing (1)
tests/unit/torch/export/test_hf_spec_rope_export.py
|
Thanks for working on this! I have a number of concerns with the current state of this PR. Given the volume of issues below, I don't think this is ready for merge — it needs significant rework. 1. Please split this into separate PRsThis PR bundles 5 independent changes: file renames, mixin extraction, offline DFlash, ParallelDraft removal, and config validation changes. These are hard to review together and make rollback difficult. I'd suggest at minimum:
2. The mixin is too broadThe core idea of deduplicating base-model discovery (
Suggestion: keep the mixin slim — just the base-model property discovery. Move compile, NVTX, export, and forward logic back to the individual classes. 3. Breaking changes need a migration path
4. Offline DFlash implementation has multiple bugs
5. No tests for new feature
6. Minor issues
|
@yeyu-nvidia , HI Ye, did you look at the code yourself or at least go through the comment by AI? I think over half of the "concerns" (2 and 3, and partially 4) are non-sense and it's a waste of time to explain item by item. Let me know if you agree any of them. There are a few points that make sense to me. Will address them. |
|
@yeyu-nvidia Since the original comment is too long, I put the item-by-item technical feedbacks here: https://docs.google.com/document/d/1Uc4yj-bmvOA4zMAhh7QQTedn0_7BDwPO1LaiC1AlE7Q/edit?usp=sharing |
What does this PR do?
Type of change: New feature, refactoring
HFSpecDecMixin: Extract duplicated base-model discovery, forward pass, NVTX profiling, andtorch.compilelogic fromHFEagleModel/HFDFlashModelinto a shared mixin.dflash_offlineconfig flag for training from pre-computed hidden states; deletes base model layers to save memory.ParallelDraft: Removeparallel_draft_step,ParallelDraftmodule, and all related logic from Eagle.transformers.py→hf_eagle.py;HFMedusaModel→hf_medusa.py;DFlashModule→modeling_dflash.py;EagleModule→modeling_eagle.py.dflash_mask_token_idauto-detection frommain.pyintoDFlashConfigPydantic validators.Testing
Validated with existing Eagle and DFlash training scripts (online + offline modes).

Before your PR is "Ready for review"
Make sure you read and follow Contributor guidelines and your commits are signed (
git commit -s -S).Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded
trust_remote_code=True,torch.load(..., weights_only=False),pickle, etc.).parallel_draft_stepconfig; renamestransformers.py→hf_eagle.pyCONTRIBUTING.md: N/AAdditional Information
Breaking changes:
modelopt.torch.speculative.plugins.transformers→.hf_eagleparallel_draft_step/parallel_draft_heads_num_layersremoved from Eagle config_draft_model_config→eagle_configin export pluginSummary by CodeRabbit
New Features
Improvements
Chores
Tests