Skip to content

Commit 6e7e617

Browse files
committed
refine code
1 parent 0580b70 commit 6e7e617

4 files changed

Lines changed: 9 additions & 16 deletions

File tree

xtuner/v1/datasets/dataloader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def get_state_dict(self, consumed_samples: int = -1) -> dict:
5454
"Dataloader.get_state_dict(consumed_samples=...) is deprecated; use the default (-1). "
5555
"Consumed samples are tracked on the sampler."
5656
)
57+
# TODO: remove consumed_samples parameter in get_dataloader_state in next major release
5758
dataloader_state = get_dataloader_state(self, consumed_samples)
5859
return cast(dict, dataloader_state)
5960

xtuner/v1/datasets/preset_sampler.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,8 @@ def get_state_dict(self, step: int | None = None) -> dict:
192192
}
193193

194194
def load_state_dict(self, state_dict: dict) -> None:
195-
tc = state_dict.get("total_consumed_steps")
196-
if tc is not None:
197-
self._consumed.set_init_from_checkpoint(int(tc))
198-
else:
199-
self._consumed.set_init_from_checkpoint(0)
195+
tc = int(state_dict.get("total_consumed_steps", 0))
196+
self._consumed.set_init_from_checkpoint(tc)
200197
if self.world_size != state_dict.get("world_size"):
201198
logger.warning(
202199
f"PresetSampler: world_size mismatch: checkpoint has "

xtuner/v1/datasets/resume.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from .consumed_steps import apply_old_ckpt_init_steps
77
from .packing import ExpandSoftPackDataset, _LegacySoftPackDataset
8+
from .preset_sampler import PresetSampler
89
from .sampler import LengthGroupedSampler, ParallelSampler
910

1011

@@ -45,7 +46,7 @@ def load_dataloader_state(
4546
state: dict,
4647
train_state_total_consumed_samples: int | None = None,
4748
):
48-
sampler = dataloader.sampler
49+
sampler: ParallelSampler | LengthGroupedSampler | PresetSampler = dataloader.sampler # type: ignore[assignment]
4950
dataset = dataloader.dataset
5051

5152
# Sampler require `load_state_dict` to restore the training progress since the sampler state will

xtuner/v1/datasets/sampler.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -141,11 +141,8 @@ def load_state_dict(self, state_dict) -> None:
141141
Args:
142142
state_dict (dict): The state of the sampler.
143143
"""
144-
tc = state_dict.get("total_consumed_steps")
145-
if tc is not None:
146-
self._consumed.set_init_from_checkpoint(int(tc))
147-
else:
148-
self._consumed.set_init_from_checkpoint(0)
144+
tc = int(state_dict.get("total_consumed_steps", 0))
145+
self._consumed.set_init_from_checkpoint(tc)
149146
self.epoch = state_dict["epoch"]
150147
self.step = state_dict["step"]
151148

@@ -298,11 +295,8 @@ def load_state_dict(self, state_dict: dict) -> None:
298295
Args:
299296
state_dict (dict): The state of the sampler.
300297
"""
301-
tc = state_dict.get("total_consumed_steps")
302-
if tc is not None:
303-
self._consumed.set_init_from_checkpoint(int(tc))
304-
else:
305-
self._consumed.set_init_from_checkpoint(0)
298+
tc = int(state_dict.get("total_consumed_steps", 0))
299+
self._consumed.set_init_from_checkpoint(tc)
306300
self.epoch = state_dict["epoch"]
307301
self.step = state_dict["step"]
308302

0 commit comments

Comments
 (0)