Skip to content

Commit 31a145e

Browse files
author
Namrata Madan
committed
Revert "Add sagemaker dependency for remote function by default V3 (aws#5487)"
This reverts commit 422b35f.
1 parent a3ab7c6 commit 31a145e

3 files changed

Lines changed: 13 additions & 681 deletions

File tree

sagemaker-core/src/sagemaker/core/remote_function/job.py

Lines changed: 13 additions & 226 deletions
Original file line numberDiff line numberDiff line change
@@ -176,12 +176,12 @@
176176
fi
177177
178178
printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n"
179-
printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function \\n"
180-
$conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function "$@"
179+
printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function \\n"
180+
$conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function "$@"
181181
else
182182
printf "INFO: No conda env provided. Invoking remote function\\n"
183-
printf "INFO: python -m sagemaker.core.remote_function.invoke_function \\n"
184-
python -m sagemaker.core.remote_function.invoke_function "$@"
183+
printf "INFO: python -m sagemaker.train.remote_function.invoke_function \\n"
184+
python -m sagemaker.train.remote_function.invoke_function "$@"
185185
fi
186186
"""
187187

@@ -235,14 +235,14 @@
235235
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
236236
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
237237
238-
python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n"
238+
python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n"
239239
$conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
240240
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
241241
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
242242
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
243243
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
244244
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
245-
python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@"
245+
python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@"
246246
247247
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1
248248
else
@@ -260,15 +260,15 @@
260260
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
261261
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
262262
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
263-
python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n"
263+
python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n"
264264
265265
mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
266266
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
267267
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
268268
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
269269
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
270270
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
271-
python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@"
271+
python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@"
272272
273273
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1
274274
else
@@ -321,18 +321,18 @@
321321
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n"
322322
printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
323323
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
324-
-m sagemaker.core.remote_function.invoke_function \\n"
324+
-m sagemaker.train.remote_function.invoke_function \\n"
325325
326326
$conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
327327
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
328-
-m sagemaker.core.remote_function.invoke_function "$@"
328+
-m sagemaker.train.remote_function.invoke_function "$@"
329329
else
330330
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
331331
printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
332-
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function \\n"
332+
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function \\n"
333333
334334
torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
335-
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function "$@"
335+
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function "$@"
336336
fi
337337
"""
338338

