@@ -152,10 +152,8 @@ def train():
152152 training_args = HfTrainingArguments (** recipe .training .model_dump ())
153153 init_distributed_env (training_args )
154154
155- if not recipe .data .data_path and not recipe .data .offline_data_path :
156- raise ValueError (
157- "Either data.data_path or data.offline_data_path must be set in the config."
158- )
155+ if recipe .data .mode in ("online" , "streaming" ) and not recipe .data .data_path :
156+ raise ValueError (f"data.mode={ recipe .data .mode !r} requires data.data_path." )
159157 if training_args .cp_size > 1 :
160158 patch_ring_attention_for_ttt ()
161159 # Specific patch to accelerate 1.12.0. Removable after move to 1.13.0
@@ -174,7 +172,7 @@ def train():
174172
175173 checkpoint = training_args .resume_from_checkpoint or last_checkpoint
176174
177- use_offline_training = recipe .data .offline_data_path is not None
175+ use_offline_training = recipe .data .mode != "online"
178176
179177 if checkpoint :
180178 with patch_transformers5_params_loading ():
@@ -249,6 +247,7 @@ def train():
249247 train_len = training_args .training_seq_len ,
250248 answer_only_loss = training_args .answer_only_loss ,
251249 shift_labels = not is_dflash ,
250+ seed = training_args .seed ,
252251 )
253252
254253 callbacks = [EagleTrainingPlot (training_args .ar_validate_steps , training_args .estimate_ar )]
@@ -258,6 +257,13 @@ def train():
258257 and recipe .eagle .eagle_base_lora_warmup_steps > 0
259258 ):
260259 callbacks .append (LoRAWarmupCallback (recipe .eagle .eagle_base_lora_warmup_steps ))
260+ if recipe .data .mode == "streaming" :
261+ # Skip-on-resume happens inside the dataset (no re-fetch from server);
262+ # disable HF Trainer's own data skip so the offset isn't applied twice.
263+ from modelopt .torch .speculative .plugins .hf_streaming_dataset import StreamingResumeCallback
264+
265+ training_args .ignore_data_skip = True
266+ callbacks .append (StreamingResumeCallback ())
261267
262268 trainer = EagleTrainerWithAccLog (
263269 model = model ,
0 commit comments