Skip to content

Commit eb1a5f9

Browse files
committed
feat: add muon_min_2d_dim parameter for Muon optimizer behavior
(cherry picked from commit c6f7e9f)
1 parent 40b5ed4 commit eb1a5f9

4 files changed

Lines changed: 226 additions & 25 deletions

File tree

deepmd/pt/optimizer/muon.py

Lines changed: 186 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,14 @@ def _maybe_compile(
7878
fn: callable,
7979
) -> callable:
8080
"""Compile a function if torch.compile is available."""
81-
if hasattr(torch, "compile"):
82-
return torch.compile(fn, fullgraph=True, dynamic=True)
83-
return fn
81+
if not hasattr(torch, "compile"):
82+
return fn
83+
# Skip compile if default device is CUDA but CUDA is unavailable.
84+
if hasattr(torch, "get_default_device"):
85+
default_device = torch.get_default_device()
86+
if default_device.type == "cuda" and not torch.cuda.is_available():
87+
return fn
88+
return torch.compile(fn, fullgraph=True, dynamic=True)
8489

8590

8691
@_maybe_compile
@@ -181,13 +186,54 @@ def zeropower_via_newtonschulz5(
181186
raise ValueError("Input must be 2D or 3D for Newton-Schulz orthogonalization.")
182187

183188

189+
def should_fallback_to_adam_for_matrix(
190+
p: torch.Tensor,
191+
min_2d_dim: int,
192+
) -> bool:
193+
"""
194+
Check if a 2D matrix should fallback to Adam due to small dimensions.
195+
196+
Parameters
197+
----------
198+
p : torch.Tensor
199+
Parameter tensor with ndim >= 2.
200+
min_2d_dim : int
201+
Minimum min(m, n) threshold for Muon. Matrices with min(m, n) >=
202+
min_2d_dim use Muon; those with min(m, n) < min_2d_dim use Adam.
203+
204+
Returns
205+
-------
206+
bool
207+
True if min(m, n) < min_2d_dim, False otherwise.
208+
209+
Raises
210+
------
211+
ValueError
212+
If tensor has ndim < 2.
213+
"""
214+
# === Step 1. Validate ===
215+
if p.ndim < 2:
216+
raise ValueError("Parameter must have ndim >= 2 for Muon suitability check.")
217+
218+
# === Step 2. Derive matrix shape consistent with Muon reshape ===
219+
m = int(p.shape[0])
220+
n = int(p.numel() // p.shape[0])
221+
222+
# === Step 3. Check if any dimension too small for Muon ===
223+
return min(m, n) < min_2d_dim
224+
225+
184226
class MuonOptimizer(Optimizer):
185227
"""
186-
Muon optimizer with auxiliary Adam for non-matrix parameters.
228+
Muon optimizer with small-2D Adam fallback and 1D Adam path.
187229
188230
This optimizer applies different update rules based on parameter dimensionality:
189-
- For >=2D parameters (weight matrices): Muon update with Newton-Schulz orthogonalization
190-
- For 1D parameters (biases, layer norms): Standard Adam update
231+
- For >=2D parameters with min(m, n) >= min_2d_dim:
232+
Muon update with Newton-Schulz orthogonalization.
233+
- For 2D parameters with min(m, n) < min_2d_dim (small matrices):
234+
Adam update with scaled learning rate and update clipping.
235+
- For 1D parameters (biases, layer norms):
236+
Standard Adam update.
191237
192238
This hybrid approach is effective because Muon's orthogonalization is designed
193239
for weight matrices, while Adam is more suitable for biases and normalization params.
@@ -224,8 +270,19 @@ class MuonOptimizer(Optimizer):
224270
scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust.
225271
Default is 10.0 (Adam lr = lr/10).
226272
lr_adjust_coeff : float
227-
Coefficient for match-RMS scaling with default 0.2.
228-
Only effective when lr_adjust <= 0.
273+
Dual-purpose coefficient with default 0.2:
274+
1. For Muon (when lr_adjust <= 0): match-RMS scaling factor,
275+
scale = lr_adjust_coeff * sqrt(max(m, n)).
276+
2. For 2D Adam fallback: learning rate multiplier,
277+
adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1).
278+
The min(., 0.1) cap ensures conservative updates for small matrices.
279+
min_2d_dim : int
280+
Minimum min(m, n) threshold for Muon on 2D matrices.
281+
Matrices with min(m, n) >= min_2d_dim use Muon;
282+
those with min(m, n) < min_2d_dim use Adam fallback.
283+
Must be >= 1.
284+
Set to 1 to disable fallback.
285+
Default is 1.
229286
230287
Examples
231288
--------
@@ -245,14 +302,19 @@ def __init__(
245302
adam_betas: tuple[float, float] = (0.9, 0.95),
246303
lr_adjust: float = 10.0,
247304
lr_adjust_coeff: float = 0.2,
305+
min_2d_dim: int = 1,
248306
) -> None:
307+
if min_2d_dim < 1:
308+
raise ValueError("min_2d_dim must be >= 1.")
309+
249310
defaults = {
250311
"lr": lr,
251312
"momentum": momentum,
252313
"weight_decay": weight_decay,
253314
"adam_betas": adam_betas,
254315
"lr_adjust": lr_adjust,
255316
"lr_adjust_coeff": lr_adjust_coeff,
317+
"min_2d_dim": min_2d_dim,
256318
}
257319
super().__init__(params, defaults)
258320
# Static parameter routing: built once on first step() call.
@@ -264,33 +326,50 @@ def _build_param_routing(self) -> None:
264326
Classify parameters into Muon and Adam routes (static routing).
265327
266328
Routing logic:
267-
- >=2D parameters → Muon path (Newton-Schulz + momentum)
268-
- 1D parameters → Adam path (standard Adam update)
329+
- >=2D parameters with min(m, n) >= min_2d_dim → Muon path
330+
- 2D parameters with min(m, n) < min_2d_dim → Adam fallback path
331+
- 1D parameters → Adam path
269332
"""
270333
if self._routing_built:
271334
return
272335

273336
self._routing = []
274337
for group in self.param_groups:
275338
muon_params: list[dict[str, Any]] = []
276-
adam_params: list[dict[str, Any]] = []
339+
adam_1d: list[dict[str, Any]] = []
340+
adam_matrix: list[dict[str, Any]] = []
341+
342+
min_2d_dim = group["min_2d_dim"]
277343

278344
for p in group["params"]:
279-
if p.ndim >= 2:
280-
muon_params.append(
345+
if p.ndim < 2:
346+
adam_1d.append({"param": p})
347+
continue
348+
349+
if (p.ndim == 2) and should_fallback_to_adam_for_matrix(
350+
p, min_2d_dim=min_2d_dim
351+
):
352+
adam_matrix.append(
281353
{
282354
"param": p,
283-
"rows": int(p.shape[0]),
284-
"cols": int(p.numel() // p.shape[0]),
355+
"abs_floor": 1e-3 * math.sqrt(float(p.numel())),
285356
}
286357
)
287-
else:
288-
adam_params.append({"param": p})
358+
continue
359+
360+
muon_params.append(
361+
{
362+
"param": p,
363+
"rows": int(p.shape[0]),
364+
"cols": int(p.numel() // p.shape[0]),
365+
}
366+
)
289367

290368
self._routing.append(
291369
{
292370
"muon_params": muon_params,
293-
"adam_params": adam_params,
371+
"adam_1d": adam_1d,
372+
"adam_matrix": adam_matrix,
294373
}
295374
)
296375

@@ -332,13 +411,14 @@ def step(
332411
lr_adjust_coeff = group["lr_adjust_coeff"]
333412

334413
# === Step 1. Adam update for 1D parameters (biases, norms, etc.) ===
414+
# === Step 1.1. Collect gradients and initialize state ===
335415
adam_params: list[torch.Tensor] = []
336416
adam_grads_fp32: list[torch.Tensor] = []
337417
adam_exp_avgs: list[torch.Tensor] = []
338418
adam_exp_avg_sqs: list[torch.Tensor] = []
339419
adam_states: list[dict[str, Any]] = []
340420

341-
for entry in route["adam_params"]:
421+
for entry in route["adam_1d"]:
342422
p = entry["param"]
343423
grad = p.grad
344424
if grad is None:
@@ -363,6 +443,7 @@ def step(
363443
adam_states.append(state)
364444

365445
if adam_params:
446+
# === Step 1.2. Update exp_avg / exp_avg_sq ===
366447
adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust
367448

368449
# exp_avg = beta1 * exp_avg + (1 - beta1) * grad
@@ -371,6 +452,7 @@ def step(
371452
grad_sq = torch._foreach_mul(adam_grads_fp32, adam_grads_fp32)
372453
torch._foreach_lerp_(adam_exp_avg_sqs, grad_sq, 1 - adam_betas[1])
373454

455+
# === Step 1.3. Bias correction and parameter update ===
374456
for i, p in enumerate(adam_params):
375457
state = adam_states[i]
376458
bias_corr1 = 1 - state["beta1_pow"]
@@ -381,7 +463,87 @@ def step(
381463
delta_fp32 = -step_size * (adam_exp_avgs[i] / denom)
382464
p.add_(delta_fp32.to(p.dtype))
383465

384-
# === Step 2. Muon update for >=2D parameters (weight matrices) ===
466+
# === Step 2. Adam update for small 2D matrices (fallback path) ===
467+
# === Step 2.1. Collect gradients and initialize state ===
468+
adam_matrix_params: list[torch.Tensor] = []
469+
adam_matrix_grads_fp32: list[torch.Tensor] = []
470+
adam_matrix_exp_avgs: list[torch.Tensor] = []
471+
adam_matrix_exp_avg_sqs: list[torch.Tensor] = []
472+
adam_matrix_states: list[dict[str, Any]] = []
473+
adam_matrix_abs_floor: list[float] = []
474+
475+
for entry in route["adam_matrix"]:
476+
p = entry["param"]
477+
grad = p.grad
478+
if grad is None:
479+
continue
480+
481+
grad_fp32 = grad.float()
482+
483+
state = self.state[p]
484+
if "exp_avg" not in state:
485+
state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32)
486+
state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32)
487+
state["beta1_pow"] = 1.0
488+
state["beta2_pow"] = 1.0
489+
490+
state["beta1_pow"] *= adam_betas[0]
491+
state["beta2_pow"] *= adam_betas[1]
492+
493+
adam_matrix_params.append(p)
494+
adam_matrix_grads_fp32.append(grad_fp32)
495+
adam_matrix_exp_avgs.append(state["exp_avg"])
496+
adam_matrix_exp_avg_sqs.append(state["exp_avg_sq"])
497+
adam_matrix_states.append(state)
498+
adam_matrix_abs_floor.append(entry["abs_floor"])
499+
500+
if adam_matrix_params:
501+
# === Step 2.2. Update exp_avg / exp_avg_sq with scaled lr ===
502+
adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust
503+
adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1)
504+
505+
# exp_avg = beta1 * exp_avg + (1 - beta1) * grad
506+
# exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
507+
torch._foreach_lerp_(
508+
adam_matrix_exp_avgs, adam_matrix_grads_fp32, 1 - adam_betas[0]
509+
)
510+
grad_sq_m = torch._foreach_mul(
511+
adam_matrix_grads_fp32, adam_matrix_grads_fp32
512+
)
513+
torch._foreach_lerp_(
514+
adam_matrix_exp_avg_sqs, grad_sq_m, 1 - adam_betas[1]
515+
)
516+
517+
# === Step 2.3. Compute unclipped deltas ===
518+
raw_deltas: list[torch.Tensor] = []
519+
for i in range(len(adam_matrix_params)):
520+
state = adam_matrix_states[i]
521+
bias_corr1 = 1 - state["beta1_pow"]
522+
bias_corr2 = 1 - state["beta2_pow"]
523+
step_size = adam_lr_matrix / bias_corr1
524+
denom = (
525+
(adam_matrix_exp_avg_sqs[i] / bias_corr2).sqrt().add_(ADAM_EPS)
526+
)
527+
raw_deltas.append(-step_size * (adam_matrix_exp_avgs[i] / denom))
528+
529+
# === Step 2.4. Clip updates by relative norm and apply ===
530+
max_rel_change = 0.05
531+
p_norms = torch.stack(torch._foreach_norm(adam_matrix_params))
532+
delta_norms = torch.stack(torch._foreach_norm(raw_deltas))
533+
floors = torch.tensor(
534+
adam_matrix_abs_floor,
535+
device=p_norms.device,
536+
dtype=p_norms.dtype,
537+
)
538+
max_delta = torch.maximum(max_rel_change * p_norms, floors)
539+
scales_tensor = torch.clamp(max_delta / (delta_norms + 1e-12), max=1.0)
540+
for i, delta in enumerate(raw_deltas):
541+
delta.mul_(scales_tensor[i])
542+
543+
torch._foreach_add_(adam_matrix_params, raw_deltas)
544+
545+
# === Step 3. Muon update for >=2D parameters (weight matrices) ===
546+
# === Step 3.1. Collect gradients and initialize momentum ===
385547
muon_params_for_decay: list[torch.Tensor] = []
386548
muon_grads: list[torch.Tensor] = []
387549
muon_momentum_buffers: list[torch.Tensor] = []
@@ -406,19 +568,22 @@ def step(
406568
muon_momentum_buffers.append(buf)
407569
active_entries.append((entry, grad))
408570

571+
# === Step 3.2. Apply weight decay (Muon path only) ===
409572
if weight_decay > 0 and muon_params_for_decay:
410573
torch._foreach_mul_(muon_params_for_decay, 1.0 - lr * weight_decay)
411574

412575
if not active_entries:
413576
continue
414577

578+
# === Step 3.3. Momentum update (Nesterov) ===
415579
# m_t = beta * m_{t-1} + (1 - beta) * g_t
416580
torch._foreach_lerp_(muon_momentum_buffers, muon_grads, 1 - momentum)
417581
# update = beta * m_t + (1 - beta) * g_t
418582
muon_updates = torch._foreach_lerp(
419583
muon_grads, muon_momentum_buffers, momentum
420584
)
421585

586+
# === Step 3.4. Bucket by shape/device/dtype for batched NS ===
422587
buckets: dict[
423588
tuple[int, int, torch.device, torch.dtype],
424589
list[tuple[dict[str, Any], torch.Tensor]],
@@ -432,6 +597,7 @@ def step(
432597
buckets[bucket_key] = []
433598
buckets[bucket_key].append((entry, muon_updates[idx]))
434599

600+
# === Step 3.5. Newton-Schulz orthogonalization and update ===
435601
for (rows, cols, _device, dtype), bucket_entries in buckets.items():
436602
# scale = coeff * sqrt(max(m, n)) [match-RMS mode]
437603
# scale = sqrt(max(1, m/n)) [rectangular mode]

deepmd/pt/train/training.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
173173
"adam_beta2": params.get("adam_beta2", 0.95),
174174
"lr_adjust": params.get("lr_adjust", 10.0),
175175
"lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2),
176+
"min_2d_dim": params.get("min_2d_dim", 1),
176177
}
177178
return opt_type, opt_param
178179

@@ -652,8 +653,7 @@ def single_model_finetune(
652653
missing, unexpected = self.model.load_state_dict(state, strict=False)
653654
if missing or unexpected:
654655
log.warning(
655-
"Checkpoint loaded non-strictly. "
656-
f"Missing keys: {missing}, Unexpected keys: {unexpected}"
656+
f"Checkpoint loaded non-strictly. Missing keys: {missing}, Unexpected keys: {unexpected}"
657657
)
658658

659659
# Get model prob for multi-task
@@ -758,6 +758,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
758758
),
759759
lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)),
760760
lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)),
761+
min_2d_dim=int(self.opt_param.get("min_2d_dim", 1)),
761762
)
762763
if optimizer_state_dict is not None and self.restart_training:
763764
self.optimizer.load_state_dict(optimizer_state_dict)
@@ -1577,8 +1578,6 @@ def model_change_out_bias(
15771578

15781579
model_type_map = _model.get_type_map()
15791580
log.info(
1580-
f"Change output bias of {model_type_map!s} "
1581-
f"from {to_numpy_array(old_bias).reshape(-1)!s} "
1582-
f"to {to_numpy_array(new_bias).reshape(-1)!s}."
1581+
f"Change output bias of {model_type_map!s} from {to_numpy_array(old_bias).reshape(-1)!s} to {to_numpy_array(new_bias).reshape(-1)!s}."
15831582
)
15841583
return _model

deepmd/utils/argcheck.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3471,6 +3471,18 @@ def training_args(
34713471
doc=doc_only_pt_supported
34723472
+ "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.",
34733473
),
3474+
Argument(
3475+
"min_2d_dim",
3476+
int,
3477+
optional=True,
3478+
default=1,
3479+
alias=["muon_min_2d_dim"],
3480+
doc=doc_only_pt_supported
3481+
+ "Minimum min(m, n) threshold for Muon on 2D matrices. "
3482+
"Matrices with min(m, n) >= min_2d_dim use Muon; "
3483+
"those with min(m, n) < min_2d_dim use Adam fallback. "
3484+
"Set to 1 to disable fallback.",
3485+
),
34743486
],
34753487
[],
34763488
optional=True,

0 commit comments

Comments
 (0)