@@ -1278,215 +1278,7 @@ def _prepare_and_upload_runtime_scripts(
12781278
return upload_path
12791279

12801280

1281-
def _decrement_version(version_str: str) -> str:
1282-
"""Decrement a version string by one minor or patch version.
1283-
1284-
Rules:
1285-
- If patch version is 0 (e.g., 3.2.0), decrement minor: 3.2.0 -> 3.1.0
1286-
- If patch version is not 0 (e.g., 3.1.2), decrement patch: 3.1.2 -> 3.1.1
1287-
1288-
Args:
1289-
version_str: Version string (e.g., "3.2.0")
1290-
1291-
Returns:
1292-
Decremented version string
1293-
"""
1294-
from packaging import version as pkg_version
1295-
1296-
try:
1297-
parsed = pkg_version.parse(version_str)
1298-
major = parsed.major
1299-
minor = parsed.minor
1300-
patch = parsed.micro
1301-
1302-
if patch == 0:
1303-
# Decrement minor version
1304-
minor = max(0, minor - 1)
1305-
else:
1306-
# Decrement patch version
1307-
patch = max(0, patch - 1)
1308-
1309-
return f"{major}.{minor}.{patch}"
1310-
except Exception:
1311-
return version_str
1312-
1313-
1314-
def _resolve_version_from_specifier(specifier_str: str) -> str:
1315-
"""Resolve the version to check based on upper bounds.
1316-
1317-
Upper bounds take priority. If upper bound is <4.0.0, it's safe (V3 only).
1318-
If no upper bound exists, it's safe (unbounded).
1319-
If the decremented upper bound is less than a lower bound, use the lower bound.
1320-
1321-
Args:
1322-
specifier_str: Version specifier string (e.g., ">=3.2.0", "<3.2.0", "==3.1.0")
1323-
1324-
Returns:
1325-
The resolved version string to check, or None if safe
1326-
"""
1327-
import re
1328-
from packaging import version as pkg_version
1329-
1330-
# Handle exact version pinning (==)
1331-
match = re.search(r'==\s*([\d.]+)', specifier_str)
1332-
if match:
1333-
return match.group(1)
1334-
1335-
# Extract lower bounds for comparison
1336-
lower_bounds = []
1337-
for match in re.finditer(r'>=\s*([\d.]+)', specifier_str):
1338-
lower_bounds.append(match.group(1))
1339-
1340-
# Handle upper bounds - find the most restrictive one
1341-
upper_bounds = []
1342-
1343-
# Find all <= bounds
1344-
for match in re.finditer(r'<=\s*([\d.]+)', specifier_str):
1345-
upper_bounds.append(('<=', match.group(1)))
1346-
1347-
# Find all < bounds
1348-
for match in re.finditer(r'<\s*([\d.]+)', specifier_str):
1349-
upper_bounds.append(('<', match.group(1)))
1350-
1351-
if upper_bounds:
1352-
# Sort by version to find the most restrictive (lowest) upper bound
1353-
upper_bounds.sort(key=lambda x: pkg_version.parse(x[1]))
1354-
operator, version = upper_bounds[0]
1355-
1356-
# Special case: if upper bound is <4.0.0, it's safe (V3 only)
1357-
try:
1358-
parsed_upper = pkg_version.parse(version)
1359-
if operator == '<' and parsed_upper.major == 4 and parsed_upper.minor == 0 and parsed_upper.micro == 0:
1360-
# <4.0.0 means V3 only, which is safe
1361-
return None
1362-
except Exception:
1363-
pass
1364-
1365-
resolved_version = version
1366-
if operator == '<':
1367-
resolved_version = _decrement_version(version)
1368-
1369-
# If we have a lower bound and the resolved version is less than it, use the lower bound
1370-
if lower_bounds:
1371-
try:
1372-
resolved_parsed = pkg_version.parse(resolved_version)
1373-
for lower_bound_str in lower_bounds:
1374-
lower_parsed = pkg_version.parse(lower_bound_str)
1375-
if resolved_parsed < lower_parsed:
1376-
resolved_version = lower_bound_str
1377-
except Exception:
1378-
pass
1379-
1380-
return resolved_version
1381-
1382-
# For lower bounds only (>=, >), we don't check
1383-
return None
1384-
1385-
1386-
def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None:
1387-
"""Check if the sagemaker version requirement uses incompatible hashing.
1388-
1389-
Raises ValueError if the requirement would install a version that uses HMAC hashing
1390-
(which is incompatible with the current SHA256-based integrity checks).
1391-
1392-
Args:
1393-
sagemaker_requirement: The sagemaker requirement string (e.g., "sagemaker>=3.2.0")
1394-
1395-
Raises:
1396-
ValueError: If the requirement would install a version using HMAC hashing
1397-
"""
1398-
import re
1399-
from packaging import version as pkg_version
1400-
1401-
match = re.search(r'sagemaker\s*(.+)$', sagemaker_requirement.strip(), re.IGNORECASE)
1402-
if not match:
1403-
return
1404-
1405-
specifier_str = match.group(1).strip()
1406-
1407-
# Resolve the version that would be installed
1408-
resolved_version_str = _resolve_version_from_specifier(specifier_str)
1409-
if not resolved_version_str:
1410-
# No upper bound or exact version, so we can't determine if it's bad
1411-
return
1412-
1413-
try:
1414-
resolved_version = pkg_version.parse(resolved_version_str)
1415-
except Exception:
1416-
return
1417-
1418-
# Define HMAC thresholds for each major version
1419-
v2_hmac_threshold = pkg_version.parse("2.256.0")
1420-
v3_hmac_threshold = pkg_version.parse("3.2.0")
1421-
1422-
# Check if the resolved version uses HMAC hashing
1423-
uses_hmac = False
1424-
if resolved_version.major == 2 and resolved_version < v2_hmac_threshold:
1425-
uses_hmac = True
1426-
elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold:
1427-
uses_hmac = True
1428-
1429-
if uses_hmac:
1430-
raise ValueError(
1431-
f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) "
1432-
f"could install a version using HMAC-based integrity checks which are incompatible "
1433-
f"with the current SHA256-based integrity checks. Please update to "
1434-
f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)."
1435-
)
1436-
1437-
1438-
def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:
1439-
"""Ensure sagemaker>=3.2.0 is in the dependencies.
1440-
1441-
This function ensures that the remote environment has a compatible version of sagemaker
1442-
that includes the fix for the HMAC key security issue. Versions < 3.2.0 use HMAC-based
1443-
integrity checks which require the REMOTE_FUNCTION_SECRET_KEY environment variable.
1444-
Versions >= 3.2.0 use SHA256-based integrity checks which are secure and don't require
1445-
the secret key.
1446-
1447-
If no dependencies are provided, creates a temporary requirements.txt with sagemaker.
1448-
If dependencies are provided, appends sagemaker if not already present.
1449-
1450-
Args:
1451-
local_dependencies_path: Path to user's dependencies file or None
1452-
1453-
Returns:
1454-
Path to the dependencies file (created or modified)
1455-
1456-
Raises:
1457-
ValueError: If user has pinned sagemaker to a version using HMAC hashing
1458-
"""
1459-
import tempfile
1460-
1461-
SAGEMAKER_MIN_VERSION = "sagemaker>=3.2.0,<4.0.0"
1462-
1463-
if local_dependencies_path is None:
1464-
fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_")
1465-
os.close(fd)
1466-
1467-
with open(req_file, "w") as f:
1468-
f.write(f"{SAGEMAKER_MIN_VERSION}\n")
1469-
logger.info("Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION)
1470-
return req_file
1471-
1472-
if local_dependencies_path.endswith(".txt"):
1473-
with open(local_dependencies_path, "r") as f:
1474-
content = f.read()
1475-
1476-
if "sagemaker" in content.lower():
1477-
for line in content.split('\n'):
1478-
if 'sagemaker' in line.lower():
1479-
_check_sagemaker_version_compatibility(line.strip())
1480-
break
1481-
else:
1482-
with open(local_dependencies_path, "a") as f:
1483-
f.write(f"\n{SAGEMAKER_MIN_VERSION}\n")
1484-
logger.info("Appended %s to requirements.txt", SAGEMAKER_MIN_VERSION)
1485-
1486-
return local_dependencies_path
1487-
1488-
1489-
def _generate_input_data_config(job_settings, s3_base_uri):
1281+
def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
14901282
"""Generates input data config"""
14911283
from sagemaker.core.workflow.utilities import load_step_compilation_context
14921284

@@ -1515,11 +1307,6 @@ def _generate_input_data_config(job_settings, s3_base_uri):
15151307

15161308
local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies)
15171309

1518-
# Ensure sagemaker dependency is included to prevent version mismatch issues
1519-
# Resolves issue where computing hash for integrity check changed in 3.2.0
1520-
local_dependencies_path = _ensure_sagemaker_dependency(local_dependencies_path)
1521-
job_settings.dependencies = local_dependencies_path
1522-
15231310
if step_compilation_context:
15241311
with _tmpdir() as tmp_dir:
15251312
script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts(

0 commit comments

Comments
 (0)