Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed PyTorch Lightning profiler not capturing dataloader worker initialization time ([#21771](https://github.com/Lightning-AI/pytorch-lightning/issues/21771))

- Fixed `TQDMProgressBar` showing an unknown total (`n/?`) and a missing epoch description for the whole first epoch after resuming from a mid-epoch checkpoint ([#20603](https://github.com/Lightning-AI/pytorch-lightning/issues/20603))

- Fixed `FSDPStrategy` raising `RuntimeError` under PyTorch 2.5+ when `root_device` is CPU, by passing an explicit `torch.device("cpu")` instead of `device_id=None` (relevant only when the GPU-accelerator guard is bypassed) ([#21774](https://github.com/Lightning-AI/pytorch-lightning/pull/21774))

- Fixed non-zero process exits in `CombinedLoader.reset()` with large tensors and persistent spawned workers by avoiding explicit `_shutdown_workers()` calls and relying on iterator cleanup via `del` [#21708](https://github.com/Lightning-AI/pytorch-lightning/issues/21708)
Expand Down
10 changes: 10 additions & 0 deletions src/lightning/pytorch/callbacks/progress/tqdm_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,16 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
self.train_progress_bar.initial = 0
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

@override
def on_train_batch_start(self, trainer: "pl.Trainer", *_: Any) -> None:
if not self.train_progress_bar.total:
# Resuming from a mid-epoch checkpoint skips ``on_train_epoch_start``,
# so initialize the bar's total and description here instead.
total = convert_inf(self.total_train_batches)
if total is not None:
self.train_progress_bar.total = total
self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}")

@override
def on_train_batch_end(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -856,3 +856,39 @@ def reset(self, total=None):
assert 2 in val_bar.total_values, (
f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}"
)


def test_tqdm_progress_bar_mid_epoch_resume(tmp_path):
"""The bar's total and epoch description must be initialized when resuming from a mid-epoch checkpoint, where
``on_train_epoch_start`` is not called (#20603)."""
model = BoringModel()
checkpoint = ModelCheckpoint(dirpath=tmp_path, every_n_train_steps=2, save_top_k=-1)
trainer_kwargs = {
"default_root_dir": tmp_path,
"max_epochs": 1,
"limit_train_batches": 4,
"limit_val_batches": 0,
"enable_model_summary": False,
"logger": False,
}
trainer = Trainer(callbacks=[checkpoint, TQDMProgressBar()], **trainer_kwargs)
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
trainer.fit(model)
mid_epoch_ckpt = str(tmp_path / "epoch=0-step=2.ckpt")

totals_seen = []
descriptions_seen = []

class TrackingBar(TQDMProgressBar):
def on_train_batch_end(self, trainer, *args, **kwargs):
totals_seen.append(self.train_progress_bar.total)
descriptions_seen.append(self.train_progress_bar.desc)
super().on_train_batch_end(trainer, *args, **kwargs)

trainer = Trainer(callbacks=[TrackingBar()], **trainer_kwargs)
with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm):
trainer.fit(model, ckpt_path=mid_epoch_ckpt)

# two batches remain in the resumed epoch; the bar must know the real total
assert totals_seen == [4, 4]
assert all(desc.startswith("Epoch 0") for desc in descriptions_seen)