Skip to content

Commit 3d3826f

Browse files
authored
update muon_sharding_optimizer with rebuilding 2d_params. (#78814)
1 parent 7f0b4b5 commit 3d3826f

4 files changed

Lines changed: 77 additions & 90 deletions

File tree

python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py

Lines changed: 56 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
g_shard_bypass_dygraph_optimizer = int(
5959
os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0)
6060
)
61-
g_shard_fused_gradient = int(os.environ.get("FLAGS_shard_fused_gradient", 0))
6261

6362

6463
def _is_trainable(param):
@@ -133,6 +132,9 @@ def __init__(self, optimizer, hcg=None):
133132

134133
pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap
135134
self.pp_overlap = pp_overlap
135+
assert not self.pp_overlap, (
136+
"muon_sharding_optimizer do not support PP overlap"
137+
)
136138

137139
self._use_main_grad = hasattr(optimizer._parameter_list[0], "main_grad")
138140

@@ -146,12 +148,6 @@ def __init__(self, optimizer, hcg=None):
146148
self._parameter_list, sharding_group
147149
)
148150

149-
# Extract MoE group info from color_to_group_info for backward compatibility
150-
moe_info = self._color_to_group_info.get('moe_expert', {})
151-
self._moe_sharding_world_size = moe_info.get('world_size', 1)
152-
self._moe_sharding_rank = moe_info.get('rank', 0)
153-
self._moe_sharding_group = moe_info.get('group', None)
154-
155151
# Get muon_param_info_map from the inner Muon optimizer.
156152
# Each entry has use_muon=True/False, set by the Trainer before construction.
157153
self._muon_param_info_map = getattr(
@@ -239,29 +235,25 @@ def __init__(self, optimizer, hcg=None):
239235
key=lambda p: self._param2rank_2d_by_color[color_key][p.name]
240236
)
241237

242-
# ---- Backward compatibility: expose legacy attributes ----
243-
# These are kept for any external code that might reference them
244-
self._params_2d = self._params_2d_by_color.get(None, [])
245-
self._params_2d_moe = self._params_2d_by_color.get('moe_expert', [])
246-
self._rank2params_2d = self._rank2params_2d_by_color.get(None, {0: []})
247-
self._param2rank_2d = self._param2rank_2d_by_color.get(None, {})
248-
self._rank2params_2d_moe = self._rank2params_2d_by_color.get(
249-
'moe_expert', {0: []}
250-
)
251-
self._param2rank_2d_moe = self._param2rank_2d_by_color.get(
252-
'moe_expert', {}
253-
)
238+
# 2D params owned by this sharding rank
239+
self._local_2d = []
240+
for color_key, params_2d in self._params_2d_by_color.items():
241+
rank2params_2d_by_color = self._rank2params_2d_by_color[color_key]
242+
243+
group_info = self._color_to_group_info[color_key]
244+
sharding_rank = max(group_info['rank'], 0)
245+
246+
self._local_2d.extend(rank2params_2d_by_color[sharding_rank])
254247

255248
self.sd_release_grads = (
256249
strategy.hybrid_configs['pp_configs'].release_gradients
257250
or sharding_configs.release_gradients
258251
)
259-
self._use_fuse_gradients = g_shard_fused_gradient
252+
self._use_fuse_gradients = self.comm_buffer_size_MB > 0
260253
# ---- Build comm buffers for 2D params (V1-style) ----
261254
if self._use_fuse_gradients:
262-
if not hasattr(self, 'comm_buffer_2d'):
263-
self.comm_buffer_2d = self._build_2d_comm_buffers()
264-
self.comm_buffer_2d.sort(key=lambda x: x._dst)
255+
self.comm_buffer_2d = self._build_2d_comm_buffers()
256+
self.comm_buffer_2d.sort(key=lambda x: x._dst)
265257

