Skip to content

Commit 59a3c4f

Browse files
committed
feat: support aws efa
1 parent 5850e31 commit 59a3c4f

1 file changed

Lines changed: 22 additions & 2 deletions

File tree

checkpoint_engine/device_utils.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import socket
55
import subprocess
66
from functools import lru_cache
7+
from pathlib import Path
78

89
import torch
910
from loguru import logger
@@ -178,6 +179,22 @@ def _resolve_device_specs(
178179
return sorted(devices)
179180

180181

182+
def has_efa_pci() -> bool:
183+
"""通过 PCI 设备 ID 精确检查是否存在 EFA 硬件"""
184+
pci_path = Path("/sys/class/infiniband/")
185+
if not pci_path.exists():
186+
return False
187+
for device in pci_path.iterdir():
188+
try:
189+
vendor = (device / "device" / "vendor").read_text().strip()
190+
# Amazon Vendor ID = 0x1d0f
191+
if vendor == "0x1d0f":
192+
return True
193+
except (OSError, ValueError): # noqa: PERF203
194+
continue
195+
return False
196+
197+
181198
class DeviceManager:
182199
def __init__(self):
183200
self.device_type = self._detect_device_type()
@@ -224,14 +241,17 @@ def transfer_engine_protocol(self) -> str:
224241
if self.device_type == "npu":
225242
return "ascend_direct"
226243
elif self.device_type == "cuda":
227-
return "rdma"
244+
if has_efa_pci():
245+
return "efa"
246+
else:
247+
return "rdma"
228248
else:
229249
raise TypeError("The current device type is not supported")
230250

231251
def rdma_device(self, rank: int) -> str:
232252
if self.transfer_engine_protocol == "ascend_direct":
233253
return ""
234-
elif self.transfer_engine_protocol == "rdma":
254+
elif self.transfer_engine_protocol in ["rdma", "efa"]:
235255
return _get_my_rdma_device(rank, self.device_module.device_count(), _get_rdma_devices())
236256
else:
237257
raise TypeError("The current transfer engine protocol is not supported")

0 commit comments

Comments
 (0)