Skip to content

Commit 646108e

Browse files
committed
Fix fsdp progress bar
1 parent 07c32b3 commit 646108e

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

src/lmms_engine/train/fsdp2_trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,10 +240,11 @@ def train(self, resume_from_checkpoint: bool = False):
240240
# start_epoch is a float, we need to convert it to an integer
241241
start_epoch = int(start_epoch)
242242
self.global_step = int(latest_checkpoint.split("-")[1])
243+
need_update_pbar = True
243244
else:
244245
start_epoch = 0
245246
self.global_step = 0
246-
247+
need_update_pbar = False
247248
Logging.info(f"Training with {self.args.num_train_epochs} epochs")
248249

249250
for epoch in range(start_epoch, self.args.num_train_epochs):
@@ -253,7 +254,11 @@ def train(self, resume_from_checkpoint: bool = False):
253254
desc=f"Epoch {epoch + 1}",
254255
disable=dist.get_rank() != 0,
255256
)
256-
pbar.update(self.global_step)
257+
# if the checkpoint is loaded, we need to update the pbar
258+
# but we only need to update the pbar once
259+
if need_update_pbar:
260+
pbar.update(self.global_step)
261+
need_update_pbar = False
257262
for step, batch in enumerate(self.train_dataloader):
258263
# send batch to device
259264
batch = send_to_device(batch, self.fsdp2_model.device)
@@ -285,7 +290,7 @@ def train(self, resume_from_checkpoint: bool = False):
285290
)
286291
train_metrics["mfu"] = round(mfu, 2)
287292

288-
epoch_progress = f"{self.global_step / self.total_steps:.2f}"
293+
epoch_progress = f"{self.global_step / self.steps_per_epoch:.2f}"
289294
train_metrics["epoch"] = float(epoch_progress)
290295
if rank == 0:
291296
self.tracking.log(train_metrics)

0 commit comments

Comments
 (0)