@@ -1537,12 +1537,6 @@ def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata):
15371537 parameter_list = []
15381538
15391539 # --- 1D params: build shard-sized slice params from FusedCommBuffer ---
1540- # (same logic as V2 branch above, using _comm_buffer_list)
1541- # IMPORTANT: set slice_param.name = "slice@" + param_name so that the
1542- # accumulator key matches what muon_sharding's sharded_state_dict expects via
1543- # _split_state_name (it strips the "_moment1_0" suffix to get static_name,
1544- # which must match param_slice_info keys = original param names after
1545- # removing the "slice@" prefix added back in sharded_state_dict).
15461540 for buffer in optimizer ._comm_buffer_list :
15471541 for param_name , grad_view in buffer ._sharding_param_grad_view .items ():
15481542 if param_name not in static_to_struct_mapping :
@@ -1559,31 +1553,24 @@ def init_optimizer(optimizer, model_sharded_state_dict, state_dict_metadata):
15591553 slice_param .name = param_name
15601554 parameter_list .append (slice_param )
15611555
1562- # --- 2D non-MoE params: local rank's full tensors (Muon) ---
1563- local_2d = optimizer ._rank2params_2d .get (optimizer ._sharding_rank , [])
1564- for param in local_2d :
1565- param_name = param .name
1566- if param_name not in static_to_struct_mapping :
1567- continue
1568- struct_name = static_to_struct_mapping [param_name ]
1569- if not any (struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names ):
1570- continue
1571- parameter_list .append (param )
1572-
1573- # --- 2D MoE expert params: local rank's full tensors (Muon) ---
1574- if optimizer ._moe_sharding_world_size > 1 :
1575- moe_rank = optimizer ._moe_sharding_rank
1576- else :
1577- moe_rank = 0
1578- local_2d_moe = optimizer ._rank2params_2d_moe .get (moe_rank , [])
1579- for param in local_2d_moe :
1580- param_name = param .name
1581- if param_name not in static_to_struct_mapping :
1582- continue
1583- struct_name = static_to_struct_mapping [param_name ]
1584- if not any (struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names ):
1585- continue
1586- parameter_list .append (param )
1556+ # -- 2D params: build full-sized 2D params from _params_2d_by_color ---
1557+ for color_key , _ in optimizer ._params_2d_by_color .items ():
1558+ assert (
1559+ color_key in optimizer ._rank2params_2d_by_color
1560+ ), f"color_key '{ color_key } ' not in optimizer._rank2params_2d_by_color."
1561+ rank2params_2d_by_color = optimizer ._rank2params_2d_by_color [color_key ]
1562+
1563+ group_info = optimizer ._color_to_group_info [color_key ]
1564+ sharding_rank = group_info ["rank" ] if group_info ["rank" ] >= 0 else 0
1565+ local_2d = rank2params_2d_by_color [sharding_rank ]
1566+ for param in local_2d :
1567+ param_name = param .name
1568+ if param_name not in static_to_struct_mapping :
1569+ continue
1570+ struct_name = static_to_struct_mapping [param_name ]
1571+ if not any (struct_name + state_name in state_dict_metadata for state_name in optimizer_state_names ):
1572+ continue
1573+ parameter_list .append (param )
15871574
15881575 optimizer ._create_accumulators (paddle .base .framework .default_main_program ().global_block (), parameter_list )
15891576 return
0 commit comments