Skip to content

Commit 988573e

Browse files
committed
fix(alias): un-break downstream testing for weight_decay default 0.0 param requirements
1 parent 6ed973a commit 988573e

2 files changed

Lines changed: 6 additions & 4 deletions

File tree

optax/transforms/_adding.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,6 @@ def init_fn(params):
5757
return base.EmptyState()
5858

5959
def update_fn(updates, state, params):
60-
if params is None:
61-
raise ValueError(base.NO_PARAMS_MSG)
6260
if callable(weight_decay):
6361
new_state = WeightDecaySchedule(numerics.safe_increment(state.count))
6462
else:
@@ -68,6 +66,9 @@ def update_fn(updates, state, params):
6866
if isinstance(weight_decay, (int, float)) and weight_decay == 0.0:
6967
return updates, new_state
7068

69+
if params is None:
70+
raise ValueError(base.NO_PARAMS_MSG)
71+
7172
s = weight_decay(state.count) if callable(weight_decay) else weight_decay
7273
updates = jax.tree.map(
7374
lambda g, p: None if g is None else g + s * p,

optax/tree_utils/_state_utils_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ def test_map_params_to_none(self):
195195
(
196196
transform.ScaleByRssState(sum_of_squares={'a': None}),
197197
base.EmptyState(),
198+
base.EmptyState(),
198199
),
199200
)
200201

@@ -378,7 +379,7 @@ def get_learning_rate(state):
378379

379380
for i in range(4):
380381
# we simply update state, we don't care about updates.
381-
_, state = opt.update(params, state)
382+
_, state = opt.update(params, state, params)
382383
lr = get_learning_rate(state)
383384
self.assertEqual(lr, 1 / (i + 1))
384385

@@ -465,7 +466,7 @@ def set_learning_rate(state, lr):
465466
for i in range(4):
466467
modified_state = set_learning_rate(modified_state, lr / (i + 1))
467468
# we simply update state, we don't care about updates.
468-
_, modified_state = opt.update(params, modified_state)
469+
_, modified_state = opt.update(params, modified_state, params)
469470
modified_lr = _state_utils.tree_get(modified_state, 'learning_rate')
470471
self.assertEqual(modified_lr, lr / (i + 1))
471472

0 commit comments

Comments
 (0)