5858g_shard_bypass_dygraph_optimizer = int (
5959 os .environ .get ("FLAGS_shard_bypass_dygraph_optimizer" , 0 )
6060)
61- g_shard_fused_gradient = int (os .environ .get ("FLAGS_shard_fused_gradient" , 0 ))
6261
6362
6463def _is_trainable (param ):
@@ -239,29 +238,25 @@ def __init__(self, optimizer, hcg=None):
239238 key = lambda p : self ._param2rank_2d_by_color [color_key ][p .name ]
240239 )
241240
242- # ---- Backward compatibility: expose legacy attributes ----
243- # These are kept for any external code that might reference them
244- self ._params_2d = self ._params_2d_by_color .get (None , [])
245- self ._params_2d_moe = self ._params_2d_by_color .get ('moe_expert' , [])
246- self ._rank2params_2d = self ._rank2params_2d_by_color .get (None , {0 : []})
247- self ._param2rank_2d = self ._param2rank_2d_by_color .get (None , {})
248- self ._rank2params_2d_moe = self ._rank2params_2d_by_color .get (
249- 'moe_expert' , {0 : []}
250- )
251- self ._param2rank_2d_moe = self ._param2rank_2d_by_color .get (
252- 'moe_expert' , {}
253- )
241+ # 2D params owned by this sharding rank
242+ self ._local_2d = []
243+ for color_key , params_2d in self ._params_2d_by_color .items ():
244+ rank2params_2d_by_color = self ._rank2params_2d_by_color [color_key ]
245+
246+ group_info = self ._color_to_group_info [color_key ]
247+ sharding_rank = max (group_info ['rank' ], 0 )
248+
249+ self ._local_2d .extend (rank2params_2d_by_color [sharding_rank ])
254250
255251 self .sd_release_grads = (
256252 strategy .hybrid_configs ['pp_configs' ].release_gradients
257253 or sharding_configs .release_gradients
258254 )
259- self ._use_fuse_gradients = g_shard_fused_gradient
255+ self ._use_fuse_gradients = self . comm_buffer_size_MB > 0
260256 # ---- Build comm buffers for 2D params (V1-style) ----
261257 if self ._use_fuse_gradients :
262- if not hasattr (self , 'comm_buffer_2d' ):
263- self .comm_buffer_2d = self ._build_2d_comm_buffers ()
264- self .comm_buffer_2d .sort (key = lambda x : x ._dst )
258+ self .comm_buffer_2d = self ._build_2d_comm_buffers ()
259+ self .comm_buffer_2d .sort (key = lambda x : x ._dst )
265260
266261 # ---- Step 3: Build comm buffers for 1D params (V2-style) ----
267262 self ._slice_params = {}
@@ -278,15 +273,9 @@ def __init__(self, optimizer, hcg=None):
278273 # The optimizer should see:
279274 # - All 2D params assigned to this rank (all colors, as whole tensors)
280275 # - 1D slice_params for all non-2D params (element-wise shards)
281- local_2d_params = []
282- for color_key , rank2params in self ._rank2params_2d_by_color .items ():
283- group_info = self ._color_to_group_info .get (color_key , {})
284- color_rank = group_info .get ('rank' , 0 )
285- world_size = group_info .get ('world_size' , 1 )
286- rank_key = color_rank if world_size > 1 else 0
287- local_2d_params .extend (rank2params .get (rank_key , []))
288-
289- local_opt_params = local_2d_params + list (self ._local_parameter_list_1d )
276+ local_opt_params = list (self ._local_2d ) + list (
277+ self ._local_parameter_list_1d
278+ )
290279
291280 self ._set_inner_opt_attr ('_parameter_list' , local_opt_params )
292281 self ._set_inner_opt_attr ('_param_groups' , local_opt_params )
@@ -306,18 +295,10 @@ def __init__(self, optimizer, hcg=None):
306295 _sg_group = hcg .get_sharding_parallel_group ()
307296 _N = self ._sharding_world_size
308297
309- # 2D params owned by this sharding rank (default color, via legacy alias)
298+ # 2D params owned by this sharding rank
310299 _local_2d_numel = sum (
311300 int (functools_reduce (lambda x , y : x * y , p .shape , 1 ))
312- for p in self ._rank2params_2d .get (self ._sharding_rank , [])
313- )
314- # 2D MoE-expert params owned by this rank (moe_expert color, via legacy alias)
315- _moe_rank_key = (
316- self ._moe_sharding_rank if self ._moe_sharding_world_size > 1 else 0
317- )
318- _local_2d_moe_numel = sum (
319- int (functools_reduce (lambda x , y : x * y , p .shape , 1 ))
320- for p in self ._rank2params_2d_moe .get (_moe_rank_key , [])
301+ for p in self ._local_2d
321302 )
322303 # 1D (AdamW) slice: each rank holds ceil(numel / sharding_world_size) elements.
323304 _local_1d_numel = sum (
@@ -327,9 +308,7 @@ def __init__(self, optimizer, hcg=None):
327308 for p in self ._params_1d
328309 )
329310
330- _local_total_numel = (
331- _local_2d_numel + _local_2d_moe_numel + _local_1d_numel
332- )
311+ _local_total_numel = _local_2d_numel + _local_1d_numel
333312 _local_total_MB = (
334313 _local_total_numel * 2 / (1024 * 1024 )
335314 ) # bf16/fp16 = 2 bytes
@@ -545,6 +524,7 @@ def _build_1d_comm_buffers(self):
545524 self ._comm_buffer_list .sort (key = lambda x : x ._dst )
546525
547526 def clear_param_storage (self , color ):
527+ # Only clear param_storage for 1d_params, 2d_params are not added to comm_buffers.
548528 self .clear_color .add (color )
549529 if color in self ._color_to_comm_buffer_list .keys ():
550530 for comm_buffer in self ._color_to_comm_buffer_list [color ]:
@@ -671,31 +651,13 @@ def reduce_gradients(self, parameter_list, hcg):
671651
672652 def filter_parameters (self , parameter_list , hcg ):
673653 """Filter parameters: return local 2D params + initialized 1D slices."""
674- sharding_rank = hcg .get_sharding_parallel_rank ()
675- local_2d = [
676- p
677- for p in parameter_list
678- if p .name in self ._param2rank_2d
679- and self ._param2rank_2d [p .name ] == sharding_rank
680- ]
681- # Also include MoE 2D params owned by this rank
682- if self ._moe_sharding_world_size > 1 :
683- moe_rank = self ._moe_sharding_rank
684- else :
685- moe_rank = 0
686- local_2d_moe = [
687- p
688- for p in parameter_list
689- if p .name in self ._param2rank_2d_moe
690- and self ._param2rank_2d_moe [p .name ] == moe_rank
691- ]
692654 local_1d = [
693655 self ._slice_params [p .name ]
694656 for p in parameter_list
695657 if p .name in self ._slice_params
696658 ]
697659 local_1d = [p for p in local_1d if p ._is_initialized ()]
698- return local_2d + local_2d_moe + local_1d
660+ return self . _local_2d + local_1d
699661
700662 # ------------------------------------------------------------------
701663 # Parameter sync after optimizer step
@@ -884,18 +846,7 @@ def step(self):
884846 def set_state_dict (self , state_dict ):
885847 inner_state = {}
886848 # Collect local parameters: 2D whole-tensor params + 1D original params
887- # (set_state_dict uses legacy aliases; covers default and moe_expert colors)
888- local_2d = list (self ._rank2params_2d .get (self ._sharding_rank , []))
889- if self ._moe_sharding_world_size > 1 :
890- local_2d_moe = list (
891- self ._rank2params_2d_moe .get (self ._moe_sharding_rank , [])
892- )
893- else :
894- local_2d_moe = list (self ._rank2params_2d_moe .get (0 , []))
895- parameters = local_2d + local_2d_moe
896- # Add 1D params (use original param names for matching)
897- for p in self ._params_1d :
898- parameters .append (p )
849+ parameters = list (self ._local_2d ) + list (self ._params_1d )
899850
900851 if "LR_Scheduler" in state_dict :
901852 inner_state ["LR_Scheduler" ] = state_dict .pop ("LR_Scheduler" )
0 commit comments