5858g_shard_bypass_dygraph_optimizer = int (
5959 os .environ .get ("FLAGS_shard_bypass_dygraph_optimizer" , 0 )
6060)
61- g_shard_fused_gradient = int (os .environ .get ("FLAGS_shard_fused_gradient" , 0 ))
6261
6362
6463def _is_trainable (param ):
@@ -133,6 +132,9 @@ def __init__(self, optimizer, hcg=None):
133132
134133 pp_overlap = strategy .hybrid_configs ['pp_configs' ].sharding_comm_overlap
135134 self .pp_overlap = pp_overlap
135+ assert not self .pp_overlap , (
136+ "muon_sharding_optimizer do not support PP overlap"
137+ )
136138
137139 self ._use_main_grad = hasattr (optimizer ._parameter_list [0 ], "main_grad" )
138140
@@ -146,12 +148,6 @@ def __init__(self, optimizer, hcg=None):
146148 self ._parameter_list , sharding_group
147149 )
148150
149- # Extract MoE group info from color_to_group_info for backward compatibility
150- moe_info = self ._color_to_group_info .get ('moe_expert' , {})
151- self ._moe_sharding_world_size = moe_info .get ('world_size' , 1 )
152- self ._moe_sharding_rank = moe_info .get ('rank' , 0 )
153- self ._moe_sharding_group = moe_info .get ('group' , None )
154-
155151 # Get muon_param_info_map from the inner Muon optimizer.
156152 # Each entry has use_muon=True/False, set by the Trainer before construction.
157153 self ._muon_param_info_map = getattr (
@@ -239,29 +235,25 @@ def __init__(self, optimizer, hcg=None):
239235 key = lambda p : self ._param2rank_2d_by_color [color_key ][p .name ]
240236 )
241237
242- # ---- Backward compatibility: expose legacy attributes ----
243- # These are kept for any external code that might reference them
244- self ._params_2d = self ._params_2d_by_color .get (None , [])
245- self ._params_2d_moe = self ._params_2d_by_color .get ('moe_expert' , [])
246- self ._rank2params_2d = self ._rank2params_2d_by_color .get (None , {0 : []})
247- self ._param2rank_2d = self ._param2rank_2d_by_color .get (None , {})
248- self ._rank2params_2d_moe = self ._rank2params_2d_by_color .get (
249- 'moe_expert' , {0 : []}
250- )
251- self ._param2rank_2d_moe = self ._param2rank_2d_by_color .get (
252- 'moe_expert' , {}
253- )
238+ # 2D params owned by this sharding rank
239+ self ._local_2d = []
240+ for color_key , params_2d in self ._params_2d_by_color .items ():
241+ rank2params_2d_by_color = self ._rank2params_2d_by_color [color_key ]
242+
243+ group_info = self ._color_to_group_info [color_key ]
244+ sharding_rank = max (group_info ['rank' ], 0 )
245+
246+ self ._local_2d .extend (rank2params_2d_by_color [sharding_rank ])
254247
255248 self .sd_release_grads = (
256249 strategy .hybrid_configs ['pp_configs' ].release_gradients
257250 or sharding_configs .release_gradients
258251 )
259- self ._use_fuse_gradients = g_shard_fused_gradient
252+ self ._use_fuse_gradients = self . comm_buffer_size_MB > 0
260253 # ---- Build comm buffers for 2D params (V1-style) ----
261254 if self ._use_fuse_gradients :
262- if not hasattr (self , 'comm_buffer_2d' ):
263- self .comm_buffer_2d = self ._build_2d_comm_buffers ()
264- self .comm_buffer_2d .sort (key = lambda x : x ._dst )
255+ self .comm_buffer_2d = self ._build_2d_comm_buffers ()
256+ self .comm_buffer_2d .sort (key = lambda x : x ._dst )
265257
266258 # ---- Step 3: Build comm buffers for 1D params (V2-style) ----
267259 self ._slice_params = {}
@@ -278,15 +270,9 @@ def __init__(self, optimizer, hcg=None):
278270 # The optimizer should see:
279271 # - All 2D params assigned to this rank (all colors, as whole tensors)
280272 # - 1D slice_params for all non-2D params (element-wise shards)
281- local_2d_params = []
282- for color_key , rank2params in self ._rank2params_2d_by_color .items ():
283- group_info = self ._color_to_group_info .get (color_key , {})
284- color_rank = group_info .get ('rank' , 0 )
285- world_size = group_info .get ('world_size' , 1 )
286- rank_key = color_rank if world_size > 1 else 0
287- local_2d_params .extend (rank2params .get (rank_key , []))
288-
289- local_opt_params = local_2d_params + list (self ._local_parameter_list_1d )
273+ local_opt_params = list (self ._local_2d ) + list (
274+ self ._local_parameter_list_1d
275+ )
290276
291277 self ._set_inner_opt_attr ('_parameter_list' , local_opt_params )
292278 self ._set_inner_opt_attr ('_param_groups' , local_opt_params )
@@ -306,18 +292,10 @@ def __init__(self, optimizer, hcg=None):
306292 _sg_group = hcg .get_sharding_parallel_group ()
307293 _N = self ._sharding_world_size
308294
309- # 2D params owned by this sharding rank (default color, via legacy alias)
295+ # 2D params owned by this sharding rank
310296 _local_2d_numel = sum (
311297 int (functools_reduce (lambda x , y : x * y , p .shape , 1 ))
312- for p in self ._rank2params_2d .get (self ._sharding_rank , [])
313- )
314- # 2D MoE-expert params owned by this rank (moe_expert color, via legacy alias)
315- _moe_rank_key = (
316- self ._moe_sharding_rank if self ._moe_sharding_world_size > 1 else 0
317- )
318- _local_2d_moe_numel = sum (
319- int (functools_reduce (lambda x , y : x * y , p .shape , 1 ))
320- for p in self ._rank2params_2d_moe .get (_moe_rank_key , [])
298+ for p in self ._local_2d
321299 )
322300 # 1D (AdamW) slice: each rank holds ceil(numel / sharding_world_size) elements.
323301 _local_1d_numel = sum (
@@ -327,9 +305,7 @@ def __init__(self, optimizer, hcg=None):
327305 for p in self ._params_1d
328306 )
329307
330- _local_total_numel = (
331- _local_2d_numel + _local_2d_moe_numel + _local_1d_numel
332- )
308+ _local_total_numel = _local_2d_numel + _local_1d_numel
333309 _local_total_MB = (
334310 _local_total_numel * 2 / (1024 * 1024 )
335311 ) # bf16/fp16 = 2 bytes
@@ -545,9 +521,15 @@ def _build_1d_comm_buffers(self):
545521 self ._comm_buffer_list .sort (key = lambda x : x ._dst )
546522
547523 def clear_param_storage (self , color ):
524+ assert self ._multi_precision , (
525+ "Muon Sharding Optimizer only support clear param with multi_precision mode"
526+ )
527+
548528 self .clear_color .add (color )
529+ # 1D params
549530 if color in self ._color_to_comm_buffer_list .keys ():
550531 for comm_buffer in self ._color_to_comm_buffer_list [color ]:
532+ has_clear = False
551533 for param in comm_buffer .params :
552534 grad_view = comm_buffer ._sharding_param_grad_view [
553535 param .name
@@ -559,16 +541,36 @@ def clear_param_storage(self, color):
559541 ):
560542 grad_view .fill_slice_param (slice_param )
561543 self ._create_master_weight (slice_param )
562- slice_param ._clear_dataptr ()
563- comm_buffer ._clear_param_storage ()
544+ if param .name in self ._master_weights :
545+ slice_param ._clear_dataptr ()
546+ has_clear = True
547+
548+ if has_clear :
549+ comm_buffer ._clear_param_storage ()
550+ # 2D params
551+ if color in self ._params_2d_by_color .keys ():
552+ for param in self ._params_2d_by_color [color ]:
553+ if not g_shard_bypass_dygraph_optimizer :
554+ self ._create_master_weight (param )
555+
556+ if param .name in self ._master_weights :
557+ param ._clear_to_zero_allocation ()
564558
565559 def reset_param_storage (self ):
566560 for color in self .clear_color :
567561 if color is None :
568562 continue
563+ # 1D params
569564 if color in self ._color_to_comm_buffer_list .keys ():
570565 for comm_buffer in self ._color_to_comm_buffer_list [color ]:
571- comm_buffer ._reset_param_storage ()
566+ if not comm_buffer .param_storage ._is_initialized ():
567+ comm_buffer ._reset_param_storage ()
568+ # 2D params
569+ if color in self ._params_2d_by_color .keys ():
570+ for param in self ._params_2d_by_color [color ]:
571+ if not param ._is_initialized ():
572+ new_param = paddle .empty_like (param )
573+ new_param ._share_buffer_to (param )
572574
573575 # ------------------------------------------------------------------
574576 # Gradient communication
@@ -671,31 +673,13 @@ def reduce_gradients(self, parameter_list, hcg):
671673
672674 def filter_parameters (self , parameter_list , hcg ):
673675 """Filter parameters: return local 2D params + initialized 1D slices."""
674- sharding_rank = hcg .get_sharding_parallel_rank ()
675- local_2d = [
676- p
677- for p in parameter_list
678- if p .name in self ._param2rank_2d
679- and self ._param2rank_2d [p .name ] == sharding_rank
680- ]
681- # Also include MoE 2D params owned by this rank
682- if self ._moe_sharding_world_size > 1 :
683- moe_rank = self ._moe_sharding_rank
684- else :
685- moe_rank = 0
686- local_2d_moe = [
687- p
688- for p in parameter_list
689- if p .name in self ._param2rank_2d_moe
690- and self ._param2rank_2d_moe [p .name ] == moe_rank
691- ]
692676 local_1d = [
693677 self ._slice_params [p .name ]
694678 for p in parameter_list
695679 if p .name in self ._slice_params
696680 ]
697681 local_1d = [p for p in local_1d if p ._is_initialized ()]
698- return local_2d + local_2d_moe + local_1d
682+ return self . _local_2d + local_1d
699683
700684 # ------------------------------------------------------------------
701685 # Parameter sync after optimizer step
@@ -771,12 +755,12 @@ def clear_grad_func(p):
771755 for p in self ._parameter_list :
772756 clear_grad_func (p )
773757
774- # 1D params are managed by comm buffers
775758 if self .sd_release_grads and not self .pp_overlap :
759+ # 1D params are managed by comm buffers
776760 for comm_buffer in self ._comm_buffer_list :
777761 if comm_buffer .need_reduce_scale_sync ():
778762 comm_buffer ._clear_grad_storage ()
779-
763+ # 2D params are managed by comm buffers
780764 if self ._use_fuse_gradients :
781765 for comm_buffer in self .comm_buffer_2d :
782766 if comm_buffer .need_reduce_scale_sync ():
@@ -820,6 +804,8 @@ def _assign_slice_grad(self):
820804
821805 def step (self ):
822806 """Optimizer step: update local 2D params and 1D slices, then sync."""
807+ self .reset_param_storage ()
808+
823809 self ._collect_comm_buffers ()
824810 self ._assign_slice_grad ()
825811
@@ -884,18 +870,7 @@ def step(self):
884870 def set_state_dict (self , state_dict ):
885871 inner_state = {}
886872 # Collect local parameters: 2D whole-tensor params + 1D original params
887- # (set_state_dict uses legacy aliases; covers default and moe_expert colors)
888- local_2d = list (self ._rank2params_2d .get (self ._sharding_rank , []))
889- if self ._moe_sharding_world_size > 1 :
890- local_2d_moe = list (
891- self ._rank2params_2d_moe .get (self ._moe_sharding_rank , [])
892- )
893- else :
894- local_2d_moe = list (self ._rank2params_2d_moe .get (0 , []))
895- parameters = local_2d + local_2d_moe
896- # Add 1D params (use original param names for matching)
897- for p in self ._params_1d :
898- parameters .append (p )
873+ parameters = list (self ._local_2d ) + list (self ._params_1d )
899874
900875 if "LR_Scheduler" in state_dict :
901876 inner_state ["LR_Scheduler" ] = state_dict .pop ("LR_Scheduler" )
0 commit comments