Skip to content

Add NorMuon optimizer (row-wise adaptive normalization for Muon)#1651

Open
bzantium wants to merge 1 commit intogoogle-deepmind:mainfrom
bzantium:add-normuon-optimizer
Open

Add NorMuon optimizer (row-wise adaptive normalization for Muon)#1651
bzantium wants to merge 1 commit intogoogle-deepmind:mainfrom
bzantium:add-normuon-optimizer

Conversation

@bzantium
Copy link
Copy Markdown
Contributor

Summary

  • Add NorMuon optimizer to optax.contrib based on arXiv:2510.05491
  • NorMuon extends Muon with row-wise second moment tracking and adaptive normalization after Newton-Schulz orthogonalization
  • Ensures balanced neuron utilization with negligible memory overhead vs Muon

Algorithm

NorMuon adds row-wise adaptive normalization after Newton-Schulz orthogonalization:

  1. Momentum: M_t = β₁ M_{t-1} + (1 - β₁) G_t
  2. NS orthogonalization: O_t = NS5(M_t) (same as Muon)
  3. Row-wise second moment: v_t = β₂ v_{t-1} + (1 - β₂) mean_cols(O_t²) where v_t ∈ ℝ^m
  4. Normalize: Ô_t = O_t / (√V_t + ε)
  5. Adaptive scale: η̂ = 0.2 · η · √(mn) / ‖Ô_t‖_F

Memory overhead vs Muon: only m additional scalars per (m, n) parameter — negligible.

Changes

  • New: optax/contrib/_normuon.pyscale_by_normuon() and normuon() implementations
  • New: optax/contrib/_normuon_test.py — 8 unit tests
  • New: optax/contrib/_normuon_benchmark_test.py — training benchmark comparing NorMuon vs Muon
  • Modified: optax/contrib/__init__.py — export normuon, NorMuonState, scale_by_normuon

Benchmark Results

Training a 2-layer MLP (32→64→1) on synthetic regression data for 500 steps:

  • Both Muon and NorMuon converge (final loss < 10% of initial)
  • NorMuon achieves comparable final loss to Muon (within 5x)
  • No NaN/Inf or numerical instability observed
  • All parameters (2D weights via NorMuon, 1D biases via Adam) updated correctly
  • Loss decreases monotonically (within 50-step windows)

Test plan

  • pytest optax/contrib/_normuon_test.py — 8 unit tests pass
  • test_normuon_vs_muon_convergence — both optimizers converge, NorMuon comparable to Muon
  • test_normuon_no_side_effects — no NaN/Inf, monotonic decrease, finite params
  • test_normuon_mixed_params_training — all params updated correctly

@google-cla
Copy link
Copy Markdown

google-cla Bot commented Mar 30, 2026

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@bzantium bzantium force-pushed the add-normuon-optimizer branch 2 times, most recently from 277c1f7 to 7df9c09 Compare March 30, 2026 15:27
NorMuon extends Muon with row-wise second moment tracking and adaptive
normalization after Newton-Schulz orthogonalization, ensuring balanced
neuron utilization with negligible memory overhead.

Reference: Li et al., "NorMuon: Making Muon more efficient and scalable"
(arxiv:2510.05491), 2025
@bzantium bzantium force-pushed the add-normuon-optimizer branch from 107a36f to 4bdb9e6 Compare March 31, 2026 00:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant