Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ repos:
modelopt/torch/quantization/plugins/attention.py|
modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py|
modelopt/torch/speculative/eagle/utils.py|
modelopt/torch/speculative/plugins/transformers.py|
modelopt/torch/speculative/plugins/hf_medusa.py|
modelopt/torch/utils/plugins/megatron_mmlu.py|
examples/chained_optimizations/bert_prune_distill_quantize.py|
examples/deepseek/quantize_to_nvfp4.py|
Expand Down
2 changes: 1 addition & 1 deletion examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def patched_templated_attn(*args, **kwargs):
original_op = args[2]

# This patch is only enabled for eagle model by context manager, not base model.
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH
patch_enbabled = modelopt.torch.speculative.plugins.hf_eagle.ENABLE_CP_TTT_PATCH

if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}")
Expand Down
2 changes: 1 addition & 1 deletion examples/speculative_decoding/scripts/ar_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from transformers import AutoTokenizer

import modelopt.torch.opt as mto
from modelopt.torch.speculative.plugins.transformers import HFARValidation
from modelopt.torch.speculative.plugins.hf_eagle import HFARValidation
from modelopt.torch.speculative.utils import load_vlm_or_llm

mto.enable_huggingface_checkpointing()
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/export/plugins/hf_spec_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def _export_config(self):
template_config = deepcopy(template_config)

def _get_config_from_draft_or_base(key: str, model: nn.Module):
if getattr(model._draft_model_config, key, None) is not None:
return getattr(model._draft_model_config, key)
if getattr(model.eagle_config, key, None) is not None:
return getattr(model.eagle_config, key)
elif getattr(model.config, key, None) is not None:
return getattr(model.config, key)
else:
Expand Down
4 changes: 0 additions & 4 deletions modelopt/torch/speculative/eagle/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
"use_aux_hidden_state": False,
"eagle_aux_hidden_state_layer_ids": [],
"use_mtp_layernorm": False,
"parallel_draft_step": 1,
"parallel_draft_heads_num_layers": 1,
"has_lm_head": False,
"head_dim": 128,
}
Expand Down Expand Up @@ -107,7 +105,5 @@
"use_aux_hidden_state": True,
"eagle_aux_hidden_state_layer_ids": [],
"use_mtp_layernorm": False,
"parallel_draft_step": 1,
"parallel_draft_heads_num_layers": 1,
"has_lm_head": False,
}
5 changes: 3 additions & 2 deletions modelopt/torch/speculative/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Please check out the source code of this module for examples of how plugins work and how you can
write your own one. Currently, we support plugins for

- :meth:`transformers<modelopt.torch.speculative.plugins.transformers>`
- :meth:`hf_eagle<modelopt.torch.speculative.plugins.hf_eagle>`
"""

from modelopt.torch.utils import import_plugin
Expand All @@ -31,4 +31,5 @@

with import_plugin("transformers"):
from .hf_dflash import *
from .transformers import *
from .hf_eagle import *
from .hf_medusa import *
215 changes: 1 addition & 214 deletions modelopt/torch/speculative/plugins/hf_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,234 +54,21 @@

import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config as _Qwen3Config
from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP as _MLP_CLS # noqa: N814
from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm as _NORM_CLS # noqa: N814
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3RotaryEmbedding as _ROTARY_CLS, # noqa: N814
)
from transformers.models.qwen3.modeling_qwen3 import rotate_half as _rotate_half
from transformers.trainer_pt_utils import LabelSmoother
from transformers.utils import ModelOutput

from ..dflash.conversion import DFlashDMRegistry
from ..dflash.dflash_model import DFlashModel
from .modeling_dflash import DFlashAttention, DFlashModule, build_target_layer_ids # noqa: F401
from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS

logger = logging.getLogger(__name__)

__all__ = ["HFDFlashModel"]


def build_target_layer_ids(num_target_layers, num_draft_layers):
"""Select layers uniformly from the target model for feature extraction."""
if num_target_layers < num_draft_layers:
raise ValueError(
f"num_target_layers ({num_target_layers}) must be >= num_draft_layers ({num_draft_layers})"
)
if num_draft_layers == 1:
return [num_target_layers // 2]
start = min(1, num_target_layers - 1)
end = max(start, num_target_layers - 3)
span = end - start
return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)]


def apply_rotary_pos_emb(q, k, cos, sin):
"""Apply RoPE. Q uses last q_len positions, K uses all positions."""
cos = cos.unsqueeze(1) # [B, 1, seq, dim]
sin = sin.unsqueeze(1)
q_len = q.size(2)
q_embed = (q * cos[:, :, -q_len:, :]) + (_rotate_half(q) * sin[:, :, -q_len:, :])
k_embed = (k * cos) + (_rotate_half(k) * sin)
return q_embed, k_embed


class DFlashAttention(nn.Module):
"""Attention with KV injection, using HF's attention dispatch."""

def __init__(self, config, layer_idx):
"""Initialize DFlash attention with KV injection projections and QK-norm."""
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
self.num_heads = config.num_attention_heads
self.num_kv_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_kv_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = getattr(config, "attention_dropout", 0.0)
self.is_causal = False

