@@ -299,7 +299,7 @@ def get_sample() -> Any:
299299 return get_sample
300300
301301 def get_lr (lr_params : dict [str , Any ]) -> BaseLR :
302- lr_params ["stop_steps " ] = self .num_steps - self . warmup_steps
302+ lr_params ["num_steps " ] = self .num_steps
303303 lr_schedule = BaseLR (** lr_params )
304304 return lr_schedule
305305
@@ -463,27 +463,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
463463 )
464464
465465 # Learning rate
466- warmup_steps = training_params .get ("warmup_steps" , None )
467- warmup_ratio = training_params .get ("warmup_ratio" , None )
468- if warmup_steps is not None :
469- self .warmup_steps = warmup_steps
470- elif warmup_ratio is not None :
471- if not 0 <= warmup_ratio < 1 :
472- raise ValueError (f"warmup_ratio must be in [0, 1), got { warmup_ratio } " )
473- self .warmup_steps = int (warmup_ratio * self .num_steps )
474- if self .warmup_steps == 0 and warmup_ratio > 0 :
475- log .warning (
476- f"warmup_ratio { warmup_ratio } results in 0 warmup steps "
477- f"due to truncation. Consider using a larger ratio or "
478- f"specify warmup_steps directly."
479- )
480- else :
481- self .warmup_steps = 0
482- self .warmup_start_factor = training_params .get ("warmup_start_factor" , 0.0 )
483466 self .gradient_max_norm = training_params .get ("gradient_max_norm" , 0.0 )
484- assert self .num_steps - self .warmup_steps > 0 or self .warmup_steps == 0 , (
485- "Warm up steps must be less than total training steps!"
486- )
487467 if self .multi_task and config .get ("learning_rate_dict" , None ) is not None :
488468 self .lr_exp = {}
489469 for model_key in self .model_keys :
@@ -738,34 +718,30 @@ def single_model_finetune(
738718
739719 # TODO add lr warmups for multitask
740720 # author: iProzd
741- def warm_up_linear (step : int , warmup_steps : int ) -> float :
742- if step < warmup_steps :
743- return self .warmup_start_factor + (1.0 - self .warmup_start_factor ) * (
744- step / warmup_steps
745- )
746- else :
747- return self .lr_exp .value (step - warmup_steps ) / self .lr_exp .start_lr
748-
749721 # TODO add optimizers for multitask
750722 # author: iProzd
723+ initial_lr = self .lr_exp .value (self .start_step )
751724 if self .opt_type in ["Adam" , "AdamW" ]:
725+ # Initialize optimizer with the actual learning rate at start_step
726+ # to ensure warmup is applied from the first step
752727 if self .opt_type == "Adam" :
753728 self .optimizer = self ._create_optimizer (
754729 torch .optim .Adam ,
755- lr = self . lr_exp . start_lr ,
730+ lr = initial_lr ,
756731 fused = DEVICE .type != "cpu" ,
757732 )
758733 else :
759734 self .optimizer = self ._create_optimizer (
760735 torch .optim .AdamW ,
761- lr = self . lr_exp . start_lr ,
736+ lr = initial_lr ,
762737 weight_decay = float (self .opt_param ["weight_decay" ]),
763738 fused = DEVICE .type != "cpu" ,
764739 )
765740 self ._load_optimizer_state (optimizer_state_dict )
766741 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
767742 self .optimizer ,
768- lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
743+ lambda step : self .lr_exp .value (step + self .start_step ) / initial_lr ,
744+ last_epoch = self .start_step - 1 ,
769745 )
770746 elif self .opt_type == "LKF" :
771747 self .optimizer = LKFOptimizer (
@@ -774,7 +750,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
774750 elif self .opt_type == "AdaMuon" :
775751 self .optimizer = self ._create_optimizer (
776752 AdaMuonOptimizer ,
777- lr = self . lr_exp . start_lr ,
753+ lr = initial_lr ,
778754 momentum = float (self .opt_param ["momentum" ]),
779755 weight_decay = float (self .opt_param ["weight_decay" ]),
780756 adam_betas = (
@@ -784,10 +760,17 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
784760 lr_adjust = float (self .opt_param ["lr_adjust" ]),
785761 lr_adjust_coeff = float (self .opt_param ["lr_adjust_coeff" ]),
786762 )
763+ if optimizer_state_dict is not None and self .restart_training :
764+ self .optimizer .load_state_dict (optimizer_state_dict )
765+ self .scheduler = torch .optim .lr_scheduler .LambdaLR (
766+ self .optimizer ,
767+ lambda step : self .lr_exp .value (step + self .start_step ) / initial_lr ,
768+ last_epoch = self .start_step - 1 ,
769+ )
787770 elif self .opt_type == "HybridMuon" :
788771 self .optimizer = self ._create_optimizer (
789772 HybridMuonOptimizer ,
790- lr = self . lr_exp . start_lr ,
773+ lr = initial_lr ,
791774 momentum = float (self .opt_param ["momentum" ]),
792775 weight_decay = float (self .opt_param ["weight_decay" ]),
793776 adam_betas = (
@@ -802,7 +785,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
802785 self ._load_optimizer_state (optimizer_state_dict )
803786 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
804787 self .optimizer ,
805- lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
788+ lambda step : self .lr_exp .value (step + self .start_step ) / initial_lr ,
789+ last_epoch = self .start_step - 1 ,
806790 )
807791 else :
808792 raise ValueError (f"Not supported optimizer type '{ self .opt_type } '" )
@@ -980,10 +964,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
980964 fout1 .flush ()
981965 if self .opt_type in ["Adam" , "AdamW" , "AdaMuon" , "HybridMuon" ]:
982966 cur_lr = self .scheduler .get_last_lr ()[0 ]
983- if _step_id < self .warmup_steps :
984- pref_lr = _lr .start_lr
985- else :
986- pref_lr = cur_lr
967+ pref_lr = cur_lr
987968 model_pred , loss , more_loss = self .wrapper (
988969 ** input_dict , cur_lr = pref_lr , label = label_dict , task_key = task_key
989970 )
0 commit comments