Skip to content

Commit 4bdb9e6

Browse files
bzantiumclaude
authored andcommitted
Add NorMuon optimizer (row-wise adaptive normalization for Muon)
NorMuon extends Muon with row-wise second moment tracking and adaptive normalization after Newton-Schulz orthogonalization, ensuring balanced neuron utilization with negligible memory overhead. Reference: Li et al., "NorMuon: Making Muon more efficient and scalable" (arxiv:2510.05491), 2025
1 parent 26ba03c commit 4bdb9e6

4 files changed

Lines changed: 671 additions & 0 deletions

File tree

optax/contrib/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@
5454
from optax.contrib._muon import MuonDimensionNumbers
5555
from optax.contrib._muon import MuonState
5656
from optax.contrib._muon import scale_by_muon
57+
from optax.contrib._normuon import normuon
58+
from optax.contrib._normuon import NorMuonState
59+
from optax.contrib._normuon import scale_by_normuon
5760
from optax.contrib._privacy import differentially_private_aggregate
5861
from optax.contrib._privacy import DifferentiallyPrivateAggregateState
5962
from optax.contrib._privacy import dpsgd

optax/contrib/_normuon.py

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
# Copyright 2026 DeepMind Technologies Limited. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""NorMuon optimizer."""
16+
17+
import math
18+
from typing import Any, Callable, Literal, NamedTuple, Optional, Union
19+
20+
import jax
21+
import jax.numpy as jnp
22+
23+
from optax._src import alias
24+
from optax._src import base
25+
from optax._src import combine
26+
from optax._src import numerics
27+
from optax._src import transform
28+
from optax._src import utils
29+
from optax.contrib._muon import _DEFAULT_NS_COEFFS
30+
from optax.contrib._muon import _is_weight_dim_nums
31+
from optax.contrib._muon import _NS_COEFFS_PRESET_DICT
32+
from optax.contrib._muon import MuonDimensionNumbers
33+
from optax.contrib._muon import orthogonalize_via_newton_schulz
34+
from optax.contrib._muon import scale_by_shape
35+
from optax.contrib._muon import WeightDimNumOrFn
36+
from optax.transforms import _masking
37+
import optax.tree
38+
39+
40+
class NorMuonState(NamedTuple):
41+
"""State for the NorMuon algorithm."""
42+
count: jax.typing.ArrayLike # shape=(), dtype=jnp.int32.
43+
mu: base.Updates
44+
nu: base.Updates
45+
ns_coeffs: jax.typing.ArrayLike
46+
47+
48+
def scale_by_normuon(
49+
ns_coeffs: Union[
50+
tuple[jax.typing.ArrayLike, jax.typing.ArrayLike,
51+
jax.typing.ArrayLike],
52+
tuple[
53+
tuple[
54+
jax.typing.ArrayLike, jax.typing.ArrayLike,
55+
jax.typing.ArrayLike
56+
],
57+
...,
58+
],
59+
] = _DEFAULT_NS_COEFFS,
60+
ns_steps: jax.typing.ArrayLike = 5,
61+
beta: jax.typing.ArrayLike = 0.95,
62+
beta2: jax.typing.ArrayLike = 0.95,
63+
eps: jax.typing.ArrayLike = 1e-8,
64+
mu_dtype: Optional[jax.typing.DTypeLike] = None,
65+
*,
66+
nesterov: bool = True,
67+
preconditioning: Literal[
68+
'frobenius', 'spectral', 'aol', 'schatten'
69+
] = 'frobenius',
70+
weight_dimension_numbers: WeightDimNumOrFn | None = None,
71+
normuon_scale: jax.typing.ArrayLike = 0.2,
72+
) -> base.GradientTransformation:
73+
r"""Rescale updates according to the NorMuon algorithm.
74+
75+
NorMuon extends Muon with row-wise adaptive normalization after the
76+
Newton-Schulz orthogonalization step. This balances neuron utilization
77+
with negligible memory overhead compared to Muon.
78+
79+
Args:
80+
ns_coeffs: Coefficients for the Newton-Schulz method.
81+
ns_steps: Number of Newton-Schulz iterations.
82+
Ignored if ``ns_coeffs`` is a tuple of tuples.
83+
beta: Decay rate for the exponentially weighted average of grads.
84+
beta2: Decay rate for the row-wise second moment estimates.
85+
eps: Term added to denominators to improve numerical stability.
86+
mu_dtype: Data type of the momentum accumulator.
87+
nesterov: Whether to use Nesterov momentum.
88+
preconditioning: Which preconditioning method to use before NS iterations.
89+
weight_dimension_numbers: An optional tree with the same structure as the
90+
params of ``MuonDimensionNumbers``s, specifying how to reshape the
91+
parameters before and after the orthogonalization OR a callable returning
92+
such a tree. None implies that all parameters are 2D matrices.
93+
normuon_scale: Adaptive learning rate coefficient (default 0.2).
94+
95+
Returns:
96+
A :class:`optax.GradientTransformation` object.
97+
98+
References:
99+
Li et al., `NorMuon: Making Muon more efficient and scalable
100+
<https://arxiv.org/abs/2510.05491>`_, 2025
101+
"""
102+
mu_dtype = utils.canonicalize_dtype(mu_dtype)
103+
104+
def init_fn(params):
105+
mu = optax.tree.zeros_like(params, dtype=mu_dtype)
106+
# nu stores row-wise second moments: shape (m,) for a (m, n) param.
107+
nu = jax.tree.map(lambda x: jnp.zeros(x.shape[:-1], dtype=mu_dtype),
108+
params)
109+
ns_coeffs_ = jnp.asarray(ns_coeffs)
110+
111+
if ns_coeffs_.ndim > 2 or ns_coeffs_.shape[-1] != 3:
112+
raise ValueError(
113+
f'ns_coeffs must have shape (3,) or (n, 3), got {ns_coeffs_.shape}'
114+
)
115+
if ns_coeffs_.ndim == 2:
116+
if ns_coeffs_.shape[0] > ns_steps:
117+
raise ValueError(f'Not enough coeffs to perform {ns_steps} steps')
118+
ns_coeffs_ = ns_coeffs_[-ns_steps:]
119+
120+
return NorMuonState(
121+
count=jnp.zeros([], jnp.int32),
122+
mu=mu,
123+
nu=nu,
124+
ns_coeffs=ns_coeffs_,
125+
)
126+
127+
def update_fn(updates, state, params=None):
128+
del params
129+
if callable(weight_dimension_numbers):
130+
resolved_weight_dim_nums = weight_dimension_numbers(updates)
131+
else:
132+
resolved_weight_dim_nums = weight_dimension_numbers
133+
134+
mu = optax.tree.update_moment(updates, state.mu, beta, 1)
135+
count_inc = numerics.safe_increment(state.count)
136+
if nesterov:
137+
mu_hat = jax.tree.map(
138+
lambda m, g: beta * m + (1 - beta) * g,
139+
optax.tree.bias_correction(
140+
mu, beta, numerics.safe_increment(count_inc)
141+
),
142+
optax.tree.bias_correction(updates, beta, count_inc),
143+
)
144+
else:
145+
mu_hat = optax.tree.bias_correction(mu, beta, count_inc)
146+
147+
# Apply Newton-Schulz orthogonalization.
148+
ortho = jax.tree.map(
149+
lambda x, dim_num: orthogonalize_via_newton_schulz(
150+
x, state.ns_coeffs, ns_steps, preconditioning, eps, dim_num),
151+
mu_hat, resolved_weight_dim_nums, is_leaf=_is_weight_dim_nums)
152+
153+
# Row-wise second moment tracking.
154+
def _update_nu(o, nu_prev):
155+
row_sq = jnp.mean(o ** 2, axis=-1)
156+
return beta2 * nu_prev + (1 - beta2) * row_sq
157+
158+
new_nu = jax.tree.map(_update_nu, ortho, state.nu)
159+
160+
# Row-wise normalization and adaptive scaling (paper Algorithm 1).
161+
def _normalize(o, nu_new):
162+
o_hat = o / (jnp.sqrt(nu_new[..., None]) + eps)
163+
m_n = math.prod(o.shape[-2:]) if o.ndim >= 2 else o.shape[-1]
164+
frob = jnp.linalg.norm(o_hat, ord='fro')
165+
scale = normuon_scale * jnp.sqrt(m_n) / (frob + eps)
166+
return o_hat * scale
167+
168+
new_updates = jax.tree.map(_normalize, ortho, new_nu)
169+
170+
mu = optax.tree.cast(mu, mu_dtype)
171+
return new_updates, NorMuonState(
172+
count=count_inc,
173+
mu=mu,
174+
nu=new_nu,
175+
ns_coeffs=state.ns_coeffs,
176+
)
177+
178+
return base.GradientTransformation(init_fn, update_fn)
179+
180+
181+
def normuon(
182+
learning_rate: base.ScalarOrSchedule,
183+
ns_coeffs: Union[
184+
tuple[jax.typing.ArrayLike, jax.typing.ArrayLike,
185+
jax.typing.ArrayLike],
186+
tuple[
187+
tuple[
188+
jax.typing.ArrayLike, jax.typing.ArrayLike,
189+
jax.typing.ArrayLike
190+
],
191+
...,
192+
],
193+
str,
194+
] = _DEFAULT_NS_COEFFS,
195+
ns_steps: jax.typing.ArrayLike = 5,
196+
beta: jax.typing.ArrayLike = 0.95,
197+
beta2: jax.typing.ArrayLike = 0.95,
198+
eps: jax.typing.ArrayLike = 1e-8,
199+
weight_decay: jax.typing.ArrayLike = 0.0,
200+
weight_decay_mask: Optional[
201+
Union[Any, Callable[[base.Params], Any]]
202+
] = None,
203+
mu_dtype: Optional[jax.typing.DTypeLike] = None,
204+
*,
205+
nesterov: bool = True,
206+
preconditioning: Literal[
207+
'frobenius', 'spectral', 'aol', 'schatten'
208+
] = 'frobenius',
209+
adam_b1: jax.typing.ArrayLike = 0.9,
210+
adam_b2: jax.typing.ArrayLike = 0.999,
211+
adam_eps_root: jax.typing.ArrayLike = 0.0,
212+
adam_weight_decay: jax.typing.ArrayLike = 0.0,
213+
adam_learning_rate: base.ScalarOrSchedule | None = None,
214+
muon_weight_dimension_numbers: WeightDimNumOrFn | None = None,
215+
normuon_scale: jax.typing.ArrayLike = 0.2,
216+
consistent_rms: jax.typing.ArrayLike | None = None,
217+
) -> base.GradientTransformation:
218+
r"""NorMuon: Muon with row-wise adaptive normalization.
219+
220+
NorMuon extends the Muon optimizer with row-wise adaptive normalization
221+
applied after Newton-Schulz orthogonalization. This ensures balanced
222+
neuron utilization with negligible memory overhead compared to Muon.
223+
224+
Like Muon, NorMuon is only defined for 2D parameters (matrices). Non-2D
225+
parameters are passed through an AdamW optimizer.
226+
227+
Args:
228+
learning_rate: A global scaling factor, either fixed or evolving along
229+
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
230+
ns_coeffs: Coefficients for the Newton-Schulz method (can be a string
231+
indicator for a preset). Existing presets: ``muon``, ``dion``.
232+
ns_steps: Number of Newton-Schulz iterations.
233+
Ignored if ``ns_coeffs`` is a tuple of tuples.
234+
beta: Decay rate for the exponentially weighted average of grads.
235+
beta2: Decay rate for the row-wise second moment estimates.
236+
eps: Term added to the denominator to improve numerical stability.
237+
weight_decay: Strength of the weight decay regularization.
238+
weight_decay_mask: A tree with same structure as (or a prefix of) the
239+
params PyTree, or a Callable that returns such a pytree given the
240+
params/updates. The leaves should be booleans, ``True`` for
241+
leaves/subtrees you want to apply the weight decay to, and ``False``
242+
for those you want to skip.
243+
mu_dtype: Data type of the momentum accumulator.
244+
nesterov: Whether to use Nesterov momentum.
245+
preconditioning: Which preconditioning method to use before NS iterations.
246+
adam_b1: Exponential decay rate for Adam's first moment estimates.
247+
adam_b2: Exponential decay rate for Adam's second moment estimates.
248+
adam_eps_root: Epsilon to stabilize division in Adam, square root version.
249+
adam_weight_decay: Weight decay factor for Adam.
250+
adam_learning_rate: Auxiliary learning rate for the Adam optimizer.
251+
If ``None``, the learning rate for Adam defaults to the same as NorMuon.
252+
muon_weight_dimension_numbers: An optional tree of
253+
``MuonDimensionNumbers``s, specifying how to reshape the parameters for
254+
orthogonalization. A ``None`` value indicates that the parameter is not
255+
a NorMuon parameter and will be optimized with Adam. If not provided,
256+
NorMuon is applied to all 2D parameters.
257+
normuon_scale: Adaptive learning rate coefficient (default 0.2).
258+
consistent_rms: An optional float to activate consistent RMS scaling.
259+
260+
Returns:
261+
The corresponding :class:`optax.GradientTransformation`.
262+
263+
References:
264+
Li et al., `NorMuon: Making Muon more efficient and scalable
265+
<https://arxiv.org/abs/2510.05491>`_, 2025
266+
"""
267+
268+
if adam_learning_rate is None:
269+
adam_learning_rate = learning_rate
270+
271+
if isinstance(ns_coeffs, str):
272+
if ns_coeffs not in _NS_COEFFS_PRESET_DICT:
273+
raise ValueError(f'Unknown ns_coeff preset string: {ns_coeffs}')
274+
ns_coeffs_ = _NS_COEFFS_PRESET_DICT[ns_coeffs]
275+
else:
276+
ns_coeffs_ = ns_coeffs
277+
278+
# None at root indicates the default 2D rule.
279+
if muon_weight_dimension_numbers is None:
280+
param_labels = lambda params: jax.tree.map(
281+
lambda x: 'normuon' if x.ndim == 2 else 'adam', params
282+
)
283+
muon_weight_dimension_numbers = MuonDimensionNumbers()
284+
else:
285+
def param_labels(params):
286+
dim_nums = (muon_weight_dimension_numbers(params)
287+
if callable(muon_weight_dimension_numbers)
288+
else muon_weight_dimension_numbers)
289+
populate_subtree_ = lambda dim_num, x: jax.tree.map(
290+
lambda y: 'normuon' if dim_num is not None else 'adam', x)
291+
return jax.tree.map(
292+
populate_subtree_, dim_nums, params,
293+
is_leaf=lambda x: x is None or _is_weight_dim_nums(x))
294+
295+
def muon_weight_dim_nums_fn(params):
296+
dim_nums = (muon_weight_dimension_numbers(params)
297+
if callable(muon_weight_dimension_numbers)
298+
else muon_weight_dimension_numbers)
299+
mask = jax.tree.map(
300+
lambda label: label == 'normuon', param_labels(params))
301+
is_leaf = lambda x: (x is None or _is_weight_dim_nums(x)
302+
or isinstance(x, _masking.MaskedNode))
303+
populate_subtree_ = lambda dim_nums, submask: jax.tree.map(
304+
lambda m: dim_nums if m else _masking.MaskedNode(), submask)
305+
return jax.tree.map(populate_subtree_, dim_nums, mask, is_leaf=is_leaf)
306+
307+
return combine.partition(
308+
transforms={
309+
'normuon': combine.chain(
310+
scale_by_normuon(
311+
ns_coeffs=ns_coeffs_,
312+
ns_steps=ns_steps,
313+
beta=beta,
314+
beta2=beta2,
315+
eps=eps,
316+
mu_dtype=mu_dtype,
317+
nesterov=nesterov,
318+
preconditioning=preconditioning,
319+
weight_dimension_numbers=muon_weight_dim_nums_fn,
320+
normuon_scale=normuon_scale,
321+
),
322+
scale_by_shape(
323+
weight_dimension_numbers=muon_weight_dim_nums_fn,
324+
consistent_rms=consistent_rms,
325+
),
326+
transform.add_decayed_weights(weight_decay, weight_decay_mask),
327+
transform.scale_by_learning_rate(learning_rate),
328+
),
329+
'adam': alias.adamw(
330+
learning_rate=adam_learning_rate,
331+
b1=adam_b1,
332+
b2=adam_b2,
333+
eps=eps,
334+
eps_root=adam_eps_root,
335+
weight_decay=adam_weight_decay,
336+
mu_dtype=mu_dtype,
337+
nesterov=nesterov,
338+
),
339+
},
340+
param_labels=param_labels,
341+
)

0 commit comments

Comments
 (0)