Skip to content

Commit f82a9ef

Browse files
committed
squash: spec-mixin
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 6ded36b commit f82a9ef

16 files changed

Lines changed: 961 additions & 840 deletions

File tree

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ repos:
9999
modelopt/torch/quantization/plugins/attention.py|
100100
modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py|
101101
modelopt/torch/speculative/eagle/utils.py|
102-
modelopt/torch/speculative/plugins/transformers.py|
102+
modelopt/torch/speculative/plugins/hf_medusa.py|
103103
modelopt/torch/utils/plugins/megatron_mmlu.py|
104104
examples/chained_optimizations/bert_prune_distill_quantize.py|
105105
examples/deepseek/quantize_to_nvfp4.py|

examples/speculative_decoding/eagle_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def patched_templated_attn(*args, **kwargs):
358358
original_op = args[2]
359359

360360
# This patch is only enabled for eagle model by context manager, not base model.
361-
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH
361+
patch_enbabled = modelopt.torch.speculative.plugins.hf_eagle.ENABLE_CP_TTT_PATCH
362362

363363
if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
364364
raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}")

examples/speculative_decoding/main.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949

5050
import modelopt.torch.opt as mto
5151
import modelopt.torch.speculative as mtsp
52-
from modelopt.torch.speculative.config import EagleConfig
52+
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig
5353
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
5454
from modelopt.torch.utils import print_rank_0
5555

@@ -303,18 +303,9 @@ def train():
303303
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)
304304
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
305305
elif training_args.mode == "dflash":
306-
# Auto-detect mask_token_id from tokenizer if not set
307-
if not dflash_cfg.get("dflash_mask_token_id"):
308-
if tokenizer.mask_token_id is not None:
309-
dflash_cfg["dflash_mask_token_id"] = tokenizer.mask_token_id
310-
print_rank_0(
311-
f"Auto-detected mask_token_id={tokenizer.mask_token_id} from tokenizer"
312-
)
313-
else:
314-
raise ValueError(
315-
"mask_token_id not found in tokenizer and not set in config. "
316-
"Set dflash.dflash_mask_token_id in the training YAML."
317-
)
306+
dflash_cfg = DFlashConfig.model_validate(
307+
dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}
308+
).model_dump()
318309
mtsp.convert(model, [("dflash", dflash_cfg)])
319310
else:
320311
raise Exception(f"{training_args.mode} is not supported!")

examples/speculative_decoding/scripts/ar_validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from transformers import AutoTokenizer
2828

2929
import modelopt.torch.opt as mto
30-
from modelopt.torch.speculative.plugins.transformers import HFARValidation
30+
from modelopt.torch.speculative.plugins.hf_eagle import HFARValidation
3131
from modelopt.torch.speculative.utils import load_vlm_or_llm
3232

3333
mto.enable_huggingface_checkpointing()

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,8 @@ def _export_config(self):
171171
template_config = deepcopy(template_config)
172172

173173
def _get_config_from_draft_or_base(key: str, model: nn.Module):
174-
if getattr(model._draft_model_config, key, None) is not None:
175-
return getattr(model._draft_model_config, key)
174+
if getattr(model.eagle_config, key, None) is not None:
175+
return getattr(model.eagle_config, key)
176176
elif getattr(model.config, key, None) is not None:
177177
return getattr(model.config, key)
178178
else:

modelopt/torch/speculative/config.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def _get_dflash_default_config():
6767
class DFlashConfig(ModeloptBaseConfig):
6868
"""DFlash config for block-wise parallel speculative decoding."""
6969

70+
dflash_offline: bool = ModeloptField(
71+
default=False,
72+
description="Whether to use detached DFlash (offline training from pre-computed hidden states).",
73+
)
74+
7075
dflash_block_size: int = ModeloptField(
7176
default=8,
7277
description="Block size for parallel prediction. Draft predicts this many tokens per block.",
@@ -110,6 +115,39 @@ class DFlashConfig(ModeloptBaseConfig):
110115
description="Whether to use torch.compile on DFlash forward/loss methods.",
111116
)
112117

118+
@model_validator(mode="before")
119+
@classmethod
120+
def _derive_dflash_offline(cls, data: Any, info: ValidationInfo) -> Any:
121+
"""Derive ``dflash_offline`` from ``data_args.offline_data_path`` when provided in context."""
122+
ctx = info.context if info.context else {}
123+
data_args = ctx.get("data_args")
124+
if data_args is not None and isinstance(data, dict):
125+
data["dflash_offline"] = data_args.offline_data_path is not None
126+
return data
127+
128+
@model_validator(mode="before")
129+
@classmethod
130+
def _resolve_mask_token_id(cls, data: Any, info: ValidationInfo) -> Any:
131+
"""Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context."""
132+
if not isinstance(data, dict) or data.get("dflash_mask_token_id") is not None:
133+
return data
134+
ctx = info.context if info.context else {}
135+
tokenizer = ctx.get("tokenizer")
136+
if tokenizer is not None and getattr(tokenizer, "mask_token_id", None) is not None:
137+
data["dflash_mask_token_id"] = tokenizer.mask_token_id
138+
return data
139+
140+
@model_validator(mode="after")
141+
def _check_mask_token_id(self) -> "DFlashConfig":
142+
"""Validate that mask_token_id is set after all resolution attempts."""
143+
if self.dflash_mask_token_id is None:
144+
raise ValueError(
145+
"dflash_mask_token_id is required. Set it in the config YAML "
146+
"(dflash.dflash_mask_token_id=TOKEN_ID) or ensure the tokenizer "
147+
"has a mask_token_id attribute."
148+
)
149+
return self
150+
113151

114152
class MedusaConfig(ModeloptBaseConfig):
115153
"""Medusa config."""

modelopt/torch/speculative/dflash/dflash_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def _setup(self):
2727

2828
def modify(self, config):
2929
"""Base DFlash Model modify function. Child class should implement the details."""
30+
self.dflash_offline = config.dflash_offline
3031
self.dflash_block_size = config.dflash_block_size
3132
self.dflash_freeze_base_model = config.dflash_freeze_base_model
3233
self.dflash_loss_decay_factor = config.dflash_loss_decay_factor

modelopt/torch/speculative/eagle/default_config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@
3737
"use_aux_hidden_state": False,
3838
"eagle_aux_hidden_state_layer_ids": [],
3939
"use_mtp_layernorm": False,
40-
"parallel_draft_step": 1,
41-
"parallel_draft_heads_num_layers": 1,
4240
"has_lm_head": False,
4341
"head_dim": 128,
4442
}
@@ -107,7 +105,5 @@
107105
"use_aux_hidden_state": True,
108106
"eagle_aux_hidden_state_layer_ids": [],
109107
"use_mtp_layernorm": False,
110-
"parallel_draft_step": 1,
111-
"parallel_draft_heads_num_layers": 1,
112108
"has_lm_head": False,
113109
}

modelopt/torch/speculative/plugins/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Please check out the source code of this module for examples of how plugins work and how you can
1919
write your own one. Currently, we support plugins for
2020
21-
- :meth:`transformers<modelopt.torch.speculative.plugins.transformers>`
21+
- :meth:`hf_eagle<modelopt.torch.speculative.plugins.hf_eagle>`
2222
"""
2323

2424
from modelopt.torch.utils import import_plugin
@@ -31,4 +31,5 @@
3131

3232
with import_plugin("transformers"):
3333
from .hf_dflash import *
34-
from .transformers import *
34+
from .hf_eagle import *
35+
from .hf_medusa import *

0 commit comments

Comments
 (0)