@@ -166,11 +166,9 @@ def get_opt_param(params: dict[str, Any]) -> tuple[str, dict[str, Any]]:
166166 # Common parameters
167167 "weight_decay" : params .get ("weight_decay" , 0.001 ),
168168 # Muon/AdaMuon parameters
169- "muon_momentum " : params .get ("muon_momentum " , 0.95 ),
169+ "momentum " : params .get ("momentum " , 0.95 ),
170170 "adam_beta1" : params .get ("adam_beta1" , 0.9 ),
171171 "adam_beta2" : params .get ("adam_beta2" , 0.95 ),
172- "adam_eps" : params .get ("adam_eps" , 1e-7 ),
173- "nesterov" : params .get ("nesterov" , True ),
174172 }
175173 return opt_type , opt_param
176174
@@ -710,7 +708,7 @@ def warm_up_linear(step: int, warmup_steps: int) -> float:
710708 self .optimizer = AdaMuonOptimizer (
711709 self .wrapper .parameters (),
712710 lr = self .lr_exp .start_lr ,
713- momentum = float (self .opt_param .get ("muon_momentum " , 0.95 )),
711+ momentum = float (self .opt_param .get ("momentum " , 0.95 )),
714712 weight_decay = float (self .opt_param .get ("weight_decay" , 0.001 )),
715713 adam_betas = (
716714 float (self .opt_param .get ("adam_beta1" , 0.9 )),
0 commit comments