33
44import torch
55
6- from deepmd .pt .optimizer .muon import (
7- MuonOptimizer ,
6+ from deepmd .pt .optimizer .hybrid_muon import (
7+ HybridMuonOptimizer ,
88 zeropower_via_newtonschulz5 ,
99)
1010from deepmd .pt .utils import (
@@ -82,8 +82,8 @@ def test_invalid_input(self) -> None:
8282
8383
8484@unittest .skipIf (not BF16_SUPPORTED , "bf16 matmul not supported on this device" )
85- class TestMuonOptimizer (unittest .TestCase ):
86- """Test MuonOptimizer class."""
85+ class TestHybridMuonOptimizer (unittest .TestCase ):
86+ """Test HybridMuonOptimizer class."""
8787
8888 def setUp (self ) -> None :
8989 self .device = env .DEVICE
@@ -96,7 +96,7 @@ def test_step(self) -> None:
9696 torch .nn .ReLU (),
9797 torch .nn .Linear (20 , 5 , device = self .device ),
9898 )
99- optimizer = MuonOptimizer (model .parameters (), lr = 0.02 )
99+ optimizer = HybridMuonOptimizer (model .parameters (), lr = 0.02 )
100100
101101 x = torch .randn (4 , 10 , device = self .device )
102102 model (x ).sum ().backward ()
@@ -111,7 +111,7 @@ def test_weight_decay(self) -> None:
111111 """Test weight decay reduces parameter norm."""
112112 torch .manual_seed (42 )
113113 model = torch .nn .Linear (10 , 10 , device = self .device )
114- optimizer = MuonOptimizer (model .parameters (), lr = 0.02 , weight_decay = 0.1 )
114+ optimizer = HybridMuonOptimizer (model .parameters (), lr = 0.02 , weight_decay = 0.1 )
115115
116116 initial_norm = model .weight .norm ().item ()
117117 for _ in range (10 ):
@@ -126,7 +126,7 @@ def test_muon_adam_separation(self) -> None:
126126 """Test Muon for 2D params, Adam for 1D params."""
127127 torch .manual_seed (42 )
128128 model = torch .nn .Linear (10 , 10 , device = self .device )
129- optimizer = MuonOptimizer (model .parameters (), lr = 0.02 )
129+ optimizer = HybridMuonOptimizer (model .parameters (), lr = 0.02 )
130130
131131 x = torch .randn (4 , 10 , device = self .device )
132132 model (x ).sum ().backward ()
@@ -145,7 +145,7 @@ def test_muon_adam_fallback_small_2d(self) -> None:
145145 torch .manual_seed (42 )
146146 linear_small = torch .nn .Linear (10 , 1 , bias = False , device = self .device )
147147 linear_large = torch .nn .Linear (10 , 10 , bias = False , device = self .device )
148- optimizer = MuonOptimizer (
148+ optimizer = HybridMuonOptimizer (
149149 list (linear_small .parameters ()) + list (linear_large .parameters ()),
150150 lr = 0.02 ,
151151 min_2d_dim = 2 ,
@@ -172,8 +172,8 @@ def test_lr_adjust_modes(self) -> None:
172172 model2 = torch .nn .Linear (10 , 20 , bias = False , device = self .device )
173173 model2 .load_state_dict (model1 .state_dict ())
174174
175- opt1 = MuonOptimizer (model1 .parameters (), lr = 0.02 , lr_adjust = 0.0 )
176- opt2 = MuonOptimizer (model2 .parameters (), lr = 0.02 , lr_adjust = 10.0 )
175+ opt1 = HybridMuonOptimizer (model1 .parameters (), lr = 0.02 , lr_adjust = 0.0 )
176+ opt2 = HybridMuonOptimizer (model2 .parameters (), lr = 0.02 , lr_adjust = 10.0 )
177177
178178 x = torch .randn (4 , 10 , device = self .device )
179179
@@ -192,7 +192,7 @@ def test_lr_adjust_modes(self) -> None:
192192
193193
194194@unittest .skipIf (not BF16_SUPPORTED , "bf16 matmul not supported on this device" )
195- class TestMuonOptimizerStateDict (unittest .TestCase ):
195+ class TestHybridMuonOptimizerStateDict (unittest .TestCase ):
196196 """Test optimizer state dict save/load."""
197197
198198 def setUp (self ) -> None :
@@ -202,7 +202,7 @@ def test_state_dict_save_load(self) -> None:
202202 """Test saving and loading optimizer state."""
203203 torch .manual_seed (42 )
204204 model = torch .nn .Linear (10 , 10 , device = self .device )
205- optimizer = MuonOptimizer (model .parameters (), lr = 0.02 )
205+ optimizer = HybridMuonOptimizer (model .parameters (), lr = 0.02 )
206206
207207 for _ in range (3 ):
208208 optimizer .zero_grad ()
@@ -212,7 +212,7 @@ def test_state_dict_save_load(self) -> None:
212212
213213 state_dict = optimizer .state_dict ()
214214
215- optimizer2 = MuonOptimizer (model .parameters (), lr = 0.02 )
215+ optimizer2 = HybridMuonOptimizer (model .parameters (), lr = 0.02 )
216216 optimizer2 .load_state_dict (state_dict )
217217
218218 # Verify state matches by param id, not iteration order
0 commit comments