Skip to content

Commit 3046a13

Browse files
committed
feat(pt): Update AdaMuon
(cherry picked from commit ea5ac54)
1 parent 98cec8b commit 3046a13

2 files changed

Lines changed: 5 additions & 5 deletions

File tree

deepmd/pt/optimizer/adamuon.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
def zeropower_via_newtonschulz5(
4141
G: torch.Tensor,
4242
steps: int = 5,
43-
eps: float = 1e-7,
43+
eps: float = 1e-8,
4444
) -> torch.Tensor:
4545
"""
4646
Compute the zeroth power (orthogonalization) of a matrix via Newton-Schulz iteration.
@@ -58,7 +58,7 @@ def zeropower_via_newtonschulz5(
5858
steps : int
5959
Number of Newton-Schulz iterations with default 5.
6060
eps : float
61-
Numerical stability epsilon for norm clamping with default 1e-7.
61+
Numerical stability epsilon for norm clamping with default 1e-8.
6262
6363
Returns
6464
-------
@@ -177,7 +177,7 @@ class AdaMuonOptimizer(Optimizer):
177177
adam_betas : tuple[float, float]
178178
Adam beta coefficients with default (0.9, 0.95).
179179
adam_eps : float
180-
Adam epsilon with default 1e-7.
180+
Adam epsilon with default 1e-8.
181181
nesterov : bool
182182
Whether to use Nesterov momentum for AdaMuon with default True.
183183
lr_adjust : float
@@ -210,7 +210,7 @@ def __init__(
210210
weight_decay: float = 0.001,
211211
ns_steps: int = 5,
212212
adam_betas: tuple[float, float] = (0.9, 0.95),
213-
adam_eps: float = 1e-7,
213+
adam_eps: float = 1e-8,
214214
nesterov: bool = True,
215215
lr_adjust: float = 10.0,
216216
lr_adjust_coeff: float = 0.2,

deepmd/pt/train/training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
711711
self.wrapper.parameters(),
712712
lr=self.lr_exp.start_lr,
713713
momentum=float(self.opt_param.get("muon_momentum", 0.95)),
714-
weight_decay=float(self.opt_param.get("weight_decay", 0.0)),
714+
weight_decay=float(self.opt_param.get("weight_decay", 0.001)),
715715
adam_betas=(
716716
float(self.opt_param.get("adam_beta1", 0.9)),
717717
float(self.opt_param.get("adam_beta2", 0.95)),

0 commit comments

Comments
 (0)