7171 https://github.com/MoonshotAI/Moonlight
7272.. [4] Flash-Muon: Triton-accelerated symmetric matmul for Newton-Schulz.
7373 https://github.com/lintianyang/flash-muon (MIT License, Tianyang Lin)
74+ .. [5] Magma: Momentum-Aligned Gradient Masking for Stable Optimizer Updates.
75+ arXiv:2602.15322, 2025.
76+ https://arxiv.org/abs/2602.15322
77+ Implements block-wise momentum-gradient alignment scoring with EMA smoothing
78+ and soft scaling for improved stability under heavy-tailed gradient noise.
79+ HybridMuon uses a stabilized variant (Magma-lite) with sigmoid range stretching
80+ and continuous soft scaling [0.1, 1.0] instead of Bernoulli masking, optimized
81+ for MLIP force-field training.
7482"""
7583
7684from __future__ import (
122130# Below this threshold, triton kernel launch overhead dominates over compute,
123131# and cuBLAS (via torch.mm/addmm) is faster for small matrices.
124132FLASH_MIN_DIM : int = 1024
133+ # Magma-lite constants (Muon path update damping only)
134+ MAGMA_TAU : float = 2.0
135+ MAGMA_EMA_DECAY : float = 0.9
136+ MAGMA_MIN_SCALE : float = 0.1
137+ MAGMA_EPS : float = 1e-12
138+ MAGMA_SIGMOID_MIN : float = 1.0 / (1.0 + math .exp (1.0 / MAGMA_TAU ))
139+ MAGMA_SIGMOID_MAX : float = 1.0 / (1.0 + math .exp (- 1.0 / MAGMA_TAU ))
125140
126141
127142# ============================================================================
@@ -554,6 +569,11 @@ class HybridMuonOptimizer(Optimizer):
554569 Requires triton and CUDA. Falls back to PyTorch implementation
555570 when triton is unavailable or running on CPU.
556571 Default is True.
572+ magma_muon : bool
573+ Enable Magma-lite damping on Muon updates with default False.
574+ This computes momentum-gradient cosine alignment per Muon block,
575+ applies EMA smoothing, and rescales Muon updates in [0.1, 1.0].
576+ Adam/AdamW paths are unchanged.
557577
558578 Examples
559579 --------
@@ -576,6 +596,7 @@ def __init__(
576596 muon_mode : str = "slice" ,
577597 named_parameters : Iterable [tuple [str , torch .Tensor ]] | None = None ,
578598 flash_muon : bool = True ,
599+ magma_muon : bool = False ,
579600 ) -> None :
580601 # === Step 1. Validate routing mode ===
581602 muon_mode = str (muon_mode ).lower ()
@@ -591,6 +612,7 @@ def __init__(
591612 "lr_adjust" : lr_adjust ,
592613 "lr_adjust_coeff" : lr_adjust_coeff ,
593614 "muon_mode" : muon_mode ,
615+ "magma_muon" : bool (magma_muon ),
594616 }
595617 super ().__init__ (params , defaults )
596618
@@ -612,6 +634,226 @@ def __init__(
612634 tuple [torch .Tensor , torch .Tensor ],
613635 ] = {}
614636
637+ def _compute_magma_scale (
638+ self ,
639+ param : torch .Tensor ,
640+ grad : torch .Tensor ,
641+ momentum_buffer : torch .Tensor ,
642+ batch_size : int ,
643+ rows : int ,
644+ cols : int ,
645+ ) -> torch .Tensor :
646+ """
647+ Compute Magma-lite Muon damping scales from momentum-gradient alignment.
648+
649+ Implements a stabilized version of Magma (Momentum-Aligned Gradient Masking)
650+ adapted for MLIP force-field training. Computes block-wise alignment scores
651+ between Muon momentum and current gradients, applies EMA smoothing, and
652+ rescales Muon updates to improve stability under heavy-tailed gradient noise.
653+
654+ Notes
655+ -----
656+ For each Muon block b:
657+
658+ 1. Compute cosine similarity between momentum and gradient:
659+
660+ cos(b) = <μ_t^(b), g_t^(b)> / (||μ_t^(b)|| * ||g_t^(b)||)
661+
662+ 2. Apply sigmoid with range stretching to [0, 1]:
663+
664+ s_raw^(b) = (sigmoid(cos(b) / τ) - s_min) / (s_max - s_min)
665+
666+ where τ=2.0, s_min=sigmoid(-1/τ), s_max=sigmoid(1/τ).
667+ This stretches the narrow sigmoid range [0.38, 0.62] to [0, 1].
668+
669+ 3. Apply EMA smoothing:
670+
671+ s̃_t^(b) = a * s̃_{t-1}^(b) + (1-a) * s_raw^(b)
672+
673+ where a=0.9 (MAGMA_EMA_DECAY).
674+
675+ 4. Map to damping scale in [s_min_scale, 1.0]:
676+
677+ scale^(b) = s_min_scale + (1 - s_min_scale) * s̃_t^(b)
678+
679+ where s_min_scale=0.1 (MAGMA_MIN_SCALE).
680+
681+ 5. Apply damping to Muon update:
682+
683+ Δ̃^(b) = scale^(b) * Δ^(b) (soft scaling, no Bernoulli masking)
684+
685+ Key differences from the original Magma paper:
686+
687+ - Sigmoid range stretching: Paper uses raw sigmoid with narrow range [0.38, 0.62].
688+ We stretch to [0, 1] for better discrimination between aligned/misaligned blocks.
689+ - Soft scaling: Paper uses Bernoulli masking (50% skip probability).
690+ We use continuous soft scaling [0.1, 1.0] for stability in MLIP training.
691+ - Minimum scale: Paper allows scale=0 (complete skip).
692+ We enforce scale >= 0.1 to guarantee minimum learning rate.
693+
694+ Parameters
695+ ----------
696+ param : torch.Tensor
697+ Parameter updated by Muon.
698+ grad : torch.Tensor
699+ Current gradient tensor with shape compatible with ``(batch_size, rows, cols)``.
700+ momentum_buffer : torch.Tensor
701+ Muon momentum buffer (updated m_t) with same shape as ``grad``.
702+ batch_size : int
703+ Number of Muon blocks (1 for 2d/flat mode, >1 for slice mode).
704+ rows : int
705+ Matrix row count per block.
706+ cols : int
707+ Matrix column count per block.
708+
709+ Returns
710+ -------
711+ torch.Tensor
712+ Damping scales with shape (batch_size,) in [MAGMA_MIN_SCALE, 1.0].
713+ """
714+ # === Step 1. Restore or initialize EMA score state ===
715+ state = self .state [param ]
716+ magma_score = state .get ("magma_score" )
717+ if (
718+ magma_score is None
719+ or magma_score .ndim != 1
720+ or magma_score .numel () != batch_size
721+ or magma_score .device != param .device
722+ ):
723+ magma_score = torch .full (
724+ (batch_size ,),
725+ 0.5 ,
726+ dtype = torch .float32 ,
727+ device = param .device ,
728+ )
729+ else :
730+ magma_score = magma_score .to (dtype = torch .float32 , device = param .device )
731+
732+ # === Step 2. Build matrix-view for block-wise cosine ===
733+ grad_view = grad .reshape (batch_size , rows , cols ).reshape (batch_size , - 1 )
734+ momentum_view = momentum_buffer .reshape (batch_size , rows , cols ).reshape (
735+ batch_size , - 1
736+ )
737+ grad_view = grad_view .to (dtype = torch .float32 )
738+ momentum_view = momentum_view .to (dtype = torch .float32 )
739+
740+ # === Step 3. Compute cosine alignment with numerical protection ===
741+ dot = (momentum_view * grad_view ).sum (dim = 1 )
742+ denom = (momentum_view .norm (dim = 1 ) * grad_view .norm (dim = 1 )).clamp (min = MAGMA_EPS )
743+ cosine = (dot / denom ).clamp (min = - 1.0 , max = 1.0 )
744+
745+ # === Step 4. Sigmoid mapping + range stretching to [0, 1] ===
746+ raw_sigmoid = torch .sigmoid (cosine / MAGMA_TAU )
747+ raw_score = (raw_sigmoid - MAGMA_SIGMOID_MIN ) / (
748+ MAGMA_SIGMOID_MAX - MAGMA_SIGMOID_MIN
749+ )
750+ raw_score = raw_score .clamp (min = 0.0 , max = 1.0 )
751+
752+ # === Step 5. Update EMA score and convert to damping scale ===
753+ magma_score = (
754+ MAGMA_EMA_DECAY * magma_score + (1.0 - MAGMA_EMA_DECAY ) * raw_score
755+ )
756+ state ["magma_score" ] = magma_score
757+ return MAGMA_MIN_SCALE + (1.0 - MAGMA_MIN_SCALE ) * magma_score
758+
759+ def _compute_magma_scales_for_bucket (
760+ self ,
761+ bucket_entries : list [
762+ tuple [dict [str , Any ], torch .Tensor , torch .Tensor , torch .Tensor ]
763+ ],
764+ batch_size : int ,
765+ rows : int ,
766+ cols : int ,
767+ ) -> list [torch .Tensor ]:
768+ """
769+ Compute Magma-lite damping scales for one Muon bucket in a batched way.
770+
771+ Parameters
772+ ----------
773+ bucket_entries : list[tuple[dict[str, Any], torch.Tensor, torch.Tensor, torch.Tensor]]
774+ Bucket entries as ``(entry, update_tensor, grad, momentum_buffer)``.
775+ batch_size : int
776+ Number of Muon blocks per parameter in this bucket.
777+ rows : int
778+ Matrix row count for this bucket.
779+ cols : int
780+ Matrix column count for this bucket.
781+
782+ Returns
783+ -------
784+ list[torch.Tensor]
785+ Magma scales for each bucket entry. Each tensor has shape (batch_size,).
786+ """
787+ # === Step 0. Fast path for single-entry bucket ===
788+ if len (bucket_entries ) == 1 :
789+ entry , _update_tensor , grad , momentum_buffer = bucket_entries [0 ]
790+ return [
791+ self ._compute_magma_scale (
792+ param = entry ["param" ],
793+ grad = grad ,
794+ momentum_buffer = momentum_buffer ,
795+ batch_size = batch_size ,
796+ rows = rows ,
797+ cols = cols ,
798+ )
799+ ]
800+
801+ # === Step 1. Build batched matrix views ===
802+ grad_views : list [torch .Tensor ] = []
803+ momentum_views : list [torch .Tensor ] = []
804+ for _ , _ , grad , momentum_buffer in bucket_entries :
805+ grad_view = grad .reshape (batch_size , rows , cols ).reshape (batch_size , - 1 )
806+ momentum_view = momentum_buffer .reshape (batch_size , rows , cols ).reshape (
807+ batch_size , - 1
808+ )
809+ grad_views .append (grad_view .to (dtype = torch .float32 ))
810+ momentum_views .append (momentum_view .to (dtype = torch .float32 ))
811+
812+ grad_batch = torch .stack (grad_views , dim = 0 )
813+ momentum_batch = torch .stack (momentum_views , dim = 0 )
814+
815+ # === Step 2. Compute cosine alignment for all entries ===
816+ dot = (momentum_batch * grad_batch ).sum (dim = 2 )
817+ denom = (momentum_batch .norm (dim = 2 ) * grad_batch .norm (dim = 2 )).clamp (
818+ min = MAGMA_EPS
819+ )
820+ cosine = (dot / denom ).clamp (min = - 1.0 , max = 1.0 )
821+ raw_sigmoid = torch .sigmoid (cosine / MAGMA_TAU )
822+ raw_scores = (raw_sigmoid - MAGMA_SIGMOID_MIN ) / (
823+ MAGMA_SIGMOID_MAX - MAGMA_SIGMOID_MIN
824+ )
825+ raw_scores = raw_scores .clamp (min = 0.0 , max = 1.0 )
826+
827+ # === Step 3. Update per-parameter EMA score state ===
828+ scales : list [torch .Tensor ] = []
829+ for idx , (entry , _ , _ , _ ) in enumerate (bucket_entries ):
830+ param = entry ["param" ]
831+ state = self .state [param ]
832+ magma_score = state .get ("magma_score" )
833+ if (
834+ magma_score is None
835+ or magma_score .ndim != 1
836+ or magma_score .numel () != batch_size
837+ or magma_score .device != param .device
838+ ):
839+ magma_score = torch .full (
840+ (batch_size ,),
841+ 0.5 ,
842+ dtype = torch .float32 ,
843+ device = param .device ,
844+ )
845+ state ["magma_score" ] = magma_score
846+ elif magma_score .dtype != torch .float32 :
847+ magma_score = magma_score .to (dtype = torch .float32 , device = param .device )
848+ state ["magma_score" ] = magma_score
849+
850+ magma_score .mul_ (MAGMA_EMA_DECAY ).add_ (
851+ raw_scores [idx ], alpha = (1.0 - MAGMA_EMA_DECAY )
852+ )
853+ scales .append (MAGMA_MIN_SCALE + (1.0 - MAGMA_MIN_SCALE ) * magma_score )
854+
855+ return scales
856+
615857 def _get_ns_buffers (
616858 self ,
617859 M : int ,
@@ -742,6 +984,7 @@ def step(
742984 adam_betas = group ["adam_betas" ]
743985 lr_adjust = group ["lr_adjust" ]
744986 lr_adjust_coeff = group ["lr_adjust_coeff" ]
987+ magma_muon = bool (group .get ("magma_muon" , False ))
745988
746989 # === Step 1. Adam update for non-decay Adam path ===
747990 # === Step 1.1. Collect gradients and initialize state ===
@@ -836,7 +1079,7 @@ def step(
8361079 # AdamW decay for >=2D Adam path.
8371080 if weight_decay > 0 :
8381081 for p in adam_decay_params :
839- p .mul_ (1.0 - lr * weight_decay )
1082+ p .mul_ (1.0 - adam_lr * weight_decay )
8401083
8411084 # exp_avg = beta1 * exp_avg + (1 - beta1) * grad
8421085 # exp_avg_sq = beta2 * exp_avg_sq + (1 - beta2) * grad^2
@@ -904,7 +1147,7 @@ def step(
9041147 # === Step 3.4. Bucket by (batch_size, rows, cols, device, dtype) ===
9051148 buckets : dict [
9061149 tuple [int , int , int , torch .device , torch .dtype ],
907- list [tuple [dict [str , Any ], torch .Tensor ]],
1150+ list [tuple [dict [str , Any ], torch .Tensor , torch . Tensor , torch . Tensor ]],
9081151 ] = {}
9091152
9101153 for idx , entry_info in enumerate (active_entries ):
@@ -919,7 +1162,14 @@ def step(
9191162 )
9201163 if bucket_key not in buckets :
9211164 buckets [bucket_key ] = []
922- buckets [bucket_key ].append ((entry , muon_updates [idx ]))
1165+ buckets [bucket_key ].append (
1166+ (
1167+ entry ,
1168+ muon_updates [idx ],
1169+ muon_grads [idx ],
1170+ muon_momentum_buffers [idx ],
1171+ )
1172+ )
9231173
9241174 # === Step 3.5. Newton-Schulz orthogonalization and update ===
9251175 for (batch_size , rows , cols , _device , _ ), bucket_entries in buckets .items ():
@@ -944,24 +1194,57 @@ def step(
9441194 if use_flash :
9451195 buf1 , buf2 = self ._get_ns_buffers (M , _device )
9461196
1197+ if magma_muon :
1198+ bucket_magma_scales = self ._compute_magma_scales_for_bucket (
1199+ bucket_entries = bucket_entries ,
1200+ batch_size = batch_size ,
1201+ rows = rows ,
1202+ cols = cols ,
1203+ )
1204+ else :
1205+ bucket_magma_scales = [None ] * len (bucket_entries )
1206+
9471207 # Process each entry individually with Newton-Schulz orth.
9481208 # Compatible with sharding propagation under FSDP2.
949- for entry , update_tensor in bucket_entries :
1209+ for (entry , update_tensor , _grad , _buffer ), magma_scale in zip (
1210+ bucket_entries , bucket_magma_scales , strict = True
1211+ ):
9501212 if batch_size > 1 :
951- update_batch = update_tensor .reshape (batch_size , rows , cols )
952- if not update_batch .is_contiguous ():
953- update_batch = update_batch .contiguous ()
1213+ if update_tensor .is_contiguous ():
1214+ update_batch = update_tensor .view (batch_size , rows , cols )
1215+ else :
1216+ update_batch = update_tensor .reshape (
1217+ batch_size , rows , cols
1218+ ).contiguous ()
9541219 orth = _batched_newton_schulz_orth (update_batch )
9551220 else :
956- update_matrix = update_tensor .reshape (rows , cols )
957- if not update_matrix .is_contiguous ():
958- update_matrix = update_matrix .contiguous ()
1221+ if update_tensor .is_contiguous ():
1222+ update_matrix = update_tensor .view (rows , cols )
1223+ else :
1224+ update_matrix = update_tensor .reshape (
1225+ rows , cols
1226+ ).contiguous ()
9591227 if use_flash :
9601228 orth = _flash_newton_schulz_orth (update_matrix , buf1 , buf2 )
9611229 else :
9621230 orth = _newton_schulz_orth (update_matrix )
9631231 orth .mul_ (scale )
964- delta = orth .reshape (entry ["param" ].shape )
1232+ if batch_size > 1 :
1233+ orth_view = orth .reshape (batch_size , rows , cols )
1234+ if magma_scale is not None :
1235+ orth_view .mul_ (
1236+ magma_scale .view (batch_size , 1 , 1 ).to (
1237+ dtype = orth .dtype ,
1238+ device = orth .device ,
1239+ )
1240+ )
1241+ delta = orth_view .reshape (entry ["param" ].shape )
1242+ else :
1243+ if magma_scale is not None :
1244+ orth .mul_ (
1245+ magma_scale [0 ].to (dtype = orth .dtype , device = orth .device )
1246+ )
1247+ delta = orth .reshape (entry ["param" ].shape )
9651248 entry ["param" ].add_ (delta , alpha = - lr )
9661249
9671250 return loss
0 commit comments