|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | | -from emerging_optimizers.scalar_optimizers.laprop import LaProp |
16 | | -from emerging_optimizers.scalar_optimizers.lion import Lion |
| 15 | +from typing import Any, override |
| 16 | + |
| 17 | +import torch |
| 18 | +from absl import logging |
| 19 | +from torch.optim.optimizer import ParamsT |
| 20 | + |
| 21 | +from emerging_optimizers.mixin import WeightDecayT |
| 22 | +from emerging_optimizers.registry import register_optimizer |
| 23 | +from emerging_optimizers.scalar_optimizers.base import ( |
| 24 | + SingleMomentumOptimizer, |
| 25 | + TwoMomentsOptimizer, |
| 26 | + _validate_common_hparams, |
| 27 | +) |
| 28 | +from emerging_optimizers.scalar_optimizers.update_functions import ( |
| 29 | + calculate_laprop_update, |
| 30 | + calculate_lion_update, |
| 31 | + calculate_signum_update, |
| 32 | + calculate_sim_ademamix_update, |
| 33 | +) |
| 34 | + |
| 35 | + |
| 36 | +__all__ = [ |
| 37 | + "LaProp", |
| 38 | + "Lion", |
| 39 | + "Signum", |
| 40 | + "SimplifiedAdEMAMix", |
| 41 | + "SingleMomentumOptimizer", |
| 42 | + "TwoMomentsOptimizer", |
| 43 | +] |
| 44 | + |
| 45 | + |
| 46 | +@register_optimizer("lion") |
| 47 | +class Lion(SingleMomentumOptimizer): |
| 48 | + """Lion optimizer (Chen et al., 2023): sign-based update with a single first-moment EMA.""" |
| 49 | + |
| 50 | + def __init__( |
| 51 | + self, |
| 52 | + params: ParamsT, |
| 53 | + lr: float = 1e-4, |
| 54 | + betas: tuple[float, float] = (0.9, 0.99), |
| 55 | + weight_decay: float = 0.01, |
| 56 | + *, |
| 57 | + weight_decay_method: WeightDecayT = "decoupled", |
| 58 | + ) -> None: |
| 59 | + _validate_common_hparams(lr=lr, betas=betas, weight_decay=weight_decay) |
| 60 | + super().__init__( |
| 61 | + params, |
| 62 | + defaults=dict(lr=lr, betas=betas, weight_decay=weight_decay), |
| 63 | + update_fn=calculate_lion_update, |
| 64 | + update_kwarg_names=("betas",), |
| 65 | + weight_decay_method=weight_decay_method, |
| 66 | + ) |
| 67 | + |
| 68 | + |
| 69 | +@register_optimizer("signum") |
| 70 | +class Signum(SingleMomentumOptimizer): |
| 71 | + """Sign-SGD / Signum optimizer (Bernstein et al., 2018): sign of a bias-corrected single-moment EMA.""" |
| 72 | + |
| 73 | + def __init__( |
| 74 | + self, |
| 75 | + params: ParamsT, |
| 76 | + lr: float = 1e-3, |
| 77 | + momentum: float = 0.9, |
| 78 | + weight_decay: float = 0.0, |
| 79 | + *, |
| 80 | + correct_bias: bool = True, |
| 81 | + nesterov: bool = False, |
| 82 | + use_shape_scaling: bool = False, |
| 83 | + weight_decay_method: WeightDecayT = "decoupled", |
| 84 | + ) -> None: |
| 85 | + _validate_common_hparams(lr=lr, weight_decay=weight_decay) |
| 86 | + if not 0.0 <= momentum < 1.0: |
| 87 | + raise ValueError(f"Invalid momentum: {momentum}") |
| 88 | + super().__init__( |
| 89 | + params, |
| 90 | + defaults=dict( |
| 91 | + lr=lr, |
| 92 | + momentum=momentum, |
| 93 | + weight_decay=weight_decay, |
| 94 | + correct_bias=correct_bias, |
| 95 | + nesterov=nesterov, |
| 96 | + use_shape_scaling=use_shape_scaling, |
| 97 | + ), |
| 98 | + update_fn=calculate_signum_update, |
| 99 | + update_kwarg_names=("momentum", "correct_bias", "nesterov", "use_shape_scaling"), |
| 100 | + weight_decay_method=weight_decay_method, |
| 101 | + ) |
| 102 | + |
| 103 | + |
| 104 | +@register_optimizer("laprop") |
| 105 | +class LaProp(TwoMomentsOptimizer): |
| 106 | + """LaProp optimizer (Ziyin et al., 2020): Adam with the gradient normalized before the first-moment update.""" |
| 107 | + |
| 108 | + def __init__( |
| 109 | + self, |
| 110 | + params: ParamsT, |
| 111 | + lr: float = 1e-3, |
| 112 | + betas: tuple[float, float] = (0.9, 0.999), |
| 113 | + eps: float = 1e-8, |
| 114 | + weight_decay: float = 0.0, |
| 115 | + *, |
| 116 | + correct_bias: bool = True, |
| 117 | + frob_normalize: bool = False, |
| 118 | + weight_decay_method: WeightDecayT = "decoupled", |
| 119 | + ) -> None: |
| 120 | + _validate_common_hparams(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) |
| 121 | + if frob_normalize and weight_decay != 0.0: |
| 122 | + logging.error("LaProp with frob_normalize=True is intended to be used with weight_decay=0.0.") |
| 123 | + self.frob_normalize = frob_normalize |
| 124 | + super().__init__( |
| 125 | + params, |
| 126 | + defaults=dict( |
| 127 | + lr=lr, |
| 128 | + betas=betas, |
| 129 | + eps=eps, |
| 130 | + weight_decay=weight_decay, |
| 131 | + correct_bias=correct_bias, |
| 132 | + ), |
| 133 | + update_fn=calculate_laprop_update, |
| 134 | + update_kwarg_names=("betas", "eps", "correct_bias"), |
| 135 | + weight_decay_method=weight_decay_method, |
| 136 | + ) |
| 137 | + |
| 138 | + @override |
| 139 | + def pre_step_inplace(self, p: torch.Tensor, group: dict) -> Any: |
| 140 | + return p.data.norm() if self.frob_normalize else None |
| 141 | + |
| 142 | + @override |
| 143 | + def post_step_inplace(self, p: torch.Tensor, group: dict, ctx: Any) -> None: |
| 144 | + if self.frob_normalize: |
| 145 | + pre_norm = ctx |
| 146 | + p.data.mul_(pre_norm / p.data.norm().clamp_min(group["eps"])) |
| 147 | + |
| 148 | + |
| 149 | +@register_optimizer("sim_ademamix") |
| 150 | +class SimplifiedAdEMAMix(TwoMomentsOptimizer): |
| 151 | + """Simplified AdEMAMix: two-buffer variant mixing alpha-scaled current gradient into a theory-style first-moment EMA.""" |
| 152 | + |
| 153 | + def __init__( |
| 154 | + self, |
| 155 | + params: ParamsT, |
| 156 | + lr: float = 1e-3, |
| 157 | + betas: tuple[float, float] = (0.9999, 0.999), |
| 158 | + eps: float = 1e-8, |
| 159 | + weight_decay: float = 0.0, |
| 160 | + *, |
| 161 | + correct_bias: bool = True, |
| 162 | + num_beta_fast_warmup_steps: int | None = None, |
| 163 | + min_beta_fast: float = 0.9, |
| 164 | + alpha: float = 2.0, |
| 165 | + weight_decay_method: WeightDecayT = "decoupled", |
| 166 | + ) -> None: |
| 167 | + _validate_common_hparams(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) |
| 168 | + if not 0.0 <= min_beta_fast < 1.0: |
| 169 | + raise ValueError(f"Invalid min_beta_fast: {min_beta_fast}") |
| 170 | + if num_beta_fast_warmup_steps is not None and num_beta_fast_warmup_steps <= 0: |
| 171 | + raise ValueError(f"Invalid num_beta_fast_warmup_steps: {num_beta_fast_warmup_steps}") |
| 172 | + super().__init__( |
| 173 | + params, |
| 174 | + defaults=dict( |
| 175 | + lr=lr, |
| 176 | + betas=betas, |
| 177 | + eps=eps, |
| 178 | + weight_decay=weight_decay, |
| 179 | + correct_bias=correct_bias, |
| 180 | + num_beta_fast_warmup_steps=num_beta_fast_warmup_steps, |
| 181 | + min_beta_fast=min_beta_fast, |
| 182 | + alpha=alpha, |
| 183 | + ), |
| 184 | + update_fn=calculate_sim_ademamix_update, |
| 185 | + update_kwarg_names=( |
| 186 | + "betas", |
| 187 | + "eps", |
| 188 | + "correct_bias", |
| 189 | + "num_beta_fast_warmup_steps", |
| 190 | + "min_beta_fast", |
| 191 | + "alpha", |
| 192 | + ), |
| 193 | + weight_decay_method=weight_decay_method, |
| 194 | + ) |
0 commit comments