diff --git a/optax/_src/alias.py b/optax/_src/alias.py index ac2ff572d..88812738a 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -39,6 +39,8 @@ def adabelief( b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-16, eps_root: jax.typing.ArrayLike = 1e-16, + weight_decay: base.ScalarOrSchedule = 0.0, + weight_decay_mask: MaskOrFn = None, *, nesterov: bool = False, ) -> base.GradientTransformationExtraArgs: @@ -136,6 +138,7 @@ def adabelief( eps_root=eps_root, nesterov=nesterov, ), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) @@ -330,6 +333,8 @@ def adagrad( learning_rate: base.ScalarOrSchedule, initial_accumulator_value: jax.typing.ArrayLike = 0.1, eps: jax.typing.ArrayLike = 1e-7, + weight_decay: base.ScalarOrSchedule = 0.0, + weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformationExtraArgs: r"""The Adagrad optimizer. @@ -407,6 +412,7 @@ def adagrad( transform.scale_by_rss( initial_accumulator_value=initial_accumulator_value, eps=eps ), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) @@ -603,7 +609,7 @@ def adamw( eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: Optional[Any] = None, weight_decay: base.ScalarOrSchedule = 1e-4, - mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, + weight_decay_mask: MaskOrFn = None, *, nesterov: bool = False, ) -> base.GradientTransformationExtraArgs: @@ -722,7 +728,7 @@ def adamw( mu_dtype=mu_dtype, nesterov=nesterov, ), - transform.add_decayed_weights(weight_decay, mask), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) @@ -814,7 +820,7 @@ def adan( eps: jax.typing.ArrayLike = 1e-8, eps_root: jax.typing.ArrayLike = 1e-8, weight_decay: base.ScalarOrSchedule = 0.0, - mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, + weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformationExtraArgs: r"""The ADAptive Nesterov momentum algorithm (Adan). @@ -919,7 +925,7 @@ def adan( eps=eps, eps_root=eps_root, ), - transform.add_decayed_weights(weight_decay, mask), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) @@ -930,7 +936,7 @@ def lion( b2: jax.typing.ArrayLike = 0.99, mu_dtype: Optional[Any] = None, weight_decay: base.ScalarOrSchedule = 1e-3, - mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, + weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformationExtraArgs: r"""The Lion optimizer. @@ -1015,7 +1021,7 @@ def lion( """ return combine.chain( transform.scale_by_lion(b1=b1, b2=b2, mu_dtype=mu_dtype), - transform.add_decayed_weights(weight_decay, mask), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) @@ -1029,6 +1035,8 @@ def amsgrad( mu_dtype: Optional[Any] = None, bias_correction_mu: bool = True, bias_correction_nu: bool = True, + weight_decay: base.ScalarOrSchedule = 0.0, + weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformationExtraArgs: """The AMSGrad optimizer. @@ -1091,6 +1099,7 @@ def amsgrad( bias_correction_mu=bias_correction_mu, bias_correction_nu=bias_correction_nu, ), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) @@ -1238,7 +1247,7 @@ def lamb( eps: jax.typing.ArrayLike = 1e-6, eps_root: jax.typing.ArrayLike = 0.0, weight_decay: base.ScalarOrSchedule = 0.0, - mask: MaskOrFn = None, + weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformationExtraArgs: """The LAMB optimizer. @@ -1294,7 +1303,7 @@ def lamb( """ return combine.chain( transform.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root), - transform.add_decayed_weights(weight_decay=weight_decay, mask=mask), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_trust_ratio(), transform.scale_by_learning_rate(learning_rate), ) @@ -1552,6 +1561,8 @@ def optimistic_gradient_descent( learning_rate: base.ScalarOrSchedule, alpha: base.ScalarOrSchedule = 1.0, beta: base.ScalarOrSchedule = 1.0, + weight_decay: base.ScalarOrSchedule = 0.0, + weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformationExtraArgs: r"""An Optimistic Gradient Descent optimizer. @@ -1611,6 +1622,7 @@ def optimistic_gradient_descent( """ return combine.chain( transform.scale_by_optimistic_gradient(alpha=alpha, beta=beta), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) @@ -1623,6 +1635,8 @@ def optimistic_adam( eps: jax.typing.ArrayLike = 1e-08, eps_root: jax.typing.ArrayLike = 0.0, mu_dtype: Optional[Any] = None, + weight_decay: base.ScalarOrSchedule = 0.0, + weight_decay_mask: MaskOrFn = None, *, nesterov: bool = True, ) -> base.GradientTransformationExtraArgs: @@ -1735,12 +1749,15 @@ def optimistic_adam( ), transform.scale_by_optimistic_gradient(alpha=learning_rate, beta=optimism), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(1.0), # flips the sign ) def optimistic_adam_v2( learning_rate: base.ScalarOrSchedule, + weight_decay: base.ScalarOrSchedule = 0.0, + weight_decay_mask: MaskOrFn = None, *, alpha: jax.typing.ArrayLike = 1.0, beta: jax.typing.ArrayLike = 1.0, @@ -1857,6 +1874,7 @@ def optimistic_adam_v2( nesterov=nesterov, ), transform.scale_by_optimistic_gradient(alpha=alpha, beta=beta), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) @@ -1868,6 +1886,8 @@ def radam( eps: jax.typing.ArrayLike = 1e-8, eps_root: jax.typing.ArrayLike = 0.0, threshold: jax.typing.ArrayLike = 5.0, + weight_decay: base.ScalarOrSchedule = 0.0, + weight_decay_mask: MaskOrFn = None, *, nesterov: bool = False, ) -> base.GradientTransformationExtraArgs: @@ -1928,6 +1948,7 @@ def radam( threshold=threshold, nesterov=nesterov, ), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) @@ -1942,6 +1963,8 @@ def rmsprop( momentum: Optional[jax.typing.ArrayLike] = None, nesterov: bool = False, bias_correction: bool = False, + weight_decay: base.ScalarOrSchedule = 0.0, + weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformationExtraArgs: r"""A flexible RMSProp optimizer. @@ -2022,6 +2045,7 @@ def rmsprop( eps_in_sqrt=eps_in_sqrt, bias_correction=bias_correction, ), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ( transform.trace(decay=momentum, nesterov=nesterov) @@ -2051,6 +2075,8 @@ def sgd( momentum: Optional[jax.typing.ArrayLike] = None, nesterov: bool = False, accumulator_dtype: Optional[Any] = None, + weight_decay: base.ScalarOrSchedule = 0.0, + weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformationExtraArgs: r"""A canonical Stochastic Gradient Descent optimizer. @@ -2132,12 +2158,16 @@ def sgd( opt = base.identity() return combine.chain( opt, + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) def sm3( learning_rate: jax.typing.ArrayLike, momentum: jax.typing.ArrayLike = 0.9 +, + weight_decay: base.ScalarOrSchedule = 0.0, + weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformationExtraArgs: r"""The SM3 optimizer. @@ -2242,6 +2272,7 @@ def sm3( """ return combine.chain( transform.scale_by_sm3(momentum), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale(-learning_rate), ) @@ -2251,6 +2282,8 @@ def yogi( b1: jax.typing.ArrayLike = 0.9, b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-3, + weight_decay: base.ScalarOrSchedule = 0.0, + weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformationExtraArgs: # pylint: disable=line-too-long """The Yogi optimizer. @@ -2301,6 +2334,7 @@ def yogi( # pylint: enable=line-too-long return combine.chain( transform.scale_by_yogi(b1=b1, b2=b2, eps=eps), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) @@ -2396,7 +2430,7 @@ def adamaxw( b2: jax.typing.ArrayLike = 0.999, eps: jax.typing.ArrayLike = 1e-8, weight_decay: base.ScalarOrSchedule = 1e-4, - mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None, + weight_decay_mask: MaskOrFn = None, ) -> base.GradientTransformationExtraArgs: """Adamax with weight decay regularization. @@ -2460,7 +2494,7 @@ def adamaxw( """ return combine.chain( transform.scale_by_adamax(b1=b1, b2=b2, eps=eps), - transform.add_decayed_weights(weight_decay, mask), + transform.add_decayed_weights(weight_decay, mask=weight_decay_mask), transform.scale_by_learning_rate(learning_rate), ) diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py index 5e894cccb..903364d3d 100644 --- a/optax/transforms/_adding.py +++ b/optax/transforms/_adding.py @@ -57,8 +57,6 @@ def init_fn(params): return base.EmptyState() def update_fn(updates, state, params): - if params is None: - raise ValueError(base.NO_PARAMS_MSG) if callable(weight_decay): new_state = WeightDecaySchedule(numerics.safe_increment(state.count)) else: @@ -68,6 +66,9 @@ def update_fn(updates, state, params): if isinstance(weight_decay, (int, float)) and weight_decay == 0.0: return updates, new_state + if params is None: + raise ValueError(base.NO_PARAMS_MSG) + s = weight_decay(state.count) if callable(weight_decay) else weight_decay updates = jax.tree.map( lambda g, p: None if g is None else g + s * p, diff --git a/optax/tree_utils/_state_utils.py b/optax/tree_utils/_state_utils.py index 2028805c8..b622a0237 100644 --- a/optax/tree_utils/_state_utils.py +++ b/optax/tree_utils/_state_utils.py @@ -437,13 +437,13 @@ def tree_set( ... ) >>> state = opt.init(params) >>> print(state) - InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState())) + InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'weight_decay': Array(0., dtype=float32), 'learning_rate': Array(1., dtype=float32)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState(), EmptyState())) >>> filtering = lambda path, value: isinstance(value, jnp.ndarray) >>> new_state = optax.tree_utils.tree_set( ... state, filtering, learning_rate=jnp.asarray(0.1) ... ) >>> print(new_state) - InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState())) + InjectStatefulHyperparamsState(count=Array(0, dtype=int32), hyperparams={'weight_decay': Array(0., dtype=float32), 'learning_rate': Array(0.1, dtype=float32, weak_type=True)}, hyperparams_states={'learning_rate': WrappedScheduleState(count=Array(0, dtype=int32))}, inner_state=(EmptyState(), EmptyState(), EmptyState())) .. note:: The recommended usage to inject hyperparameters schedules is through :func:`optax.inject_hyperparams`. This function is a helper for other diff --git a/optax/tree_utils/_state_utils_test.py b/optax/tree_utils/_state_utils_test.py index a55884d91..0c5676d32 100644 --- a/optax/tree_utils/_state_utils_test.py +++ b/optax/tree_utils/_state_utils_test.py @@ -195,6 +195,7 @@ def test_map_params_to_none(self): ( transform.ScaleByRssState(sum_of_squares={'a': None}), base.EmptyState(), + base.EmptyState(), ), ) @@ -378,7 +379,7 @@ def get_learning_rate(state): for i in range(4): # we simply update state, we don't care about updates. - _, state = opt.update(params, state) + _, state = opt.update(params, state, params) lr = get_learning_rate(state) self.assertEqual(lr, 1 / (i + 1)) @@ -465,7 +466,7 @@ def set_learning_rate(state, lr): for i in range(4): modified_state = set_learning_rate(modified_state, lr / (i + 1)) # we simply update state, we don't care about updates. - _, modified_state = opt.update(params, modified_state) + _, modified_state = opt.update(params, modified_state, params) modified_lr = _state_utils.tree_get(modified_state, 'learning_rate') self.assertEqual(modified_lr, lr / (i + 1))