Skip to content

Commit 8719802

Browse files
committed
polish
Signed-off-by: h-guo18 <67671475+h-guo18@users.noreply.github.com>
1 parent d9fe3d7 commit 8719802

2 files changed

Lines changed: 15 additions & 41 deletions

File tree

examples/speculative_decoding/main.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
import modelopt.torch.speculative as mtsp
5050
from modelopt.recipe import load_recipe
5151
from modelopt.recipe.config import ModelOptDFlashRecipe, ModelOptEagleRecipe, ModelOptMedusaRecipe
52-
from modelopt.torch.speculative.config import DFlashConfig
5352
from modelopt.torch.speculative.utils import load_vlm_or_llm, patch_transformers5_params_loading
5453
from modelopt.torch.utils import print_rank_0
5554
from modelopt.torch.utils.distributed import is_master
@@ -159,10 +158,15 @@ def train():
159158
# Load draft vocab cache
160159
mtsp.plugins.HFEagleModel.load_draft_vocab_cache(model, recipe.data.draft_vocab_cache)
161160
elif isinstance(recipe, ModelOptDFlashRecipe):
162-
# Re-validate with tokenizer to resolve dflash_mask_token_id and enforce its presence.
163-
dflash_cfg: dict = DFlashConfig.model_validate(
164-
recipe.dflash.model_dump(), context={"tokenizer": tokenizer}
165-
).model_dump()
161+
# Fall back to tokenizer.mask_token_id when not set in the recipe; require one of the two.
162+
if recipe.dflash.dflash_mask_token_id is None:
163+
recipe.dflash.dflash_mask_token_id = getattr(tokenizer, "mask_token_id", None)
164+
if recipe.dflash.dflash_mask_token_id is None:
165+
raise ValueError(
166+
"dflash.dflash_mask_token_id is required: set it in the recipe YAML "
167+
"or use a tokenizer that defines mask_token_id."
168+
)
169+
dflash_cfg: dict = recipe.dflash.model_dump()
166170
mtsp.convert(model, [("dflash", dflash_cfg)])
167171
else:
168172
raise ValueError(f"Unsupported speculative recipe type: {type(recipe).__name__}")

modelopt/torch/speculative/config.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,8 @@
1616
"""Configurations for speculative decoding modes."""
1717

1818
from copy import deepcopy
19-
from typing import Any
2019

21-
from pydantic import ValidationInfo, model_validator
20+
from pydantic import model_validator
2221

2322
from modelopt.torch.opt.config import ModeloptBaseConfig, ModeloptField
2423

@@ -102,10 +101,12 @@ class DFlashConfig(ModeloptBaseConfig):
102101
default=True, description="Whether to report eval accuracy."
103102
)
104103

105-
dflash_mask_token_id: int = ModeloptField(
104+
dflash_mask_token_id: int | None = ModeloptField(
106105
default=None,
107-
description="Token ID used for masked (unknown) positions. "
108-
"Set explicitly or auto-detected from tokenizer.mask_token_id in main.py.",
106+
description=(
107+
"Token ID used for masked (unknown) positions. Set explicitly in the recipe YAML, "
108+
"or left unset to fall back to ``tokenizer.mask_token_id`` at training time."
109+
),
109110
)
110111

111112
dflash_architecture_config: dict = ModeloptField(
@@ -117,37 +118,6 @@ class DFlashConfig(ModeloptBaseConfig):
117118
description="Whether to use torch.compile on DFlash forward/loss methods.",
118119
)
119120

120-
@model_validator(mode="before")
121-
@classmethod
122-
def _resolve_mask_token_id(cls, data: Any, info: ValidationInfo) -> Any:
123-
"""Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context."""
124-
if not isinstance(data, dict) or data.get("dflash_mask_token_id") is not None:
125-
return data
126-
ctx = info.context if info.context else {}
127-
tokenizer = ctx.get("tokenizer")
128-
if tokenizer is not None and getattr(tokenizer, "mask_token_id", None) is not None:
129-
data["dflash_mask_token_id"] = tokenizer.mask_token_id
130-
return data
131-
132-
@model_validator(mode="after")
133-
def _check_mask_token_id(self, info: ValidationInfo) -> "DFlashConfig":
134-
"""Require ``dflash_mask_token_id`` once a tokenizer is available.
135-
136-
Skipped when no tokenizer is in context (e.g., recipe-load time before the tokenizer
137-
is constructed). The caller is expected to re-validate with ``context={"tokenizer": ...}``
138-
once the tokenizer is loaded; that pass enforces the requirement.
139-
"""
140-
ctx = info.context if info.context else {}
141-
if ctx.get("tokenizer") is None:
142-
return self
143-
if self.dflash_mask_token_id is None:
144-
raise ValueError(
145-
"dflash_mask_token_id is required. Set it in the config YAML "
146-
"(dflash.dflash_mask_token_id=TOKEN_ID) or ensure the tokenizer "
147-
"has a mask_token_id attribute."
148-
)
149-
return self
150-
151121

152122
class MedusaConfig(ModeloptBaseConfig):
153123
"""Medusa config."""

0 commit comments

Comments
 (0)