diff --git a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py index c9a8c09d..8d75b747 100644 --- a/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py +++ b/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py @@ -35,8 +35,8 @@ class MuonHyperball(muon.Muon): W_{t+1} = R \\cdot \\text{normalize}(W_t - \\text{lr} \\cdot R \\cdot \\text{normalize}(\\text{update})) - where :math:`R` is the Frobenius norm of :math:`W_t` (or a user-specified radius). This keeps - the weight matrix at constant scale while updating. + where :math:`R` is the user-specified Frobenius norm. This keeps the weight matrix at + constant scale while updating. Warning: This optimizer is experimental and may change in future versions. @@ -49,52 +49,60 @@ class MuonHyperball(muon.Muon): *args: Arguments passed to Muon. hyperball_eps: Epsilon for numerical stability in normalization. Default: ``1e-8``. - hyperball_radius: Fixed radius for the hyperball. If ``None`` (default), - uses each parameter's initial Frobenius norm as its radius. If specified, all - parameters will be rescaled to have this radius at initialization. + hyperball_radius: Fixed radius for the hyperball. All parameters must + already have this Frobenius norm at construction time. **kwargs: Keyword arguments passed to Muon. + Raises: + ValueError: If any parameter has zero norm, or if a parameter's + Frobenius norm does not match ``hyperball_radius``. + """ def __init__( self, *args: Any, hyperball_eps: float = 1e-8, - hyperball_radius: float | None = None, + hyperball_radius: float, **kwargs: Any, ) -> None: self.hyperball_eps = hyperball_eps self.hyperball_radius = hyperball_radius super().__init__(*args, **kwargs) - # Validate and optionally rescale parameters based on hyperball_radius. with torch.no_grad(): for group in self.param_groups: for p in group["params"]: p_norm = p.norm() - # Validate that parameter has non-zero norm. - if p_norm.item() == 0: + if p_norm == 0: + raise ValueError( + "MuonHyperball requires all parameters to have non-zero norm. " + "Found parameter with zero norm." + ) + if not torch.isclose( + p_norm, + torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device), + rtol=1e-5, + atol=1e-8, + ): raise ValueError( - "MuonHyperball requires all parameters to have non-zero norm. Found parameter with zero norm." + f"hyperball_radius={self.hyperball_radius} was specified but a parameter " + f"has Frobenius norm {p_norm.item()}. Rescale your model parameters to the " + f"desired radius before constructing the optimizer." ) - # Rescale parameter to have the specified radius if provided. - if self.hyperball_radius is not None: - p.mul_(self.hyperball_radius / p_norm.clamp_min(self.hyperball_eps)) @override def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None: - """Store the original weight norm and normalize the update using Frobenius norm. + """Normalize the update using Frobenius norm, scaled by R. Args: p: The parameter tensor. update: The orthogonalized gradient tensor. """ - # Use user-specified radius or compute R = ||W_t||_F (Frobenius norm) - R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item() - self.state[p]["hyperball_R"] = R + if "hyperball_R" not in self.state[p]: + self.state[p]["hyperball_R"] = torch.tensor(self.hyperball_radius, dtype=p.dtype, device=p.device) + R = self.state[p]["hyperball_R"] - # Normalize the update in-place and scale by R - # This modifies update to be: R * normalize(update) using Frobenius norm. update_norm = update.norm().clamp_min(self.hyperball_eps) update.mul_(R / update_norm)