From f208109848b0b010d817a7a81221d175c933601d Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sun, 19 Apr 2026 21:51:01 +0000 Subject: [PATCH 1/2] [Feat]: Offline DFlash training - Add `dflash_offline` config flag for training from pre-computed hidden states; deletes base model layers to save memory. - Move `dflash_mask_token_id` auto-detection from `main.py` into `DFlashConfig` Pydantic validators; derive `dflash_offline` from `data_args.offline_data_path`. - Add `DFlashBaseModelOutput.from_offline_dict` classmethod for consuming pre-computed hidden states in the forward path. Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- examples/speculative_decoding/main.py | 17 ++------ modelopt/torch/speculative/config.py | 38 ++++++++++++++++++ .../torch/speculative/dflash/dflash_model.py | 1 + .../torch/speculative/plugins/hf_dflash.py | 40 ++++++++++++------- .../speculative/plugins/modeling_dflash.py | 8 ++++ 5 files changed, 77 insertions(+), 27 deletions(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index efc4ba82bd..6002a397ab 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -49,7 +49,7 @@ import modelopt.torch.opt as mto import modelopt.torch.speculative as mtsp -from modelopt.torch.speculative.config import EagleConfig +from modelopt.torch.speculative.config import DFlashConfig, EagleConfig from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading from modelopt.torch.utils import print_rank_0 @@ -303,18 +303,9 @@ def train(): model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True) print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.") elif training_args.mode == "dflash": - # Auto-detect mask_token_id from tokenizer if not set - if not dflash_cfg.get("dflash_mask_token_id"): - if tokenizer.mask_token_id is not None: - dflash_cfg["dflash_mask_token_id"] = tokenizer.mask_token_id - print_rank_0( - f"Auto-detected mask_token_id={tokenizer.mask_token_id} from tokenizer" - ) - else: - raise ValueError( - "mask_token_id not found in tokenizer and not set in config. " - "Set dflash.dflash_mask_token_id in the training YAML." - ) + dflash_cfg = DFlashConfig.model_validate( + dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args} + ).model_dump() mtsp.convert(model, [("dflash", dflash_cfg)]) else: raise Exception(f"{training_args.mode} is not supported!") diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 09592583d1..a7691c68c3 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -67,6 +67,11 @@ def _get_dflash_default_config(): class DFlashConfig(ModeloptBaseConfig): """DFlash config for block-wise parallel speculative decoding.""" + dflash_offline: bool = ModeloptField( + default=False, + description="Whether to use detached DFlash (offline training from pre-computed hidden states).", + ) + dflash_block_size: int = ModeloptField( default=8, description="Block size for parallel prediction. Draft predicts this many tokens per block.", @@ -110,6 +115,39 @@ class DFlashConfig(ModeloptBaseConfig): description="Whether to use torch.compile on DFlash forward/loss methods.", ) + @model_validator(mode="before") + @classmethod + def _derive_dflash_offline(cls, data: Any, info: ValidationInfo) -> Any: + """Derive ``dflash_offline`` from ``data_args.offline_data_path`` when provided in context.""" + ctx = info.context if info.context else {} + data_args = ctx.get("data_args") + if data_args is not None and isinstance(data, dict): + data["dflash_offline"] = data_args.offline_data_path is not None + return data + + @model_validator(mode="before") + @classmethod + def _resolve_mask_token_id(cls, data: Any, info: ValidationInfo) -> Any: + """Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context.""" + if not isinstance(data, dict) or data.get("dflash_mask_token_id") is not None: + return data + ctx = info.context if info.context else {} + tokenizer = ctx.get("tokenizer") + if tokenizer is not None and getattr(tokenizer, "mask_token_id", None) is not None: + data["dflash_mask_token_id"] = tokenizer.mask_token_id + return data + + @model_validator(mode="after") + def _check_mask_token_id(self) -> "DFlashConfig": + """Validate that mask_token_id is set after all resolution attempts.""" + if self.dflash_mask_token_id is None: + raise ValueError( + "dflash_mask_token_id is required. Set it in the config YAML " + "(dflash.dflash_mask_token_id=TOKEN_ID) or ensure the tokenizer " + "has a mask_token_id attribute." + ) + return self + class MedusaConfig(ModeloptBaseConfig): """Medusa config.""" diff --git a/modelopt/torch/speculative/dflash/dflash_model.py b/modelopt/torch/speculative/dflash/dflash_model.py index 0a10f065eb..a99e93c816 100644 --- a/modelopt/torch/speculative/dflash/dflash_model.py +++ b/modelopt/torch/speculative/dflash/dflash_model.py @@ -27,6 +27,7 @@ def _setup(self): def modify(self, config): """Base DFlash Model modify function. Child class should implement the details.""" + self.dflash_offline = config.dflash_offline self.dflash_block_size = config.dflash_block_size self.dflash_freeze_base_model = config.dflash_freeze_base_model self.dflash_loss_decay_factor = config.dflash_loss_decay_factor diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index 46c96d2d30..ecde8a3426 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -181,20 +181,17 @@ def modify(self, config): self.dflash_config.block_size = self.dflash_block_size # Target layer IDs - num_target_layers = base_config.num_hidden_layers + num_target_layers = ( + base_config.num_orig_hidden_layers + if self.dflash_offline + else base_config.num_hidden_layers + ) num_draft_layers = self.dflash_config.num_hidden_layers self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers) self.dflash_config.target_layer_ids = self.target_layer_ids - # mask_token_id: set in DFlashConfig (or auto-detected by main.py from tokenizer) - mask_id = config.dflash_mask_token_id - if mask_id is None: - raise ValueError( - "dflash_mask_token_id is required. Set it in the config YAML " - "(dflash.dflash_mask_token_id=TOKEN_ID) or let main.py auto-detect " - "from tokenizer.mask_token_id." - ) - self.mask_token_id = mask_id + # mask_token_id: validated by DFlashConfig, auto-detected from tokenizer context + self.mask_token_id = config.dflash_mask_token_id logger.info("DFlash mask_token_id: %s", self.mask_token_id) # Freeze base model @@ -207,10 +204,17 @@ def modify(self, config): self.dflash_module = DFlashModule(self.dflash_config) # Match base model dtype/device. Skip if base is on meta (during from_pretrained # restore — the model will be moved to the correct device after weight loading). - base_device = next(self._base_model.layers[-1].parameters()).device + if self.dflash_offline: + base_device = self._base_model_lm_head.weight.device + else: + base_device = next(self._base_model.layers[-1].parameters()).device if base_device.type != "meta": self.dflash_module.to(self._base_model.dtype).to(base_device) + # Delete base model layers for offline training (save memory) + if self.dflash_offline: + self._base_model._modules.pop("layers") + self.is_quantized = False self._num_anchors = self.dflash_num_anchors @@ -465,9 +469,17 @@ def forward( ) # 1. Run base model → extract target hidden states - base_outputs = self._dflash_base_model_forward( - input_ids, attention_mask, freeze=self.dflash_freeze_base_model - ) + if self.dflash_offline: + assert "base_model_outputs" in kwargs + base_outputs = DFlashBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"]) + if base_outputs.logits is None and self.dflash_self_logit_distillation: + # Compute logits from last-layer hidden states for KD loss + out_hiddens = kwargs["base_model_outputs"].get("base_model_hidden_states") + base_outputs.logits = self._base_model_lm_head(out_hiddens) + else: + base_outputs = self._dflash_base_model_forward( + input_ids, attention_mask, freeze=self.dflash_freeze_base_model + ) # 2. Build loss mask. # When labels are provided (answer_only_loss), they already encode both diff --git a/modelopt/torch/speculative/plugins/modeling_dflash.py b/modelopt/torch/speculative/plugins/modeling_dflash.py index 4cb8684b66..87db08e49d 100644 --- a/modelopt/torch/speculative/plugins/modeling_dflash.py +++ b/modelopt/torch/speculative/plugins/modeling_dflash.py @@ -42,6 +42,14 @@ class DFlashBaseModelOutput: target_hidden: torch.Tensor # concatenated hidden states from target layers [B, seq, N*H] logits: torch.Tensor | None = None # base model logits [B, seq, vocab] + @classmethod + def from_offline_dict(cls, d: dict): + """Construct from a dict of pre-computed base model outputs (offline training).""" + return cls( + target_hidden=d.get("aux_hidden_states"), + logits=d.get("base_model_logits"), + ) + def build_target_layer_ids(num_target_layers, num_draft_layers): """Select layers uniformly from the target model for feature extraction.""" From 9e4eeb0969204471c430411289de1b7526bfb925 Mon Sep 17 00:00:00 2001 From: h-guo18 <67671475+h-guo18@users.noreply.github.com> Date: Sun, 19 Apr 2026 21:30:33 +0000 Subject: [PATCH 2/2] [Refactor]: HFSpecDecMixin shared across HF spec-decoding plugins Extract duplicated base-model discovery, forward pass, NVTX profiling, and torch.compile logic from HFEagleModel / HFDFlashModel into a shared mixin (hf_spec_mixin.py). HFEagleModel and HFDFlashModel now inherit from (HFSpecDecMixin, EagleModel/DFlashModel). Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com> --- .../torch/speculative/plugins/hf_dflash.py | 66 +------ .../torch/speculative/plugins/hf_eagle.py | 78 +------- .../speculative/plugins/hf_spec_mixin.py | 175 ++++++++++++++++++ 3 files changed, 187 insertions(+), 132 deletions(-) create mode 100644 modelopt/torch/speculative/plugins/hf_spec_mixin.py diff --git a/modelopt/torch/speculative/plugins/hf_dflash.py b/modelopt/torch/speculative/plugins/hf_dflash.py index ecde8a3426..7aec1187d1 100644 --- a/modelopt/torch/speculative/plugins/hf_dflash.py +++ b/modelopt/torch/speculative/plugins/hf_dflash.py @@ -50,12 +50,10 @@ lazy rope pattern needed for MLA models. """ -import contextlib import logging import torch import torch.nn.functional as F -from torch.nn import CrossEntropyLoss from transformers import PreTrainedModel from transformers.models.qwen3.configuration_qwen3 import Qwen3Config as _Qwen3Config from transformers.trainer_pt_utils import LabelSmoother @@ -63,13 +61,13 @@ from ..dflash.conversion import DFlashDMRegistry from ..dflash.dflash_model import DFlashModel +from .hf_spec_mixin import HFSpecDecMixin from .modeling_dflash import ( # noqa: F401 DFlashAttention, DFlashBaseModelOutput, DFlashModule, build_target_layer_ids, ) -from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS logger = logging.getLogger(__name__) @@ -77,69 +75,9 @@ @DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) -class HFDFlashModel(DFlashModel): +class HFDFlashModel(HFSpecDecMixin, DFlashModel): """DFlash Model for HuggingFace transformers.""" - @property - def _base_model(self): - return self.get_submodule(self.base_model_path) - - @property - def _base_model_embeddings(self): - return self.get_submodule(self.base_model_embeddings_path) - - @property - def _base_model_lm_head(self): - return self.get_submodule(self.base_model_lm_head_path) - - @property - def _base_llm_config(self): - return ( - getattr(self.config, "text_config", None) - or getattr(self.config, "llm_config", None) - or self.config - ) - - def _find_base_model_parts(self): - """Locate base model submodules (backbone, embeddings, lm_head) by probing known paths. - - Reuses the shared path constants from modeling_fakebase (same as EAGLE). - """ - for name, paths in { - "base_model_path": _BASE_MODEL_PATHS, - "base_model_embeddings_path": _EMBED_TOKENS_PATHS, - "base_model_lm_head_path": _LM_HEAD_PATHS, - }.items(): - for path in paths: - try: - submodule = self.get_submodule(path) - assert isinstance(submodule, torch.nn.Module) - setattr(self, name, path) - break - except Exception: - continue - else: - raise ValueError(f"Part {name} not found in model") - - def _base_model_forward(self, input_ids, attention_mask, freeze=True, labels=None, **kwargs): - """Run the base model forward pass with optional freeze and base-model loss.""" - ctx = torch.no_grad() if freeze else contextlib.nullcontext() - with ctx: - outputs = super().forward( - input_ids=input_ids, - attention_mask=attention_mask, - output_hidden_states=True, - **kwargs, - ) - base_loss = None - if not freeze and labels is not None: - loss_fct = CrossEntropyLoss() - base_loss = loss_fct( - outputs.logits.view(-1, outputs.logits.shape[-1]), - labels.view(-1), - ) - return outputs, base_loss - def modify(self, config): """Initialize DFlash draft module.""" super().modify(config) diff --git a/modelopt/torch/speculative/plugins/hf_eagle.py b/modelopt/torch/speculative/plugins/hf_eagle.py index d2af52a3e8..ef7ba8d6c0 100644 --- a/modelopt/torch/speculative/plugins/hf_eagle.py +++ b/modelopt/torch/speculative/plugins/hf_eagle.py @@ -36,8 +36,8 @@ get_ttt_msk_func, temporary_set_config_value, ) +from .hf_spec_mixin import HFSpecDecMixin from .modeling_eagle import EagleBaseModelOutput, EagleModule -from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS __all__ = ["HFARValidation", "HFEagleModel"] @@ -47,75 +47,14 @@ @EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) -class HFEagleModel(EagleModel): +class HFEagleModel(HFSpecDecMixin, EagleModel): """Eagle Model Class for huggingface models.""" - @property - def _base_model(self): - return self.get_submodule(self.base_model_path) - - @property - def _base_model_embeddings(self): - return self.get_submodule(self.base_model_embeddings_path) - - @property - def _base_model_lm_head(self): - return self.get_submodule(self.base_model_lm_head_path) - - @property - def _base_llm_config(self): - """Return the llm config for the base model, from LLM or VLM.""" - return ( - getattr(self.config, "text_config", None) - or getattr(self.config, "llm_config", None) - or self.config - ) - - def _nvtx_range(self, name): - """Optionally create an NVTX range for the given name when config.eagle_enable_nvtx is set.""" - if not self.eagle_enable_nvtx: - return contextlib.nullcontext() - try: - import torch.cuda.nvtx as nvtx - - return nvtx.range(name) - except Exception as e: - print(f"Failed to create NVTX range {name}: {e}") - return contextlib.nullcontext() - - def _find_base_model_parts(self): - """Find model parts from different models and set base_{part}_path attributes.""" - for name, paths in { - "base_model_path": _BASE_MODEL_PATHS, - "base_model_embeddings_path": _EMBED_TOKENS_PATHS, - "base_model_lm_head_path": _LM_HEAD_PATHS, - }.items(): - for path in paths: - try: - submodule = self.get_submodule(path) - assert isinstance(submodule, torch.nn.Module) - setattr(self, name, path) - break - except Exception: - continue - else: - raise ValueError(f"Part {name} not found in model") - - def _activate_torch_compile(self): - import torch._dynamo - - torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode - - compile_targets = [ - ("_prepare_eagle_inputs", {}), - ("_eagle_forward", {"mode": "max-autotune"}), - ("_eagle_loss", {"fullgraph": True}), - ] - for name, kwargs in compile_targets: - try: - setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) - except Exception: # noqa: PERF203 - print(f"Disabling torch.compile for {name} due to compilation error.") + _compile_targets = [ + ("_prepare_eagle_inputs", {}), + ("_eagle_forward", {"mode": "max-autotune"}), + ("_eagle_loss", {"fullgraph": True}), + ] def get_dummy_inputs(self) -> dict: """Construct dummy inputs for export forward pass.""" @@ -285,6 +224,9 @@ def modify( if self.eagle_config._attn_implementation is None: self.eagle_config._attn_implementation = "sdpa" + # Mixin interface attribute + self._enable_nvtx = self.eagle_enable_nvtx + # Set default aux_hidden_state layers if ( self.eagle_config.use_aux_hidden_state diff --git a/modelopt/torch/speculative/plugins/hf_spec_mixin.py b/modelopt/torch/speculative/plugins/hf_spec_mixin.py new file mode 100644 index 0000000000..ea90849c19 --- /dev/null +++ b/modelopt/torch/speculative/plugins/hf_spec_mixin.py @@ -0,0 +1,175 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Shared mixin for HuggingFace speculative decoding model classes.""" + +# mypy: disable-error-code="attr-defined,misc" + +import contextlib + +import torch +from torch.nn import CrossEntropyLoss + +from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS + +__all__ = ["HFSpecDecMixin"] + + +class HFSpecDecMixin: + """Mixin providing HuggingFace base-model discovery for speculative decoding plugins. + + Provides shared properties and methods for locating base-model submodules + (backbone, embeddings, lm_head) and running the base-model forward pass. + + Must be used with multiple inheritance alongside an algorithm-specific base + (EagleModel, DFlashModel, etc.) that inherits from DynamicModule. + + Example:: + + @EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"}) + class HFEagleModel(HFSpecDecMixin, EagleModel): ... + """ + + # -- Class attributes (subclasses may override) -- + + # List of (method_name, compile_kwargs) for _activate_torch_compile(). + # Example: [("_eagle_forward", {"mode": "max-autotune"}), ("_eagle_loss", {"fullgraph": True})] + _compile_targets: list[tuple[str, dict]] = [] + + # -- Properties: base model access -- + + @property + def _base_model(self): + return self.get_submodule(self.base_model_path) + + @property + def _base_model_embeddings(self): + return self.get_submodule(self.base_model_embeddings_path) + + @property + def _base_model_lm_head(self): + return self.get_submodule(self.base_model_lm_head_path) + + @property + def _base_llm_config(self): + """Return the LLM config for the base model, handling VLM nesting.""" + return ( + getattr(self.config, "text_config", None) + or getattr(self.config, "llm_config", None) + or self.config + ) + + # -- Methods: model discovery -- + + def _find_base_model_parts(self): + """Find model parts from different models and set base_{part}_path attributes. + + Iterates over candidate submodule paths from modeling_fakebase to locate the + base model backbone, embedding layer, and LM head. + + Raises: + ValueError: If any required model part cannot be found. + """ + for name, paths in { + "base_model_path": _BASE_MODEL_PATHS, + "base_model_embeddings_path": _EMBED_TOKENS_PATHS, + "base_model_lm_head_path": _LM_HEAD_PATHS, + }.items(): + for path in paths: + try: + submodule = self.get_submodule(path) + assert isinstance(submodule, torch.nn.Module) + setattr(self, name, path) + break + except Exception: + continue + else: + raise ValueError(f"Part {name} not found in model") + + # -- Methods: base model forward -- + + def _base_model_forward(self, input_ids, attention_mask, freeze=True, labels=None, **kwargs): + """Run the base model forward pass with optional freeze and base-model loss. + + Args: + input_ids: Input token IDs. + attention_mask: Attention mask. + freeze: If True, run under torch.no_grad(). + labels: Optional labels for computing base model CE loss. + **kwargs: Additional keyword arguments forwarded to the base model. + + Returns: + (outputs, base_loss) tuple where outputs is the raw model output and + base_loss is the cross-entropy loss (None if freeze=True or labels=None). + """ + ctx = torch.no_grad() if freeze else contextlib.nullcontext() + with ctx: + outputs = super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + output_hidden_states=True, + **kwargs, + ) + base_loss = None + if not freeze and labels is not None: + loss_fct = CrossEntropyLoss() + base_loss = loss_fct( + outputs.logits.view(-1, outputs.logits.shape[-1]), + labels.view(-1), + ) + return outputs, base_loss + + # -- Methods: profiling & compilation -- + + def _nvtx_range(self, name): + """Optionally create an NVTX range for profiling. + + Enabled when the subclass sets ``self._enable_nvtx = True`` in ``modify()``. + """ + if not getattr(self, "_enable_nvtx", False): + return contextlib.nullcontext() + try: + import torch.cuda.nvtx as nvtx + + return nvtx.range(name) + except Exception as e: + print(f"Failed to create NVTX range {name}: {e}") + return contextlib.nullcontext() + + def _activate_torch_compile(self): + """Apply ``torch.compile`` to methods listed in ``_compile_targets``. + + Each entry is ``(method_name, extra_kwargs)`` passed to ``torch.compile(..., dynamic=False)``. + Failures fall back to eager mode silently. + """ + import torch._dynamo + + torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode + + for name, kwargs in self._compile_targets: + try: + setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs)) + except Exception: # noqa: PERF203 + print(f"Disabling torch.compile for {name} due to compilation error.") + + # -- Methods: export interface (subclasses must override) -- + + def get_dummy_inputs(self) -> dict: + """Construct dummy inputs for export forward pass. Subclasses must override.""" + raise NotImplementedError + + def get_exporter(self): + """Return the exporter for the draft model. Subclasses must override.""" + raise NotImplementedError