Skip to content

Commit 3b24cf6

Browse files
authored
refactor(optimizer): finalize optimizer schema and backend handling (#5157)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added a top-level, configurable optimizer section; PyTorch: Adam/AdamW/LKF/AdaMuon/HybridMuon, TensorFlow/Paddle: Adam. * Exposes Adam-style hyperparameters (beta1, beta2, weight_decay) for relevant optimizers. * **Documentation** * Training docs updated with optimizer examples and framework-specific guidance. * **Backward Compatibility** * Legacy optimizer fields are auto-migrated to the new format; tests and examples updated to use the new config. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 2deb70c commit 3b24cf6

34 files changed

Lines changed: 758 additions & 348 deletions

deepmd/pd/train/training.py

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ def __init__(
120120
self.restart_training = restart_model is not None
121121
model_params = config["model"]
122122
training_params = config["training"]
123+
optimizer_params = config.get("optimizer", {})
123124
self.multi_task = "model_dict" in model_params
124125
self.finetune_links = finetune_links
125126
self.finetune_update_stat = False
@@ -157,14 +158,17 @@ def __init__(
157158
self.lcurve_should_print_header = True
158159

159160
def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
160-
opt_type = params.get("opt_type", "Adam")
161-
opt_param = {
162-
"kf_blocksize": params.get("kf_blocksize", 5120),
163-
"kf_start_pref_e": params.get("kf_start_pref_e", 1),
164-
"kf_limit_pref_e": params.get("kf_limit_pref_e", 1),
165-
"kf_start_pref_f": params.get("kf_start_pref_f", 1),
166-
"kf_limit_pref_f": params.get("kf_limit_pref_f", 1),
167-
}
161+
"""
162+
Extract optimizer parameters.
163+
164+
Note: Default values are already filled by argcheck.normalize()
165+
before this function is called.
166+
"""
167+
opt_type = params.get("type", "Adam")
168+
if opt_type != "Adam":
169+
raise ValueError(f"Not supported optimizer type '{opt_type}'")
170+
opt_param = dict(params)
171+
opt_param.pop("type", None)
168172
return opt_type, opt_param
169173

170174
def get_data_loader(
@@ -256,22 +260,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
256260
return lr_schedule
257261

258262
# Optimizer
259-
if self.multi_task and training_params.get("optim_dict", None) is not None:
260-
self.optim_dict = training_params.get("optim_dict")
261-
missing_keys = [
262-
key for key in self.model_keys if key not in self.optim_dict
263-
]
264-
assert not missing_keys, (
265-
f"These keys are not in optim_dict: {missing_keys}!"
266-
)
267-
self.opt_type = {}
268-
self.opt_param = {}
269-
for model_key in self.model_keys:
270-
self.opt_type[model_key], self.opt_param[model_key] = get_opt_param(
271-
self.optim_dict[model_key]
272-
)
273-
else:
274-
self.opt_type, self.opt_param = get_opt_param(training_params)
263+
self.opt_type, self.opt_param = get_opt_param(optimizer_params)
275264

276265
# loss_param_tmp for Hessian activation
277266
loss_param_tmp = None
@@ -667,7 +656,11 @@ def single_model_finetune(
667656
),
668657
)
669658
self.optimizer = paddle.optimizer.Adam(
670-
learning_rate=self.scheduler, parameters=self.wrapper.parameters()
659+
learning_rate=self.scheduler,
660+
parameters=self.wrapper.parameters(),
661+
beta1=float(self.opt_param["adam_beta1"]),
662+
beta2=float(self.opt_param["adam_beta2"]),
663+
weight_decay=float(self.opt_param["weight_decay"]),
671664
)
672665
if optimizer_state_dict is not None and self.restart_training:
673666
self.optimizer.set_state_dict(optimizer_state_dict)

deepmd/pt/train/training.py

Lines changed: 49 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def __init__(
142142
self.restart_training = restart_model is not None
143143
model_params = config["model"]
144144
training_params = config["training"]
145+
optimizer_params = config.get("optimizer", {})
145146
self.multi_task = "model_dict" in model_params
146147
self.finetune_links = finetune_links
147148
self.finetune_update_stat = False
@@ -185,26 +186,17 @@ def __init__(
185186
self.lcurve_should_print_header = True
186187

187188
def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
188-
opt_type = params.get("opt_type", "Adam")
189-
opt_param = {
190-
# LKF parameters
191-
"kf_blocksize": params.get("kf_blocksize", 5120),
192-
"kf_start_pref_e": params.get("kf_start_pref_e", 1),
193-
"kf_limit_pref_e": params.get("kf_limit_pref_e", 1),
194-
"kf_start_pref_f": params.get("kf_start_pref_f", 1),
195-
"kf_limit_pref_f": params.get("kf_limit_pref_f", 1),
196-
# Common parameters
197-
"weight_decay": params.get("weight_decay", 0.001),
198-
# Muon/AdaMuon parameters
199-
"momentum": params.get("momentum", 0.95),
200-
"adam_beta1": params.get("adam_beta1", 0.9),
201-
"adam_beta2": params.get("adam_beta2", 0.95),
202-
"lr_adjust": params.get("lr_adjust", 10.0),
203-
"lr_adjust_coeff": params.get("lr_adjust_coeff", 0.2),
204-
"muon_2d_only": params.get("muon_2d_only", True),
205-
"min_2d_dim": params.get("min_2d_dim", 1),
206-
"flash_muon": params.get("flash_muon", True),
207-
}
189+
"""
190+
Extract optimizer parameters.
191+
192+
Note: Default values are already filled by argcheck.normalize()
193+
before this function is called.
194+
"""
195+
opt_type = params.get("type", "Adam")
196+
if opt_type not in ("Adam", "AdamW", "LKF", "AdaMuon", "HybridMuon"):
197+
raise ValueError(f"Not supported optimizer type '{opt_type}'")
198+
opt_param = dict(params)
199+
opt_param.pop("type", None)
208200
return opt_type, opt_param
209201

210202
def cycle_iterator(iterable: Iterable) -> Generator[Any, None, None]:
@@ -313,22 +305,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
313305
return lr_schedule
314306

315307
# Optimizer
316-
if self.multi_task and training_params.get("optim_dict", None) is not None:
317-
self.optim_dict = training_params.get("optim_dict")
318-
missing_keys = [
319-
key for key in self.model_keys if key not in self.optim_dict
320-
]
321-
assert not missing_keys, (
322-
f"These keys are not in optim_dict: {missing_keys}!"
323-
)
324-
self.opt_type = {}
325-
self.opt_param = {}
326-
for model_key in self.model_keys:
327-
self.opt_type[model_key], self.opt_param[model_key] = get_opt_param(
328-
self.optim_dict[model_key]
329-
)
330-
else:
331-
self.opt_type, self.opt_param = get_opt_param(training_params)
308+
self.opt_type, self.opt_param = get_opt_param(optimizer_params)
332309
if self.zero_stage > 0 and self.multi_task:
333310
raise ValueError(
334311
"training.zero_stage is currently only supported in single-task training."
@@ -782,71 +759,48 @@ def single_model_finetune(
782759
# TODO add optimizers for multitask
783760
# author: iProzd
784761
initial_lr = self.lr_schedule.value(self.start_step)
785-
if self.opt_type in ["Adam", "AdamW"]:
786-
# Initialize optimizer with the actual learning rate at start_step
787-
# to ensure warmup is applied from the first step
788-
if self.opt_type == "Adam":
789-
self.optimizer = self._create_optimizer(
790-
torch.optim.Adam,
791-
lr=initial_lr,
792-
fused=DEVICE.type != "cpu",
793-
)
794-
else:
795-
self.optimizer = self._create_optimizer(
796-
torch.optim.AdamW,
797-
lr=initial_lr,
798-
weight_decay=float(self.opt_param["weight_decay"]),
799-
fused=DEVICE.type != "cpu",
800-
)
801-
self._load_optimizer_state(optimizer_state_dict)
802-
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
803-
self.optimizer,
804-
lambda step: (
805-
self.lr_schedule.value(step + self.start_step) / initial_lr
806-
),
807-
last_epoch=self.start_step - 1,
808-
)
809-
elif self.opt_type == "LKF":
762+
if self.opt_type == "LKF":
810763
self.optimizer = LKFOptimizer(
811764
self.wrapper.parameters(), 0.98, 0.99870, self.opt_param["kf_blocksize"]
812765
)
813-
elif self.opt_type == "AdaMuon":
814-
self.optimizer = self._create_optimizer(
815-
AdaMuonOptimizer,
816-
lr=initial_lr,
817-
momentum=float(self.opt_param["momentum"]),
818-
weight_decay=float(self.opt_param["weight_decay"]),
819-
adam_betas=(
820-
float(self.opt_param["adam_beta1"]),
821-
float(self.opt_param["adam_beta2"]),
822-
),
823-
lr_adjust=float(self.opt_param["lr_adjust"]),
824-
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
825-
)
826-
if optimizer_state_dict is not None and self.restart_training:
827-
self.optimizer.load_state_dict(optimizer_state_dict)
828-
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
829-
self.optimizer,
830-
lambda step: (
831-
self.lr_schedule.value(step + self.start_step) / initial_lr
832-
),
833-
last_epoch=self.start_step - 1,
766+
else:
767+
# === Common path for gradient-based optimizers ===
768+
adam_betas = (
769+
float(self.opt_param["adam_beta1"]),
770+
float(self.opt_param["adam_beta2"]),
834771
)
835-
elif self.opt_type == "HybridMuon":
772+
weight_decay = float(self.opt_param["weight_decay"])
773+
774+
if self.opt_type in ("Adam", "AdamW"):
775+
cls = torch.optim.Adam if self.opt_type == "Adam" else torch.optim.AdamW
776+
extra = {"betas": adam_betas, "fused": DEVICE.type != "cpu"}
777+
elif self.opt_type == "AdaMuon":
778+
cls = AdaMuonOptimizer
779+
extra = {
780+
"adam_betas": adam_betas,
781+
"momentum": float(self.opt_param["momentum"]),
782+
"lr_adjust": float(self.opt_param["lr_adjust"]),
783+
"lr_adjust_coeff": float(self.opt_param["lr_adjust_coeff"]),
784+
}
785+
elif self.opt_type == "HybridMuon":
786+
cls = HybridMuonOptimizer
787+
extra = {
788+
"adam_betas": adam_betas,
789+
"momentum": float(self.opt_param["momentum"]),
790+
"lr_adjust": float(self.opt_param["lr_adjust"]),
791+
"lr_adjust_coeff": float(self.opt_param["lr_adjust_coeff"]),
792+
"muon_2d_only": bool(self.opt_param["muon_2d_only"]),
793+
"min_2d_dim": int(self.opt_param["min_2d_dim"]),
794+
"flash_muon": bool(self.opt_param["flash_muon"]),
795+
}
796+
else:
797+
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
798+
836799
self.optimizer = self._create_optimizer(
837-
HybridMuonOptimizer,
800+
cls,
838801
lr=initial_lr,
839-
momentum=float(self.opt_param["momentum"]),
840-
weight_decay=float(self.opt_param["weight_decay"]),
841-
adam_betas=(
842-
float(self.opt_param["adam_beta1"]),
843-
float(self.opt_param["adam_beta2"]),
844-
),
845-
lr_adjust=float(self.opt_param["lr_adjust"]),
846-
lr_adjust_coeff=float(self.opt_param["lr_adjust_coeff"]),
847-
muon_2d_only=bool(self.opt_param["muon_2d_only"]),
848-
min_2d_dim=int(self.opt_param["min_2d_dim"]),
849-
flash_muon=bool(self.opt_param["flash_muon"]),
802+
weight_decay=weight_decay,
803+
**extra,
850804
)
851805
self._load_optimizer_state(optimizer_state_dict)
852806
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
@@ -856,8 +810,6 @@ def single_model_finetune(
856810
),
857811
last_epoch=self.start_step - 1,
858812
)
859-
else:
860-
raise ValueError(f"Not supported optimizer type '{self.opt_type}'")
861813

862814
if self.zero_stage > 0 and self.rank == 0:
863815
if self.zero_stage == 1:

deepmd/tf/entrypoints/train.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ def train(
167167
jdata["model"] = json.loads(t_training_script)["model"]
168168

169169
jdata = update_deepmd_input(jdata, warning=True, dump="input_v2_compat.json")
170-
171170
jdata = normalize(jdata)
172171

173172
if not is_compress and not skip_neighbor_stat:

deepmd/tf/train/trainer.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,22 @@ def get_lr_and_coef(
136136
# learning rate
137137
lr_param = jdata["learning_rate"]
138138
self.lr, self.scale_lr_coef = get_lr_and_coef(lr_param)
139+
# optimizer
140+
# Note: Default values are already filled by argcheck.normalize()
141+
optimizer_param = jdata.get("optimizer", {})
142+
self.optimizer_type = optimizer_param.get("type", "Adam")
143+
self.optimizer_beta1 = float(optimizer_param.get("adam_beta1"))
144+
self.optimizer_beta2 = float(optimizer_param.get("adam_beta2"))
145+
self.optimizer_weight_decay = float(optimizer_param.get("weight_decay"))
146+
if self.optimizer_type != "Adam":
147+
raise RuntimeError(
148+
f"Unsupported optimizer type {self.optimizer_type} for TensorFlow backend."
149+
)
150+
if self.optimizer_weight_decay != 0.0:
151+
raise RuntimeError(
152+
"TensorFlow Adam optimizer does not support weight_decay. "
153+
"Set optimizer/weight_decay to 0."
154+
)
139155
# loss
140156
# infer loss type by fitting_type
141157
loss_param = jdata.get("loss", {})
@@ -328,17 +344,31 @@ def _build_network(self, data: DeepmdDataSystem, suffix: str = "") -> None:
328344
log.info("built network")
329345

330346
def _build_optimizer(self) -> Any:
347+
if self.optimizer_type != "Adam":
348+
raise RuntimeError(
349+
f"Unsupported optimizer type {self.optimizer_type} for TensorFlow backend."
350+
)
331351
if self.run_opt.is_distrib:
332352
if self.scale_lr_coef > 1.0:
333353
log.info("Scale learning rate by coef: %f", self.scale_lr_coef)
334354
optimizer = tf.train.AdamOptimizer(
335-
self.learning_rate * self.scale_lr_coef
355+
self.learning_rate * self.scale_lr_coef,
356+
beta1=self.optimizer_beta1,
357+
beta2=self.optimizer_beta2,
336358
)
337359
else:
338-
optimizer = tf.train.AdamOptimizer(self.learning_rate)
360+
optimizer = tf.train.AdamOptimizer(
361+
self.learning_rate,
362+
beta1=self.optimizer_beta1,
363+
beta2=self.optimizer_beta2,
364+
)
339365
optimizer = self.run_opt._HVD.DistributedOptimizer(optimizer)
340366
else:
341-
optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate)
367+
optimizer = tf.train.AdamOptimizer(
368+
learning_rate=self.learning_rate,
369+
beta1=self.optimizer_beta1,
370+
beta2=self.optimizer_beta2,
371+
)
342372

343373
if self.mixed_prec is not None:
344374
_TF_VERSION = Version(TF_VERSION)

deepmd/tf/utils/compat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from deepmd.utils.compat import (
55
convert_input_v0_v1,
66
convert_input_v1_v2,
7+
convert_optimizer_v31_to_v32,
78
deprecate_numb_test,
89
update_deepmd_input,
910
)
1011

1112
__all__ = [
1213
"convert_input_v0_v1",
1314
"convert_input_v1_v2",
15+
"convert_optimizer_v31_to_v32",
1416
"deprecate_numb_test",
1517
"update_deepmd_input",
1618
]

0 commit comments

Comments
 (0)