Skip to content

Commit 9e8bf5c

Browse files
authored
refactor: move rdma device detection to device_utils.py (#87)
1 parent b8516e3 commit 9e8bf5c

3 files changed

Lines changed: 153 additions & 138 deletions

File tree

checkpoint_engine/device_utils.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ctypes
12
import os
23
import re
34
import socket
@@ -44,6 +45,133 @@ def npu_generate_uuid() -> str:
4445
raise ValueError("The current process is not running on the npu device") from e
4546

4647

48+
def _ibv_get_device_list() -> list[str]:
49+
lib = ctypes.CDLL("libibverbs.so.1")
50+
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
51+
lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
52+
53+
lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
54+
lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
55+
lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
56+
57+
num = ctypes.c_int()
58+
dev_array = lib.ibv_get_device_list(ctypes.byref(num))
59+
if not dev_array or num.value <= 0:
60+
return []
61+
62+
devices = []
63+
for i in range(num.value):
64+
dev_ptr = dev_array[i] # struct ibv_device *
65+
name = lib.ibv_get_device_name(dev_ptr) # const char *
66+
devices.append(name.decode())
67+
lib.ibv_free_device_list(dev_array)
68+
return devices
69+
70+
71+
def _get_rdma_devices() -> list[str]:
72+
"""
73+
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
74+
"""
75+
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
76+
if devices_str:
77+
return devices_str.split(",")
78+
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
79+
hca = os.getenv("NCCL_IB_HCA", None)
80+
return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
81+
82+
83+
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
84+
"""
85+
implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
86+
"""
87+
if not devices:
88+
raise RuntimeError("no rdma devices found")
89+
try:
90+
assert len(devices) <= gpu_count, (
91+
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
92+
)
93+
assert gpu_count % len(devices) == 0, (
94+
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
95+
)
96+
return devices[local_rank // (gpu_count // len(devices))]
97+
except AssertionError:
98+
logger.error(
99+
"Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
100+
"The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
101+
"The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
102+
)
103+
raise
104+
105+
106+
def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
107+
"""
108+
The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
109+
The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
110+
111+
The list is comma-separated; port numbers are NOT supported yet.
112+
An optional prefix '^' indicates the list is an exclude list.
113+
A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
114+
Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
115+
116+
Examples:
117+
- `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
118+
- `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
119+
- `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
120+
- `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
121+
"""
122+
max_hcas = 32
123+
if not value or value.strip() == "":
124+
return available_devices[:max_hcas]
125+
126+
value = value.strip()
127+
result = []
128+
is_exclude = value.startswith("^")
129+
if is_exclude:
130+
value = value.removeprefix("^")
131+
is_exact_match = value.startswith("=")
132+
if is_exact_match:
133+
value = value.removeprefix("=")
134+
135+
device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
136+
137+
result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
138+
if is_exclude:
139+
result = [dev for dev in available_devices if dev not in result]
140+
if len(result) > max_hcas:
141+
result = result[:max_hcas]
142+
143+
logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
144+
145+
return result
146+
147+
148+
def _resolve_device_specs(
149+
device_specs: list[str], is_exact_match: bool, available_devices: list[str]
150+
) -> list[str]:
151+
devices = set()
152+
for spec in device_specs:
153+
parts = spec.split(":", 1)
154+
device_name = parts[0].strip()
155+
# HACK: mooncake transfer engine does not support port specification yet, so we ignore it
156+
# port = parts[1].strip() if len(parts) > 1 else None
157+
base_devices = (
158+
[device_name]
159+
if device_name in available_devices
160+
else []
161+
if is_exact_match
162+
else [dev for dev in available_devices if dev.startswith(device_name)]
163+
)
164+
165+
if not base_devices:
166+
logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
167+
continue
168+
169+
for base_dev in base_devices:
170+
devices.add(base_dev)
171+
172+
return sorted(devices)
173+
174+
47175
class DeviceManager:
48176
def __init__(self):
49177
self.device_type = self._detect_device_type()
@@ -84,3 +212,20 @@ def backend(self) -> str:
84212
return "nccl"
85213
else:
86214
raise TypeError("The current device type is not supported")
215+
216+
@property
217+
def transfer_engine_protocol(self) -> str:
218+
if self.device_type == "npu":
219+
return "ascend_direct"
220+
elif self.device_type == "cuda":
221+
return "rdma"
222+
else:
223+
raise TypeError("The current device type is not supported")
224+
225+
def rdma_device(self, rank: int) -> str:
226+
if self.transfer_engine_protocol == "ascend_direct":
227+
return ""
228+
elif self.transfer_engine_protocol == "rdma":
229+
return _get_my_rdma_device(rank, self.device_module.device_count(), _get_rdma_devices())
230+
else:
231+
raise TypeError("The current transfer engine protocol is not supported")

checkpoint_engine/p2p_store.py

Lines changed: 2 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import ctypes
21
import os
32
import random
43
import time
@@ -9,145 +8,14 @@
98
from checkpoint_engine.device_utils import DeviceManager, get_ip
109

1110

12-
def _ibv_get_device_list() -> list[str]:
13-
lib = ctypes.CDLL("libibverbs.so.1")
14-
lib.ibv_get_device_list.argtypes = [ctypes.POINTER(ctypes.c_int)] # int *num_devices
15-
lib.ibv_get_device_list.restype = ctypes.POINTER(ctypes.c_void_p) # struct ibv_device **
16-
17-
lib.ibv_free_device_list.argtypes = [ctypes.POINTER(ctypes.c_void_p)]
18-
lib.ibv_get_device_name.argtypes = [ctypes.c_void_p] # struct ibv_device *
19-
lib.ibv_get_device_name.restype = ctypes.c_char_p # const char *
20-
21-
num = ctypes.c_int()
22-
dev_array = lib.ibv_get_device_list(ctypes.byref(num))
23-
if not dev_array or num.value <= 0:
24-
return []
25-
26-
devices = []
27-
for i in range(num.value):
28-
dev_ptr = dev_array[i] # struct ibv_device *
29-
name = lib.ibv_get_device_name(dev_ptr) # const char *
30-
devices.append(name.decode())
31-
lib.ibv_free_device_list(dev_array)
32-
return devices
33-
34-
35-
def _get_rdma_devices() -> list[str]:
36-
"""
37-
use _ibv_get_device_list to get RDMA devices, if NCCL_IB_HCA has multiple values, just return
38-
"""
39-
devices_str = os.getenv("PS_P2P_STORE_RDMA_DEVICES")
40-
if devices_str:
41-
return devices_str.split(",")
42-
# if PS_P2P_STORE_RDMA_DEVICES is not set, try to use NCCL_IB_HCA to get RDMA devices
43-
hca = os.getenv("NCCL_IB_HCA", None)
44-
return _parse_NCCL_IB_HCA(hca or "", _ibv_get_device_list()) or _ibv_get_device_list()
45-
46-
47-
def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) -> str:
48-
"""
49-
implement network card device allocation, if network card is "mlx5_0,mlx5_1", then 0-3 will share mlx5_0, 4-7 will share mlx5_1, etc.
50-
"""
51-
if not devices:
52-
raise RuntimeError("no rdma devices found")
53-
try:
54-
assert len(devices) <= gpu_count, (
55-
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
56-
)
57-
assert gpu_count % len(devices) == 0, (
58-
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
59-
)
60-
return devices[local_rank // (gpu_count // len(devices))]
61-
except AssertionError:
62-
logger.error(
63-
"Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."
64-
"The number of RDMA devices should be less than or equal to GPU count, and GPU count should be divisible by the number of RDMA devices."
65-
"The acceptable value by NCCL_IB_HCA is documented in 'https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8'."
66-
)
67-
raise
68-
69-
70-
def _parse_NCCL_IB_HCA(value: str, available_devices: list[str]) -> list[str]:
71-
"""
72-
The acceptable value by NCCL_IB_HCA is documented in https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#id8.
73-
The Python version parser is referred to the CPP parser in NCCL: https://github.com/NVIDIA/nccl/blob/v2.28.3-1/src/transport/net_ib.cc#L658-L662.
74-
75-
The list is comma-separated; port numbers are NOT supported yet.
76-
An optional prefix '^' indicates the list is an exclude list.
77-
A second optional prefix '=' indicates that the tokens are exact names, otherwise by default NCCL would treat each token as a prefix.
78-
Please note that when '^' and '=' appear together, only '^=' is allowed, '=^' is not supported.
79-
80-
Examples:
81-
- `NCCL_IB_HCA="mlx5"`: Use all cards starting with `mlx5`.
82-
- `NCCL_IB_HCA="=mlx5_0,mlx5_1"`: Use specific cards `mlx5_0` and `mlx5_1`.
83-
- `NCCL_IB_HCA="^mlx5"`: Use all cards except those starting with `mlx5`.
84-
- `NCCL_IB_HCA="^=mlx5_0,mlx5_1"`: Use all cards except `mlx5_0` and `mlx5_1`.
85-
"""
86-
max_hcas = 32
87-
if not value or value.strip() == "":
88-
return available_devices[:max_hcas]
89-
90-
value = value.strip()
91-
result = []
92-
is_exclude = value.startswith("^")
93-
if is_exclude:
94-
value = value.removeprefix("^")
95-
is_exact_match = value.startswith("=")
96-
if is_exact_match:
97-
value = value.removeprefix("=")
98-
99-
device_specs = [spec.strip() for spec in value.split(",") if spec.strip()]
100-
101-
result = _resolve_device_specs(device_specs, is_exact_match, available_devices)
102-
if is_exclude:
103-
result = [dev for dev in available_devices if dev not in result]
104-
if len(result) > max_hcas:
105-
result = result[:max_hcas]
106-
107-
logger.info(f"RDMA Devices from 'NCCL_IB_HCA': {result}")
108-
109-
return result
110-
111-
112-
def _resolve_device_specs(
113-
device_specs: list[str], is_exact_match: bool, available_devices: list[str]
114-
) -> list[str]:
115-
devices = set()
116-
for spec in device_specs:
117-
parts = spec.split(":", 1)
118-
device_name = parts[0].strip()
119-
# HACK: mooncake transfer engine does not support port specification yet, so we ignore it
120-
# port = parts[1].strip() if len(parts) > 1 else None
121-
base_devices = (
122-
[device_name]
123-
if device_name in available_devices
124-
else []
125-
if is_exact_match
126-
else [dev for dev in available_devices if dev.startswith(device_name)]
127-
)
128-
129-
if not base_devices:
130-
logger.warning(f"No RDMA device match {device_name=} where {is_exact_match=}.")
131-
continue
132-
133-
for base_dev in base_devices:
134-
devices.add(base_dev)
135-
136-
return sorted(devices)
137-
138-
13911
class P2PStore:
14012
def __init__(self, device_manager: DeviceManager):
14113
from mooncake.engine import TransferEngine
14214

14315
self.rank = int(os.environ["RANK"]) # ENV RANK is required
14416
gpu_count = device_manager.device_module.device_count()
14517
local_rank = self.rank % gpu_count
146-
device_type = device_manager.device_type
147-
if device_type == "npu" and os.getenv("PS_P2P_STORE_RDMA_DEVICES") is None:
148-
self.device = ""
149-
else:
150-
self.device = _get_my_rdma_device(local_rank, gpu_count, _get_rdma_devices())
18+
self.device = device_manager.rdma_device(local_rank)
15119
self.ip = get_ip()
15220

15321
# we will start at most 8 ps processes, so we use 8 retries to avoid port conflicts in extreme cases
@@ -157,7 +25,7 @@ def __init__(self, device_manager: DeviceManager):
15725
ret = self.engine.initialize(
15826
self.ip,
15927
"P2PHANDSHAKE",
160-
"ascend_direct" if device_type == "npu" else "rdma",
28+
device_manager.transfer_engine_protocol,
16129
self.device,
16230
)
16331
if ret == 0:

tests/test_rdma_parser.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import pytest
55

6-
from checkpoint_engine.p2p_store import (
6+
from checkpoint_engine.device_utils import (
77
_get_my_rdma_device,
88
_get_rdma_devices,
99
_ibv_get_device_list,
@@ -43,7 +43,8 @@ def test_get_rdma_devices_no_env_vars(mock_available_devices: list[str]):
4343
with (
4444
patch.dict(os.environ, clear=True),
4545
patch(
46-
"checkpoint_engine.p2p_store._ibv_get_device_list", return_value=mock_available_devices
46+
"checkpoint_engine.device_utils._ibv_get_device_list",
47+
return_value=mock_available_devices,
4748
),
4849
):
4950
devices = _get_rdma_devices()
@@ -123,7 +124,7 @@ def test_parse_exact_match_with_nonexistent_device(
123124
mock_available_devices: list[str],
124125
):
125126
"""Test exact matching with non-existent device"""
126-
with patch("checkpoint_engine.p2p_store.logger") as mock_logger:
127+
with patch("checkpoint_engine.device_utils.logger") as mock_logger:
127128
result = _parse_NCCL_IB_HCA(input_value, mock_available_devices)
128129
assert result == expected_result
129130
mock_logger.warning.assert_called_once_with(expected_warning)
@@ -151,7 +152,8 @@ def test_get_rdma_devices_with_env_vars(
151152
with (
152153
patch.dict(os.environ, env_dict),
153154
patch(
154-
"checkpoint_engine.p2p_store._ibv_get_device_list", return_value=mock_available_devices
155+
"checkpoint_engine.device_utils._ibv_get_device_list",
156+
return_value=mock_available_devices,
155157
),
156158
):
157159
devices = _get_rdma_devices()

0 commit comments

Comments
 (0)