Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

import modelopt.torch.opt as mto
import modelopt.torch.speculative as mtsp
from modelopt.torch.speculative.config import EagleConfig
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
from modelopt.torch.utils import print_rank_0

Expand Down Expand Up @@ -303,18 +303,9 @@ def train():
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
elif training_args.mode == "dflash":
# Auto-detect mask_token_id from tokenizer if not set
if not dflash_cfg.get("dflash_mask_token_id"):
if tokenizer.mask_token_id is not None:
dflash_cfg["dflash_mask_token_id"] = tokenizer.mask_token_id
print_rank_0(
f"Auto-detected mask_token_id={tokenizer.mask_token_id} from tokenizer"
)
else:
raise ValueError(
"mask_token_id not found in tokenizer and not set in config. "
"Set dflash.dflash_mask_token_id in the training YAML."
)
dflash_cfg = DFlashConfig.model_validate(
dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}
).model_dump()
mtsp.convert(model, [("dflash", dflash_cfg)])
else:
raise Exception(f"{training_args.mode} is not supported!")
Expand Down
38 changes: 38 additions & 0 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def _get_dflash_default_config():
class DFlashConfig(ModeloptBaseConfig):
"""DFlash config for block-wise parallel speculative decoding."""

dflash_offline: bool = ModeloptField(
default=False,
description="Whether to use detached DFlash (offline training from pre-computed hidden states).",
)

dflash_block_size: int = ModeloptField(
default=8,
description="Block size for parallel prediction. Draft predicts this many tokens per block.",
Expand Down Expand Up @@ -110,6 +115,39 @@ class DFlashConfig(ModeloptBaseConfig):
description="Whether to use torch.compile on DFlash forward/loss methods.",
)

@model_validator(mode="before")
@classmethod
def _derive_dflash_offline(cls, data: Any, info: ValidationInfo) -> Any:
"""Derive ``dflash_offline`` from ``data_args.offline_data_path`` when provided in context."""
ctx = info.context if info.context else {}
data_args = ctx.get("data_args")
if data_args is not None and isinstance(data, dict):
data["dflash_offline"] = data_args.offline_data_path is not None
return data

@model_validator(mode="before")
@classmethod
def _resolve_mask_token_id(cls, data: Any, info: ValidationInfo) -> Any:
"""Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context."""
if not isinstance(data, dict) or data.get("dflash_mask_token_id") is not None:
return data
ctx = info.context if info.context else {}
tokenizer = ctx.get("tokenizer")
if tokenizer is not None and getattr(tokenizer, "mask_token_id", None) is not None:
data["dflash_mask_token_id"] = tokenizer.mask_token_id
return data

@model_validator(mode="after")
def _check_mask_token_id(self) -> "DFlashConfig":
"""Validate that mask_token_id is set after all resolution attempts."""
if self.dflash_mask_token_id is None:
raise ValueError(
"dflash_mask_token_id is required. Set it in the config YAML "
"(dflash.dflash_mask_token_id=TOKEN_ID) or ensure the tokenizer "
"has a mask_token_id attribute."
)
return self


class MedusaConfig(ModeloptBaseConfig):
"""Medusa config."""
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/speculative/dflash/dflash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _setup(self):

def modify(self, config):
"""Base DFlash Model modify function. Child class should implement the details."""
self.dflash_offline = config.dflash_offline
self.dflash_block_size = config.dflash_block_size
self.dflash_freeze_base_model = config.dflash_freeze_base_model
self.dflash_loss_decay_factor = config.dflash_loss_decay_factor
Expand Down
106 changes: 28 additions & 78 deletions modelopt/torch/speculative/plugins/hf_dflash.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,96 +50,34 @@
lazy rope pattern needed for MLA models.
"""

import contextlib
import logging

import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers import PreTrainedModel
from transformers.models.qwen3.configuration_qwen3 import Qwen3Config as _Qwen3Config
from transformers.trainer_pt_utils import LabelSmoother
from transformers.utils import ModelOutput

from ..dflash.conversion import DFlashDMRegistry
from ..dflash.dflash_model import DFlashModel
from .hf_spec_mixin import HFSpecDecMixin
from .modeling_dflash import ( # noqa: F401
DFlashAttention,
DFlashBaseModelOutput,
DFlashModule,
build_target_layer_ids,
)
from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS

logger = logging.getLogger(__name__)

__all__ = ["HFDFlashModel"]


@DFlashDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
class HFDFlashModel(DFlashModel):
class HFDFlashModel(HFSpecDecMixin, DFlashModel):
"""DFlash Model for HuggingFace transformers."""

@property
def _base_model(self):
return self.get_submodule(self.base_model_path)

@property
def _base_model_embeddings(self):
return self.get_submodule(self.base_model_embeddings_path)

@property
def _base_model_lm_head(self):
return self.get_submodule(self.base_model_lm_head_path)

@property
def _base_llm_config(self):
return (
getattr(self.config, "text_config", None)
or getattr(self.config, "llm_config", None)
or self.config
)

def _find_base_model_parts(self):
"""Locate base model submodules (backbone, embeddings, lm_head) by probing known paths.

