Skip to content

Commit 7c80d85

Browse files
authored
[1/3][Refactor]: File reorg; deprecate ParallelDraft (#1296)
### What does this PR do? Type of change: refactoring Part 1 of a 3-PR series splitting #1271: - **[1/3] this PR**: File reorg + deprecate `ParallelDraft` - **[2/3] #1295**: Offline DFlash training - **[3/3] #1297**: Extract `HFSpecDecMixin` Changes: - **File reorg**: `transformers.py` → `hf_eagle.py`; extract `HFMedusaModel` → `hf_medusa.py`; extract `EagleModule` / `EagleBaseModelOutput` → `modeling_eagle.py`; extract `DFlashModule` / `DFlashAttention` / `DFlashDecoderLayer` / `build_target_layer_ids` / `apply_rotary_pos_emb` → `modeling_dflash.py`. - **Deprecate `ParallelDraft`**: remove `parallel_draft_step`, `parallel_draft_heads_num_layers`, and the `ParallelDraft` module from HF Eagle; remove the `EagleMedusaExporter` branch from `HFEagleModel.get_exporter()` (the `EagleMedusaExporter` class itself still lives in `hf_spec_export.py` for Megatron parity). - **Rename**: `_draft_model_config` → `eagle_config` in export plugin. - Update imports in `examples/speculative_decoding/` and `modelopt/torch/speculative/utils.py` to follow the module rename. ### Testing Validated with existing Eagle and DFlash training scripts (re-run after `9ae5302729 revert behavior change`). ### Before your PR is "*Ready for review*" Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md) and your commits are signed (`git commit -s -S`). Make sure you read and follow the [Security Best Practices](https://github.com/NVIDIA/Model-Optimizer/blob/main/SECURITY.md#security-coding-practices-for-contributors) (e.g. avoiding hardcoded `trust_remote_code=True`, `torch.load(..., weights_only=False)`, `pickle`, etc.). - Is this change backward compatible?: ❌ — renames `modelopt.torch.speculative.plugins.transformers` → `.hf_eagle`; removes `parallel_draft_step` / `parallel_draft_heads_num_layers` from Eagle config; renames `_draft_model_config` → `eagle_config` in export plugin. - If you copied code from any other sources or added a new PIP dependency, did you follow guidance in `CONTRIBUTING.md`: N/A - Did you write any new necessary tests?: N/A — pure refactor; existing tests updated for the rename. `test_hf_spec_rope_export.py` assertions were also corrected to reflect the actual production path (the old assertions were masked by `MagicMock` not invoking the `_draft_model_config` `@property`). - Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?: ❌ ### Additional Information Breaking changes: - `modelopt.torch.speculative.plugins.transformers` → `.hf_eagle` - `parallel_draft_step` / `parallel_draft_heads_num_layers` removed from Eagle config - `_draft_model_config` → `eagle_config` in export plugin <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Refactoring** * Reorganized speculative-decoding plugins into focused modules, converting the legacy "transformers" entry into a deprecated shim that re-exports the new plugin surface. * Consolidated DFlash implementation into a shared modeling component and introduced a dedicated EAGLE decoder module. * **New Features** * Added a Medusa speculative-decoding plugin with configurable heads and combined-loss training behavior. * **Chores** * Updated pre-commit license-hook exclusion and feature-flag wiring. * **Tests** * Updated export tests to expect rope-scaling fallback semantics. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 946639a commit 7c80d85

14 files changed

Lines changed: 1511 additions & 1464 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ repos:
9999
modelopt/torch/quantization/plugins/attention.py|
100100
modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py|
101101
modelopt/torch/speculative/eagle/utils.py|
102-
modelopt/torch/speculative/plugins/transformers.py|
102+
modelopt/torch/speculative/plugins/hf_medusa.py|
103103
modelopt/torch/utils/plugins/megatron_mmlu.py|
104104
examples/chained_optimizations/bert_prune_distill_quantize.py|
105105
examples/deepseek/quantize_to_nvfp4.py|

examples/speculative_decoding/eagle_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def patched_templated_attn(*args, **kwargs):
358358
original_op = args[2]
359359

360360
# This patch is only enabled for eagle model by context manager, not base model.
361-
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH
361+
patch_enbabled = modelopt.torch.speculative.plugins.hf_eagle.ENABLE_CP_TTT_PATCH
362362

363363
if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
364364
raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}")

