Skip to content

Commit f91cf9d

Browse files
committed
squash: spec-mixin
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent 361f7e3 commit f91cf9d

File tree

16 files changed

+975
-865
lines changed

16 files changed

+975
-865
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ repos:
9999
modelopt/torch/quantization/plugins/attention.py|
100100
modelopt/torch/sparsity/attention_sparsity/methods/vsa_utils.py|
101101
modelopt/torch/speculative/eagle/utils.py|
102-
modelopt/torch/speculative/plugins/transformers.py|
102+
modelopt/torch/speculative/plugins/hf_medusa.py|
103103
modelopt/torch/utils/plugins/megatron_mmlu.py|
104104
examples/chained_optimizations/bert_prune_distill_quantize.py|
105105
examples/deepseek/quantize_to_nvfp4.py|

examples/speculative_decoding/eagle_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def patched_templated_attn(*args, **kwargs):
259259
original_op = args[2]
260260

261261
# This patch is only enabled for eagle model by context manager, not base model.
262-
patch_enbabled = modelopt.torch.speculative.plugins.transformers.ENABLE_CP_TTT_PATCH
262+
patch_enbabled = modelopt.torch.speculative.plugins.hf_eagle.ENABLE_CP_TTT_PATCH
263263

264264
if patch_enbabled and original_op != torch.ops.aten._scaled_dot_product_cudnn_attention:
265265
raise ValueError(f"CP TTT only supports cudnn attention now. Got: {original_op}")

examples/speculative_decoding/main.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848

4949
import modelopt.torch.opt as mto
5050
import modelopt.torch.speculative as mtsp
51-
from modelopt.torch.speculative.config import EagleConfig
51+
from modelopt.torch.speculative.config import DFlashConfig, EagleConfig
5252
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
5353
from modelopt.torch.utils import print_rank_0
5454

@@ -300,18 +300,9 @@ def train():
300300
model.eagle_module.d2t = torch.load(data_args.draft_vocab_cache, weights_only=True)
301301
print_rank_0(f"Loaded draft vocab cache from {data_args.draft_vocab_cache}.")
302302
elif training_args.mode == "dflash":
303-
# Auto-detect mask_token_id from tokenizer if not set
304-
if not dflash_cfg.get("dflash_mask_token_id"):
305-
if tokenizer.mask_token_id is not None:
306-
dflash_cfg["dflash_mask_token_id"] = tokenizer.mask_token_id
307-
print_rank_0(
308-
f"Auto-detected mask_token_id={tokenizer.mask_token_id} from tokenizer"
309-
)
310-
else:
311-
raise ValueError(
312-
"mask_token_id not found in tokenizer and not set in config. "
313-
"Set dflash.dflash_mask_token_id in the training YAML."
314-
)
303+
dflash_cfg = DFlashConfig.model_validate(
304+
dflash_cfg, context={"tokenizer": tokenizer, "data_args": data_args}
305+
).model_dump()
315306
mtsp.convert(model, [("dflash", dflash_cfg)])
316307
else:
317308
raise Exception(f"{training_args.mode} is not supported!")

examples/speculative_decoding/scripts/ar_validate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from transformers import AutoTokenizer
2828

2929
import modelopt.torch.opt as mto
30-
from modelopt.torch.speculative.plugins.transformers import HFARValidation
30+
from modelopt.torch.speculative.plugins.hf_eagle import HFARValidation
3131
from modelopt.torch.speculative.utils import load_vlm_or_llm
3232

3333
mto.enable_huggingface_checkpointing()

modelopt/torch/export/plugins/hf_spec_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,8 +165,8 @@ def _export_config(self):
165165
template_config = deepcopy(template_config)
166166

167167
def _get_config_from_draft_or_base(key: str, model: nn.Module):
168-
if getattr(model._draft_model_config, key, None) is not None:
169-
return getattr(model._draft_model_config, key)
168+
if getattr(model.eagle_config, key, None) is not None:
169+
return getattr(model.eagle_config, key)
170170
elif getattr(model.config, key, None) is not None:
171171
return getattr(model.config, key)
172172
else:

modelopt/torch/speculative/config.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,11 @@ def _get_dflash_default_config():
6767
class DFlashConfig(ModeloptBaseConfig):
6868
"""DFlash config for block-wise parallel speculative decoding."""
6969

