@@ -784,71 +784,48 @@ def single_model_finetune(
784784 # TODO add optimizers for multitask
785785 # author: iProzd
786786 initial_lr = self .lr_schedule .value (self .start_step )
787- if self .opt_type in ["Adam" , "AdamW" ]:
787+ if self .opt_type == "LKF" :
788+ self .optimizer = LKFOptimizer (
789+ self .wrapper .parameters (), 0.98 , 0.99870 , self .opt_param ["kf_blocksize" ]
790+ )
791+ else :
792+ # === Common path for gradient-based optimizers ===
788793 adam_betas = (
789794 float (self .opt_param ["adam_beta1" ]),
790795 float (self .opt_param ["adam_beta2" ]),
791796 )
792797 weight_decay = float (self .opt_param ["weight_decay" ])
793- optimizer_class = (
794- torch .optim .Adam if self .opt_type == "Adam" else torch .optim .AdamW
795- )
798+
799+ if self .opt_type in ("Adam" , "AdamW" ):
800+ cls = torch .optim .Adam if self .opt_type == "Adam" else torch .optim .AdamW
801+ extra = {"betas" : adam_betas , "fused" : DEVICE .type != "cpu" }
802+ elif self .opt_type == "AdaMuon" :
803+ cls = AdaMuonOptimizer
804+ extra = {
805+ "adam_betas" : adam_betas ,
806+ "momentum" : float (self .opt_param ["momentum" ]),
807+ "lr_adjust" : float (self .opt_param ["lr_adjust" ]),
808+ "lr_adjust_coeff" : float (self .opt_param ["lr_adjust_coeff" ]),
809+ }
810+ elif self .opt_type == "HybridMuon" :
811+ cls = HybridMuonOptimizer
812+ extra = {
813+ "adam_betas" : adam_betas ,
814+ "momentum" : float (self .opt_param ["momentum" ]),
815+ "lr_adjust" : float (self .opt_param ["lr_adjust" ]),
816+ "lr_adjust_coeff" : float (self .opt_param ["lr_adjust_coeff" ]),
817+ "muon_2d_only" : bool (self .opt_param ["muon_2d_only" ]),
818+ "min_2d_dim" : int (self .opt_param ["min_2d_dim" ]),
819+ "flash_muon" : bool (self .opt_param ["flash_muon" ]),
820+ }
821+ else :
822+ raise ValueError (f"Not supported optimizer type '{ self .opt_type } '" )
823+
796824 self .optimizer = self ._create_optimizer (
797- optimizer_class ,
825+ cls ,
798826 lr = initial_lr ,
799- betas = adam_betas ,
800827 weight_decay = weight_decay ,
801- fused = DEVICE .type != "cpu" ,
802- )
803- self ._load_optimizer_state (optimizer_state_dict )
804- self .scheduler = torch .optim .lr_scheduler .LambdaLR (
805- self .optimizer ,
806- lambda step : (
807- self .lr_schedule .value (step + self .start_step ) / initial_lr
808- ),
809- last_epoch = self .start_step - 1 ,
810- )
811- elif self .opt_type == "LKF" :
812- self .optimizer = LKFOptimizer (
813- self .wrapper .parameters (), 0.98 , 0.99870 , self .opt_param ["kf_blocksize" ]
814- )
815- elif self .opt_type == "AdaMuon" :
816- self .optimizer = self ._create_optimizer (
817- AdaMuonOptimizer ,
818- lr = initial_lr ,
819- momentum = float (self .opt_param ["momentum" ]),
820- weight_decay = float (self .opt_param ["weight_decay" ]),
821- adam_betas = (
822- float (self .opt_param ["adam_beta1" ]),
823- float (self .opt_param ["adam_beta2" ]),
824- ),
825- lr_adjust = float (self .opt_param ["lr_adjust" ]),
826- lr_adjust_coeff = float (self .opt_param ["lr_adjust_coeff" ]),
827- )
828- if optimizer_state_dict is not None and self .restart_training :
829- self .optimizer .load_state_dict (optimizer_state_dict )
830- self .scheduler = torch .optim .lr_scheduler .LambdaLR (
831- self .optimizer ,
832- lambda step : (
833- self .lr_schedule .value (step + self .start_step ) / initial_lr
834- ),
835- last_epoch = self .start_step - 1 ,
836- )
837- elif self .opt_type == "HybridMuon" :
838- self .optimizer = self ._create_optimizer (
839- HybridMuonOptimizer ,
840- lr = initial_lr ,
841- momentum = float (self .opt_param ["momentum" ]),
842- weight_decay = float (self .opt_param ["weight_decay" ]),
843- adam_betas = (
844- float (self .opt_param ["adam_beta1" ]),
845- float (self .opt_param ["adam_beta2" ]),
846- ),
847- lr_adjust = float (self .opt_param ["lr_adjust" ]),
848- lr_adjust_coeff = float (self .opt_param ["lr_adjust_coeff" ]),
849- muon_2d_only = bool (self .opt_param ["muon_2d_only" ]),
850- min_2d_dim = int (self .opt_param ["min_2d_dim" ]),
851- flash_muon = bool (self .opt_param ["flash_muon" ]),
828+ ** extra ,
852829 )
853830 self ._load_optimizer_state (optimizer_state_dict )
854831 self .scheduler = torch .optim .lr_scheduler .LambdaLR (
@@ -858,8 +835,6 @@ def single_model_finetune(
858835 ),
859836 last_epoch = self .start_step - 1 ,
860837 )
861- else :
862- raise ValueError (f"Not supported optimizer type '{ self .opt_type } '" )
863838
864839 if self .zero_stage > 0 and self .rank == 0 :
865840 if self .zero_stage == 1 :
0 commit comments