|
4 | 4 | import socket |
5 | 5 | import subprocess |
6 | 6 | from functools import lru_cache |
| 7 | +from pathlib import Path |
7 | 8 |
|
8 | 9 | import torch |
9 | 10 | from loguru import logger |
@@ -178,6 +179,22 @@ def _resolve_device_specs( |
178 | 179 | return sorted(devices) |
179 | 180 |
|
180 | 181 |
|
| 182 | +def has_efa_pci() -> bool: |
| 183 | + """通过 PCI 设备 ID 精确检查是否存在 EFA 硬件""" |
| 184 | + pci_path = Path("/sys/class/infiniband/") |
| 185 | + if not pci_path.exists(): |
| 186 | + return False |
| 187 | + for device in pci_path.iterdir(): |
| 188 | + try: |
| 189 | + vendor = (device / "device" / "vendor").read_text().strip() |
| 190 | + # Amazon Vendor ID = 0x1d0f |
| 191 | + if vendor == "0x1d0f": |
| 192 | + return True |
| 193 | + except (OSError, ValueError): # noqa: PERF203 |
| 194 | + continue |
| 195 | + return False |
| 196 | + |
| 197 | + |
181 | 198 | class DeviceManager: |
182 | 199 | def __init__(self): |
183 | 200 | self.device_type = self._detect_device_type() |
@@ -224,14 +241,17 @@ def transfer_engine_protocol(self) -> str: |
224 | 241 | if self.device_type == "npu": |
225 | 242 | return "ascend_direct" |
226 | 243 | elif self.device_type == "cuda": |
227 | | - return "rdma" |
| 244 | + if has_efa_pci(): |
| 245 | + return "efa" |
| 246 | + else: |
| 247 | + return "rdma" |
228 | 248 | else: |
229 | 249 | raise TypeError("The current device type is not supported") |
230 | 250 |
|
231 | 251 | def rdma_device(self, rank: int) -> str: |
232 | 252 | if self.transfer_engine_protocol == "ascend_direct": |
233 | 253 | return "" |
234 | | - elif self.transfer_engine_protocol == "rdma": |
| 254 | + elif self.transfer_engine_protocol in ["rdma", "efa"]: |
235 | 255 | return _get_my_rdma_device(rank, self.device_module.device_count(), _get_rdma_devices()) |
236 | 256 | else: |
237 | 257 | raise TypeError("The current transfer engine protocol is not supported") |
0 commit comments