Skip to content

Commit 5e79487

Browse files
committed
feat(pt): add AdamW for pt training
1 parent 8176173 commit 5e79487

2 files changed

Lines changed: 16 additions & 7 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def get_opt_param(params):
158158
"kf_limit_pref_e": params.get("kf_limit_pref_e", 1),
159159
"kf_start_pref_f": params.get("kf_start_pref_f", 1),
160160
"kf_limit_pref_f": params.get("kf_limit_pref_f", 1),
161+
"weight_decay": params.get("weight_decay", 0.001),
161162
}
162163
return opt_type, opt_param
163164

@@ -609,12 +610,19 @@ def warm_up_linear(step, warmup_steps):
609610

610611
# TODO add optimizers for multitask
611612
# author: iProzd
612-
if self.opt_type == "Adam":
613-
self.optimizer = torch.optim.Adam(
614-
self.wrapper.parameters(),
615-
lr=self.lr_exp.start_lr,
616-
fused=False if DEVICE.type == "cpu" else True,
617-
)
613+
if self.opt_type in ["Adam", "AdamW"]:
614+
if self.opt_type == "Adam":
615+
self.optimizer = torch.optim.Adam(
616+
self.wrapper.parameters(),
617+
lr=self.lr_exp.start_lr,
618+
fused=False if DEVICE.type == "cpu" else True,
619+
)
620+
else:
621+
self.optimizer = torch.optim.AdamW(
622+
self.wrapper.parameters(),
623+
lr=self.lr_exp.start_lr,
624+
weight_decay=float(self.opt_param["weight_decay"]),
625+
)
618626
if optimizer_state_dict is not None and self.restart_training:
619627
self.optimizer.load_state_dict(optimizer_state_dict)
620628
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
@@ -710,7 +718,7 @@ def step(_step_id, task_key="Default") -> None:
710718
print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n"
711719
fout1.write(print_str)
712720
fout1.flush()
713-
if self.opt_type == "Adam":
721+
if self.opt_type in ["Adam", "AdamW"]:
714722
cur_lr = self.scheduler.get_last_lr()[0]
715723
if _step_id < self.warmup_steps:
716724
pref_lr = _lr.start_lr

deepmd/utils/argcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3177,6 +3177,7 @@ def training_args(
31773177
"opt_type",
31783178
choices=[
31793179
Argument("Adam", dict, [], [], optional=True),
3180+
Argument("AdamW", dict, [], [], optional=True),
31803181
Argument(
31813182
"LKF",
31823183
dict,

0 commit comments

Comments
 (0)