1- import ctypes
21import os
32import random
43import time
98from checkpoint_engine .device_utils import DeviceManager , get_ip
109
1110
12- def _ibv_get_device_list () -> list [str ]:
13- lib = ctypes .CDLL ("libibverbs.so.1" )
14- lib .ibv_get_device_list .argtypes = [ctypes .POINTER (ctypes .c_int )] # int *num_devices
15- lib .ibv_get_device_list .restype = ctypes .POINTER (ctypes .c_void_p ) # struct ibv_device **
16-
17- lib .ibv_free_device_list .argtypes = [ctypes .POINTER (ctypes .c_void_p )]
18- lib .ibv_get_device_name .argtypes = [ctypes .c_void_p ] # struct ibv_device *
19- lib .ibv_get_device_name .restype = ctypes .c_char_p # const char *
20-
21- num = ctypes .c_int ()
22- dev_array = lib .ibv_get_device_list (ctypes .byref (num ))
23- if not dev_array or num .value <= 0 :
24- return []
25-
26- devices = []
27- for i in range (num .value ):
28- dev_ptr = dev_array [i ] # struct ibv_device *
29- name = lib .ibv_get_device_name (dev_ptr ) # const char *
30- devices .append (name .decode ())
31- lib .ibv_free_device_list (dev_array )
32- return devices
33-
34-
35- def _get_rdma_devices () -> list [str ]:
36- """
37- use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
38- """
39- devices_str = os .getenv ("PS_P2P_STORE_RDMA_DEVICES" )
40- if devices_str :
41- return devices_str .split ("," )
42- # if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
43- hca = os .getenv ("NCCL_IB_HCA" , None )
44- return _parse_NCCL_IB_HCA (hca or "" , _ibv_get_device_list ()) or _ibv_get_device_list ()
45-
46-
47- def _get_my_rdma_device (local_rank : int , gpu_count : int , devices : list [str ]) -> str :
48- """
49- implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
50- """
51- if not devices :
52- raise RuntimeError ("no rdma devices found" )
53- try :
54- assert len (devices ) <= gpu_count , (
55- f"rdma devices count { len (devices )} should be less than or equal to gpu count { gpu_count } "
56- )
57- assert gpu_count % len (devices ) == 0 , (
58- f"gpu count { gpu_count } should be divisible by rdma devices count { len (devices )} "
59- )
60- return devices [local_rank // (gpu_count // len (devices ))]
61- except AssertionError :
62- logger .error (
63- "Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
64- "The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
65- "The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
66- )
67- raise
68-
69-
70- def _parse_NCCL_IB_HCA (value : str , available_devices : list [str ]) -> list [str ]:
71- """
72- The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
73- The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
74-
75- The list is comma-separated; port numbers are NOT supported yet.
76- An optional prefix '^' indicates the list is an exclude list.
77- A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
78- Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
79-
80- Examples:
81- - `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
82- - `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
83- - `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
84- - `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
85- """
86- max_hcas = 32
87- if not value or value .strip () == "" :
88- return available_devices [:max_hcas ]
89-
90- value = value .strip ()
91- result = []
92- is_exclude = value .startswith ("^" )
93- if is_exclude :
94- value = value .removeprefix ("^" )
95- is_exact_match = value .startswith ("=" )
96- if is_exact_match :
97- value = value .removeprefix ("=" )
98-
99- device_specs = [spec .strip () for spec in value .split ("," ) if spec .strip ()]
100-
101- result = _resolve_device_specs (device_specs , is_exact_match , available_devices )
102- if is_exclude :
103- result = [dev for dev in available_devices if dev not in result ]
104- if len (result ) > max_hcas :
105- result = result [:max_hcas ]
106-
107- logger .info (f"RDMA Devices from 'NCCL_IB_HCA': { result } " )
108-
109- return result
110-
111-
112- def _resolve_device_specs (
113- device_specs : list [str ], is_exact_match : bool , available_devices : list [str ]
114- ) -> list [str ]:
115- devices = set ()
116- for spec in device_specs :
117- parts = spec .split (":" , 1 )
118- device_name = parts [0 ].strip ()
119- # HACK: mooncake transfer engine does not support port specification yet, so we ignore it
120- # port = parts[1].strip() if len(parts) > 1 else None
121- base_devices = (
122- [device_name ]
123- if device_name in available_devices
124- else []
125- if is_exact_match
126- else [dev for dev in available_devices if dev .startswith (device_name )]
127- )
128-
129- if not base_devices :
130- logger .warning (f"No RDMA device match { device_name = } where { is_exact_match = } ." )
131- continue
132-
133- for base_dev in base_devices :
134- devices .add (base_dev )
135-
136- return sorted (devices )
137-
138-
13911class P2PStore :
14012 def __init__ (self , device_manager : DeviceManager ):
14113 from mooncake .engine import TransferEngine
14214
14315 self .rank = int (os .environ ["RANK" ]) # ENV RANK is required
14416 gpu_count = device_manager .device_module .device_count ()
14517 local_rank = self .rank % gpu_count
146- device_type = device_manager .device_type
147- if device_type == "npu" and os .getenv ("PS_P2P_STORE_RDMA_DEVICES" ) is None :
148- self .device = ""
149- else :
150- self .device = _get_my_rdma_device (local_rank , gpu_count , _get_rdma_devices ())
18+ self .device = device_manager .rdma_device (local_rank )
15119 self .ip = get_ip ()
15220
15321 # we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
@@ -157,7 +25,7 @@ def __init__(self, device_manager: DeviceManager):
15725 ret = self .engine .initialize (
15826 self .ip ,
15927 "P2PHANDSHAKE" ,
160- "ascend_direct" if device_type == "npu" else "rdma" ,
28+ device_manager . transfer_engine_protocol ,
16129 self .device ,
16230 )
16331 if ret == 0 :
0 commit comments