Skip to content

Commit ef53d85

Browse files
committed
update muon_sharding_optimizer with rebuilding 2d_params.
1 parent 1561597 commit ef53d85

2 files changed

Lines changed: 46 additions & 78 deletions

File tree

python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py

Lines changed: 40 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
g_shard_bypass_dygraph_optimizer = int(
5959
os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0)
6060
)
61-
g_shard_fused_gradient = int(os.environ.get("FLAGS_shard_fused_gradient", 0))
6261

6362

6463
def _is_trainable(param):
@@ -133,6 +132,9 @@ def __init__(self, optimizer, hcg=None):
133132

134133
pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap
135134
self.pp_overlap = pp_overlap
135+
assert not self.pp_overlap, (
136+
"muon_sharding_optimizer do not support PP overlap"
137+
)
136138

137139
self._use_main_grad = hasattr(optimizer._parameter_list[0], "main_grad")
138140

@@ -146,12 +148,6 @@ def __init__(self, optimizer, hcg=None):
146148
self._parameter_list, sharding_group
147149
)
148150

149-
# Extract MoE group info from color_to_group_info for backward compatibility
150-
moe_info = self._color_to_group_info.get('moe_expert', {})
151-
self._moe_sharding_world_size = moe_info.get('world_size', 1)
152-
self._moe_sharding_rank = moe_info.get('rank', 0)
153-
self._moe_sharding_group = moe_info.get('group', None)
154-
155151
# Get muon_param_info_map from the inner Muon optimizer.
156152
# Each entry has use_muon=True/False, set by the Trainer before construction.
157153
self._muon_param_info_map = getattr(
@@ -239,29 +235,25 @@ def __init__(self, optimizer, hcg=None):
239235
key=lambda p: self._param2rank_2d_by_color[color_key][p.name]
240236
)
241237

242-
# ---- Backward compatibility: expose legacy attributes ----
243-
# These are kept for any external code that might reference them
244-
self._params_2d = self._params_2d_by_color.get(None, [])
245-
self._params_2d_moe = self._params_2d_by_color.get('moe_expert', [])
246-
self._rank2params_2d = self._rank2params_2d_by_color.get(None, {0: []})
247-
self._param2rank_2d = self._param2rank_2d_by_color.get(None, {})
248-
self._rank2params_2d_moe = self._rank2params_2d_by_color.get(
249-
'moe_expert', {0: []}
250-
)
251-
self._param2rank_2d_moe = self._param2rank_2d_by_color.get(
252-
'moe_expert', {}
253-
)
238+
# 2D params owned by this sharding rank
239+
self._local_2d = []
240+
for color_key, params_2d in self._params_2d_by_color.items():
241+
rank2params_2d_by_color = self._rank2params_2d_by_color[color_key]
242+
243+
group_info = self._color_to_group_info[color_key]
244+
sharding_rank = max(group_info['rank'], 0)
245+
246+
self._local_2d.extend(rank2params_2d_by_color[sharding_rank])
254247

255248
self.sd_release_grads = (
256249
strategy.hybrid_configs['pp_configs'].release_gradients
257250
or sharding_configs.release_gradients
258251
)
259-
self._use_fuse_gradients = g_shard_fused_gradient
252+
self._use_fuse_gradients = self.comm_buffer_size_MB > 0
260253
# ---- Build comm buffers for 2D params (V1-style) ----
261254
if self._use_fuse_gradients:
262-
if not hasattr(self, 'comm_buffer_2d'):
263-
self.comm_buffer_2d = self._build_2d_comm_buffers()
264-
self.comm_buffer_2d.sort(key=lambda x: x._dst)
255+
self.comm_buffer_2d = self._build_2d_comm_buffers()
256+
self.comm_buffer_2d.sort(key=lambda x: x._dst)
265257

