@@ -78,9 +78,14 @@ def _maybe_compile(
7878 fn : callable ,
7979) -> callable :
8080 """Compile a function if torch.compile is available."""
81- if hasattr (torch , "compile" ):
82- return torch .compile (fn , fullgraph = True , dynamic = True )
83- return fn
81+ if not hasattr (torch , "compile" ):
82+ return fn
83+ # Skip compile if default device is CUDA but CUDA is unavailable.
84+ if hasattr (torch , "get_default_device" ):
85+ default_device = torch .get_default_device ()
86+ if default_device .type == "cuda" and not torch .cuda .is_available ():
87+ return fn
88+ return torch .compile (fn , fullgraph = True , dynamic = True )
8489
8590
8691@_maybe_compile
@@ -181,13 +186,54 @@ def zeropower_via_newtonschulz5(
181186 raise ValueError ("Input must be 2D or 3D for Newton-Schulz orthogonalization." )
182187
183188
189+ def should_fallback_to_adam_for_matrix (
190+ p : torch .Tensor ,
191+ min_2d_dim : int ,
192+ ) -> bool :
193+ """
194+ Check if a 2D matrix should fallback to Adam due to small dimensions.
195+
196+ Parameters
197+ ----------
198+ p : torch.Tensor
199+ Parameter tensor with ndim >= 2.
200+ min_2d_dim : int
201+ Minimum min(m, n) threshold for Muon. Matrices with min(m, n) >=
202+ min_2d_dim use Muon; those with min(m, n) < min_2d_dim use Adam.
203+
204+ Returns
205+ -------
206+ bool
207+ True if min(m, n) < min_2d_dim, False otherwise.
208+
209+ Raises
210+ ------
211+ ValueError
212+ If tensor has ndim < 2.
213+ """
214+ # === Step 1. Validate ===
215+ if p .ndim < 2 :
216+ raise ValueError ("Parameter must have ndim >= 2 for Muon suitability check." )
217+
218+ # === Step 2. Derive matrix shape consistent with Muon reshape ===
219+ m = int (p .shape [0 ])
220+ n = int (p .numel () // p .shape [0 ])
221+
222+ # === Step 3. Check if any dimension too small for Muon ===
223+ return min (m , n ) < min_2d_dim
224+
225+
184226class MuonOptimizer (Optimizer ):
185227 """
186- Muon optimizer with auxiliary Adam for non-matrix parameters .
228+ Muon optimizer with small-2D Adam fallback and 1D Adam path .
187229
188230 This optimizer applies different update rules based on parameter dimensionality:
189- - For >=2D parameters (weight matrices): Muon update with Newton-Schulz orthogonalization
190- - For 1D parameters (biases, layer norms): Standard Adam update
231+ - For >=2D parameters with min(m, n) >= min_2d_dim:
232+ Muon update with Newton-Schulz orthogonalization.
233+ - For 2D parameters with min(m, n) < min_2d_dim (small matrices):
234+ Adam update with scaled learning rate and update clipping.
235+ - For 1D parameters (biases, layer norms):
236+ Standard Adam update.
191237
192238 This hybrid approach is effective because Muon's orthogonalization is designed
193239 for weight matrices, while Adam is more suitable for biases and normalization params.
@@ -224,8 +270,19 @@ class MuonOptimizer(Optimizer):
224270 scale = sqrt(max(1.0, m/n)). Adam uses lr/lr_adjust.
225271 Default is 10.0 (Adam lr = lr/10).
226272 lr_adjust_coeff : float
227- Coefficient for match-RMS scaling with default 0.2.
228- Only effective when lr_adjust <= 0.
273+ Dual-purpose coefficient with default 0.2:
274+ 1. For Muon (when lr_adjust <= 0): match-RMS scaling factor,
275+ scale = lr_adjust_coeff * sqrt(max(m, n)).
276+ 2. For 2D Adam fallback: learning rate multiplier,
277+ adam_lr_matrix = adam_lr * min(lr_adjust_coeff, 0.1).
278+ The min(., 0.1) cap ensures conservative updates for small matrices.
279+ min_2d_dim : int
280+ Minimum min(m, n) threshold for Muon on 2D matrices.
281+ Matrices with min(m, n) >= min_2d_dim use Muon;
282+ those with min(m, n) < min_2d_dim use Adam fallback.
283+ Must be >= 1.
284+ Set to 1 to disable fallback.
285+ Default is 1.
229286
230287 Examples
231288 --------
@@ -245,14 +302,19 @@ def __init__(
245302 adam_betas : tuple [float , float ] = (0.9 , 0.95 ),
246303 lr_adjust : float = 10.0 ,
247304 lr_adjust_coeff : float = 0.2 ,
305+ min_2d_dim : int = 1 ,
248306 ) -> None :
307+ if min_2d_dim < 1 :
308+ raise ValueError ("min_2d_dim must be >= 1." )
309+
249310 defaults = {
250311 "lr" : lr ,
251312 "momentum" : momentum ,
252313 "weight_decay" : weight_decay ,
253314 "adam_betas" : adam_betas ,
254315 "lr_adjust" : lr_adjust ,
255316 "lr_adjust_coeff" : lr_adjust_coeff ,
317+ "min_2d_dim" : min_2d_dim ,
256318 }
257319 super ().__init__ (params , defaults )
258320 # Static parameter routing: built once on first step() call.
@@ -264,33 +326,50 @@ def _build_param_routing(self) -> None:
264326 Classify parameters into Muon and Adam routes (static routing).
265327
266328 Routing logic:
267- - >=2D parameters → Muon path (Newton-Schulz + momentum)
268- - 1D parameters → Adam path (standard Adam update)
329+ - >=2D parameters with min(m, n) >= min_2d_dim → Muon path
330+ - 2D parameters with min(m, n) < min_2d_dim → Adam fallback path
331+ - 1D parameters → Adam path
269332 """
270333 if self ._routing_built :
271334 return
272335
273336 self ._routing = []
274337 for group in self .param_groups :
275338 muon_params : list [dict [str , Any ]] = []
276- adam_params : list [dict [str , Any ]] = []
339+ adam_1d : list [dict [str , Any ]] = []
340+ adam_matrix : list [dict [str , Any ]] = []
341+
342+ min_2d_dim = group ["min_2d_dim" ]
277343
278344 for p in group ["params" ]:
279- if p .ndim >= 2 :
280- muon_params .append (
345+ if p .ndim < 2 :
346+ adam_1d .append ({"param" : p })
347+ continue
348+
349+ if (p .ndim == 2 ) and should_fallback_to_adam_for_matrix (
350+ p , min_2d_dim = min_2d_dim
351+ ):
352+ adam_matrix .append (
281353 {
282354 "param" : p ,
283- "rows" : int (p .shape [0 ]),
284- "cols" : int (p .numel () // p .shape [0 ]),
355+ "abs_floor" : 1e-3 * math .sqrt (float (p .numel ())),
285356 }
286357 )
287- else :
288- adam_params .append ({"param" : p })
358+ continue
359+
360+ muon_params .append (
361+ {
362+ "param" : p ,
363+ "rows" : int (p .shape [0 ]),
364+ "cols" : int (p .numel () // p .shape [0 ]),
365+ }
366+ )
289367
290368 self ._routing .append (
291369 {
292370 "muon_params" : muon_params ,
293- "adam_params" : adam_params ,
371+ "adam_1d" : adam_1d ,
372+ "adam_matrix" : adam_matrix ,
294373 }
295374 )
296375
@@ -332,13 +411,14 @@ def step(
332411 lr_adjust_coeff = group ["lr_adjust_coeff" ]
333412
334413 # === Step 1. Adam update for 1D parameters (biases, norms, etc.) ===
414+ # === Step 1.1. Collect gradients and initialize state ===
335415 adam_params : list [torch .Tensor ] = []
336416 adam_grads_fp32 : list [torch .Tensor ] = []
337417 adam_exp_avgs : list [torch .Tensor ] = []
338418 adam_exp_avg_sqs : list [torch .Tensor ] = []
339419 adam_states : list [dict [str , Any ]] = []
340420
341- for entry in route ["adam_params " ]:
421+ for entry in route ["adam_1d " ]:
342422 p = entry ["param" ]
343423 grad = p .grad
344424 if grad is None :
@@ -363,6 +443,7 @@ def step(
363443 adam_states .append (state )
364444
365445 if adam_params :
446+ # === Step 1.2. Update exp_avg / exp_avg_sq ===
366447 adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust
367448
368449 # exp_avg = beta1 * exp_avg + (1 - beta1) * grad
@@ -371,6 +452,7 @@ def step(
371452 grad_sq = torch ._foreach_mul (adam_grads_fp32 , adam_grads_fp32 )
372453 torch ._foreach_lerp_ (adam_exp_avg_sqs , grad_sq , 1 - adam_betas [1 ])
373454
455+ # === Step 1.3. Bias correction and parameter update ===
374456 for i , p in enumerate (adam_params ):
375457 state = adam_states [i ]
376458 bias_corr1 = 1 - state ["beta1_pow" ]
@@ -381,7 +463,87 @@ def step(
381463 delta_fp32 = - step_size * (adam_exp_avgs [i ] / denom )
382464 p .add_ (delta_fp32 .to (p .dtype ))
383465
384- # === Step 2. Muon update for >=2D parameters (weight matrices) ===
466+ # === Step 2. Adam update for small 2D matrices (fallback path) ===
467+ # === Step 2.1. Collect gradients and initialize state ===
468+ adam_matrix_params : list [torch .Tensor ] = []
469+ adam_matrix_grads_fp32 : list [torch .Tensor ] = []
470+ adam_matrix_exp_avgs : list [torch .Tensor ] = []
471+ adam_matrix_exp_avg_sqs : list [torch .Tensor ] = []
472+ adam_matrix_states : list [dict [str , Any ]] = []
473+ adam_matrix_abs_floor : list [float ] = []
474+
475+ for entry in route ["adam_matrix" ]:
476+ p = entry ["param" ]
477+ grad = p .grad
478+ if grad is None :
479+ continue
480+
481+ grad_fp32 = grad .float ()
482+
483+ state = self .state [p ]
484+ if "exp_avg" not in state :
485+ state ["exp_avg" ] = torch .zeros_like (p , dtype = torch .float32 )
486+ state ["exp_avg_sq" ] = torch .zeros_like (p , dtype = torch .float32 )
487+ state ["beta1_pow" ] = 1.0
488+ state ["beta2_pow" ] = 1.0
489+
490+ state ["beta1_pow" ] *= adam_betas [0 ]
491+ state ["beta2_pow" ] *= adam_betas [1 ]
492+
493+ adam_matrix_params .append (p )
494+ adam_matrix_grads_fp32 .append (grad_fp32 )
495+ adam_matrix_exp_avgs .append (state ["exp_avg" ])
496+ adam_matrix_exp_avg_sqs .append (state ["exp_avg_sq" ])
497+ adam_matrix_states .append (state )
498+ adam_matrix_abs_floor .append (entry ["abs_floor" ])
499+
500+ if adam_matrix_params :
501+ # === Step 2.2. Update exp_avg / exp_avg_sq with scaled lr ===
502+ adam_lr = lr if lr_adjust <= 0 else lr / lr_adjust
503+ adam_lr_matrix = adam_lr * min (lr_adjust_coeff , 0.1 )
504+
505+ # exp_avg = beta1 * exp_avg + (1 - beta1) * grad
506+ # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
507+ torch ._foreach_lerp_ (
508+ adam_matrix_exp_avgs , adam_matrix_grads_fp32 , 1 - adam_betas [0 ]
509+ )
510+ grad_sq_m = torch ._foreach_mul (
511+ adam_matrix_grads_fp32 , adam_matrix_grads_fp32
512+ )
513+ torch ._foreach_lerp_ (
514+ adam_matrix_exp_avg_sqs , grad_sq_m , 1 - adam_betas [1 ]
515+ )
516+
517+ # === Step 2.3. Compute unclipped deltas ===
518+ raw_deltas : list [torch .Tensor ] = []
519+ for i in range (len (adam_matrix_params )):
520+ state = adam_matrix_states [i ]
521+ bias_corr1 = 1 - state ["beta1_pow" ]
522+ bias_corr2 = 1 - state ["beta2_pow" ]
523+ step_size = adam_lr_matrix / bias_corr1
524+ denom = (
525+ (adam_matrix_exp_avg_sqs [i ] / bias_corr2 ).sqrt ().add_ (ADAM_EPS )
526+ )
527+ raw_deltas .append (- step_size * (adam_matrix_exp_avgs [i ] / denom ))
528+
529+ # === Step 2.4. Clip updates by relative norm and apply ===
530+ max_rel_change = 0.05
531+ p_norms = torch .stack (torch ._foreach_norm (adam_matrix_params ))
532+ delta_norms = torch .stack (torch ._foreach_norm (raw_deltas ))
533+ floors = torch .tensor (
534+ adam_matrix_abs_floor ,
535+ device = p_norms .device ,
536+ dtype = p_norms .dtype ,
537+ )
538+ max_delta = torch .maximum (max_rel_change * p_norms , floors )
539+ scales_tensor = torch .clamp (max_delta / (delta_norms + 1e-12 ), max = 1.0 )
540+ for i , delta in enumerate (raw_deltas ):
541+ delta .mul_ (scales_tensor [i ])
542+
543+ torch ._foreach_add_ (adam_matrix_params , raw_deltas )
544+
545+ # === Step 3. Muon update for >=2D parameters (weight matrices) ===
546+ # === Step 3.1. Collect gradients and initialize momentum ===
385547 muon_params_for_decay : list [torch .Tensor ] = []
386548 muon_grads : list [torch .Tensor ] = []
387549 muon_momentum_buffers : list [torch .Tensor ] = []
@@ -406,19 +568,22 @@ def step(
406568 muon_momentum_buffers .append (buf )
407569 active_entries .append ((entry , grad ))
408570
571+ # === Step 3.2. Apply weight decay (Muon path only) ===
409572 if weight_decay > 0 and muon_params_for_decay :
410573 torch ._foreach_mul_ (muon_params_for_decay , 1.0 - lr * weight_decay )
411574
412575 if not active_entries :
413576 continue
414577
578+ # === Step 3.3. Momentum update (Nesterov) ===
415579 # m_t = beta * m_{t-1} + (1 - beta) * g_t
416580 torch ._foreach_lerp_ (muon_momentum_buffers , muon_grads , 1 - momentum )
417581 # update = beta * m_t + (1 - beta) * g_t
418582 muon_updates = torch ._foreach_lerp (
419583 muon_grads , muon_momentum_buffers , momentum
420584 )
421585
586+ # === Step 3.4. Bucket by shape/device/dtype for batched NS ===
422587 buckets : dict [
423588 tuple [int , int , torch .device , torch .dtype ],
424589 list [tuple [dict [str , Any ], torch .Tensor ]],
@@ -432,6 +597,7 @@ def step(
432597 buckets [bucket_key ] = []
433598 buckets [bucket_key ].append ((entry , muon_updates [idx ]))
434599
600+ # === Step 3.5. Newton-Schulz orthogonalization and update ===
435601 for (rows , cols , _device , dtype ), bucket_entries in buckets .items ():
436602 # scale = coeff * sqrt(max(m, n)) [match-RMS mode]
437603 # scale = sqrt(max(1, m/n)) [rectangular mode]
0 commit comments