@@ -80,14 +80,44 @@ def apply_momentum(
8080 return update
8181
8282
83- def apply_scaling (grad : Tensor , rms_scale : bool = False ) -> Tensor :
84- """Post-NS scaling: either Moonlight RMS or Keller Jordan max(1, m/n)^0.5."""
85- if rms_scale :
86- # https://github.com/MoonshotAI/Moonlight/blob/5afcb6911077e7f182d05865fe90d9f39abcbcbd/examples/toy_train.py#L146
87- grad *= 0.2 * math .sqrt (max (grad .shape [1 ], grad .shape [0 ]))
83+ def apply_scaling (
84+ grad : Tensor ,
85+ mode : str = "spectral" ,
86+ extra_scale_factor : float = 1.0 ,
87+ ) -> Tensor :
88+ """Post-Newton-Schulz update scaling.
89+
90+ Naming aligned with Megatron-Core / emerging_optimizers (NVIDIA-NeMo).
91+
92+ Final scale = scale_factor(mode) * extra_scale_factor, where:
93+ - 'spectral' : sqrt(max(m, n))
94+ Kimi/Moonlight (arXiv:2502.16982); emerging_optimizers default.
95+ - 'unit_rms_norm' : sqrt(m / n)
96+ Scion (arXiv:2502.07529) / Bernstein
97+ (https://jeremybernste.in/writing/deriving-muon).
98+ - 'shape_scaling' : max(1, m / n)**0.5
99+ Keller Jordan original (https://kellerjordan.github.io/posts/muon).
100+
101+ Set extra_scale_factor=0.2 with mode='spectral' to reproduce the legacy
102+ Moonlight `https://github.com/MoonshotAI/Moonlight/blob/5afcb6911077e7f182d05865fe90d9f39abcbcbd/examples/toy_train.py#L146`
103+ setting (= 0.2 * sqrt(max(m, n))), which
104+ approximately matches AdamW's update RMS norm so a single lr works for
105+ both Muon and the AdamW backend.
106+ """
107+ m = grad .size (- 2 )
108+ n = grad .size (- 1 )
109+ if mode == "spectral" :
110+ scale = math .sqrt (max (m , n ))
111+ elif mode == "unit_rms_norm" :
112+ scale = math .sqrt (m / n )
113+ elif mode == "shape_scaling" :
114+ scale = max (1 , m / n ) ** 0.5
88115 else :
89- # https://github.com/KellerJordan/Muon/blob/f90a42b28e00b8d9d2d05865fe90d9f39abcbcbd/muon.py#L40
90- grad *= max (1 , grad .size (- 2 ) / grad .size (- 1 )) ** 0.5
116+ raise ValueError (
117+ f"Invalid muon_scale_mode { mode !r} . Valid: "
118+ "{'spectral', 'unit_rms_norm', 'shape_scaling'}."
119+ )
120+ grad *= scale * extra_scale_factor
91121 return grad
92122
93123
@@ -194,7 +224,11 @@ def finish(self):
194224 else :
195225 scatter (grad .to_local (), None , src = dest_rank , group = pg , async_op = False )
196226
197- update = apply_scaling (grad , self .group ["rms_scale" ])
227+ update = apply_scaling (
228+ grad ,
229+ self .group ["scale_mode" ],
230+ self .group ["extra_scale_factor" ],
231+ )
198232
199233 self .param .mul_ (1 - self .group ["lr" ] * self .group ["weight_decay" ])
200234 self .param .add_ (update .reshape (self .param .shape ), alpha = - self .group ["lr" ])
@@ -272,7 +306,11 @@ def finish(self):
272306 new_local = distribute_tensor (g_full , mesh , placements ).to_local ()
273307 grad .to_local ().copy_ (new_local )
274308
275- update = apply_scaling (grad , self .group ["rms_scale" ])
309+ update = apply_scaling (
310+ grad ,
311+ self .group ["scale_mode" ],
312+ self .group ["extra_scale_factor" ],
313+ )
276314
277315 self .param .mul_ (1 - self .group ["lr" ] * self .group ["weight_decay" ])
278316 self .param .add_ (update .reshape (self .param .shape ), alpha = - self .group ["lr" ])
@@ -309,7 +347,11 @@ def start(self):
309347 )
310348 update = zeropower_via_newtonschulz5 (update , self .group ["ns_steps" ])
311349 update = update .to (self .param .grad .dtype )
312- update = apply_scaling (update , self .group ["rms_scale" ])
350+ update = apply_scaling (
351+ update ,
352+ self .group ["scale_mode" ],
353+ self .group ["extra_scale_factor" ],
354+ )
313355 self .param .mul_ (1 - self .group ["lr" ] * self .group ["weight_decay" ])
314356 self .param .add_ (update .reshape (self .param .shape ), alpha = - self .group ["lr" ])
315357
@@ -330,7 +372,8 @@ class Muon(torch.optim.Optimizer):
330372
331373 Notable changes:
332374 - DTensor/FSDP2 native: uses gather/scatter for distributed NS instead of DDP.
333- - ``rms_scale`` argument following the Moonlight paper (https://arxiv.org/abs/2502.16982).
375+ - ``scale_mode`` / ``extra_scale_factor`` arguments aligned with Megatron-Core /
376+ emerging_optimizers (NVIDIA-NeMo). See :func:`apply_scaling` for details.
334377
335378 Example::
336379
@@ -340,7 +383,7 @@ class Muon(torch.optim.Optimizer):
340383 ])
341384
342385 Param group args (``use_muon=True``):
343- lr, momentum, weight_decay, rms_scale , nesterov, ns_steps
386+ lr, momentum, weight_decay, scale_mode, extra_scale_factor , nesterov, ns_steps
344387
345388 Param group args (``use_muon=False``):
346389 lr, betas, eps, weight_decay
@@ -353,7 +396,8 @@ def __init__(self, param_groups):
353396 group .setdefault ("lr" , 0.02 )
354397 group .setdefault ("momentum" , 0.95 )
355398 group .setdefault ("weight_decay" , 0 )
356- group .setdefault ("rms_scale" , True )
399+ group .setdefault ("scale_mode" , "spectral" )
400+ group .setdefault ("extra_scale_factor" , 1.0 )
357401 group .setdefault ("nesterov" , True )
358402 group .setdefault ("ns_steps" , 5 )
359403 assert set (group .keys ()) == {
@@ -362,7 +406,8 @@ def __init__(self, param_groups):
362406 "momentum" ,
363407 "weight_decay" ,
364408 "use_muon" ,
365- "rms_scale" ,
409+ "scale_mode" ,
410+ "extra_scale_factor" ,
366411 "nesterov" ,
367412 "ns_steps" ,
368413 }
0 commit comments