Skip to content

Commit 9894b32

Browse files
authored
[Cherry-Pick][RL] Support cpu tensor broadcast(#7833) (#7840)
* support cpu tensor broadcast * fix place * fix group * fix init * fix shutdown process group
1 parent 514ed5c commit 9894b32

2 files changed

Lines changed: 19 additions & 5 deletions

File tree

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ def clear_parameters(self, pid: int = 0, shutdown_process_group=False) -> None:
348348
if shutdown_process_group:
349349
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
350350
if shutdown_process_group:
351+
# ProcessGroupGloo has no shutdown(); remove it from paddle's registry
352+
# before the global sweep to avoid AttributeError.
353+
from paddle.distributed.collective import _get_group_map_by_name
354+
355+
for name, pg in list(_get_group_map_by_name().items()):
356+
if pg.process_group is not None and not hasattr(pg.process_group, "shutdown"):
357+
_get_group_map_by_name().pop(name, None)
351358
paddle.distributed.shutdown_process_group()
352359
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
353360

fastdeploy/worker/worker_process.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,10 @@ def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0) ->
174174
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
175175
self.enable_overlap_schedule = self.scheduler_config.enable_overlap_schedule
176176
self.cached_control_reqs = []
177+
if self.ranks > 1:
178+
self.gloo_group = dist.new_group(list(range(self.ranks)), backend="gloo")
179+
else:
180+
self.gloo_group = None
177181

178182
def init_control(self):
179183
engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port
@@ -316,9 +320,12 @@ def update_weights_from_tensor(self, mmap_infos):
316320
self.experts_manager.tensor_infos = None
317321

318322
def _broadcast_model_weights_signal(self, src: int, group) -> int:
319-
signal_list = [self.model_weights_signal[0]]
320-
paddle.distributed.broadcast_object_list(signal_list, src=src, group=group)
321-
return int(signal_list[0])
323+
model_weights_signal_tensor = paddle.full(
324+
shape=[1], fill_value=self.model_weights_signal[0], dtype="int32", device="cpu"
325+
)
326+
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
327+
value = model_weights_signal_tensor.numpy()[0]
328+
return int(value)
322329

323330
def _get_exist_task_flag(self) -> bool:
324331
if self.nnode > 1:
@@ -498,7 +505,7 @@ def event_loop_normal(self) -> None:
498505
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
499506
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
500507
if self.ranks > 1:
501-
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None)
508+
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=self.gloo_group)
502509

503510
req_dicts = None
504511
self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time())
@@ -563,7 +570,7 @@ def event_loop_normal(self) -> None:
563570
self.model_weights_signal[0] = self.model_weights_status.value[0]
564571
if self.ranks > 1:
565572
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
566-
src=0, group=None
573+
src=0, group=self.gloo_group
567574
)
568575
time.sleep(1)
569576
self.model_weights_status.value[0] = (

0 commit comments

Comments
 (0)