Skip to content

Commit 32186aa

Browse files
hawkinsplearned_optimization authors
authored andcommitted
No public description
PiperOrigin-RevId: 888266025
1 parent 0f9803c commit 32186aa

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

learned_optimization/optimizers/optax_opts.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,9 @@ def __init__(self,
186186
epsilon_root=1e-8):
187187
opt = optax.chain(
188188
optax.scale_by_adam(
189-
b1=beta1, b2=beta2, eps=epsilon, eps_root=epsilon_root),
190-
optax.scale_by_schedule(piecewise_linear(times, vals=lrs)),
189+
b1=beta1, b2=beta2, eps=epsilon, eps_root=epsilon_root
190+
),
191+
optax.scale_by_schedule(piecewise_linear(times, vals=lrs)), # pytype: disable=wrong-arg-types # jax-arraylike
191192
optax.scale(-1),
192193
)
193194
super().__init__(opt)

0 commit comments

Comments
 (0)