Skip to content

Commit b995d2e

Browse files
authored
Unify scalar optimizers. (#210)
* unify scalar optimizers. add optimizer classes Signed-off-by: Hao Wu <skyw@nvidia.com>
1 parent a0e376b commit b995d2e

8 files changed

Lines changed: 754 additions & 811 deletions

File tree

emerging_optimizers/scalar_optimizers/__init__.py

Lines changed: 180 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,183 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
from emerging_optimizers.scalar_optimizers.laprop import LaProp
16-
from emerging_optimizers.scalar_optimizers.lion import Lion
15+
from typing import Any, override
16+
17+
import torch
18+
from absl import logging
19+
from torch.optim.optimizer import ParamsT
20+
21+
from emerging_optimizers.mixin import WeightDecayT
22+
from emerging_optimizers.registry import register_optimizer
23+
from emerging_optimizers.scalar_optimizers.base import (
24+
SingleMomentumOptimizer,
25+
TwoMomentsOptimizer,
26+
_validate_common_hparams,
27+
)
28+
from emerging_optimizers.scalar_optimizers.update_functions import (
29+
calculate_laprop_update,
30+
calculate_lion_update,
31+
calculate_signum_update,
32+
calculate_sim_ademamix_update,
33+
)
34+
35+
36+
__all__ = [
37+
"LaProp",
38+
"Lion",
39+
"Signum",
40+
"SimplifiedAdEMAMix",
41+
"SingleMomentumOptimizer",
42+
"TwoMomentsOptimizer",
43+
]
44+
45+
46+
@register_optimizer("lion")
47+
class Lion(SingleMomentumOptimizer):
48+
"""Lion optimizer (Chen et al., 2023): sign-based update with a single first-moment EMA."""
49+
50+
def __init__(
51+
self,
52+
params: ParamsT,
53+
lr: float = 1e-4,
54+
betas: tuple[float, float] = (0.9, 0.99),
55+
weight_decay: float = 0.01,
56+
*,
57+
weight_decay_method: WeightDecayT = "decoupled",
58+
) -> None:
59+
_validate_common_hparams(lr=lr, betas=betas, weight_decay=weight_decay)
60+
super().__init__(
61+
params,
62+
defaults=dict(lr=lr, betas=betas, weight_decay=weight_decay),
63+
update_fn=calculate_lion_update,
64+
update_kwarg_names=("betas",),
65+
weight_decay_method=weight_decay_method,
66+
)
67+
68+
69+
@register_optimizer("signum")
70+
class Signum(SingleMomentumOptimizer):
71+
"""Sign-SGD / Signum optimizer (Bernstein et al., 2018): sign of a bias-corrected single-moment EMA."""
72+
73+
def __init__(
74+
self,
75+
params: ParamsT,
76+
lr: float = 1e-3,
77+
momentum: float = 0.9,
78+
weight_decay: float = 0.0,
79+
*,
80+
correct_bias: bool = True,
81+
nesterov: bool = False,
82+
use_shape_scaling: bool = False,
83+
weight_decay_method: WeightDecayT = "decoupled",
84+
) -> None:
85+
_validate_common_hparams(lr=lr, weight_decay=weight_decay)
86+
if not 0.0 <= momentum < 1.0:
87+
raise ValueError(f"Invalid momentum: {momentum}")
88+
super().__init__(
89+
params,
90+
defaults=dict(
91+
lr=lr,
92+
momentum=momentum,
93+
weight_decay=weight_decay,
94+
correct_bias=correct_bias,
95+
nesterov=nesterov,
96+
use_shape_scaling=use_shape_scaling,
97+
),
98+
update_fn=calculate_signum_update,
99+
update_kwarg_names=("momentum", "correct_bias", "nesterov", "use_shape_scaling"),
100+
weight_decay_method=weight_decay_method,
101+
)
102+
103+
104+
@register_optimizer("laprop")
105+
class LaProp(TwoMomentsOptimizer):
106+
"""LaProp optimizer (Ziyin et al., 2020): Adam with the gradient normalized before the first-moment update."""
107+
108+
def __init__(
109+
self,
110+
params: ParamsT,
111+
lr: float = 1e-3,
112+
betas: tuple[float, float] = (0.9, 0.999),
113+
eps: float = 1e-8,
114+
weight_decay: float = 0.0,
115+
*,
116+
correct_bias: bool = True,
117+
frob_normalize: bool = False,
118+
weight_decay_method: WeightDecayT = "decoupled",
119+
) -> None:
120+
_validate_common_hparams(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
121+
if frob_normalize and weight_decay != 0.0:
122+
logging.error("LaProp with frob_normalize=True is intended to be used with weight_decay=0.0.")
123+
self.frob_normalize = frob_normalize
124+
super().__init__(
125+
params,
126+
defaults=dict(
127+
lr=lr,
128+
betas=betas,
129+
eps=eps,
130+
weight_decay=weight_decay,
131+
correct_bias=correct_bias,
132+
),
133+
update_fn=calculate_laprop_update,
134+
update_kwarg_names=("betas", "eps", "correct_bias"),
135+
weight_decay_method=weight_decay_method,
136+
)
137+
138+
@override
139+
def pre_step_inplace(self, p: torch.Tensor, group: dict) -> Any:
140+
return p.data.norm() if self.frob_normalize else None
141+
142+
@override
143+
def post_step_inplace(self, p: torch.Tensor, group: dict, ctx: Any) -> None:
144+
if self.frob_normalize:
145+
pre_norm = ctx
146+
p.data.mul_(pre_norm / p.data.norm().clamp_min(group["eps"]))
147+
148+
149+
@register_optimizer("sim_ademamix")
150+
class SimplifiedAdEMAMix(TwoMomentsOptimizer):
151+
"""Simplified AdEMAMix: two-buffer variant mixing alpha-scaled current gradient into a theory-style first-moment EMA."""
152+
153+
def __init__(
154+
self,
155+
params: ParamsT,
156+
lr: float = 1e-3,
157+
betas: tuple[float, float] = (0.9999, 0.999),
158+
eps: float = 1e-8,
159+
weight_decay: float = 0.0,
160+
*,
161+
correct_bias: bool = True,
162+
num_beta_fast_warmup_steps: int | None = None,
163+
min_beta_fast: float = 0.9,
164+
alpha: float = 2.0,
165+
weight_decay_method: WeightDecayT = "decoupled",
166+
) -> None:
167+
_validate_common_hparams(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
168+
if not 0.0 <= min_beta_fast < 1.0:
169+
raise ValueError(f"Invalid min_beta_fast: {min_beta_fast}")
170+
if num_beta_fast_warmup_steps is not None and num_beta_fast_warmup_steps <= 0:
171+
raise ValueError(f"Invalid num_beta_fast_warmup_steps: {num_beta_fast_warmup_steps}")
172+
super().__init__(
173+
params,
174+
defaults=dict(
175+
lr=lr,
176+
betas=betas,
177+
eps=eps,
178+
weight_decay=weight_decay,
179+
correct_bias=correct_bias,
180+
num_beta_fast_warmup_steps=num_beta_fast_warmup_steps,
181+
min_beta_fast=min_beta_fast,
182+
alpha=alpha,
183+
),
184+
update_fn=calculate_sim_ademamix_update,
185+
update_kwarg_names=(
186+
"betas",
187+
"eps",
188+
"correct_bias",
189+
"num_beta_fast_warmup_steps",
190+
"min_beta_fast",
191+
"alpha",
192+
),
193+
weight_decay_method=weight_decay_method,
194+
)
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from collections.abc import Callable
16+
from typing import TYPE_CHECKING, Any, ClassVar, override
17+
18+
19+
if TYPE_CHECKING:
20+
from typing import overload
21+
22+
import torch
23+
from torch.optim.optimizer import ParamsT
24+
25+
from emerging_optimizers.mixin import WeightDecayMixin, WeightDecayT
26+
27+
28+
__all__ = [
29+
"SingleMomentumOptimizer",
30+
"TwoMomentsOptimizer",
31+
]
32+
33+
34+
def _validate_common_hparams(
35+
*,
36+
lr: float | None = None,
37+
betas: tuple[float, ...] | None = None,
38+
eps: float | None = None,
39+
weight_decay: float | None = None,
40+
) -> None:
41+
"""Validates the hyperparameters shared by most scalar optimizers."""
42+
if lr is not None and lr < 0.0:
43+
raise ValueError(f"Invalid learning rate: {lr}")
44+
if betas is not None:
45+
for i, b in enumerate(betas):
46+
if not 0.0 <= b < 1.0:
47+
raise ValueError(f"Invalid beta at index {i}: {b}")
48+
if eps is not None and eps < 0.0:
49+
raise ValueError(f"Invalid epsilon value: {eps}")
50+
if weight_decay is not None and weight_decay < 0.0:
51+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
52+
53+
54+
class _ScalarOptimizerBase(WeightDecayMixin, torch.optim.Optimizer):
55+
"""Shared implementation for scalar optimizers grouped by state shape.
56+
57+
Subclasses set ``state_keys`` as a ``ClassVar``. The base lazily allocates one
58+
zero-initialized buffer per name plus a per-parameter ``step`` counter, then
59+
dispatches each step to a constructor-supplied ``update_fn`` whose signature is
60+
``update_fn(grad, *buffers, **kwargs) -> Tensor``.
61+
62+
Hyperparameters forwarded into ``update_fn`` are selected from the parameter
63+
group via ``update_kwarg_names`` (a tuple of dict keys present in the
64+
``defaults`` mapping). The per-parameter ``step`` is always forwarded as
65+
``step=state["step"]``, so every update function must accept a ``step`` kwarg.
66+
67+
Subclasses can additionally override :meth:`pre_step_inplace` /
68+
:meth:`post_step_inplace` to bracket the per-parameter update with custom
69+
logic (e.g. norm preservation).
70+
"""
71+
72+
state_keys: ClassVar[tuple[str, ...]]
73+
74+
def __init__(
75+
self,
76+
params: ParamsT,
77+
defaults: dict[str, Any],
78+
*,
79+
update_fn: Callable[..., torch.Tensor],
80+
update_kwarg_names: tuple[str, ...],
81+
weight_decay_method: WeightDecayT = "decoupled",
82+
) -> None:
83+
missing = set(update_kwarg_names) - set(defaults.keys())
84+
if missing:
85+
raise ValueError(
86+
f"update_kwarg_names {sorted(missing)} not present in defaults (keys: {sorted(defaults.keys())})"
87+
)
88+
self.update_fn = update_fn
89+
self.update_kwarg_names = update_kwarg_names
90+
self.weight_decay_method = weight_decay_method
91+
super().__init__(params, defaults)
92+
93+
@torch.no_grad()
94+
def _init_group(
95+
self,
96+
group: dict,
97+
skip_non_grad_params: bool = True,
98+
) -> None:
99+
"""Performs lazy state initialization for parameters."""
100+
for p in group["params"]:
101+
if skip_non_grad_params and p.grad is None:
102+
continue
103+
state = self.state[p]
104+
if len(state) == 0:
105+
for key in self.state_keys:
106+
state[key] = torch.zeros_like(p.data)
107+
state["step"] = 0
108+
109+
def pre_step_inplace(self, p: torch.Tensor, group: dict) -> Any:
110+
"""Hook called before weight decay and the update. Return value is forwarded to ``post_step_inplace``."""
111+
return None
112+
113+
def post_step_inplace(self, p: torch.Tensor, group: dict, ctx: Any) -> None:
114+
"""Hook called after the update. Receives the value returned by ``pre_step_inplace``."""
115+
return None
116+
117+
if TYPE_CHECKING:
118+
119+
@overload
120+
def step(self, closure: None = ...) -> None: ...
121+
122+
@overload
123+
def step(self, closure: Callable[[], float]) -> float: ...
124+
125+
@torch.no_grad() # type: ignore[misc]
126+
@override
127+
def step(self, closure: Callable[[], float] | None = None) -> float | None:
128+
"""Perform a single optimization step.
129+
130+
Note:
131+
When ``weight_decay_method="l2"``, ``p.grad`` is modified in-place
132+
(the L2 penalty ``weight_decay * p`` is added to the gradient).
133+
If you need the original gradient after this call, clone it beforehand.
134+
135+
Args:
136+
closure: Unsupported; must be ``None``.
137+
"""
138+
if closure is not None:
139+
raise ValueError("closure is not supported")
140+
141+
for group in self.param_groups:
142+
self._init_group(group)
143+
144+
lr = group["lr"]
145+
weight_decay = group["weight_decay"]
146+
update_kwargs = {key: group[key] for key in self.update_kwarg_names}
147+
148+
for p in group["params"]:
149+
if p.grad is None:
150+
continue # pragma: no cover
151+
152+
state = self.state[p]
153+
state["step"] += 1
154+
update_kwargs["step"] = state["step"]
155+
156+
ctx = self.pre_step_inplace(p, group)
157+
self._apply_weight_decay_inplace(p.data, p.grad, lr, weight_decay)
158+
159+
buffers = tuple(state[key] for key in self.state_keys)
160+
update = self.update_fn(p.grad, *buffers, **update_kwargs)
161+
p.data.add_(update, alpha=-lr)
162+
163+
self.post_step_inplace(p, group, ctx)
164+
165+
return None
166+
167+
168+
class SingleMomentumOptimizer(_ScalarOptimizerBase):
169+
"""Base for scalar optimizers tracking a single first-moment EMA buffer."""
170+
171+
state_keys: ClassVar[tuple[str, ...]] = ("exp_avg",)
172+
173+
174+
class TwoMomentsOptimizer(_ScalarOptimizerBase):
175+
"""Base for Adam-style scalar optimizers tracking first + second moment buffers."""
176+
177+
state_keys: ClassVar[tuple[str, ...]] = ("exp_avg", "exp_avg_sq")

0 commit comments

Comments
 (0)