5858g_shard_bypass_dygraph_optimizer = int (
5959 os .environ .get ("FLAGS_shard_bypass_dygraph_optimizer" , 0 )
6060)
61- g_shard_fused_gradient = int (os .environ .get ("FLAGS_shard_fused_gradient" , 0 ))
6261
6362
6463def _is_trainable (param ):
@@ -133,6 +132,9 @@ def __init__(self, optimizer, hcg=None):
133132
134133 pp_overlap = strategy .hybrid_configs ['pp_configs' ].sharding_comm_overlap
135134 self .pp_overlap = pp_overlap
135+ assert not self .pp_overlap , (
136+ "muon_sharding_optimizer do not support PP overlap"
137+ )
136138
137139 self ._use_main_grad = hasattr (optimizer ._parameter_list [0 ], "main_grad" )
138140
@@ -146,12 +148,6 @@ def __init__(self, optimizer, hcg=None):
146148 self ._parameter_list , sharding_group
147149 )
148150
149- # Extract MoE group info from color_to_group_info for backward compatibility
150- moe_info = self ._color_to_group_info .get ('moe_expert' , {})
151- self ._moe_sharding_world_size = moe_info .get ('world_size' , 1 )
152- self ._moe_sharding_rank = moe_info .get ('rank' , 0 )
153- self ._moe_sharding_group = moe_info .get ('group' , None )
154-
155151 # Get muon_param_info_map from the inner Muon optimizer.
156152 # Each entry has use_muon=True/False, set by the Trainer before construction.
157153 self ._muon_param_info_map = getattr (
@@ -239,29 +235,25 @@ def __init__(self, optimizer, hcg=None):
239235 key = lambda p : self ._param2rank_2d_by_color [color_key ][p .name ]
240236 )
241237
242- # ---- Backward compatibility: expose legacy attributes ----
243- # These are kept for any external code that might reference them
244- self ._params_2d = self ._params_2d_by_color .get (None , [])
245- self ._params_2d_moe = self ._params_2d_by_color .get ('moe_expert' , [])
246- self ._rank2params_2d = self ._rank2params_2d_by_color .get (None , {0 : []})
247- self ._param2rank_2d = self ._param2rank_2d_by_color .get (None , {})
248- self ._rank2params_2d_moe = self ._rank2params_2d_by_color .get (
249- 'moe_expert' , {0 : []}
250- )
251- self ._param2rank_2d_moe = self ._param2rank_2d_by_color .get (
252- 'moe_expert' , {}
253- )
238+ # 2D params owned by this sharding rank
239+ self ._local_2d = []
240+ for color_key , params_2d in self ._params_2d_by_color .items ():
241+ rank2params_2d_by_color = self ._rank2params_2d_by_color [color_key ]
242+
243+ group_info = self ._color_to_group_info [color_key ]
244+ sharding_rank = max (group_info ['rank' ], 0 )
245+
246+ self ._local_2d .extend (rank2params_2d_by_color [sharding_rank ])
254247
255248 self .sd_release_grads = (
256249 strategy .hybrid_configs ['pp_configs' ].release_gradients
257250 or sharding_configs .release_gradients
258251 )
259- self ._use_fuse_gradients = g_shard_fused_gradient
252+ self ._use_fuse_gradients = self . comm_buffer_size_MB > 0
260253 # ---- Build comm buffers for 2D params (V1-style) ----
261254 if self ._use_fuse_gradients :
262- if not hasattr (self , 'comm_buffer_2d' ):
263- self .comm_buffer_2d = self ._build_2d_comm_buffers ()
264- self .comm_buffer_2d .sort (key = lambda x : x ._dst )
255+ self .comm_buffer_2d = self ._build_2d_comm_buffers ()
256+ self .comm_buffer_2d .sort (key = lambda x : x ._dst )
265257
266258 # ---- Step 3: Build comm buffers for 1D params (V2-style) ----
267259 self ._slice_params = {}
@@ -278,15 +270,9 @@ def __init__(self, optimizer, hcg=None):
278270 # The optimizer should see:
279271 # - All 2D params assigned to this rank (all colors, as whole tensors)
280272 # - 1D slice_params for all non-2D params (element-wise shards)
281- local_2d_params = []
282- for color_key , rank2params in self ._rank2params_2d_by_color .items ():
283- group_info = self ._color_to_group_info .get (color_key , {})
284- color_rank = group_info .get ('rank' , 0 )
285- world_size = group_info .get ('world_size' , 1 )
286- rank_key = color_rank if world_size > 1 else 0
287- local_2d_params .extend (rank2params .get (rank_key , []))
288-
289- local_opt_params = local_2d_params + list (self ._local_parameter_list_1d )
273+ local_opt_params = list (self ._local_2d ) + list (
274+ self ._local_parameter_list_1d
275+ )
290276
291277 self ._set_inner_opt_attr ('_parameter_list' , local_opt_params )
292278 self ._set_inner_opt_attr ('_param_groups' , local_opt_params )
@@ -306,18 +292,10 @@ def __init__(self, optimizer, hcg=None):
306292 _sg_group = hcg .get_sharding_parallel_group ()
307293 _N = self ._sharding_world_size
308294
309- # 2D params owned by this sharding rank (default color, via legacy alias)
295+ # 2D params owned by this sharding rank
310296 _local_2d_numel = sum (
311297 int (functools_reduce (lambda x , y : x * y , p .shape , 1 ))
312- for p in self ._rank2params_2d .get (self ._sharding_rank , [])
313- )
314- # 2D MoE-expert params owned by this rank (moe_expert color, via legacy alias)
315- _moe_rank_key = (
316- self ._moe_sharding_rank if self ._moe_sharding_world_size > 1 else 0
317- )
318- _local_2d_moe_numel = sum (
319- int (functools_reduce (lambda x , y : x * y , p .shape , 1 ))
320- for p in self ._rank2params_2d_moe .get (_moe_rank_key , [])
298+ for p in self ._local_2d
321299 )
322300 # 1D (AdamW) slice: each rank holds ceil(numel / sharding_world_size) elements.
323301 _local_1d_numel = sum (
@@ -327,9 +305,7 @@ def __init__(self, optimizer, hcg=None):
327305 for p in self ._params_1d
328306 )
329307
330- _local_total_numel = (
331- _local_2d_numel + _local_2d_moe_numel + _local_1d_numel
332- )
308+ _local_total_numel = _local_2d_numel + _local_1d_numel
333309 _local_total_MB = (
334310 _local_total_numel * 2 / (1024 * 1024 )
335311 ) # bf16/fp16 = 2 bytes
@@ -546,6 +522,7 @@ def _build_1d_comm_buffers(self):
546522
547523 def clear_param_storage (self , color ):
548524 self .clear_color .add (color )
525+ # 1D params
549526 if color in self ._color_to_comm_buffer_list .keys ():
550527 for comm_buffer in self ._color_to_comm_buffer_list [color ]:
551528 for param in comm_buffer .params :
@@ -561,14 +538,26 @@ def clear_param_storage(self, color):
561538 self ._create_master_weight (slice_param )
562539 slice_param ._clear_dataptr ()
563540 comm_buffer ._clear_param_storage ()
541+ # 2D params
542+ if color in self ._params_2d_by_color .keys ():
543+ for param in self ._params_2d_by_color [color ]:
544+ if not g_shard_bypass_dygraph_optimizer :
545+ self ._create_master_weight (param )
546+ param ._clear_to_zero_allocation ()
564547
565548 def reset_param_storage (self ):
566549 for color in self .clear_color :
567550 if color is None :
568551 continue
552+ # 1D params
569553 if color in self ._color_to_comm_buffer_list .keys ():
570554 for comm_buffer in self ._color_to_comm_buffer_list [color ]:
571555 comm_buffer ._reset_param_storage ()
556+ # 2D params
557+ if color in self ._params_2d_by_color .keys ():
558+ for param in self ._params_2d_by_color [color ]:
559+ new_param = paddle .empty_like (param )
560+ new_param ._share_buffer_to (param )
572561
573562 # ------------------------------------------------------------------
574563 # Gradient communication
@@ -671,31 +660,13 @@ def reduce_gradients(self, parameter_list, hcg):
671660
672661 def filter_parameters (self , parameter_list , hcg ):
673662 """Filter parameters: return local 2D params + initialized 1D slices."""
674- sharding_rank = hcg .get_sharding_parallel_rank ()
675- local_2d = [
676- p
677- for p in parameter_list
678- if p .name in self ._param2rank_2d
679- and self ._param2rank_2d [p .name ] == sharding_rank
680- ]
681- # Also include MoE 2D params owned by this rank
682- if self ._moe_sharding_world_size > 1 :
683- moe_rank = self ._moe_sharding_rank
684- else :
685- moe_rank = 0
686- local_2d_moe = [
687- p
688- for p in parameter_list
689- if p .name in self ._param2rank_2d_moe
690- and self ._param2rank_2d_moe [p .name ] == moe_rank
691- ]
692663 local_1d = [
693664 self ._slice_params [p .name ]
694665 for p in parameter_list
695666 if p .name in self ._slice_params
696667 ]
697668 local_1d = [p for p in local_1d if p ._is_initialized ()]
698- return local_2d + local_2d_moe + local_1d
669+ return self . _local_2d + local_1d
699670
700671 # ------------------------------------------------------------------
701672 # Parameter sync after optimizer step
@@ -771,12 +742,12 @@ def clear_grad_func(p):
771742 for p in self ._parameter_list :
772743 clear_grad_func (p )
773744
774- # 1D params are managed by comm buffers
775745 if self .sd_release_grads and not self .pp_overlap :
746+ # 1D params are managed by comm buffers
776747 for comm_buffer in self ._comm_buffer_list :
777748 if comm_buffer .need_reduce_scale_sync ():
778749 comm_buffer ._clear_grad_storage ()
779-
750+ # 2D params are managed by comm buffers
780751 if self ._use_fuse_gradients :
781752 for comm_buffer in self .comm_buffer_2d :
782753 if comm_buffer .need_reduce_scale_sync ():
@@ -820,6 +791,8 @@ def _assign_slice_grad(self):
820791
821792 def step (self ):
822793 """Optimizer step: update local 2D params and 1D slices, then sync."""
794+ self .reset_param_storage ()
795+
823796 self ._collect_comm_buffers ()
824797 self ._assign_slice_grad ()
825798
@@ -884,18 +857,7 @@ def step(self):
884857 def set_state_dict (self , state_dict ):
885858 inner_state = {}
886859 # Collect local parameters: 2D whole-tensor params + 1D original params
887- # (set_state_dict uses legacy aliases; covers default and moe_expert colors)
888- local_2d = list (self ._rank2params_2d .get (self ._sharding_rank , []))
889- if self ._moe_sharding_world_size > 1 :
890- local_2d_moe = list (
891- self ._rank2params_2d_moe .get (self ._moe_sharding_rank , [])
892- )
893- else :
894- local_2d_moe = list (self ._rank2params_2d_moe .get (0 , []))
895- parameters = local_2d + local_2d_moe
896- # Add 1D params (use original param names for matching)
897- for p in self ._params_1d :
898- parameters .append (p )
860+ parameters = list (self ._local_2d ) + list (self ._params_1d )
899861
900862 if "LR_Scheduler" in state_dict :
901863 inner_state ["LR_Scheduler" ] = state_dict .pop ("LR_Scheduler" )
0 commit comments