diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8775c149fb365..1197b8184d20e 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed PyTorch Lightning profiler not capturing dataloader worker initialization time ([#21771](https://github.com/Lightning-AI/pytorch-lightning/issues/21771)) +- Fixed `TQDMProgressBar` showing an unknown total (`n/?`) and a missing epoch description for the whole first epoch after resuming from a mid-epoch checkpoint ([#20603](https://github.com/Lightning-AI/pytorch-lightning/issues/20603)) + - Fixed `FSDPStrategy` raising `RuntimeError` under PyTorch 2.5+ when `root_device` is CPU, by passing an explicit `torch.device("cpu")` instead of `device_id=None` (relevant only when the GPU-accelerator guard is bypassed) ([#21774](https://github.com/Lightning-AI/pytorch-lightning/pull/21774)) - Fixed non-zero process exits in `CombinedLoader.reset()` with large tensors and persistent spawned workers by avoiding explicit `_shutdown_workers()` calls and relying on iterator cleanup via `del` [#21708](https://github.com/Lightning-AI/pytorch-lightning/issues/21708) diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 74abb8ecd850c..c0dcc649ccf65 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -271,6 +271,16 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: self.train_progress_bar.initial = 0 self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") + @override + def on_train_batch_start(self, trainer: "pl.Trainer", *_: Any) -> None: + if not self.train_progress_bar.total: + # Resuming from a mid-epoch checkpoint skips ``on_train_epoch_start``, + # so initialize the bar's total and description here instead. + total = convert_inf(self.total_train_batches) + if total is not None: + self.train_progress_bar.total = total + self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") + @override def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index 72da9bf543155..dc41e62eb841c 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -856,3 +856,39 @@ def reset(self, total=None): assert 2 in val_bar.total_values, ( f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}" ) + + +def test_tqdm_progress_bar_mid_epoch_resume(tmp_path): + """The bar's total and epoch description must be initialized when resuming from a mid-epoch checkpoint, where + ``on_train_epoch_start`` is not called (#20603).""" + model = BoringModel() + checkpoint = ModelCheckpoint(dirpath=tmp_path, every_n_train_steps=2, save_top_k=-1) + trainer_kwargs = { + "default_root_dir": tmp_path, + "max_epochs": 1, + "limit_train_batches": 4, + "limit_val_batches": 0, + "enable_model_summary": False, + "logger": False, + } + trainer = Trainer(callbacks=[checkpoint, TQDMProgressBar()], **trainer_kwargs) + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.fit(model) + mid_epoch_ckpt = str(tmp_path / "epoch=0-step=2.ckpt") + + totals_seen = [] + descriptions_seen = [] + + class TrackingBar(TQDMProgressBar): + def on_train_batch_end(self, trainer, *args, **kwargs): + totals_seen.append(self.train_progress_bar.total) + descriptions_seen.append(self.train_progress_bar.desc) + super().on_train_batch_end(trainer, *args, **kwargs) + + trainer = Trainer(callbacks=[TrackingBar()], **trainer_kwargs) + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.fit(model, ckpt_path=mid_epoch_ckpt) + + # two batches remain in the resumed epoch; the bar must know the real total + assert totals_seen == [4, 4] + assert all(desc.startswith("Epoch 0") for desc in descriptions_seen)