As described by MuonH paper, the parameter's norm is normalized to its F-norm at first step, so this value should be computed at most once.
The current implementation:
|
@override |
|
def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None: |
|
"""Store the original weight norm and normalize the update using Frobenius norm. |
|
|
|
Args: |
|
p: The parameter tensor. |
|
update: The orthogonalized gradient tensor. |
|
""" |
|
# Use user-specified radius or compute R = ||W_t||_F (Frobenius norm) |
|
R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item() |
|
self.state[p]["hyperball_R"] = R |
|
|
|
# Normalize the update in-place and scale by R |
|
# This modifies update to be: R * normalize(update) using Frobenius norm. |
|
update_norm = update.norm().clamp_min(self.hyperball_eps) |
|
update.mul_(R / update_norm) |
Computes p.norm in each optimizer call, which leads to redundancy. My suggestion is something like:
@override
def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None:
if "hyperball_R" not in self.state:
R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item()
self.state[p]["hyperball_R"] = R
else:
R = self.state[p]["hyperball_R"]
update_norm = update.norm().clamp_min(self.hyperball_eps)
update.mul_(R / update_norm)
As described by MuonH paper, the parameter's norm is normalized to its F-norm at first step, so this value should be computed at most once.
The current implementation:
Emerging-Optimizers/emerging_optimizers/orthogonalized_optimizers/muon_hyperball.py
Lines 84 to 99 in 3b6c5fb
Computes
p.normin each optimizer call, which leads to redundancy. My suggestion is something like: