Skip to content

Commit 51e73bc

Browse files
committed
update muon_sharding_optimizer with rebuilding 2d_params.
1 parent 1561597 commit 51e73bc

1 file changed

Lines changed: 21 additions & 70 deletions

File tree

python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py

Lines changed: 21 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@
5858
g_shard_bypass_dygraph_optimizer = int(
5959
os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0)
6060
)
61-
g_shard_fused_gradient = int(os.environ.get("FLAGS_shard_fused_gradient", 0))
6261

6362

6463
def _is_trainable(param):
@@ -239,29 +238,25 @@ def __init__(self, optimizer, hcg=None):
239238
key=lambda p: self._param2rank_2d_by_color[color_key][p.name]
240239
)
241240

242-
# ---- Backward compatibility: expose legacy attributes ----
243-
# These are kept for any external code that might reference them
244-
self._params_2d = self._params_2d_by_color.get(None, [])
245-
self._params_2d_moe = self._params_2d_by_color.get('moe_expert', [])
246-
self._rank2params_2d = self._rank2params_2d_by_color.get(None, {0: []})
247-
self._param2rank_2d = self._param2rank_2d_by_color.get(None, {})
248-
self._rank2params_2d_moe = self._rank2params_2d_by_color.get(
249-
'moe_expert', {0: []}
250-
)
251-
self._param2rank_2d_moe = self._param2rank_2d_by_color.get(
252-
'moe_expert', {}
253-
)
241+
# 2D params owned by this sharding rank
242+
self._local_2d = []
243+
for color_key, params_2d in self._params_2d_by_color.items():
244+
rank2params_2d_by_color = self._rank2params_2d_by_color[color_key]
245+
246+
group_info = self._color_to_group_info[color_key]
247+
sharding_rank = max(group_info['rank'], 0)
248+
249+
self._local_2d.extend(rank2params_2d_by_color[sharding_rank])
254250

255251
self.sd_release_grads = (
256252
strategy.hybrid_configs['pp_configs'].release_gradients
257253
or sharding_configs.release_gradients
258254
)
259-
self._use_fuse_gradients = g_shard_fused_gradient
255+
self._use_fuse_gradients = self.comm_buffer_size_MB > 0
260256
# ---- Build comm buffers for 2D params (V1-style) ----
261257
if self._use_fuse_gradients:
262-
if not hasattr(self, 'comm_buffer_2d'):
263-
self.comm_buffer_2d = self._build_2d_comm_buffers()
264-
self.comm_buffer_2d.sort(key=lambda x: x._dst)
258+
self.comm_buffer_2d = self._build_2d_comm_buffers()
259+
self.comm_buffer_2d.sort(key=lambda x: x._dst)
265260

266261
# ---- Step 3: Build comm buffers for 1D params (V2-style) ----
267262
self._slice_params = {}
@@ -278,15 +273,9 @@ def __init__(self, optimizer, hcg=None):
278273
# The optimizer should see:
279274
# - All 2D params assigned to this rank (all colors, as whole tensors)
280275
# - 1D slice_params for all non-2D params (element-wise shards)
281-
local_2d_params = []
282-
for color_key, rank2params in self._rank2params_2d_by_color.items():
283-
group_info = self._color_to_group_info.get(color_key, {})
284-
color_rank = group_info.get('rank', 0)
285-
world_size = group_info.get('world_size', 1)
286-
rank_key = color_rank if world_size > 1 else 0
287-
local_2d_params.extend(rank2params.get(rank_key, []))
288-
289-
local_opt_params = local_2d_params + list(self._local_parameter_list_1d)
276+
local_opt_params = list(self._local_2d) + list(
277+
self._local_parameter_list_1d
278+
)
290279

291280
self._set_inner_opt_attr('_parameter_list', local_opt_params)
292281
self._set_inner_opt_attr('_param_groups', local_opt_params)
@@ -306,18 +295,10 @@ def __init__(self, optimizer, hcg=None):
306295
_sg_group = hcg.get_sharding_parallel_group()
307296
_N = self._sharding_world_size
308297

