Skip to content

Commit 0580b70

Browse files
committed
_save_dataloader return total_consumed_samples
1 parent 6ce0ad1 commit 0580b70

4 files changed

Lines changed: 3 additions & 18 deletions

File tree

xtuner/v1/datasets/dataloader.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,6 @@ def get_state_dict(self, consumed_samples: int = -1) -> dict:
5757
dataloader_state = get_dataloader_state(self, consumed_samples)
5858
return cast(dict, dataloader_state)
5959

60-
def get_total_consumed_samples(self) -> int:
61-
sampler = self.sampler
62-
if hasattr(sampler, "get_total_consumed_steps"):
63-
return int(sampler.get_total_consumed_steps())
64-
return 0
65-
6660
# __iter__ is inherited from torch.utils.data.DataLoader
6761

6862
# Streaming dataloader may not have `set_epoch` and `__len__` method, so we add here.

xtuner/v1/datasets/preset_sampler.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,6 @@ def __len__(self) -> int:
174174
def set_epoch(self, epoch: int) -> None:
175175
self.epoch = epoch
176176

177-
def get_total_consumed_steps(self) -> int:
178-
return self._consumed.total_for_checkpoint()
179-
180177
def get_state_dict(self, step: int | None = None) -> dict:
181178
# Same convention as :class:`LengthGroupedSampler`: ``step`` is the global pack offset
182179
# (modulo ``total_size``) into ``global_order``, shared across all ranks in the checkpoint.

xtuner/v1/datasets/sampler.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,6 @@ def set_epoch(self, epoch: int) -> None:
135135
"""
136136
self.epoch = epoch
137137

138-
def get_total_consumed_steps(self) -> int:
139-
return self._consumed.total_for_checkpoint()
140-
141138
def load_state_dict(self, state_dict) -> None:
142139
"""Load the sampler state.
143140
@@ -295,9 +292,6 @@ def set_epoch(self, epoch: int) -> None:
295292
"""
296293
self.epoch = epoch
297294

298-
def get_total_consumed_steps(self) -> int:
299-
return self._consumed.total_for_checkpoint()
300-
301295
def load_state_dict(self, state_dict: dict) -> None:
302296
"""Load the sampler state.
303297

xtuner/v1/train/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1129,10 +1129,9 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
11291129
total_consumed_tokens = (
11301130
self._reduce_number_across_rank(self._local_total_consumed_tokens) + self._init_total_tokens
11311131
)
1132-
total_consumed_samples = self._dataloader.get_total_consumed_samples()
11331132

11341133
# Save dataloader
1135-
self._save_dataloader(dataloader_path)
1134+
total_consumed_samples = self._save_dataloader(dataloader_path)
11361135

11371136
DEVICE_MODULE.empty_cache()
11381137

@@ -1211,10 +1210,11 @@ def _maybe_save(self, is_snapshot: bool = False) -> bool:
12111210

12121211
return True
12131212

1214-
def _save_dataloader(self, dataloader_path: Path | str):
1213+
def _save_dataloader(self, dataloader_path: Path | str) -> int:
12151214
dataloader_state = self._dataloader.get_state_dict()
12161215
if self.rank == 0:
12171216
torch.save(dataloader_state, dataloader_path)
1217+
return int(dataloader_state.get("sampler", {}).get("total_consumed_steps", 0))
12181218

12191219
@property
12201220
def work_dir(self) -> Path:

0 commit comments

Comments
 (0)