Skip to content

Commit 586ca17

Browse files
committed
fix(pt): compatible with AdaMuon
1 parent eb1a5f9 commit 586ca17

2 files changed

Lines changed: 63 additions & 0 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -745,6 +745,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
745745
float(self.opt_param.get("adam_beta1", 0.9)),
746746
float(self.opt_param.get("adam_beta2", 0.95)),
747747
),
748+
lr_adjust=float(self.opt_param.get("lr_adjust", 10.0)),
749+
lr_adjust_coeff=float(self.opt_param.get("lr_adjust_coeff", 0.2)),
748750
)
749751
elif self.opt_type == "Muon":
750752
self.optimizer = MuonOptimizer(

deepmd/utils/argcheck.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3471,6 +3471,67 @@ def training_args(
34713471
doc=doc_only_pt_supported
34723472
+ "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.",
34733473
),
3474+
],
3475+
[],
3476+
optional=True,
3477+
),
3478+
Argument(
3479+
"Muon",
3480+
dict,
3481+
[
3482+
Argument(
3483+
"momentum",
3484+
float,
3485+
optional=True,
3486+
default=0.95,
3487+
alias=["muon_momentum"],
3488+
doc=doc_only_pt_supported
3489+
+ "Momentum coefficient for Muon optimizer (>=2D params). "
3490+
"Used in Nesterov momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t.",
3491+
),
3492+
Argument(
3493+
"adam_beta1",
3494+
float,
3495+
optional=True,
3496+
default=0.9,
3497+
doc=doc_only_pt_supported
3498+
+ "Adam beta1 coefficient for 1D parameters (biases, norms).",
3499+
),
3500+
Argument(
3501+
"adam_beta2",
3502+
float,
3503+
optional=True,
3504+
default=0.95,
3505+
doc=doc_only_pt_supported
3506+
+ "Adam beta2 coefficient for 1D parameters (biases, norms).",
3507+
),
3508+
Argument(
3509+
"weight_decay",
3510+
float,
3511+
optional=True,
3512+
default=0.001,
3513+
doc=doc_only_pt_supported
3514+
+ "Weight decay coefficient. Applied only to >=2D parameters (Muon path).",
3515+
),
3516+
Argument(
3517+
"lr_adjust",
3518+
float,
3519+
optional=True,
3520+
default=10.0,
3521+
doc=doc_only_pt_supported
3522+
+ "Learning rate adjustment mode for Muon scaling and Adam learning rate. "
3523+
"If lr_adjust <= 0: use match-RMS scaling (scale = coeff*sqrt(max(m,n))), Adam uses lr directly. "
3524+
"If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1, m/n))), Adam uses lr/lr_adjust. "
3525+
"Default is 10.0 (Adam lr = lr/10).",
3526+
),
3527+
Argument(
3528+
"lr_adjust_coeff",
3529+
float,
3530+
optional=True,
3531+
default=0.2,
3532+
doc=doc_only_pt_supported
3533+
+ "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0.",
3534+
),
34743535
Argument(
34753536
"min_2d_dim",
34763537
int,

0 commit comments

Comments
 (0)