Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ repos:
modelopt/torch/quantization/plugins/attention.py|
modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py|
modelopt/torch/speculative/eagle/utils.py|
modelopt/torch/speculative/plugins/transformers.py|
modelopt/torch/speculative/plugins/hf_medusa.py|
modelopt/torch/utils/plugins/megatron_mmlu.py|
examples/chained_optimizations/bert_prune_distill_quantize.py|
examples/deepseek/quantize_to_nvfp4.py|
Expand Down
2 changes: 1 addition & 1 deletion examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def patched_templated_attn(*args, **kwargs):
original_op = args[2]

# This patch is only enabled for eagle model by context manager, not base model.
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH
patch_enbabled = modelopt.torch.speculative.plugins.hf_eagle.ENABLE_CP_TTT_PATCH

if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}")
Expand Down
17 changes: 4 additions & 13 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@

import modelopt.torch.opt as mto
import modelopt.torch.speculative as mtsp
from modelopt.torch.speculative.config import EagleConfig
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
from modelopt.torch.utils import print_rank_0

Expand Down Expand Up @@ -303,18 +303,9 @@ def train():
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
elif training_args.mode == "dflash":
# Auto-detect mask_token_id from tokenizer if not set
if not dflash_cfg.get("dflash_mask_token_id"):
if tokenizer.mask_token_id is not None:
dflash_cfg["dflash_mask_token_id"] = tokenizer.mask_token_id
print_rank_0(
f"Auto-detected mask_token_id={tokenizer.mask_token_id} from tokenizer"
)
else:
raise ValueError(
"mask_token_id not found in tokenizer and not set in config. "
"Set dflash.dflash_mask_token_id in the training YAML."
)
dflash_cfg = DFlashConfig.model_validate(
dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}
).model_dump()
mtsp.convert(model, [("dflash", dflash_cfg)])
else:
raise Exception(f"{training_args.mode} is not supported!")
Expand Down
2 changes: 1 addition & 1 deletion examples/speculative_decoding/scripts/ar_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from transformers import AutoTokenizer

import modelopt.torch.opt as mto
from modelopt.torch.speculative.plugins.transformers import HFARValidation
from modelopt.torch.speculative.plugins.hf_eagle import HFARValidation
from modelopt.torch.speculative.utils import load_vlm_or_llm

mto.enable_huggingface_checkpointing()
Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/export/plugins/hf_spec_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ def _export_config(self):
template_config = deepcopy(template_config)

def _get_config_from_draft_or_base(key: str, model: nn.Module):
if getattr(model._draft_model_config, key, None) is not None:
return getattr(model._draft_model_config, key)
if getattr(model.eagle_config, key, None) is not None:
return getattr(model.eagle_config, key)
elif getattr(model.config, key, None) is not None:
return getattr(model.config, key)
else:
Expand Down
38 changes: 38 additions & 0 deletions modelopt/torch/speculative/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ def _get_dflash_default_config():
class DFlashConfig(ModeloptBaseConfig):
"""DFlash config for block-wise parallel speculative decoding."""

dflash_offline: bool = ModeloptField(
default=False,
description="Whether to use detached DFlash (offline training from pre-computed hidden states).",
)

dflash_block_size: int = ModeloptField(
default=8,
description="Block size for parallel prediction. Draft predicts this many tokens per block.",
Expand Down Expand Up @@ -110,6 +115,39 @@ class DFlashConfig(ModeloptBaseConfig):
description="Whether to use torch.compile on DFlash forward/loss methods.",
)

@model_validator(mode="before")
@classmethod
def _derive_dflash_offline(cls, data: Any, info: ValidationInfo) -> Any:
"""Derive ``dflash_offline`` from ``data_args.offline_data_path`` when provided in context."""
ctx = info.context if info.context else {}
data_args = ctx.get("data_args")
if data_args is not None and isinstance(data, dict):
data["dflash_offline"] = data_args.offline_data_path is not None
return data

@model_validator(mode="before")
@classmethod
def _resolve_mask_token_id(cls, data: Any, info: ValidationInfo) -> Any:
"""Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context."""
if not isinstance(data, dict) or data.get("dflash_mask_token_id") is not None:
return data
ctx = info.context if info.context else {}
tokenizer = ctx.get("tokenizer")
if tokenizer is not None and getattr(tokenizer, "mask_token_id", None) is not None:
data["dflash_mask_token_id"] = tokenizer.mask_token_id
return data

@model_validator(mode="after")
def _check_mask_token_id(self) -> "DFlashConfig":
"""Validate that mask_token_id is set after all resolution attempts."""
if self.dflash_mask_token_id is None:
raise ValueError(
"dflash_mask_token_id is required. Set it in the config YAML "
"(dflash.dflash_mask_token_id=TOKEN_ID) or ensure the tokenizer "
"has a mask_token_id attribute."
)
return self


class MedusaConfig(ModeloptBaseConfig):
"""Medusa config."""
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/speculative/dflash/dflash_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def _setup(self):

def modify(self, config):
"""Base DFlash Model modify function. Child class should implement the details."""
self.dflash_offline = config.dflash_offline
self.dflash_block_size = config.dflash_block_size
self.dflash_freeze_base_model = config.dflash_freeze_base_model
self.dflash_loss_decay_factor = config.dflash_loss_decay_factor
Expand Down
4 changes: 0 additions & 4 deletions modelopt/torch/speculative/eagle/default_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
"use_aux_hidden_state": False,
"eagle_aux_hidden_state_layer_ids": [],
"use_mtp_layernorm": False,
"parallel_draft_step": 1,
"parallel_draft_heads_num_layers": 1,
"has_lm_head": False,
"head_dim": 128,
}
Expand Down Expand Up @@ -107,7 +105,5 @@
"use_aux_hidden_state": True,
"eagle_aux_hidden_state_layer_ids": [],
"use_mtp_layernorm": False,
"parallel_draft_step": 1,
"parallel_draft_heads_num_layers": 1,
"has_lm_head": False,
}
5 changes: 3 additions & 2 deletions modelopt/torch/speculative/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
Please check out the source code of this module for examples of how plugins work and how you can
write your own one. Currently, we support plugins for

- :meth:`transformers<modelopt.torch.speculative.plugins.transformers>`
- :meth:`hf_eagle<modelopt.torch.speculative.plugins.hf_eagle>`
"""

from modelopt.torch.utils import import_plugin
Expand All @@ -31,4 +31,5 @@

with import_plugin("transformers"):
from .hf_dflash import *
from .transformers import *
from .hf_eagle import *
from .hf_medusa import *
Loading
Loading