diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index eab5601d55..af8f46d7a9 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -158,6 +158,7 @@ def get_opt_param(params): "kf_limit_pref_e": params.get("kf_limit_pref_e", 1), "kf_start_pref_f": params.get("kf_start_pref_f", 1), "kf_limit_pref_f": params.get("kf_limit_pref_f", 1), + "weight_decay": params.get("weight_decay", 0.001), } return opt_type, opt_param @@ -609,12 +610,20 @@ def warm_up_linear(step, warmup_steps): # TODO add optimizers for multitask # author: iProzd - if self.opt_type == "Adam": - self.optimizer = torch.optim.Adam( - self.wrapper.parameters(), - lr=self.lr_exp.start_lr, - fused=False if DEVICE.type == "cpu" else True, - ) + if self.opt_type in ["Adam", "AdamW"]: + if self.opt_type == "Adam": + self.optimizer = torch.optim.Adam( + self.wrapper.parameters(), + lr=self.lr_exp.start_lr, + fused=False if DEVICE.type == "cpu" else True, + ) + else: + self.optimizer = torch.optim.AdamW( + self.wrapper.parameters(), + lr=self.lr_exp.start_lr, + weight_decay=float(self.opt_param["weight_decay"]), + fused=False if DEVICE.type == "cpu" else True, + ) if optimizer_state_dict is not None and self.restart_training: self.optimizer.load_state_dict(optimizer_state_dict) self.scheduler = torch.optim.lr_scheduler.LambdaLR( @@ -710,7 +719,7 @@ def step(_step_id, task_key="Default") -> None: print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" fout1.write(print_str) fout1.flush() - if self.opt_type == "Adam": + if self.opt_type in ["Adam", "AdamW"]: cur_lr = self.scheduler.get_last_lr()[0] if _step_id < self.warmup_steps: pref_lr = _lr.start_lr diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0260700165..dbfb71459e 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3177,6 +3177,7 @@ def training_args( "opt_type", choices=[ Argument("Adam", dict, [], [], optional=True), + Argument("AdamW", dict, [], [], optional=True), Argument( "LKF", dict,