Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 44 additions & 10 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def adabelief(
b2: jax.typing.ArrayLike = 0.999,
eps: jax.typing.ArrayLike = 1e-16,
eps_root: jax.typing.ArrayLike = 1e-16,
weight_decay: base.ScalarOrSchedule = 0.0,
weight_decay_mask: MaskOrFn = None,
*,
nesterov: bool = False,
) -> base.GradientTransformationExtraArgs:
Expand Down Expand Up @@ -136,6 +138,7 @@ def adabelief(
eps_root=eps_root,
nesterov=nesterov,
),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
)

Expand Down Expand Up @@ -330,6 +333,8 @@ def adagrad(
learning_rate: base.ScalarOrSchedule,
initial_accumulator_value: jax.typing.ArrayLike = 0.1,
eps: jax.typing.ArrayLike = 1e-7,
weight_decay: base.ScalarOrSchedule = 0.0,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformationExtraArgs:
r"""The Adagrad optimizer.

Expand Down Expand Up @@ -407,6 +412,7 @@ def adagrad(
transform.scale_by_rss(
initial_accumulator_value=initial_accumulator_value, eps=eps
),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
)

Expand Down Expand Up @@ -603,7 +609,7 @@ def adamw(
eps_root: jax.typing.ArrayLike = 0.0,
mu_dtype: Optional[Any] = None,
weight_decay: base.ScalarOrSchedule = 1e-4,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
weight_decay_mask: MaskOrFn = None,
*,
nesterov: bool = False,
) -> base.GradientTransformationExtraArgs:
Expand Down Expand Up @@ -722,7 +728,7 @@ def adamw(
mu_dtype=mu_dtype,
nesterov=nesterov,
),
transform.add_decayed_weights(weight_decay, mask),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
)

Expand Down Expand Up @@ -814,7 +820,7 @@ def adan(
eps: jax.typing.ArrayLike = 1e-8,
eps_root: jax.typing.ArrayLike = 1e-8,
weight_decay: base.ScalarOrSchedule = 0.0,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformationExtraArgs:
r"""The ADAptive Nesterov momentum algorithm (Adan).

Expand Down Expand Up @@ -919,7 +925,7 @@ def adan(
eps=eps,
eps_root=eps_root,
),
transform.add_decayed_weights(weight_decay, mask),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
)

