@@ -263,7 +263,6 @@ def _init_sft(self, worker_cfg: WorkerConfig):
263263
264264 self ._rollout_step = 0
265265 self ._sft_cur_epoch = 0
266- self ._sft_total_consumed_samples = 0
267266 self ._sft_total_consumed_tokens = 0
268267
269268 if self ._sft_dataloader_config is not None :
@@ -672,15 +671,13 @@ def _fit_sft(self):
672671 time_before_train_step = time .time ()
673672 data_time = time_before_train_step - time_before_get_data
674673 DEVICE_MODULE .reset_peak_memory_stats ()
675- cur_sample_num = len (data_batch )
676674
677675 train_step_info , grad_norm = self ._train_one_step_sft (data_batch )
678676
679677 time_after_train_step = time .time ()
680678 step_time = time_after_train_step - time_before_train_step
681679 step_consumed_tokens = train_step_info ["step_consumed_tokens" ]
682680
683- self ._sft_total_consumed_samples += self ._reduce_number_across_rank (cur_sample_num )
684681 reduced_step_consumed_tokens = self ._reduce_number_across_rank (step_consumed_tokens )
685682 self ._sft_total_consumed_tokens += reduced_step_consumed_tokens
686683
@@ -1391,9 +1388,13 @@ def save(self, checkpoint_path: Path | str, no_save_optimizer: bool = False):
13911388 )
13921389
13931390 # Save sft dataloader
1394- if self .rank == 0 and self . _sft_dataloader is not None :
1391+ if self ._sft_dataloader is not None :
13951392 sft_dataloader_path = checkpoint_path / self ._SAVE_SFT_DATALOADER_DIR
1396- dataloader_state = self ._sft_dataloader .get_state_dict (self ._sft_total_consumed_samples )
1393+ dataloader_state = self ._sft_dataloader .get_state_dict ()
1394+ total_consumed_samples = int (dataloader_state .get ("sampler" , {}).get ("total_consumed_steps" , 0 ))
1395+ if self .rank != 0 :
1396+ return
1397+
13971398 torch .save (dataloader_state , sft_dataloader_path )
13981399
13991400 train_state_path = checkpoint_path / self ._SAVE_SFT_TRAIN_STATE_PATH
@@ -1403,7 +1404,7 @@ def save(self, checkpoint_path: Path | str, no_save_optimizer: bool = False):
14031404 {
14041405 "cur_step" : self ._rollout_step ,
14051406 "cur_epoch" : self ._sft_cur_epoch ,
1406- "total_consumed_samples" : self . _sft_total_consumed_samples ,
1407+ "total_consumed_samples" : total_consumed_samples ,
14071408 "total_consumed_tokens" : self ._sft_total_consumed_tokens ,
14081409 }
14091410 )
@@ -1437,24 +1438,26 @@ def resume(self, load_checkpoint_cfg: LoadCheckpointConfig):
14371438 )
14381439
14391440 # Resume sft dataloader
1440- sft_dataloader_path = resume_from / self ._SAVE_SFT_DATALOADER_DIR
14411441 if self ._sft_dataloader is not None :
1442- if not sft_dataloader_path .exists ():
1443- raise FileNotFoundError (f"Dataloader path { sft_dataloader_path } does not exist." )
1444- dataloader_state = torch .load (sft_dataloader_path , map_location = DEVICE )
1445- self ._sft_dataloader .load_state_dict (dataloader_state )
1446- self .logger .info (f"Resume sft dataloader from { sft_dataloader_path } " )
1447-
14481442 train_state_path = resume_from / self ._SAVE_SFT_TRAIN_STATE_PATH
14491443 if not train_state_path .exists ():
14501444 raise FileNotFoundError (f"Train state path { train_state_path } does not exist." )
14511445 with train_state_path .open ("r" ) as f :
14521446 train_state = json .loads (f .read ())
1453- self ._rollout_step = train_state ["cur_step" ]
1454- self ._sft_cur_epoch = train_state ["cur_epoch" ]
1455- self ._sft_total_consumed_samples = train_state ["total_consumed_samples" ]
1456- self ._sft_total_consumed_tokens = train_state ["total_consumed_tokens" ]
1457- self .logger .info (f"Resume sft train state from { train_state_path } " )
1447+ self ._rollout_step = train_state ["cur_step" ]
1448+ self ._sft_cur_epoch = train_state ["cur_epoch" ]
1449+ self ._sft_total_consumed_tokens = train_state ["total_consumed_tokens" ]
1450+ self .logger .info (f"Resume sft train state from { train_state_path } " )
1451+
1452+ sft_dataloader_path = resume_from / self ._SAVE_SFT_DATALOADER_DIR
1453+ if not sft_dataloader_path .exists ():
1454+ raise FileNotFoundError (f"Dataloader path { sft_dataloader_path } does not exist." )
1455+ dataloader_state = torch .load (sft_dataloader_path , map_location = DEVICE )
1456+ self ._sft_dataloader .load_state_dict (
1457+ dataloader_state ,
1458+ train_state_total_consumed_samples = train_state .get ("total_consumed_samples" , 0 ),
1459+ )
1460+ self .logger .info (f"Resume sft dataloader from { sft_dataloader_path } " )
14581461
14591462 @ray_method
14601463 def ready (self ) -> bool :
0 commit comments