@@ -279,7 +279,7 @@ def get_sample() -> Any:
279279 return get_sample
280280
281281 def get_lr (lr_params : dict [str , Any ]) -> BaseLR :
282- lr_params ["stop_steps " ] = self .num_steps - self . warmup_steps
282+ lr_params ["num_steps " ] = self .num_steps
283283 lr_schedule = BaseLR (** lr_params )
284284 return lr_schedule
285285
@@ -437,27 +437,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
437437 )
438438
439439 # Learning rate
440- warmup_steps = training_params .get ("warmup_steps" , None )
441- warmup_ratio = training_params .get ("warmup_ratio" , None )
442- if warmup_steps is not None :
443- self .warmup_steps = warmup_steps
444- elif warmup_ratio is not None :
445- if not 0 <= warmup_ratio < 1 :
446- raise ValueError (f"warmup_ratio must be in [0, 1), got { warmup_ratio } " )
447- self .warmup_steps = int (warmup_ratio * self .num_steps )
448- if self .warmup_steps == 0 and warmup_ratio > 0 :
449- log .warning (
450- f"warmup_ratio { warmup_ratio } results in 0 warmup steps "
451- f"due to truncation. Consider using a larger ratio or "
452- f"specify warmup_steps directly."
453- )
454- else :
455- self .warmup_steps = 0
456- self .warmup_start_factor = training_params .get ("warmup_start_factor" , 0.0 )
457440 self .gradient_max_norm = training_params .get ("gradient_max_norm" , 0.0 )
458- assert self .num_steps - self .warmup_steps > 0 or self .warmup_steps == 0 , (
459- "Warm up steps must be less than total training steps!"
460- )
461441 if self .multi_task and config .get ("learning_rate_dict" , None ) is not None :
462442 self .lr_exp = {}
463443 for model_key in self .model_keys :
@@ -702,44 +682,43 @@ def single_model_finetune(
702682
703683 # TODO add lr warmups for multitask
704684 # author: iProzd
705- def warm_up_linear (step : int , warmup_steps : int ) -> float :
706- if step < warmup_steps :
707- return self .warmup_start_factor + (1.0 - self .warmup_start_factor ) * (
708- step / warmup_steps
709- )
710- else :
711- return self .lr_exp .value (step - warmup_steps ) / self .lr_exp .start_lr
712-
713685 # TODO add optimizers for multitask
714686 # author: iProzd
715687 if self .opt_type in ["Adam" , "AdamW" ]:
688+ # Initialize optimizer with the actual learning rate at start_step
689+ # to ensure warmup is applied from the first step
690+ initial_lr = self .lr_exp .value (self .start_step )
716691 if self .opt_type == "Adam" :
717692 self .optimizer = torch .optim .Adam (
718693 self .wrapper .parameters (),
719- lr = self . lr_exp . start_lr ,
694+ lr = initial_lr ,
720695 fused = False if DEVICE .type == "cpu" else True ,
721696 )
722697 else :
723698 self .optimizer = torch .optim .AdamW (
724699 self .wrapper .parameters (),
725- lr = self . lr_exp . start_lr ,
700+ lr = initial_lr ,
726701 weight_decay = float (self .opt_param ["weight_decay" ]),
727702 fused = False if DEVICE .type == "cpu" else True ,
728703 )
729704 if optimizer_state_dict is not None and self .restart_training :
730705 self .optimizer .load_state_dict (optimizer_state_dict )
731706 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
732707 self .optimizer ,
733- lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
708+ lambda step : self .lr_exp .value (step + self .start_step ) / initial_lr ,
709+ last_epoch = self .start_step - 1 ,
734710 )
735711 elif self .opt_type == "LKF" :
736712 self .optimizer = LKFOptimizer (
737713 self .wrapper .parameters (), 0.98 , 0.99870 , self .opt_param ["kf_blocksize" ]
738714 )
739715 elif self .opt_type == "AdaMuon" :
716+ # Initialize optimizer with the actual learning rate at start_step
717+ # to ensure warmup is applied from the first step
718+ initial_lr = self .lr_exp .value (self .start_step )
740719 self .optimizer = AdaMuonOptimizer (
741720 self .wrapper .parameters (),
742- lr = self . lr_exp . start_lr ,
721+ lr = initial_lr ,
743722 momentum = float (self .opt_param ["momentum" ]),
744723 weight_decay = float (self .opt_param ["weight_decay" ]),
745724 adam_betas = (
@@ -749,10 +728,20 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
749728 lr_adjust = float (self .opt_param ["lr_adjust" ]),
750729 lr_adjust_coeff = float (self .opt_param ["lr_adjust_coeff" ]),
751730 )
731+ if optimizer_state_dict is not None and self .restart_training :
732+ self .optimizer .load_state_dict (optimizer_state_dict )
733+ self .scheduler = torch .optim .lr_scheduler .LambdaLR (
734+ self .optimizer ,
735+ lambda step : self .lr_exp .value (step + self .start_step ) / initial_lr ,
736+ last_epoch = self .start_step - 1 ,
737+ )
752738 elif self .opt_type == "HybridMuon" :
739+ # Initialize optimizer with the actual learning rate at start_step
740+ # to ensure warmup is applied from the first step
741+ initial_lr = self .lr_exp .value (self .start_step )
753742 self .optimizer = HybridMuonOptimizer (
754743 self .wrapper .parameters (),
755- lr = self . lr_exp . start_lr ,
744+ lr = initial_lr ,
756745 momentum = float (self .opt_param ["momentum" ]),
757746 weight_decay = float (self .opt_param ["weight_decay" ]),
758747 adam_betas = (
@@ -768,7 +757,8 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
768757 self .optimizer .load_state_dict (optimizer_state_dict )
769758 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
770759 self .optimizer ,
771- lambda step : warm_up_linear (step + self .start_step , self .warmup_steps ),
760+ lambda step : self .lr_exp .value (step + self .start_step ) / initial_lr ,
761+ last_epoch = self .start_step - 1 ,
772762 )
773763 else :
774764 raise ValueError (f"Not supported optimizer type '{ self .opt_type } '" )
@@ -883,10 +873,7 @@ def step(_step_id: int, task_key: str = "Default") -> None:
883873 fout1 .flush ()
884874 if self .opt_type in ["Adam" , "AdamW" , "AdaMuon" , "HybridMuon" ]:
885875 cur_lr = self .scheduler .get_last_lr ()[0 ]
886- if _step_id < self .warmup_steps :
887- pref_lr = _lr .start_lr
888- else :
889- pref_lr = cur_lr
876+ pref_lr = cur_lr
890877 model_pred , loss , more_loss = self .wrapper (
891878 ** input_dict , cur_lr = pref_lr , label = label_dict , task_key = task_key
892879 )
0 commit comments