1616"""Configurations for speculative decoding modes."""
1717
1818from copy import deepcopy
19- from typing import Any
2019
21- from pydantic import ValidationInfo , model_validator
20+ from pydantic import model_validator
2221
2322from modelopt .torch .opt .config import ModeloptBaseConfig , ModeloptField
2423
@@ -102,10 +101,12 @@ class DFlashConfig(ModeloptBaseConfig):
102101 default = True , description = "Whether to report eval accuracy."
103102 )
104103
105- dflash_mask_token_id : int = ModeloptField (
104+ dflash_mask_token_id : int | None = ModeloptField (
106105 default = None ,
107- description = "Token ID used for masked (unknown) positions. "
108- "Set explicitly or auto-detected from tokenizer.mask_token_id in main.py." ,
106+ description = (
107+ "Token ID used for masked (unknown) positions. Set explicitly in the recipe YAML, "
108+ "or left unset to fall back to ``tokenizer.mask_token_id`` at training time."
109+ ),
109110 )
110111
111112 dflash_architecture_config : dict = ModeloptField (
@@ -117,37 +118,6 @@ class DFlashConfig(ModeloptBaseConfig):
117118 description = "Whether to use torch.compile on DFlash forward/loss methods." ,
118119 )
119120
120- @model_validator (mode = "before" )
121- @classmethod
122- def _resolve_mask_token_id (cls , data : Any , info : ValidationInfo ) -> Any :
123- """Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context."""
124- if not isinstance (data , dict ) or data .get ("dflash_mask_token_id" ) is not None :
125- return data
126- ctx = info .context if info .context else {}
127- tokenizer = ctx .get ("tokenizer" )
128- if tokenizer is not None and getattr (tokenizer , "mask_token_id" , None ) is not None :
129- data ["dflash_mask_token_id" ] = tokenizer .mask_token_id
130- return data
131-
132- @model_validator (mode = "after" )
133- def _check_mask_token_id (self , info : ValidationInfo ) -> "DFlashConfig" :
134- """Require ``dflash_mask_token_id`` once a tokenizer is available.
135-
136- Skipped when no tokenizer is in context (e.g., recipe-load time before the tokenizer
137- is constructed). The caller is expected to re-validate with ``context={"tokenizer": ...}``
138- once the tokenizer is loaded; that pass enforces the requirement.
139- """
140- ctx = info .context if info .context else {}
141- if ctx .get ("tokenizer" ) is None :
142- return self
143- if self .dflash_mask_token_id is None :
144- raise ValueError (
145- "dflash_mask_token_id is required. Set it in the config YAML "
146- "(dflash.dflash_mask_token_id=TOKEN_ID) or ensure the tokenizer "
147- "has a mask_token_id attribute."
148- )
149- return self
150-
151121
152122class MedusaConfig (ModeloptBaseConfig ):
153123 """Medusa config."""
0 commit comments