Skip to content

Commit fe6b1cc

Browse files
Fix: Add weights_only parameter to LearningRateFinder checkpoint restore (#21758)
* Fix: add weights_only parameter to LearningRateFinder and propagate through LR finder call chain * Test: ensure LearningRateFinder supports weights_only=False during checkpoint restore * Test: ensure LearningRateFinder supports weights_only=False during checkpoint restore and apply pre-commit * Style: apply pre-commit formatting to LR Finder patch * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent d9a9766 commit fe6b1cc

4 files changed

Lines changed: 80 additions & 2 deletions

File tree

src/lightning/pytorch/callbacks/lr_finder.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,11 @@ class LearningRateFinder(Callback):
5050
update_attr: Whether to update the learning rate attribute or not.
5151
attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
5252
automatically detected. Otherwise, set the name here.
53+
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
54+
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
55+
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
56+
recommend using ``weights_only=True``. For more information, please refer to the
57+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
5358
5459
Example::
5560
@@ -92,6 +97,7 @@ def __init__(
9297
early_stop_threshold: Optional[float] = 4.0,
9398
update_attr: bool = True,
9499
attr_name: str = "",
100+
weights_only: Optional[bool] = None,
95101
) -> None:
96102
mode = mode.lower()
97103
if mode not in self.SUPPORTED_MODES:
@@ -104,6 +110,7 @@ def __init__(
104110
self._early_stop_threshold = early_stop_threshold
105111
self._update_attr = update_attr
106112
self._attr_name = attr_name
113+
self._weights_only = weights_only
107114

108115
self._early_exit = False
109116
self.optimal_lr: Optional[_LRFinder] = None
@@ -120,6 +127,7 @@ def lr_find(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> Non
120127
early_stop_threshold=self._early_stop_threshold,
121128
update_attr=self._update_attr,
122129
attr_name=self._attr_name,
130+
weights_only=self._weights_only,
123131
)
124132

125133
if self._early_exit:

src/lightning/pytorch/tuner/lr_finder.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ def _lr_find(
206206
early_stop_threshold: Optional[float] = 4.0,
207207
update_attr: bool = False,
208208
attr_name: str = "",
209+
weights_only: Optional[bool] = None,
209210
) -> Optional[_LRFinder]:
210211
"""Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in picking
211212
a good starting learning rate.
@@ -227,6 +228,11 @@ def _lr_find(
227228
update_attr: Whether to update the learning rate attribute or not.
228229
attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
229230
automatically detected. Otherwise, set the name here.
231+
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
232+
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
233+
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
234+
recommend using ``weights_only=True``. For more information, please refer to the
235+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
230236
231237
"""
232238
if trainer.fast_dev_run:
@@ -285,7 +291,7 @@ def _lr_find(
285291
raise ex
286292
finally:
287293
# Restore initial state of model (this will also restore the original optimizer state)
288-
trainer._checkpoint_connector.restore(ckpt_path)
294+
trainer._checkpoint_connector.restore(ckpt_path, weights_only=weights_only)
289295
trainer.strategy.remove_checkpoint(ckpt_path)
290296
trainer.fit_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True
291297
trainer.fit_loop.epoch_loop.restarting = False # reset restarting flag as checkpoint restoring sets it to True

src/lightning/pytorch/tuner/tuning.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ def lr_find(
131131
early_stop_threshold: Optional[float] = 4.0,
132132
update_attr: bool = True,
133133
attr_name: str = "",
134+
weights_only: Optional[bool] = None,
134135
) -> Optional["_LRFinder"]:
135136
"""Enables the user to do a range test of good initial learning rates, to reduce the amount of guesswork in
136137
picking a good starting learning rate.
@@ -159,6 +160,11 @@ def lr_find(
159160
update_attr: Whether to update the learning rate attribute or not.
160161
attr_name: Name of the attribute which stores the learning rate. The names 'learning_rate' or 'lr' get
161162
automatically detected. Otherwise, set the name here.
163+
weights_only: Defaults to ``None``. If ``True``, restricts loading to ``state_dicts`` of plain
164+
``torch.Tensor`` and other primitive types. If loading a checkpoint from a trusted source that contains
165+
an ``nn.Module``, use ``weights_only=False``. If loading checkpoint from an untrusted source, we
166+
recommend using ``weights_only=True``. For more information, please refer to the
167+
`PyTorch Developer Notes on Serialization Semantics <https://docs.pytorch.org/docs/main/notes/serialization.html#id3>`_.
162168
163169
Raises:
164170
MisconfigurationException:
@@ -183,6 +189,7 @@ def lr_find(
183189
early_stop_threshold=early_stop_threshold,
184190
update_attr=update_attr,
185191
attr_name=attr_name,
192+
weights_only=weights_only,
186193
)
187194

188195
lr_finder_callback._early_exit = True

tests/tests_pytorch/tuner/test_lr_finder.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import torch
2424
from lightning_utilities.test.warning import no_warning_call
2525

26-
from lightning.pytorch import Trainer, seed_everything
26+
from lightning.pytorch import LightningModule, Trainer, seed_everything
2727
from lightning.pytorch.callbacks import EarlyStopping
2828
from lightning.pytorch.callbacks.finetuning import BackboneFinetuning
2929
from lightning.pytorch.callbacks.lr_finder import LearningRateFinder
@@ -844,3 +844,60 @@ def configure_optimizers(self):
844844
# Check that backbone was unfrozen at the correct epoch
845845
for param in model.backbone.parameters():
846846
assert param.requires_grad, "Backbone parameters should be unfrozen after epoch 1"
847+
848+
849+
def test_lr_finder_respects_weights_only(tmp_path):
850+
"""Test that lr_find works correctly when saving more than the weights."""
851+
852+
# Simple torch Module
853+
class TorchCoder(torch.nn.Module):
854+
def __init__(self, in_features, out_features):
855+
super().__init__()
856+
self.net = torch.nn.Linear(in_features, out_features)
857+
858+
def forward(self, x):
859+
return self.net(x)
860+
861+
# Simple model
862+
class SimpleModel(LightningModule):
863+
def __init__(self, coder, loss, lr=1e-3):
864+
super().__init__()
865+
self.save_hyperparameters()
866+
self.layer = coder
867+
self.loss = loss
868+
self.lr = lr
869+
870+
def training_step(self, batch, batch_idx):
871+
x, y = batch
872+
y_hat = self.layer(x)
873+
return self.loss(y_hat, y)
874+
875+
def configure_optimizers(self):
876+
return torch.optim.Adam(self.parameters(), lr=self.lr)
877+
878+
# Dummy data
879+
x = torch.randn(16, 4)
880+
y = torch.randn(16, 2)
881+
loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x, y), batch_size=4)
882+
883+
model = SimpleModel(
884+
TorchCoder(4, 2),
885+
loss=torch.nn.MSELoss(),
886+
)
887+
888+
trainer = Trainer(
889+
default_root_dir=tmp_path,
890+
max_epochs=1,
891+
logger=False,
892+
enable_checkpointing=True,
893+
)
894+
895+
# This should NOT raise an exception after the fix
896+
lr_finder = Tuner(trainer).lr_find(
897+
model,
898+
train_dataloaders=loader,
899+
weights_only=False, # <-- the key part
900+
)
901+
902+
assert lr_finder is not None
903+
assert hasattr(lr_finder, "results")

0 commit comments

Comments
 (0)