1212
1313from xtuner .v1 .utils import get_logger
1414
15+ from .consumed_steps import ConsumedStepsTracker
1516from .jsonl import JsonlDataset
1617from .packing import MLLMPretrainHybridPackDataset , _LegacySoftPackDataset
1718from .preset_pack import PresetPackDataset
@@ -84,6 +85,7 @@ def __init__(
8485 self .epoch = 0
8586 self .step = 0
8687 self .round_up = round_up
88+ self ._consumed = ConsumedStepsTracker (dp_mesh )
8789
8890 if self .round_up :
8991 self .num_samples = math .ceil (len (self .dataset ) / global_batch_size ) * global_batch_size // world_size
@@ -131,12 +133,23 @@ def set_epoch(self, epoch: int) -> None:
131133 """
132134 self .epoch = epoch
133135
136+ def record_consumed_samples (self , n : int ) -> None :
137+ self ._consumed .record (n )
138+
139+ def get_total_consumed_steps (self ) -> int :
140+ return self ._consumed .total_for_checkpoint ()
141+
134142 def load_state_dict (self , state_dict ) -> None :
135143 """Load the sampler state.
136144
137145 Args:
138146 state_dict (dict): The state of the sampler.
139147 """
148+ tc = state_dict .get ("total_consumed_steps" )
149+ if tc is not None :
150+ self ._consumed .set_init_from_checkpoint (int (tc ))
151+ else :
152+ self ._consumed .set_init_from_checkpoint (0 )
140153 self .epoch = state_dict ["epoch" ]
141154 self .step = state_dict ["step" ]
142155
@@ -146,12 +159,17 @@ def load_state_dict(self, state_dict) -> None:
146159 f"is different from the current shuffle ({ self .shuffle } )."
147160 )
148161
149- def get_state_dict (self , step : int ):
162+ def get_state_dict (self , step : int | None = None ):
150163 # Attention! Do not set self.step here, or it will cause the next __iter__ to get less samples.
151- step = step % self .total_size
164+ if step is None :
165+ total_consumed = self ._consumed .total_for_checkpoint ()
166+ else :
167+ total_consumed = int (step )
168+ step_mod = total_consumed % self .total_size
152169 return {
153170 "epoch" : self .epoch ,
154- "step" : step ,
171+ "step" : step_mod ,
172+ "total_consumed_steps" : total_consumed ,
155173 "world_size" : self .world_size ,
156174 "shuffle" : self .shuffle ,
157175 "round_up" : self .round_up ,
@@ -233,6 +251,7 @@ def __init__(
233251 assert isinstance (self .max_lengths , (list , tuple , Column , np .ndarray ))
234252
235253 self .global_batch_size = global_batch_size
254+ self ._consumed = ConsumedStepsTracker (dp_mesh )
236255
237256 def __iter__ (self ) -> Iterator [int ]:
238257 """Iterate the indices."""
@@ -275,12 +294,23 @@ def set_epoch(self, epoch: int) -> None:
275294 """
276295 self .epoch = epoch
277296
297+ def record_consumed_samples (self , n : int ) -> None :
298+ self ._consumed .record (n )
299+
300+ def get_total_consumed_steps (self ) -> int :
301+ return self ._consumed .total_for_checkpoint ()
302+
278303 def load_state_dict (self , state_dict : dict ) -> None :
279304 """Load the sampler state.
280305
281306 Args:
282307 state_dict (dict): The state of the sampler.
283308 """
309+ tc = state_dict .get ("total_consumed_steps" )
310+ if tc is not None :
311+ self ._consumed .set_init_from_checkpoint (int (tc ))
312+ else :
313+ self ._consumed .set_init_from_checkpoint (0 )
284314 self .epoch = state_dict ["epoch" ]
285315 self .step = state_dict ["step" ]
286316
@@ -298,17 +328,22 @@ def load_state_dict(self, state_dict: dict) -> None:
298328 )
299329 self .group_size = origin_group_size
300330
301- def get_state_dict (self , step : int ):
331+ def get_state_dict (self , step : int | None = None ):
302332 """Get the sampler state dict.
303333
304334 Returns:
305335 dict: The state of the sampler.
306336 """
307337 # Attention! Do not set self.step here, or it will cause the next __iter__ to get less samples.
308- step = step % self .total_size
338+ if step is None :
339+ total_consumed = self ._consumed .total_for_checkpoint ()
340+ else :
341+ total_consumed = int (step )
342+ step_mod = total_consumed % self .total_size
309343 return {
310344 "epoch" : self .epoch ,
311- "step" : step ,
345+ "step" : step_mod ,
346+ "total_consumed_steps" : total_consumed ,
312347 "world_size" : self .world_size ,
313348 "round_up" : self .round_up ,
314349 "num_samples" : self .num_samples ,
0 commit comments