Skip to content

Commit b2784f7

Browse files
authored
Upgrade to PyTorch Lightning 1.8 (#1795)
* Initial commit * Update ClampCallback for 1.8
1 parent 4537b78 commit b2784f7

3 files changed

Lines changed: 5 additions & 5 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ pyro-ppl = ">=1.6.0"
6767
pytest = {version = ">=4.4", optional = true}
6868
python = ">=3.7,<4.0"
6969
python-igraph = {version = "*", optional = true}
70-
pytorch-lightning = ">=1.7.0,<1.8"
70+
pytorch-lightning = ">=1.8.0,<1.9"
7171
rich = ">=9.1.0"
7272
scanpy = {version = ">=1.6", optional = true}
7373
scikit-learn = ">=0.21.2"

scvi/external/cellassign/_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,8 @@ class ClampCallback(Callback):
271271
def __init__(self):
272272
super().__init__()
273273

274-
def on_batch_end(self, trainer, pl_module):
274+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
275275
"""Clamp parameters."""
276276
with torch.inference_mode():
277277
pl_module.module.delta_log.clamp_(np.log(pl_module.module.min_delta))
278-
super().on_batch_end(trainer, pl_module)
278+
super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)

scvi/train/_callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self):
2121
def on_train_epoch_start(self, trainer, pl_module):
2222
"""Subsample labels at the beginning of each epoch."""
2323
trainer.train_dataloader.loaders.resample_labels()
24-
super().on_epoch_start(trainer, pl_module)
24+
super().on_train_epoch_start(trainer, pl_module)
2525

2626

2727
class SaveBestState(Callback):
@@ -86,7 +86,7 @@ def __init__(
8686
def check_monitor_top(self, current): # noqa: D102
8787
return self.monitor_op(current, self.best_module_metric_val)
8888

89-
def on_epoch_end(self, trainer, pl_module): # noqa: D102
89+
def on_val_epoch_end(self, trainer, pl_module): # noqa: D102
9090
logs = trainer.callback_metrics
9191
self.epochs_since_last_check += 1
9292

0 commit comments

Comments
 (0)