Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 180 additions & 2 deletions emerging_optimizers/scalar_optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,183 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from emerging_optimizers.scalar_optimizers.laprop import LaProp
from emerging_optimizers.scalar_optimizers.lion import Lion
from typing import Any, override

import torch
from absl import logging
from torch.optim.optimizer import ParamsT

from emerging_optimizers.mixin import WeightDecayT
from emerging_optimizers.registry import register_optimizer
from emerging_optimizers.scalar_optimizers.base import (
SingleMomentumOptimizer,
TwoMomentsOptimizer,
_validate_common_hparams,
)
from emerging_optimizers.scalar_optimizers.update_functions import (
calculate_laprop_update,
calculate_lion_update,
calculate_signum_update,
calculate_sim_ademamix_update,
)


__all__ = [
"LaProp",
"Lion",
"Signum",
"SimplifiedAdEMAMix",
"SingleMomentumOptimizer",
"TwoMomentsOptimizer",
]


@register_optimizer("lion")
class Lion(SingleMomentumOptimizer):
"""Lion optimizer (Chen et al., 2023): sign-based update with a single first-moment EMA."""
Comment thread
skyw marked this conversation as resolved.

def __init__(
self,
params: ParamsT,
lr: float = 1e-4,
betas: tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.01,
*,
weight_decay_method: WeightDecayT = "decoupled",
) -> None:
_validate_common_hparams(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(
params,
defaults=dict(lr=lr, betas=betas, weight_decay=weight_decay),
update_fn=calculate_lion_update,
update_kwarg_names=("betas",),
weight_decay_method=weight_decay_method,
)


@register_optimizer("signum")
class Signum(SingleMomentumOptimizer):
"""Sign-SGD / Signum optimizer (Bernstein et al., 2018): sign of a bias-corrected single-moment EMA."""

def __init__(
self,
params: ParamsT,
lr: float = 1e-3,
momentum: float = 0.9,
weight_decay: float = 0.0,
*,
correct_bias: bool = True,
nesterov: bool = False,
use_shape_scaling: bool = False,
weight_decay_method: WeightDecayT = "decoupled",
) -> None:
_validate_common_hparams(lr=lr, weight_decay=weight_decay)
if not 0.0 <= momentum < 1.0:
raise ValueError(f"Invalid momentum: {momentum}")
super().__init__(
params,
defaults=dict(
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
correct_bias=correct_bias,
nesterov=nesterov,
use_shape_scaling=use_shape_scaling,
),
update_fn=calculate_signum_update,
update_kwarg_names=("momentum", "correct_bias", "nesterov", "use_shape_scaling"),
weight_decay_method=weight_decay_method,
)


@register_optimizer("laprop")
class LaProp(TwoMomentsOptimizer):
"""LaProp optimizer (Ziyin et al., 2020): Adam with the gradient normalized before the first-moment update."""

def __init__(
self,
params: ParamsT,
lr: float = 1e-3,
betas: tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
*,
correct_bias: bool = True,
frob_normalize: bool = False,
weight_decay_method: WeightDecayT = "decoupled",
) -> None:
_validate_common_hparams(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
if frob_normalize and weight_decay != 0.0:
logging.error("LaProp with frob_normalize=True is intended to be used with weight_decay=0.0.")
self.frob_normalize = frob_normalize
super().__init__(
params,
defaults=dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
correct_bias=correct_bias,
),
update_fn=calculate_laprop_update,
update_kwarg_names=("betas", "eps", "correct_bias"),
weight_decay_method=weight_decay_method,
)

@override
def pre_step_inplace(self, p: torch.Tensor, group: dict) -> Any:
return p.data.norm() if self.frob_normalize else None

@override
def post_step_inplace(self, p: torch.Tensor, group: dict, ctx: Any) -> None:
if self.frob_normalize:
pre_norm = ctx
p.data.mul_(pre_norm / p.data.norm().clamp_min(group["eps"]))


@register_optimizer("sim_ademamix")
class SimplifiedAdEMAMix(TwoMomentsOptimizer):
"""Simplified AdEMAMix: two-buffer variant mixing alpha-scaled current gradient into a theory-style first-moment EMA."""

def __init__(
self,
params: ParamsT,
lr: float = 1e-3,
betas: tuple[float, float] = (0.9999, 0.999),
eps: float = 1e-8,
weight_decay: float = 0.0,
*,
correct_bias: bool = True,
num_beta_fast_warmup_steps: int | None = None,
min_beta_fast: float = 0.9,
alpha: float = 2.0,
weight_decay_method: WeightDecayT = "decoupled",
) -> None:
_validate_common_hparams(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
if not 0.0 <= min_beta_fast < 1.0:
raise ValueError(f"Invalid min_beta_fast: {min_beta_fast}")
if num_beta_fast_warmup_steps is not None and num_beta_fast_warmup_steps <= 0:
raise ValueError(f"Invalid num_beta_fast_warmup_steps: {num_beta_fast_warmup_steps}")
super().__init__(
params,
defaults=dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
correct_bias=correct_bias,
num_beta_fast_warmup_steps=num_beta_fast_warmup_steps,
min_beta_fast=min_beta_fast,
alpha=alpha,
),
update_fn=calculate_sim_ademamix_update,
update_kwarg_names=(
"betas",
"eps",
"correct_bias",
"num_beta_fast_warmup_steps",
"min_beta_fast",
"alpha",
),
weight_decay_method=weight_decay_method,
)
177 changes: 177 additions & 0 deletions emerging_optimizers/scalar_optimizers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, ClassVar, override


if TYPE_CHECKING:
from typing import overload

import torch
from torch.optim.optimizer import ParamsT

from emerging_optimizers.mixin import WeightDecayMixin, WeightDecayT


__all__ = [
"SingleMomentumOptimizer",
"TwoMomentsOptimizer",
]


def _validate_common_hparams(
*,
lr: float | None = None,
betas: tuple[float, ...] | None = None,
eps: float | None = None,
weight_decay: float | None = None,
) -> None:
"""Validates the hyperparameters shared by most scalar optimizers."""
if lr is not None and lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if betas is not None:
for i, b in enumerate(betas):
if not 0.0 <= b < 1.0:
raise ValueError(f"Invalid beta at index {i}: {b}")
if eps is not None and eps < 0.0:
raise ValueError(f"Invalid epsilon value: {eps}")
if weight_decay is not None and weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}")


