@@ -178,6 +178,10 @@ def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0) ->
178178 self .max_chips_per_node = 16 if current_platform .is_iluvatar () else 8
179179 self .enable_overlap_schedule = self .scheduler_config .enable_overlap_schedule
180180 self .cached_control_reqs = []
181+ if self .ranks > 1 :
182+ self .gloo_group = dist .new_group (list (range (self .ranks )), backend = "gloo" )
183+ else :
184+ self .gloo_group = None
181185
182186 def init_control (self ):
183187 engine_worker_queue_port = self .parallel_config .local_engine_worker_queue_port
@@ -312,9 +316,12 @@ def update_weights_from_tensor(self, mmap_infos):
312316 self .experts_manager .tensor_infos = None
313317
314318 def _broadcast_model_weights_signal (self , src : int , group ) -> int :
315- signal_list = [self .model_weights_signal [0 ]]
316- paddle .distributed .broadcast_object_list (signal_list , src = src , group = group )
317- return int (signal_list [0 ])
319+ model_weights_signal_tensor = paddle .full (
320+ shape = [1 ], fill_value = self .model_weights_signal [0 ], dtype = "int32" , device = "cpu"
321+ )
322+ paddle .distributed .broadcast (model_weights_signal_tensor , src = src , group = group )
323+ value = model_weights_signal_tensor .numpy ()[0 ]
324+ return int (value )
318325
319326 def _get_exist_task_flag (self ) -> bool :
320327 if self .nnode > 1 :
@@ -465,7 +472,7 @@ def event_loop_normal(self) -> None:
465472 if self .fd_config .load_config .dynamic_load_weight and not envs .FD_ENABLE_V1_UPDATE_WEIGHTS :
466473 self .model_weights_signal [0 ] = int (self .model_weights_status .value [0 ])
467474 if self .ranks > 1 :
468- self .model_weights_signal [0 ] = self ._broadcast_model_weights_signal (src = 0 , group = None )
475+ self .model_weights_signal [0 ] = self ._broadcast_model_weights_signal (src = 0 , group = self . gloo_group )
469476
470477 req_dicts = None
471478 self .worker_healthy_live_signal .value [tp_rank % self .max_chips_per_node ] = int (time .time ())
@@ -530,7 +537,7 @@ def event_loop_normal(self) -> None:
530537 self .model_weights_signal [0 ] = self .model_weights_status .value [0 ]
531538 if self .ranks > 1 :
532539 self .model_weights_signal [0 ] = self ._broadcast_model_weights_signal (
533- src = 0 , group = None
540+ src = 0 , group = self . gloo_group
534541 )
535542 time .sleep (1 )
536543 self .model_weights_status .value [0 ] = (
0 commit comments