266258
# ---- Step 3: Build comm buffers for 1D params (V2-style) ----
267259
self._slice_params = {}
@@ -278,15 +270,9 @@ def __init__(self, optimizer, hcg=None):
278270
# The optimizer should see:
279271
# - All 2D params assigned to this rank (all colors, as whole tensors)
280272
# - 1D slice_params for all non-2D params (element-wise shards)
281-
local_2d_params = []
282-
for color_key, rank2params in self._rank2params_2d_by_color.items():
283-
group_info = self._color_to_group_info.get(color_key, {})
284-
color_rank = group_info.get('rank', 0)
285-
world_size = group_info.get('world_size', 1)
286-
rank_key = color_rank if world_size > 1 else 0
287-
local_2d_params.extend(rank2params.get(rank_key, []))
288-
289-
local_opt_params = local_2d_params + list(self._local_parameter_list_1d)
273+
local_opt_params = list(self._local_2d) + list(
274+
self._local_parameter_list_1d
275+
)
290276

291277
self._set_inner_opt_attr('_parameter_list', local_opt_params)
292278
self._set_inner_opt_attr('_param_groups', local_opt_params)
@@ -306,18 +292,10 @@ def __init__(self, optimizer, hcg=None):
306292
_sg_group = hcg.get_sharding_parallel_group()
307293
_N = self._sharding_world_size
308294

309-
# 2D params owned by this sharding rank (default color, via legacy alias)
295+
# 2D params owned by this sharding rank
310296
_local_2d_numel = sum(
311297
int(functools_reduce(lambda x, y: x * y, p.shape, 1))
312-
for p in self._rank2params_2d.get(self._sharding_rank, [])
313-
)
314-
# 2D MoE-expert params owned by this rank (moe_expert color, via legacy alias)
315-
_moe_rank_key = (
316-
self._moe_sharding_rank if self._moe_sharding_world_size > 1 else 0
317-
)
318-
_local_2d_moe_numel = sum(
319-
int(functools_reduce(lambda x, y: x * y, p.shape, 1))
320-
for p in self._rank2params_2d_moe.get(_moe_rank_key, [])
298+
for p in self._local_2d
321299
)
322300
# 1D (AdamW) slice: each rank holds ceil(numel / sharding_world_size) elements.
323301
_local_1d_numel = sum(
@@ -327,9 +305,7 @@ def __init__(self, optimizer, hcg=None):
327305
for p in self._params_1d
328306
)
329307

330-
_local_total_numel = (
331-
_local_2d_numel + _local_2d_moe_numel + _local_1d_numel
332-
)
308+
_local_total_numel = _local_2d_numel + _local_1d_numel
333309
_local_total_MB = (
334310
_local_total_numel * 2 / (1024 * 1024)
335311
) # bf16/fp16 = 2 bytes
@@ -546,6 +522,7 @@ def _build_1d_comm_buffers(self):
546522

547523
def clear_param_storage(self, color):
548524
self.clear_color.add(color)
525+
# 1D params
549526
if color in self._color_to_comm_buffer_list.keys():
550527
for comm_buffer in self._color_to_comm_buffer_list[color]:
551528
for param in comm_buffer.params:
@@ -561,14 +538,26 @@ def clear_param_storage(self, color):
561538
self._create_master_weight(slice_param)
562539
slice_param._clear_dataptr()
563540
comm_buffer._clear_param_storage()
541+
# 2D params
542+
if color in self._params_2d_by_color.keys():
543+
for param in self._params_2d_by_color[color]:
544+
if not g_shard_bypass_dygraph_optimizer:
545+
self._create_master_weight(param)
546+
param._clear_to_zero_allocation()
564547

565548
def reset_param_storage(self):
566549
for color in self.clear_color:
567550
if color is None:
568551
continue
552+
# 1D params
569553
if color in self._color_to_comm_buffer_list.keys():
570554
for comm_buffer in self._color_to_comm_buffer_list[color]:
571555
comm_buffer._reset_param_storage()
556+
# 2D params
557+
if color in self._params_2d_by_color.keys():
558+
for param in self._params_2d_by_color[color]:
559+
new_param = paddle.empty_like(param)
560+
new_param._share_buffer_to(param)
572561

573562
# ------------------------------------------------------------------
574563
# Gradient communication
@@ -671,31 +660,13 @@ def reduce_gradients(self, parameter_list, hcg):
671660

