@@ -3471,6 +3471,67 @@ def training_args(
34713471 doc = doc_only_pt_supported
34723472 + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0." ,
34733473 ),
3474+ ],
3475+ [],
3476+ optional = True ,
3477+ ),
3478+ Argument (
3479+ "Muon" ,
3480+ dict ,
3481+ [
3482+ Argument (
3483+ "momentum" ,
3484+ float ,
3485+ optional = True ,
3486+ default = 0.95 ,
3487+ alias = ["muon_momentum" ],
3488+ doc = doc_only_pt_supported
3489+ + "Momentum coefficient for Muon optimizer (>=2D params). "
3490+ "Used in Nesterov momentum update: m_t = beta*m_{t-1} + (1-beta)*g_t." ,
3491+ ),
3492+ Argument (
3493+ "adam_beta1" ,
3494+ float ,
3495+ optional = True ,
3496+ default = 0.9 ,
3497+ doc = doc_only_pt_supported
3498+ + "Adam beta1 coefficient for 1D parameters (biases, norms)." ,
3499+ ),
3500+ Argument (
3501+ "adam_beta2" ,
3502+ float ,
3503+ optional = True ,
3504+ default = 0.95 ,
3505+ doc = doc_only_pt_supported
3506+ + "Adam beta2 coefficient for 1D parameters (biases, norms)." ,
3507+ ),
3508+ Argument (
3509+ "weight_decay" ,
3510+ float ,
3511+ optional = True ,
3512+ default = 0.001 ,
3513+ doc = doc_only_pt_supported
3514+ + "Weight decay coefficient. Applied only to >=2D parameters (Muon path)." ,
3515+ ),
3516+ Argument (
3517+ "lr_adjust" ,
3518+ float ,
3519+ optional = True ,
3520+ default = 10.0 ,
3521+ doc = doc_only_pt_supported
3522+ + "Learning rate adjustment mode for Muon scaling and Adam learning rate. "
3523+ "If lr_adjust <= 0: use match-RMS scaling (scale = coeff*sqrt(max(m,n))), Adam uses lr directly. "
3524+ "If lr_adjust > 0: use rectangular correction (scale = sqrt(max(1, m/n))), Adam uses lr/lr_adjust. "
3525+ "Default is 10.0 (Adam lr = lr/10)." ,
3526+ ),
3527+ Argument (
3528+ "lr_adjust_coeff" ,
3529+ float ,
3530+ optional = True ,
3531+ default = 0.2 ,
3532+ doc = doc_only_pt_supported
3533+ + "Coefficient for match-RMS scaling. Only effective when lr_adjust <= 0." ,
3534+ ),
34743535 Argument (
34753536 "min_2d_dim" ,
34763537 int ,
0 commit comments