From 94ff6597669b41f0f25945f9b5090413810f84af Mon Sep 17 00:00:00 2001 From: Vineeth Sai Date: Sat, 4 Jul 2026 10:02:58 -0700 Subject: [PATCH 1/2] Fix TQDMProgressBar total and description after mid-epoch resume Resuming from a checkpoint saved mid-epoch intentionally skips on_train_epoch_start in the resumed process (the RestartStage rework in 2.5), but TQDMProgressBar only sets the bar's total and the 'Epoch N' description in that hook. The whole resumed epoch therefore rendered as 'n/?' with the initial 'Training' description. Lazily initialize the bar in on_train_batch_start when the epoch-start hook did not run: set the total from total_train_batches and the epoch description. Normal epochs are untouched (the total is already set), and genuinely infinite dataloaders keep an unknown total exactly as before. Fixes #20603 --- src/lightning/pytorch/CHANGELOG.md | 2 ++ .../callbacks/progress/tqdm_progress.py | 10 ++++++ .../progress/test_tqdm_progress_bar.py | 36 +++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index 8775c149fb365..1197b8184d20e 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -26,6 +26,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed PyTorch Lightning profiler not capturing dataloader worker initialization time ([#21771](https://github.com/Lightning-AI/pytorch-lightning/issues/21771)) +- Fixed `TQDMProgressBar` showing an unknown total (`n/?`) and a missing epoch description for the whole first epoch after resuming from a mid-epoch checkpoint ([#20603](https://github.com/Lightning-AI/pytorch-lightning/issues/20603)) + - Fixed `FSDPStrategy` raising `RuntimeError` under PyTorch 2.5+ when `root_device` is CPU, by passing an explicit `torch.device("cpu")` instead of `device_id=None` (relevant only when the GPU-accelerator guard is bypassed) ([#21774](https://github.com/Lightning-AI/pytorch-lightning/pull/21774)) - Fixed non-zero process exits in `CombinedLoader.reset()` with large tensors and persistent spawned workers by avoiding explicit `_shutdown_workers()` calls and relying on iterator cleanup via `del` [#21708](https://github.com/Lightning-AI/pytorch-lightning/issues/21708) diff --git a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py index 74abb8ecd850c..c0dcc649ccf65 100644 --- a/src/lightning/pytorch/callbacks/progress/tqdm_progress.py +++ b/src/lightning/pytorch/callbacks/progress/tqdm_progress.py @@ -271,6 +271,16 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", *_: Any) -> None: self.train_progress_bar.initial = 0 self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") + @override + def on_train_batch_start(self, trainer: "pl.Trainer", *_: Any) -> None: + if not self.train_progress_bar.total: + # Resuming from a mid-epoch checkpoint skips ``on_train_epoch_start``, + # so initialize the bar's total and description here instead. + total = convert_inf(self.total_train_batches) + if total is not None: + self.train_progress_bar.total = total + self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") + @override def on_train_batch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", outputs: STEP_OUTPUT, batch: Any, batch_idx: int diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index 72da9bf543155..3e60709f53e0b 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -856,3 +856,39 @@ def reset(self, total=None): assert 2 in val_bar.total_values, ( f"validation total should be set to 2 after reset(), got total_values: {val_bar.total_values}" ) + + +def test_tqdm_progress_bar_mid_epoch_resume(tmp_path): + """The bar's total and epoch description must be initialized when resuming from a mid-epoch + checkpoint, where ``on_train_epoch_start`` is not called (#20603).""" + model = BoringModel() + checkpoint = ModelCheckpoint(dirpath=tmp_path, every_n_train_steps=2, save_top_k=-1) + trainer_kwargs = { + "default_root_dir": tmp_path, + "max_epochs": 1, + "limit_train_batches": 4, + "limit_val_batches": 0, + "enable_model_summary": False, + "logger": False, + } + trainer = Trainer(callbacks=[checkpoint, TQDMProgressBar()], **trainer_kwargs) + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.fit(model) + mid_epoch_ckpt = str(tmp_path / "epoch=0-step=2.ckpt") + + totals_seen = [] + descriptions_seen = [] + + class TrackingBar(TQDMProgressBar): + def on_train_batch_end(self, trainer, *args, **kwargs): + totals_seen.append(self.train_progress_bar.total) + descriptions_seen.append(self.train_progress_bar.desc) + super().on_train_batch_end(trainer, *args, **kwargs) + + trainer = Trainer(callbacks=[TrackingBar()], **trainer_kwargs) + with mock.patch("lightning.pytorch.callbacks.progress.tqdm_progress.Tqdm", MockTqdm): + trainer.fit(model, ckpt_path=mid_epoch_ckpt) + + # two batches remain in the resumed epoch; the bar must know the real total + assert totals_seen == [4, 4] + assert all(desc.startswith("Epoch 0") for desc in descriptions_seen) From 65c0ffb4fefbb407391bf2f7204082f667651b11 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 4 Jul 2026 17:03:55 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../callbacks/progress/test_tqdm_progress_bar.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py index 3e60709f53e0b..dc41e62eb841c 100644 --- a/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py +++ b/tests/tests_pytorch/callbacks/progress/test_tqdm_progress_bar.py @@ -859,8 +859,8 @@ def reset(self, total=None): def test_tqdm_progress_bar_mid_epoch_resume(tmp_path): - """The bar's total and epoch description must be initialized when resuming from a mid-epoch - checkpoint, where ``on_train_epoch_start`` is not called (#20603).""" + """The bar's total and epoch description must be initialized when resuming from a mid-epoch checkpoint, where + ``on_train_epoch_start`` is not called (#20603).""" model = BoringModel() checkpoint = ModelCheckpoint(dirpath=tmp_path, every_n_train_steps=2, save_top_k=-1) trainer_kwargs = {