Skip to content

Commit 3ea9e66

Browse files
committed
doc & test: Muon
(cherry picked from commit 46fcb7d)
1 parent b87bf49 commit 3ea9e66

4 files changed

Lines changed: 147 additions & 145 deletions

File tree

deepmd/pt/optimizer/muon.py

Lines changed: 77 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,35 @@
66
before using momentum, resulting in orthogonalized updates for weight matrices.
77
This can improve training stability and convergence for certain architectures.
88
9-
Reference:
10-
https://github.com/KellerJordan/Muon
9+
Algorithm
10+
---------
11+
For >=2D parameters (weight matrices), the Muon update is:
12+
13+
1. Momentum update (Nesterov):
14+
m_t = beta * m_{t-1} + (1 - beta) * g_t
15+
update = beta * m_t + (1 - beta) * g_t
16+
17+
2. Newton-Schulz orthogonalization (quintic iteration):
18+
X_0 = G / ||G||_F
19+
X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T
20+
Coefficients: a=3.4445, b=-4.7750, c=2.0315
21+
22+
3. Scaling: scale = coeff * sqrt(max(m, n)) [match-RMS mode]
23+
scale = sqrt(max(1, m/n)) [rectangular mode]
24+
25+
4. Parameter update: theta -= lr * scale * orth(update)
26+
27+
For 1D parameters (biases, norms), standard Adam is used.
28+
29+
Dtype Behavior
30+
--------------
31+
- Newton-Schulz iterations: always bfloat16 (matches official Muon)
32+
- Adam state (exp_avg, exp_avg_sq): always float32 for numerical stability
33+
- Gradients: cast to parameter dtype before momentum update
34+
35+
Reference
36+
---------
37+
https://github.com/KellerJordan/Muon
1138
"""
1239

1340
from __future__ import (
@@ -30,29 +57,39 @@
3057
Iterable,
3158
)
3259

60+
# ============================================================================
61+
# Constants
62+
# ============================================================================
63+
64+
# Newton-Schulz iteration count
65+
NS_STEPS: int = 5
66+
# Numerical stability epsilon for norm clamping
67+
NS_EPS: float = 1e-7
68+
# Adam epsilon for numerical stability
69+
ADAM_EPS: float = 1e-7
70+
3371

3472
def zeropower_via_newtonschulz5(
3573
G: torch.Tensor,
36-
steps: int = 5,
37-
eps: float = 1e-7,
3874
) -> torch.Tensor:
3975
"""
4076
Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration.
4177
4278
Uses quintic Newton-Schulz iteration to compute the orthogonal component of the
4379
input matrix. This is equivalent to computing U from the SVD decomposition G = USV^T.
4480
81+
Mathematical formulation:
82+
X_0 = G / ||G||_F
83+
X_{k+1} = a*X_k + (b*A_k + c*A_k^2) @ X_k, where A_k = X_k @ X_k^T
84+
Coefficients: a=3.4445, b=-4.7750, c=2.0315
85+
4586
This implementation matches PyTorch official Muon behavior: it always performs
4687
Newton-Schulz in bfloat16 and returns a bfloat16 tensor.
4788
4889
Parameters
4990
----------
5091
G : torch.Tensor
5192
Input matrix to orthogonalize with shape (..., M, N).
52-
steps : int
53-
Number of Newton-Schulz iterations with default 5.
54-
eps : float
55-
Numerical stability epsilon for norm clamping with default 1e-7.
5693
5794
Returns
5895
-------
@@ -63,14 +100,10 @@ def zeropower_via_newtonschulz5(
63100
------
64101
ValueError
65102
If G has fewer than 2 dimensions.
66-
ValueError
67-
If steps >= 100 (guard for efficiency).
68103
"""
69104
# === Step 1. Validate ===
70105
if G.ndim < 2:
71106
raise ValueError("Input must have at least 2 dimensions (..., M, N).")
72-
if steps >= 100:
73-
raise ValueError("Number of steps must be less than 100 for efficiency.")
74107

75108
a, b, c = (3.4445, -4.7750, 2.0315)
76109

@@ -82,10 +115,10 @@ def zeropower_via_newtonschulz5(
82115
X = X.mT
83116

84117
# === Step 4. Normalize Frobenius norm to at most 1 ===
85-
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=eps)
118+
X = X / X.norm(dim=(-2, -1), keepdim=True).clamp(min=NS_EPS)
86119

87120
# === Step 5. Newton-Schulz iterations with fused GEMM ===
88-
for _ in range(steps):
121+
for _ in range(NS_STEPS):
89122
A = X @ X.mT
90123
# gram_update = b*A + c*(A@A) via addmm/baddbmm
91124
# X = a*X + gram_update@X via addmm/baddbmm
@@ -107,11 +140,13 @@ def _prepare_muon_momentum(
107140
grad: torch.Tensor,
108141
momentum_buffer: torch.Tensor,
109142
beta: float,
110-
nesterov: bool,
111143
) -> tuple[torch.Tensor, tuple[int, ...]]:
112144
"""
113145
Prepare momentum update and reshape for batched Newton-Schulz.
114146
147+
Uses Nesterov momentum: update = beta*m_t + (1-beta)*g_t, where m_t is
148+
the updated momentum buffer.
149+
115150
Parameters
116151
----------
117152
grad : torch.Tensor
@@ -120,8 +155,6 @@ def _prepare_muon_momentum(
120155
Momentum buffer (will be updated in-place).
121156
beta : float
122157
Momentum coefficient.
123-
nesterov : bool
124-
Whether to use Nesterov momentum.
125158
126159
Returns
127160
-------
@@ -132,7 +165,8 @@ def _prepare_muon_momentum(
132165
"""
133166
# === Step 1. Update momentum buffer ===
134167
momentum_buffer.lerp_(grad, 1 - beta)
135-
update = grad.lerp(momentum_buffer, beta) if nesterov else momentum_buffer
168+
# Nesterov lookahead
169+
update = grad.lerp(momentum_buffer, beta)
136170

137171
# === Step 2. Handle tensor -> matrix reshape ===
138172
original_shape = update.shape
@@ -147,12 +181,24 @@ class MuonOptimizer(Optimizer):
147181
Muon optimizer with auxiliary Adam for non-matrix parameters.
148182
149183
This optimizer applies different update rules based on parameter dimensionality:
150-
- For 2D+ parameters (weight matrices): Muon update with Newton-Schulz orthogonalization
184+
- For >=2D parameters (weight matrices): Muon update with Newton-Schulz orthogonalization
151185
- For 1D parameters (biases, layer norms): Standard Adam update
152186
153187
This hybrid approach is effective because Muon's orthogonalization is designed
154188
for weight matrices, while Adam is more suitable for biases and normalization params.
155189
190+
Update Rules
191+
------------
192+
Muon (>=2D params):
193+
1. Momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t
194+
2. Nesterov lookahead: update = beta*m_t + (1-beta)*g_t
195+
3. Newton-Schulz orthogonalization: orth = NS(update)
196+
4. Scaling: scale = coeff*sqrt(max(m,n)) or sqrt(max(1, m/n))
197+
5. Parameter update: theta -= lr * scale * orth
198+
199+
Adam (1D params):
200+
Standard Adam with bias correction, all computations in float32.
201+
156202
Parameters
157203
----------
158204
params : iterable
@@ -163,21 +209,15 @@ class MuonOptimizer(Optimizer):
163209
Momentum coefficient for Muon with default 0.95.
164210
weight_decay : float
165211
Weight decay coefficient (applied only to >=2D params) with default 0.001.
166-
ns_steps : int
167-
Number of Newton-Schulz iterations with default 5.
168212
adam_betas : tuple[float, float]
169213
Adam beta coefficients with default (0.9, 0.95).
170-
adam_eps : float
171-
Adam epsilon with default 1e-7.
172-
nesterov : bool
173-
Whether to use Nesterov momentum for Muon with default True.
174214
lr_adjust : float
175-
Learning rate adjustment factor for Adam (1D params).
176-
- If lr_adjust <= 0: use match-RMS scaling for Muon update,
215+
Learning rate adjustment mode for Muon scaling and Adam learning rate.
216+
- If lr_adjust <= 0: use match-RMS scaling for Muon,
177217
scale = lr_adjust_coeff * sqrt(max(m, n)). Adam uses lr directly.
178-
- If lr_adjust > 0: use rectangular correction for Muon update,
179-
scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust as learning rate.
180-
Default is 10.0 (Adam lr = lr/10).
218+
- If lr_adjust > 0: use rectangular correction for Muon,
219+
scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust.
220+
Default is 0.0 (match-RMS scaling).
181221
lr_adjust_coeff : float
182222
Coefficient for match-RMS scaling with default 0.2.
183223
Only effective when lr_adjust <= 0.
@@ -197,21 +237,15 @@ def __init__(
197237
lr: float = 1e-3,
198238
momentum: float = 0.95,
199239
weight_decay: float = 0.001,
200-
ns_steps: int = 5,
201240
adam_betas: tuple[float, float] = (0.9, 0.95),
202-
adam_eps: float = 1e-7,
203-
nesterov: bool = True,
204-
lr_adjust: float = 10.0,
241+
lr_adjust: float = 0.0,
205242
lr_adjust_coeff: float = 0.2,
206243
) -> None:
207244
defaults = {
208245
"lr": lr,
209246
"momentum": momentum,
210247
"weight_decay": weight_decay,
211-
"ns_steps": ns_steps,
212248
"adam_betas": adam_betas,
213-
"adam_eps": adam_eps,
214-
"nesterov": nesterov,
215249
"lr_adjust": lr_adjust,
216250
"lr_adjust_coeff": lr_adjust_coeff,
217251
}
@@ -232,8 +266,8 @@ def step(
232266
233267
Returns
234268
-------
235-
loss : float, optional
236-
The loss value if closure is provided.
269+
torch.Tensor | None
270+
The loss value if closure is provided, otherwise None.
237271
"""
238272
loss = None
239273
if closure is not None:
@@ -244,10 +278,7 @@ def step(
244278
lr = group["lr"]
245279
momentum = group["momentum"]
246280
weight_decay = group["weight_decay"]
247-
ns_steps = group["ns_steps"]
248281
adam_betas = group["adam_betas"]
249-
adam_eps = group["adam_eps"]
250-
nesterov = group["nesterov"]
251282
lr_adjust = group["lr_adjust"]
252283
lr_adjust_coeff = group["lr_adjust_coeff"]
253284

@@ -276,7 +307,7 @@ def step(
276307
if "momentum_buffer" not in state:
277308
state["momentum_buffer"] = torch.zeros_like(grad)
278309
update, orig_shape = _prepare_muon_momentum(
279-
grad, state["momentum_buffer"], momentum, nesterov
310+
grad, state["momentum_buffer"], momentum
280311
)
281312
muon_entries.append((p, update, orig_shape))
282313
else:
@@ -315,7 +346,7 @@ def step(
315346
bias_corr2 = 1 - state["beta2_pow"]
316347
step_size = adam_lr / bias_corr1
317348
# FP32 computation: compute full delta in FP32, then cast once
318-
denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(adam_eps)
349+
denom = (adam_exp_avg_sqs[i] / bias_corr2).sqrt().add_(ADAM_EPS)
319350
delta_fp32 = -step_size * (adam_exp_avgs[i] / denom)
320351
p.add_(delta_fp32.to(p.dtype))
321352

@@ -351,7 +382,7 @@ def step(
351382
if len(bucket) == 1:
352383
# Single parameter: 2D path with addmm (faster, correct behavior)
353384
p, update, orig_shape = bucket[0]
354-
orth = zeropower_via_newtonschulz5(update, steps=ns_steps)
385+
orth = zeropower_via_newtonschulz5(update)
355386
# === Apply scaling and update parameters ===
356387
orth.mul_(scale)
357388
p.add_(orth.reshape(orig_shape), alpha=-lr)
@@ -360,7 +391,7 @@ def step(
360391
stacked = torch.stack(
361392
[item[1].contiguous() for item in bucket], dim=0
362393
)
363-
orth_stacked = zeropower_via_newtonschulz5(stacked, steps=ns_steps)
394+
orth_stacked = zeropower_via_newtonschulz5(stacked)
364395
# === Apply scaling and update parameters ===
365396
orth_stacked.mul_(scale)
366397
for i, (p, _, orig_shape) in enumerate(bucket):

deepmd/pt/train/training.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
171171
"momentum": params.get("momentum", 0.95),
172172
"adam_beta1": params.get("adam_beta1", 0.9),
173173
"adam_beta2": params.get("adam_beta2", 0.95),
174+
"lr_adjust": params.get("lr_adjust", 0.0),
175+
"lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2),
174176
}
175177
return opt_type, opt_param
176178

@@ -748,13 +750,13 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
748750
self.optimizer = MuonOptimizer(
749751
self.wrapper.parameters(),
750752
lr=self.lr_exp.start_lr,
751-
momentum=float(self.opt_param.get("muon_momentum", 0.95)),
753+
momentum=float(self.opt_param.get("momentum", 0.95)),
752754
weight_decay=float(self.opt_param.get("weight_decay", 0.001)),
753755
adam_betas=(
754756
float(self.opt_param.get("adam_beta1", 0.9)),
755757
float(self.opt_param.get("adam_beta2", 0.95)),
756758
),
757-
lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)),
759+
lr_adjust=float(self.opt_param.get("lr_adjust", 0.0)),
758760
lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)),
759761
)
760762
if optimizer_state_dict is not None and self.restart_training:

deepmd/utils/argcheck.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3425,6 +3425,7 @@ def training_args(
34253425
float,
34263426
optional=True,
34273427
default=0.95,
3428+
alias=["muon_momentum"],
34283429
doc=doc_only_pt_supported
34293430
+ "Momentum coefficient for AdaMuon optimizer.",
34303431
),
@@ -3456,7 +3457,7 @@ def training_args(
34563457
"lr_adjust",
34573458
float,
34583459
optional=True,
3459-
default=10.0,
3460+
default=0.0,
34603461
doc=doc_only_pt_supported
34613462
+ "Learning rate adjustment factor for Adam (1D params). "
34623463
"If lr_adjust <= 0: use match-RMS scaling (scale = lr_adjust_coeff * sqrt(max(m, n))), Adam uses lr directly. "

0 commit comments

Comments
 (0)