309-
# 2D params owned by this sharding rank (default color, via legacy alias)
298+
# 2D params owned by this sharding rank
310299
_local_2d_numel = sum(
311300
int(functools_reduce(lambda x, y: x * y, p.shape, 1))
312-
for p in self._rank2params_2d.get(self._sharding_rank, [])
313-
)
314-
# 2D MoE-expert params owned by this rank (moe_expert color, via legacy alias)
315-
_moe_rank_key = (
316-
self._moe_sharding_rank if self._moe_sharding_world_size > 1 else 0
317-
)
318-
_local_2d_moe_numel = sum(
319-
int(functools_reduce(lambda x, y: x * y, p.shape, 1))
320-
for p in self._rank2params_2d_moe.get(_moe_rank_key, [])
301+
for p in self._local_2d
321302
)
322303
# 1D (AdamW) slice: each rank holds ceil(numel / sharding_world_size) elements.
323304
_local_1d_numel = sum(
@@ -327,9 +308,7 @@ def __init__(self, optimizer, hcg=None):
327308
for p in self._params_1d
328309
)
329310

330-
_local_total_numel = (
331-
_local_2d_numel + _local_2d_moe_numel + _local_1d_numel
332-
)
311+
_local_total_numel = _local_2d_numel + _local_1d_numel
333312
_local_total_MB = (
334313
_local_total_numel * 2 / (1024 * 1024)
335314
) # bf16/fp16 = 2 bytes
@@ -545,6 +524,7 @@ def _build_1d_comm_buffers(self):
545524
self._comm_buffer_list.sort(key=lambda x: x._dst)
546525

547526
def clear_param_storage(self, color):
527+
# Only clear param_storage for 1d_params, 2d_params are not added to comm_buffers.
548528
self.clear_color.add(color)
549529
if color in self._color_to_comm_buffer_list.keys():
550530
for comm_buffer in self._color_to_comm_buffer_list[color]:
@@ -671,31 +651,13 @@ def reduce_gradients(self, parameter_list, hcg):
671651

672652
def filter_parameters(self, parameter_list, hcg):
673653
"""Filter parameters: return local 2D params + initialized 1D slices."""
674-
sharding_rank = hcg.get_sharding_parallel_rank()
675-
local_2d = [
676-
p
677-
for p in parameter_list
678-
if p.name in self._param2rank_2d
679-
and self._param2rank_2d[p.name] == sharding_rank
680-
]
681-
# Also include MoE 2D params owned by this rank
682-
if self._moe_sharding_world_size > 1:
683-
moe_rank = self._moe_sharding_rank
684-
else:
685-
moe_rank = 0
686-
local_2d_moe = [
687-
p
688-
for p in parameter_list
689-
if p.name in self._param2rank_2d_moe
690-
and self._param2rank_2d_moe[p.name] == moe_rank
691-
]
692654
local_1d = [
693655
self._slice_params[p.name]
694656
for p in parameter_list
695657
if p.name in self._slice_params
696658
]
697659
local_1d = [p for p in local_1d if p._is_initialized()]
698-
return local_2d + local_2d_moe + local_1d
660+
return self._local_2d + local_1d
699661

700662
# ------------------------------------------------------------------
701663
# Parameter sync after optimizer step
@@ -884,18 +846,7 @@ def step(self):
884846
def set_state_dict(self, state_dict):
885847
inner_state = {}
886848
# Collect local parameters: 2D whole-tensor params + 1D original params
887-
# (set_state_dict uses legacy aliases; covers default and moe_expert colors)
888-
local_2d = list(self._rank2params_2d.get(self._sharding_rank, []))
889-
if self._moe_sharding_world_size > 1:
890-
local_2d_moe = list(
891-
self._rank2params_2d_moe.get(self._moe_sharding_rank, [])
892-
)
893-
else:
894-
local_2d_moe = list(self._rank2params_2d_moe.get(0, []))
895-
parameters = local_2d + local_2d_moe
896-
# Add 1D params (use original param names for matching)
897-
for p in self._params_1d:
898-
parameters.append(p)
849+
parameters = list(self._local_2d) + list(self._params_1d)
899850

900851
if "LR_Scheduler" in state_dict:
901852
inner_state["LR_Scheduler"] = state_dict.pop("LR_Scheduler")

0 commit comments

Comments
 (0)