examples/speculative_decoding/scripts/ar_validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from transformers import AutoTokenizer
2828

2929
import modelopt.torch.opt as mto
30-
from modelopt.torch.speculative.plugins.transformers import HFARValidation
30+
from modelopt.torch.speculative.plugins.hf_eagle import HFARValidation
3131
from modelopt.torch.speculative.utils import load_vlm_or_llm
3232

3333
mto.enable_huggingface_checkpointing()

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ def _export_config(self):
171171
template_config = deepcopy(template_config)
172172

173173
def _get_config_from_draft_or_base(key: str, model: nn.Module):
174-
if getattr(model._draft_model_config, key, None) is not None:
175-
return getattr(model._draft_model_config, key)
174+
if getattr(model.eagle_config, key, None) is not None:
175+
return getattr(model.eagle_config, key)
176176
elif getattr(model.config, key, None) is not None:
177177
return getattr(model.config, key)
178178
else:

modelopt/torch/speculative/eagle/default_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"use_aux_hidden_state": False,
3838
"eagle_aux_hidden_state_layer_ids": [],
3939
"use_mtp_layernorm": False,
40+
# Deprecated on the HF flow; TODO: remove once the Megatron flow stops reading these.
4041
"parallel_draft_step": 1,
4142
"parallel_draft_heads_num_layers": 1,
4243
"has_lm_head": False,
@@ -107,6 +108,7 @@
107108
"use_aux_hidden_state": True,
108109
"eagle_aux_hidden_state_layer_ids": [],
109110
"use_mtp_layernorm": False,
111+
# Deprecated on the HF flow; TODO: remove once the Megatron flow stops reading these.
110112
"parallel_draft_step": 1,
111113
"parallel_draft_heads_num_layers": 1,
112114
"has_lm_head": False,