672661
def filter_parameters(self, parameter_list, hcg):
673662
"""Filter parameters: return local 2D params + initialized 1D slices."""
674-
sharding_rank = hcg.get_sharding_parallel_rank()
675-
local_2d = [
676-
p
677-
for p in parameter_list
678-
if p.name in self._param2rank_2d
679-
and self._param2rank_2d[p.name] == sharding_rank
680-
]
681-
# Also include MoE 2D params owned by this rank
682-
if self._moe_sharding_world_size > 1:
683-
moe_rank = self._moe_sharding_rank
684-
else:
685-
moe_rank = 0
686-
local_2d_moe = [
687-
p
688-
for p in parameter_list
689-
if p.name in self._param2rank_2d_moe
690-
and self._param2rank_2d_moe[p.name] == moe_rank
691-
]
692663
local_1d = [
693664
self._slice_params[p.name]
694665
for p in parameter_list
695666
if p.name in self._slice_params
696667
]
697668
local_1d = [p for p in local_1d if p._is_initialized()]
698-
return local_2d + local_2d_moe + local_1d
669+
return self._local_2d + local_1d
699670

700671
# ------------------------------------------------------------------
701672
# Parameter sync after optimizer step
@@ -771,12 +742,12 @@ def clear_grad_func(p):
771742
for p in self._parameter_list:
772743
clear_grad_func(p)
773744

774-
# 1D params are managed by comm buffers
775745
if self.sd_release_grads and not self.pp_overlap:
746+
# 1D params are managed by comm buffers
776747
for comm_buffer in self._comm_buffer_list:
777748
if comm_buffer.need_reduce_scale_sync():
778749
comm_buffer._clear_grad_storage()
779-
750+
# 2D params are managed by comm buffers
780751
if self._use_fuse_gradients:
781752
for comm_buffer in self.comm_buffer_2d:
782753
if comm_buffer.need_reduce_scale_sync():
@@ -820,6 +791,8 @@ def _assign_slice_grad(self):
820791

821792
def step(self):
822793
"""Optimizer step: update local 2D params and 1D slices, then sync."""
794+
self.reset_param_storage()
795+
823796
self._collect_comm_buffers()
824797
self._assign_slice_grad()
825798

@@ -884,18 +857,7 @@ def step(self):
884857
def set_state_dict(self, state_dict):
885858
inner_state = {}
886859
# Collect local parameters: 2D whole-tensor params + 1D original params
887-
# (set_state_dict uses legacy aliases; covers default and moe_expert colors)
888-
local_2d = list(self._rank2params_2d.get(self._sharding_rank, []))
889-
if self._moe_sharding_world_size > 1:
890-
local_2d_moe = list(
891-
self._rank2params_2d_moe.get(self._moe_sharding_rank, [])
892-
)
893-
else:
894-
local_2d_moe = list(self._rank2params_2d_moe.get(0, []))
895-
parameters = local_2d + local_2d_moe
896-
# Add 1D params (use original param names for matching)
897-
for p in self._params_1d:
898-
parameters.append(p)
860+
parameters = list(self._local_2d) + list(self._params_1d)
899861

900862
if "LR_Scheduler" in state_dict:
901863
inner_state["LR_Scheduler"] = state_dict.pop("LR_Scheduler")

test/collective/fleet/hybrid_parallel_sharding_muon_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@
2727

2828
import paddle
2929
from paddle.distributed import fleet
30+
from paddle.distributed.fleet.meta_optimizers.muon_sharding_optimizer import (
31+
MuonShardingOptimizer,
32+
)
3033
from paddle.distributed.fleet.utils import mix_precision_utils
3134
from paddle.optimizer.muon import (
3235
MuonParamInfo,
@@ -240,6 +243,9 @@ def train_batch(self, batch, model, optimizer):
240243
output = model(batch)
241244
loss = output.mean()
242245
loss.backward()
246+
inner_opt = getattr(optimizer, '_inner_opt', optimizer)
247+
if isinstance(inner_opt, MuonShardingOptimizer):
248+
optimizer.clear_param_storage('test_color')
243249
optimizer.step()
244250
optimizer.clear_grad()
245251
return loss

0 commit comments

Comments
 (0)