266258
# ---- Step 3: Build comm buffers for 1D params (V2-style) ----
267259
self._slice_params = {}
@@ -278,15 +270,9 @@ def __init__(self, optimizer, hcg=None):
278270
# The optimizer should see:
279271
# - All 2D params assigned to this rank (all colors, as whole tensors)
280272
# - 1D slice_params for all non-2D params (element-wise shards)
281-
local_2d_params = []
282-
for color_key, rank2params in self._rank2params_2d_by_color.items():
283-
group_info = self._color_to_group_info.get(color_key, {})
284-
color_rank = group_info.get('rank', 0)
285-
world_size = group_info.get('world_size', 1)
286-
rank_key = color_rank if world_size > 1 else 0
287-
local_2d_params.extend(rank2params.get(rank_key, []))
288-
289-
local_opt_params = local_2d_params + list(self._local_parameter_list_1d)
273+
local_opt_params = list(self._local_2d) + list(
274+
self._local_parameter_list_1d
275+
)
290276

291277
self._set_inner_opt_attr('_parameter_list', local_opt_params)
292278
self._set_inner_opt_attr('_param_groups', local_opt_params)
@@ -306,18 +292,10 @@ def __init__(self, optimizer, hcg=None):
306292
_sg_group = hcg.get_sharding_parallel_group()
307293
_N = self._sharding_world_size
308294

309-
# 2D params owned by this sharding rank (default color, via legacy alias)
295+
# 2D params owned by this sharding rank
310296
_local_2d_numel = sum(
311297
int(functools_reduce(lambda x, y: x * y, p.shape, 1))
312-
for p in self._rank2params_2d.get(self._sharding_rank, [])
313-
)
314-
# 2D MoE-expert params owned by this rank (moe_expert color, via legacy alias)
315-
_moe_rank_key = (
316-
self._moe_sharding_rank if self._moe_sharding_world_size > 1 else 0
317-
)
318-
_local_2d_moe_numel = sum(
319-
int(functools_reduce(lambda x, y: x * y, p.shape, 1))
320-
for p in self._rank2params_2d_moe.get(_moe_rank_key, [])
298+
for p in self._local_2d
321299
)
322300
# 1D (AdamW) slice: each rank holds ceil(numel / sharding_world_size) elements.
323301
_local_1d_numel = sum(
@@ -327,9 +305,7 @@ def __init__(self, optimizer, hcg=None):
327305
for p in self._params_1d
328306
)
329307

330-
_local_total_numel = (
331-
_local_2d_numel + _local_2d_moe_numel + _local_1d_numel
332-
)
308+
_local_total_numel = _local_2d_numel + _local_1d_numel
333309
_local_total_MB = (
334310
_local_total_numel * 2 / (1024 * 1024)
335311
) # bf16/fp16 = 2 bytes
@@ -545,9 +521,15 @@ def _build_1d_comm_buffers(self):
545521
self._comm_buffer_list.sort(key=lambda x: x._dst)
546522

547523
def clear_param_storage(self, color):
524+
assert self._multi_precision, (
525+
"Muon Sharding Optimizer only support clear param with multi_precision mode"
526+
)
527+
548528
self.clear_color.add(color)
529+
# 1D params
549530
if color in self._color_to_comm_buffer_list.keys():
550531
for comm_buffer in self._color_to_comm_buffer_list[color]:
532+
has_clear = False
551533
for param in comm_buffer.params:
552534
grad_view = comm_buffer._sharding_param_grad_view[
553535
param.name
@@ -559,16 +541,36 @@ def clear_param_storage(self, color):
559541
):
560542
grad_view.fill_slice_param(slice_param)
561543
self._create_master_weight(slice_param)
562-
slice_param._clear_dataptr()
563-
comm_buffer._clear_param_storage()
544+
if param.name in self._master_weights:
545+
slice_param._clear_dataptr()
546+
has_clear = True
547+
548+
if has_clear:
549+
comm_buffer._clear_param_storage()
550+
# 2D params
551+
if color in self._params_2d_by_color.keys():
552+
for param in self._params_2d_by_color[color]:
553+
if not g_shard_bypass_dygraph_optimizer:
554+
self._create_master_weight(param)
555+
556+
if param.name in self._master_weights:
557+
param._clear_to_zero_allocation()
564558

