Skip to content

Commit e1b73cd

Browse files
Refine multi-node support on ascend-A3 (#4711)
* refactor ascend multinode * fixed rebase * cleanup code * cleanup again --------- Co-authored-by: yaofengchen <fengchenyao@foxmail.com>
1 parent 2871443 commit e1b73cd

2 files changed

Lines changed: 41 additions & 63 deletions

File tree

lmdeploy/pytorch/engine/executor/ray_executor.py

Lines changed: 35 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
22
import asyncio
33
import contextlib
4-
import json
54
import os
65
from typing import Any
76

@@ -46,42 +45,6 @@ def _get_master_port():
4645
return find_available_port()
4746

4847

49-
def get_ascend_device_rank_mapping(master_addr):
50-
rank_table_file = _envs.ascend_rank_table_file
51-
if not rank_table_file:
52-
raise ValueError('ASCEND_RANK_TABLE_FILE_PATH is not set')
53-
with open(rank_table_file) as f:
54-
rank_table = json.load(f)
55-
try:
56-
assert master_addr == rank_table['server_list'][0]['server_id'], 'Master address does not match rank table'
57-
rank_mapping: dict[int, int] = {}
58-
worker_ip_by_rank: dict[int, str] = {}
59-
for server in rank_table['server_list']:
60-
node_ip = server['server_id']
61-
for idx, device in enumerate(server['device']):
62-
# Prefer explicit device_id if present; fall back to enumeration order.
63-
local_rank = int(device.get('device_id', idx))
64-
global_rank = int(device['rank_id'])
65-
rank_mapping[global_rank] = local_rank
66-
worker_ip_by_rank[global_rank] = node_ip
67-
68-
if len(worker_ip_by_rank) == 0:
69-
raise ValueError('Rank table contains no devices.')
70-
71-
ranks = sorted(worker_ip_by_rank.keys())
72-
if ranks[0] != 0 or ranks[-1] != len(ranks) - 1:
73-
raise ValueError(f'Rank ids are not contiguous starting from 0: {ranks[:8]}...{ranks[-8:]}')
74-
worker_ips = [worker_ip_by_rank[r] for r in range(len(ranks))]
75-
except Exception as e:
76-
logger.error(f'Parse rank table file({rank_table}) failed')
77-
raise e
78-
79-
envs = {
80-
'ASCEND_RANK_TABLE_FILE_PATH': rank_table_file,
81-
}
82-
return rank_mapping, worker_ips, envs
83-
84-
8548
def _update_env_cuda_alloc_conf(env_vars: dict):
8649
"""Update runtime env for CUDA alloc conf."""
8750
cuda_alloc_conf = os.getenv('PYTORCH_CUDA_ALLOC_CONF', None)
@@ -672,11 +635,14 @@ def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict
672635
runtime_env=runtime_env,
673636
)(RayWorkerWrapper).remote(**worker_kwargs)
674637
else:
638+
runtime_env = dict()
639+
runtime_env = _update_runtime_envs(runtime_env)
675640
worker = ray.remote(
676641
num_cpus=0,
677642
num_gpus=0,
678643
resources={device_str: 0.01},
679644
scheduling_strategy=scheduling_strategy,
645+
runtime_env=runtime_env,
680646
)(RayWorkerWrapper).remote(**worker_kwargs)
681647
workers.append(worker)
682648
return workers
@@ -697,33 +663,40 @@ def _init_distributed_environment_by_device(self, device_str: str):
697663

698664
def _init_ascend_distributed_environment(self, driver_ip):
699665
"""Init ascend distributed environment."""
700-
rank_table_file = _envs.ascend_rank_table_file
701-
set_rt_visable_devices_by_ray = _envs.ascend_set_rt_visable_devices_by_ray
666+
from collections import defaultdict
702667

