@@ -155,6 +155,7 @@ def get_opt_param(params):
155155 "kf_limit_pref_e" : params .get ("kf_limit_pref_e" , 1 ),
156156 "kf_start_pref_f" : params .get ("kf_start_pref_f" , 1 ),
157157 "kf_limit_pref_f" : params .get ("kf_limit_pref_f" , 1 ),
158+ "weight_decay" : params .get ("weight_decay" , 0.001 ),
158159 }
159160 return opt_type , opt_param
160161
@@ -577,10 +578,17 @@ def warm_up_linear(step, warmup_steps):
577578
578579 # TODO add optimizers for multitask
579580 # author: iProzd
580- if self .opt_type == "Adam" :
581- self .optimizer = torch .optim .Adam (
582- self .wrapper .parameters (), lr = self .lr_exp .start_lr , fused = True
583- )
581+ if self .opt_type in ["Adam" , "AdamW" ]:
582+ if self .opt_type == "Adam" :
583+ self .optimizer = torch .optim .Adam (
584+ self .wrapper .parameters (), lr = self .lr_exp .start_lr , fused = True
585+ )
586+ else :
587+ self .optimizer = torch .optim .AdamW (
588+ self .wrapper .parameters (),
589+ lr = self .lr_exp .start_lr ,
590+ weight_decay = self .opt_param ["weight_decay" ],
591+ )
584592 if optimizer_state_dict is not None and self .restart_training :
585593 self .optimizer .load_state_dict (optimizer_state_dict )
586594 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
@@ -676,7 +684,7 @@ def step(_step_id, task_key="Default") -> None:
676684 print_str = f"Step { _step_id } : sample system{ log_dict ['sid' ]} frame{ log_dict ['fid' ]} \n "
677685 fout1 .write (print_str )
678686 fout1 .flush ()
679- if self .opt_type == "Adam" :
687+ if self .opt_type in [ "Adam" , "AdamW" ] :
680688 cur_lr = self .scheduler .get_last_lr ()[0 ]
681689 if _step_id < self .warmup_steps :
682690 pref_lr = _lr .start_lr
0 commit comments