Skip to content

Commit 5c783a7

Browse files
committed
fix RL worker's sft_dataloader save
1 parent 6e7e617 commit 5c783a7

1 file changed

Lines changed: 21 additions & 18 deletions

File tree

xtuner/v1/rl/base/worker.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,6 @@ def _init_sft(self, worker_cfg: WorkerConfig):
263263

264264
self._rollout_step = 0
265265
self._sft_cur_epoch = 0
266-
self._sft_total_consumed_samples = 0
267266
self._sft_total_consumed_tokens = 0
268267

269268
if self._sft_dataloader_config is not None:
@@ -672,15 +671,13 @@ def _fit_sft(self):
672671
time_before_train_step = time.time()
673672
data_time = time_before_train_step - time_before_get_data
674673
DEVICE_MODULE.reset_peak_memory_stats()
675-
cur_sample_num = len(data_batch)
676674

677675
train_step_info, grad_norm = self._train_one_step_sft(data_batch)
678676

679677
time_after_train_step = time.time()
680678
step_time = time_after_train_step - time_before_train_step
681679
step_consumed_tokens = train_step_info["step_consumed_tokens"]
682680

683-
self._sft_total_consumed_samples += self._reduce_number_across_rank(cur_sample_num)
684681
reduced_step_consumed_tokens = self._reduce_number_across_rank(step_consumed_tokens)
685682
self._sft_total_consumed_tokens += reduced_step_consumed_tokens
686683

@@ -1391,9 +1388,13 @@ def save(self, checkpoint_path: Path | str, no_save_optimizer: bool = False):
13911388
)
13921389

13931390
# Save sft dataloader
1394-
if self.rank == 0 and self._sft_dataloader is not None:
1391+
if self._sft_dataloader is not None:
13951392
sft_dataloader_path = checkpoint_path / self._SAVE_SFT_DATALOADER_DIR
1396-
dataloader_state = self._sft_dataloader.get_state_dict(self._sft_total_consumed_samples)
1393+
dataloader_state = self._sft_dataloader.get_state_dict()
1394+
total_consumed_samples = int(dataloader_state.get("sampler", {}).get("total_consumed_steps", 0))
1395+
if self.rank != 0:
1396+
return
1397+
13971398
torch.save(dataloader_state, sft_dataloader_path)
13981399

13991400
train_state_path = checkpoint_path / self._SAVE_SFT_TRAIN_STATE_PATH
@@ -1403,7 +1404,7 @@ def save(self, checkpoint_path: Path | str, no_save_optimizer: bool = False):
14031404
{
14041405
"cur_step": self._rollout_step,
14051406
"cur_epoch": self._sft_cur_epoch,
1406-
"total_consumed_samples": self._sft_total_consumed_samples,
1407+
"total_consumed_samples": total_consumed_samples,
14071408
"total_consumed_tokens": self._sft_total_consumed_tokens,
14081409
}
14091410
)
@@ -1437,24 +1438,26 @@ def resume(self, load_checkpoint_cfg: LoadCheckpointConfig):
14371438
)
14381439

14391440
# Resume sft dataloader
1440-
sft_dataloader_path = resume_from / self._SAVE_SFT_DATALOADER_DIR
14411441
if self._sft_dataloader is not None:
1442-
if not sft_dataloader_path.exists():
1443-
raise FileNotFoundError(f"Dataloader path {sft_dataloader_path} does not exist.")
1444-
dataloader_state = torch.load(sft_dataloader_path, map_location=DEVICE)
1445-
self._sft_dataloader.load_state_dict(dataloader_state)
1446-
self.logger.info(f"Resume sft dataloader from {sft_dataloader_path}")
1447-
14481442
train_state_path = resume_from / self._SAVE_SFT_TRAIN_STATE_PATH
14491443
if not train_state_path.exists():
14501444
raise FileNotFoundError(f"Train state path {train_state_path} does not exist.")
14511445
with train_state_path.open("r") as f:
14521446
train_state = json.loads(f.read())
1453-
self._rollout_step = train_state["cur_step"]
1454-
self._sft_cur_epoch = train_state["cur_epoch"]
1455-
self._sft_total_consumed_samples = train_state["total_consumed_samples"]
1456-
self._sft_total_consumed_tokens = train_state["total_consumed_tokens"]
1457-
self.logger.info(f"Resume sft train state from {train_state_path}")
1447+
self._rollout_step = train_state["cur_step"]
1448+
self._sft_cur_epoch = train_state["cur_epoch"]
1449+
self._sft_total_consumed_tokens = train_state["total_consumed_tokens"]
1450+
self.logger.info(f"Resume sft train state from {train_state_path}")
1451+
1452+
sft_dataloader_path = resume_from / self._SAVE_SFT_DATALOADER_DIR
1453+
if not sft_dataloader_path.exists():
1454+
raise FileNotFoundError(f"Dataloader path {sft_dataloader_path} does not exist.")
1455+
dataloader_state = torch.load(sft_dataloader_path, map_location=DEVICE)
1456+
self._sft_dataloader.load_state_dict(
1457+
dataloader_state,
1458+
train_state_total_consumed_samples=train_state.get("total_consumed_samples", 0),
1459+
)
1460+
self.logger.info(f"Resume sft dataloader from {sft_dataloader_path}")
14581461

14591462
@ray_method
14601463
def ready(self) -> bool:

0 commit comments

Comments
 (0)