66before using momentum, resulting in orthogonalized updates for weight matrices.
77This can improve training stability and convergence for certain architectures.
88
9- Reference:
10- https://github.com/KellerJordan/Muon
9+ Algorithm
10+ ---------
11+ For >=2D parameters (weight matrices), the Muon update is:
12+
13+ 1. Momentum update (Nesterov):
14+ m_t = beta * m_{t-1} + (1 - beta) * g_t
15+ update = beta * m_t + (1 - beta) * g_t
16+
17+ 2. Newton-Schulz orthogonalization (quintic iteration):
18+ X_0 = G / ||G||_F
19+ X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T
20+ Coefficients: a=3.4445, b=-4.7750, c=2.0315
21+
22+ 3. Scaling: scale = coeff * sqrt(max(m, n)) [match-RMS mode]
23+ scale = sqrt(max(1, m/n)) [rectangular mode]
24+
25+ 4. Parameter update: theta -= lr * scale * orth(update)
26+
27+ For 1D parameters (biases, norms), standard Adam is used.
28+
29+ Dtype Behavior
30+ --------------
31+ - Newton-Schulz iterations: always bfloat16 (matches official Muon)
32+ - Adam state (exp_avg, exp_avg_sq): always float32 for numerical stability
33+ - Gradients: cast to parameter dtype before momentum update
34+
35+ Reference
36+ ---------
37+ https://github.com/KellerJordan/Muon
1138"""
1239
1340from __future__ import (
3057 Iterable ,
3158 )
3259
60+ # ============================================================================
61+ # Constants
62+ # ============================================================================
63+
64+ # Newton-Schulz iteration count
65+ NS_STEPS : int = 5
66+ # Numerical stability epsilon for norm clamping
67+ NS_EPS : float = 1e-7
68+ # Adam epsilon for numerical stability
69+ ADAM_EPS : float = 1e-7
70+
3371
3472def zeropower_via_newtonschulz5 (
3573 G : torch .Tensor ,
36- steps : int = 5 ,
37- eps : float = 1e-7 ,
3874) -> torch .Tensor :
3975 """
4076 Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration.
4177
4278 Uses quintic Newton-Schulz iteration to compute the orthogonal component of the
4379 input matrix. This is equivalent to computing U from the SVD decomposition G = USV^T.
4480
81+ Mathematical formulation:
82+ X_0 = G / ||G||_F
83+ X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T
84+ Coefficients: a=3.4445, b=-4.7750, c=2.0315
85+
4586 This implementation matches PyTorch official Muon behavior: it always performs
4687 Newton-Schulz in bfloat16 and returns a bfloat16 tensor.
4788
4889 Parameters
4990 ----------
5091 G : torch.Tensor
5192 Input matrix to orthogonalize with shape (..., M, N).
52- steps : int
53- Number of Newton-Schulz iterations with default 5.
54- eps : float
55- Numerical stability epsilon for norm clamping with default 1e-7.
5693
5794 Returns
5895 -------
@@ -63,14 +100,10 @@ def zeropower_via_newtonschulz5(
63100 ------
64101 ValueError
65102 If G has fewer than 2 dimensions.
66- ValueError
67- If steps >= 100 (guard for efficiency).
68103 """
69104 # === Step 1. Validate ===
70105 if G .ndim < 2 :
71106 raise ValueError ("Input must have at least 2 dimensions (..., M, N)." )
72- if steps >= 100 :
73- raise ValueError ("Number of steps must be less than 100 for efficiency." )
74107
75108 a , b , c = (3.4445 , - 4.7750 , 2.0315 )
76109
@@ -82,10 +115,10 @@ def zeropower_via_newtonschulz5(
82115 X = X .mT
83116
84117 # === Step 4. Normalize Frobenius norm to at most 1 ===
85- X = X / X .norm (dim = (- 2 , - 1 ), keepdim = True ).clamp (min = eps )
118+ X = X / X .norm (dim = (- 2 , - 1 ), keepdim = True ).clamp (min = NS_EPS )
86119
87120 # === Step 5. Newton-Schulz iterations with fused GEMM ===
88- for _ in range (steps ):
121+ for _ in range (NS_STEPS ):
89122 A = X @ X .mT
90123 # gram_update = b*A + c*(A@A) via addmm/baddbmm
91124 # X = a*X + gram_update@X via addmm/baddbmm
@@ -107,11 +140,13 @@ def _prepare_muon_momentum(
107140 grad : torch .Tensor ,
108141 momentum_buffer : torch .Tensor ,
109142 beta : float ,
110- nesterov : bool ,
111143) -> tuple [torch .Tensor , tuple [int , ...]]:
112144 """
113145 Prepare momentum update and reshape for batched Newton-Schulz.
114146
147+ Uses Nesterov momentum: update = beta*m_t + (1-beta)*g_t, where m_t is
148+ the updated momentum buffer.
149+
115150 Parameters
116151 ----------
117152 grad : torch.Tensor
@@ -120,8 +155,6 @@ def _prepare_muon_momentum(
120155 Momentum buffer (will be updated in-place).
121156 beta : float
122157 Momentum coefficient.
123- nesterov : bool
124- Whether to use Nesterov momentum.
125158
126159 Returns
127160 -------
@@ -132,7 +165,8 @@ def _prepare_muon_momentum(
132165 """
133166 # === Step 1. Update momentum buffer ===
134167 momentum_buffer .lerp_ (grad , 1 - beta )
135- update = grad .lerp (momentum_buffer , beta ) if nesterov else momentum_buffer
168+ # Nesterov lookahead
169+ update = grad .lerp (momentum_buffer , beta )
136170
137171 # === Step 2. Handle tensor -> matrix reshape ===
138172 original_shape = update .shape
@@ -147,12 +181,24 @@ class MuonOptimizer(Optimizer):
147181 Muon optimizer with auxiliary Adam for non-matrix parameters.
148182
149183 This optimizer applies different update rules based on parameter dimensionality:
150- - For 2D+ parameters (weight matrices): Muon update with Newton-Schulz orthogonalization
184+ - For >=2D parameters (weight matrices): Muon update with Newton-Schulz orthogonalization
151185 - For 1D parameters (biases, layer norms): Standard Adam update
152186
153187 This hybrid approach is effective because Muon's orthogonalization is designed
154188 for weight matrices, while Adam is more suitable for biases and normalization params.
155189
190+ Update Rules
191+ ------------
192+ Muon (>=2D params):
193+ 1. Momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t
194+ 2. Nesterov lookahead: update = beta*m_t + (1-beta)*g_t
195+ 3. Newton-Schulz orthogonalization: orth = NS(update)
196+ 4. Scaling: scale = coeff*sqrt(max(m,n)) or sqrt(max(1, m/n))
197+ 5. Parameter update: theta -= lr * scale * orth
198+
199+ Adam (1D params):
200+ Standard Adam with bias correction, all computations in float32.
201+
156202 Parameters
157203 ----------
158204 params : iterable
@@ -163,21 +209,15 @@ class MuonOptimizer(Optimizer):
163209 Momentum coefficient for Muon with default 0.95.
164210 weight_decay : float
165211 Weight decay coefficient (applied only to >=2D params) with default 0.001.
166- ns_steps : int
167- Number of Newton-Schulz iterations with default 5.
168212 adam_betas : tuple[float, float]
169213 Adam beta coefficients with default (0.9, 0.95).
170- adam_eps : float
171- Adam epsilon with default 1e-7.
172- nesterov : bool
173- Whether to use Nesterov momentum for Muon with default True.
174214 lr_adjust : float
175- Learning rate adjustment factor for Adam (1D params) .
176- - If lr_adjust <= 0: use match-RMS scaling for Muon update ,
215+ Learning rate adjustment mode for Muon scaling and Adam learning rate .
216+ - If lr_adjust <= 0: use match-RMS scaling for Muon,
177217 scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly.
178- - If lr_adjust > 0: use rectangular correction for Muon update ,
179- scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust as learning rate .
180- Default is 10 .0 (Adam lr = lr/10 ).
218+ - If lr_adjust > 0: use rectangular correction for Muon,
219+ scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust.
220+ Default is 0 .0 (match-RMS scaling ).
181221 lr_adjust_coeff : float
182222 Coefficient for match-RMS scaling with default 0.2.
183223 Only effective when lr_adjust <= 0.
@@ -197,21 +237,15 @@ def __init__(
197237 lr : float = 1e-3 ,
198238 momentum : float = 0.95 ,
199239 weight_decay : float = 0.001 ,
200- ns_steps : int = 5 ,
201240 adam_betas : tuple [float , float ] = (0.9 , 0.95 ),
202- adam_eps : float = 1e-7 ,
203- nesterov : bool = True ,
204- lr_adjust : float = 10.0 ,
241+ lr_adjust : float = 0.0 ,
205242 lr_adjust_coeff : float = 0.2 ,
206243 ) -> None :
207244 defaults = {
208245 "lr" : lr ,
209246 "momentum" : momentum ,
210247 "weight_decay" : weight_decay ,
211- "ns_steps" : ns_steps ,
212248 "adam_betas" : adam_betas ,
213- "adam_eps" : adam_eps ,
214- "nesterov" : nesterov ,
215249 "lr_adjust" : lr_adjust ,
216250 "lr_adjust_coeff" : lr_adjust_coeff ,
217251 }
@@ -232,8 +266,8 @@ def step(
232266
233267 Returns
234268 -------
235- loss : float, optional
236- The loss value if closure is provided.
269+ torch.Tensor | None
270+ The loss value if closure is provided, otherwise None .
237271 """
238272 loss = None
239273 if closure is not None :
@@ -244,10 +278,7 @@ def step(
244278 lr = group ["lr" ]
245279 momentum = group ["momentum" ]
246280 weight_decay = group ["weight_decay" ]
247- ns_steps = group ["ns_steps" ]
248281 adam_betas = group ["adam_betas" ]
249- adam_eps = group ["adam_eps" ]
250- nesterov = group ["nesterov" ]
251282 lr_adjust = group ["lr_adjust" ]
252283 lr_adjust_coeff = group ["lr_adjust_coeff" ]
253284
@@ -276,7 +307,7 @@ def step(
276307 if "momentum_buffer" not in state :
277308 state ["momentum_buffer" ] = torch .zeros_like (grad )
278309 update , orig_shape = _prepare_muon_momentum (
279- grad , state ["momentum_buffer" ], momentum , nesterov
310+ grad , state ["momentum_buffer" ], momentum
280311 )
281312 muon_entries .append ((p , update , orig_shape ))
282313 else :
@@ -315,7 +346,7 @@ def step(
315346 bias_corr2 = 1 - state ["beta2_pow" ]
316347 step_size = adam_lr / bias_corr1
317348 # FP32 computation: compute full delta in FP32, then cast once
318- denom = (adam_exp_avg_sqs [i ] / bias_corr2 ).sqrt ().add_ (adam_eps )
349+ denom = (adam_exp_avg_sqs [i ] / bias_corr2 ).sqrt ().add_ (ADAM_EPS )
319350 delta_fp32 = - step_size * (adam_exp_avgs [i ] / denom )
320351 p .add_ (delta_fp32 .to (p .dtype ))
321352
@@ -351,7 +382,7 @@ def step(
351382 if len (bucket ) == 1 :
352383 # Single parameter: 2D path with addmm (faster, correct behavior)
353384 p , update , orig_shape = bucket [0 ]
354- orth = zeropower_via_newtonschulz5 (update , steps = ns_steps )
385+ orth = zeropower_via_newtonschulz5 (update )
355386 # === Apply scaling and update parameters ===
356387 orth .mul_ (scale )
357388 p .add_ (orth .reshape (orig_shape ), alpha = - lr )
@@ -360,7 +391,7 @@ def step(
360391 stacked = torch .stack (
361392 [item [1 ].contiguous () for item in bucket ], dim = 0
362393 )
363- orth_stacked = zeropower_via_newtonschulz5 (stacked , steps = ns_steps )
394+ orth_stacked = zeropower_via_newtonschulz5 (stacked )
364395 # === Apply scaling and update parameters ===
365396 orth_stacked .mul_ (scale )
366397 for i , (p , _ , orig_shape ) in enumerate (bucket ):
0 commit comments