Skip to content

Commit c13ca6c

Browse files
committed
feat: add Magma-lite damping for Muon path; fix AdamW decay lr
- Implement block-wise momentum-gradient alignment with EMA smoothing and soft scaling [0.1, 1.0] on Muon updates (magma_muon option) - Fix AdamW weight decay to use adam_lr instead of base lr - Wire magma_muon through training config and argcheck - Clean up redundant optimizer tests
1 parent 2ab1d28 commit c13ca6c

File tree

4 files changed

+400
-69
lines changed

4 files changed

+400
-69
lines changed

deepmd/pt/optimizer/hybrid_muon.py

Lines changed: 294 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@
7171
https://github.com/MoonshotAI/Moonlight
7272
.. [4] Flash-Muon: Triton-accelerated symmetric matmul for Newton-Schulz.
7373
https://github.com/lintianyang/flash-muon (MIT License, Tianyang Lin)
74+
.. [5] Magma: Momentum-Aligned Gradient Masking for Stable Optimizer Updates.
75+
arXiv:2602.15322, 2025.
76+
https://arxiv.org/abs/2602.15322
77+
Implements block-wise momentum-gradient alignment scoring with EMA smoothing
78+
and soft scaling for improved stability under heavy-tailed gradient noise.
79+
HybridMuon uses a stabilized variant (Magma-lite) with sigmoid range stretching
80+
and continuous soft scaling [0.1, 1.0] instead of Bernoulli masking, optimized
81+
for MLIP force-field training.
7482
"""
7583

7684
from __future__ import (
@@ -122,6 +130,13 @@
122130
# Below this threshold, triton kernel launch overhead dominates over compute,
123131
# and cuBLAS (via torch.mm/addmm) is faster for small matrices.
124132
FLASH_MIN_DIM: int = 1024
133+
# Magma-lite constants (Muon path update damping only)
134+
MAGMA_TAU: float = 2.0
135+
MAGMA_EMA_DECAY: float = 0.9
136+
MAGMA_MIN_SCALE: float = 0.1
137+
MAGMA_EPS: float = 1e-12
138+
MAGMA_SIGMOID_MIN: float = 1.0 / (1.0 + math.exp(1.0 / MAGMA_TAU))
139+
MAGMA_SIGMOID_MAX: float = 1.0 / (1.0 + math.exp(-1.0 / MAGMA_TAU))
125140

126141

127142
# ============================================================================
@@ -554,6 +569,11 @@ class HybridMuonOptimizer(Optimizer):
554569
Requires triton and CUDA. Falls back to PyTorch implementation
555570
when triton is unavailable or running on CPU.
556571
Default is True.
572+
magma_muon : bool
573+
Enable Magma-lite damping on Muon updates with default False.
574+
This computes momentum-gradient cosine alignment per Muon block,
575+
applies EMA smoothing, and rescales Muon updates in [0.1, 1.0].
576+
Adam/AdamW paths are unchanged.
557577
558578
Examples
559579
--------
@@ -576,6 +596,7 @@ def __init__(
576596
muon_mode: str = "slice",
577597
named_parameters: Iterable[tuple[str, torch.Tensor]] | None = None,
578598
flash_muon: bool = True,
599+
magma_muon: bool = False,
579600
) -> None:
580601
# === Step 1. Validate routing mode ===
581602
muon_mode = str(muon_mode).lower()
@@ -591,6 +612,7 @@ def __init__(
591612
"lr_adjust": lr_adjust,
592613
"lr_adjust_coeff": lr_adjust_coeff,
593614
"muon_mode": muon_mode,
615+
"magma_muon": bool(magma_muon),
594616
}
595617
super().__init__(params, defaults)
596618

@@ -612,6 +634,226 @@ def __init__(
612634
tuple[torch.Tensor, torch.Tensor],
613635
] = {}
614636

637+
def _compute_magma_scale(
638+
self,
639+
param: torch.Tensor,
640+
grad: torch.Tensor,
641+
momentum_buffer: torch.Tensor,
642+
batch_size: int,
643+
rows: int,
644+
cols: int,
645+
) -> torch.Tensor:
646+
"""
647+
Compute Magma-lite Muon damping scales from momentum-gradient alignment.
648+
649+
Implements a stabilized version of Magma (Momentum-Aligned Gradient Masking)
650+
adapted for MLIP force-field training. Computes block-wise alignment scores
651+
between Muon momentum and current gradients, applies EMA smoothing, and
652+
rescales Muon updates to improve stability under heavy-tailed gradient noise.
653+
654+
Notes
655+
-----
656+
For each Muon block b:
657+
658+
1. Compute cosine similarity between momentum and gradient:
659+
660+
cos(b) = <μ_t^(b), g_t^(b)> / (||μ_t^(b)|| * ||g_t^(b)||)
661+
662+
2. Apply sigmoid with range stretching to [0, 1]:
663+
664+
s_raw^(b) = (sigmoid(cos(b) / τ) - s_min) / (s_max - s_min)
665+
666+
where τ=2.0, s_min=sigmoid(-1/τ), s_max=sigmoid(1/τ).
667+
This stretches the narrow sigmoid range [0.38, 0.62] to [0, 1].
668+
669+
3. Apply EMA smoothing:
670+
671+
s̃_t^(b) = a * s̃_{t-1}^(b) + (1-a) * s_raw^(b)
672+
673+
where a=0.9 (MAGMA_EMA_DECAY).
674+
675+
4. Map to damping scale in [s_min_scale, 1.0]:
676+
677+
scale^(b) = s_min_scale + (1 - s_min_scale) * s̃_t^(b)
678+
679+
where s_min_scale=0.1 (MAGMA_MIN_SCALE).
680+
681+
5. Apply damping to Muon update:
682+
683+
Δ̃^(b) = scale^(b) * Δ^(b) (soft scaling, no Bernoulli masking)
684+
685+
Key differences from the original Magma paper:
686+
687+
- Sigmoid range stretching: Paper uses raw sigmoid with narrow range [0.38, 0.62].
688+
We stretch to [0, 1] for better discrimination between aligned/misaligned blocks.
689+
- Soft scaling: Paper uses Bernoulli masking (50% skip probability).
690+
We use continuous soft scaling [0.1, 1.0] for stability in MLIP training.
691+
- Minimum scale: Paper allows scale=0 (complete skip).
692+
We enforce scale >= 0.1 to guarantee minimum learning rate.
693+
694+
Parameters
695+
----------
696+
param : torch.Tensor
697+
Parameter updated by Muon.
698+
grad : torch.Tensor
699+
Current gradient tensor with shape compatible with ``(batch_size, rows, cols)``.
700+
momentum_buffer : torch.Tensor
701+
Muon momentum buffer (updated m_t) with same shape as ``grad``.
702+
batch_size : int
703+
Number of Muon blocks (1 for 2d/flat mode, >1 for slice mode).
704+
rows : int
705+
Matrix row count per block.
706+
cols : int
707+
Matrix column count per block.
708+
709+
Returns
710+
-------
711+
torch.Tensor
712+
Damping scales with shape (batch_size,) in [MAGMA_MIN_SCALE, 1.0].
713+
"""
714+
# === Step 1. Restore or initialize EMA score state ===
715+
state = self.state[param]
716+
magma_score = state.get("magma_score")
717+
if (
718+
magma_score is None
719+
or magma_score.ndim != 1
720+
or magma_score.numel() != batch_size
721+
or magma_score.device != param.device
722+
):
723+
magma_score = torch.full(
724+
(batch_size,),
725+
0.5,
726+
dtype=torch.float32,
727+
device=param.device,
728+
)
729+
else:
730+
magma_score = magma_score.to(dtype=torch.float32, device=param.device)
731+
732+
# === Step 2. Build matrix-view for block-wise cosine ===
733+
grad_view = grad.reshape(batch_size, rows, cols).reshape(batch_size, -1)
734+
momentum_view = momentum_buffer.reshape(batch_size, rows, cols).reshape(
735+
batch_size, -1
736+
)
737+
grad_view = grad_view.to(dtype=torch.float32)
738+
momentum_view = momentum_view.to(dtype=torch.float32)
739+
740+
# === Step 3. Compute cosine alignment with numerical protection ===
741+
dot = (momentum_view * grad_view).sum(dim=1)
742+
denom = (momentum_view.norm(dim=1) * grad_view.norm(dim=1)).clamp(min=MAGMA_EPS)
743+
cosine = (dot / denom).clamp(min=-1.0, max=1.0)
744+
745+
# === Step 4. Sigmoid mapping + range stretching to [0, 1] ===
746+
raw_sigmoid = torch.sigmoid(cosine / MAGMA_TAU)
747+
raw_score = (raw_sigmoid - MAGMA_SIGMOID_MIN) / (
748+
MAGMA_SIGMOID_MAX - MAGMA_SIGMOID_MIN
749+
)
750+
raw_score = raw_score.clamp(min=0.0, max=1.0)
751+
752+
# === Step 5. Update EMA score and convert to damping scale ===
753+
magma_score = (
754+
MAGMA_EMA_DECAY * magma_score + (1.0 - MAGMA_EMA_DECAY) * raw_score
755+
)
756+
state["magma_score"] = magma_score
757+
return MAGMA_MIN_SCALE + (1.0 - MAGMA_MIN_SCALE) * magma_score
758+
759+
def _compute_magma_scales_for_bucket(
760+
self,
761+
bucket_entries: list[
762+
tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]
763+
],
764+
batch_size: int,
765+
rows: int,
766+
cols: int,
767+
) -> list[torch.Tensor]:
768+
"""
769+
Compute Magma-lite damping scales for one Muon bucket in a batched way.
770+
771+
Parameters
772+
----------
773+
bucket_entries : list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]]
774+
Bucket entries as ``(entry, update_tensor, grad, momentum_buffer)``.
775+
batch_size : int
776+
Number of Muon blocks per parameter in this bucket.
777+
rows : int
778+
Matrix row count for this bucket.
779+
cols : int
780+
Matrix column count for this bucket.
781+
782+
Returns
783+
-------
784+
list[torch.Tensor]
785+
Magma scales for each bucket entry. Each tensor has shape (batch_size,).
786+
"""
787+
# === Step 0. Fast path for single-entry bucket ===
788+
if len(bucket_entries) == 1:
789+
entry, _update_tensor, grad, momentum_buffer = bucket_entries[0]
790+
return [
791+
self._compute_magma_scale(
792+
param=entry["param"],
793+
grad=grad,
794+
momentum_buffer=momentum_buffer,
795+
batch_size=batch_size,
796+
rows=rows,
797+
cols=cols,
798+
)
799+
]
800+
801+
# === Step 1. Build batched matrix views ===
802+
grad_views: list[torch.Tensor] = []
803+
momentum_views: list[torch.Tensor] = []
804+
for _, _, grad, momentum_buffer in bucket_entries:
805+
grad_view = grad.reshape(batch_size, rows, cols).reshape(batch_size, -1)
806+
momentum_view = momentum_buffer.reshape(batch_size, rows, cols).reshape(
807+
batch_size, -1
808+
)
809+
grad_views.append(grad_view.to(dtype=torch.float32))
810+
momentum_views.append(momentum_view.to(dtype=torch.float32))
811+
812+
grad_batch = torch.stack(grad_views, dim=0)
813+
momentum_batch = torch.stack(momentum_views, dim=0)
814+
815+
# === Step 2. Compute cosine alignment for all entries ===
816+
dot = (momentum_batch * grad_batch).sum(dim=2)
817+
denom = (momentum_batch.norm(dim=2) * grad_batch.norm(dim=2)).clamp(
818+
min=MAGMA_EPS
819+
)
820+
cosine = (dot / denom).clamp(min=-1.0, max=1.0)
821+
raw_sigmoid = torch.sigmoid(cosine / MAGMA_TAU)
822+
raw_scores = (raw_sigmoid - MAGMA_SIGMOID_MIN) / (
823+
MAGMA_SIGMOID_MAX - MAGMA_SIGMOID_MIN
824+
)
825+
raw_scores = raw_scores.clamp(min=0.0, max=1.0)
826+
827+
# === Step 3. Update per-parameter EMA score state ===
828+
scales: list[torch.Tensor] = []
829+
for idx, (entry, _, _, _) in enumerate(bucket_entries):
830+
param = entry["param"]
831+
state = self.state[param]
832+
magma_score = state.get("magma_score")
833+
if (
834+
magma_score is None
835+
or magma_score.ndim != 1
836+
or magma_score.numel() != batch_size
837+
or magma_score.device != param.device
838+
):
839+
magma_score = torch.full(
840+
(batch_size,),
841+
0.5,
842+
dtype=torch.float32,
843+
device=param.device,
844+
)
845+
state["magma_score"] = magma_score
846+
elif magma_score.dtype != torch.float32:
847+
magma_score = magma_score.to(dtype=torch.float32, device=param.device)
848+
state["magma_score"] = magma_score
849+
850+
magma_score.mul_(MAGMA_EMA_DECAY).add_(
851+
raw_scores[idx], alpha=(1.0 - MAGMA_EMA_DECAY)
852+
)
853+
scales.append(MAGMA_MIN_SCALE + (1.0 - MAGMA_MIN_SCALE) * magma_score)
854+
855+
return scales
856+
615857
def _get_ns_buffers(
616858
self,
617859
M: int,
@@ -742,6 +984,7 @@ def step(
742984
adam_betas = group["adam_betas"]
743985
lr_adjust = group["lr_adjust"]
744986
lr_adjust_coeff = group["lr_adjust_coeff"]
987+
magma_muon = bool(group.get("magma_muon", False))
745988

746989
# === Step 1. Adam update for non-decay Adam path ===
747990
# === Step 1.1. Collect gradients and initialize state ===
@@ -836,7 +1079,7 @@ def step(
8361079
# AdamW decay for >=2D Adam path.
8371080
if weight_decay > 0:
8381081
for p in adam_decay_params:
839-
p.mul_(1.0 - lr * weight_decay)
1082+
p.mul_(1.0 - adam_lr * weight_decay)
8401083

8411084
# exp_avg = beta1 * exp_avg + (1 - beta1) * grad
8421085
# exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
@@ -904,7 +1147,7 @@ def step(
9041147
# === Step 3.4. Bucket by (batch_size, rows, cols, device, dtype) ===
9051148
buckets: dict[
9061149
tuple[int, int, int, torch.device, torch.dtype],
907-
list[tuple[dict[str, Any], torch.Tensor]],
1150+
list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]],
9081151
] = {}
9091152

9101153
for idx, entry_info in enumerate(active_entries):
@@ -919,7 +1162,14 @@ def step(
9191162
)
9201163
if bucket_key not in buckets:
9211164
buckets[bucket_key] = []
922-
buckets[bucket_key].append((entry, muon_updates[idx]))
1165+
buckets[bucket_key].append(
1166+
(
1167+
entry,
1168+
muon_updates[idx],
1169+
muon_grads[idx],
1170+
muon_momentum_buffers[idx],
1171+
)
1172+
)
9231173

9241174
# === Step 3.5. Newton-Schulz orthogonalization and update ===
9251175
for (batch_size, rows, cols, _device, _), bucket_entries in buckets.items():
@@ -944,24 +1194,57 @@ def step(
9441194
if use_flash:
9451195
buf1, buf2 = self._get_ns_buffers(M, _device)
9461196

1197+
if magma_muon:
1198+
bucket_magma_scales = self._compute_magma_scales_for_bucket(
1199+
bucket_entries=bucket_entries,
1200+
batch_size=batch_size,
1201+
rows=rows,
1202+
cols=cols,
1203+
)
1204+
else:
1205+
bucket_magma_scales = [None] * len(bucket_entries)
1206+
9471207
# Process each entry individually with Newton-Schulz orth.
9481208
# Compatible with sharding propagation under FSDP2.
949-
for entry, update_tensor in bucket_entries:
1209+
for (entry, update_tensor, _grad, _buffer), magma_scale in zip(
1210+
bucket_entries, bucket_magma_scales, strict=True
1211+
):
9501212
if batch_size > 1:
951-
update_batch = update_tensor.reshape(batch_size, rows, cols)
952-
if not update_batch.is_contiguous():
953-
update_batch = update_batch.contiguous()
1213+
if update_tensor.is_contiguous():
1214+
update_batch = update_tensor.view(batch_size, rows, cols)
1215+
else:
1216+
update_batch = update_tensor.reshape(
1217+
batch_size, rows, cols
1218+
).contiguous()
9541219
orth = _batched_newton_schulz_orth(update_batch)
9551220
else:
956-
update_matrix = update_tensor.reshape(rows, cols)
957-
if not update_matrix.is_contiguous():
958-
update_matrix = update_matrix.contiguous()
1221+
if update_tensor.is_contiguous():
1222+
update_matrix = update_tensor.view(rows, cols)
1223+
else:
1224+
update_matrix = update_tensor.reshape(
1225+
rows, cols
1226+
).contiguous()
9591227
if use_flash:
9601228
orth = _flash_newton_schulz_orth(update_matrix, buf1, buf2)
9611229
else:
9621230
orth = _newton_schulz_orth(update_matrix)
9631231
orth.mul_(scale)
964-
delta = orth.reshape(entry["param"].shape)
1232+
if batch_size > 1:
1233+
orth_view = orth.reshape(batch_size, rows, cols)
1234+
if magma_scale is not None:
1235+
orth_view.mul_(
1236+
magma_scale.view(batch_size, 1, 1).to(
1237+
dtype=orth.dtype,
1238+
device=orth.device,
1239+
)
1240+
)
1241+
delta = orth_view.reshape(entry["param"].shape)
1242+
else:
1243+
if magma_scale is not None:
1244+
orth.mul_(
1245+
magma_scale[0].to(dtype=orth.dtype, device=orth.device)
1246+
)
1247+
delta = orth.reshape(entry["param"].shape)
9651248
entry["param"].add_(delta, alpha=-lr)
9661249

9671250
return loss

deepmd/pt/train/training.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@ def single_model_finetune(
818818
"muon_mode": str(self.opt_param.get("muon_mode", "slice")),
819819
"named_parameters": tuple(self.wrapper.named_parameters()),
820820
"flash_muon": bool(self.opt_param.get("flash_muon", True)),
821+
"magma_muon": bool(self.opt_param.get("magma_muon", False)),
821822
}
822823
else:
823824
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")

0 commit comments

Comments
 (0)