Expand All @@ -930,7 +936,7 @@ def lion(
b2: jax.typing.ArrayLike = 0.99,
mu_dtype: Optional[Any] = None,
weight_decay: base.ScalarOrSchedule = 1e-3,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformationExtraArgs:
r"""The Lion optimizer.

Expand Down Expand Up @@ -1015,7 +1021,7 @@ def lion(
"""
return combine.chain(
transform.scale_by_lion(b1=b1, b2=b2, mu_dtype=mu_dtype),
transform.add_decayed_weights(weight_decay, mask),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
)

Expand All @@ -1029,6 +1035,8 @@ def amsgrad(
mu_dtype: Optional[Any] = None,
bias_correction_mu: bool = True,
bias_correction_nu: bool = True,
weight_decay: base.ScalarOrSchedule = 0.0,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformationExtraArgs:
"""The AMSGrad optimizer.

Expand Down Expand Up @@ -1091,6 +1099,7 @@ def amsgrad(
bias_correction_mu=bias_correction_mu,
bias_correction_nu=bias_correction_nu,
),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
)

Expand Down Expand Up @@ -1238,7 +1247,7 @@ def lamb(
eps: jax.typing.ArrayLike = 1e-6,
eps_root: jax.typing.ArrayLike = 0.0,
weight_decay: base.ScalarOrSchedule = 0.0,
mask: MaskOrFn = None,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformationExtraArgs:
"""The LAMB optimizer.

Expand Down Expand Up @@ -1294,7 +1303,7 @@ def lamb(
"""
return combine.chain(
transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root),
transform.add_decayed_weights(weight_decay=weight_decay, mask=mask),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_trust_ratio(),
transform.scale_by_learning_rate(learning_rate),
)
Expand Down Expand Up @@ -1552,6 +1561,8 @@ def optimistic_gradient_descent(
learning_rate: base.ScalarOrSchedule,
alpha: base.ScalarOrSchedule = 1.0,
beta: base.ScalarOrSchedule = 1.0,
weight_decay: base.ScalarOrSchedule = 0.0,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformationExtraArgs:
r"""An Optimistic Gradient Descent optimizer.

Expand Down Expand Up @@ -1611,6 +1622,7 @@ def optimistic_gradient_descent(
"""
return combine.chain(
transform.scale_by_optimistic_gradient(alpha=alpha, beta=beta),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
)

Expand All @@ -1623,6 +1635,8 @@ def optimistic_adam(
eps: jax.typing.ArrayLike = 1e-08,
eps_root: jax.typing.ArrayLike = 0.0,
mu_dtype: Optional[Any] = None,
weight_decay: base.ScalarOrSchedule = 0.0,
weight_decay_mask: MaskOrFn = None,
*,
nesterov: bool = True,
) -> base.GradientTransformationExtraArgs:
Expand Down Expand Up @@ -1735,12 +1749,15 @@ def optimistic_adam(
),
transform.scale_by_optimistic_gradient(alpha=learning_rate,
beta=optimism),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(1.0), # flips the sign
)


def optimistic_adam_v2(
learning_rate: base.ScalarOrSchedule,
weight_decay: base.ScalarOrSchedule = 0.0,
weight_decay_mask: MaskOrFn = None,
*,
alpha: jax.typing.ArrayLike = 1.0,
beta: jax.typing.ArrayLike = 1.0,
Expand Down Expand Up @@ -1857,6 +1874,7 @@ def optimistic_adam_v2(
nesterov=nesterov,
),
transform.scale_by_optimistic_gradient(alpha=alpha, beta=beta),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
)

Expand All @@ -1868,6 +1886,8 @@ def radam(
eps: jax.typing.ArrayLike = 1e-8,
eps_root: jax.typing.ArrayLike = 0.0,
threshold: jax.typing.ArrayLike = 5.0,
weight_decay: base.ScalarOrSchedule = 0.0,
weight_decay_mask: MaskOrFn = None,
*,
nesterov: bool = False,
) -> base.GradientTransformationExtraArgs:
Expand Down Expand Up @@ -1928,6 +1948,7 @@ def radam(
threshold=threshold,
nesterov=nesterov,
),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
)

Expand All @@ -1942,6 +1963,8 @@ def rmsprop(
momentum: Optional[jax.typing.ArrayLike] = None,
nesterov: bool = False,
bias_correction: bool = False,
weight_decay: base.ScalarOrSchedule = 0.0,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformationExtraArgs:
r"""A flexible RMSProp optimizer.

Expand Down Expand Up @@ -2022,6 +2045,7 @@ def rmsprop(
eps_in_sqrt=eps_in_sqrt,
bias_correction=bias_correction,
),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
(
transform.trace(decay=momentum, nesterov=nesterov)
Expand Down Expand Up @@ -2051,6 +2075,8 @@ def sgd(
momentum: Optional[jax.typing.ArrayLike] = None,
nesterov: bool = False,
accumulator_dtype: Optional[Any] = None,
weight_decay: base.ScalarOrSchedule = 0.0,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformationExtraArgs:
r"""A canonical Stochastic Gradient Descent optimizer.

Expand Down Expand Up @@ -2132,12 +2158,16 @@ def sgd(
opt = base.identity()
return combine.chain(
opt,
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
)


def sm3(
learning_rate: jax.typing.ArrayLike, momentum: jax.typing.ArrayLike = 0.9
,
weight_decay: base.ScalarOrSchedule = 0.0,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformationExtraArgs:
r"""The SM3 optimizer.

Expand Down Expand Up @@ -2242,6 +2272,7 @@ def sm3(
"""
return combine.chain(
transform.scale_by_sm3(momentum),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale(-learning_rate),
)

Expand All @@ -2251,6 +2282,8 @@ def yogi(
b1: jax.typing.ArrayLike = 0.9,
b2: jax.typing.ArrayLike = 0.999,
eps: jax.typing.ArrayLike = 1e-3,
weight_decay: base.ScalarOrSchedule = 0.0,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformationExtraArgs:
# pylint: disable=line-too-long
"""The Yogi optimizer.
Expand Down Expand Up @@ -2301,6 +2334,7 @@ def yogi(
# pylint: enable=line-too-long
return combine.chain(
transform.scale_by_yogi(b1=b1, b2=b2, eps=eps),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
)

Expand Down Expand Up @@ -2396,7 +2430,7 @@ def adamaxw(
b2: jax.typing.ArrayLike = 0.999,
eps: jax.typing.ArrayLike = 1e-8,
weight_decay: base.ScalarOrSchedule = 1e-4,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
weight_decay_mask: MaskOrFn = None,
) -> base.GradientTransformationExtraArgs:
"""Adamax with weight decay regularization.

Expand Down Expand Up @@ -2460,7 +2494,7 @@ def adamaxw(
"""
return combine.chain(
transform.scale_by_adamax(b1=b1, b2=b2, eps=eps),
transform.add_decayed_weights(weight_decay, mask),
transform.add_decayed_weights(weight_decay, mask=weight_decay_mask),
transform.scale_by_learning_rate(learning_rate),
)

Expand Down
5 changes: 3 additions & 2 deletions optax/transforms/_adding.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def init_fn(params):
return base.EmptyState()

def update_fn(updates, state, params):
if params is None:
raise ValueError(base.NO_PARAMS_MSG)
if callable(weight_decay):
new_state = WeightDecaySchedule(numerics.safe_increment(state.count))
else:
Expand All @@ -68,6 +66,9 @@ def update_fn(updates, state, params):
if isinstance(weight_decay, (int, float)) and weight_decay == 0.0:
return updates, new_state

if params is None:
raise ValueError(base.NO_PARAMS_MSG)

s = weight_decay(state.count) if callable(weight_decay) else weight_decay
updates = jax.tree.map(
lambda g, p: None if g is None else g + s * p,
Expand Down
4 changes: 2 additions & 2 deletions optax/tree_utils/_state_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,13 +437,13 @@ def tree_set(
... )
>>> state = opt.init(params)
>>> print(state)
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'weight_decay': Array(0., dtype=float32), 'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState(), EmptyState()))
>>> filtering = lambda path, value: isinstance(value, jnp.ndarray)
>>> new_state = optax.tree_utils.tree_set(
... state, filtering, learning_rate=jnp.asarray(0.1)
... )
>>> print(new_state)
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState()))
InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'weight_decay': Array(0., dtype=float32), 'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState(), EmptyState()))

.. note:: The recommended usage to inject hyperparameters schedules is through
:func:`optax.inject_hyperparams`. This function is a helper for other
Expand Down
5 changes: 3 additions & 2 deletions optax/tree_utils/_state_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ def test_map_params_to_none(self):
(
transform.ScaleByRssState(sum_of_squares={'a': None}),
base.EmptyState(),
base.EmptyState(),
),
)

Expand Down Expand Up @@ -378,7 +379,7 @@ def get_learning_rate(state):

for i in range(4):
# we simply update state, we don't care about updates.
_, state = opt.update(params, state)
_, state = opt.update(params, state, params)
lr = get_learning_rate(state)
self.assertEqual(lr, 1 / (i + 1))

Expand Down Expand Up @@ -465,7 +466,7 @@ def set_learning_rate(state, lr):
for i in range(4):
modified_state = set_learning_rate(modified_state, lr / (i + 1))
# we simply update state, we don't care about updates.
_, modified_state = opt.update(params, modified_state)
_, modified_state = opt.update(params, modified_state, params)
modified_lr = _state_utils.tree_get(modified_state, 'learning_rate')
self.assertEqual(modified_lr, lr / (i + 1))

Expand Down
Loading