703-
if rank_table_file:
704-
# if rank table file is set, use it to get rank mapping, multiple nodes
705-
rank_mapping, worker_ips, envs = get_ascend_device_rank_mapping(driver_ip)
706-
rank_start = self.rank_offset
707-
rank_end = rank_start + len(self.workers)
708-
if rank_end > len(worker_ips):
709-
raise ValueError(
710-
'Rank table world_size is smaller than required ranks for current dp_rank. '
711-
f'rank_table_world_size={len(worker_ips)}, required_rank_range=[{rank_start}, {rank_end})')
712-
713-
# In dp mode each process only owns a slice of global ranks.
714-
expected_worker_ips = worker_ips[rank_start:rank_end]
715-
self.workers = self._sort_workers_by_ip(expected_worker_ips, self.workers)
716-
717-
ray.get(
718-
[worker.set_device.remote(rank_mapping[rank_start + idx]) for idx, worker in enumerate(self.workers)])
719-
ray.get([worker.set_env.remote(envs) for worker in self.workers])
720-
elif not set_rt_visable_devices_by_ray:
721-
# if rank table file is not set, treat as single node
722-
# simply set device by index, this is for single node, multiple devices
723-
self.workers = self._sort_workers(driver_ip, self.workers)
724-
ray.get([worker.set_device.remote(idx + self.rank_offset) for idx, worker in enumerate(self.workers)])
668+
set_rt_visable_devices_by_ray = _envs.ascend_set_rt_visable_devices_by_ray
669+
self.workers = self._sort_workers(driver_ip, self.workers)
670+
671+
if set_rt_visable_devices_by_ray:
672+
# Ray populated ASCEND_RT_VISIBLE_DEVICES per actor; no set_device.
673+
return
674+
675+
worker_ips = ray.get([w.get_node_ip.remote() for w in self.workers])
676+
is_multi_node_pg = len(set(worker_ips)) > 1
677+
678+
if is_multi_node_pg:
679+
# Cross-node TP: each worker uses its index within its own node.
680+
local_indices: list[int] = []
681+
counts: dict[str, int] = defaultdict(int)
682+
for ip in worker_ips:
683+
local_indices.append(counts[ip])
684+
counts[ip] += 1
685+
ray.get([w.set_device.remote(local_indices[idx]) for idx, w in enumerate(self.workers)])
686+
return
687+
688+
# Single-node PG below.
689+
if 'ASCEND_RT_VISIBLE_DEVICES' in os.environ:
690+
ray.get([w.set_device.remote(idx) for idx, w in enumerate(self.workers)])
725691
else:
726-
self.workers = self._sort_workers(driver_ip, self.workers)
692+
local_npu_count = torch.npu.device_count()
693+
if local_npu_count <= 0:
694+
raise RuntimeError(
695+
'torch.npu.device_count() returned a non-positive value; '
696+
'cannot derive local NPU offset. Please set '
697+
'ASCEND_RT_VISIBLE_DEVICES explicitly.')
698+
local_offset = self.rank_offset % local_npu_count
699+
ray.get([w.set_device.remote(idx + local_offset) for idx, w in enumerate(self.workers)])
727700

728701
""" PD Disaggregation API Begin """
729702

lmdeploy/pytorch/envs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,6 @@ def _patched_get_env(
102102

103103
# ascend
104104
ascend_set_rt_visable_devices_by_ray = env_to_bool('ASCEND_SET_RT_VISIBLE_DEVICES_BY_RAY', False)
105-
ascend_rank_table_file = os.getenv('ASCEND_RANK_TABLE_FILE_PATH')
106105

107106
# dp
108107
dp_master_addr = os.getenv('LMDEPLOY_DP_MASTER_ADDR', None)
@@ -140,6 +139,12 @@ def _patched_get_env(
140139
# check env
141140
enable_check_env = env_to_bool('LMDEPLOY_ENABLE_CHECK_ENV', True)
142141

142+
# hccl / ascend - passed to ray workers
143+
os.getenv('HCCL_BUFFSIZE', None)
144+
os.getenv('HCCL_CONNECT_TIMEOUT', None)
145+
os.getenv('HCCL_OP_EXPANSION_MODE', None)
146+
os.getenv('HCCL_IF_IP', None)
147+
143148
# dlblas
144149
# we don't need to read this, it would be passed to ray workers
145150
# If Ray is launched from outside, it may fail to access the environment variables.

0 commit comments

Comments
 (0)