Skip to content

Commit 0a6620c

Browse files
committed
feat: support multiple interface for single device
1 parent 9e8bf5c commit 0a6620c

2 files changed

Lines changed: 14 additions & 8 deletions

File tree

checkpoint_engine/device_utils.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,19 @@ def _get_my_rdma_device(local_rank: int, gpu_count: int, devices: list[str]) ->
8787
if not devices:
8888
raise RuntimeError("no rdma devices found")
8989
try:
90-
assert len(devices) <= gpu_count, (
91-
f"rdma devices count {len(devices)} should be less than or equal to gpu count {gpu_count}"
92-
)
93-
assert gpu_count % len(devices) == 0, (
94-
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
95-
)
96-
return devices[local_rank // (gpu_count // len(devices))]
90+
if len(devices) <= gpu_count:
91+
assert gpu_count % len(devices) == 0, (
92+
f"gpu count {gpu_count} should be divisible by rdma devices count {len(devices)}"
93+
)
94+
return devices[local_rank // (gpu_count // len(devices))]
95+
else:
96+
assert len(devices) % gpu_count == 0, (
97+
f"rdma devices count {len(devices)} should be divisible by gpu count {gpu_count}"
98+
)
99+
device_per_rank = len(devices) // gpu_count
100+
return ",".join(
101+
devices[local_rank * device_per_rank : (local_rank + 1) * device_per_rank]
102+
)
97103
except AssertionError:
98104
logger.error(
99105
"Please set 'NCCL_IB_HCA' or 'PS_P2P_STORE_RDMA_DEVICES' environment variable to choose proper number of RDMA devices."

checkpoint_engine/p2p_store.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, device_manager: DeviceManager):
4141
self.port = self.engine.get_rpc_port()
4242
self.named_tensors: dict[str, torch.Tensor] = {}
4343
logger.info(
44-
f"[rank{self.rank}] p2p store initialized, addr is {self.addr}, rdma device is {self.device}"
44+
f"[rank{self.rank}] p2p store initialized, protocol {device_manager.transfer_engine_protocol}, addr {self.addr}, device {self.device}"
4545
)
4646

4747
@property

0 commit comments

Comments
 (0)