Skip to content

Commit 6045f04

Browse files
authored
[RL] Support cpu tensor broadcast (#7833)
* support cpu tensor broadcast * fix place * fix group * fix init * fix shutdown process group
1 parent 9d3dc0e commit 6045f04

2 files changed

Lines changed: 19 additions & 5 deletions

File tree

fastdeploy/rl/dynamic_weight_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,13 @@ def clear_parameters(self, pid: int = 0, shutdown_process_group=False) -> None:
348348
if shutdown_process_group:
349349
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
350350
if shutdown_process_group:
351+
# ProcessGroupGloo has no shutdown(); remove it from paddle's registry
352+
# before the global sweep to avoid AttributeError.
353+
from paddle.distributed.collective import _get_group_map_by_name
354+
355+
for name, pg in list(_get_group_map_by_name().items()):
356+
if pg.process_group is not None and not hasattr(pg.process_group, "shutdown"):
357+
_get_group_map_by_name().pop(name, None)
351358
paddle.distributed.shutdown_process_group()
352359
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)
353360

fastdeploy/worker/worker_process.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,10 @@ def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0) ->
178178
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
179179
self.enable_overlap_schedule = self.scheduler_config.enable_overlap_schedule
180180
self.cached_control_reqs = []
181+
if self.ranks > 1:
182+
self.gloo_group = dist.new_group(list(range(self.ranks)), backend="gloo")
183+
else:
184+
self.gloo_group = None
181185

182186
def init_control(self):
183187
engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port
@@ -312,9 +316,12 @@ def update_weights_from_tensor(self, mmap_infos):
312316
self.experts_manager.tensor_infos = None
313317

314318
def _broadcast_model_weights_signal(self, src: int, group) -> int:
315-
signal_list = [self.model_weights_signal[0]]
316-
paddle.distributed.broadcast_object_list(signal_list, src=src, group=group)
317-
return int(signal_list[0])
319+
model_weights_signal_tensor = paddle.full(
320+
shape=[1], fill_value=self.model_weights_signal[0], dtype="int32", device="cpu"
321+
)
322+
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
323+
value = model_weights_signal_tensor.numpy()[0]
324+
return int(value)
318325

319326
def _get_exist_task_flag(self) -> bool:
320327
if self.nnode > 1:
@@ -465,7 +472,7 @@ def event_loop_normal(self) -> None:
465472
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
466473
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
467474
if self.ranks > 1:
468-
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None)
475+
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=self.gloo_group)
469476

470477
req_dicts = None
471478
self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time())
@@ -530,7 +537,7 @@ def event_loop_normal(self) -> None:
530537
self.model_weights_signal[0] = self.model_weights_status.value[0]
531538
if self.ranks > 1:
532539
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
533-
src=0, group=None
540+
src=0, group=self.gloo_group
534541
)
535542
time.sleep(1)
536543
self.model_weights_status.value[0] = (

0 commit comments

Comments
 (0)