Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
g_shard_bypass_dygraph_optimizer = int(
os.environ.get("FLAGS_shard_bypass_dygraph_optimizer", 0)
)
g_shard_fused_gradient = int(os.environ.get("FLAGS_shard_fused_gradient", 0))


def _is_trainable(param):
Expand Down Expand Up @@ -133,6 +132,9 @@ def __init__(self, optimizer, hcg=None):

pp_overlap = strategy.hybrid_configs['pp_configs'].sharding_comm_overlap
self.pp_overlap = pp_overlap
assert not self.pp_overlap, (
"muon_sharding_optimizer do not support PP overlap"
)

self._use_main_grad = hasattr(optimizer._parameter_list[0], "main_grad")

Expand All @@ -146,12 +148,6 @@ def __init__(self, optimizer, hcg=None):
self._parameter_list, sharding_group
)

# Extract MoE group info from color_to_group_info for backward compatibility
moe_info = self._color_to_group_info.get('moe_expert', {})
self._moe_sharding_world_size = moe_info.get('world_size', 1)
self._moe_sharding_rank = moe_info.get('rank', 0)
self._moe_sharding_group = moe_info.get('group', None)

# Get muon_param_info_map from the inner Muon optimizer.
# Each entry has use_muon=True/False, set by the Trainer before construction.
self._muon_param_info_map = getattr(
Expand Down Expand Up @@ -239,29 +235,25 @@ def __init__(self, optimizer, hcg=None):
key=lambda p: self._param2rank_2d_by_color[color_key][p.name]
)

# ---- Backward compatibility: expose legacy attributes ----
# These are kept for any external code that might reference them
self._params_2d = self._params_2d_by_color.get(None, [])
self._params_2d_moe = self._params_2d_by_color.get('moe_expert', [])
self._rank2params_2d = self._rank2params_2d_by_color.get(None, {0: []})
self._param2rank_2d = self._param2rank_2d_by_color.get(None, {})
self._rank2params_2d_moe = self._rank2params_2d_by_color.get(
'moe_expert', {0: []}
)
self._param2rank_2d_moe = self._param2rank_2d_by_color.get(
'moe_expert', {}
)
# 2D params owned by this sharding rank
self._local_2d = []
for color_key, params_2d in self._params_2d_by_color.items():
rank2params_2d_by_color = self._rank2params_2d_by_color[color_key]

group_info = self._color_to_group_info[color_key]
sharding_rank = max(group_info['rank'], 0)

self._local_2d.extend(rank2params_2d_by_color[sharding_rank])

self.sd_release_grads = (
strategy.hybrid_configs['pp_configs'].release_gradients
or sharding_configs.release_gradients
)
self._use_fuse_gradients = g_shard_fused_gradient
self._use_fuse_gradients = self.comm_buffer_size_MB > 0
# ---- Build comm buffers for 2D params (V1-style) ----
if self._use_fuse_gradients:
if not hasattr(self, 'comm_buffer_2d'):
self.comm_buffer_2d = self._build_2d_comm_buffers()
self.comm_buffer_2d.sort(key=lambda x: x._dst)
self.comm_buffer_2d = self._build_2d_comm_buffers()
self.comm_buffer_2d.sort(key=lambda x: x._dst)

# ---- Step 3: Build comm buffers for 1D params (V2-style) ----
self._slice_params = {}
Expand All @@ -278,15 +270,9 @@ def __init__(self, optimizer, hcg=None):
# The optimizer should see:
# - All 2D params assigned to this rank (all colors, as whole tensors)
# - 1D slice_params for all non-2D params (element-wise shards)
local_2d_params = []
for color_key, rank2params in self._rank2params_2d_by_color.items():
group_info = self._color_to_group_info.get(color_key, {})
color_rank = group_info.get('rank', 0)
world_size = group_info.get('world_size', 1)
rank_key = color_rank if world_size > 1 else 0
local_2d_params.extend(rank2params.get(rank_key, []))

local_opt_params = local_2d_params + list(self._local_parameter_list_1d)
local_opt_params = list(self._local_2d) + list(
self._local_parameter_list_1d
)

self._set_inner_opt_attr('_parameter_list', local_opt_params)
self._set_inner_opt_attr('_param_groups', local_opt_params)
Expand All @@ -306,18 +292,10 @@ def __init__(self, optimizer, hcg=None):
_sg_group = hcg.get_sharding_parallel_group()
_N = self._sharding_world_size

