11# Copyright (c) OpenMMLab. All rights reserved.
22import asyncio
33import contextlib
4- import json
54import os
65from typing import Any
76
@@ -46,42 +45,6 @@ def _get_master_port():
4645 return find_available_port ()
4746
4847
49- def get_ascend_device_rank_mapping (master_addr ):
50- rank_table_file = _envs .ascend_rank_table_file
51- if not rank_table_file :
52- raise ValueError ('ASCEND_RANK_TABLE_FILE_PATH is not set' )
53- with open (rank_table_file ) as f :
54- rank_table = json .load (f )
55- try :
56- assert master_addr == rank_table ['server_list' ][0 ]['server_id' ], 'Master address does not match rank table'
57- rank_mapping : dict [int , int ] = {}
58- worker_ip_by_rank : dict [int , str ] = {}
59- for server in rank_table ['server_list' ]:
60- node_ip = server ['server_id' ]
61- for idx , device in enumerate (server ['device' ]):
62- # Prefer explicit device_id if present; fall back to enumeration order.
63- local_rank = int (device .get ('device_id' , idx ))
64- global_rank = int (device ['rank_id' ])
65- rank_mapping [global_rank ] = local_rank
66- worker_ip_by_rank [global_rank ] = node_ip
67-
68- if len (worker_ip_by_rank ) == 0 :
69- raise ValueError ('Rank table contains no devices.' )
70-
71- ranks = sorted (worker_ip_by_rank .keys ())
72- if ranks [0 ] != 0 or ranks [- 1 ] != len (ranks ) - 1 :
73- raise ValueError (f'Rank ids are not contiguous starting from 0: { ranks [:8 ]} ...{ ranks [- 8 :]} ' )
74- worker_ips = [worker_ip_by_rank [r ] for r in range (len (ranks ))]
75- except Exception as e :
76- logger .error (f'Parse rank table file({ rank_table } ) failed' )
77- raise e
78-
79- envs = {
80- 'ASCEND_RANK_TABLE_FILE_PATH' : rank_table_file ,
81- }
82- return rank_mapping , worker_ips , envs
83-
84-
8548def _update_env_cuda_alloc_conf (env_vars : dict ):
8649 """Update runtime env for CUDA alloc conf."""
8750 cuda_alloc_conf = os .getenv ('PYTORCH_CUDA_ALLOC_CONF' , None )
@@ -672,11 +635,14 @@ def _init_workers_ray(self, placement_group: PlacementGroup, worker_kwargs: dict
672635 runtime_env = runtime_env ,
673636 )(RayWorkerWrapper ).remote (** worker_kwargs )
674637 else :
638+ runtime_env = dict ()
639+ runtime_env = _update_runtime_envs (runtime_env )
675640 worker = ray .remote (
676641 num_cpus = 0 ,
677642 num_gpus = 0 ,
678643 resources = {device_str : 0.01 },
679644 scheduling_strategy = scheduling_strategy ,
645+ runtime_env = runtime_env ,
680646 )(RayWorkerWrapper ).remote (** worker_kwargs )
681647 workers .append (worker )
682648 return workers
@@ -697,33 +663,40 @@ def _init_distributed_environment_by_device(self, device_str: str):
697663
698664 def _init_ascend_distributed_environment (self , driver_ip ):
699665 """Init ascend distributed environment."""
700- rank_table_file = _envs .ascend_rank_table_file
701- set_rt_visable_devices_by_ray = _envs .ascend_set_rt_visable_devices_by_ray
666+ from collections import defaultdict
702667
703- if rank_table_file :
704- # if rank table file is set, use it to get rank mapping, multiple nodes
705- rank_mapping , worker_ips , envs = get_ascend_device_rank_mapping (driver_ip )
706- rank_start = self .rank_offset
707- rank_end = rank_start + len (self .workers )
708- if rank_end > len (worker_ips ):
709- raise ValueError (
710- 'Rank table world_size is smaller than required ranks for current dp_rank. '
711- f'rank_table_world_size={ len (worker_ips )} , required_rank_range=[{ rank_start } , { rank_end } )' )
712-
713- # In dp mode each process only owns a slice of global ranks.
714- expected_worker_ips = worker_ips [rank_start :rank_end ]
715- self .workers = self ._sort_workers_by_ip (expected_worker_ips , self .workers )
716-
717- ray .get (
718- [worker .set_device .remote (rank_mapping [rank_start + idx ]) for idx , worker in enumerate (self .workers )])
719- ray .get ([worker .set_env .remote (envs ) for worker in self .workers ])
720- elif not set_rt_visable_devices_by_ray :
721- # if rank table file is not set, treat as single node
722- # simply set device by index, this is for single node, multiple devices
723- self .workers = self ._sort_workers (driver_ip , self .workers )
724- ray .get ([worker .set_device .remote (idx + self .rank_offset ) for idx , worker in enumerate (self .workers )])
668+ set_rt_visable_devices_by_ray = _envs .ascend_set_rt_visable_devices_by_ray
669+ self .workers = self ._sort_workers (driver_ip , self .workers )
670+
671+ if set_rt_visable_devices_by_ray :
672+ # Ray populated ASCEND_RT_VISIBLE_DEVICES per actor; no set_device.
673+ return
674+
675+ worker_ips = ray .get ([w .get_node_ip .remote () for w in self .workers ])
676+ is_multi_node_pg = len (set (worker_ips )) > 1
677+
678+ if is_multi_node_pg :
679+ # Cross-node TP: each worker uses its index within its own node.
680+ local_indices : list [int ] = []
681+ counts : dict [str , int ] = defaultdict (int )
682+ for ip in worker_ips :
683+ local_indices .append (counts [ip ])
684+ counts [ip ] += 1
685+ ray .get ([w .set_device .remote (local_indices [idx ]) for idx , w in enumerate (self .workers )])
686+ return
687+
688+ # Single-node PG below.
689+ if 'ASCEND_RT_VISIBLE_DEVICES' in os .environ :
690+ ray .get ([w .set_device .remote (idx ) for idx , w in enumerate (self .workers )])
725691 else :
726- self .workers = self ._sort_workers (driver_ip , self .workers )
692+ local_npu_count = torch .npu .device_count ()
693+ if local_npu_count <= 0 :
694+ raise RuntimeError (
695+ 'torch.npu.device_count() returned a non-positive value; '
696+ 'cannot derive local NPU offset. Please set '
697+ 'ASCEND_RT_VISIBLE_DEVICES explicitly.' )
698+ local_offset = self .rank_offset % local_npu_count
699+ ray .get ([w .set_device .remote (idx + local_offset ) for idx , w in enumerate (self .workers )])
727700
728701 """ PD Disaggregation API Begin """
729702
0 commit comments