Skip to content

Possible redundant norm computation in MuonHyperball #155

@Harry-Chen

Description

@Harry-Chen

As described by MuonH paper, the parameter's norm is normalized to its F-norm at first step, so this value should be computed at most once.

The current implementation:

@override
def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None:
"""Store the original weight norm and normalize the update using Frobenius norm.
Args:
p: The parameter tensor.
update: The orthogonalized gradient tensor.
"""
# Use user-specified radius or compute R = ||W_t||_F (Frobenius norm)
R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item()
self.state[p]["hyperball_R"] = R
# Normalize the update in-place and scale by R
# This modifies update to be: R * normalize(update) using Frobenius norm.
update_norm = update.norm().clamp_min(self.hyperball_eps)
update.mul_(R / update_norm)

Computes p.norm in each optimizer call, which leads to redundancy. My suggestion is something like:

    @override
    def pre_weight_update_fn_inplace(self, p: torch.Tensor, update: torch.Tensor) -> None:
        if "hyperball_R" not in self.state:
            R = self.hyperball_radius if self.hyperball_radius is not None else p.norm().item()
            self.state[p]["hyperball_R"] = R
        else:
            R = self.state[p]["hyperball_R"]
        update_norm = update.norm().clamp_min(self.hyperball_eps)
        update.mul_(R / update_norm)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions