@@ -408,9 +408,8 @@ def __init__(self) -> None:
408408 (float (a ), float (b ), float (c )) for a , b , c in POLAR_EXPRESS_COEFFICIENTS
409409 )
410410 self ._restart_iteration_set = frozenset ((2 ,))
411- self ._call_impl = self ._orthogonalize_impl
412411 self ._compiled_call = torch .compile (
413- self ._call_impl ,
412+ self ._orthogonalize_impl ,
414413 fullgraph = True ,
415414 dynamic = True ,
416415 )
@@ -561,85 +560,6 @@ def _reshape_update_to_matrix_batch(
561560 return update_tensor .reshape (batch_size , rows , cols ).contiguous ()
562561
563562
564- def _stack_bucket_updates (
565- bucket_entries : list [
566- tuple [dict [str , Any ], torch .Tensor , torch .Tensor , torch .Tensor ]
567- ],
568- batch_size : int ,
569- rows : int ,
570- cols : int ,
571- ) -> torch .Tensor :
572- """
573- Stack same-shape Muon updates into one tensor for orthogonalization.
574-
575- Parameters
576- ----------
577- bucket_entries : list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]]
578- Bucket entries as ``(entry, update_tensor, grad, momentum_buffer)``.
579- batch_size : int
580- Number of matrix slices per parameter in the bucket.
581- rows : int
582- Matrix row count for each slice.
583- cols : int
584- Matrix column count for each slice.
585-
586- Returns
587- -------
588- torch.Tensor
589- Tensor with shape ``(num_entries, batch_size, rows, cols)``.
590- """
591- update_batches = [
592- _reshape_update_to_matrix_batch (update_tensor , batch_size , rows , cols )
593- for _ , update_tensor , _ , _ in bucket_entries
594- ]
595- return torch .stack (update_batches , dim = 0 )
596-
597-
598- def _orthogonalize_standard_stacked (
599- stacked_updates : torch .Tensor ,
600- rows : int ,
601- cols : int ,
602- use_flash : bool ,
603- flash_buffers : tuple [torch .Tensor , torch .Tensor ] | None = None ,
604- ) -> torch .Tensor :
605- """
606- Orthogonalize stacked updates with the current standard Newton-Schulz path.
607-
608- Parameters
609- ----------
610- stacked_updates : torch.Tensor
611- Tensor with shape ``(num_entries, batch_size, rows, cols)``.
612- rows : int
613- Matrix row count.
614- cols : int
615- Matrix column count.
616- use_flash : bool
617- Whether to use the Triton single-matrix path when ``batch_size == 1``.
618- flash_buffers : tuple[torch.Tensor, torch.Tensor] | None, optional
619- Pre-allocated flash buffers when ``use_flash`` is enabled.
620-
621- Returns
622- -------
623- torch.Tensor
624- Orthogonalized tensor with the same shape as ``stacked_updates``.
625- """
626- num_entries , batch_size , _ , _ = stacked_updates .shape
627- if use_flash and batch_size == 1 :
628- if flash_buffers is None :
629- raise ValueError ("Flash buffers are required when use_flash=True." )
630- buf1 , buf2 = flash_buffers
631- orthogonalized = [
632- _flash_newton_schulz_orth (stacked_updates [idx , 0 ], buf1 , buf2 )
633- for idx in range (num_entries )
634- ]
635- return torch .stack (orthogonalized , dim = 0 ).unsqueeze (1 )
636-
637- # Flash path only supports one matrix per entry; batched inputs use bmm path.
638- flat_updates = stacked_updates .reshape (num_entries * batch_size , rows , cols )
639- orthogonalized = _batched_newton_schulz_orth (flat_updates )
640- return orthogonalized .reshape (num_entries , batch_size , rows , cols )
641-
642-
643563def _compute_muon_nesterov_updates (
644564 gradients : list [torch .Tensor ],
645565 momentum_buffers : list [torch .Tensor ],
@@ -958,23 +878,25 @@ def __init__(
958878 ] = {}
959879 self ._gram_orthogonalizer : _GramNewtonSchulzOrthogonalizer | None = None
960880
961- # === Step 5. Foreach acceleration (disabled under FSDP2 / DTensor) ===
881+ # === Step 5. Foreach acceleration ===
882+ # Defaults to True for single-GPU / DDP / ZeRO-1 (plain tensors). Callers
883+ # that train under FSDP2 (``fully_shard``) should pass
884+ # ``use_foreach=False`` explicitly because several ``torch._foreach_*``
885+ # ops lack DTensor sharding propagation on older PyTorch builds.
962886 self ._use_foreach = self ._resolve_foreach (use_foreach )
963887
964888 @staticmethod
965889 def _resolve_foreach (use_foreach : bool | None ) -> bool :
966- """Decide whether to use ``torch._foreach_*`` multi-tensor kernels.
890+ """Resolve the ``use_foreach`` flag for ``torch._foreach_*`` kernels.
967891
968892 Foreach fuses per-parameter loops into single kernel launches,
969- eliminating Python overhead. Disabled when parameters are FSDP2
970- ``DTensor`` because several foreach ops lack DTensor sharding
971- propagation rules in older PyTorch builds.
893+ eliminating Python overhead. When ``use_foreach`` is ``None`` the
894+ default is ``True`` because plain ``torch.Tensor`` (single-GPU, DDP,
895+ ZeRO-1) always supports these ops; callers that hit DTensor dispatch
896+ errors under FSDP2 must pass ``use_foreach=False`` explicitly.
972897 """
973898 if use_foreach is not None :
974899 return bool (use_foreach )
975- # Conservative default: enable foreach (safe for DDP / ZeRO-1).
976- # FSDP2 users should pass use_foreach=False explicitly if they
977- # encounter DTensor dispatch errors.
978900 return True
979901
980902 def _compute_magma_scales_merged (
@@ -1392,7 +1314,6 @@ def _process_merged_gram_buckets(
13921314
13931315 for (rows , cols , dev , dt ), bucket_entries in gram_buckets .items ():
13941316 min_dim = min (rows , cols )
1395- max_dim = max (rows , cols )
13961317 transposed = rows > cols
13971318 sb_key = (min_dim , dev , dt )
13981319 if sb_key not in super_buckets :
@@ -1401,7 +1322,7 @@ def _process_merged_gram_buckets(
14011322
14021323 gram_orth = self ._get_gram_orthogonalizer ()
14031324
1404- for (min_dim , _dev , _dt ), sub_list in super_buckets .items ():
1325+ for (_min_dim , _dev , _dt ), sub_list in super_buckets .items ():
14051326 # Find the maximum large-dimension across all sub-buckets.
14061327 padded_max_dim = max (max (r , c ) for r , c , _ , _ in sub_list )
14071328
0 commit comments