Skip to content

Commit 9a86678

Browse files
committed
sort args
1 parent 3046a13 commit 9a86678

3 files changed

Lines changed: 7 additions & 9 deletions

File tree

deepmd/pt/optimizer/adamuon.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ class AdaMuonOptimizer(Optimizer):
160160
Key AdaMuon features:
161161
- Sign-stabilized orthogonal direction: Applies sign() before orthogonalization.
162162
- Per-element second-moment normalization using momentum coefficient.
163-
- RMS-aligned global scaling: 0.2 * sqrt(min * max) / norm.
163+
- RMS-aligned global scaling: 0.2 * sqrt(m * n) / norm.
164164
165165
Parameters
166166
----------
@@ -245,7 +245,7 @@ def step(
245245
246246
Returns
247247
-------
248-
loss : float, optional
248+
loss : torch.Tensor, optional
249249
The loss value if closure is provided.
250250
"""
251251
loss = None

deepmd/pt/train/training.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,9 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
166166
# Common parameters
167167
"weight_decay": params.get("weight_decay", 0.001),
168168
# Muon/AdaMuon parameters
169-
"muon_momentum": params.get("muon_momentum", 0.95),
169+
"momentum": params.get("momentum", 0.95),
170170
"adam_beta1": params.get("adam_beta1", 0.9),
171171
"adam_beta2": params.get("adam_beta2", 0.95),
172-
"adam_eps": params.get("adam_eps", 1e-7),
173-
"nesterov": params.get("nesterov", True),
174172
}
175173
return opt_type, opt_param
176174

@@ -710,7 +708,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
710708
self.optimizer = AdaMuonOptimizer(
711709
self.wrapper.parameters(),
712710
lr=self.lr_exp.start_lr,
713-
momentum=float(self.opt_param.get("muon_momentum", 0.95)),
711+
momentum=float(self.opt_param.get("momentum", 0.95)),
714712
weight_decay=float(self.opt_param.get("weight_decay", 0.001)),
715713
adam_betas=(
716714
float(self.opt_param.get("adam_beta1", 0.9)),

deepmd/utils/argcheck.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3375,7 +3375,7 @@ def training_args(
33753375
dict,
33763376
[
33773377
Argument(
3378-
"muon_momentum",
3378+
"momentum",
33793379
float,
33803380
optional=True,
33813381
default=0.95,
@@ -3413,8 +3413,8 @@ def training_args(
34133413
default=10.0,
34143414
doc=doc_only_pt_supported
34153415
+ "Learning rate adjustment factor for Adam (1D params). "
3416-
"If lr_adjust <= 0: use match-RMS scaling for AdaMuon, Adam uses lr directly. "
3417-
"If lr_adjust > 0: use rectangular correction for AdaMuon, Adam uses lr/lr_adjust.",
3416+
"If lr_adjust <= 0: use match-RMS scaling (scale = lr_adjust_coeff * sqrt(max(m, n))), Adam uses lr directly. "
3417+
"If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1.0, m/n))), Adam uses lr/lr_adjust.",
34183418
),
34193419
Argument(
34203420
"lr_adjust_coeff",

0 commit comments

Comments
 (0)