Skip to content

Commit 2279459

Browse files
authored
feat(pt): add AdamW for pt training (#4757)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - **New Features** - Added support for the "AdamW" optimizer in training configurations. - Introduced a "weight_decay" parameter for optimizer settings, with a default value of 0.001. - **Chores** - Updated configuration options to allow selection of "AdamW" as an optimizer type. <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent cb78ec0 commit 2279459

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

deepmd/pt/train/training.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def get_opt_param(params):
158158
"kf_limit_pref_e": params.get("kf_limit_pref_e", 1),
159159
"kf_start_pref_f": params.get("kf_start_pref_f", 1),
160160
"kf_limit_pref_f": params.get("kf_limit_pref_f", 1),
161+
"weight_decay": params.get("weight_decay", 0.001),
161162
}
162163
return opt_type, opt_param
163164

@@ -609,12 +610,20 @@ def warm_up_linear(step, warmup_steps):
609610

610611
# TODO add optimizers for multitask
611612
# author: iProzd
612-
if self.opt_type == "Adam":
613-
self.optimizer = torch.optim.Adam(
614-
self.wrapper.parameters(),
615-
lr=self.lr_exp.start_lr,
616-
fused=False if DEVICE.type == "cpu" else True,
617-
)
613+
if self.opt_type in ["Adam", "AdamW"]:
614+
if self.opt_type == "Adam":
615+
self.optimizer = torch.optim.Adam(
616+
self.wrapper.parameters(),
617+
lr=self.lr_exp.start_lr,
618+
fused=False if DEVICE.type == "cpu" else True,
619+
)
620+
else:
621+
self.optimizer = torch.optim.AdamW(
622+
self.wrapper.parameters(),
623+
lr=self.lr_exp.start_lr,
624+
weight_decay=float(self.opt_param["weight_decay"]),
625+
fused=False if DEVICE.type == "cpu" else True,
626+
)
618627
if optimizer_state_dict is not None and self.restart_training:
619628
self.optimizer.load_state_dict(optimizer_state_dict)
620629
self.scheduler = torch.optim.lr_scheduler.LambdaLR(
@@ -710,7 +719,7 @@ def step(_step_id, task_key="Default") -> None:
710719
print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n"
711720
fout1.write(print_str)
712721
fout1.flush()
713-
if self.opt_type == "Adam":
722+
if self.opt_type in ["Adam", "AdamW"]:
714723
cur_lr = self.scheduler.get_last_lr()[0]
715724
if _step_id < self.warmup_steps:
716725
pref_lr = _lr.start_lr

deepmd/utils/argcheck.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3177,6 +3177,7 @@ def training_args(
31773177
"opt_type",
31783178
choices=[
31793179
Argument("Adam", dict, [], [], optional=True),
3180+
Argument("AdamW", dict, [], [], optional=True),
31803181
Argument(
31813182
"LKF",
31823183
dict,

0 commit comments

Comments
 (0)