# 2D params owned by this sharding rank (default color, via legacy alias)
# 2D params owned by this sharding rank
_local_2d_numel = sum(
int(functools_reduce(lambda x, y: x * y, p.shape, 1))
for p in self._rank2params_2d.get(self._sharding_rank, [])
)
# 2D MoE-expert params owned by this rank (moe_expert color, via legacy alias)
_moe_rank_key = (
self._moe_sharding_rank if self._moe_sharding_world_size > 1 else 0
)
_local_2d_moe_numel = sum(
int(functools_reduce(lambda x, y: x * y, p.shape, 1))
for p in self._rank2params_2d_moe.get(_moe_rank_key, [])
for p in self._local_2d
)
# 1D (AdamW) slice: each rank holds ceil(numel / sharding_world_size) elements.
_local_1d_numel = sum(
Expand All @@ -327,9 +305,7 @@ def __init__(self, optimizer, hcg=None):
for p in self._params_1d
)

_local_total_numel = (
_local_2d_numel + _local_2d_moe_numel + _local_1d_numel
)
_local_total_numel = _local_2d_numel + _local_1d_numel
_local_total_MB = (
_local_total_numel * 2 / (1024 * 1024)
) # bf16/fp16 = 2 bytes
Expand Down Expand Up @@ -545,9 +521,15 @@ def _build_1d_comm_buffers(self):
self._comm_buffer_list.sort(key=lambda x: x._dst)

def clear_param_storage(self, color):
assert self._multi_precision, (
"Muon Sharding Optimizer only support clear param with multi_precision mode"
)

self.clear_color.add(color)
# 1D params
if color in self._color_to_comm_buffer_list.keys():
for comm_buffer in self._color_to_comm_buffer_list[color]:
has_clear = False
for param in comm_buffer.params:
grad_view = comm_buffer._sharding_param_grad_view[
param.name
Expand All @@ -559,16 +541,36 @@ def clear_param_storage(self, color):
):
grad_view.fill_slice_param(slice_param)
self._create_master_weight(slice_param)
slice_param._clear_dataptr()
comm_buffer._clear_param_storage()
if param.name in self._master_weights:
slice_param._clear_dataptr()
has_clear = True

if has_clear:
comm_buffer._clear_param_storage()
# 2D params
if color in self._params_2d_by_color.keys():
for param in self._params_2d_by_color[color]:
if not g_shard_bypass_dygraph_optimizer:
self._create_master_weight(param)

if param.name in self._master_weights:
param._clear_to_zero_allocation()

def reset_param_storage(self):
for color in self.clear_color:
if color is None:
continue
# 1D params
if color in self._color_to_comm_buffer_list.keys():
for comm_buffer in self._color_to_comm_buffer_list[color]:
comm_buffer._reset_param_storage()
if not comm_buffer.param_storage._is_initialized():
comm_buffer._reset_param_storage()
# 2D params
if color in self._params_2d_by_color.keys():
for param in self._params_2d_by_color[color]:
if not param._is_initialized():
new_param = paddle.empty_like(param)
new_param._share_buffer_to(param)

# ------------------------------------------------------------------
# Gradient communication
Expand Down Expand Up @@ -671,31 +673,13 @@ def reduce_gradients(self, parameter_list, hcg):

def filter_parameters(self, parameter_list, hcg):
"""Filter parameters: return local 2D params + initialized 1D slices."""
sharding_rank = hcg.get_sharding_parallel_rank()
local_2d = [
p
for p in parameter_list
if p.name in self._param2rank_2d
and self._param2rank_2d[p.name] == sharding_rank
]
# Also include MoE 2D params owned by this rank
if self._moe_sharding_world_size > 1:
moe_rank = self._moe_sharding_rank
else:
moe_rank = 0
local_2d_moe = [
p
for p in parameter_list
if p.name in self._param2rank_2d_moe
and self._param2rank_2d_moe[p.name] == moe_rank
]
local_1d = [
self._slice_params[p.name]
for p in parameter_list
if p.name in self._slice_params
]
local_1d = [p for p in local_1d if p._is_initialized()]
return local_2d + local_2d_moe + local_1d
return self._local_2d + local_1d

# ------------------------------------------------------------------
# Parameter sync after optimizer step
Expand Down Expand Up @@ -771,12 +755,12 @@ def clear_grad_func(p):
for p in self._parameter_list:
clear_grad_func(p)

# 1D params are managed by comm buffers
if self.sd_release_grads and not self.pp_overlap:
# 1D params are managed by comm buffers
for comm_buffer in self._comm_buffer_list:
if comm_buffer.need_reduce_scale_sync():
comm_buffer._clear_grad_storage()

# 2D params are managed by comm buffers
if self._use_fuse_gradients:
for comm_buffer in self.comm_buffer_2d:
if comm_buffer.need_reduce_scale_sync():
Expand Down Expand Up @@ -820,6 +804,8 @@ def _assign_slice_grad(self):

