Skip to content

Commit 422b35f

Browse files
authored
Add sagemaker dependency for remote function by default V3 (#5487)
* Add sagemaker dependency for remote function by default V3 * Adding unit and integ tests with minor bug fix to rf script
1 parent 33bf993 commit 422b35f

File tree

3 files changed

+681
-13
lines changed

3 files changed

+681
-13
lines changed

sagemaker-core/src/sagemaker/core/remote_function/job.py

Lines changed: 226 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,12 @@
175175
fi
176176
177177
printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n"
178-
printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function \\n"
179-
$conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function "$@"
178+
printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function \\n"
179+
$conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function "$@"
180180
else
181181
printf "INFO: No conda env provided. Invoking remote function\\n"
182-
printf "INFO: python -m sagemaker.train.remote_function.invoke_function \\n"
183-
python -m sagemaker.train.remote_function.invoke_function "$@"
182+
printf "INFO: python -m sagemaker.core.remote_function.invoke_function \\n"
183+
python -m sagemaker.core.remote_function.invoke_function "$@"
184184
fi
185185
"""
186186

@@ -234,14 +234,14 @@
234234
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
235235
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
236236
237-
python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n"
237+
python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n"
238238
$conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
239239
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
240240
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
241241
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
242242
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
243243
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
244-
python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@"
244+
python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@"
245245
246246
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1
247247
else
@@ -259,15 +259,15 @@
259259
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
260260
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
261261
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
262-
python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n"
262+
python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n"
263263
264264
mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
265265
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
266266
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
267267
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
268268
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
269269
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
270-
python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@"
270+
python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@"
271271
272272
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1
273273
else
@@ -320,18 +320,18 @@
320320
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n"
321321
printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
322322
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
323-
-m sagemaker.train.remote_function.invoke_function \\n"
323+
-m sagemaker.core.remote_function.invoke_function \\n"
324324
325325
$conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
326326
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
327-
-m sagemaker.train.remote_function.invoke_function "$@"
327+
-m sagemaker.core.remote_function.invoke_function "$@"
328328
else
329329
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
330330
printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
331-
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function \\n"
331+
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function \\n"
332332
333333
torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
334-
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function "$@"
334+
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function "$@"
335335
fi
336336
"""
337337

@@ -1259,7 +1259,215 @@ def _prepare_and_upload_runtime_scripts(
12591259
return upload_path
12601260

12611261

1262-
def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
1262+
def _decrement_version(version_str: str) -> str:
1263+
"""Decrement a version string by one minor or patch version.
1264+
1265+
Rules:
1266+
- If patch version is 0 (e.g., 3.2.0), decrement minor: 3.2.0 -> 3.1.0
1267+
- If patch version is not 0 (e.g., 3.1.2), decrement patch: 3.1.2 -> 3.1.1
1268+
1269+
Args:
1270+
version_str: Version string (e.g., "3.2.0")
1271+
1272+
Returns:
1273+
Decremented version string
1274+
"""
1275+
from packaging import version as pkg_version
1276+
1277+
try:
1278+
parsed = pkg_version.parse(version_str)
1279+
major = parsed.major
1280+
minor = parsed.minor
1281+
patch = parsed.micro
1282+
1283+
if patch == 0:
1284+
# Decrement minor version
1285+
minor = max(0, minor - 1)
1286+
else:
1287+
# Decrement patch version
1288+
patch = max(0, patch - 1)
1289+
1290+
return f"{major}.{minor}.{patch}"
1291+
except Exception:
1292+
return version_str
1293+
1294+
1295+
def _resolve_version_from_specifier(specifier_str: str) -> str:
1296+
"""Resolve the version to check based on upper bounds.
1297+
1298+
Upper bounds take priority. If upper bound is <4.0.0, it's safe (V3 only).
1299+
If no upper bound exists, it's safe (unbounded).
1300+
If the decremented upper bound is less than a lower bound, use the lower bound.
1301+
1302+
Args:
1303+
specifier_str: Version specifier string (e.g., ">=3.2.0", "<3.2.0", "==3.1.0")
1304+
1305+
Returns:
1306+
The resolved version string to check, or None if safe
1307+
"""
1308+
import re
1309+
from packaging import version as pkg_version
1310+
1311+
# Handle exact version pinning (==)
1312+
match = re.search(r'==\s*([\d.]+)', specifier_str)
1313+
if match:
1314+
return match.group(1)
1315+
1316+
# Extract lower bounds for comparison
1317+
lower_bounds = []
1318+
for match in re.finditer(r'>=\s*([\d.]+)', specifier_str):
1319+
lower_bounds.append(match.group(1))
1320+
1321+
# Handle upper bounds - find the most restrictive one
1322+
upper_bounds = []
1323+
1324+
# Find all <= bounds
1325+
for match in re.finditer(r'<=\s*([\d.]+)', specifier_str):
1326+
upper_bounds.append(('<=', match.group(1)))
1327+
1328+
# Find all < bounds
1329+
for match in re.finditer(r'<\s*([\d.]+)', specifier_str):
1330+
upper_bounds.append(('<', match.group(1)))
1331+
1332+
if upper_bounds:
1333+
# Sort by version to find the most restrictive (lowest) upper bound
1334+
upper_bounds.sort(key=lambda x: pkg_version.parse(x[1]))
1335+
operator, version = upper_bounds[0]
1336+
1337+
# Special case: if upper bound is <4.0.0, it's safe (V3 only)
1338+
try:
1339+
parsed_upper = pkg_version.parse(version)
1340+
if operator == '<' and parsed_upper.major == 4 and parsed_upper.minor == 0 and parsed_upper.micro == 0:
1341+
# <4.0.0 means V3 only, which is safe
1342+
return None
1343+
except Exception:
1344+
pass
1345+
1346+
resolved_version = version
1347+
if operator == '<':
1348+
resolved_version = _decrement_version(version)
1349+
1350+
# If we have a lower bound and the resolved version is less than it, use the lower bound
1351+
if lower_bounds:
1352+
try:
1353+
resolved_parsed = pkg_version.parse(resolved_version)
1354+
for lower_bound_str in lower_bounds:
1355+
lower_parsed = pkg_version.parse(lower_bound_str)
1356+
if resolved_parsed < lower_parsed:
1357+
resolved_version = lower_bound_str
1358+
except Exception:
1359+
pass
1360+
1361+
return resolved_version
1362+
1363+
# For lower bounds only (>=, >), we don't check
1364+
return None
1365+
1366+
1367+
def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None:
1368+
"""Check if the sagemaker version requirement uses incompatible hashing.
1369+
1370+
Raises ValueError if the requirement would install a version that uses HMAC hashing
1371+
(which is incompatible with the current SHA256-based integrity checks).
1372+
1373+
Args:
1374+
sagemaker_requirement: The sagemaker requirement string (e.g., "sagemaker>=3.2.0")
1375+
1376+
Raises:
1377+
ValueError: If the requirement would install a version using HMAC hashing
1378+
"""
1379+
import re
1380+
from packaging import version as pkg_version
1381+
1382+
match = re.search(r'sagemaker\s*(.+)$', sagemaker_requirement.strip(), re.IGNORECASE)
1383+
if not match:
1384+
return
1385+
1386+
specifier_str = match.group(1).strip()
1387+
1388+
# Resolve the version that would be installed
1389+
resolved_version_str = _resolve_version_from_specifier(specifier_str)
1390+
if not resolved_version_str:
1391+
# No upper bound or exact version, so we can't determine if it's bad
1392+
return
1393+
1394+
try:
1395+
resolved_version = pkg_version.parse(resolved_version_str)
1396+
except Exception:
1397+
return
1398+
1399+
# Define HMAC thresholds for each major version
1400+
v2_hmac_threshold = pkg_version.parse("2.256.0")
1401+
v3_hmac_threshold = pkg_version.parse("3.2.0")
1402+
1403+
# Check if the resolved version uses HMAC hashing
1404+
uses_hmac = False
1405+
if resolved_version.major == 2 and resolved_version < v2_hmac_threshold:
1406+
uses_hmac = True
1407+
elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold:
1408+
uses_hmac = True
1409+
1410+
if uses_hmac:
1411+
raise ValueError(
1412+
f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) "
1413+
f"could install a version using HMAC-based integrity checks which are incompatible "
1414+
f"with the current SHA256-based integrity checks. Please update to "
1415+
f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)."
1416+
)
1417+
1418+
1419+
def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:
1420+
"""Ensure sagemaker>=3.2.0 is in the dependencies.
1421+
1422+
This function ensures that the remote environment has a compatible version of sagemaker
1423+
that includes the fix for the HMAC key security issue. Versions < 3.2.0 use HMAC-based
1424+
integrity checks which require the REMOTE_FUNCTION_SECRET_KEY environment variable.
1425+
Versions >= 3.2.0 use SHA256-based integrity checks which are secure and don't require
1426+
the secret key.
1427+
1428+
If no dependencies are provided, creates a temporary requirements.txt with sagemaker.
1429+
If dependencies are provided, appends sagemaker if not already present.
1430+
1431+
Args:
1432+
local_dependencies_path: Path to user's dependencies file or None
1433+
1434+
Returns:
1435+
Path to the dependencies file (created or modified)
1436+
1437+
Raises:
1438+
ValueError: If user has pinned sagemaker to a version using HMAC hashing
1439+
"""
1440+
import tempfile
1441+
1442+
SAGEMAKER_MIN_VERSION = "sagemaker>=3.2.0,<4.0.0"
1443+
1444+
if local_dependencies_path is None:
1445+
fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_")
1446+
os.close(fd)
1447+
1448+
with open(req_file, "w") as f:
1449+
f.write(f"{SAGEMAKER_MIN_VERSION}\n")
1450+
logger.info("Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION)
1451+
return req_file
1452+
1453+
if local_dependencies_path.endswith(".txt"):
1454+
with open(local_dependencies_path, "r") as f:
1455+
content = f.read()
1456+
1457+
if "sagemaker" in content.lower():
1458+
for line in content.split('\n'):
1459+
if 'sagemaker' in line.lower():
1460+
_check_sagemaker_version_compatibility(line.strip())
1461+
break
1462+
else:
1463+
with open(local_dependencies_path, "a") as f:
1464+
f.write(f"\n{SAGEMAKER_MIN_VERSION}\n")
1465+
logger.info("Appended %s to requirements.txt", SAGEMAKER_MIN_VERSION)
1466+
1467+
return local_dependencies_path
1468+
1469+
1470+
def _generate_input_data_config(job_settings, s3_base_uri):
12631471
"""Generates input data config"""
12641472
from sagemaker.core.workflow.utilities import load_step_compilation_context
12651473

@@ -1288,6 +1496,11 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
12881496

12891497
local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies)
12901498

1499+
# Ensure sagemaker dependency is included to prevent version mismatch issues
1500+
# Resolves issue where computing hash for integrity check changed in 3.2.0
1501+
local_dependencies_path = _ensure_sagemaker_dependency(local_dependencies_path)
1502+
job_settings.dependencies = local_dependencies_path
1503+
12911504
if step_compilation_context:
12921505
with _tmpdir() as tmp_dir:
12931506
script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts(

0 commit comments

Comments
 (0)