Skip to content

Commit a2c52c6

Browse files
committed
fixup
1 parent d29ee0c commit a2c52c6

File tree

2 files changed

+15
-13
lines changed

2 files changed

+15
-13
lines changed

deepmd/pt/optimizer/hybrid_muon.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
.. [4] Flash-Muon: Triton-accelerated symmetric matmul for Newton-Schulz.
7373
https://github.com/lintianyang/flash-muon (MIT License, Tianyang Lin)
7474
.. [5] Magma: Momentum-Aligned Gradient Masking for Stable Optimizer Updates.
75-
arXiv:2602.15322, 2025.
75+
arXiv:2602.15322, 2026.
7676
https://arxiv.org/abs/2602.15322
7777
Implements block-wise momentum-gradient alignment scoring with EMA smoothing
7878
and soft scaling for improved stability under heavy-tailed gradient noise.
@@ -340,9 +340,7 @@ def _batched_newton_schulz_orth(
340340
"""
341341
# === Step 1. Validate and prepare matrix orientation ===
342342
if G.ndim != 3:
343-
raise ValueError(
344-
"Batched Newton-Schulz expects a 3D tensor with shape (B, m, n)."
345-
)
343+
raise ValueError("Batched Newton-Schulz expects a 3D tensor (B, m, n).")
346344

347345
X = G.to(dtype=torch.bfloat16)
348346
transposed = X.size(-2) > X.size(-1)
@@ -474,9 +472,7 @@ def get_matrix_view_shape(
474472
rows = int(effective_shape[-2])
475473
cols = int(effective_shape[-1])
476474
return (batch_size, rows, cols)
477-
raise ValueError(
478-
f"Unsupported muon_mode '{muon_mode}'. Expected one of ['2d', 'flat', 'slice']."
479-
)
475+
raise ValueError(f"Invalid muon_mode '{muon_mode}'. Use '2d', 'flat', or 'slice'.")
480476

481477

482478
class HybridMuonOptimizer(Optimizer):
@@ -601,7 +597,9 @@ def __init__(
601597
# === Step 1. Validate routing mode ===
602598
muon_mode = str(muon_mode).lower()
603599
if muon_mode not in {"2d", "flat", "slice"}:
604-
raise ValueError("muon_mode must be one of ['2d', 'flat', 'slice'].")
600+
raise ValueError(
601+
f"Invalid muon_mode '{muon_mode}'. Use '2d', 'flat', or 'slice'."
602+
)
605603

606604
# === Step 2. Register optimizer defaults ===
607605
defaults = {
@@ -1024,10 +1022,12 @@ def step(
10241022

10251023
# exp_avg = beta1 * exp_avg + (1 - beta1) * grad
10261024
# exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
1027-
for ea, g in zip(adam_no_decay_exp_avgs, adam_no_decay_grads_fp32):
1025+
for ea, g in zip(
1026+
adam_no_decay_exp_avgs, adam_no_decay_grads_fp32, strict=True
1027+
):
10281028
ea.lerp_(g, 1 - adam_betas[0])
10291029
grad_sq = [g * g for g in adam_no_decay_grads_fp32]
1030-
for eas, gsq in zip(adam_no_decay_exp_avg_sqs, grad_sq):
1030+
for eas, gsq in zip(adam_no_decay_exp_avg_sqs, grad_sq, strict=True):
10311031
eas.lerp_(gsq, 1 - adam_betas[1])
10321032

10331033
# === Step 1.3. Bias correction and parameter update ===
@@ -1083,10 +1083,12 @@ def step(
10831083

10841084
# exp_avg = beta1 * exp_avg + (1 - beta1) * grad
10851085
# exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
1086-
for ea, g in zip(adam_decay_exp_avgs, adam_decay_grads_fp32):
1086+
for ea, g in zip(
1087+
adam_decay_exp_avgs, adam_decay_grads_fp32, strict=True
1088+
):
10871089
ea.lerp_(g, 1 - adam_betas[0])
10881090
grad_sq = [g * g for g in adam_decay_grads_fp32]
1089-
for eas, gsq in zip(adam_decay_exp_avg_sqs, grad_sq):
1091+
for eas, gsq in zip(adam_decay_exp_avg_sqs, grad_sq, strict=True):
10901092
eas.lerp_(gsq, 1 - adam_betas[1])
10911093

10921094
# === Step 2.3. Bias correction and parameter update ===

source/tests/pt/test_hybrid_muon.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
357357

358358
self.assertFalse(torch.allclose(model1.weight, model2.weight))
359359
self.assertTrue(
360-
torch.allclose(model1.adam_scale, model2.adam_scale, atol=0.0, rtol=0.0)
360+
torch.allclose(model1.adam_scale, model2.adam_scale, atol=1e-7, rtol=1e-7)
361361
)
362362

363363

0 commit comments

Comments
 (0)