565559
def reset_param_storage(self):
566560
for color in self.clear_color:
567561
if color is None:
568562
continue
563+
# 1D params
569564
if color in self._color_to_comm_buffer_list.keys():
570565
for comm_buffer in self._color_to_comm_buffer_list[color]:
571-
comm_buffer._reset_param_storage()
566+
if not comm_buffer.param_storage._is_initialized():
567+
comm_buffer._reset_param_storage()
568+
# 2D params
569+
if color in self._params_2d_by_color.keys():
570+
for param in self._params_2d_by_color[color]:
571+
if not param._is_initialized():
572+
new_param = paddle.empty_like(param)
573+
new_param._share_buffer_to(param)
572574

573575
# ------------------------------------------------------------------
574576
# Gradient communication
@@ -671,31 +673,13 @@ def reduce_gradients(self, parameter_list, hcg):
671673

672674
def filter_parameters(self, parameter_list, hcg):
673675
"""Filter parameters: return local 2D params + initialized 1D slices."""
674-
sharding_rank = hcg.get_sharding_parallel_rank()
675-
local_2d = [
676-
p
677-
for p in parameter_list
678-
if p.name in self._param2rank_2d
679-
and self._param2rank_2d[p.name] == sharding_rank
680-
]
681-
# Also include MoE 2D params owned by this rank
682-
if self._moe_sharding_world_size > 1:
683-
moe_rank = self._moe_sharding_rank
684-
else:
685-
moe_rank = 0
686-
local_2d_moe = [
687-
p
688-
for p in parameter_list
689-
if p.name in self._param2rank_2d_moe
690-
and self._param2rank_2d_moe[p.name] == moe_rank
691-
]
692676
local_1d = [
693677
self._slice_params[p.name]
694678
for p in parameter_list
695679
if p.name in self._slice_params
696680
]
697681
local_1d = [p for p in local_1d if p._is_initialized()]
698-
return local_2d + local_2d_moe + local_1d
682+
return self._local_2d + local_1d
699683

700684
# ------------------------------------------------------------------
701685
# Parameter sync after optimizer step
@@ -771,12 +755,12 @@ def clear_grad_func(p):
771755
for p in self._parameter_list:
772756
clear_grad_func(p)
773757

774-
# 1D params are managed by comm buffers
775758
if self.sd_release_grads and not self.pp_overlap:
759+
# 1D params are managed by comm buffers
776760
for comm_buffer in self._comm_buffer_list:
777761
if comm_buffer.need_reduce_scale_sync():
778762
comm_buffer._clear_grad_storage()
779-
763+
# 2D params are managed by comm buffers
780764
if self._use_fuse_gradients:
781765
for comm_buffer in self.comm_buffer_2d:
782766
if comm_buffer.need_reduce_scale_sync():
@@ -820,6 +804,8 @@ def _assign_slice_grad(self):
820804

821805
def step(self):
822806
"""Optimizer step: update local 2D params and 1D slices, then sync."""
807+
self.reset_param_storage()
808+
823809
self._collect_comm_buffers()
824810
self._assign_slice_grad()
825811

@@ -884,18 +870,7 @@ def step(self):
884870
def set_state_dict(self, state_dict):
885871
inner_state = {}
886872
# Collect local parameters: 2D whole-tensor params + 1D original params
887-
# (set_state_dict uses legacy aliases; covers default and moe_expert colors)
888-
local_2d = list(self._rank2params_2d.get(self._sharding_rank, []))
889-
if self._moe_sharding_world_size > 1:
890-
local_2d_moe = list(
891-
self._rank2params_2d_moe.get(self._moe_sharding_rank, [])
892-
)
893-
else:
894-
local_2d_moe = list(self._rank2params_2d_moe.get(0, []))
895-
parameters = local_2d + local_2d_moe
896-
# Add 1D params (use original param names for matching)
897-
for p in self._params_1d:
898-
parameters.append(p)
873+
parameters = list(self._local_2d) + list(self._params_1d)
899874

