@@ -142,6 +142,7 @@ def __init__(
142142 self .restart_training = restart_model is not None
143143 model_params = config ["model" ]
144144 training_params = config ["training" ]
145+ optimizer_params = config .get ("optimizer" , {})
145146 self .multi_task = "model_dict" in model_params
146147 self .finetune_links = finetune_links
147148 self .finetune_update_stat = False
@@ -185,26 +186,17 @@ def __init__(
185186 self .lcurve_should_print_header = True
186187
187188 def get_opt_param (params : dict [str , Any ]) -> tuple [str , dict [str , Any ]]:
188- opt_type = params .get ("opt_type" , "Adam" )
189- opt_param = {
190- # LKF parameters
191- "kf_blocksize" : params .get ("kf_blocksize" , 5120 ),
192- "kf_start_pref_e" : params .get ("kf_start_pref_e" , 1 ),
193- "kf_limit_pref_e" : params .get ("kf_limit_pref_e" , 1 ),
194- "kf_start_pref_f" : params .get ("kf_start_pref_f" , 1 ),
195- "kf_limit_pref_f" : params .get ("kf_limit_pref_f" , 1 ),
196- # Common parameters
197- "weight_decay" : params .get ("weight_decay" , 0.001 ),
198- # Muon/AdaMuon parameters
199- "momentum" : params .get ("momentum" , 0.95 ),
200- "adam_beta1" : params .get ("adam_beta1" , 0.9 ),
201- "adam_beta2" : params .get ("adam_beta2" , 0.95 ),
202- "lr_adjust" : params .get ("lr_adjust" , 10.0 ),
203- "lr_adjust_coeff" : params .get ("lr_adjust_coeff" , 0.2 ),
204- "muon_2d_only" : params .get ("muon_2d_only" , True ),
205- "min_2d_dim" : params .get ("min_2d_dim" , 1 ),
206- "flash_muon" : params .get ("flash_muon" , True ),
207- }
189+ """
190+ Extract optimizer parameters.
191+
192+ Note: Default values are already filled by argcheck.normalize()
193+ before this function is called.
194+ """
195+ opt_type = params .get ("type" , "Adam" )
196+ if opt_type not in ("Adam" , "AdamW" , "LKF" , "AdaMuon" , "HybridMuon" ):
197+ raise ValueError (f"Not supported optimizer type '{ opt_type } '" )
198+ opt_param = dict (params )
199+ opt_param .pop ("type" , None )
208200 return opt_type , opt_param
209201
210202 def cycle_iterator (iterable : Iterable ) -> Generator [Any , None , None ]:
@@ -313,22 +305,7 @@ def get_lr(lr_params: dict[str, Any]) -> BaseLR:
313305 return lr_schedule
314306
315307 # Optimizer
316- if self .multi_task and training_params .get ("optim_dict" , None ) is not None :
317- self .optim_dict = training_params .get ("optim_dict" )
318- missing_keys = [
319- key for key in self .model_keys if key not in self .optim_dict
320- ]
321- assert not missing_keys , (
322- f"These keys are not in optim_dict: { missing_keys } !"
323- )
324- self .opt_type = {}
325- self .opt_param = {}
326- for model_key in self .model_keys :
327- self .opt_type [model_key ], self .opt_param [model_key ] = get_opt_param (
328- self .optim_dict [model_key ]
329- )
330- else :
331- self .opt_type , self .opt_param = get_opt_param (training_params )
308+ self .opt_type , self .opt_param = get_opt_param (optimizer_params )
332309 if self .zero_stage > 0 and self .multi_task :
333310 raise ValueError (
334311 "training.zero_stage is currently only supported in single-task training."
@@ -782,71 +759,48 @@ def single_model_finetune(
782759 # TODO add optimizers for multitask
783760 # author: iProzd
784761 initial_lr = self .lr_schedule .value (self .start_step )
785- if self .opt_type in ["Adam" , "AdamW" ]:
786- # Initialize optimizer with the actual learning rate at start_step
787- # to ensure warmup is applied from the first step
788- if self .opt_type == "Adam" :
789- self .optimizer = self ._create_optimizer (
790- torch .optim .Adam ,
791- lr = initial_lr ,
792- fused = DEVICE .type != "cpu" ,
793- )
794- else :
795- self .optimizer = self ._create_optimizer (
796- torch .optim .AdamW ,
797- lr = initial_lr ,
798- weight_decay = float (self .opt_param ["weight_decay" ]),
799- fused = DEVICE .type != "cpu" ,
800- )
801- self ._load_optimizer_state (optimizer_state_dict )
802- self .scheduler = torch .optim .lr_scheduler .LambdaLR (
803- self .optimizer ,
804- lambda step : (
805- self .lr_schedule .value (step + self .start_step ) / initial_lr
806- ),
807- last_epoch = self .start_step - 1 ,
808- )
809- elif self .opt_type == "LKF" :
762+ if self .opt_type == "LKF" :
810763 self .optimizer = LKFOptimizer (
811764 self .wrapper .parameters (), 0.98 , 0.99870 , self .opt_param ["kf_blocksize" ]
812765 )
813- elif self .opt_type == "AdaMuon" :
814- self .optimizer = self ._create_optimizer (
815- AdaMuonOptimizer ,
816- lr = initial_lr ,
817- momentum = float (self .opt_param ["momentum" ]),
818- weight_decay = float (self .opt_param ["weight_decay" ]),
819- adam_betas = (
820- float (self .opt_param ["adam_beta1" ]),
821- float (self .opt_param ["adam_beta2" ]),
822- ),
823- lr_adjust = float (self .opt_param ["lr_adjust" ]),
824- lr_adjust_coeff = float (self .opt_param ["lr_adjust_coeff" ]),
825- )
826- if optimizer_state_dict is not None and self .restart_training :
827- self .optimizer .load_state_dict (optimizer_state_dict )
828- self .scheduler = torch .optim .lr_scheduler .LambdaLR (
829- self .optimizer ,
830- lambda step : (
831- self .lr_schedule .value (step + self .start_step ) / initial_lr
832- ),
833- last_epoch = self .start_step - 1 ,
766+ else :
767+ # === Common path for gradient-based optimizers ===
768+ adam_betas = (
769+ float (self .opt_param ["adam_beta1" ]),
770+ float (self .opt_param ["adam_beta2" ]),
834771 )
835- elif self .opt_type == "HybridMuon" :
772+ weight_decay = float (self .opt_param ["weight_decay" ])
773+
774+ if self .opt_type in ("Adam" , "AdamW" ):
775+ cls = torch .optim .Adam if self .opt_type == "Adam" else torch .optim .AdamW
776+ extra = {"betas" : adam_betas , "fused" : DEVICE .type != "cpu" }
777+ elif self .opt_type == "AdaMuon" :
778+ cls = AdaMuonOptimizer
779+ extra = {
780+ "adam_betas" : adam_betas ,
781+ "momentum" : float (self .opt_param ["momentum" ]),
782+ "lr_adjust" : float (self .opt_param ["lr_adjust" ]),
783+ "lr_adjust_coeff" : float (self .opt_param ["lr_adjust_coeff" ]),
784+ }
785+ elif self .opt_type == "HybridMuon" :
786+ cls = HybridMuonOptimizer
787+ extra = {
788+ "adam_betas" : adam_betas ,
789+ "momentum" : float (self .opt_param ["momentum" ]),
790+ "lr_adjust" : float (self .opt_param ["lr_adjust" ]),
791+ "lr_adjust_coeff" : float (self .opt_param ["lr_adjust_coeff" ]),
792+ "muon_2d_only" : bool (self .opt_param ["muon_2d_only" ]),
793+ "min_2d_dim" : int (self .opt_param ["min_2d_dim" ]),
794+ "flash_muon" : bool (self .opt_param ["flash_muon" ]),
795+ }
796+ else :
797+ raise ValueError (f"Not supported optimizer type '{ self .opt_type } '" )
798+
836799 self .optimizer = self ._create_optimizer (
837- HybridMuonOptimizer ,
800+ cls ,
838801 lr = initial_lr ,
839- momentum = float (self .opt_param ["momentum" ]),
840- weight_decay = float (self .opt_param ["weight_decay" ]),
841- adam_betas = (
842- float (self .opt_param ["adam_beta1" ]),
843- float (self .opt_param ["adam_beta2" ]),
844- ),
845- lr_adjust = float (self .opt_param ["lr_adjust" ]),
846- lr_adjust_coeff = float (self .opt_param ["lr_adjust_coeff" ]),
847- muon_2d_only = bool (self .opt_param ["muon_2d_only" ]),
848- min_2d_dim = int (self .opt_param ["min_2d_dim" ]),
849- flash_muon = bool (self .opt_param ["flash_muon" ]),
802+ weight_decay = weight_decay ,
803+ ** extra ,
850804 )
851805 self ._load_optimizer_state (optimizer_state_dict )
852806 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
@@ -856,8 +810,6 @@ def single_model_finetune(
856810 ),
857811 last_epoch = self .start_step - 1 ,
858812 )
859- else :
860- raise ValueError (f"Not supported optimizer type '{ self .opt_type } '" )
861813
862814 if self .zero_stage > 0 and self .rank == 0 :
863815 if self .zero_stage == 1 :
0 commit comments