Skip to content

Commit a2d4de4

Browse files
fix: Add ml.p5e.48xlarge and ml.p5.48xlarge to EFA instance lists
Add ml.p5e.48xlarge to SM_EFA_NCCL_INSTANCES and SM_EFA_RDMA_INSTANCES. Add ml.p5.48xlarge to SM_EFA_RDMA_INSTANCES (was missing). Without these entries, NCCL hangs during distributed training initialization on P5e instances due to missing EFA environment variables (FI_PROVIDER, FI_EFA_USE_DEVICE_RDMA, RDMAV_FORK_SAFE). Fixes #5491
1 parent 33bf993 commit a2d4de4

File tree

5 files changed

+38
-0
lines changed

5 files changed

+38
-0
lines changed

sagemaker-core/src/sagemaker/core/modules/train/container_drivers/common/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,15 @@
5050
"ml.p4d.24xlarge",
5151
"ml.p4de.24xlarge",
5252
"ml.p5.48xlarge",
53+
"ml.p5e.48xlarge",
5354
"ml.trn1.32xlarge",
5455
]
5556

5657
SM_EFA_RDMA_INSTANCES = [
5758
"ml.p4d.24xlarge",
5859
"ml.p4de.24xlarge",
60+
"ml.p5.48xlarge",
61+
"ml.p5e.48xlarge",
5962
"ml.trn1.32xlarge",
6063
]
6164

sagemaker-core/src/sagemaker/core/remote_function/runtime_environment/bootstrap_runtime_environment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@
7575
"ml.p4d.24xlarge",
7676
"ml.p4de.24xlarge",
7777
"ml.p5.48xlarge",
78+
"ml.p5e.48xlarge",
7879
"ml.trn1.32xlarge",
7980
]
8081

8182
SM_EFA_RDMA_INSTANCES = [
8283
"ml.p4d.24xlarge",
8384
"ml.p4de.24xlarge",
85+
"ml.p5.48xlarge",
86+
"ml.p5e.48xlarge",
8487
"ml.trn1.32xlarge",
8588
]
8689

sagemaker-train/src/sagemaker/train/container_drivers/common/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,12 +50,15 @@
5050
"ml.p4d.24xlarge",
5151
"ml.p4de.24xlarge",
5252
"ml.p5.48xlarge",
53+
"ml.p5e.48xlarge",
5354
"ml.trn1.32xlarge",
5455
]
5556

5657
SM_EFA_RDMA_INSTANCES = [
5758
"ml.p4d.24xlarge",
5859
"ml.p4de.24xlarge",
60+
"ml.p5.48xlarge",
61+
"ml.p5e.48xlarge",
5962
"ml.trn1.32xlarge",
6063
]
6164

sagemaker-train/src/sagemaker/train/remote_function/runtime_environment/bootstrap_runtime_environment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,15 @@
7575
"ml.p4d.24xlarge",
7676
"ml.p4de.24xlarge",
7777
"ml.p5.48xlarge",
78+
"ml.p5e.48xlarge",
7879
"ml.trn1.32xlarge",
7980
]
8081

8182
SM_EFA_RDMA_INSTANCES = [
8283
"ml.p4d.24xlarge",
8384
"ml.p4de.24xlarge",
85+
"ml.p5.48xlarge",
86+
"ml.p5e.48xlarge",
8487
"ml.trn1.32xlarge",
8588
]
8689

sagemaker-train/tests/unit/train/container_drivers/test_torchrun_driver.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import sys
1818
import json
1919

20+
import pytest
2021
from unittest.mock import patch, MagicMock
2122

2223
sys.modules["utils"] = MagicMock()
@@ -146,3 +147,28 @@ def test_create_commands_multi_node(
146147

147148
command = torchrun_driver.create_commands()
148149
assert command == expected_command
150+
151+
152+
@pytest.mark.parametrize("instance_type", ["ml.p5.48xlarge", "ml.p5e.48xlarge"])
153+
@patch.dict(
154+
os.environ,
155+
{
156+
"SM_NETWORK_INTERFACE_NAME": "eth0",
157+
"SM_HOST_COUNT": "2",
158+
"SM_MASTER_ADDR": "algo-1",
159+
"SM_MASTER_PORT": "7777",
160+
"SM_CURRENT_HOST_RANK": "0",
161+
"SM_HPS": json.dumps({}),
162+
"SM_DISTRIBUTED_CONFIG": json.dumps(DUMMY_DISTRIBUTED),
163+
"SM_ENTRY_SCRIPT": "script.py",
164+
},
165+
)
166+
def test_p5_p5e_efa_environment_setup(instance_type):
167+
"""Test that P5 and P5e instances are in EFA instance lists."""
168+
from sagemaker.train.container_drivers.common.utils import (
169+
SM_EFA_NCCL_INSTANCES,
170+
SM_EFA_RDMA_INSTANCES,
171+
)
172+
173+
assert instance_type in SM_EFA_NCCL_INSTANCES
174+
assert instance_type in SM_EFA_RDMA_INSTANCES

0 commit comments

Comments
 (0)