Skip to content

Commit 19f6195

Browse files
committed
Add sagemaker dependency for remote function by default
1 parent a140cfc commit 19f6195

3 files changed

Lines changed: 555 additions & 0 deletions

File tree

src/sagemaker/remote_function/job.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1235,6 +1235,11 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
12351235

12361236
local_dependencies_path = RuntimeEnvironmentManager().snapshot(job_settings.dependencies)
12371237

1238+
# Ensure sagemaker dependency is included to prevent version mismatch issues
1239+
# Resolves issue where computing hash for integrity check changed in 2.256.0
1240+
local_dependencies_path = _ensure_sagemaker_dependency(local_dependencies_path)
1241+
job_settings.dependencies = local_dependencies_path
1242+
12381243
if step_compilation_context:
12391244
with _tmpdir() as tmp_dir:
12401245
script_and_dependencies_s3uri = _prepare_dependencies_and_pre_execution_scripts(
@@ -1291,6 +1296,113 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
12911296
return input_data_config
12921297

12931298

1299+
def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None:
1300+
"""Check if the sagemaker version requirement uses incompatible hashing.
1301+
1302+
Raises ValueError if the requirement would install a version that uses HMAC hashing
1303+
(which is incompatible with the current SHA256-based integrity checks).
1304+
1305+
Args:
1306+
sagemaker_requirement: The sagemaker requirement string (e.g., "sagemaker>=2.200.0")
1307+
1308+
Raises:
1309+
ValueError: If the requirement would install a version using HMAC hashing
1310+
"""
1311+
import re
1312+
from packaging.specifiers import SpecifierSet
1313+
from packaging import version as pkg_version
1314+
1315+
match = re.search(r'sagemaker\s*(.+)$', sagemaker_requirement.strip(), re.IGNORECASE)
1316+
if not match:
1317+
return
1318+
1319+
specifier_str = match.group(1).strip()
1320+
1321+
try:
1322+
specifier_set = SpecifierSet(specifier_str)
1323+
except Exception:
1324+
return
1325+
1326+
# Test if any HMAC version would satisfy the specifier
1327+
# V2 HMAC versions: < 2.256.0
1328+
v2_hmac_test_versions = ["2.0.0", "2.100.0", "2.200.0", "2.255.0", "2.255.1", "2.255.99"]
1329+
for test_version in v2_hmac_test_versions:
1330+
if test_version in specifier_set:
1331+
raise ValueError(
1332+
f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) "
1333+
f"could install a version using HMAC-based integrity checks which are incompatible "
1334+
f"with the current SHA256-based integrity checks. Please update to "
1335+
f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)."
1336+
)
1337+
1338+
# V3 HMAC versions: < 3.2.0
1339+
v3_hmac_test_versions = ["3.0.0", "3.0.1", "3.1.0", "3.1.99"]
1340+
for test_version in v3_hmac_test_versions:
1341+
if test_version in specifier_set:
1342+
raise ValueError(
1343+
f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) "
1344+
f"could install a version using HMAC-based integrity checks which are incompatible "
1345+
f"with the current SHA256-based integrity checks. Please update to "
1346+
f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)."
1347+
)
1348+
1349+
1350+
def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:
1351+
"""Ensure sagemaker>=2.256.0 is in the dependencies.
1352+
1353+
This function ensures that the remote environment has a compatible version of sagemaker
1354+
that includes the fix for the HMAC key security issue. Versions < 2.256.0 use HMAC-based
1355+
integrity checks which require the REMOTE_FUNCTION_SECRET_KEY environment variable.
1356+
Versions >= 2.256.0 use SHA256-based integrity checks which are secure and don't require
1357+
the secret key.
1358+
1359+
If no dependencies are provided, creates a temporary requirements.txt with sagemaker.
1360+
If dependencies are provided, appends sagemaker if not already present.
1361+
1362+
Args:
1363+
local_dependencies_path: Path to user's dependencies file or None
1364+
1365+
Returns:
1366+
Path to the dependencies file (created or modified)
1367+
1368+
Raises:
1369+
ValueError: If user has pinned sagemaker to a version using HMAC hashing
1370+
"""
1371+
import tempfile
1372+
1373+
SAGEMAKER_MIN_VERSION = "sagemaker>=2.256.0,<3.0.0"
1374+
1375+
if local_dependencies_path is None:
1376+
# Create a temporary requirements.txt in the system temp directory
1377+
# This avoids overwriting any user files in their working directory
1378+
fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_")
1379+
os.close(fd) # Close the file descriptor, we'll write to it ourselves
1380+
1381+
with open(req_file, "w") as f:
1382+
f.write(f"{SAGEMAKER_MIN_VERSION}\n")
1383+
logger.info("Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION)
1384+
return req_file
1385+
1386+
# If dependencies provided, ensure sagemaker is included
1387+
if local_dependencies_path.endswith(".txt"):
1388+
with open(local_dependencies_path, "r") as f:
1389+
content = f.read()
1390+
1391+
# Check if sagemaker is already specified
1392+
if "sagemaker" in content.lower():
1393+
# Extract the sagemaker requirement line for compatibility check
1394+
for line in content.split('\n'):
1395+
if 'sagemaker' in line.lower():
1396+
_check_sagemaker_version_compatibility(line.strip())
1397+
break
1398+
else:
1399+
with open(local_dependencies_path, "a") as f:
1400+
f.write(f"\n{SAGEMAKER_MIN_VERSION}\n")
1401+
logger.info("Appended %s to requirements.txt", SAGEMAKER_MIN_VERSION)
1402+
1403+
return local_dependencies_path
1404+
1405+
12941406
def _prepare_dependencies_and_pre_execution_scripts(
12951407
local_dependencies_path: str,
12961408
pre_execution_commands: List[str],
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""Integration tests for sagemaker dependency injection in remote functions.
2+
3+
These tests verify that the sagemaker>=2.256.0 dependency is properly injected
4+
into remote function jobs, preventing version mismatch issues.
5+
"""
6+
7+
import os
8+
import sys
9+
import tempfile
10+
import pytest
11+
12+
# Add src to path
13+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src'))
14+
15+
from sagemaker.remote_function import remote
16+
17+
18+
class TestRemoteFunctionDependencyInjection:
19+
"""Integration tests for dependency injection in remote functions."""
20+
21+
@pytest.mark.integ
22+
def test_remote_function_without_dependencies(self):
23+
"""Test remote function execution without explicit dependencies.
24+
25+
This test verifies that when no dependencies are provided, the remote
26+
function still executes successfully because sagemaker>=2.256.0 is
27+
automatically injected.
28+
"""
29+
@remote(
30+
instance_type="ml.m5.large",
31+
# No dependencies specified - sagemaker should be injected automatically
32+
)
33+
def simple_add(x, y):
34+
"""Simple function that adds two numbers."""
35+
return x + y
36+
37+
# Execute the function
38+
result = simple_add(5, 3)
39+
40+
# Verify result
41+
assert result == 8, f"Expected 8, got {result}"
42+
print("✓ Remote function without dependencies executed successfully")
43+
44+
@pytest.mark.integ
45+
def test_remote_function_with_user_dependencies_no_sagemaker(self):
46+
"""Test remote function with user dependencies but no sagemaker.
47+
48+
This test verifies that when user provides dependencies without sagemaker,
49+
sagemaker>=2.256.0 is automatically appended.
50+
"""
51+
# Create a temporary requirements.txt without sagemaker
52+
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
53+
f.write("numpy>=1.20.0\npandas>=1.3.0\n")
54+
req_file = f.name
55+
56+
try:
57+
@remote(
58+
instance_type="ml.m5.large",
59+
dependencies=req_file,
60+
)
61+
def compute_with_numpy(x):
62+
"""Function that uses numpy."""
63+
import numpy as np
64+
return np.array([x, x*2, x*3]).sum()
65+
66+
# Execute the function
67+
result = compute_with_numpy(5)
68+
69+
# Verify result (5 + 10 + 15 = 30)
70+
assert result == 30, f"Expected 30, got {result}"
71+
print("✓ Remote function with user dependencies executed successfully")
72+
finally:
73+
os.remove(req_file)
74+
75+
76+
class TestRemoteFunctionVersionCompatibility:
77+
"""Tests for version compatibility between local and remote environments."""
78+
79+
@pytest.mark.integ
80+
def test_deserialization_with_injected_sagemaker(self):
81+
"""Test that deserialization works with injected sagemaker dependency.
82+
83+
This test verifies that the remote environment can properly deserialize
84+
functions when sagemaker>=2.256.0 is available.
85+
"""
86+
@remote(
87+
instance_type="ml.m5.large",
88+
)
89+
def complex_computation(data):
90+
"""Function that performs complex computation."""
91+
result = sum(data) * len(data)
92+
return result
93+
94+
# Execute with various data types
95+
test_data = [1, 2, 3, 4, 5]
96+
result = complex_computation(test_data)
97+
98+
# Verify result (sum=15, len=5, 15*5=75)
99+
assert result == 75, f"Expected 75, got {result}"
100+
print("✓ Deserialization with injected sagemaker works correctly")
101+
102+
@pytest.mark.integ
103+
def test_multiple_remote_functions_with_dependencies(self):
104+
"""Test multiple remote functions with different dependency configurations.
105+
106+
This test verifies that the dependency injection works correctly
107+
when multiple remote functions are defined and executed.
108+
"""
109+
@remote(instance_type="ml.m5.large")
110+
def func1(x):
111+
return x + 1
112+
113+
@remote(instance_type="ml.m5.large")
114+
def func2(x):
115+
return x * 2
116+
117+
# Execute both functions
118+
result1 = func1(5)
119+
result2 = func2(5)
120+
121+
assert result1 == 6, f"func1: Expected 6, got {result1}"
122+
assert result2 == 10, f"func2: Expected 10, got {result2}"
123+
print("✓ Multiple remote functions with dependencies executed successfully")
124+
125+
126+
if __name__ == "__main__":
127+
pytest.main([__file__, "-v", "-m", "integ"])

0 commit comments

Comments
 (0)