Skip to content

Commit 69961fb

Browse files
committed
feat: add adamW
1 parent 40a0833 commit 69961fb

2 files changed

Lines changed: 14 additions & 5 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def get_opt_param(params):
155155
"kf_limit_pref_e": params.get("kf_limit_pref_e", 1),
156156
"kf_start_pref_f": params.get("kf_start_pref_f", 1),
157157
"kf_limit_pref_f": params.get("kf_limit_pref_f", 1),
158+
"weight_decay": params.get("weight_decay", 0.001),
158159
}
159160
return opt_type, opt_param
160161

@@ -577,10 +578,17 @@ def warm_up_linear(step, warmup_steps):
577578

578579
# TODO add optimizers for multitask
579580
# author: iProzd
580-
if self.opt_type == "Adam":
581-
self.optimizer = torch.optim.Adam(
582-
self.wrapper.parameters(), lr=self.lr_exp.start_lr, fused=True
583-
)
581+
if self.opt_type in ["Adam", "AdamW"]:
582+
if self.opt_type == "Adam":
583+
self.optimizer = torch.optim.Adam(
584+
self.wrapper.parameters(), lr=self.lr_exp.start_lr, fused=True
585+
)
586+
else:
587+
self.optimizer = torch.optim.AdamW(
588+
self.wrapper.parameters(),
589+
lr=self.lr_exp.start_lr,
590+
weight_decay=self.opt_param["weight_decay"],
591+
)
584592
if optimizer_state_dict is not None and self.restart_training:
585593
self.optimizer.load_state_dict(optimizer_state_dict)
586594
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
@@ -676,7 +684,7 @@ def step(_step_id, task_key="Default") -> None:
676684
print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n"
677685
fout1.write(print_str)
678686
fout1.flush()
679-
if self.opt_type == "Adam":
687+
if self.opt_type in ["Adam", "AdamW"]:
680688
cur_lr = self.scheduler.get_last_lr()[0]
681689
if _step_id < self.warmup_steps:
682690
pref_lr = _lr.start_lr

deepmd/utils/argcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3327,6 +3327,7 @@ def training_args(
33273327
"opt_type",
33283328
choices=[
33293329
Argument("Adam", dict, [], [], optional=True),
3330+
Argument("AdamW", dict, [], [], optional=True),
33303331
Argument(
33313332
"LKF",
33323333
dict,

0 commit comments

Comments
 (0)