From f488231447256f22333eb84002fb0d78f54d954f Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sun, 19 Apr 2026 21:46:56 +0000 Subject: [PATCH 1/3] reorg files Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .pre-commit-config.yaml | 2 +- examples/speculative_decoding/eagle_utils.py | 2 +- .../scripts/ar_validate.py | 2 +- .../torch/export/plugins/hf_spec_export.py | 4 +- .../torch/speculative/eagle/default_config.py | 4 - .../torch/speculative/plugins/__init__.py | 5 +- .../torch/speculative/plugins/hf_dflash.py | 299 ++-------- .../plugins/{transformers.py => hf_eagle.py} | 557 +++--------------- .../torch/speculative/plugins/hf_medusa.py | 167 ++++++ .../speculative/plugins/modeling_dflash.py | 249 ++++++++ .../speculative/plugins/modeling_eagle.py | 213 +++++++ modelopt/torch/speculative/utils.py | 6 +- .../torch/export/test_hf_spec_rope_export.py | 17 +- 13 files changed, 777 insertions(+), 750 deletions(-) rename modelopt/torch/speculative/plugins/{transformers.py => hf_eagle.py} (63%) create mode 100644 modelopt/torch/speculative/plugins/hf_medusa.py create mode 100644 modelopt/torch/speculative/plugins/modeling_dflash.py create mode 100644 modelopt/torch/speculative/plugins/modeling_eagle.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3c4c11a090..2fb8692533 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -99,7 +99,7 @@ repos: modelopt/torch/quantization/plugins/attention.py| modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py| modelopt/torch/speculative/eagle/utils.py| - modelopt/torch/speculative/plugins/transformers.py| + modelopt/torch/speculative/plugins/hf_medusa.py| modelopt/torch/utils/plugins/megatron_mmlu.py| examples/chained_optimizations/bert_prune_distill_quantize.py| examples/deepseek/quantize_to_nvfp4.py| diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 12522b259b..dca2e0fb88 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -358,7 +358,7 @@ def patched_templated_attn(*args, **kwargs): original_op = args[2] # This patch is only enabled for eagle model by context manager, not base model. - patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH + patch_enbabled = modelopt.torch.speculative.plugins.hf_eagle.ENABLE_CP_TTT_PATCH if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention: raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}") diff --git a/examples/speculative_decoding/scripts/ar_validate.py b/examples/speculative_decoding/scripts/ar_validate.py index 8fcae87652..5699c480b7 100644 --- a/examples/speculative_decoding/scripts/ar_validate.py +++ b/examples/speculative_decoding/scripts/ar_validate.py @@ -27,7 +27,7 @@ from transformers import AutoTokenizer import modelopt.torch.opt as mto -from modelopt.torch.speculative.plugins.transformers import HFARValidation +from modelopt.torch.speculative.plugins.hf_eagle import HFARValidation from modelopt.torch.speculative.utils import load_vlm_or_llm mto.enable_huggingface_checkpointing() diff --git a/modelopt/torch/export/plugins/hf_spec_export.py b/modelopt/torch/export/plugins/hf_spec_export.py index 74d0a8e1d6..54d6e493c2 100644 --- a/modelopt/torch/export/plugins/hf_spec_export.py +++ b/modelopt/torch/export/plugins/hf_spec_export.py @@ -171,8 +171,8 @@ def _export_config(self): template_config = deepcopy(template_config) def _get_config_from_draft_or_base(key: str, model: nn.Module): - if getattr(model._draft_model_config, key, None) is not None: - return getattr(model._draft_model_config, key) + if getattr(model.eagle_config, key, None) is not None: + return getattr(model.eagle_config, key) elif getattr(model.config, key, None) is not None: return getattr(model.config, key) else: diff --git a/modelopt/torch/speculative/eagle/default_config.py b/modelopt/torch/speculative/eagle/default_config.py index ae6081a878..bd67cfe30e 100644 --- a/modelopt/torch/speculative/eagle/default_config.py +++ b/modelopt/torch/speculative/eagle/default_config.py @@ -37,8 +37,6 @@ "use_aux_hidden_state": False, "eagle_aux_hidden_state_layer_ids": [], "use_mtp_layernorm": False, - "parallel_draft_step": 1, - "parallel_draft_heads_num_layers": 1, "has_lm_head": False, "head_dim": 128, } @@ -107,7 +105,5 @@ "use_aux_hidden_state": True, "eagle_aux_hidden_state_layer_ids": [], "use_mtp_layernorm": False, - "parallel_draft_step": 1, - "parallel_draft_heads_num_layers": 1, "has_lm_head": False, } diff --git a/modelopt/torch/speculative/plugins/__init__.py b/modelopt/torch/speculative/plugins/__init__.py index 9b55db3af0..c30a65b2b4 100644 --- a/modelopt/torch/speculative/plugins/__init__.py +++ b/modelopt/torch/speculative/plugins/__init__.py @@ -18,7 +18,7 @@ Please check out the source code of this module for examples of how plugins work and how you can write your own one. Currently, we support plugins for -- :meth:`transformers` +- :meth:`hf_eagle` """ from modelopt.torch.utils import import_plugin @@ -31,4 +31,5 @@ with import_plugin("transformers"): from .hf_dflash import * - from .transformers import * + from .hf_eagle import * + from .hf_medusa import * diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index c2079038a2..46c96d2d30 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -50,25 +50,25 @@ lazy rope pattern needed for MLA models. """ +import contextlib import logging import torch import torch.nn.functional as F -from torch import nn +from torch.nn import CrossEntropyLoss from transformers import PreTrainedModel -from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS from transformers.models.qwen3.configuration_qwen3 import Qwen3Config as _Qwen3Config -from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP as _MLP_CLS # noqa: N814 -from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm as _NORM_CLS # noqa: N814 -from transformers.models.qwen3.modeling_qwen3 import ( - Qwen3RotaryEmbedding as _ROTARY_CLS, # noqa: N814 -) -from transformers.models.qwen3.modeling_qwen3 import rotate_half as _rotate_half from transformers.trainer_pt_utils import LabelSmoother from transformers.utils import ModelOutput from ..dflash.conversion import DFlashDMRegistry from ..dflash.dflash_model import DFlashModel +from .modeling_dflash import ( # noqa: F401 + DFlashAttention, + DFlashBaseModelOutput, + DFlashModule, + build_target_layer_ids, +) from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS logger = logging.getLogger(__name__) @@ -76,212 +76,6 @@ __all__ = ["HFDFlashModel"] -def build_target_layer_ids(num_target_layers, num_draft_layers): - """Select layers uniformly from the target model for feature extraction.""" - if num_target_layers < num_draft_layers: - raise ValueError( - f"num_target_layers ({num_target_layers}) must be >= num_draft_layers ({num_draft_layers})" - ) - if num_draft_layers == 1: - return [num_target_layers // 2] - start = min(1, num_target_layers - 1) - end = max(start, num_target_layers - 3) - span = end - start - return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] - - -def apply_rotary_pos_emb(q, k, cos, sin): - """Apply RoPE. Q uses last q_len positions, K uses all positions.""" - cos = cos.unsqueeze(1) # [B, 1, seq, dim] - sin = sin.unsqueeze(1) - q_len = q.size(2) - q_embed = (q * cos[:, :, -q_len:, :]) + (_rotate_half(q) * sin[:, :, -q_len:, :]) - k_embed = (k * cos) + (_rotate_half(k) * sin) - return q_embed, k_embed - - -class DFlashAttention(nn.Module): - """Attention with KV injection, using HF's attention dispatch.""" - - def __init__(self, config, layer_idx): - """Initialize DFlash attention with KV injection projections and QK-norm.""" - super().__init__() - self.config = config - self.layer_idx = layer_idx - self.head_dim = getattr( - config, "head_dim", config.hidden_size // config.num_attention_heads - ) - self.num_heads = config.num_attention_heads - self.num_kv_heads = config.num_key_value_heads - self.num_key_value_groups = self.num_heads // self.num_kv_heads - self.scaling = self.head_dim**-0.5 - self.attention_dropout = getattr(config, "attention_dropout", 0.0) - self.is_causal = False - - attn_bias = getattr(config, "attention_bias", False) - self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=attn_bias) - self.k_proj = nn.Linear( - config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias - ) - self.v_proj = nn.Linear( - config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias - ) - self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=attn_bias) - - self.q_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) - self.k_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) - - # Resolve HF attention function - self._attn_fn = None - # Qwen3 uses sliding window attention on some layers (config.layer_types) - if hasattr(config, "layer_types") and hasattr(config, "sliding_window"): - is_sliding = config.layer_types[layer_idx] == "sliding_attention" - self.sliding_window = config.sliding_window if is_sliding else None - else: - self.sliding_window = None - - def _get_attn_fn(self): - """Lazily resolve the HF attention function (default: sdpa).""" - if self._attn_fn is not None: - return self._attn_fn - impl = self.config._attn_implementation # default set in dflash/default_config.py - self._attn_fn = ALL_ATTENTION_FUNCTIONS.get(impl, ALL_ATTENTION_FUNCTIONS["sdpa"]) - return self._attn_fn - - def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): - """Forward with KV injection. - - Q is projected from the noise block (draft token embeddings: [anchor, mask, mask, ...]). - K and V are projected from the concatenation of target hidden states (context from the - base model) and noise block, so the draft can attend to both context and its own block. - """ - bsz, q_len, _ = hidden_states.shape - ctx_len = target_hidden.shape[1] - - # Q from noise block only (the draft tokens being predicted), with QK-norm - q = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim) - q = self.q_norm(q).transpose(1, 2) - - # K from context + noise, with QK-norm - k_ctx = self.k_proj(target_hidden) - k_noise = self.k_proj(hidden_states) - k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) - k = self.k_norm(k).transpose(1, 2) - - # V from context + noise (no norm) - v_ctx = self.v_proj(target_hidden) - v_noise = self.v_proj(hidden_states) - v = ( - torch.cat([v_ctx, v_noise], dim=1) - .view(bsz, ctx_len + q_len, -1, self.head_dim) - .transpose(1, 2) - ) - - # RoPE - cos, sin = position_embeddings - q, k = apply_rotary_pos_emb(q, k, cos, sin) - - # Use HF's attention dispatch (handles GQA internally) - attn_fn = self._get_attn_fn() - attn_output, _ = attn_fn( - self, - q, - k, - v, - attention_mask, - dropout=0.0 if not self.training else self.attention_dropout, - scaling=self.scaling, - sliding_window=self.sliding_window, - ) - attn_output = attn_output.reshape(bsz, q_len, -1) - return self.o_proj(attn_output) - - -class DFlashDecoderLayer(nn.Module): - """Draft decoder layer with KV injection.""" - - def __init__(self, config, layer_idx): - """Initialize decoder layer with attention, MLP, and layer norms.""" - super().__init__() - self.self_attn = DFlashAttention(config, layer_idx) - self.mlp = _MLP_CLS(config) - self.input_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) - self.post_attention_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) - - def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): - """Forward pass with residual connections.""" - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) - hidden_states = self.self_attn( - hidden_states, target_hidden, position_embeddings, attention_mask - ) - hidden_states = residual + hidden_states - - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) - hidden_states = self.mlp(hidden_states) - hidden_states = residual + hidden_states - return hidden_states - - -class DFlashModule(nn.Module): - """DFlash draft module using Qwen3 components (MLP, RMSNorm, RotaryEmbedding).""" - - def __init__(self, config): - """Initialize DFlash module with feature fusion, decoder layers, and rotary embeddings.""" - super().__init__() - self.config = config - self.block_size = config.block_size - - # Feature fusion - num_fused_layers = len(config.target_layer_ids) - self.fc = nn.Linear(num_fused_layers * config.hidden_size, config.hidden_size, bias=False) - self.hidden_norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) - - # Decoder layers - self.layers = nn.ModuleList( - [DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) - self._rotary_config = config # Used by _maybe_init_rotary_emb - - # Explicit weight init is needed because DFlashModule is instantiated via - # mtsp.convert() AFTER the base model's post_init() has already run, so HF's - # automatic _init_weights walk doesn't reach these new layers. - self._init_weights(config) - - def _maybe_init_rotary_emb(self, device=None): - """Lazily initialize rotary embeddings on first forward call. - - Same pattern as EAGLE3's _maybe_init_rope. Avoids creating rotary_emb - during __init__ (which runs on meta device during from_pretrained), - preventing the meta-tensor inv_freq issue on checkpoint resume. - """ - if not hasattr(self, "rotary_emb"): - self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device=device) - - def _init_weights(self, config): - """Initialize weights matching HF PreTrainedModel._init_weights.""" - std = getattr(config, "initializer_range", 0.02) - for module in self.modules(): - if isinstance(module, nn.Linear): - nn.init.normal_(module.weight, mean=0.0, std=std) - if module.bias is not None: - nn.init.zeros_(module.bias) - - def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=None): - """Forward with feature fusion, KV injection, and position embeddings.""" - hidden_states = noise_embedding - target_hidden = self.hidden_norm(self.fc(target_hidden)) - self._maybe_init_rotary_emb(device=hidden_states.device) - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - for layer in self.layers: - hidden_states = layer(hidden_states, target_hidden, position_embeddings, attention_mask) - - return self.norm(hidden_states) - - @DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) class HFDFlashModel(DFlashModel): """DFlash Model for HuggingFace transformers.""" @@ -327,6 +121,25 @@ def _find_base_model_parts(self): else: raise ValueError(f"Part {name} not found in model") + def _base_model_forward(self, input_ids, attention_mask, freeze=True, labels=None, **kwargs): + """Run the base model forward pass with optional freeze and base-model loss.""" + ctx = torch.no_grad() if freeze else contextlib.nullcontext() + with ctx: + outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + **kwargs, + ) + base_loss = None + if not freeze and labels is not None: + loss_fct = CrossEntropyLoss() + base_loss = loss_fct( + outputs.logits.view(-1, outputs.logits.shape[-1]), + labels.view(-1), + ) + return outputs, base_loss + def modify(self, config): """Initialize DFlash draft module.""" super().modify(config) @@ -593,6 +406,16 @@ def _compute_loss( return loss, accuracy + def _dflash_base_model_forward( + self, input_ids, attention_mask, freeze=True + ) -> DFlashBaseModelOutput: + """Run base model and extract target hidden states for DFlash.""" + outputs, _ = self._base_model_forward(input_ids, attention_mask, freeze=freeze) + # hidden_states[0] is the embedding output; layer i output is at index i+1 + selected = [outputs.hidden_states[lid + 1] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) + return DFlashBaseModelOutput(target_hidden=target_hidden, logits=outputs.logits) + def forward( self, input_ids=None, @@ -641,18 +464,10 @@ def forward( f"Adjust training_seq_len or use padding." ) - # 1. Run base model → hidden states - # TODO: For co-training the base model, remove no_grad and eval() switch. - with torch.no_grad(): - base_outputs = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - ) - - offset = 1 - selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] - target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] + # 1. Run base model → extract target hidden states + base_outputs = self._dflash_base_model_forward( + input_ids, attention_mask, freeze=self.dflash_freeze_base_model + ) # 2. Build loss mask. # When labels are provided (answer_only_loss), they already encode both @@ -682,13 +497,18 @@ def forward( ) full_pos = self._build_position_ids(seq_len, anchor_positions, device) attn_mask = self._build_draft_attention_mask( - seq_len, anchor_positions, block_keep_mask, n_blocks, target_hidden.dtype, device + seq_len, + anchor_positions, + block_keep_mask, + n_blocks, + base_outputs.target_hidden.dtype, + device, ) # 5. Draft forward hidden = self.dflash_module( noise_embedding=noise_embedding, - target_hidden=target_hidden, + target_hidden=base_outputs.target_hidden, position_ids=full_pos, attention_mask=attn_mask, ) @@ -762,29 +582,14 @@ def pseudo_speculative_generate(self, input_ids, steps=1): base_token: Next token from base model [B, 1]. draft_tokens: Draft tokens [B, min(steps, block_size-1)] or None if steps < 1. """ - # Call the base model's inner model directly (avoids DynamicModule dispatch) - model_output = self._base_model( - input_ids=input_ids, - output_hidden_states=True, - ) - # Compute logits via lm_head - base_logits = self._base_model_lm_head(model_output.last_hidden_state) - # Build output with hidden_states - base_outputs = ModelOutput( - logits=base_logits, - hidden_states=model_output.hidden_states, - ) - base_logits = base_outputs.logits - base_token = base_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) + base_outputs = self._dflash_base_model_forward(input_ids, attention_mask=None, freeze=True) + assert base_outputs.logits is not None + base_token = base_outputs.logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) if steps < 1: return base_token, None - # Extract target hidden states (raw, before FC projection) - hid_offset = 1 - selected = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] - target_hidden = torch.cat(selected, dim=-1) - + target_hidden = base_outputs.target_hidden block_size = self.dflash_block_size bsz = input_ids.shape[0] device = input_ids.device diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/hf_eagle.py similarity index 63% rename from modelopt/torch/speculative/plugins/transformers.py rename to modelopt/torch/speculative/plugins/hf_eagle.py index e213393191..d2af52a3e8 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/hf_eagle.py @@ -1,17 +1,3 @@ -# Adapted from: https://github.com/ctlllll/axolotl/blob/f86767e/src/axolotl/monkeypatch/medusa_utils.py -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # @@ -31,392 +17,39 @@ import contextlib import copy -from dataclasses import dataclass from typing import Any import torch -from torch import nn -from torch.nn import CrossEntropyLoss from torch.nn.attention.flex_attention import BlockMask, create_block_mask from transformers import Cache, DynamicCache, PreTrainedModel -from transformers.models.llama.modeling_llama import ( - LlamaDecoderLayer, - LlamaRMSNorm, - LlamaRotaryEmbedding, -) -from transformers.trainer_pt_utils import LabelSmoother +from transformers.models.llama.modeling_llama import LlamaDecoderLayer from transformers.utils import ModelOutput -from ...export.plugins.hf_spec_export import ( - EagleExporter, - EagleMedusaExporter, - SpeculativeDecodingExporter, -) +from ...export.plugins.hf_spec_export import EagleExporter, SpeculativeDecodingExporter from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel from ..eagle.utils import expand_mask, make_causal_mask -from ..medusa.conversion import MedusaDMRegistry -from ..medusa.medusa_model import MedusaModel from ..utils import ( AcceptanceRateValidation, - ResBlock, _setup_kimi_k2_decoder, enable_cp_ttt_patch, get_ttt_msk_func, temporary_set_config_value, ) +from .modeling_eagle import EagleBaseModelOutput, EagleModule from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS -__all__ = ["HFARValidation", "HFEagleModel", "HFMedusaModel"] +__all__ = ["HFARValidation", "HFEagleModel"] -IGNORE_TOKEN_ID = LabelSmoother.ignore_index ENABLE_CP_TTT_PATCH = False # module variable to cache attention mask for cp ttt CACHED_SHARD_TTT_MASKS = {} -@MedusaDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) -class HFMedusaModel(MedusaModel): - """Medusa Model Class for huggingface models.""" - - def modify(self, medusa_num_heads=0, medusa_num_layers=0): - """Constructor. - - Args: - medusa_num_heads: number of medusa heads. - medusa_num_layers: number of ResBlock layers in each head. - """ - super().modify(medusa_num_heads=medusa_num_heads, medusa_num_layers=medusa_num_layers) - self.config.medusa = { - "num_medusa_heads": medusa_num_heads, - "num_medusa_layers": medusa_num_layers, - } - - hidden_size = self.lm_head.weight.shape[-1] - vocab_size = self.lm_head.weight.shape[0] - - # Create a list of Medusa heads - self.medusa_heads = nn.ModuleList( - [ - nn.Sequential( - *([ResBlock(hidden_size) for _ in range(self.medusa_num_layers)]), - nn.Linear(hidden_size, vocab_size, bias=False), - ) - for _ in range(self.medusa_num_heads) - ] - ) - - # Ensure medusa_head's dtype and device align with the base_model - self.medusa_heads.to(self.lm_head.weight.dtype).to(self.lm_head.weight.device) - self.medusa_heads.device = self.lm_head.weight.device - if hasattr(self, "hf_device_map") and "lm_head" in self.hf_device_map: - self.hf_device_map["medusa_heads"] = self.hf_device_map["lm_head"] - - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = None, - output_hidden_states: bool | None = None, - cache_position: torch.LongTensor | None = None, - logits_to_keep: int | torch.Tensor = 0, - freeze_base_model: bool = True, - medusa_heads_coefficient: float | None = 0.2, - medusa_decay_coefficient: float | None = 0.8, - **kwargs, - ) -> Any: - """Forward pass of the MedusaModel. - - Returns: - torch.Tensor: A tensor containing predictions from all Medusa heads. - """ - # Pass input through the base model - with torch.no_grad() if freeze_base_model else contextlib.nullcontext(): - outputs = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - rcache_position=cache_position, - **kwargs, - ) - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = ( - slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - ) - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - medusa_logits = [ - self.medusa_heads[i](hidden_states[:, slice_indices, :]) - for i in range(self.medusa_num_heads) - ] - - if labels is not None: - loss = 0 - loss_fct = CrossEntropyLoss() - # Base model loss - if not freeze_base_model: - loss_logits = logits.view(-1, logits.shape[-1]) - loss_labels = labels.view(-1) - base_model_loss = loss_fct(loss_logits, loss_labels) - loss += base_model_loss - # Medusa loss - for i in range(self.medusa_num_heads): - labels = labels[..., 1:].contiguous() - loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous() - loss_logits = loss_logits.view(-1, loss_logits.shape[-1]) - loss_labels = labels.view(-1) - loss += ( - loss_fct(loss_logits, loss_labels) - * medusa_decay_coefficient**i - * medusa_heads_coefficient - ) - else: - loss = None - - return ModelOutput( - loss=loss, - logits=logits, - medusa_logits=medusa_logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - -class ParallelDraft(nn.Module): - """ParallelDraft module with multiple Medusa heads and a shared lm head.""" - - def __init__(self, hidden_size: int, vocab_size: int, num_heads: int = 1, num_layers: int = 1): - """Init function for ParallelDraft.""" - super().__init__() - - self.medusa_heads = torch.nn.ModuleList( - [ - nn.Sequential( - *([ResBlock(hidden_size) for _ in range(num_layers)]), - ) - for _ in range(num_heads) - ] - ) - self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) - - def forward(self, x): - """Forward function.""" - output = [] - for head in self.medusa_heads: - x_head = head(x) - output.append(self.lm_head(x_head)) - return output - - -class EagleModule(nn.Module): - """Eagle module used in EAGLE model.""" - - def __init__(self, config, decoder_layer_cls, bias=False): - """Init function for EagleModule.""" - super().__init__() - self.config = config - - self.layers = nn.ModuleList( - [decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - if config.use_last_layernorm: - self.norm = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps) - - # Optionally, we use a smaller vocab table for eagle module - if config.draft_vocab_size != config.vocab_size or config.has_lm_head: - # Need an extra lm_head for eagle module since vocab size is reduced. - assert config.draft_vocab_size <= config.vocab_size, ( - "EAGLE module's vocab size should be <= base model vocab size!" - ) - # Initialize the buffers to zero. - # Their values depend on specific tokenzier and calibrate dataset, and should be set in training script. - if config.draft_vocab_size < config.vocab_size: - self.register_buffer("d2t", torch.zeros(config.draft_vocab_size, dtype=torch.int64)) - self.lm_head = nn.Linear( - config.hidden_size, - config.draft_vocab_size, - bias=False, - ) - - if config.use_aux_hidden_state: - # In EAGLE-3, the FC concentrate hidden states from multiple base model layers - self.fc = nn.Linear( - len(config.eagle_aux_hidden_state_layer_ids) * config.hidden_size, - config.hidden_size, - bias=bias, - ) - - first_layer_attn = self.layers[0].self_attn - - # Expand first attn input dim since it accepts cat(input_embeds, hidden_states) - self._expand_first_attn_in_dim(first_layer_attn) - - # EAGLE-3's first attention require [input_layernorm_output, aux_hidden_states] - first_layer_attn.register_forward_pre_hook( - self._eagle3_attention_forward_pre_hook, with_kwargs=True - ) - - # In EAGLE-3, input_embeds and hidden_states are normalized separately before concatenation. - self.layers[0].input_layernorm = LlamaRMSNorm( - config.hidden_size, eps=config.rms_norm_eps - ) - self.layers[0].hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) - - if self.config.parallel_draft_step > 1: - self.parallel_draft_heads = ParallelDraft( - config.hidden_size, - config.draft_vocab_size, - num_heads=self.config.parallel_draft_step - 1, - num_layers=self.config.parallel_draft_heads_num_layers, - ) - - def _maybe_init_rope(self, device=None): - if self.config.eagle_decoder_type == "llama" and not hasattr(self, "rotary_emb"): - self.rotary_emb = LlamaRotaryEmbedding(config=self.config, device=device) - - def _expand_first_attn_in_dim(self, first_layer_attn): - """Modify qkv projection in first layer to accept 2h hidden size.""" - # Find Linear modules to expand - eagle_attn_type = type(first_layer_attn) - if eagle_attn_type.__name__ == "LlamaAttention": - expand_modules = ["q_proj", "k_proj", "v_proj"] - elif eagle_attn_type.__name__ == "DeepseekV3Attention": - if first_layer_attn.q_lora_rank is None: - expand_modules = ["q_proj", "kv_a_proj_with_mqa"] - else: - expand_modules = ["q_a_proj", "kv_a_proj_with_mqa"] - else: - raise ValueError(f"Unsupported attention type: {eagle_attn_type}") - - # Replace Linear with 2x input dim - for module in expand_modules: - original_linear = getattr(first_layer_attn, module) - assert isinstance(original_linear, nn.Linear), f"Module {module} is not a Linear" - setattr( - first_layer_attn, - module, - nn.Linear( - original_linear.in_features * 2, - original_linear.out_features, - bias=first_layer_attn.config.attention_bias, - ), - ) - - def _eagle3_attention_forward_pre_hook(self, module, args, kwargs): - """Concat input_embeds and hidden_states for EAGLE-3's first attention layer.""" - if "hidden_states" not in kwargs: - raise ValueError("hidden_states not found in kwargs") - if self._input_embeds is None: - raise ValueError("self._input_embeds is None") - - input_embeds = self._input_embeds - self._input_embeds = None - kwargs["hidden_states"] = torch.cat( - (input_embeds, self.layers[0].hidden_norm(kwargs["hidden_states"])), dim=-1 - ) - - return args, kwargs - - def forward( - self, - hidden_states: torch.Tensor, - inputs_embeds: torch.Tensor, - attention_mask: torch.Tensor, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - use_cache: bool | None = None, - output_attentions: bool | None = False, - ): - """Forward function for EagleModule.""" - batch_size, seq_length, _ = hidden_states.shape - seq_length_with_past = seq_length - past_key_values_length = 0 - - if past_key_values is not None: - past_key_values_length = past_key_values.get_seq_length() - seq_length_with_past = seq_length_with_past + past_key_values_length - if position_ids is None: - device = hidden_states.device if hidden_states is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, - seq_length + past_key_values_length, - dtype=torch.long, - device=device, - ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() - - inputs_embeds = inputs_embeds.to(hidden_states.dtype).to(hidden_states.device) - # In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function - # Also, we normalize input embeddings and hidden states before concatenating them. - # The default input norm in first layer attn will be disabled. - self._input_embeds = self.layers[0].input_layernorm(inputs_embeds) - - if self.config.eagle_decoder_type == "llama": - position_embeddings = self.rotary_emb(hidden_states, position_ids) - else: - position_embeddings = None - - for decoder_layer in self.layers: - layer_outputs = decoder_layer( - hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - output_attentions=output_attentions, - use_cache=use_cache, - position_embeddings=position_embeddings, - ) - # For HF>= 4.54.0, the layer_outputs is a tensor, for older, it is a tuple. - if isinstance(layer_outputs, tuple): - hidden_states = layer_outputs[0] - else: - hidden_states = layer_outputs - - pre_norm_h = hidden_states - - post_norm_h = self.norm(hidden_states) if hasattr(self, "norm") else hidden_states - - return post_norm_h, pre_norm_h, past_key_values - - -@dataclass -class EagleBaseModelOutput: - out_hiddens: torch.Tensor - aux_hiddens: torch.Tensor | None = None - logits: torch.Tensor | None = None - input_embeds: torch.Tensor | None = None - loss: torch.Tensor | None = None - - @classmethod - def from_offline_dict(cls, d: dict): - return cls( - out_hiddens=d.get("base_model_hidden_states"), - aux_hiddens=d.get("aux_hidden_states"), - logits=d.get("base_model_logits"), - input_embeds=d.get("base_model_input_embeds"), - loss=None, - ) - - @EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) class HFEagleModel(EagleModel): """Eagle Model Class for huggingface models.""" - # Use functions to get base model parts without creating tied modules. @property def _base_model(self): return self.get_submodule(self.base_model_path) @@ -438,16 +71,6 @@ def _base_llm_config(self): or self.config ) - @property - def _draft_model_config(self): - """Return the llm config for the draft model.""" - return self.eagle_config - - def _enable_cp_ttt(self): - if self.training and not self.eagle_mix_hidden_states: - return enable_cp_ttt_patch() - return contextlib.nullcontext() - def _nvtx_range(self, name): """Optionally create an NVTX range for the given name when config.eagle_enable_nvtx is set.""" if not self.eagle_enable_nvtx: @@ -460,20 +83,50 @@ def _nvtx_range(self, name): print(f"Failed to create NVTX range {name}: {e}") return contextlib.nullcontext() - def get_dummy_inputs(self) -> dict: - """Construct dummy inputs for export forward pass. + def _find_base_model_parts(self): + """Find model parts from different models and set base_{part}_path attributes.""" + for name, paths in { + "base_model_path": _BASE_MODEL_PATHS, + "base_model_embeddings_path": _EMBED_TOKENS_PATHS, + "base_model_lm_head_path": _LM_HEAD_PATHS, + }.items(): + for path in paths: + try: + submodule = self.get_submodule(path) + assert isinstance(submodule, torch.nn.Module) + setattr(self, name, path) + break + except Exception: + continue + else: + raise ValueError(f"Part {name} not found in model") - Returns a dict of kwargs that can be passed to forward(). For offline EAGLE models, - this includes dummy base_model_outputs with the right tensor shapes so the export - pipeline doesn't need to thread real calibration data through multiple layers. - """ + def _activate_torch_compile(self): + import torch._dynamo + + torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode + + compile_targets = [ + ("_prepare_eagle_inputs", {}), + ("_eagle_forward", {"mode": "max-autotune"}), + ("_eagle_loss", {"fullgraph": True}), + ] + for name, kwargs in compile_targets: + try: + setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) + except Exception: # noqa: PERF203 + print(f"Disabling torch.compile for {name} due to compilation error.") + + def get_dummy_inputs(self) -> dict: + """Construct dummy inputs for export forward pass.""" device = self.device - dtype = next(self.parameters()).dtype - hidden_size = self._base_llm_config.hidden_size dummy_inputs = { "input_ids": torch.ones(1, 2, dtype=torch.long, device=device), } if self.eagle_offline: + device = self.device + dtype = next(self.parameters()).dtype + hidden_size = self._base_llm_config.hidden_size base_model_outputs = { "base_model_hidden_states": torch.zeros( 1, 2, hidden_size, dtype=dtype, device=device @@ -492,33 +145,12 @@ def get_dummy_inputs(self) -> dict: def get_exporter(self) -> SpeculativeDecodingExporter: """Get the exporter for the draft model.""" - exporter_cls = ( - EagleExporter if self.eagle_config.parallel_draft_step <= 1 else EagleMedusaExporter - ) - return exporter_cls(self) + return EagleExporter(self) - def _find_base_model_parts(self): - """Find model parts from different models and set base_{part}_path attributes.""" - base_model_parts_mapping = { - "base_model_path": _BASE_MODEL_PATHS, - "base_model_embeddings_path": _EMBED_TOKENS_PATHS, - "base_model_lm_head_path": _LM_HEAD_PATHS, - } - - for name, paths in base_model_parts_mapping.items(): - found_submodule = False - for path in paths: - try: - submodule = self.get_submodule(path) - assert isinstance(submodule, torch.nn.Module) - print(f"Found {name} at {path}") - found_submodule = True - setattr(self, name, path) - break - except Exception: - continue - if not found_submodule: - raise ValueError(f"Part {name} not found in model") + def _enable_cp_ttt(self): + if self.training and not self.eagle_mix_hidden_states: + return enable_cp_ttt_patch() + return contextlib.nullcontext() def _set_default_aux_hidden_state_layers(self): # Read a custom config attribute since we override num_hidden_layers for offline training @@ -608,7 +240,9 @@ def _preservation_loss( KL(softmax(ref) || log_softmax(lora)) weighted by eagle_base_lora_preservation_loss_weight. """ - loss = nn.Softmax(dim=-1)(ref_logits.detach()) * nn.LogSoftmax(dim=-1)(lora_logits) + loss = torch.nn.Softmax(dim=-1)(ref_logits.detach()) * torch.nn.LogSoftmax(dim=-1)( + lora_logits + ) return -loss.sum(dim=-1).mean() * self.eagle_base_lora_preservation_loss_weight def modify( @@ -702,22 +336,6 @@ def modify( self._cached_attn_blk_masks = {} - def _activate_torch_compile(self): - import torch._dynamo - - torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode - - compile_targets = [ - ("_prepare_eagle_inputs", {}), - ("_eagle_forward", {"mode": "max-autotune"}), - ("_eagle_loss", {"fullgraph": True}), - ] - for name, kwargs in compile_targets: - try: - setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) - except Exception: # noqa: PERF203 - print(f"Disabling torch.compile for {name} due to compilation error.") - def _get_ttt_attention_mask(self, batch_size, seq_length, ttt_step): # compile and cached flex attention masks in first call if ttt_step not in self._cached_attn_blk_masks: @@ -868,7 +486,7 @@ def _compute_ttt_attention_mask( return tensor_mask - def _base_model_forward( + def _eagle_base_model_forward( self, input_ids, attention_mask, @@ -911,7 +529,7 @@ def _run_forward(no_grad): if ref_logits is not None: base_model_loss = self._preservation_loss(ref_logits, base_model_logits) elif not freeze_base_model and labels is not None: - loss_fct = CrossEntropyLoss() + loss_fct = torch.nn.CrossEntropyLoss() base_model_loss = loss_fct( base_model_logits.view(-1, base_model_logits.shape[-1]), labels.view(-1) ) @@ -957,13 +575,7 @@ def _eagle_forward( ) eagle_logits = eagle_lm_head(eagle_postnorm_h) - draft_logits_list = [eagle_logits] - if self.eagle_config.parallel_draft_step > 1: - # Get additional draft logits from parallel draft heads - draft_logits = self.eagle_module.parallel_draft_heads(eagle_postnorm_h) - draft_logits_list += draft_logits - - return eagle_postnorm_h, eagle_prenorm_h, draft_logits_list, eagle_cache + return eagle_postnorm_h, eagle_prenorm_h, eagle_logits, eagle_cache def forward( self, @@ -1013,7 +625,7 @@ def forward( past_key_values = None else: with self._nvtx_range("base_model_forward"): - base_outputs, past_key_values = self._base_model_forward( + base_outputs, past_key_values = self._eagle_base_model_forward( input_ids, attention_mask, position_ids, @@ -1031,9 +643,8 @@ def forward( # ====Prepare inputs for the first eagle forward pass==== eagle_loss = None - num_parallel = self.eagle_config.parallel_draft_step num_ttt = self.eagle_ttt_steps - train_accs = torch.zeros(num_parallel, num_ttt, device=input_ids.device) + train_accs = torch.zeros(1, num_ttt, device=input_ids.device) b, seq_length, _ = base_outputs.out_hiddens.shape with self._nvtx_range("prepare_eagle_inputs"): ( @@ -1051,10 +662,9 @@ def forward( base_outputs, ) - self.eagle_module._maybe_init_rope(device=eagle_input_hiddens.device) - # ====Run eagle forward with extra training-time-test steps==== - for ttt_step in range(self.eagle_ttt_steps): + num_ttt_steps = self.eagle_ttt_steps if self.training else 1 + for ttt_step in range(num_ttt_steps): # TODO: (hg) during cp training, this mask is not used. Maybe turn it off then. eagle_attention_mask = ( eagle_attn_mask_0 @@ -1087,29 +697,24 @@ def forward( else: eagle_input_hiddens = eagle_output_hiddens - for i in range(self.eagle_config.parallel_draft_step): - eagle_logit = eagle_logits[i] - with self._nvtx_range("eagle_loss"): - classification_loss, acc = self._eagle_loss( - # base model predict +1 tok, while eagle predict +2 - # so we shift base model outputs compared to eagle outputs - # additionally, we mask the first n tok of eagle outputs at nth TTT step - base_output_softmax_logits[:, 1 + i + ttt_step :], - base_output_predict_tok[:, 1 + i + ttt_step :], - eagle_logit[:, ttt_step : -(1 + i)], - loss_mask[:, 1 + ttt_step :] if i == 0 else loss_mask[:, 1 + ttt_step : -i], - ) - # Apply loss decay factor to focus on early steps - classification_loss *= self.eagle_loss_decay_factor ** (ttt_step + i) - eagle_loss = ( - classification_loss if eagle_loss is None else eagle_loss + classification_loss + with self._nvtx_range("eagle_loss"): + classification_loss, acc = self._eagle_loss( + # base model predict +1 tok, while eagle predict +2 + # so we shift base model outputs compared to eagle outputs + # additionally, we mask the first n tok of eagle outputs at nth TTT step + base_output_softmax_logits[:, 1 + ttt_step :], + base_output_predict_tok[:, 1 + ttt_step :], + eagle_logits[:, ttt_step:-1], + loss_mask[:, 1 + ttt_step :], ) - train_accs[i, ttt_step] = acc - if not self.training: - break + # Apply loss decay factor to focus on early steps + classification_loss *= self.eagle_loss_decay_factor**ttt_step + eagle_loss = ( + classification_loss if eagle_loss is None else eagle_loss + classification_loss + ) + train_accs[0, ttt_step] = acc - # Slice by actual number of steps taken, in case of early return - train_accs = train_accs[:, : ttt_step + 1].tolist() + train_accs = train_accs[:, :num_ttt_steps].tolist() # Merge eagle loss and preservation loss (if LoRA co-training) if base_outputs.loss is None and eagle_loss is None: @@ -1186,7 +791,6 @@ def pseudo_speculative_generate( else: eagle_input_hidden_states = base_model_hidden_states - self.eagle_module._maybe_init_rope(device=eagle_input_hidden_states.device) draft_tokens = [] for step in range(steps): b, seq_length = eagle_ids.shape @@ -1210,13 +814,7 @@ def pseudo_speculative_generate( None, ) - # parallel logits are only used after the last step - if step == steps - 1 and self.eagle_config.parallel_draft_step > 1: - parallel_logits = [ - eagle_logits[i][:, -1:, :] - for i in range(1, self.eagle_config.parallel_draft_step) - ] - draft_token = eagle_logits[0][:, -1:, :].argmax(dim=-1) + draft_token = eagle_logits[:, -1:, :].argmax(dim=-1) if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: draft_token += self.eagle_module.d2t[draft_token] draft_tokens.append(draft_token) @@ -1227,13 +825,6 @@ def pseudo_speculative_generate( ) draft_tokens = torch.cat(draft_tokens, dim=-1).to(base_token.device) - if self.eagle_config.parallel_draft_step > 1: - parallel_logits = torch.cat(parallel_logits, dim=1) - parallel_tokens = parallel_logits.argmax(dim=-1) - if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: - parallel_tokens += self.eagle_module.d2t[parallel_tokens] - draft_tokens = torch.cat((draft_tokens, parallel_tokens), dim=-1).to(base_token.device) - return base_token, draft_tokens diff --git a/modelopt/torch/speculative/plugins/hf_medusa.py b/modelopt/torch/speculative/plugins/hf_medusa.py new file mode 100644 index 0000000000..42ea262d8d --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_medusa.py @@ -0,0 +1,167 @@ +# Adapted from: https://github.com/ctlllll/axolotl/blob/f86767e/src/axolotl/monkeypatch/medusa_utils.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Medusa speculative decoding plugin for HuggingFace models.""" + +import contextlib +from typing import Any + +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import Cache, PreTrainedModel +from transformers.trainer_pt_utils import LabelSmoother +from transformers.utils import ModelOutput + +from ..medusa.conversion import MedusaDMRegistry +from ..medusa.medusa_model import MedusaModel +from ..utils import ResBlock + +__all__ = ["HFMedusaModel"] + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +@MedusaDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) +class HFMedusaModel(MedusaModel): + """Medusa Model Class for huggingface models.""" + + def modify(self, medusa_num_heads=0, medusa_num_layers=0): + """Constructor. + + Args: + medusa_num_heads: number of medusa heads. + medusa_num_layers: number of ResBlock layers in each head. + """ + super().modify(medusa_num_heads=medusa_num_heads, medusa_num_layers=medusa_num_layers) + self.config.medusa = { + "num_medusa_heads": medusa_num_heads, + "num_medusa_layers": medusa_num_layers, + } + + hidden_size = self.lm_head.weight.shape[-1] + vocab_size = self.lm_head.weight.shape[0] + + # Create a list of Medusa heads + self.medusa_heads = nn.ModuleList( + [ + nn.Sequential( + *([ResBlock(hidden_size) for _ in range(self.medusa_num_layers)]), + nn.Linear(hidden_size, vocab_size, bias=False), + ) + for _ in range(self.medusa_num_heads) + ] + ) + + # Ensure medusa_head's dtype and device align with the base_model + self.medusa_heads.to(self.lm_head.weight.dtype).to(self.lm_head.weight.device) + self.medusa_heads.device = self.lm_head.weight.device + if hasattr(self, "hf_device_map") and "lm_head" in self.hf_device_map: + self.hf_device_map["medusa_heads"] = self.hf_device_map["lm_head"] + + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = None, + output_hidden_states: bool | None = None, + cache_position: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + freeze_base_model: bool = True, + medusa_heads_coefficient: float | None = 0.2, + medusa_decay_coefficient: float | None = 0.8, + **kwargs, + ) -> Any: + """Forward pass of the MedusaModel. + + Returns: + torch.Tensor: A tensor containing predictions from all Medusa heads. + """ + # Pass input through the base model + with torch.no_grad() if freeze_base_model else contextlib.nullcontext(): + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + rcache_position=cache_position, + **kwargs, + ) + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = ( + slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + ) + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + medusa_logits = [ + self.medusa_heads[i](hidden_states[:, slice_indices, :]) + for i in range(self.medusa_num_heads) + ] + + if labels is not None: + loss = 0 + loss_fct = CrossEntropyLoss() + # Base model loss + if not freeze_base_model: + loss_logits = logits.view(-1, logits.shape[-1]) + loss_labels = labels.view(-1) + base_model_loss = loss_fct(loss_logits, loss_labels) + loss += base_model_loss + # Medusa loss + for i in range(self.medusa_num_heads): + labels = labels[..., 1:].contiguous() + loss_logits = medusa_logits[i][:, : -(1 + i)].contiguous() + loss_logits = loss_logits.view(-1, loss_logits.shape[-1]) + loss_labels = labels.view(-1) + loss += ( + loss_fct(loss_logits, loss_labels) + * medusa_decay_coefficient**i + * medusa_heads_coefficient + ) + else: + loss = None + + return ModelOutput( + loss=loss, + logits=logits, + medusa_logits=medusa_logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/modelopt/torch/speculative/plugins/modeling_dflash.py b/modelopt/torch/speculative/plugins/modeling_dflash.py new file mode 100644 index 0000000000..4cb8684b66 --- /dev/null +++ b/modelopt/torch/speculative/plugins/modeling_dflash.py @@ -0,0 +1,249 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DFlash draft model architecture (DFlashModule) and related components. + +Draft model components use Qwen3 (MLP, RMSNorm, RotaryEmbedding) from +``transformers.models.qwen3``, matching z-lab's reference checkpoint format. +The draft architecture is independent of the target model. +""" + +from dataclasses import dataclass + +import torch +from torch import nn +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS +from transformers.models.qwen3.modeling_qwen3 import Qwen3MLP as _MLP_CLS # noqa: N814 +from transformers.models.qwen3.modeling_qwen3 import Qwen3RMSNorm as _NORM_CLS # noqa: N814 +from transformers.models.qwen3.modeling_qwen3 import ( + Qwen3RotaryEmbedding as _ROTARY_CLS, # noqa: N814 +) +from transformers.models.qwen3.modeling_qwen3 import rotate_half as _rotate_half + +__all__ = ["DFlashBaseModelOutput", "DFlashModule", "build_target_layer_ids"] + + +@dataclass +class DFlashBaseModelOutput: + """Output container for base model forward pass in DFlash training.""" + + target_hidden: torch.Tensor # concatenated hidden states from target layers [B, seq, N*H] + logits: torch.Tensor | None = None # base model logits [B, seq, vocab] + + +def build_target_layer_ids(num_target_layers, num_draft_layers): + """Select layers uniformly from the target model for feature extraction.""" + if num_target_layers < num_draft_layers: + raise ValueError( + f"num_target_layers ({num_target_layers}) must be >= num_draft_layers ({num_draft_layers})" + ) + if num_draft_layers == 1: + return [num_target_layers // 2] + start = min(1, num_target_layers - 1) + end = max(start, num_target_layers - 3) + span = end - start + return [round(start + (i * span) / (num_draft_layers - 1)) for i in range(num_draft_layers)] + + +def apply_rotary_pos_emb(q, k, cos, sin): + """Apply RoPE. Q uses last q_len positions, K uses all positions.""" + cos = cos.unsqueeze(1) # [B, 1, seq, dim] + sin = sin.unsqueeze(1) + q_len = q.size(2) + q_embed = (q * cos[:, :, -q_len:, :]) + (_rotate_half(q) * sin[:, :, -q_len:, :]) + k_embed = (k * cos) + (_rotate_half(k) * sin) + return q_embed, k_embed + + +class DFlashAttention(nn.Module): + """Attention with KV injection, using HF's attention dispatch.""" + + def __init__(self, config, layer_idx): + """Initialize DFlash attention with KV injection projections and QK-norm.""" + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.head_dim = getattr( + config, "head_dim", config.hidden_size // config.num_attention_heads + ) + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_kv_heads + self.scaling = self.head_dim**-0.5 + self.attention_dropout = getattr(config, "attention_dropout", 0.0) + self.is_causal = False + + attn_bias = getattr(config, "attention_bias", False) + self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.head_dim, bias=attn_bias) + self.k_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias + ) + self.v_proj = nn.Linear( + config.hidden_size, self.num_kv_heads * self.head_dim, bias=attn_bias + ) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, config.hidden_size, bias=attn_bias) + + self.q_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) + self.k_norm = _NORM_CLS(self.head_dim, eps=config.rms_norm_eps) + + # Resolve HF attention function + self._attn_fn = None + # Qwen3 uses sliding window attention on some layers (config.layer_types) + if hasattr(config, "layer_types") and hasattr(config, "sliding_window"): + is_sliding = config.layer_types[layer_idx] == "sliding_attention" + self.sliding_window = config.sliding_window if is_sliding else None + else: + self.sliding_window = None + + def _get_attn_fn(self): + """Lazily resolve the HF attention function (default: sdpa).""" + if self._attn_fn is not None: + return self._attn_fn + impl = self.config._attn_implementation # default set in dflash/default_config.py + self._attn_fn = ALL_ATTENTION_FUNCTIONS.get(impl, ALL_ATTENTION_FUNCTIONS["sdpa"]) + return self._attn_fn + + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + """Forward with KV injection. + + Q is projected from the noise block (draft token embeddings: [anchor, mask, mask, ...]). + K and V are projected from the concatenation of target hidden states (context from the + base model) and noise block, so the draft can attend to both context and its own block. + """ + bsz, q_len, _ = hidden_states.shape + ctx_len = target_hidden.shape[1] + + # Q from noise block only (the draft tokens being predicted), with QK-norm + q = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim) + q = self.q_norm(q).transpose(1, 2) + + # K from context + noise, with QK-norm + k_ctx = self.k_proj(target_hidden) + k_noise = self.k_proj(hidden_states) + k = torch.cat([k_ctx, k_noise], dim=1).view(bsz, ctx_len + q_len, -1, self.head_dim) + k = self.k_norm(k).transpose(1, 2) + + # V from context + noise (no norm) + v_ctx = self.v_proj(target_hidden) + v_noise = self.v_proj(hidden_states) + v = ( + torch.cat([v_ctx, v_noise], dim=1) + .view(bsz, ctx_len + q_len, -1, self.head_dim) + .transpose(1, 2) + ) + + # RoPE + cos, sin = position_embeddings + q, k = apply_rotary_pos_emb(q, k, cos, sin) + + # Use HF's attention dispatch (handles GQA internally) + attn_fn = self._get_attn_fn() + attn_output, _ = attn_fn( + self, + q, + k, + v, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + ) + attn_output = attn_output.reshape(bsz, q_len, -1) + return self.o_proj(attn_output) + + +class DFlashDecoderLayer(nn.Module): + """Draft decoder layer with KV injection.""" + + def __init__(self, config, layer_idx): + """Initialize decoder layer with attention, MLP, and layer norms.""" + super().__init__() + self.self_attn = DFlashAttention(config, layer_idx) + self.mlp = _MLP_CLS(config) + self.input_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + + def forward(self, hidden_states, target_hidden, position_embeddings, attention_mask=None): + """Forward pass with residual connections.""" + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + hidden_states, target_hidden, position_embeddings, attention_mask + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + return hidden_states + + +class DFlashModule(nn.Module): + """DFlash draft module using Qwen3 components (MLP, RMSNorm, RotaryEmbedding).""" + + def __init__(self, config): + """Initialize DFlash module with feature fusion, decoder layers, and rotary embeddings.""" + super().__init__() + self.config = config + self.block_size = config.block_size + + # Feature fusion + num_fused_layers = len(config.target_layer_ids) + self.fc = nn.Linear(num_fused_layers * config.hidden_size, config.hidden_size, bias=False) + self.hidden_norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + + # Decoder layers + self.layers = nn.ModuleList( + [DFlashDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.norm = _NORM_CLS(config.hidden_size, eps=config.rms_norm_eps) + self._rotary_config = config # Used by _maybe_init_rotary_emb + + # Explicit weight init is needed because DFlashModule is instantiated via + # mtsp.convert() AFTER the base model's post_init() has already run, so HF's + # automatic _init_weights walk doesn't reach these new layers. + self._init_weights(config) + + def _maybe_init_rotary_emb(self, device=None): + """Lazily initialize rotary embeddings on first forward call. + + Same pattern as EAGLE3's _maybe_init_rope. Avoids creating rotary_emb + during __init__ (which runs on meta device during from_pretrained), + preventing the meta-tensor inv_freq issue on checkpoint resume. + """ + if not hasattr(self, "rotary_emb"): + self.rotary_emb = _ROTARY_CLS(config=self._rotary_config, device=device) + + def _init_weights(self, config): + """Initialize weights matching HF PreTrainedModel._init_weights.""" + std = getattr(config, "initializer_range", 0.02) + for module in self.modules(): + if isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=std) + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward(self, noise_embedding, target_hidden, position_ids, attention_mask=None): + """Forward with feature fusion, KV injection, and position embeddings.""" + hidden_states = noise_embedding + target_hidden = self.hidden_norm(self.fc(target_hidden)) + self._maybe_init_rotary_emb(device=hidden_states.device) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + + for layer in self.layers: + hidden_states = layer(hidden_states, target_hidden, position_embeddings, attention_mask) + + return self.norm(hidden_states) diff --git a/modelopt/torch/speculative/plugins/modeling_eagle.py b/modelopt/torch/speculative/plugins/modeling_eagle.py new file mode 100644 index 0000000000..b9df4eba1a --- /dev/null +++ b/modelopt/torch/speculative/plugins/modeling_eagle.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""EAGLE draft model architecture (EagleModule) and related data structures.""" + +from dataclasses import dataclass + +import torch +from torch import nn +from transformers import Cache +from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding + +__all__ = ["EagleBaseModelOutput", "EagleModule"] + + +class EagleModule(nn.Module): + """Eagle module used in EAGLE model.""" + + def __init__(self, config, decoder_layer_cls, bias=False): + """Init function for EagleModule.""" + super().__init__() + self.config = config + + self.layers = nn.ModuleList( + [decoder_layer_cls(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + if config.use_last_layernorm: + self.norm = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps) + + # Optionally, we use a smaller vocab table for eagle module + if config.draft_vocab_size != config.vocab_size or config.has_lm_head: + # Need an extra lm_head for eagle module since vocab size is reduced. + assert config.draft_vocab_size <= config.vocab_size, ( + "EAGLE module's vocab size should be <= base model vocab size!" + ) + # Initialize the buffers to zero. + # Their values depend on specific tokenzier and calibrate dataset, and should be set in training script. + if config.draft_vocab_size < config.vocab_size: + self.register_buffer("d2t", torch.zeros(config.draft_vocab_size, dtype=torch.int64)) + self.lm_head = nn.Linear( + config.hidden_size, + config.draft_vocab_size, + bias=False, + ) + + if config.use_aux_hidden_state: + # In EAGLE-3, the FC concentrate hidden states from multiple base model layers + self.fc = nn.Linear( + len(config.eagle_aux_hidden_state_layer_ids) * config.hidden_size, + config.hidden_size, + bias=bias, + ) + + first_layer_attn = self.layers[0].self_attn + + # Expand first attn input dim since it accepts cat(input_embeds, hidden_states) + self._expand_first_attn_in_dim(first_layer_attn) + + # EAGLE-3's first attention require [input_layernorm_output, aux_hidden_states] + first_layer_attn.register_forward_pre_hook( + self._eagle3_attention_forward_pre_hook, with_kwargs=True + ) + + # In EAGLE-3, input_embeds and hidden_states are normalized separately before concatenation. + self.layers[0].input_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + self.layers[0].hidden_norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def _maybe_init_rope(self, device=None): + if self.config.eagle_decoder_type == "llama" and not hasattr(self, "rotary_emb"): + self.rotary_emb = LlamaRotaryEmbedding(config=self.config, device=device) + + def _expand_first_attn_in_dim(self, first_layer_attn): + """Modify qkv projection in first layer to accept 2h hidden size.""" + # Find Linear modules to expand + eagle_attn_type = type(first_layer_attn) + if eagle_attn_type.__name__ == "LlamaAttention": + expand_modules = ["q_proj", "k_proj", "v_proj"] + elif eagle_attn_type.__name__ == "DeepseekV3Attention": + if first_layer_attn.q_lora_rank is None: + expand_modules = ["q_proj", "kv_a_proj_with_mqa"] + else: + expand_modules = ["q_a_proj", "kv_a_proj_with_mqa"] + else: + raise ValueError(f"Unsupported attention type: {eagle_attn_type}") + + # Replace Linear with 2x input dim + for module in expand_modules: + original_linear = getattr(first_layer_attn, module) + assert isinstance(original_linear, nn.Linear), f"Module {module} is not a Linear" + setattr( + first_layer_attn, + module, + nn.Linear( + original_linear.in_features * 2, + original_linear.out_features, + bias=first_layer_attn.config.attention_bias, + ), + ) + + def _eagle3_attention_forward_pre_hook(self, module, args, kwargs): + """Concat input_embeds and hidden_states for EAGLE-3's first attention layer.""" + if "hidden_states" not in kwargs: + raise ValueError("hidden_states not found in kwargs") + if self._input_embeds is None: + raise ValueError("self._input_embeds is None") + + input_embeds = self._input_embeds + self._input_embeds = None + kwargs["hidden_states"] = torch.cat( + (input_embeds, self.layers[0].hidden_norm(kwargs["hidden_states"])), dim=-1 + ) + + return args, kwargs + + def forward( + self, + hidden_states: torch.Tensor, + inputs_embeds: torch.Tensor, + attention_mask: torch.Tensor, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + use_cache: bool | None = None, + output_attentions: bool | None = False, + ): + """Forward function for EagleModule.""" + batch_size, seq_length, _ = hidden_states.shape + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + seq_length_with_past = seq_length_with_past + past_key_values_length + if position_ids is None: + device = hidden_states.device if hidden_states is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, + seq_length + past_key_values_length, + dtype=torch.long, + device=device, + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + inputs_embeds = inputs_embeds.to(hidden_states.dtype).to(hidden_states.device) + # In EAGLE-3, we save input embeddings to attribute, and use it in first decoder layer by hook function + # Also, we normalize input embeddings and hidden states before concatenating them. + # The default input norm in first layer attn will be disabled. + self._input_embeds = self.layers[0].input_layernorm(inputs_embeds) + + if self.config.eagle_decoder_type == "llama": + self._maybe_init_rope(device=hidden_states.device) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + else: + position_embeddings = None + + for decoder_layer in self.layers: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + position_embeddings=position_embeddings, + ) + # For HF>= 4.54.0, the layer_outputs is a tensor, for older, it is a tuple. + if isinstance(layer_outputs, tuple): + hidden_states = layer_outputs[0] + else: + hidden_states = layer_outputs + + pre_norm_h = hidden_states + + post_norm_h = self.norm(hidden_states) if hasattr(self, "norm") else hidden_states + + return post_norm_h, pre_norm_h, past_key_values + + +@dataclass +class EagleBaseModelOutput: + """Output container for base model forward pass in EAGLE training.""" + + out_hiddens: torch.Tensor + aux_hiddens: torch.Tensor | None = None + logits: torch.Tensor | None = None + input_embeds: torch.Tensor | None = None + loss: torch.Tensor | None = None + + @classmethod + def from_offline_dict(cls, d: dict): + """Construct from a dict of pre-computed base model outputs (offline training).""" + return cls( + out_hiddens=d.get("base_model_hidden_states"), + aux_hiddens=d.get("aux_hidden_states"), + logits=d.get("base_model_logits"), + input_embeds=d.get("base_model_input_embeds"), + loss=None, + ) diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index ab73684c8a..bb8a4010de 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -554,14 +554,14 @@ def ttt_msk_func(b, h, q_idx, kv_idx): @contextlib.contextmanager def enable_cp_ttt_patch(): """Context manager to enable CP TTT patch.""" - import modelopt.torch.speculative.plugins.transformers + import modelopt.torch.speculative.plugins.hf_eagle - modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = True + modelopt.torch.speculative.plugins.hf_eagle.ENABLE_CP_TTT_PATCH = True with sdpa_kernel([SDPBackend.CUDNN_ATTENTION, SDPBackend.MATH]): try: yield finally: - modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH = False + modelopt.torch.speculative.plugins.hf_eagle.ENABLE_CP_TTT_PATCH = False def load_vlm_or_llm( diff --git a/tests/unit/torch/export/test_hf_spec_rope_export.py b/tests/unit/torch/export/test_hf_spec_rope_export.py index 3b2c86a690..171fe6263c 100644 --- a/tests/unit/torch/export/test_hf_spec_rope_export.py +++ b/tests/unit/torch/export/test_hf_spec_rope_export.py @@ -37,6 +37,9 @@ def _make_exporter( model = MagicMock() model.eagle_config.eagle_decoder_type = "llama" model.eagle_config.rope_scaling = {"rope_type": rope_type, "rope_theta": rope_theta} + # rope_theta lives inside rope_scaling in transformers 5.x; clear the top-level attr + # so the fallback path is exercised instead of MagicMock's auto-attr. + model.eagle_config.rope_theta = None model.eagle_export_rope_scaling = eagle_export_rope_scaling model._draft_model_config = None model.config.rope_scaling = None @@ -55,16 +58,18 @@ def test_yarn_rope_injected_with_correct_config(): assert config["rope_scaling"] == DEFAULT_ROPE_SCALING -def test_rope_not_injected_when_non_default_training_rope(): - """rope_scaling is not overridden when training rope_type is not 'default'.""" +def test_rope_not_overridden_when_non_default_training_rope(): + """Export override is not applied when training rope_type is not 'default'; + rope_scaling falls through to the training config.""" config = _make_exporter(rope_type="llama3")._export_config() - assert config.get("rope_scaling") is None + assert config["rope_scaling"] == {"rope_type": "llama3", "rope_theta": 10000} -def test_rope_not_injected_when_eagle_export_rope_scaling_is_empty(): - """rope_scaling is not injected when eagle_export_rope_scaling is empty dict.""" +def test_rope_not_overridden_when_eagle_export_rope_scaling_is_empty(): + """Export override is not applied when eagle_export_rope_scaling is empty; + rope_scaling falls through to the training config.""" config = _make_exporter(eagle_export_rope_scaling={})._export_config() - assert config.get("rope_scaling") is None + assert config["rope_scaling"] == {"rope_type": "default", "rope_theta": 10000} def test_rope_theta_fallback_from_rope_scaling(): From 9ae53027296dcce6f910d4bb8e761e04837354d4 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sun, 19 Apr 2026 23:01:37 +0000 Subject: [PATCH 2/3] revert behavior change Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../torch/speculative/plugins/hf_dflash.py | 86 ++++++++----------- .../torch/speculative/plugins/hf_eagle.py | 11 ++- .../speculative/plugins/modeling_dflash.py | 12 +-- 3 files changed, 43 insertions(+), 66 deletions(-) diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 46c96d2d30..55b5e81490 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -50,12 +50,10 @@ lazy rope pattern needed for MLA models. """ -import contextlib import logging import torch import torch.nn.functional as F -from torch.nn import CrossEntropyLoss from transformers import PreTrainedModel from transformers.models.qwen3.configuration_qwen3 import Qwen3Config as _Qwen3Config from transformers.trainer_pt_utils import LabelSmoother @@ -63,12 +61,7 @@ from ..dflash.conversion import DFlashDMRegistry from ..dflash.dflash_model import DFlashModel -from .modeling_dflash import ( # noqa: F401 - DFlashAttention, - DFlashBaseModelOutput, - DFlashModule, - build_target_layer_ids, -) +from .modeling_dflash import DFlashAttention, DFlashModule, build_target_layer_ids # noqa: F401 from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS logger = logging.getLogger(__name__) @@ -121,25 +114,6 @@ def _find_base_model_parts(self): else: raise ValueError(f"Part {name} not found in model") - def _base_model_forward(self, input_ids, attention_mask, freeze=True, labels=None, **kwargs): - """Run the base model forward pass with optional freeze and base-model loss.""" - ctx = torch.no_grad() if freeze else contextlib.nullcontext() - with ctx: - outputs = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - **kwargs, - ) - base_loss = None - if not freeze and labels is not None: - loss_fct = CrossEntropyLoss() - base_loss = loss_fct( - outputs.logits.view(-1, outputs.logits.shape[-1]), - labels.view(-1), - ) - return outputs, base_loss - def modify(self, config): """Initialize DFlash draft module.""" super().modify(config) @@ -406,16 +380,6 @@ def _compute_loss( return loss, accuracy - def _dflash_base_model_forward( - self, input_ids, attention_mask, freeze=True - ) -> DFlashBaseModelOutput: - """Run base model and extract target hidden states for DFlash.""" - outputs, _ = self._base_model_forward(input_ids, attention_mask, freeze=freeze) - # hidden_states[0] is the embedding output; layer i output is at index i+1 - selected = [outputs.hidden_states[lid + 1] for lid in self.target_layer_ids] - target_hidden = torch.cat(selected, dim=-1) - return DFlashBaseModelOutput(target_hidden=target_hidden, logits=outputs.logits) - def forward( self, input_ids=None, @@ -464,10 +428,18 @@ def forward( f"Adjust training_seq_len or use padding." ) - # 1. Run base model → extract target hidden states - base_outputs = self._dflash_base_model_forward( - input_ids, attention_mask, freeze=self.dflash_freeze_base_model - ) + # 1. Run base model → hidden states + # TODO: For co-training the base model, remove no_grad and eval() switch. + with torch.no_grad(): + base_outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + offset = 1 + selected = [base_outputs.hidden_states[lid + offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) # [B, seq, num_layers * H] # 2. Build loss mask. # When labels are provided (answer_only_loss), they already encode both @@ -497,18 +469,13 @@ def forward( ) full_pos = self._build_position_ids(seq_len, anchor_positions, device) attn_mask = self._build_draft_attention_mask( - seq_len, - anchor_positions, - block_keep_mask, - n_blocks, - base_outputs.target_hidden.dtype, - device, + seq_len, anchor_positions, block_keep_mask, n_blocks, target_hidden.dtype, device ) # 5. Draft forward hidden = self.dflash_module( noise_embedding=noise_embedding, - target_hidden=base_outputs.target_hidden, + target_hidden=target_hidden, position_ids=full_pos, attention_mask=attn_mask, ) @@ -582,14 +549,29 @@ def pseudo_speculative_generate(self, input_ids, steps=1): base_token: Next token from base model [B, 1]. draft_tokens: Draft tokens [B, min(steps, block_size-1)] or None if steps < 1. """ - base_outputs = self._dflash_base_model_forward(input_ids, attention_mask=None, freeze=True) - assert base_outputs.logits is not None - base_token = base_outputs.logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) + # Call the base model's inner model directly (avoids DynamicModule dispatch) + model_output = self._base_model( + input_ids=input_ids, + output_hidden_states=True, + ) + # Compute logits via lm_head + base_logits = self._base_model_lm_head(model_output.last_hidden_state) + # Build output with hidden_states + base_outputs = ModelOutput( + logits=base_logits, + hidden_states=model_output.hidden_states, + ) + base_logits = base_outputs.logits + base_token = base_logits[:, -1:, :].argmax(dim=-1).to(input_ids.device) if steps < 1: return base_token, None - target_hidden = base_outputs.target_hidden + # Extract target hidden states (raw, before FC projection) + hid_offset = 1 + selected = [base_outputs.hidden_states[lid + hid_offset] for lid in self.target_layer_ids] + target_hidden = torch.cat(selected, dim=-1) + block_size = self.dflash_block_size bsz = input_ids.shape[0] device = input_ids.device diff --git a/modelopt/torch/speculative/plugins/hf_eagle.py b/modelopt/torch/speculative/plugins/hf_eagle.py index d2af52a3e8..0c15439174 100644 --- a/modelopt/torch/speculative/plugins/hf_eagle.py +++ b/modelopt/torch/speculative/plugins/hf_eagle.py @@ -85,20 +85,25 @@ def _nvtx_range(self, name): def _find_base_model_parts(self): """Find model parts from different models and set base_{part}_path attributes.""" - for name, paths in { + base_model_parts_mapping = { "base_model_path": _BASE_MODEL_PATHS, "base_model_embeddings_path": _EMBED_TOKENS_PATHS, "base_model_lm_head_path": _LM_HEAD_PATHS, - }.items(): + } + + for name, paths in base_model_parts_mapping.items(): + found_submodule = False for path in paths: try: submodule = self.get_submodule(path) assert isinstance(submodule, torch.nn.Module) + print(f"Found {name} at {path}") + found_submodule = True setattr(self, name, path) break except Exception: continue - else: + if not found_submodule: raise ValueError(f"Part {name} not found in model") def _activate_torch_compile(self): diff --git a/modelopt/torch/speculative/plugins/modeling_dflash.py b/modelopt/torch/speculative/plugins/modeling_dflash.py index 4cb8684b66..7cb614d1e0 100644 --- a/modelopt/torch/speculative/plugins/modeling_dflash.py +++ b/modelopt/torch/speculative/plugins/modeling_dflash.py @@ -20,8 +20,6 @@ The draft architecture is independent of the target model. """ -from dataclasses import dataclass - import torch from torch import nn from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS @@ -32,15 +30,7 @@ ) from transformers.models.qwen3.modeling_qwen3 import rotate_half as _rotate_half -__all__ = ["DFlashBaseModelOutput", "DFlashModule", "build_target_layer_ids"] - - -@dataclass -class DFlashBaseModelOutput: - """Output container for base model forward pass in DFlash training.""" - - target_hidden: torch.Tensor # concatenated hidden states from target layers [B, seq, N*H] - logits: torch.Tensor | None = None # base model logits [B, seq, vocab] +__all__ = ["DFlashModule", "build_target_layer_ids"] def build_target_layer_ids(num_target_layers, num_draft_layers): From 4eaafb23e69657642edaf6f884be933a79284893 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sun, 19 Apr 2026 23:30:24 +0000 Subject: [PATCH 3/3] fix typo: tokenzier -> tokenizer Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- modelopt/torch/speculative/plugins/modeling_eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/modeling_eagle.py b/modelopt/torch/speculative/plugins/modeling_eagle.py index b9df4eba1a..223f86f021 100644 --- a/modelopt/torch/speculative/plugins/modeling_eagle.py +++ b/modelopt/torch/speculative/plugins/modeling_eagle.py @@ -46,7 +46,7 @@ def __init__(self, config, decoder_layer_cls, bias=False): "EAGLE module's vocab size should be <= base model vocab size!" ) # Initialize the buffers to zero. - # Their values depend on specific tokenzier and calibrate dataset, and should be set in training script. + # Their values depend on specific tokenizer and calibration dataset, and should be set in training script. if config.draft_vocab_size < config.vocab_size: self.register_buffer("d2t", torch.zeros(config.draft_vocab_size, dtype=torch.int64)) self.lm_head = nn.Linear(