From 72c21897063f43883468baa7e59582b665b939a4 Mon Sep 17 00:00:00 2001 From: xxyux <1650459510@qq.com> Date: Mon, 27 Apr 2026 17:15:21 +0800 Subject: [PATCH 1/4] update muon_sharding_optimizer with rebuilding 2d_params. --- .../muon_sharding_optimizer.py | 118 ++++++------------ .../hybrid_parallel_sharding_muon_model.py | 6 + 2 files changed, 46 insertions(+), 78 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py index 89996eaea1ca49..c4f5be3013555e 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py @@ -58,7 +58,6 @@ g_shard_bypass_dygraph_optimizer = int( os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0) ) -g_shard_fused_gradient = int(os.environ.get("FLAGS_shard_fused_gradient", 0)) def _is_trainable(param): @@ -133,6 +132,9 @@ def __init__(self, optimizer, hcg=None): pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap self.pp_overlap = pp_overlap + assert not self.pp_overlap, ( + "muon_sharding_optimizer do not support PP overlap" + ) self._use_main_grad = hasattr(optimizer._parameter_list[0], "main_grad") @@ -146,12 +148,6 @@ def __init__(self, optimizer, hcg=None): self._parameter_list, sharding_group ) - # Extract MoE group info from color_to_group_info for backward compatibility - moe_info = self._color_to_group_info.get('moe_expert', {}) - self._moe_sharding_world_size = moe_info.get('world_size', 1) - self._moe_sharding_rank = moe_info.get('rank', 0) - self._moe_sharding_group = moe_info.get('group', None) - # Get muon_param_info_map from the inner Muon optimizer. # Each entry has use_muon=True/False, set by the Trainer before construction. self._muon_param_info_map = getattr( @@ -239,29 +235,25 @@ def __init__(self, optimizer, hcg=None): key=lambda p: self._param2rank_2d_by_color[color_key][p.name] ) - # ---- Backward compatibility: expose legacy attributes ---- - # These are kept for any external code that might reference them - self._params_2d = self._params_2d_by_color.get(None, []) - self._params_2d_moe = self._params_2d_by_color.get('moe_expert', []) - self._rank2params_2d = self._rank2params_2d_by_color.get(None, {0: []}) - self._param2rank_2d = self._param2rank_2d_by_color.get(None, {}) - self._rank2params_2d_moe = self._rank2params_2d_by_color.get( - 'moe_expert', {0: []} - ) - self._param2rank_2d_moe = self._param2rank_2d_by_color.get( - 'moe_expert', {} - ) + # 2D params owned by this sharding rank + self._local_2d = [] + for color_key, params_2d in self._params_2d_by_color.items(): + rank2params_2d_by_color = self._rank2params_2d_by_color[color_key] + + group_info = self._color_to_group_info[color_key] + sharding_rank = max(group_info['rank'], 0) + + self._local_2d.extend(rank2params_2d_by_color[sharding_rank]) self.sd_release_grads = ( strategy.hybrid_configs['pp_configs'].release_gradients or sharding_configs.release_gradients ) - self._use_fuse_gradients = g_shard_fused_gradient + self._use_fuse_gradients = self.comm_buffer_size_MB > 0 # ---- Build comm buffers for 2D params (V1-style) ---- if self._use_fuse_gradients: - if not hasattr(self, 'comm_buffer_2d'): - self.comm_buffer_2d = self._build_2d_comm_buffers() - self.comm_buffer_2d.sort(key=lambda x: x._dst) + self.comm_buffer_2d = self._build_2d_comm_buffers() + self.comm_buffer_2d.sort(key=lambda x: x._dst) # ---- Step 3: Build comm buffers for 1D params (V2-style) ---- self._slice_params = {} @@ -278,15 +270,9 @@ def __init__(self, optimizer, hcg=None): # The optimizer should see: # - All 2D params assigned to this rank (all colors, as whole tensors) # - 1D slice_params for all non-2D params (element-wise shards) - local_2d_params = [] - for color_key, rank2params in self._rank2params_2d_by_color.items(): - group_info = self._color_to_group_info.get(color_key, {}) - color_rank = group_info.get('rank', 0) - world_size = group_info.get('world_size', 1) - rank_key = color_rank if world_size > 1 else 0 - local_2d_params.extend(rank2params.get(rank_key, [])) - - local_opt_params = local_2d_params + list(self._local_parameter_list_1d) + local_opt_params = list(self._local_2d) + list( + self._local_parameter_list_1d + ) self._set_inner_opt_attr('_parameter_list', local_opt_params) self._set_inner_opt_attr('_param_groups', local_opt_params) @@ -306,18 +292,10 @@ def __init__(self, optimizer, hcg=None): _sg_group = hcg.get_sharding_parallel_group() _N = self._sharding_world_size - # 2D params owned by this sharding rank (default color, via legacy alias) + # 2D params owned by this sharding rank _local_2d_numel = sum( int(functools_reduce(lambda x, y: x * y, p.shape, 1)) - for p in self._rank2params_2d.get(self._sharding_rank, []) - ) - # 2D MoE-expert params owned by this rank (moe_expert color, via legacy alias) - _moe_rank_key = ( - self._moe_sharding_rank if self._moe_sharding_world_size > 1 else 0 - ) - _local_2d_moe_numel = sum( - int(functools_reduce(lambda x, y: x * y, p.shape, 1)) - for p in self._rank2params_2d_moe.get(_moe_rank_key, []) + for p in self._local_2d ) # 1D (AdamW) slice: each rank holds ceil(numel / sharding_world_size) elements. _local_1d_numel = sum( @@ -327,9 +305,7 @@ def __init__(self, optimizer, hcg=None): for p in self._params_1d ) - _local_total_numel = ( - _local_2d_numel + _local_2d_moe_numel + _local_1d_numel - ) + _local_total_numel = _local_2d_numel + _local_1d_numel _local_total_MB = ( _local_total_numel * 2 / (1024 * 1024) ) # bf16/fp16 = 2 bytes @@ -546,6 +522,7 @@ def _build_1d_comm_buffers(self): def clear_param_storage(self, color): self.clear_color.add(color) + # 1D params if color in self._color_to_comm_buffer_list.keys(): for comm_buffer in self._color_to_comm_buffer_list[color]: for param in comm_buffer.params: @@ -561,14 +538,26 @@ def clear_param_storage(self, color): self._create_master_weight(slice_param) slice_param._clear_dataptr() comm_buffer._clear_param_storage() + # 2D params + if color in self._params_2d_by_color.keys(): + for param in self._params_2d_by_color[color]: + if not g_shard_bypass_dygraph_optimizer: + self._create_master_weight(param) + param._clear_to_zero_allocation() def reset_param_storage(self): for color in self.clear_color: if color is None: continue + # 1D params if color in self._color_to_comm_buffer_list.keys(): for comm_buffer in self._color_to_comm_buffer_list[color]: comm_buffer._reset_param_storage() + # 2D params + if color in self._params_2d_by_color.keys(): + for param in self._params_2d_by_color[color]: + new_param = paddle.empty_like(param) + new_param._share_buffer_to(param) # ------------------------------------------------------------------ # Gradient communication @@ -671,31 +660,13 @@ def reduce_gradients(self, parameter_list, hcg): def filter_parameters(self, parameter_list, hcg): """Filter parameters: return local 2D params + initialized 1D slices.""" - sharding_rank = hcg.get_sharding_parallel_rank() - local_2d = [ - p - for p in parameter_list - if p.name in self._param2rank_2d - and self._param2rank_2d[p.name] == sharding_rank - ] - # Also include MoE 2D params owned by this rank - if self._moe_sharding_world_size > 1: - moe_rank = self._moe_sharding_rank - else: - moe_rank = 0 - local_2d_moe = [ - p - for p in parameter_list - if p.name in self._param2rank_2d_moe - and self._param2rank_2d_moe[p.name] == moe_rank - ] local_1d = [ self._slice_params[p.name] for p in parameter_list if p.name in self._slice_params ] local_1d = [p for p in local_1d if p._is_initialized()] - return local_2d + local_2d_moe + local_1d + return self._local_2d + local_1d # ------------------------------------------------------------------ # Parameter sync after optimizer step @@ -771,12 +742,12 @@ def clear_grad_func(p): for p in self._parameter_list: clear_grad_func(p) - # 1D params are managed by comm buffers if self.sd_release_grads and not self.pp_overlap: + # 1D params are managed by comm buffers for comm_buffer in self._comm_buffer_list: if comm_buffer.need_reduce_scale_sync(): comm_buffer._clear_grad_storage() - + # 2D params are managed by comm buffers if self._use_fuse_gradients: for comm_buffer in self.comm_buffer_2d: if comm_buffer.need_reduce_scale_sync(): @@ -820,6 +791,8 @@ def _assign_slice_grad(self): def step(self): """Optimizer step: update local 2D params and 1D slices, then sync.""" + self.reset_param_storage() + self._collect_comm_buffers() self._assign_slice_grad() @@ -884,18 +857,7 @@ def step(self): def set_state_dict(self, state_dict): inner_state = {} # Collect local parameters: 2D whole-tensor params + 1D original params - # (set_state_dict uses legacy aliases; covers default and moe_expert colors) - local_2d = list(self._rank2params_2d.get(self._sharding_rank, [])) - if self._moe_sharding_world_size > 1: - local_2d_moe = list( - self._rank2params_2d_moe.get(self._moe_sharding_rank, []) - ) - else: - local_2d_moe = list(self._rank2params_2d_moe.get(0, [])) - parameters = local_2d + local_2d_moe - # Add 1D params (use original param names for matching) - for p in self._params_1d: - parameters.append(p) + parameters = list(self._local_2d) + list(self._params_1d) if "LR_Scheduler" in state_dict: inner_state["LR_Scheduler"] = state_dict.pop("LR_Scheduler") diff --git a/test/collective/fleet/hybrid_parallel_sharding_muon_model.py b/test/collective/fleet/hybrid_parallel_sharding_muon_model.py index 489ecdc5cac447..e4676c27501ed3 100644 --- a/test/collective/fleet/hybrid_parallel_sharding_muon_model.py +++ b/test/collective/fleet/hybrid_parallel_sharding_muon_model.py @@ -27,6 +27,9 @@ import paddle from paddle.distributed import fleet +from paddle.distributed.fleet.meta_optimizers.muon_sharding_optimizer import ( + MuonShardingOptimizer, +) from paddle.distributed.fleet.utils import mix_precision_utils from paddle.optimizer.muon import ( MuonParamInfo, @@ -240,6 +243,9 @@ def train_batch(self, batch, model, optimizer): output = model(batch) loss = output.mean() loss.backward() + if isinstance(optimizer, MuonShardingOptimizer): + optimizer.clear_param_storage('test_color') + optimizer.reset_param_storage() optimizer.step() optimizer.clear_grad() return loss From 5e989c9c6bf75f41671065ca3b1a797f410625b5 Mon Sep 17 00:00:00 2001 From: xxyux <1650459510@qq.com> Date: Mon, 27 Apr 2026 17:15:21 +0800 Subject: [PATCH 2/4] update muon_sharding_optimizer with rebuilding 2d_params. --- test/collective/fleet/hybrid_parallel_sharding_muon_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/collective/fleet/hybrid_parallel_sharding_muon_model.py b/test/collective/fleet/hybrid_parallel_sharding_muon_model.py index e4676c27501ed3..23909368e81a47 100644 --- a/test/collective/fleet/hybrid_parallel_sharding_muon_model.py +++ b/test/collective/fleet/hybrid_parallel_sharding_muon_model.py @@ -243,9 +243,9 @@ def train_batch(self, batch, model, optimizer): output = model(batch) loss = output.mean() loss.backward() - if isinstance(optimizer, MuonShardingOptimizer): + inner_opt = getattr(optimizer, '_inner_opt', optimizer) + if isinstance(inner_opt, MuonShardingOptimizer): optimizer.clear_param_storage('test_color') - optimizer.reset_param_storage() optimizer.step() optimizer.clear_grad() return loss From b386ea8049283c827ac101f3c5c1f7a7d3ba8402 Mon Sep 17 00:00:00 2001 From: xxyux <1650459510@qq.com> Date: Mon, 27 Apr 2026 17:15:21 +0800 Subject: [PATCH 3/4] update muon_sharding_optimizer with rebuilding 2d_params. --- .../meta_optimizers/muon_sharding_optimizer.py | 4 ++++ .../fleet/test_parallel_dygraph_muon.py | 16 +++++++++++++--- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py index c4f5be3013555e..199e7dcff2d3ae 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py @@ -521,6 +521,10 @@ def _build_1d_comm_buffers(self): self._comm_buffer_list.sort(key=lambda x: x._dst) def clear_param_storage(self, color): + assert self._multi_precision, ( + "Muon Sharding Optimizer only support clear param with multi_precision mode" + ) + self.clear_color.add(color) # 1D params if color in self._color_to_comm_buffer_list.keys(): diff --git a/test/collective/fleet/test_parallel_dygraph_muon.py b/test/collective/fleet/test_parallel_dygraph_muon.py index d845baa4810c1a..13b6d990a8fc3a 100644 --- a/test/collective/fleet/test_parallel_dygraph_muon.py +++ b/test/collective/fleet/test_parallel_dygraph_muon.py @@ -26,7 +26,10 @@ def test_muon_sharding_optimizer(self): Test logic is in hybrid_parallel_sharding_muon_model.py, iterating 4 ns_coeff_types. fp32 matmul is auto-selected on V100. """ - self.run_mnist_2accelerators('hybrid_parallel_sharding_muon_model.py') + self.run_mnist_2accelerators( + 'hybrid_parallel_sharding_muon_model.py', + need_envs={"MULTI_PRECISION": "1"}, + ) def test_muon_sharding_fused_gradient(self): """MuonSharding test with FLAGS_shard_fused_gradient=1. @@ -36,7 +39,10 @@ def test_muon_sharding_fused_gradient(self): """ self.run_mnist_2accelerators( 'hybrid_parallel_sharding_muon_model.py', - need_envs={"FLAGS_shard_fused_gradient": "1"}, + need_envs={ + "FLAGS_shard_fused_gradient": "1", + "MULTI_PRECISION": "1", + }, ) def test_muon_sharding_fuse_optimizer_states(self): @@ -46,7 +52,10 @@ def test_muon_sharding_fuse_optimizer_states(self): """ self.run_mnist_2accelerators( 'hybrid_parallel_sharding_muon_model.py', - need_envs={"ENABLE_FUSE_OPTIMIZER_STATES": "1"}, + need_envs={ + "ENABLE_FUSE_OPTIMIZER_STATES": "1", + "MULTI_PRECISION": "1", + }, ) def test_muon_sharding_release_grads_fused(self): @@ -60,6 +69,7 @@ def test_muon_sharding_release_grads_fused(self): need_envs={ "FLAGS_shard_fused_gradient": "1", "RELEASE_GRADIENTS": "1", + "MULTI_PRECISION": "1", }, ) From eb271fbc977d6485d00b219072a7ba126d431eb6 Mon Sep 17 00:00:00 2001 From: xxyux <1650459510@qq.com> Date: Mon, 27 Apr 2026 17:15:21 +0800 Subject: [PATCH 4/4] update muon_sharding_optimizer with rebuilding 2d_params. --- .../muon_sharding_optimizer.py | 21 +++++++++++++------ python/paddle/optimizer/muon.py | 8 ++----- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py index 199e7dcff2d3ae..5d1d903a2b3573 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/muon_sharding_optimizer.py @@ -529,6 +529,7 @@ def clear_param_storage(self, color): # 1D params if color in self._color_to_comm_buffer_list.keys(): for comm_buffer in self._color_to_comm_buffer_list[color]: + has_clear = False for param in comm_buffer.params: grad_view = comm_buffer._sharding_param_grad_view[ param.name @@ -540,14 +541,20 @@ def clear_param_storage(self, color): ): grad_view.fill_slice_param(slice_param) self._create_master_weight(slice_param) - slice_param._clear_dataptr() - comm_buffer._clear_param_storage() + if param.name in self._master_weights: + slice_param._clear_dataptr() + has_clear = True + + if has_clear: + comm_buffer._clear_param_storage() # 2D params if color in self._params_2d_by_color.keys(): for param in self._params_2d_by_color[color]: if not g_shard_bypass_dygraph_optimizer: self._create_master_weight(param) - param._clear_to_zero_allocation() + + if param.name in self._master_weights: + param._clear_to_zero_allocation() def reset_param_storage(self): for color in self.clear_color: @@ -556,12 +563,14 @@ def reset_param_storage(self): # 1D params if color in self._color_to_comm_buffer_list.keys(): for comm_buffer in self._color_to_comm_buffer_list[color]: - comm_buffer._reset_param_storage() + if not comm_buffer.param_storage._is_initialized(): + comm_buffer._reset_param_storage() # 2D params if color in self._params_2d_by_color.keys(): for param in self._params_2d_by_color[color]: - new_param = paddle.empty_like(param) - new_param._share_buffer_to(param) + if not param._is_initialized(): + new_param = paddle.empty_like(param) + new_param._share_buffer_to(param) # ------------------------------------------------------------------ # Gradient communication diff --git a/python/paddle/optimizer/muon.py b/python/paddle/optimizer/muon.py index db8e70e8e221ef..31dbf7587b0946 100644 --- a/python/paddle/optimizer/muon.py +++ b/python/paddle/optimizer/muon.py @@ -445,9 +445,7 @@ def _adamw_update( ): with_decay = False - find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( - param.dtype - ) + find_master = param.name in self._master_weights master_weight = ( self._master_weights[param.name] if find_master else None ) @@ -557,9 +555,7 @@ def ortho_fn(m): # Default: whole matrix orthogonalisation orthogonal_update = ortho_fn(matrix_2d_global) - find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( - param.dtype - ) + find_master = param.name in self._master_weights master_weight = ( self._master_weights[param.name] if find_master else None )