File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -939,12 +939,10 @@ def single_model_finetune(
939939 ** extra ,
940940 )
941941 self ._load_optimizer_state (optimizer_state_dict )
942- self .scheduler = torch . optim . lr_scheduler . LambdaLR (
942+ self .scheduler = self . _create_lr_scheduler (
943943 self .optimizer ,
944- lambda step : (
945- self .lr_schedule .value (step + self .start_step ) / initial_lr
946- ),
947- last_epoch = self .start_step - 1 ,
944+ self .lr_schedule ,
945+ self .start_step ,
948946 )
949947
950948 if self .zero_stage > 0 and self .rank == 0 :
@@ -975,6 +973,21 @@ def single_model_finetune(
975973 if self .rank == 0 :
976974 self ._log_parameter_count ()
977975
976+ @staticmethod
977+ def _create_lr_scheduler (
978+ optimizer : torch .optim .Optimizer ,
979+ lr_schedule : BaseLR ,
980+ start_step : int ,
981+ ) -> torch .optim .lr_scheduler .LambdaLR :
982+ base_lr = float (lr_schedule .start_lr )
983+ for group in optimizer .param_groups :
984+ group ["initial_lr" ] = base_lr
985+ return torch .optim .lr_scheduler .LambdaLR (
986+ optimizer ,
987+ lambda step : lr_schedule .value (step ) / base_lr ,
988+ last_epoch = start_step - 1 ,
989+ )
990+
978991 def _create_full_validator (
979992 self ,
980993 * ,
You can’t perform that action at this time.
0 commit comments