@@ -181,20 +181,17 @@ def modify(self, config):
181181 self .dflash_config .block_size = self .dflash_block_size
182182
183183 # Target layer IDs
184- num_target_layers = base_config .num_hidden_layers
184+ num_target_layers = (
185+ base_config .num_orig_hidden_layers
186+ if self .dflash_offline
187+ else base_config .num_hidden_layers
188+ )
185189 num_draft_layers = self .dflash_config .num_hidden_layers
186190 self .target_layer_ids = build_target_layer_ids (num_target_layers , num_draft_layers )
187191 self .dflash_config .target_layer_ids = self .target_layer_ids
188192
189- # mask_token_id: set in DFlashConfig (or auto-detected by main.py from tokenizer)
190- mask_id = config .dflash_mask_token_id
191- if mask_id is None :
192- raise ValueError (
193- "dflash_mask_token_id is required. Set it in the config YAML "
194- "(dflash.dflash_mask_token_id=TOKEN_ID) or let main.py auto-detect "
195- "from tokenizer.mask_token_id."
196- )
197- self .mask_token_id = mask_id
193+ # mask_token_id: validated by DFlashConfig, auto-detected from tokenizer context
194+ self .mask_token_id = config .dflash_mask_token_id
198195 logger .info ("DFlash mask_token_id: %s" , self .mask_token_id )
199196
200197 # Freeze base model
@@ -207,10 +204,17 @@ def modify(self, config):
207204 self .dflash_module = DFlashModule (self .dflash_config )
208205 # Match base model dtype/device. Skip if base is on meta (during from_pretrained
209206 # restore — the model will be moved to the correct device after weight loading).
210- base_device = next (self ._base_model .layers [- 1 ].parameters ()).device
207+ if self .dflash_offline :
208+ base_device = self ._base_model_lm_head .weight .device
209+ else :
210+ base_device = next (self ._base_model .layers [- 1 ].parameters ()).device
211211 if base_device .type != "meta" :
212212 self .dflash_module .to (self ._base_model .dtype ).to (base_device )
213213
214+ # Delete base model layers for offline training (save memory)
215+ if self .dflash_offline :
216+ self ._base_model ._modules .pop ("layers" )
217+
214218 self .is_quantized = False
215219 self ._num_anchors = self .dflash_num_anchors
216220
@@ -465,9 +469,17 @@ def forward(
465469 )
466470
467471 # 1. Run base model → extract target hidden states
468- base_outputs = self ._dflash_base_model_forward (
469- input_ids , attention_mask , freeze = self .dflash_freeze_base_model
470- )
472+ if self .dflash_offline :
473+ assert "base_model_outputs" in kwargs
474+ base_outputs = DFlashBaseModelOutput .from_offline_dict (kwargs ["base_model_outputs" ])
475+ if base_outputs .logits is None and self .dflash_self_logit_distillation :
476+ # Compute logits from last-layer hidden states for KD loss
477+ out_hiddens = kwargs ["base_model_outputs" ].get ("base_model_hidden_states" )
478+ base_outputs .logits = self ._base_model_lm_head (out_hiddens )
479+ else :
480+ base_outputs = self ._dflash_base_model_forward (
481+ input_ids , attention_mask , freeze = self .dflash_freeze_base_model
482+ )
471483
472484 # 2. Build loss mask.
473485 # When labels are provided (answer_only_loss), they already encode both
0 commit comments