|
23 | 23 | import torch |
24 | 24 | from lightning_utilities.test.warning import no_warning_call |
25 | 25 |
|
26 | | -from lightning.pytorch import Trainer, seed_everything |
| 26 | +from lightning.pytorch import LightningModule, Trainer, seed_everything |
27 | 27 | from lightning.pytorch.callbacks import EarlyStopping |
28 | 28 | from lightning.pytorch.callbacks.finetuning import BackboneFinetuning |
29 | 29 | from lightning.pytorch.callbacks.lr_finder import LearningRateFinder |
@@ -844,3 +844,60 @@ def configure_optimizers(self): |
844 | 844 | # Check that backbone was unfrozen at the correct epoch |
845 | 845 | for param in model.backbone.parameters(): |
846 | 846 | assert param.requires_grad, "Backbone parameters should be unfrozen after epoch 1" |
| 847 | + |
| 848 | + |
| 849 | +def test_lr_finder_respects_weights_only(tmp_path): |
| 850 | + """Test that lr_find works correctly when saving more than the weights.""" |
| 851 | + |
| 852 | + # Simple torch Module |
| 853 | + class TorchCoder(torch.nn.Module): |
| 854 | + def __init__(self, in_features, out_features): |
| 855 | + super().__init__() |
| 856 | + self.net = torch.nn.Linear(in_features, out_features) |
| 857 | + |
| 858 | + def forward(self, x): |
| 859 | + return self.net(x) |
| 860 | + |
| 861 | + # Simple model |
| 862 | + class SimpleModel(LightningModule): |
| 863 | + def __init__(self, coder, loss, lr=1e-3): |
| 864 | + super().__init__() |
| 865 | + self.save_hyperparameters() |
| 866 | + self.layer = coder |
| 867 | + self.loss = loss |
| 868 | + self.lr = lr |
| 869 | + |
| 870 | + def training_step(self, batch, batch_idx): |
| 871 | + x, y = batch |
| 872 | + y_hat = self.layer(x) |
| 873 | + return self.loss(y_hat, y) |
| 874 | + |
| 875 | + def configure_optimizers(self): |
| 876 | + return torch.optim.Adam(self.parameters(), lr=self.lr) |
| 877 | + |
| 878 | + # Dummy data |
| 879 | + x = torch.randn(16, 4) |
| 880 | + y = torch.randn(16, 2) |
| 881 | + loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x, y), batch_size=4) |
| 882 | + |
| 883 | + model = SimpleModel( |
| 884 | + TorchCoder(4, 2), |
| 885 | + loss=torch.nn.MSELoss(), |
| 886 | + ) |
| 887 | + |
| 888 | + trainer = Trainer( |
| 889 | + default_root_dir=tmp_path, |
| 890 | + max_epochs=1, |
| 891 | + logger=False, |
| 892 | + enable_checkpointing=True, |
| 893 | + ) |
| 894 | + |
| 895 | + # This should NOT raise an exception after the fix |
| 896 | + lr_finder = Tuner(trainer).lr_find( |
| 897 | + model, |
| 898 | + train_dataloaders=loader, |
| 899 | + weights_only=False, # <-- the key part |
| 900 | + ) |
| 901 | + |
| 902 | + assert lr_finder is not None |
| 903 | + assert hasattr(lr_finder, "results") |
0 commit comments