Reuses the shared path constants from modeling_fakebase (same as EAGLE).
"""
for name, paths in {
"base_model_path": _BASE_MODEL_PATHS,
"base_model_embeddings_path": _EMBED_TOKENS_PATHS,
"base_model_lm_head_path": _LM_HEAD_PATHS,
}.items():
for path in paths:
try:
submodule = self.get_submodule(path)
assert isinstance(submodule, torch.nn.Module)
setattr(self, name, path)
break
except Exception:
continue
else:
raise ValueError(f"Part {name} not found in model")

def _base_model_forward(self, input_ids, attention_mask, freeze=True, labels=None, **kwargs):
"""Run the base model forward pass with optional freeze and base-model loss."""
ctx = torch.no_grad() if freeze else contextlib.nullcontext()
with ctx:
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs,
)
base_loss = None
if not freeze and labels is not None:
loss_fct = CrossEntropyLoss()
base_loss = loss_fct(
outputs.logits.view(-1, outputs.logits.shape[-1]),
labels.view(-1),
)
return outputs, base_loss

def modify(self, config):
"""Initialize DFlash draft module."""
super().modify(config)
Expand Down Expand Up @@ -181,20 +119,17 @@ def modify(self, config):
self.dflash_config.block_size = self.dflash_block_size

# Target layer IDs
num_target_layers = base_config.num_hidden_layers
num_target_layers = (
base_config.num_orig_hidden_layers
if self.dflash_offline
else base_config.num_hidden_layers
)
num_draft_layers = self.dflash_config.num_hidden_layers
self.target_layer_ids = build_target_layer_ids(num_target_layers, num_draft_layers)
self.dflash_config.target_layer_ids = self.target_layer_ids

# mask_token_id: set in DFlashConfig (or auto-detected by main.py from tokenizer)
mask_id = config.dflash_mask_token_id
if mask_id is None:
raise ValueError(
"dflash_mask_token_id is required. Set it in the config YAML "
"(dflash.dflash_mask_token_id=TOKEN_ID) or let main.py auto-detect "
"from tokenizer.mask_token_id."
)
self.mask_token_id = mask_id
# mask_token_id: validated by DFlashConfig, auto-detected from tokenizer context
self.mask_token_id = config.dflash_mask_token_id
logger.info("DFlash mask_token_id: %s", self.mask_token_id)

# Freeze base model
Expand All @@ -207,10 +142,17 @@ def modify(self, config):
self.dflash_module = DFlashModule(self.dflash_config)
# Match base model dtype/device. Skip if base is on meta (during from_pretrained
# restore — the model will be moved to the correct device after weight loading).
base_device = next(self._base_model.layers[-1].parameters()).device
if self.dflash_offline:
base_device = self._base_model_lm_head.weight.device
else:
base_device = next(self._base_model.layers[-1].parameters()).device
if base_device.type != "meta":
self.dflash_module.to(self._base_model.dtype).to(base_device)

# Delete base model layers for offline training (save memory)
if self.dflash_offline:
self._base_model._modules.pop("layers")

self.is_quantized = False
self._num_anchors = self.dflash_num_anchors

Expand Down Expand Up @@ -465,9 +407,17 @@ def forward(
)

# 1. Run base model → extract target hidden states
base_outputs = self._dflash_base_model_forward(
input_ids, attention_mask, freeze=self.dflash_freeze_base_model
)
if self.dflash_offline:
assert "base_model_outputs" in kwargs
base_outputs = DFlashBaseModelOutput.from_offline_dict(kwargs["base_model_outputs"])
if base_outputs.logits is None and self.dflash_self_logit_distillation:
# Compute logits from last-layer hidden states for KD loss
out_hiddens = kwargs["base_model_outputs"].get("base_model_hidden_states")
base_outputs.logits = self._base_model_lm_head(out_hiddens)
else:
base_outputs = self._dflash_base_model_forward(
input_ids, attention_mask, freeze=self.dflash_freeze_base_model
)

# 2. Build loss mask.
# When labels are provided (answer_only_loss), they already encode both
Expand Down
78 changes: 10 additions & 68 deletions modelopt/torch/speculative/plugins/hf_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@
get_ttt_msk_func,
temporary_set_config_value,
)
from .hf_spec_mixin import HFSpecDecMixin
from .modeling_eagle import EagleBaseModelOutput, EagleModule
from .modeling_fakebase import _BASE_MODEL_PATHS, _EMBED_TOKENS_PATHS, _LM_HEAD_PATHS

__all__ = ["HFARValidation", "HFEagleModel"]

Expand All @@ -47,75 +47,14 @@


@EagleDMRegistry.register({PreTrainedModel: "hf.PreTrainedModel"})
class HFEagleModel(EagleModel):
class HFEagleModel(HFSpecDecMixin, EagleModel):
"""Eagle Model Class for huggingface models."""

@property
def _base_model(self):
return self.get_submodule(self.base_model_path)

@property
def _base_model_embeddings(self):
return self.get_submodule(self.base_model_embeddings_path)

@property
def _base_model_lm_head(self):
return self.get_submodule(self.base_model_lm_head_path)

@property
def _base_llm_config(self):
"""Return the llm config for the base model, from LLM or VLM."""
return (
getattr(self.config, "text_config", None)
or getattr(self.config, "llm_config", None)
or self.config
)

def _nvtx_range(self, name):
"""Optionally create an NVTX range for the given name when config.eagle_enable_nvtx is set."""
if not self.eagle_enable_nvtx:
return contextlib.nullcontext()
try:
import torch.cuda.nvtx as nvtx

return nvtx.range(name)
except Exception as e:
print(f"Failed to create NVTX range {name}: {e}")
return contextlib.nullcontext()

def _find_base_model_parts(self):
"""Find model parts from different models and set base_{part}_path attributes."""
for name, paths in {
"base_model_path": _BASE_MODEL_PATHS,
"base_model_embeddings_path": _EMBED_TOKENS_PATHS,
"base_model_lm_head_path": _LM_HEAD_PATHS,
}.items():
for path in paths:
try:
submodule = self.get_submodule(path)
assert isinstance(submodule, torch.nn.Module)
setattr(self, name, path)
break
except Exception:
continue
else:
raise ValueError(f"Part {name} not found in model")

def _activate_torch_compile(self):
import torch._dynamo

torch._dynamo.config.suppress_errors = True # Allow fallback to eager mode

compile_targets = [
("_prepare_eagle_inputs", {}),
("_eagle_forward", {"mode": "max-autotune"}),
("_eagle_loss", {"fullgraph": True}),
]
for name, kwargs in compile_targets:
try:
setattr(self, name, torch.compile(getattr(self, name), dynamic=False, **kwargs))
except Exception: # noqa: PERF203
print(f"Disabling torch.compile for {name} due to compilation error.")
_compile_targets = [
("_prepare_eagle_inputs", {}),
("_eagle_forward", {"mode": "max-autotune"}),
("_eagle_loss", {"fullgraph": True}),
]

def get_dummy_inputs(self) -> dict:
"""Construct dummy inputs for export forward pass."""
Expand Down Expand Up @@ -285,6 +224,9 @@ def modify(
if self.eagle_config._attn_implementation is None:
self.eagle_config._attn_implementation = "sdpa"

# Mixin interface attribute
self._enable_nvtx = self.eagle_enable_nvtx

# Set default aux_hidden_state layers
if (
self.eagle_config.use_aux_hidden_state
Expand Down
Loading
Loading