@@ -240,10 +240,11 @@ def train(self, resume_from_checkpoint: bool = False):
240240 # start_epoch is a float, we need to convert it to an integer
241241 start_epoch = int (start_epoch )
242242 self .global_step = int (latest_checkpoint .split ("-" )[1 ])
243+ need_update_pbar = True
243244 else :
244245 start_epoch = 0
245246 self .global_step = 0
246-
247+ need_update_pbar = False
247248 Logging .info (f"Training with { self .args .num_train_epochs } epochs" )
248249
249250 for epoch in range (start_epoch , self .args .num_train_epochs ):
@@ -253,7 +254,11 @@ def train(self, resume_from_checkpoint: bool = False):
253254 desc = f"Epoch { epoch + 1 } " ,
254255 disable = dist .get_rank () != 0 ,
255256 )
256- pbar .update (self .global_step )
257+ # if the checkpoint is loaded, we need to update the pbar
258+ # but we only need to update the pbar once
259+ if need_update_pbar :
260+ pbar .update (self .global_step )
261+ need_update_pbar = False
257262 for step , batch in enumerate (self .train_dataloader ):
258263 # send batch to device
259264 batch = send_to_device (batch , self .fsdp2_model .device )
@@ -285,7 +290,7 @@ def train(self, resume_from_checkpoint: bool = False):
285290 )
286291 train_metrics ["mfu" ] = round (mfu , 2 )
287292
288- epoch_progress = f"{ self .global_step / self .total_steps :.2f} "
293+ epoch_progress = f"{ self .global_step / self .steps_per_epoch :.2f} "
289294 train_metrics ["epoch" ] = float (epoch_progress )
290295 if rank == 0 :
291296 self .tracking .log (train_metrics )
0 commit comments