@@ -174,6 +174,10 @@ def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0) ->
174174 self .max_chips_per_node = 16 if current_platform .is_iluvatar () else 8
175175 self .enable_overlap_schedule = self .scheduler_config .enable_overlap_schedule
176176 self .cached_control_reqs = []
177+ if self .ranks > 1 :
178+ self .gloo_group = dist .new_group (list (range (self .ranks )), backend = "gloo" )
179+ else :
180+ self .gloo_group = None
177181
178182 def init_control (self ):
179183 engine_worker_queue_port = self .parallel_config .local_engine_worker_queue_port
@@ -316,9 +320,12 @@ def update_weights_from_tensor(self, mmap_infos):
316320 self .experts_manager .tensor_infos = None
317321
318322 def _broadcast_model_weights_signal (self , src : int , group ) -> int :
319- signal_list = [self .model_weights_signal [0 ]]
320- paddle .distributed .broadcast_object_list (signal_list , src = src , group = group )
321- return int (signal_list [0 ])
323+ model_weights_signal_tensor = paddle .full (
324+ shape = [1 ], fill_value = self .model_weights_signal [0 ], dtype = "int32" , device = "cpu"
325+ )
326+ paddle .distributed .broadcast (model_weights_signal_tensor , src = src , group = group )
327+ value = model_weights_signal_tensor .numpy ()[0 ]
328+ return int (value )
322329
323330 def _get_exist_task_flag (self ) -> bool :
324331 if self .nnode > 1 :
@@ -498,7 +505,7 @@ def event_loop_normal(self) -> None:
498505 if self .fd_config .load_config .dynamic_load_weight and not envs .FD_ENABLE_V1_UPDATE_WEIGHTS :
499506 self .model_weights_signal [0 ] = int (self .model_weights_status .value [0 ])
500507 if self .ranks > 1 :
501- self .model_weights_signal [0 ] = self ._broadcast_model_weights_signal (src = 0 , group = None )
508+ self .model_weights_signal [0 ] = self ._broadcast_model_weights_signal (src = 0 , group = self . gloo_group )
502509
503510 req_dicts = None
504511 self .worker_healthy_live_signal .value [tp_rank % self .max_chips_per_node ] = int (time .time ())
@@ -563,7 +570,7 @@ def event_loop_normal(self) -> None:
563570 self .model_weights_signal [0 ] = self .model_weights_status .value [0 ]
564571 if self .ranks > 1 :
565572 self .model_weights_signal [0 ] = self ._broadcast_model_weights_signal (
566- src = 0 , group = None
573+ src = 0 , group = self . gloo_group
567574 )
568575 time .sleep (1 )
569576 self .model_weights_status .value [0 ] = (
0 commit comments