@@ -26,7 +26,10 @@ def test_muon_sharding_optimizer(self):
2626 Test logic is in hybrid_parallel_sharding_muon_model.py,
2727 iterating 4 ns_coeff_types. fp32 matmul is auto-selected on V100.
2828 """
29- self .run_mnist_2accelerators ('hybrid_parallel_sharding_muon_model.py' )
29+ self .run_mnist_2accelerators (
30+ 'hybrid_parallel_sharding_muon_model.py' ,
31+ need_envs = {"MULTI_PRECISION" : "1" },
32+ )
3033
3134 def test_muon_sharding_fused_gradient (self ):
3235 """MuonSharding test with FLAGS_shard_fused_gradient=1.
@@ -36,7 +39,10 @@ def test_muon_sharding_fused_gradient(self):
3639 """
3740 self .run_mnist_2accelerators (
3841 'hybrid_parallel_sharding_muon_model.py' ,
39- need_envs = {"FLAGS_shard_fused_gradient" : "1" },
42+ need_envs = {
43+ "FLAGS_shard_fused_gradient" : "1" ,
44+ "MULTI_PRECISION" : "1" ,
45+ },
4046 )
4147
4248 def test_muon_sharding_fuse_optimizer_states (self ):
@@ -46,7 +52,10 @@ def test_muon_sharding_fuse_optimizer_states(self):
4652 """
4753 self .run_mnist_2accelerators (
4854 'hybrid_parallel_sharding_muon_model.py' ,
49- need_envs = {"ENABLE_FUSE_OPTIMIZER_STATES" : "1" },
55+ need_envs = {
56+ "ENABLE_FUSE_OPTIMIZER_STATES" : "1" ,
57+ "MULTI_PRECISION" : "1" ,
58+ },
5059 )
5160
5261 def test_muon_sharding_release_grads_fused (self ):
@@ -60,6 +69,7 @@ def test_muon_sharding_release_grads_fused(self):
6069 need_envs = {
6170 "FLAGS_shard_fused_gradient" : "1" ,
6271 "RELEASE_GRADIENTS" : "1" ,
72+ "MULTI_PRECISION" : "1" ,
6373 },
6474 )
6575
0 commit comments