def step(self):
"""Optimizer step: update local 2D params and 1D slices, then sync."""
self.reset_param_storage()

self._collect_comm_buffers()
self._assign_slice_grad()

Expand Down Expand Up @@ -884,18 +870,7 @@ def step(self):
def set_state_dict(self, state_dict):
inner_state = {}
# Collect local parameters: 2D whole-tensor params + 1D original params
# (set_state_dict uses legacy aliases; covers default and moe_expert colors)
local_2d = list(self._rank2params_2d.get(self._sharding_rank, []))
if self._moe_sharding_world_size > 1:
local_2d_moe = list(
self._rank2params_2d_moe.get(self._moe_sharding_rank, [])
)
else:
local_2d_moe = list(self._rank2params_2d_moe.get(0, []))
parameters = local_2d + local_2d_moe
# Add 1D params (use original param names for matching)
for p in self._params_1d:
parameters.append(p)
parameters = list(self._local_2d) + list(self._params_1d)

if "LR_Scheduler" in state_dict:
inner_state["LR_Scheduler"] = state_dict.pop("LR_Scheduler")
Expand Down
8 changes: 2 additions & 6 deletions python/paddle/optimizer/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,9 +445,7 @@ def _adamw_update(
):
with_decay = False

find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
param.dtype
)
find_master = param.name in self._master_weights
master_weight = (
self._master_weights[param.name] if find_master else None
)
Expand Down Expand Up @@ -557,9 +555,7 @@ def ortho_fn(m):
# Default: whole matrix orthogonalisation
orthogonal_update = ortho_fn(matrix_2d_global)

find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(
param.dtype
)
find_master = param.name in self._master_weights
master_weight = (
self._master_weights[param.name] if find_master else None
)
Expand Down
6 changes: 6 additions & 0 deletions test/collective/fleet/hybrid_parallel_sharding_muon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

import paddle
from paddle.distributed import fleet
from paddle.distributed.fleet.meta_optimizers.muon_sharding_optimizer import (
MuonShardingOptimizer,
)
from paddle.distributed.fleet.utils import mix_precision_utils
from paddle.optimizer.muon import (
MuonParamInfo,
Expand Down Expand Up @@ -240,6 +243,9 @@ def train_batch(self, batch, model, optimizer):
output = model(batch)
loss = output.mean()
loss.backward()
inner_opt = getattr(optimizer, '_inner_opt', optimizer)
if isinstance(inner_opt, MuonShardingOptimizer):
optimizer.clear_param_storage('test_color')
optimizer.step()
optimizer.clear_grad()
return loss
Expand Down
16 changes: 13 additions & 3 deletions test/collective/fleet/test_parallel_dygraph_muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ def test_muon_sharding_optimizer(self):
Test logic is in hybrid_parallel_sharding_muon_model.py,
iterating 4 ns_coeff_types. fp32 matmul is auto-selected on V100.
"""
self.run_mnist_2accelerators('hybrid_parallel_sharding_muon_model.py')
self.run_mnist_2accelerators(
'hybrid_parallel_sharding_muon_model.py',
need_envs={"MULTI_PRECISION": "1"},
)

def test_muon_sharding_fused_gradient(self):
"""MuonSharding test with FLAGS_shard_fused_gradient=1.
Expand All @@ -36,7 +39,10 @@ def test_muon_sharding_fused_gradient(self):
"""
self.run_mnist_2accelerators(
'hybrid_parallel_sharding_muon_model.py',
need_envs={"FLAGS_shard_fused_gradient": "1"},
need_envs={
"FLAGS_shard_fused_gradient": "1",
"MULTI_PRECISION": "1",
},
)

def test_muon_sharding_fuse_optimizer_states(self):
Expand All @@ -46,7 +52,10 @@ def test_muon_sharding_fuse_optimizer_states(self):
"""
self.run_mnist_2accelerators(
'hybrid_parallel_sharding_muon_model.py',
need_envs={"ENABLE_FUSE_OPTIMIZER_STATES": "1"},
need_envs={
"ENABLE_FUSE_OPTIMIZER_STATES": "1",
"MULTI_PRECISION": "1",
},
)

def test_muon_sharding_release_grads_fused(self):
Expand All @@ -60,6 +69,7 @@ def test_muon_sharding_release_grads_fused(self):
need_envs={
"FLAGS_shard_fused_gradient": "1",
"RELEASE_GRADIENTS": "1",
"MULTI_PRECISION": "1",
},
)

Expand Down
Loading