@@ -158,6 +158,7 @@ def get_opt_param(params):
158158 "kf_limit_pref_e" : params .get ("kf_limit_pref_e" , 1 ),
159159 "kf_start_pref_f" : params .get ("kf_start_pref_f" , 1 ),
160160 "kf_limit_pref_f" : params .get ("kf_limit_pref_f" , 1 ),
161+ "weight_decay" : params .get ("weight_decay" , 0.001 ),
161162 }
162163 return opt_type , opt_param
163164
@@ -609,12 +610,19 @@ def warm_up_linear(step, warmup_steps):
609610
610611 # TODO add optimizers for multitask
611612 # author: iProzd
612- if self .opt_type == "Adam" :
613- self .optimizer = torch .optim .Adam (
614- self .wrapper .parameters (),
615- lr = self .lr_exp .start_lr ,
616- fused = False if DEVICE .type == "cpu" else True ,
617- )
613+ if self .opt_type in ["Adam" , "AdamW" ]:
614+ if self .opt_type == "Adam" :
615+ self .optimizer = torch .optim .Adam (
616+ self .wrapper .parameters (),
617+ lr = self .lr_exp .start_lr ,
618+ fused = False if DEVICE .type == "cpu" else True ,
619+ )
620+ else :
621+ self .optimizer = torch .optim .AdamW (
622+ self .wrapper .parameters (),
623+ lr = self .lr_exp .start_lr ,
624+ weight_decay = float (self .opt_param ["weight_decay" ]),
625+ )
618626 if optimizer_state_dict is not None and self .restart_training :
619627 self .optimizer .load_state_dict (optimizer_state_dict )
620628 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
@@ -710,7 +718,7 @@ def step(_step_id, task_key="Default") -> None:
710718 print_str = f"Step { _step_id } : sample system{ log_dict ['sid' ]} frame{ log_dict ['fid' ]} \n "
711719 fout1 .write (print_str )
712720 fout1 .flush ()
713- if self .opt_type == "Adam" :
721+ if self .opt_type in [ "Adam" , "AdamW" ] :
714722 cur_lr = self .scheduler .get_last_lr ()[0 ]
715723 if _step_id < self .warmup_steps :
716724 pref_lr = _lr .start_lr
0 commit comments