modelopt/torch/speculative/plugins/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Please check out the source code of this module for examples of how plugins work and how you can
1919
write your own one. Currently, we support plugins for
2020
21-
- :meth:`transformers<modelopt.torch.speculative.plugins.transformers>`
21+
- :meth:`hf_eagle<modelopt.torch.speculative.plugins.hf_eagle>`
2222
"""
2323

2424
from modelopt.torch.utils import import_plugin
@@ -31,4 +31,5 @@
3131

3232
with import_plugin("transformers"):
3333
from .hf_dflash import *
34-
from .transformers import *
34+
from .hf_eagle import *
35+
from .hf_medusa import *

modelopt/torch/speculative/plugins/hf_dflash.py

Lines changed: 1 addition & 214 deletions
Original file line numberDiff line numberDiff line change
@@ -54,234 +54,21 @@
5454

5555
import torch
5656
import torch.nn.functional as F
57-
from torch import nn
5857
from transformers import PreTrainedModel
59-
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
6058
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config as _Qwen3Config
61-
from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP as _MLP_CLS # noqa: N814
62-
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm as _NORM_CLS # noqa: N814
63-
from transformers.models.qwen3.modeling_qwen3 import (
64-
Qwen3RotaryEmbedding as _ROTARY_CLS, # noqa: N814
65-
)
66-
from transformers.models.qwen3.modeling_qwen3 import rotate_half as _rotate_half
6759
from transformers.trainer_pt_utils import LabelSmoother
6860
from transformers.utils import ModelOutput
6961

7062
from ..dflash.conversion import DFlashDMRegistry
7163
from ..dflash.dflash_model import DFlashModel
64+
from .modeling_dflash import DFlashAttention, DFlashModule, build_target_layer_ids # noqa: F401
7265
from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS
7366

7467
logger = logging.getLogger(__name__)
7568

7669
__all__ = ["HFDFlashModel"]
7770

7871

79-
def build_target_layer_ids(num_target_layers, num_draft_layers):
80-
"""Select layers uniformly from the target model for feature extraction."""
81-
if num_target_layers < num_draft_layers:
82-
raise ValueError(
83-
f"num_target_layers ({num_target_layers}) must be >= num_draft_layers ({num_draft_layers})"
84-
)
85-
if num_draft_layers == 1:
86-
return [num_target_layers // 2]
87-
start = min(1, num_target_layers - 1)
88-
end = max(start, num_target_layers - 3)
89-
span = end - start
90-
return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)]
91-
92-
93-
def apply_rotary_pos_emb(q, k, cos, sin):
94-
"""Apply RoPE. Q uses last q_len positions, K uses all positions."""
95-
cos = cos.unsqueeze(1) # [B, 1, seq, dim]
96-
sin = sin.unsqueeze(1)
97-
q_len = q.size(2)
98-
q_embed = (q * cos[:, :, -q_len:, :]) + (_rotate_half(q) * sin[:, :, -q_len:, :])
99-
k_embed = (k * cos) + (_rotate_half(k) * sin)
100-
return q_embed, k_embed
101-
102-
103-
class DFlashAttention(nn.Module):
104-
"""Attention with KV injection, using HF's attention dispatch."""
105-
106-
def __init__(self, config, layer_idx):
107-
"""Initialize DFlash attention with KV injection projections and QK-norm."""
108-
super().__init__()
109-
self.config = config
110-
self.layer_idx = layer_idx
111-
self.head_dim = getattr(
112-
config, "head_dim", config.hidden_size // config.num_attention_heads
113-
)
114-
self.num_heads = config.num_attention_heads
115-
self.num_kv_heads = config.num_key_value_heads
116-
self.num_key_value_groups = self.num_heads // self.num_kv_heads
117-
self.scaling = self.head_dim**-0.5
118-
self.attention_dropout = getattr(config, "attention_dropout", 0.0)
119-
self.is_causal = False
120-
121-
attn_bias = getattr(config, "attention_bias", False)
122-
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=attn_bias)
123-
self.k_proj = nn.Linear(
124-
config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias
125-
)
126-
self.v_proj = nn.Linear(
127-
config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias
128-
)
129-
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=attn_bias)
130-
131-
self.q_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps)
132-
self.k_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps)
133-
134-
# Resolve HF attention function
135-
self._attn_fn = None
136-
# Qwen3 uses sliding window attention on some layers (config.layer_types)
137-
if hasattr(config, "layer_types") and hasattr(config, "sliding_window"):
138-
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
139-
self.sliding_window = config.sliding_window if is_sliding else None
140-
else:
141-
self.sliding_window = None
142-
143-
def _get_attn_fn(self):
144-
"""Lazily resolve the HF attention function (default: sdpa)."""
145-
if self._attn_fn is not None:
146-
return self._attn_fn
147-
impl = self.config._attn_implementation # default set in dflash/default_config.py
148-
self._attn_fn = ALL_ATTENTION_FUNCTIONS.get(impl, ALL_ATTENTION_FUNCTIONS["sdpa"])
149-
return self._attn_fn
150-
151-
def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None):
152-
"""Forward with KV injection.
153-
154-
Q is projected from the noise block (draft token embeddings: [anchor, mask, mask, ...]).
155-
K and V are projected from the concatenation of target hidden states (context from the
156-
base model) and noise block, so the draft can attend to both context and its own block.
157-
"""
158-
bsz, q_len, _ = hidden_states.shape
159-
ctx_len = target_hidden.shape[1]
160-
161-
# Q from noise block only (the draft tokens being predicted), with QK-norm
162-
q = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim)
163-
q = self.q_norm(q).transpose(1, 2)
164-
165-
# K from context + noise, with QK-norm
166-
k_ctx = self.k_proj(target_hidden)
167-
k_noise = self.k_proj(hidden_states)
168-
k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim)
169-
k = self.k_norm(k).transpose(1, 2)
170-
171-
# V from context + noise (no norm)
172-
v_ctx = self.v_proj(target_hidden)
173-
v_noise = self.v_proj(hidden_states)
174-
v = (
175-
torch.cat([v_ctx, v_noise], dim=1)
176-
.view(bsz, ctx_len + q_len, -1, self.head_dim)
177-
.transpose(1, 2)
178-
)
179-
180-
# RoPE
181-
cos, sin = position_embeddings
182-
q, k = apply_rotary_pos_emb(q, k, cos, sin)
183-
184-
# Use HF's attention dispatch (handles GQA internally)
185-
attn_fn = self._get_attn_fn()
186-
attn_output, _ = attn_fn(
187-
self,
188-
q,
189-
k,
190-
v,
191-
attention_mask,
192-
dropout=0.0 if not self.training else self.attention_dropout,
193-
scaling=self.scaling,
194-
sliding_window=self.sliding_window,
195-
)
196-
attn_output = attn_output.reshape(bsz, q_len, -1)
197-
return self.o_proj(attn_output)
198-
199-
200-
class DFlashDecoderLayer(nn.Module):
201-
"""Draft decoder layer with KV injection."""
202-
203-
def __init__(self, config, layer_idx):
204-
"""Initialize decoder layer with attention, MLP, and layer norms."""
205-
super().__init__()
206-
self.self_attn = DFlashAttention(config, layer_idx)
207-
self.mlp = _MLP_CLS(config)
208-
self.input_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps)
209-
self.post_attention_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps)
210-
211-
def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None):
212-
"""Forward pass with residual connections."""
213-
residual = hidden_states
214-
hidden_states = self.input_layernorm(hidden_states)
215-
hidden_states = self.self_attn(
216-
hidden_states, target_hidden, position_embeddings, attention_mask
217-
)
218-
hidden_states = residual + hidden_states
219-
220-
residual = hidden_states
221-
hidden_states = self.post_attention_layernorm(hidden_states)
222-
hidden_states = self.mlp(hidden_states)
223-
hidden_states = residual + hidden_states
224-
return hidden_states
225-
226-
227-
class DFlashModule(nn.Module):
228-
"""DFlash draft module using Qwen3 components (MLP, RMSNorm, RotaryEmbedding)."""
229-
230-
def __init__(self, config):
231-
"""Initialize DFlash module with feature fusion, decoder layers, and rotary embeddings."""
232-
super().__init__()
233-
self.config = config
234-
self.block_size = config.block_size
235-
236-
# Feature fusion
237-
num_fused_layers = len(config.target_layer_ids)
238-
self.fc = nn.Linear(num_fused_layers * config.hidden_size, config.hidden_size, bias=False)
239-
self.hidden_norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps)
240-
241-
# Decoder layers
242-
self.layers = nn.ModuleList(
243-
[DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
244-
)
245-
self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps)
246-
self._rotary_config = config # Used by _maybe_init_rotary_emb
247-
248-
# Explicit weight init is needed because DFlashModule is instantiated via
249-
# mtsp.convert() AFTER the base model's post_init() has already run, so HF's
250-
# automatic _init_weights walk doesn't reach these new layers.
251-
self._init_weights(config)
252-
253-
def _maybe_init_rotary_emb(self, device=None):
254-
"""Lazily initialize rotary embeddings on first forward call.
255-
256-
Same pattern as EAGLE3's _maybe_init_rope. Avoids creating rotary_emb
257-
during __init__ (which runs on meta device during from_pretrained),
258-
preventing the meta-tensor inv_freq issue on checkpoint resume.
259-
"""
260-
if not hasattr(self, "rotary_emb"):
261-
self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device=device)
262-
263-
def _init_weights(self, config):
264-
"""Initialize weights matching HF PreTrainedModel._init_weights."""
265-
std = getattr(config, "initializer_range", 0.02)
266-
for module in self.modules():
267-
if isinstance(module, nn.Linear):
268-
nn.init.normal_(module.weight, mean=0.0, std=std)
269-
if module.bias is not None:
270-
nn.init.zeros_(module.bias)
271-
272-
def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=None):
273-
"""Forward with feature fusion, KV injection, and position embeddings."""
274-
hidden_states = noise_embedding
275-
target_hidden = self.hidden_norm(self.fc(target_hidden))
276-
self._maybe_init_rotary_emb(device=hidden_states.device)
277-
position_embeddings = self.rotary_emb(hidden_states, position_ids)
278-
279-
for layer in self.layers:
280-
hidden_states = layer(hidden_states, target_hidden, position_embeddings, attention_mask)
281-
282-
return self.norm(hidden_states)
283-
284-
28572
@DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
28673
class HFDFlashModel(DFlashModel):
28774
"""DFlash Model for HuggingFace transformers."""

0 commit comments

Comments
 (0)