Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions fastdeploy/rl/dynamic_weight_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,13 @@ def clear_parameters(self, pid: int = 0, shutdown_process_group=False) -> None:
if shutdown_process_group:
paddle.distributed.shutdown_process_group(self.parallel_config.ep_group)
if shutdown_process_group:
# ProcessGroupGloo has no shutdown(); remove it from paddle's registry
# before the global sweep to avoid AttributeError.
from paddle.distributed.collective import _get_group_map_by_name

for name, pg in list(_get_group_map_by_name().items()):

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 使用了 Paddle 内部私有 API _get_group_map_by_name,存在版本兼容风险

  1. _get_group_map_by_name_ 开头,是 Paddle 的内部实现细节,不在公开 API 保证范围内,Paddle 版本升级时可能被改名/移除/返回格式变化,导致 ImportError 或逻辑静默失效。
  2. 下一行直接访问 pg.process_group,未先检查 hasattr(pg, 'process_group'),若 Paddle 内部 Group 对象结构变化会抛出 AttributeError(注释中说的 "to avoid AttributeError" 反而自身也有 AttributeError 风险)。

建议加防御性保护:

try:
    from paddle.distributed.collective import _get_group_map_by_name
    for name, pg in list(_get_group_map_by_name().items()):
        proc_group = getattr(pg, 'process_group', None)
        if proc_group is not None and not hasattr(proc_group, "shutdown"):
            _get_group_map_by_name().pop(name, None)
except (ImportError, AttributeError):
    pass  # paddle version without gloo registry; safe to skip

if pg.process_group is not None and not hasattr(pg.process_group, "shutdown"):
_get_group_map_by_name().pop(name, None)
paddle.distributed.shutdown_process_group()
self._update_shared_status(pid, ModelWeightsStatus.CLEARED)

Expand Down
17 changes: 12 additions & 5 deletions fastdeploy/worker/worker_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ def __init__(self, fd_config: FDConfig, ranks: int = 1, local_rank: int = 0) ->
self.max_chips_per_node = 16 if current_platform.is_iluvatar() else 8
self.enable_overlap_schedule = self.scheduler_config.enable_overlap_schedule
self.cached_control_reqs = []
if self.ranks > 1:
self.gloo_group = dist.new_group(list(range(self.ranks)), backend="gloo")
else:
self.gloo_group = None

def init_control(self):
engine_worker_queue_port = self.parallel_config.local_engine_worker_queue_port
Expand Down Expand Up @@ -316,9 +320,12 @@ def update_weights_from_tensor(self, mmap_infos):
self.experts_manager.tensor_infos = None

def _broadcast_model_weights_signal(self, src: int, group) -> int:
signal_list = [self.model_weights_signal[0]]
paddle.distributed.broadcast_object_list(signal_list, src=src, group=group)
return int(signal_list[0])
model_weights_signal_tensor = paddle.full(
shape=[1], fill_value=self.model_weights_signal[0], dtype="int32", device="cpu"
)
paddle.distributed.broadcast(model_weights_signal_tensor, src=src, group=group)
value = model_weights_signal_tensor.numpy()[0]
return int(value)

def _get_exist_task_flag(self) -> bool:
if self.nnode > 1:
Expand Down Expand Up @@ -498,7 +505,7 @@ def event_loop_normal(self) -> None:
if self.fd_config.load_config.dynamic_load_weight and not envs.FD_ENABLE_V1_UPDATE_WEIGHTS:
self.model_weights_signal[0] = int(self.model_weights_status.value[0])
if self.ranks > 1:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=None)
self.model_weights_signal[0] = self._broadcast_model_weights_signal(src=0, group=self.gloo_group)

req_dicts = None
self.worker_healthy_live_signal.value[tp_rank % self.max_chips_per_node] = int(time.time())
Expand Down Expand Up @@ -563,7 +570,7 @@ def event_loop_normal(self) -> None:
self.model_weights_signal[0] = self.model_weights_status.value[0]
if self.ranks > 1:
self.model_weights_signal[0] = self._broadcast_model_weights_signal(
src=0, group=None
src=0, group=self.gloo_group
)
time.sleep(1)
self.model_weights_status.value[0] = (
Expand Down
Loading