Skip to content

Commit 4537b78

Browse files
Fix JaxTrainingPlan to update global_step (#1791)
* Add fake optimizer step to JaxTrainingPlan * Update scvi/train/_trainingplans.py Co-authored-by: Adam Gayoso <adamgayoso@users.noreply.github.com> Co-authored-by: Adam Gayoso <adamgayoso@users.noreply.github.com>
1 parent f6a7b40 commit 4537b78

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

scvi/train/_trainingplans.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,9 @@ def training_step(self, batch, batch_idx):
11431143
prog_bar=True,
11441144
)
11451145
self.compute_and_log_metrics(loss_output, self.train_metrics, "train")
1146+
# Update the dummy optimizer to update the global step
1147+
_opt = self.optimizers()
1148+
_opt.step()
11461149

11471150
@partial(jax.jit, static_argnums=(0,))
11481151
def jit_validation_step(

0 commit comments

Comments
 (0)