70+
dflash_offline: bool = ModeloptField(
71+
default=False,
72+
description="Whether to use detached DFlash (offline training from pre-computed hidden states).",
73+
)
74+
7075
dflash_block_size: int = ModeloptField(
7176
default=8,
7277
description="Block size for parallel prediction. Draft predicts this many tokens per block.",
@@ -110,6 +115,39 @@ class DFlashConfig(ModeloptBaseConfig):
110115
description="Whether to use torch.compile on DFlash forward/loss methods.",
111116
)
112117

118+
@model_validator(mode="before")
119+
@classmethod
120+
def _derive_dflash_offline(cls, data: Any, info: ValidationInfo) -> Any:
121+
"""Derive ``dflash_offline`` from ``data_args.offline_data_path`` when provided in context."""
122+
ctx = info.context if info.context else {}
123+
data_args = ctx.get("data_args")
124+
if data_args is not None and isinstance(data, dict):
125+
data["dflash_offline"] = data_args.offline_data_path is not None
126+
return data
127+
128+
@model_validator(mode="before")
129+
@classmethod
130+
def _resolve_mask_token_id(cls, data: Any, info: ValidationInfo) -> Any:
131+
"""Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context."""
132+
if not isinstance(data, dict) or data.get("dflash_mask_token_id") is not None:
133+
return data
134+
ctx = info.context if info.context else {}
135+
tokenizer = ctx.get("tokenizer")
136+
if tokenizer is not None and getattr(tokenizer, "mask_token_id", None) is not None:
137+
data["dflash_mask_token_id"] = tokenizer.mask_token_id
138+
return data
139+
140+
@model_validator(mode="after")
141+
def _check_mask_token_id(self) -> "DFlashConfig":
142+
"""Validate that mask_token_id is set after all resolution attempts."""
143+
if self.dflash_mask_token_id is None:
144+
raise ValueError(
145+
"dflash_mask_token_id is required. Set it in the config YAML "
146+
"(dflash.dflash_mask_token_id=TOKEN_ID) or ensure the tokenizer "
147+
"has a mask_token_id attribute."
148+
)
149+
return self
150+
113151

114152
class MedusaConfig(ModeloptBaseConfig):
115153
"""Medusa config."""

modelopt/torch/speculative/dflash/dflash_model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def _setup(self):
2727

2828
def modify(self, config):
2929
"""Base DFlash Model modify function. Child class should implement the details."""
30+
self.dflash_offline = config.dflash_offline
3031
self.dflash_block_size = config.dflash_block_size
3132
self.dflash_freeze_base_model = config.dflash_freeze_base_model
3233
self.dflash_loss_decay_factor = config.dflash_loss_decay_factor

modelopt/torch/speculative/eagle/default_config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@
3737
"use_aux_hidden_state": False,
3838
"eagle_aux_hidden_state_layer_ids": [],
3939
"use_mtp_layernorm": False,
40-
"parallel_draft_step": 1,
41-
"parallel_draft_heads_num_layers": 1,
4240
"has_lm_head": False,
4341
"head_dim": 128,
4442
}
@@ -107,7 +105,5 @@
107105
"use_aux_hidden_state": True,
108106
"eagle_aux_hidden_state_layer_ids": [],
109107
"use_mtp_layernorm": False,
110-
"parallel_draft_step": 1,
111-
"parallel_draft_heads_num_layers": 1,
112108
"has_lm_head": False,
113109
}

modelopt/torch/speculative/plugins/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
Please check out the source code of this module for examples of how plugins work and how you can
1919
write your own one. Currently, we support plugins for
2020
21-
- :meth:`transformers<modelopt.torch.speculative.plugins.transformers>`
21+
- :meth:`hf_eagle<modelopt.torch.speculative.plugins.hf_eagle>`
2222
"""
2323

2424
from modelopt.torch.utils import import_plugin
@@ -31,4 +31,5 @@
3131

3232
with import_plugin("transformers"):
3333
from .hf_dflash import *
34-
from .transformers import *
34+
from .hf_eagle import *
35+
from .hf_medusa import *

0 commit comments

Comments
 (0)