Skip to content

Commit 5d425da

Browse files
authored
udpate init_optimizer (#4367)
1 parent 5708bb8 commit 5d425da

1 file changed

Lines changed: 18 additions & 31 deletions

File tree

paddleformers/trainer/trainer_utils.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1537,12 +1537,6 @@ def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata):
15371537
parameter_list = []
15381538

15391539
# --- 1D params: build shard-sized slice params from FusedCommBuffer ---
1540-
# (same logic as V2 branch above, using _comm_buffer_list)
1541-
# IMPORTANT: set slice_param.name = "slice@" + param_name so that the
1542-
# accumulator key matches what muon_sharding's sharded_state_dict expects via
1543-
# _split_state_name (it strips the "_moment1_0" suffix to get static_name,
1544-
# which must match param_slice_info keys = original param names after
1545-
# removing the "slice@" prefix added back in sharded_state_dict).
15461540
for buffer in optimizer._comm_buffer_list:
15471541
for param_name, grad_view in buffer._sharding_param_grad_view.items():
15481542
if param_name not in static_to_struct_mapping:
@@ -1559,31 +1553,24 @@ def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata):
15591553
slice_param.name = param_name
15601554
parameter_list.append(slice_param)
15611555

1562-
# --- 2D non-MoE params: local rank's full tensors (Muon) ---
1563-
local_2d = optimizer._rank2params_2d.get(optimizer._sharding_rank, [])
1564-
for param in local_2d:
1565-
param_name = param.name
1566-
if param_name not in static_to_struct_mapping:
1567-
continue
1568-
struct_name = static_to_struct_mapping[param_name]
1569-
if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names):
1570-
continue
1571-
parameter_list.append(param)
1572-
1573-
# --- 2D MoE expert params: local rank's full tensors (Muon) ---
1574-
if optimizer._moe_sharding_world_size > 1:
1575-
moe_rank = optimizer._moe_sharding_rank
1576-
else:
1577-
moe_rank = 0
1578-
local_2d_moe = optimizer._rank2params_2d_moe.get(moe_rank, [])
1579-
for param in local_2d_moe:
1580-
param_name = param.name
1581-
if param_name not in static_to_struct_mapping:
1582-
continue
1583-
struct_name = static_to_struct_mapping[param_name]
1584-
if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names):
1585-
continue
1586-
parameter_list.append(param)
1556+
# -- 2D params: build full-sized 2D params from _params_2d_by_color ---
1557+
for color_key, _ in optimizer._params_2d_by_color.items():
1558+
assert (
1559+
color_key in optimizer._rank2params_2d_by_color
1560+
), f"color_key '{color_key}' not in optimizer._rank2params_2d_by_color."
1561+
rank2params_2d_by_color = optimizer._rank2params_2d_by_color[color_key]
1562+
1563+
group_info = optimizer._color_to_group_info[color_key]
1564+
sharding_rank = group_info["rank"] if group_info["rank"] >= 0 else 0
1565+
local_2d = rank2params_2d_by_color[sharding_rank]
1566+
for param in local_2d:
1567+
param_name = param.name
1568+
if param_name not in static_to_struct_mapping:
1569+
continue
1570+
struct_name = static_to_struct_mapping[param_name]
1571+
if not any(struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names):
1572+
continue
1573+
parameter_list.append(param)
15871574

15881575
optimizer._create_accumulators(paddle.base.framework.default_main_program().global_block(), parameter_list)
15891576
return

0 commit comments

Comments
 (0)