class _ScalarOptimizerBase(WeightDecayMixin, torch.optim.Optimizer):
"""Shared implementation for scalar optimizers grouped by state shape.

Subclasses set ``state_keys`` as a ``ClassVar``. The base lazily allocates one
zero-initialized buffer per name plus a per-parameter ``step`` counter, then
dispatches each step to a constructor-supplied ``update_fn`` whose signature is
``update_fn(grad, *buffers, **kwargs) -> Tensor``.

Hyperparameters forwarded into ``update_fn`` are selected from the parameter
group via ``update_kwarg_names`` (a tuple of dict keys present in the
``defaults`` mapping). The per-parameter ``step`` is always forwarded as
``step=state["step"]``, so every update function must accept a ``step`` kwarg.

Subclasses can additionally override :meth:`pre_step_inplace` /
:meth:`post_step_inplace` to bracket the per-parameter update with custom
logic (e.g. norm preservation).
"""

state_keys: ClassVar[tuple[str, ...]]

def __init__(
self,
params: ParamsT,
defaults: dict[str, Any],
*,
update_fn: Callable[..., torch.Tensor],
update_kwarg_names: tuple[str, ...],
weight_decay_method: WeightDecayT = "decoupled",
) -> None:
missing = set(update_kwarg_names) - set(defaults.keys())
if missing:
raise ValueError(
f"update_kwarg_names {sorted(missing)} not present in defaults (keys: {sorted(defaults.keys())})"
)
self.update_fn = update_fn
self.update_kwarg_names = update_kwarg_names
self.weight_decay_method = weight_decay_method
super().__init__(params, defaults)

@torch.no_grad()
def _init_group(
self,
group: dict,
skip_non_grad_params: bool = True,
) -> None:
"""Performs lazy state initialization for parameters."""
for p in group["params"]:
if skip_non_grad_params and p.grad is None:
continue
state = self.state[p]
if len(state) == 0:
for key in self.state_keys:
state[key] = torch.zeros_like(p.data)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe add warning that states are always the same shape as p?

many optims like Adafactor do not have that

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I consider that is implied by scalar optimizer, everything is element wise so states can only have same shape.
I think a note in the docstring would be enough?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure sounds good

state["step"] = 0

def pre_step_inplace(self, p: torch.Tensor, group: dict) -> Any:
"""Hook called before weight decay and the update. Return value is forwarded to ``post_step_inplace``."""
return None

def post_step_inplace(self, p: torch.Tensor, group: dict, ctx: Any) -> None:
"""Hook called after the update. Receives the value returned by ``pre_step_inplace``."""
return None

if TYPE_CHECKING:

@overload
def step(self, closure: None = ...) -> None: ...

@overload
def step(self, closure: Callable[[], float]) -> float: ...
Comment thread
skyw marked this conversation as resolved.

@torch.no_grad() # type: ignore[misc]
@override
def step(self, closure: Callable[[], float] | None = None) -> float | None:
"""Perform a single optimization step.

Note:
When ``weight_decay_method="l2"``, ``p.grad`` is modified in-place
(the L2 penalty ``weight_decay * p`` is added to the gradient).
If you need the original gradient after this call, clone it beforehand.

Args:
closure: Unsupported; must be ``None``.
"""
if closure is not None:
raise ValueError("closure is not supported")

for group in self.param_groups:
self._init_group(group)

lr = group["lr"]
weight_decay = group["weight_decay"]
update_kwargs = {key: group[key] for key in self.update_kwarg_names}

for p in group["params"]:
if p.grad is None:
continue # pragma: no cover
Comment thread
skyw marked this conversation as resolved.

state = self.state[p]
state["step"] += 1
update_kwargs["step"] = state["step"]

ctx = self.pre_step_inplace(p, group)
self._apply_weight_decay_inplace(p.data, p.grad, lr, weight_decay)

buffers = tuple(state[key] for key in self.state_keys)
update = self.update_fn(p.grad, *buffers, **update_kwargs)
p.data.add_(update, alpha=-lr)

self.post_step_inplace(p, group, ctx)

return None


class SingleMomentumOptimizer(_ScalarOptimizerBase):
"""Base for scalar optimizers tracking a single first-moment EMA buffer."""

state_keys: ClassVar[tuple[str, ...]] = ("exp_avg",)


class TwoMomentsOptimizer(_ScalarOptimizerBase):
"""Base for Adam-style scalar optimizers tracking first + second moment buffers."""

state_keys: ClassVar[tuple[str, ...]] = ("exp_avg", "exp_avg_sq")
Loading
Loading