Skip to content

Commit 3931e83

Browse files
committed
fix resuming init total tokens and samples
1 parent 6362a64 commit 3931e83

1 file changed

Lines changed: 7 additions & 3 deletions

File tree

xtuner/v1/train/trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1810,11 +1810,15 @@ def _load_checkpoint(self):
18101810
self._cur_step = train_state["cur_step"]
18111811
self._cur_epoch = train_state["cur_epoch"]
18121812

1813-
self._init_total_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC
1814-
self._init_total_samples = train_state.get("total_consumed_samples", 0) # default 0 for BC
1815-
18161813
if load_checkpoint_cfg.load_dataset:
18171814
self._train_time_offset = train_state["train_time_offset"]
1815+
self._init_total_tokens = train_state.get("total_consumed_tokens", 0) # default 0 for BC
1816+
# TODO: total_samples 由 Dataloader 维护, 包括 save/resume
1817+
# self._init_total_samples 会影响 save dcp时 dataloader.get_state_dict的状态。
1818+
# 1) 如果加载 dataset,应该恢复_total_consumed_samples为checkpoint中的值。
1819+
# 2) 如果不加载 dataset,应该保持 self._init_total_samples为初始值0,否则如果加载上旧dataloader的total_consumed_samples
1820+
# 会导致存储新dataloader时 total_consumed_samples 是不正确的值。
1821+
self._init_total_samples = train_state.get("total_consumed_samples", 0) # default 0 for BC
18181822

18191823
dataloader_path = resume_from / self._SAVE_DATALOADER_DIR
18201824
self._resume_dataloader(dataloader_path)

0 commit comments

Comments
 (0)