@@ -67,6 +67,11 @@ def _get_dflash_default_config():
6767class DFlashConfig (ModeloptBaseConfig ):
6868 """DFlash config for block-wise parallel speculative decoding."""
6969
70+ dflash_offline : bool = ModeloptField (
71+ default = False ,
72+ description = "Whether to use detached DFlash (offline training from pre-computed hidden states)." ,
73+ )
74+
7075 dflash_block_size : int = ModeloptField (
7176 default = 8 ,
7277 description = "Block size for parallel prediction. Draft predicts this many tokens per block." ,
@@ -110,6 +115,39 @@ class DFlashConfig(ModeloptBaseConfig):
110115 description = "Whether to use torch.compile on DFlash forward/loss methods." ,
111116 )
112117
118+ @model_validator (mode = "before" )
119+ @classmethod
120+ def _derive_dflash_offline (cls , data : Any , info : ValidationInfo ) -> Any :
121+ """Derive ``dflash_offline`` from ``data_args.offline_data_path`` when provided in context."""
122+ ctx = info .context if info .context else {}
123+ data_args = ctx .get ("data_args" )
124+ if data_args is not None and isinstance (data , dict ):
125+ data ["dflash_offline" ] = data_args .offline_data_path is not None
126+ return data
127+
128+ @model_validator (mode = "before" )
129+ @classmethod
130+ def _resolve_mask_token_id (cls , data : Any , info : ValidationInfo ) -> Any :
131+ """Auto-detect ``dflash_mask_token_id`` from tokenizer when provided in context."""
132+ if not isinstance (data , dict ) or data .get ("dflash_mask_token_id" ) is not None :
133+ return data
134+ ctx = info .context if info .context else {}
135+ tokenizer = ctx .get ("tokenizer" )
136+ if tokenizer is not None and getattr (tokenizer , "mask_token_id" , None ) is not None :
137+ data ["dflash_mask_token_id" ] = tokenizer .mask_token_id
138+ return data
139+
140+ @model_validator (mode = "after" )
141+ def _check_mask_token_id (self ) -> "DFlashConfig" :
142+ """Validate that mask_token_id is set after all resolution attempts."""
143+ if self .dflash_mask_token_id is None :
144+ raise ValueError (
145+ "dflash_mask_token_id is required. Set it in the config YAML "
146+ "(dflash.dflash_mask_token_id=TOKEN_ID) or ensure the tokenizer "
147+ "has a mask_token_id attribute."
148+ )
149+ return self
150+
113151
114152class MedusaConfig (ModeloptBaseConfig ):
115153 """Medusa config."""
0 commit comments