attn_bias = getattr(config, "attention_bias", False)
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=attn_bias)
self.k_proj = nn.Linear(
config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias
)
self.v_proj = nn.Linear(
config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias
)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=attn_bias)

self.q_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps)

# Resolve HF attention function
self._attn_fn = None
# Qwen3 uses sliding window attention on some layers (config.layer_types)
if hasattr(config, "layer_types") and hasattr(config, "sliding_window"):
is_sliding = config.layer_types[layer_idx] == "sliding_attention"
self.sliding_window = config.sliding_window if is_sliding else None
else:
self.sliding_window = None

def _get_attn_fn(self):
"""Lazily resolve the HF attention function (default: sdpa)."""
if self._attn_fn is not None:
return self._attn_fn
impl = self.config._attn_implementation # default set in dflash/default_config.py
self._attn_fn = ALL_ATTENTION_FUNCTIONS.get(impl, ALL_ATTENTION_FUNCTIONS["sdpa"])
return self._attn_fn

def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None):
"""Forward with KV injection.

Q is projected from the noise block (draft token embeddings: [anchor, mask, mask, ...]).
K and V are projected from the concatenation of target hidden states (context from the
base model) and noise block, so the draft can attend to both context and its own block.
"""
bsz, q_len, _ = hidden_states.shape
ctx_len = target_hidden.shape[1]

# Q from noise block only (the draft tokens being predicted), with QK-norm
q = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim)
q = self.q_norm(q).transpose(1, 2)

# K from context + noise, with QK-norm
k_ctx = self.k_proj(target_hidden)
k_noise = self.k_proj(hidden_states)
k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim)
k = self.k_norm(k).transpose(1, 2)

# V from context + noise (no norm)
v_ctx = self.v_proj(target_hidden)
v_noise = self.v_proj(hidden_states)
v = (
torch.cat([v_ctx, v_noise], dim=1)
.view(bsz, ctx_len + q_len, -1, self.head_dim)
.transpose(1, 2)
)

# RoPE
cos, sin = position_embeddings
q, k = apply_rotary_pos_emb(q, k, cos, sin)

# Use HF's attention dispatch (handles GQA internally)
attn_fn = self._get_attn_fn()
attn_output, _ = attn_fn(
self,
q,
k,
v,
attention_mask,
dropout=0.0 if not self.training else self.attention_dropout,
scaling=self.scaling,
sliding_window=self.sliding_window,
)
attn_output = attn_output.reshape(bsz, q_len, -1)
return self.o_proj(attn_output)


class DFlashDecoderLayer(nn.Module):
"""Draft decoder layer with KV injection."""

def __init__(self, config, layer_idx):
"""Initialize decoder layer with attention, MLP, and layer norms."""
super().__init__()
self.self_attn = DFlashAttention(config, layer_idx)
self.mlp = _MLP_CLS(config)
self.input_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps)

def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None):
"""Forward pass with residual connections."""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states, target_hidden, position_embeddings, attention_mask
)
hidden_states = residual + hidden_states

residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states


class DFlashModule(nn.Module):
"""DFlash draft module using Qwen3 components (MLP, RMSNorm, RotaryEmbedding)."""

def __init__(self, config):
"""Initialize DFlash module with feature fusion, decoder layers, and rotary embeddings."""
super().__init__()
self.config = config
self.block_size = config.block_size

# Feature fusion
num_fused_layers = len(config.target_layer_ids)
self.fc = nn.Linear(num_fused_layers * config.hidden_size, config.hidden_size, bias=False)
self.hidden_norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps)

# Decoder layers
self.layers = nn.ModuleList(
[DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps)
self._rotary_config = config # Used by _maybe_init_rotary_emb

# Explicit weight init is needed because DFlashModule is instantiated via
# mtsp.convert() AFTER the base model's post_init() has already run, so HF's
# automatic _init_weights walk doesn't reach these new layers.
self._init_weights(config)

def _maybe_init_rotary_emb(self, device=None):
"""Lazily initialize rotary embeddings on first forward call.

Same pattern as EAGLE3's _maybe_init_rope. Avoids creating rotary_emb
during __init__ (which runs on meta device during from_pretrained),
preventing the meta-tensor inv_freq issue on checkpoint resume.
"""
if not hasattr(self, "rotary_emb"):
self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device=device)

def _init_weights(self, config):
"""Initialize weights matching HF PreTrainedModel._init_weights."""
std = getattr(config, "initializer_range", 0.02)
for module in self.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=std)
if module.bias is not None:
nn.init.zeros_(module.bias)

def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=None):
"""Forward with feature fusion, KV injection, and position embeddings."""
hidden_states = noise_embedding
target_hidden = self.hidden_norm(self.fc(target_hidden))
self._maybe_init_rotary_emb(device=hidden_states.device)
position_embeddings = self.rotary_emb(hidden_states, position_ids)

for layer in self.layers:
hidden_states = layer(hidden_states, target_hidden, position_embeddings, attention_mask)

return self.norm(hidden_states)


@DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
class HFDFlashModel(DFlashModel):
"""DFlash Model for HuggingFace transformers."""
Expand Down
Loading
Loading