900875
if "LR_Scheduler" in state_dict:
901876
inner_state["LR_Scheduler"] = state_dict.pop("LR_Scheduler")

python/paddle/optimizer/muon.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,7 @@ def _adamw_update(
445445
):
446446
with_decay = False
447447

448-
find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
449-
param.dtype
450-
)
448+
find_master = param.name in self._master_weights
451449
master_weight = (
452450
self._master_weights[param.name] if find_master else None
453451
)
@@ -557,9 +555,7 @@ def ortho_fn(m):
557555
# Default: whole matrix orthogonalisation
558556
orthogonal_update = ortho_fn(matrix_2d_global)
559557

560-
find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
561-
param.dtype
562-
)
558+
find_master = param.name in self._master_weights
563559
master_weight = (
564560
self._master_weights[param.name] if find_master else None
565561
)

test/collective/fleet/hybrid_parallel_sharding_muon_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
import paddle
2929
from paddle.distributed import fleet
30+
from paddle.distributed.fleet.meta_optimizers.muon_sharding_optimizer import (
31+
MuonShardingOptimizer,
32+
)
3033
from paddle.distributed.fleet.utils import mix_precision_utils
3134
from paddle.optimizer.muon import (
3235
MuonParamInfo,
@@ -240,6 +243,9 @@ def train_batch(self, batch, model, optimizer):
240243
output = model(batch)
241244
loss = output.mean()
242245
loss.backward()
246+
inner_opt = getattr(optimizer, '_inner_opt', optimizer)
247+
if isinstance(inner_opt, MuonShardingOptimizer):
248+
optimizer.clear_param_storage('test_color')
243249
optimizer.step()
244250
optimizer.clear_grad()
245251
return loss

test/collective/fleet/test_parallel_dygraph_muon.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ def test_muon_sharding_optimizer(self):
2626
Test logic is in hybrid_parallel_sharding_muon_model.py,
2727
iterating 4 ns_coeff_types. fp32 matmul is auto-selected on V100.
2828
"""
29-
self.run_mnist_2accelerators('hybrid_parallel_sharding_muon_model.py')
29+
self.run_mnist_2accelerators(
30+
'hybrid_parallel_sharding_muon_model.py',
31+
need_envs={"MULTI_PRECISION": "1"},
32+
)
3033

3134
def test_muon_sharding_fused_gradient(self):
3235
"""MuonSharding test with FLAGS_shard_fused_gradient=1.
@@ -36,7 +39,10 @@ def test_muon_sharding_fused_gradient(self):
3639
"""
3740
self.run_mnist_2accelerators(
3841
'hybrid_parallel_sharding_muon_model.py',
39-
need_envs={"FLAGS_shard_fused_gradient": "1"},
42+
need_envs={
43+
"FLAGS_shard_fused_gradient": "1",
44+
"MULTI_PRECISION": "1",
45+
},
4046
)
4147

4248
def test_muon_sharding_fuse_optimizer_states(self):
@@ -46,7 +52,10 @@ def test_muon_sharding_fuse_optimizer_states(self):
4652
"""
4753
self.run_mnist_2accelerators(
4854
'hybrid_parallel_sharding_muon_model.py',
49-
need_envs={"ENABLE_FUSE_OPTIMIZER_STATES": "1"},
55+
need_envs={
56+
"ENABLE_FUSE_OPTIMIZER_STATES": "1",
57+
"MULTI_PRECISION": "1",
58+
},
5059
)
5160

5261
def test_muon_sharding_release_grads_fused(self):
@@ -60,6 +69,7 @@ def test_muon_sharding_release_grads_fused(self):
6069
need_envs={
6170
"FLAGS_shard_fused_gradient": "1",
6271
"RELEASE_GRADIENTS": "1",
72+
"MULTI_PRECISION": "1",
6373
},
6474
)
6575

0 commit comments

Comments
 (0)