diff --git a/emerging_optimizers/scalar_optimizers/__init__.py b/emerging_optimizers/scalar_optimizers/__init__.py index e7d6e379..45cbd3cd 100644 --- a/emerging_optimizers/scalar_optimizers/__init__.py +++ b/emerging_optimizers/scalar_optimizers/__init__.py @@ -12,5 +12,183 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from emerging_optimizers.scalar_optimizers.laprop import LaProp -from emerging_optimizers.scalar_optimizers.lion import Lion +from typing import Any, override + +import torch +from absl import logging +from torch.optim.optimizer import ParamsT + +from emerging_optimizers.mixin import WeightDecayT +from emerging_optimizers.registry import register_optimizer +from emerging_optimizers.scalar_optimizers.base import ( + SingleMomentumOptimizer, + TwoMomentsOptimizer, + _validate_common_hparams, +) +from emerging_optimizers.scalar_optimizers.update_functions import ( + calculate_laprop_update, + calculate_lion_update, + calculate_signum_update, + calculate_sim_ademamix_update, +) + + +__all__ = [ + "LaProp", + "Lion", + "Signum", + "SimplifiedAdEMAMix", + "SingleMomentumOptimizer", + "TwoMomentsOptimizer", +] + + +@register_optimizer("lion") +class Lion(SingleMomentumOptimizer): + """Lion optimizer (Chen et al., 2023): sign-based update with a single first-moment EMA.""" + + def __init__( + self, + params: ParamsT, + lr: float = 1e-4, + betas: tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.01, + *, + weight_decay_method: WeightDecayT = "decoupled", + ) -> None: + _validate_common_hparams(lr=lr, betas=betas, weight_decay=weight_decay) + super().__init__( + params, + defaults=dict(lr=lr, betas=betas, weight_decay=weight_decay), + update_fn=calculate_lion_update, + update_kwarg_names=("betas",), + weight_decay_method=weight_decay_method, + ) + + +@register_optimizer("signum") +class Signum(SingleMomentumOptimizer): + """Sign-SGD / Signum optimizer (Bernstein et al., 2018): sign of a bias-corrected single-moment EMA.""" + + def __init__( + self, + params: ParamsT, + lr: float = 1e-3, + momentum: float = 0.9, + weight_decay: float = 0.0, + *, + correct_bias: bool = True, + nesterov: bool = False, + use_shape_scaling: bool = False, + weight_decay_method: WeightDecayT = "decoupled", + ) -> None: + _validate_common_hparams(lr=lr, weight_decay=weight_decay) + if not 0.0 <= momentum < 1.0: + raise ValueError(f"Invalid momentum: {momentum}") + super().__init__( + params, + defaults=dict( + lr=lr, + momentum=momentum, + weight_decay=weight_decay, + correct_bias=correct_bias, + nesterov=nesterov, + use_shape_scaling=use_shape_scaling, + ), + update_fn=calculate_signum_update, + update_kwarg_names=("momentum", "correct_bias", "nesterov", "use_shape_scaling"), + weight_decay_method=weight_decay_method, + ) + + +@register_optimizer("laprop") +class LaProp(TwoMomentsOptimizer): + """LaProp optimizer (Ziyin et al., 2020): Adam with the gradient normalized before the first-moment update.""" + + def __init__( + self, + params: ParamsT, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + correct_bias: bool = True, + frob_normalize: bool = False, + weight_decay_method: WeightDecayT = "decoupled", + ) -> None: + _validate_common_hparams(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + if frob_normalize and weight_decay != 0.0: + logging.error("LaProp with frob_normalize=True is intended to be used with weight_decay=0.0.") + self.frob_normalize = frob_normalize + super().__init__( + params, + defaults=dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + correct_bias=correct_bias, + ), + update_fn=calculate_laprop_update, + update_kwarg_names=("betas", "eps", "correct_bias"), + weight_decay_method=weight_decay_method, + ) + + @override + def pre_step_inplace(self, p: torch.Tensor, group: dict) -> Any: + return p.data.norm() if self.frob_normalize else None + + @override + def post_step_inplace(self, p: torch.Tensor, group: dict, ctx: Any) -> None: + if self.frob_normalize: + pre_norm = ctx + p.data.mul_(pre_norm / p.data.norm().clamp_min(group["eps"])) + + +@register_optimizer("sim_ademamix") +class SimplifiedAdEMAMix(TwoMomentsOptimizer): + """Simplified AdEMAMix: two-buffer variant mixing alpha-scaled current gradient into a theory-style first-moment EMA.""" + + def __init__( + self, + params: ParamsT, + lr: float = 1e-3, + betas: tuple[float, float] = (0.9999, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + *, + correct_bias: bool = True, + num_beta_fast_warmup_steps: int | None = None, + min_beta_fast: float = 0.9, + alpha: float = 2.0, + weight_decay_method: WeightDecayT = "decoupled", + ) -> None: + _validate_common_hparams(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) + if not 0.0 <= min_beta_fast < 1.0: + raise ValueError(f"Invalid min_beta_fast: {min_beta_fast}") + if num_beta_fast_warmup_steps is not None and num_beta_fast_warmup_steps <= 0: + raise ValueError(f"Invalid num_beta_fast_warmup_steps: {num_beta_fast_warmup_steps}") + super().__init__( + params, + defaults=dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + correct_bias=correct_bias, + num_beta_fast_warmup_steps=num_beta_fast_warmup_steps, + min_beta_fast=min_beta_fast, + alpha=alpha, + ), + update_fn=calculate_sim_ademamix_update, + update_kwarg_names=( + "betas", + "eps", + "correct_bias", + "num_beta_fast_warmup_steps", + "min_beta_fast", + "alpha", + ), + weight_decay_method=weight_decay_method, + ) diff --git a/emerging_optimizers/scalar_optimizers/base.py b/emerging_optimizers/scalar_optimizers/base.py new file mode 100644 index 00000000..ed97955c --- /dev/null +++ b/emerging_optimizers/scalar_optimizers/base.py @@ -0,0 +1,177 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, ClassVar, override + + +if TYPE_CHECKING: + from typing import overload + +import torch +from torch.optim.optimizer import ParamsT + +from emerging_optimizers.mixin import WeightDecayMixin, WeightDecayT + + +__all__ = [ + "SingleMomentumOptimizer", + "TwoMomentsOptimizer", +] + + +def _validate_common_hparams( + *, + lr: float | None = None, + betas: tuple[float, ...] | None = None, + eps: float | None = None, + weight_decay: float | None = None, +) -> None: + """Validates the hyperparameters shared by most scalar optimizers.""" + if lr is not None and lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if betas is not None: + for i, b in enumerate(betas): + if not 0.0 <= b < 1.0: + raise ValueError(f"Invalid beta at index {i}: {b}") + if eps is not None and eps < 0.0: + raise ValueError(f"Invalid epsilon value: {eps}") + if weight_decay is not None and weight_decay < 0.0: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + +class _ScalarOptimizerBase(WeightDecayMixin, torch.optim.Optimizer): + """Shared implementation for scalar optimizers grouped by state shape. + + Subclasses set ``state_keys`` as a ``ClassVar``. The base lazily allocates one + zero-initialized buffer per name plus a per-parameter ``step`` counter, then + dispatches each step to a constructor-supplied ``update_fn`` whose signature is + ``update_fn(grad, *buffers, **kwargs) -> Tensor``. + + Hyperparameters forwarded into ``update_fn`` are selected from the parameter + group via ``update_kwarg_names`` (a tuple of dict keys present in the + ``defaults`` mapping). The per-parameter ``step`` is always forwarded as + ``step=state["step"]``, so every update function must accept a ``step`` kwarg. + + Subclasses can additionally override :meth:`pre_step_inplace` / + :meth:`post_step_inplace` to bracket the per-parameter update with custom + logic (e.g. norm preservation). + """ + + state_keys: ClassVar[tuple[str, ...]] + + def __init__( + self, + params: ParamsT, + defaults: dict[str, Any], + *, + update_fn: Callable[..., torch.Tensor], + update_kwarg_names: tuple[str, ...], + weight_decay_method: WeightDecayT = "decoupled", + ) -> None: + missing = set(update_kwarg_names) - set(defaults.keys()) + if missing: + raise ValueError( + f"update_kwarg_names {sorted(missing)} not present in defaults (keys: {sorted(defaults.keys())})" + ) + self.update_fn = update_fn + self.update_kwarg_names = update_kwarg_names + self.weight_decay_method = weight_decay_method + super().__init__(params, defaults) + + @torch.no_grad() + def _init_group( + self, + group: dict, + skip_non_grad_params: bool = True, + ) -> None: + """Performs lazy state initialization for parameters.""" + for p in group["params"]: + if skip_non_grad_params and p.grad is None: + continue + state = self.state[p] + if len(state) == 0: + for key in self.state_keys: + state[key] = torch.zeros_like(p.data) + state["step"] = 0 + + def pre_step_inplace(self, p: torch.Tensor, group: dict) -> Any: + """Hook called before weight decay and the update. Return value is forwarded to ``post_step_inplace``.""" + return None + + def post_step_inplace(self, p: torch.Tensor, group: dict, ctx: Any) -> None: + """Hook called after the update. Receives the value returned by ``pre_step_inplace``.""" + return None + + if TYPE_CHECKING: + + @overload + def step(self, closure: None = ...) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + @torch.no_grad() # type: ignore[misc] + @override + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Perform a single optimization step. + + Note: + When ``weight_decay_method="l2"``, ``p.grad`` is modified in-place + (the L2 penalty ``weight_decay * p`` is added to the gradient). + If you need the original gradient after this call, clone it beforehand. + + Args: + closure: Unsupported; must be ``None``. + """ + if closure is not None: + raise ValueError("closure is not supported") + + for group in self.param_groups: + self._init_group(group) + + lr = group["lr"] + weight_decay = group["weight_decay"] + update_kwargs = {key: group[key] for key in self.update_kwarg_names} + + for p in group["params"]: + if p.grad is None: + continue # pragma: no cover + + state = self.state[p] + state["step"] += 1 + update_kwargs["step"] = state["step"] + + ctx = self.pre_step_inplace(p, group) + self._apply_weight_decay_inplace(p.data, p.grad, lr, weight_decay) + + buffers = tuple(state[key] for key in self.state_keys) + update = self.update_fn(p.grad, *buffers, **update_kwargs) + p.data.add_(update, alpha=-lr) + + self.post_step_inplace(p, group, ctx) + + return None + + +class SingleMomentumOptimizer(_ScalarOptimizerBase): + """Base for scalar optimizers tracking a single first-moment EMA buffer.""" + + state_keys: ClassVar[tuple[str, ...]] = ("exp_avg",) + + +class TwoMomentsOptimizer(_ScalarOptimizerBase): + """Base for Adam-style scalar optimizers tracking first + second moment buffers.""" + + state_keys: ClassVar[tuple[str, ...]] = ("exp_avg", "exp_avg_sq") diff --git a/emerging_optimizers/scalar_optimizers/laprop.py b/emerging_optimizers/scalar_optimizers/laprop.py deleted file mode 100644 index 1bf72ee0..00000000 --- a/emerging_optimizers/scalar_optimizers/laprop.py +++ /dev/null @@ -1,182 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections.abc import Callable -from typing import TYPE_CHECKING, override - - -if TYPE_CHECKING: - from typing import overload - -import torch -from absl import logging -from torch.optim.optimizer import ParamsT - -from emerging_optimizers import registry -from emerging_optimizers.mixin import WeightDecayMixin, WeightDecayT -from emerging_optimizers.scalar_optimizers.update_functions import calculate_laprop_update - - -__all__ = [ - "LaProp", -] - - -@registry.register_optimizer("laprop") -class LaProp(WeightDecayMixin, torch.optim.Optimizer): - """LaProp optimizer. - - LAProp can be seen as RMSProp with a momentum term, or normalized SGD with momentum. - This optimizer tracks Adam-style first and second moments, but normalizes the gradient - before the first-moment update. - - The update rule below assumes ``weight_decay_method="decoupled"`` (the default). - See :class:`~emerging_optimizers.mixin.WeightDecayMixin` for the other modes. - - .. math:: - p = p \\cdot (1 - \\text{lr} \\cdot \\lambda) \\\\ - v_t = \\beta_2 v_{t-1} + (1 - \\beta_2) g_t^2 \\\\ - \\hat{v}_t = \\frac{v_t}{1 - \\beta_2^t} \\\\ - g'_t = \\frac{g_t}{\\sqrt{\\hat{v}_t} + \\epsilon} \\\\ - m_t = \\beta_1 m_{t-1} + (1 - \\beta_1) g'_t \\\\ - \\hat{m}_t = \\frac{m_t}{1 - \\beta_1^t} \\\\ - p = p - \\text{lr} \\cdot \\hat{m}_t - - References: - - Ziyin, L., Wang, Z. T., & Ueda, M. *LaProp: Separating Momentum and - Adaptivity in Adam.* arXiv:2002.04839 (2020). - [`arXiv:2002.04839 `_] - - Args: - params: Iterable of parameters to optimize or dicts defining parameter groups. - lr: Learning rate. - betas: Coefficients (beta1, beta2) for first and second moment EMAs. - eps: Term added to the denominator for numerical stability. - weight_decay: Weight decay coefficient. - correct_bias: Whether to apply bias correction to the first and second moment EMAs. - frob_normalize: Whether to normalize each updated parameter back to its pre-update Frobenius norm. - weight_decay_method: Method to apply weight decay, see - :class:`~emerging_optimizers.mixin.WeightDecayMixin` for more details. - """ - - def __init__( - self, - params: ParamsT, - lr: float = 1e-3, - betas: tuple[float, float] = (0.9, 0.999), - eps: float = 1e-8, - weight_decay: float = 0.0, - *, - correct_bias: bool = True, - frob_normalize: bool = False, - weight_decay_method: WeightDecayT = "decoupled", - ) -> None: - if lr < 0.0: - raise ValueError(f"Invalid learning rate: {lr}") - if not 0.0 <= betas[0] < 1.0: - raise ValueError(f"Invalid beta at index 0: {betas[0]}") - if not 0.0 <= betas[1] < 1.0: - raise ValueError(f"Invalid beta at index 1: {betas[1]}") - if eps < 0.0: - raise ValueError(f"Invalid epsilon value: {eps}") - if weight_decay < 0.0: - raise ValueError(f"Invalid weight_decay value: {weight_decay}") - if frob_normalize and weight_decay != 0.0: - logging.warning("LaProp with frob_normalize=True is intended to be used with weight_decay=0.0.") - defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias) - self.weight_decay_method = weight_decay_method - self.frob_normalize = frob_normalize - super().__init__(params, defaults) - - @torch.no_grad() - def _init_group( - self, - group: dict, - skip_non_grad_params: bool = True, - ) -> None: - """Performs lazy state initialization for parameters. - - Args: - group: Parameter group dictionary. - skip_non_grad_params: If True, skip parameters without gradients. - """ - for p in group["params"]: - if skip_non_grad_params and p.grad is None: - continue - state = self.state[p] - - if len(state) == 0: - state["step"] = 0 - state["exp_avg"] = torch.zeros_like(p.data) - state["exp_avg_sq"] = torch.zeros_like(p.data) - - if TYPE_CHECKING: - - @overload - def step(self, closure: None = ...) -> None: ... - - @overload - def step(self, closure: Callable[[], float]) -> float: ... - - @torch.no_grad() # type: ignore[misc] - @override - def step(self, closure: Callable[[], float] | None = None) -> float | None: - """Perform a single optimization step. - - Note: - When ``weight_decay_method="l2"``, ``p.grad`` is modified in-place - (the L2 penalty ``weight_decay * p`` is added to the gradient). - If you need the original gradient after this call, clone it beforehand. - - Args: - closure: Unsupported; must be ``None``. - """ - if closure is not None: - raise ValueError("closure is not supported") - - for group in self.param_groups: - self._init_group(group) - - lr = group["lr"] - betas = group["betas"] - eps = group["eps"] - weight_decay = group["weight_decay"] - correct_bias = group["correct_bias"] - - for p in group["params"]: - if p.grad is None: - continue # pragma: no cover - - grad = p.grad - state = self.state[p] - state["step"] += 1 - pre_norm = p.data.norm() if self.frob_normalize else None - - self._apply_weight_decay_inplace(p.data, grad, lr, weight_decay) - - update = calculate_laprop_update( - grad, - state["exp_avg"], - state["exp_avg_sq"], - betas=betas, - eps=eps, - correct_bias=correct_bias, - step=state["step"], - ) - p.data.add_(update, alpha=-lr) - if self.frob_normalize: - assert pre_norm is not None - p.data.mul_(pre_norm / p.data.norm().clamp_min(eps)) - - return None diff --git a/emerging_optimizers/scalar_optimizers/lion.py b/emerging_optimizers/scalar_optimizers/lion.py deleted file mode 100644 index ae2fb913..00000000 --- a/emerging_optimizers/scalar_optimizers/lion.py +++ /dev/null @@ -1,156 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from collections.abc import Callable -from typing import TYPE_CHECKING, override - - -if TYPE_CHECKING: - from typing import overload - -import torch -from torch.optim.optimizer import ParamsT - -from emerging_optimizers import registry -from emerging_optimizers.mixin import WeightDecayMixin, WeightDecayT -from emerging_optimizers.scalar_optimizers.update_functions import calculate_lion_update - - -__all__ = [ - "Lion", -] - - -@registry.register_optimizer("lion") -class Lion(WeightDecayMixin, torch.optim.Optimizer): - """Lion optimizer (Chen et al., 2023). - - A memory-efficient optimizer that uses only sign updates and tracks a single - exponential moving average (no second moment), resulting in lower memory usage - than Adam. - - The update rule below assumes ``weight_decay_method="decoupled"`` (the default). - See :class:`~emerging_optimizers.mixin.WeightDecayMixin` for the other modes. - - .. math:: - p = p \\cdot (1 - \\text{lr} \\cdot \\lambda) \\\\ - \\text{update} = \\text{sign}(\\beta_1 m_{t-1} + (1 - \\beta_1) g_t) \\\\ - m_t = \\beta_2 m_{t-1} + (1 - \\beta_2) g_t \\\\ - p = p - \\text{lr} \\cdot \\text{update} - - References: - - Chen, X., Liang, C., Huang, D., Real, E., Wang, K., Liu, Y., Pham, H., Dong, X., - Luber, T., Cho, T., Le, Q. V., & Henaff, O. J. *Symbolic Discovery of Optimization Algorithms.* - arXiv:2302.06675 (2023). - [`arXiv:2302.06675 `_] - - Args: - params: Iterable of parameters to optimize or dicts defining parameter groups. - lr: Learning rate. - betas: Coefficients (beta1, beta2) for computing the update and running average. - beta1 is used for the sign update interpolation, beta2 for the momentum EMA update. - weight_decay: Weight decay coefficient. - weight_decay_method: Method to apply weight decay, see - :class:`~emerging_optimizers.mixin.WeightDecayMixin` for more details. - """ - - def __init__( - self, - params: ParamsT, - lr: float = 1e-4, - betas: tuple[float, float] = (0.9, 0.99), - weight_decay: float = 0.01, - *, - weight_decay_method: WeightDecayT = "decoupled", - ) -> None: - if lr < 0.0: - raise ValueError(f"Invalid learning rate: {lr}") - if not 0.0 <= betas[0] < 1.0: - raise ValueError(f"Invalid beta at index 0: {betas[0]}") - if not 0.0 <= betas[1] < 1.0: - raise ValueError(f"Invalid beta at index 1: {betas[1]}") - if weight_decay < 0.0: - raise ValueError(f"Invalid weight_decay value: {weight_decay}") - defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) - self.weight_decay_method = weight_decay_method - super().__init__(params, defaults) - - @torch.no_grad() - def _init_group( - self, - group: dict, - skip_non_grad_params: bool = True, - ) -> None: - """Performs lazy state initialization for parameters. - - Args: - group: Parameter group dictionary. - skip_non_grad_params: If True, skip parameters without gradients. - """ - for p in group["params"]: - if skip_non_grad_params and p.grad is None: - continue - state = self.state[p] - - if len(state) == 0: - state["exp_avg"] = torch.zeros_like(p.data) - - if TYPE_CHECKING: - - @overload - def step(self, closure: None = ...) -> None: ... - - @overload - def step(self, closure: Callable[[], float]) -> float: ... - - @torch.no_grad() # type: ignore[misc] - @override - def step(self, closure: Callable[[], float] | None = None) -> float | None: - """Perform a single optimization step. - - Note: - When ``weight_decay_method="l2"``, ``p.grad`` is modified in-place - (the L2 penalty ``weight_decay * p`` is added to the gradient). - If you need the original gradient after this call, clone it beforehand. - - Args: - closure: Unsupported; must be ``None``. - """ - if closure is not None: - raise ValueError("closure is not supported") - - for group in self.param_groups: - self._init_group(group) - - lr = group["lr"] - betas = group["betas"] - weight_decay = group["weight_decay"] - - for p in group["params"]: - if p.grad is None: - continue # pragma: no cover - - grad = p.grad - exp_avg = self.state[p]["exp_avg"] - - # Weight decay - self._apply_weight_decay_inplace(p.data, grad, lr, weight_decay) - - # Lion update: sign(beta1 * m + (1-beta1) * g) - # Note: different betas per param-group will each trigger a one-time - # torch.compile recompilation of calculate_lion_update. - update = calculate_lion_update(grad, exp_avg, betas=betas) - p.data.add_(update, alpha=-lr) - - return None diff --git a/emerging_optimizers/scalar_optimizers/update_functions/lion.py b/emerging_optimizers/scalar_optimizers/update_functions/lion.py index 3a15e53c..0572bec8 100644 --- a/emerging_optimizers/scalar_optimizers/update_functions/lion.py +++ b/emerging_optimizers/scalar_optimizers/update_functions/lion.py @@ -27,6 +27,7 @@ def calculate_lion_update( exp_avg: torch.Tensor, *, betas: tuple[float, float], + step: int, ) -> torch.Tensor: """Performs the Lion update. @@ -42,10 +43,12 @@ def calculate_lion_update( grad: The gradient tensor. exp_avg: The accumulated first moment of the gradient (modified in place). betas: The EMA beta coefficients ``(beta1, beta2)``. ``beta1`` controls the sign-update interpolation; ``beta2`` controls the momentum EMA. + step: Current optimizer step (1-based). Accepted for signature uniformity with the other ``calculate_*_update`` functions; Lion has no time-dependent bias correction so the value is unused. Returns: The Lion update. """ + del step # unused; accepted for signature uniformity with other update fns beta1, beta2 = betas diff --git a/tests/test_laprop.py b/tests/test_laprop.py deleted file mode 100644 index 3c874f79..00000000 --- a/tests/test_laprop.py +++ /dev/null @@ -1,181 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from absl import flags, logging -from absl.testing import absltest, parameterized - -from emerging_optimizers.scalar_optimizers import LaProp -from emerging_optimizers.scalar_optimizers.update_functions import calculate_laprop_update - - -flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") -flags.DEFINE_integer("seed", None, "Random seed for reproducible tests") -FLAGS = flags.FLAGS - - -def setUpModule() -> None: - if FLAGS.seed is not None: - logging.info("Setting random seed to %d", FLAGS.seed) - torch.manual_seed(FLAGS.seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(FLAGS.seed) - - -class LaPropOptimizerTest(parameterized.TestCase): - def setUp(self): - self.device = FLAGS.device - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_smoke(self, shape) -> None: - """LaProp optimizer can be instantiated and stepped.""" - param = torch.nn.Parameter(torch.randn(*shape, device=self.device)) - optimizer = LaProp([param], lr=1e-4) - param.grad = torch.randn_like(param) - optimizer.step() - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_state_initialization(self, shape) -> None: - """LaProp initializes first moment, second moment, and step state.""" - beta1, beta2 = 0.5, 0.75 - param = torch.nn.Parameter(torch.randint(1, 5, shape, device=self.device, dtype=torch.float32)) - optimizer = LaProp([param], lr=0.25, betas=(beta1, beta2), weight_decay=0.0, correct_bias=True) - grad = torch.randint_like(param, 1, 5) - param.grad = grad.clone() - optimizer.step() - - self.assertEqual(optimizer.state[param]["step"], 1) - self.assertIn("exp_avg", optimizer.state[param]) - self.assertIn("exp_avg_sq", optimizer.state[param]) - - expected_exp_avg_sq = (1 - beta2) * grad.square() - normalized_grad = grad / (grad.abs() + optimizer.param_groups[0]["eps"]) - expected_exp_avg = (1 - beta1) * normalized_grad - torch.testing.assert_close(optimizer.state[param]["exp_avg_sq"], expected_exp_avg_sq, atol=0, rtol=0) - torch.testing.assert_close(optimizer.state[param]["exp_avg"], expected_exp_avg, atol=0, rtol=0) - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_optimizer_step_matches_update_function(self, shape) -> None: - """LaProp optimizer delegates update math to calculate_laprop_update.""" - lr = 0.25 - betas = (0.5, 0.75) - eps = 1e-8 - param = torch.nn.Parameter(torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32)) - grad = torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32) - optimizer = LaProp([param], lr=lr, betas=betas, eps=eps, weight_decay=0.0) - - old_param = param.detach().clone() - exp_avg = torch.zeros_like(param) - exp_avg_sq = torch.zeros_like(param) - expected_update = calculate_laprop_update( - grad, exp_avg, exp_avg_sq, betas=betas, eps=eps, correct_bias=True, step=1 - ) - - param.grad = grad.clone() - optimizer.step() - - torch.testing.assert_close(param, old_param - lr * expected_update, atol=0, rtol=0) - torch.testing.assert_close(optimizer.state[param]["exp_avg"], exp_avg, atol=0, rtol=0) - torch.testing.assert_close(optimizer.state[param]["exp_avg_sq"], exp_avg_sq, atol=0, rtol=0) - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_no_grad_no_update_params_unchanged(self, shape) -> None: - """Parameters without gradients are not updated.""" - param = torch.nn.Parameter(torch.randn(*shape, device=self.device)) - original = param.detach().clone() - optimizer = LaProp([param], lr=1e-4) - optimizer.step() - torch.testing.assert_close(param, original, atol=0, rtol=0) - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_frob_normalize_preserves_parameter_norm(self, shape) -> None: - """LaProp can normalize updated parameters back to their pre-update norm.""" - param = torch.nn.Parameter(torch.randint(1, 5, shape, device=self.device, dtype=torch.float32)) - optimizer = LaProp([param], lr=0.25, weight_decay=0.0, frob_normalize=True) - param.grad = torch.randint(1, 5, shape, device=self.device, dtype=torch.float32) - original_norm = param.norm() - - optimizer.step() - - torch.testing.assert_close(param.norm(), original_norm) - - @parameterized.parameters(True, False) - def test_init_group_skip_non_grad_params(self, skip_non_grad_params) -> None: - """Test _init_group with skip_non_grad_params flag.""" - param_with_grad = torch.nn.Parameter(torch.randn(5, 7, device=self.device)) - param_without_grad = torch.nn.Parameter(torch.randn(5, 7, device=self.device)) - param_with_grad.grad = torch.randn_like(param_with_grad) - - opt = LaProp([param_with_grad, param_without_grad], lr=1e-4) - opt._init_group(opt.param_groups[0], skip_non_grad_params=skip_non_grad_params) - - self.assertIn("exp_avg", opt.state[param_with_grad]) - self.assertIn("exp_avg_sq", opt.state[param_with_grad]) - self.assertEqual(opt.state[param_with_grad]["step"], 0) - self.assertEqual("exp_avg" in opt.state[param_without_grad], not skip_non_grad_params) - - def test_negative_lr_raises_value_error(self) -> None: - """Test that LaProp raises ValueError for negative learning rate.""" - param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) - with self.assertRaisesRegex(ValueError, "Invalid learning rate"): - LaProp([param], lr=-1.0) - - def test_beta0_out_of_range_raises_value_error(self) -> None: - """Test that LaProp raises ValueError for invalid beta at index 0.""" - param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) - with self.assertRaisesRegex(ValueError, "Invalid beta at index 0"): - LaProp([param], betas=(1.0, 0.999)) - - def test_beta1_out_of_range_raises_value_error(self) -> None: - """Test that LaProp raises ValueError for invalid beta at index 1.""" - param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) - with self.assertRaisesRegex(ValueError, "Invalid beta at index 1"): - LaProp([param], betas=(0.9, 1.0)) - - def test_negative_eps_raises_value_error(self) -> None: - """Test that LaProp raises ValueError for negative eps.""" - param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) - with self.assertRaisesRegex(ValueError, "Invalid epsilon"): - LaProp([param], eps=-1e-8) - - def test_negative_weight_decay_raises_value_error(self) -> None: - """Test that LaProp raises ValueError for negative weight_decay.""" - param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) - with self.assertRaisesRegex(ValueError, "Invalid weight_decay"): - LaProp([param], weight_decay=-0.1) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/test_lion.py b/tests/test_lion.py deleted file mode 100644 index 1b59285e..00000000 --- a/tests/test_lion.py +++ /dev/null @@ -1,287 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import torch -from absl import flags, logging -from absl.testing import absltest, parameterized - -from emerging_optimizers.scalar_optimizers import Lion - - -flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") -flags.DEFINE_integer("seed", None, "Random seed for reproducible tests") -FLAGS = flags.FLAGS - - -def setUpModule() -> None: - if FLAGS.seed is not None: - logging.info("Setting random seed to %d", FLAGS.seed) - torch.manual_seed(FLAGS.seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed_all(FLAGS.seed) - - -class LionOptimizerTest(parameterized.TestCase): - def setUp(self): - self.device = FLAGS.device - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_smoke(self, shape) -> None: - """Lion optimizer can be instantiated and stepped.""" - param = torch.nn.Parameter(torch.randn(*shape, device=self.device)) - optimizer = Lion([param], lr=1e-4) - param.grad = torch.randn_like(param) - optimizer.step() - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_state_initialization(self, shape) -> None: - """Lion initializes exp_avg state to zeros on first step.""" - beta2 = 0.75 - param = torch.nn.Parameter(torch.randint(-3, 5, shape, device=self.device, dtype=torch.float32)) - optimizer = Lion([param], lr=0.25, betas=(0.5, beta2), weight_decay=0.0) - grad = torch.randint_like(param, -3, 5) - param.grad = grad.clone() - optimizer.step() - self.assertIn("exp_avg", optimizer.state[param]) - # exp_avg is initialized to zero then updated: 0 * beta2 + (1 - beta2) * grad - expected = (1 - beta2) * grad - torch.testing.assert_close(optimizer.state[param]["exp_avg"], expected, atol=0, rtol=0) - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_no_grad_no_update_params_unchanged(self, shape) -> None: - """Parameters without gradients are not updated.""" - param = torch.nn.Parameter(torch.randn(*shape, device=self.device)) - original = param.data.clone() - optimizer = Lion([param], lr=1e-4) - optimizer.step() - torch.testing.assert_close(param.data, original, atol=0, rtol=0) - - @parameterized.product( - betas=[(0.9, 0.99), (0.95, 0.98)], - shape=[(3, 3), (15, 31), (127, 255)], - ) - def test_update_is_sign_based(self, betas, shape) -> None: - """Lion updates should be +/- lr (sign-based).""" - param = torch.nn.Parameter(torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32)) - optimizer = Lion([param], lr=0.25, betas=betas, weight_decay=0.0) - # Use a fixed, non-zero gradient to guarantee sign(g) != 0 for every element. - param.grad = torch.randint(1, 5, shape, device=self.device, dtype=torch.float32) - old_param = param.data.clone() - optimizer.step() - - # The change should be exactly +/- lr since Lion uses sign updates - diff = old_param - param.data - torch.testing.assert_close(diff.abs(), torch.full_like(diff, 0.25), atol=0, rtol=0) - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_weight_decay_decoupled_matches_analytical(self, shape) -> None: - """Decoupled weight decay shrinks parameters toward zero.""" - lr = 0.25 - wd = 0.5 - param = torch.nn.Parameter(torch.randint(1, 5, shape, device=self.device, dtype=torch.float32)) - optimizer = Lion([param], lr=lr, weight_decay=wd, weight_decay_method="decoupled") - param.grad = torch.zeros(*shape, device=self.device) - - old_param = param.data.clone() - optimizer.step() - - # With zero grad, sign update is 0. Decoupled weight decay: p = p * (1 - lr * wd) - expected = old_param * (1 - lr * wd) - torch.testing.assert_close(param.data, expected, atol=0, rtol=0) - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_weight_decay_l2(self, shape) -> None: - """L2 weight decay folds into gradient before sign(), so it can be masked.""" - lr = 0.25 - wd = 0.5 - param = torch.nn.Parameter(torch.randint(1, 5, shape, device=self.device, dtype=torch.float32)) - optimizer = Lion([param], lr=lr, weight_decay=wd, weight_decay_method="l2") - # Use zero gradient so that the only gradient contribution is from L2: grad += wd * p - param.grad = torch.zeros(*shape, device=self.device) - - old_param = param.data.clone() - optimizer.step() - - # After L2, grad becomes 0 + wd * p (all positive since p > 0). - # First step: exp_avg is zero, so update = sign(beta1 * 0 + (1-beta1) * wd * p) = sign(positive) = 1 - # p = p - lr * sign = p - lr - expected = old_param - lr - torch.testing.assert_close(param.data, expected, atol=0, rtol=0) - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_weight_decay_l2_masked_by_gradient(self, shape) -> None: - """L2 decay penalty can be masked when the gradient dominates the sign.""" - lr = 0.25 - wd = 0.125 - param = torch.nn.Parameter(torch.randint(1, 5, shape, device=self.device, dtype=torch.float32)) - optimizer = Lion([param], lr=lr, weight_decay=wd, weight_decay_method="l2") - # Large negative gradient dominates: grad + wd*p is still negative, sign = -1 - param.grad = torch.randint(-10, -5, shape, device=self.device, dtype=torch.float32) - - old_param = param.data.clone() - optimizer.step() - - # sign(negative) = -1, so p = p - lr * (-1) = p + lr, parameter grows - # L2 cannot guarantee shrinkage when gradient dominates - expected = old_param + lr - torch.testing.assert_close(param.data, expected, atol=0, rtol=0) - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_weight_decay_independent_matches_analytical(self, shape) -> None: - """Independent weight decay shrinks params without lr scaling.""" - lr = 0.25 - wd = 0.5 - param = torch.nn.Parameter(torch.randint(1, 5, shape, device=self.device, dtype=torch.float32)) - optimizer = Lion([param], lr=lr, weight_decay=wd, weight_decay_method="independent") - param.grad = torch.zeros(*shape, device=self.device) - - old_param = param.data.clone() - optimizer.step() - - # Independent: p = p * (1 - wd). With zero grad, sign update is 0. - expected = old_param * (1 - wd) - torch.testing.assert_close(param.data, expected, atol=0, rtol=0) - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_exp_avg_evolves_correctly(self, shape) -> None: - """Verify exp_avg state matches analytical values after deterministic steps.""" - beta1, beta2 = 0.9, 0.99 - param = torch.nn.Parameter(torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32)) - optimizer = Lion([param], lr=0.01, betas=(beta1, beta2), weight_decay=0.0) - - grads = [ - torch.randint(-3, 3, shape, device=self.device, dtype=torch.float32), - torch.randint(-3, 3, shape, device=self.device, dtype=torch.float32), - torch.randint(-3, 3, shape, device=self.device, dtype=torch.float32), - ] - - # exp_avg starts at 0. Each step: exp_avg = lerp(exp_avg, grad, 1 - beta2) - # i.e. exp_avg = beta2 * exp_avg + (1 - beta2) * grad - expected_exp_avg = torch.zeros(*shape, device=self.device) - for grad in grads: - param.grad = grad.clone() - optimizer.step() - expected_exp_avg = beta2 * expected_exp_avg + (1 - beta2) * grad - - torch.testing.assert_close(optimizer.state[param]["exp_avg"], expected_exp_avg, atol=1e-6, rtol=1e-6) - - @parameterized.parameters(True, False) - def test_init_group_skip_non_grad_params(self, skip_non_grad_params) -> None: - """Test _init_group with skip_non_grad_params flag.""" - param_with_grad = torch.nn.Parameter(torch.randn(5, 7, device=self.device)) - param_without_grad = torch.nn.Parameter(torch.randn(5, 7, device=self.device)) - param_with_grad.grad = torch.randn_like(param_with_grad) - - opt = Lion([param_with_grad, param_without_grad], lr=1e-4) - - opt._init_group(opt.param_groups[0], skip_non_grad_params=skip_non_grad_params) - - self.assertIn("exp_avg", opt.state[param_with_grad]) - self.assertEqual(opt.state[param_with_grad]["exp_avg"].shape, param_with_grad.data.shape) - - self.assertEqual("exp_avg" in opt.state[param_without_grad], not skip_non_grad_params) - - @parameterized.parameters( - {"shape": (3, 3)}, - {"shape": (15, 31)}, - {"shape": (127, 255)}, - ) - def test_param_groups_large_lr_moves_more(self, shape) -> None: - """Lion supports multiple parameter groups with different hyperparameters.""" - p1 = torch.nn.Parameter(torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32)) - p2 = torch.nn.Parameter(torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32)) - optimizer = Lion( - [ - {"params": [p1], "lr": 0.01}, - {"params": [p2], "lr": 0.001}, - ], - betas=(0.9, 0.99), - weight_decay=0.0, - ) - p1_original = p1.data.clone() - p2_original = p2.data.clone() - grad = torch.randint(1, 5, shape, device=self.device, dtype=torch.float32) - p1.grad = grad.clone() - p2.grad = grad.clone() - optimizer.step() - - # Both should have state initialized - self.assertIn("exp_avg", optimizer.state[p1]) - self.assertIn("exp_avg", optimizer.state[p2]) - - # p1 (lr=0.01) should have moved more than p2 (lr=0.001) - p1_change = (p1.data - p1_original).abs().mean() - p2_change = (p2.data - p2_original).abs().mean() - self.assertGreater(p1_change.item(), p2_change.item()) - - def test_negative_lr_raises_value_error(self) -> None: - """Test that Lion raises ValueError for negative learning rate.""" - param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) - with self.assertRaisesRegex(ValueError, "Invalid learning rate"): - Lion([param], lr=-1.0) - - def test_beta0_out_of_range_raises_value_error(self) -> None: - """Test that Lion raises ValueError for invalid beta at index 0.""" - param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) - with self.assertRaisesRegex(ValueError, "Invalid beta at index 0"): - Lion([param], betas=(1.0, 0.99)) - - def test_beta1_out_of_range_raises_value_error(self) -> None: - """Test that Lion raises ValueError for invalid beta at index 1.""" - param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) - with self.assertRaisesRegex(ValueError, "Invalid beta at index 1"): - Lion([param], betas=(0.9, 1.0)) - - def test_negative_weight_decay_raises_value_error(self) -> None: - """Test that Lion raises ValueError for negative weight_decay.""" - param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) - with self.assertRaisesRegex(ValueError, "Invalid weight_decay"): - Lion([param], weight_decay=-0.1) - - -if __name__ == "__main__": - absltest.main() diff --git a/tests/test_scalar_optimizers.py b/tests/test_scalar_optimizers.py index 8148bc6c..b6b6dd46 100644 --- a/tests/test_scalar_optimizers.py +++ b/tests/test_scalar_optimizers.py @@ -16,7 +16,13 @@ from absl import flags, logging, testing from absl.testing import parameterized -from emerging_optimizers.scalar_optimizers import update_functions +from emerging_optimizers.scalar_optimizers import ( + LaProp, + Lion, + Signum, + SimplifiedAdEMAMix, + update_functions, +) flags.DEFINE_enum("device", "cpu", ["cpu", "cuda"], "Device to run tests on") @@ -304,7 +310,7 @@ def test_calculate_lion_update_returns_sign(self) -> None: exp_avg = torch.randn(shape, device=self.device) exp_avg_clone = exp_avg.clone() - update = update_functions.calculate_lion_update(grad, exp_avg, betas=(beta, beta)) + update = update_functions.calculate_lion_update(grad, exp_avg, betas=(beta, beta), step=1) # Update should be sign(beta * m + (1 - beta) * g) expected_update = torch.sign(beta * exp_avg_clone + (1 - beta) * grad) @@ -322,7 +328,7 @@ def test_calculate_lion_update_with_separate_betas(self) -> None: exp_avg = torch.randn(shape, device=self.device) exp_avg_clone = exp_avg.clone() - update = update_functions.calculate_lion_update(grad, exp_avg, betas=(beta1, beta2)) + update = update_functions.calculate_lion_update(grad, exp_avg, betas=(beta1, beta2), step=1) expected_update = torch.sign(beta1 * exp_avg_clone + (1 - beta1) * grad) torch.testing.assert_close(update, expected_update, atol=0, rtol=0) @@ -453,5 +459,390 @@ def test_calculate_madam_update_5steps_zero_masked_is_finite(self) -> None: self.assertTrue((nonzero_cols.abs() > 0).all()) +class _CommonScalarOptimizerTests: + """Tests applied via mixin to every scalar optimizer test class. + + Concrete subclasses must additionally inherit from ``parameterized.TestCase`` + and set: + + - ``OPTIMIZER_CLS``: the optimizer class under test + - ``STATE_KEYS``: tuple of per-parameter state keys (must include ``"step"``) + """ + + OPTIMIZER_CLS: type + STATE_KEYS: tuple[str, ...] + + def setUp(self) -> None: + self.device = FLAGS.device + + def test_smoke(self) -> None: + param = torch.nn.Parameter(torch.randn(15, 31, device=self.device)) + opt = self.OPTIMIZER_CLS([param], lr=1e-4) + param.grad = torch.randn_like(param) + opt.step() + + def test_no_grad_no_update_params_unchanged(self) -> None: + """Parameters without gradients are not updated.""" + param = torch.nn.Parameter(torch.randn(15, 31, device=self.device)) + original = param.detach().clone() + opt = self.OPTIMIZER_CLS([param], lr=1e-4) + opt.step() + torch.testing.assert_close(param, original, atol=0, rtol=0) + + def test_state_keys_after_first_step(self) -> None: + """First step populates exactly the expected state keys, with step==1 and matching-shape buffers.""" + param = torch.nn.Parameter(torch.randn(5, 7, device=self.device)) + param.grad = torch.randn_like(param) + opt = self.OPTIMIZER_CLS([param], lr=1e-4) + opt.step() + self.assertEqual(set(opt.state[param].keys()), set(self.STATE_KEYS)) + self.assertEqual(opt.state[param]["step"], 1) + for key in self.STATE_KEYS: + if key == "step": + continue + self.assertEqual(opt.state[param][key].shape, param.data.shape) + + def test_init_group_skip_non_grad_params(self) -> None: + """``_init_group(..., skip_non_grad_params=...)`` honors the flag.""" + with_grad = torch.nn.Parameter(torch.randn(5, 7, device=self.device)) + without_grad = torch.nn.Parameter(torch.randn(5, 7, device=self.device)) + with_grad.grad = torch.randn_like(with_grad) + + for skip in (True, False): + with self.subTest(skip_non_grad_params=skip): + opt = self.OPTIMIZER_CLS([with_grad, without_grad], lr=1e-4) + opt._init_group(opt.param_groups[0], skip_non_grad_params=skip) + for key in self.STATE_KEYS: + self.assertIn(key, opt.state[with_grad]) + self.assertEqual(opt.state[with_grad]["step"], 0) + self.assertEqual("exp_avg" in opt.state[without_grad], not skip) + + def test_param_groups_large_lr_moves_more(self) -> None: + """A param group with larger lr moves farther after one step.""" + shape = (15, 31) + p1 = torch.nn.Parameter(torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32)) + p2 = torch.nn.Parameter(torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32)) + opt = self.OPTIMIZER_CLS( + [{"params": [p1], "lr": 0.01}, {"params": [p2], "lr": 0.001}], + weight_decay=0.0, + ) + p1_orig = p1.detach().clone() + p2_orig = p2.detach().clone() + grad = torch.randint(1, 5, shape, device=self.device, dtype=torch.float32) + p1.grad = grad.clone() + p2.grad = grad.clone() + opt.step() + self.assertGreater( + (p1.data - p1_orig).abs().mean().item(), + (p2.data - p2_orig).abs().mean().item(), + ) + + def test_closure_unsupported(self) -> None: + param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) + param.grad = torch.randn_like(param) + opt = self.OPTIMIZER_CLS([param], lr=1e-4) + with self.assertRaisesRegex(ValueError, "closure is not supported"): + opt.step(closure=lambda: 0.0) + + def test_negative_lr_raises_value_error(self) -> None: + param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) + with self.assertRaisesRegex(ValueError, "Invalid learning rate"): + self.OPTIMIZER_CLS([param], lr=-1.0) + + def test_negative_weight_decay_raises_value_error(self) -> None: + param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) + with self.assertRaisesRegex(ValueError, "Invalid weight_decay"): + self.OPTIMIZER_CLS([param], weight_decay=-0.1) + + +class _HasBetasTests: + """Mixed into optimizer test classes whose ``__init__`` accepts ``betas=(b1, b2)``.""" + + OPTIMIZER_CLS: type + + def test_beta0_out_of_range_raises_value_error(self) -> None: + param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) + with self.assertRaisesRegex(ValueError, "Invalid beta at index 0"): + self.OPTIMIZER_CLS([param], betas=(1.0, 0.99)) + + def test_beta1_out_of_range_raises_value_error(self) -> None: + param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) + with self.assertRaisesRegex(ValueError, "Invalid beta at index 1"): + self.OPTIMIZER_CLS([param], betas=(0.9, 1.0)) + + +class _HasEpsTests: + """Mixed into optimizer test classes whose ``__init__`` accepts ``eps``.""" + + OPTIMIZER_CLS: type + + def test_negative_eps_raises_value_error(self) -> None: + param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) + with self.assertRaisesRegex(ValueError, "Invalid epsilon"): + self.OPTIMIZER_CLS([param], eps=-1e-8) + + +class LionOptimizerTest(_CommonScalarOptimizerTests, _HasBetasTests, parameterized.TestCase): + OPTIMIZER_CLS = Lion + STATE_KEYS = ("exp_avg", "step") + + @parameterized.product( + betas=[(0.9, 0.99), (0.95, 0.98)], + shape=[(3, 3), (15, 31), (127, 255)], + ) + def test_update_is_sign_based(self, betas, shape) -> None: + """Lion updates should be +/- lr (sign-based).""" + param = torch.nn.Parameter(torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32)) + opt = Lion([param], lr=0.25, betas=betas, weight_decay=0.0) + param.grad = torch.randint(1, 5, shape, device=self.device, dtype=torch.float32) + old_param = param.data.clone() + opt.step() + diff = old_param - param.data + torch.testing.assert_close(diff.abs(), torch.full_like(diff, 0.25), atol=0, rtol=0) + + @parameterized.parameters( + {"shape": (3, 3)}, + {"shape": (15, 31)}, + {"shape": (127, 255)}, + ) + def test_exp_avg_evolves_correctly(self, shape) -> None: + """``exp_avg`` matches the analytical EMA after three deterministic steps.""" + beta1, beta2 = 0.9, 0.99 + param = torch.nn.Parameter(torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32)) + opt = Lion([param], lr=0.01, betas=(beta1, beta2), weight_decay=0.0) + grads = [ + torch.randint(-3, 3, shape, device=self.device, dtype=torch.float32), + torch.randint(-3, 3, shape, device=self.device, dtype=torch.float32), + torch.randint(-3, 3, shape, device=self.device, dtype=torch.float32), + ] + expected_exp_avg = torch.zeros(*shape, device=self.device) + for grad in grads: + param.grad = grad.clone() + opt.step() + expected_exp_avg = beta2 * expected_exp_avg + (1 - beta2) * grad + torch.testing.assert_close(opt.state[param]["exp_avg"], expected_exp_avg, atol=1e-6, rtol=1e-6) + + @parameterized.parameters( + {"shape": (3, 3)}, + {"shape": (15, 31)}, + ) + def test_weight_decay_decoupled_matches_analytical(self, shape) -> None: + lr, wd = 0.25, 0.5 + param = torch.nn.Parameter(torch.randint(1, 5, shape, device=self.device, dtype=torch.float32)) + opt = Lion([param], lr=lr, weight_decay=wd, weight_decay_method="decoupled") + param.grad = torch.zeros(*shape, device=self.device) + old_param = param.data.clone() + opt.step() + torch.testing.assert_close(param.data, old_param * (1 - lr * wd), atol=0, rtol=0) + + @parameterized.parameters( + {"shape": (3, 3)}, + {"shape": (15, 31)}, + ) + def test_weight_decay_independent_matches_analytical(self, shape) -> None: + lr, wd = 0.25, 0.5 + param = torch.nn.Parameter(torch.randint(1, 5, shape, device=self.device, dtype=torch.float32)) + opt = Lion([param], lr=lr, weight_decay=wd, weight_decay_method="independent") + param.grad = torch.zeros(*shape, device=self.device) + old_param = param.data.clone() + opt.step() + torch.testing.assert_close(param.data, old_param * (1 - wd), atol=0, rtol=0) + + @parameterized.parameters( + {"shape": (3, 3)}, + {"shape": (15, 31)}, + ) + def test_weight_decay_l2(self, shape) -> None: + """L2 weight decay folds into the gradient before sign(); with zero grad it shrinks via WD.""" + lr, wd = 0.25, 0.5 + param = torch.nn.Parameter(torch.randint(1, 5, shape, device=self.device, dtype=torch.float32)) + opt = Lion([param], lr=lr, weight_decay=wd, weight_decay_method="l2") + param.grad = torch.zeros(*shape, device=self.device) + old_param = param.data.clone() + opt.step() + torch.testing.assert_close(param.data, old_param - lr, atol=0, rtol=0) + + @parameterized.parameters( + {"shape": (3, 3)}, + {"shape": (15, 31)}, + ) + def test_weight_decay_l2_masked_by_gradient(self, shape) -> None: + """A large negative grad dominates the sign so L2 cannot guarantee shrinkage.""" + lr, wd = 0.25, 0.125 + param = torch.nn.Parameter(torch.randint(1, 5, shape, device=self.device, dtype=torch.float32)) + opt = Lion([param], lr=lr, weight_decay=wd, weight_decay_method="l2") + param.grad = torch.randint(-10, -5, shape, device=self.device, dtype=torch.float32) + old_param = param.data.clone() + opt.step() + torch.testing.assert_close(param.data, old_param + lr, atol=0, rtol=0) + + +class SignumOptimizerTest(_CommonScalarOptimizerTests, parameterized.TestCase): + OPTIMIZER_CLS = Signum + STATE_KEYS = ("exp_avg", "step") + + def test_invalid_momentum_raises_value_error(self) -> None: + param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) + for invalid in (-0.1, 1.0, 1.5): + with self.subTest(momentum=invalid): + with self.assertRaisesRegex(ValueError, "Invalid momentum"): + Signum([param], momentum=invalid) + + @parameterized.parameters( + {"shape": (3, 3)}, + {"shape": (15, 31)}, + ) + def test_update_is_sign_based(self, shape) -> None: + """With ``use_shape_scaling=False`` and a positive gradient, Signum updates are +/- lr.""" + param = torch.nn.Parameter(torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32)) + opt = Signum( + [param], + lr=0.25, + momentum=0.9, + weight_decay=0.0, + correct_bias=True, + nesterov=False, + use_shape_scaling=False, + ) + param.grad = torch.randint(1, 5, shape, device=self.device, dtype=torch.float32) + old_param = param.data.clone() + opt.step() + # bias_correction at step 1 cancels the (1-momentum) factor, so sign(corrected) = sign(grad) = +1. + torch.testing.assert_close(old_param - param.data, torch.full(shape, 0.25, device=self.device), atol=0, rtol=0) + + +class LaPropOptimizerTest(_CommonScalarOptimizerTests, _HasBetasTests, _HasEpsTests, parameterized.TestCase): + OPTIMIZER_CLS = LaProp + STATE_KEYS = ("exp_avg", "exp_avg_sq", "step") + + def test_frob_normalize_with_nonzero_weight_decay_logs_error(self) -> None: + """LaProp logs an ERROR when ``frob_normalize=True`` is combined with a non-zero weight decay.""" + param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) + with self.assertLogs(level="ERROR") as cm: + LaProp([param], frob_normalize=True, weight_decay=0.1) + self.assertTrue( + any("frob_normalize=True is intended to be used with weight_decay=0.0" in msg for msg in cm.output), + f"expected frob_normalize/weight_decay warning, got: {cm.output}", + ) + + @parameterized.parameters( + {"shape": (3, 3)}, + {"shape": (15, 31)}, + {"shape": (127, 255)}, + ) + def test_state_evolves_correctly(self, shape) -> None: + """After one step, ``exp_avg`` and ``exp_avg_sq`` match LaProp's analytical values.""" + beta1, beta2 = 0.5, 0.75 + param = torch.nn.Parameter(torch.randint(1, 5, shape, device=self.device, dtype=torch.float32)) + opt = LaProp([param], lr=0.25, betas=(beta1, beta2), weight_decay=0.0, correct_bias=True) + grad = torch.randint_like(param, 1, 5) + param.grad = grad.clone() + opt.step() + expected_exp_avg_sq = (1 - beta2) * grad.square() + normalized_grad = grad / (grad.abs() + opt.param_groups[0]["eps"]) + expected_exp_avg = (1 - beta1) * normalized_grad + torch.testing.assert_close(opt.state[param]["exp_avg_sq"], expected_exp_avg_sq, atol=0, rtol=0) + torch.testing.assert_close(opt.state[param]["exp_avg"], expected_exp_avg, atol=0, rtol=0) + + @parameterized.parameters( + {"shape": (3, 3)}, + {"shape": (15, 31)}, + {"shape": (127, 255)}, + ) + def test_optimizer_step_matches_update_function(self, shape) -> None: + """LaProp optimizer delegates update math to ``calculate_laprop_update``.""" + lr = 0.25 + betas = (0.5, 0.75) + eps = 1e-8 + param = torch.nn.Parameter(torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32)) + grad = torch.randint(-5, 5, shape, device=self.device, dtype=torch.float32) + opt = LaProp([param], lr=lr, betas=betas, eps=eps, weight_decay=0.0) + old_param = param.detach().clone() + exp_avg = torch.zeros_like(param) + exp_avg_sq = torch.zeros_like(param) + expected_update = update_functions.calculate_laprop_update( + grad, exp_avg, exp_avg_sq, betas=betas, eps=eps, correct_bias=True, step=1 + ) + param.grad = grad.clone() + opt.step() + torch.testing.assert_close(param, old_param - lr * expected_update, atol=0, rtol=0) + torch.testing.assert_close(opt.state[param]["exp_avg"], exp_avg, atol=0, rtol=0) + torch.testing.assert_close(opt.state[param]["exp_avg_sq"], exp_avg_sq, atol=0, rtol=0) + + @parameterized.parameters( + {"shape": (3, 3)}, + {"shape": (15, 31)}, + {"shape": (127, 255)}, + ) + def test_frob_normalize_preserves_parameter_norm(self, shape) -> None: + """LaProp with ``frob_normalize=True`` restores the pre-step Frobenius norm.""" + param = torch.nn.Parameter(torch.randint(1, 5, shape, device=self.device, dtype=torch.float32)) + opt = LaProp([param], lr=0.25, weight_decay=0.0, frob_normalize=True) + param.grad = torch.randint(-2, 3, shape, device=self.device, dtype=torch.float32) + original_norm = param.norm() + opt.step() + torch.testing.assert_close(param.norm(), original_norm, atol=0, rtol=2e-5) + + +class SimplifiedAdEMAMixOptimizerTest( + _CommonScalarOptimizerTests, _HasBetasTests, _HasEpsTests, parameterized.TestCase +): + OPTIMIZER_CLS = SimplifiedAdEMAMix + STATE_KEYS = ("exp_avg", "exp_avg_sq", "step") + + def test_invalid_min_beta_fast_raises_value_error(self) -> None: + param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) + with self.assertRaisesRegex(ValueError, "Invalid min_beta_fast"): + SimplifiedAdEMAMix([param], min_beta_fast=1.0) + + def test_invalid_num_beta_fast_warmup_steps_raises_value_error(self) -> None: + param = torch.nn.Parameter(torch.randn(3, 3, device=self.device)) + with self.assertRaisesRegex(ValueError, "Invalid num_beta_fast_warmup_steps"): + SimplifiedAdEMAMix([param], num_beta_fast_warmup_steps=-1) + + @parameterized.parameters( + {"shape": (3, 3)}, + {"shape": (15, 31)}, + ) + def test_optimizer_step_matches_update_function(self, shape) -> None: + """SimplifiedAdEMAMix optimizer delegates update math to ``calculate_sim_ademamix_update``.""" + lr = 0.25 + betas = (0.9999, 0.999) + eps = 1e-8 + min_beta_fast = 0.9 + alpha = 2.0 + param = torch.nn.Parameter(torch.randn(*shape, device=self.device)) + grad = torch.randn_like(param) + opt = SimplifiedAdEMAMix( + [param], + lr=lr, + betas=betas, + eps=eps, + weight_decay=0.0, + min_beta_fast=min_beta_fast, + alpha=alpha, + ) + old_param = param.detach().clone() + exp_avg = torch.zeros_like(param) + exp_avg_sq = torch.zeros_like(param) + expected_update = update_functions.calculate_sim_ademamix_update( + grad, + exp_avg, + exp_avg_sq, + betas=betas, + eps=eps, + correct_bias=True, + step=1, + num_beta_fast_warmup_steps=None, + min_beta_fast=min_beta_fast, + alpha=alpha, + ) + param.grad = grad.clone() + opt.step() + torch.testing.assert_close(param, old_param - lr * expected_update, atol=0, rtol=0) + torch.testing.assert_close(opt.state[param]["exp_avg"], exp_avg, atol=0, rtol=0) + torch.testing.assert_close(opt.state[param]["exp_avg_sq"], exp_avg_sq, atol=0, rtol=0) + + if __name__ == "__main__": testing.absltest.main()