Skip to content

Commit b386ea8

Browse files
committed
update muon_sharding_optimizer with rebuilding 2d_params.
1 parent 5e989c9 commit b386ea8

2 files changed

Lines changed: 17 additions & 3 deletions

File tree

python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,10 @@ def _build_1d_comm_buffers(self):
521521
self._comm_buffer_list.sort(key=lambda x: x._dst)
522522

523523
def clear_param_storage(self, color):
524+
assert self._multi_precision, (
525+
"Muon Sharding Optimizer only support clear param with multi_precision mode"
526+
)
527+
524528
self.clear_color.add(color)
525529
# 1D params
526530
if color in self._color_to_comm_buffer_list.keys():

test/collective/fleet/test_parallel_dygraph_muon.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ def test_muon_sharding_optimizer(self):
2626
Test logic is in hybrid_parallel_sharding_muon_model.py,
2727
iterating 4 ns_coeff_types. fp32 matmul is auto-selected on V100.
2828
"""
29-
self.run_mnist_2accelerators('hybrid_parallel_sharding_muon_model.py')
29+
self.run_mnist_2accelerators(
30+
'hybrid_parallel_sharding_muon_model.py',
31+
need_envs={"MULTI_PRECISION": "1"},
32+
)
3033

3134
def test_muon_sharding_fused_gradient(self):
3235
"""MuonSharding test with FLAGS_shard_fused_gradient=1.
@@ -36,7 +39,10 @@ def test_muon_sharding_fused_gradient(self):
3639
"""
3740
self.run_mnist_2accelerators(
3841
'hybrid_parallel_sharding_muon_model.py',
39-
need_envs={"FLAGS_shard_fused_gradient": "1"},
42+
need_envs={
43+
"FLAGS_shard_fused_gradient": "1",
44+
"MULTI_PRECISION": "1",
45+
},
4046
)
4147

4248
def test_muon_sharding_fuse_optimizer_states(self):
@@ -46,7 +52,10 @@ def test_muon_sharding_fuse_optimizer_states(self):
4652
"""
4753
self.run_mnist_2accelerators(
4854
'hybrid_parallel_sharding_muon_model.py',
49-
need_envs={"ENABLE_FUSE_OPTIMIZER_STATES": "1"},
55+
need_envs={
56+
"ENABLE_FUSE_OPTIMIZER_STATES": "1",
57+
"MULTI_PRECISION": "1",
58+
},
5059
)
5160

5261
def test_muon_sharding_release_grads_fused(self):
@@ -60,6 +69,7 @@ def test_muon_sharding_release_grads_fused(self):
6069
need_envs={
6170
"FLAGS_shard_fused_gradient": "1",
6271
"RELEASE_GRADIENTS": "1",
72+
"MULTI_PRECISION": "1",
6373
},
6474
)
6575

0 commit comments

Comments
 (0)