Skip to content

Commit 5d84038

Browse files
committed
Revise sagemaker compatibility check
1 parent 19f6195 commit 5d84038

2 files changed

Lines changed: 142 additions & 30 deletions

File tree

src/sagemaker/remote_function/job.py

Lines changed: 132 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1296,6 +1296,111 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
12961296
return input_data_config
12971297

12981298

1299+
def _decrement_version(version_str: str) -> str:
1300+
"""Decrement a version string by one minor or patch version.
1301+
1302+
Rules:
1303+
- If patch version is 0 (e.g., 2.256.0), decrement minor: 2.256.0 -> 2.255.0
1304+
- If patch version is not 0 (e.g., 2.254.2), decrement patch: 2.254.2 -> 2.254.1
1305+
1306+
Args:
1307+
version_str: Version string (e.g., "2.256.0")
1308+
1309+
Returns:
1310+
Decremented version string
1311+
"""
1312+
from packaging import version as pkg_version
1313+
1314+
try:
1315+
parsed = pkg_version.parse(version_str)
1316+
major = parsed.major
1317+
minor = parsed.minor
1318+
patch = parsed.micro
1319+
1320+
if patch == 0:
1321+
# Decrement minor version
1322+
minor = max(0, minor - 1)
1323+
else:
1324+
# Decrement patch version
1325+
patch = max(0, patch - 1)
1326+
1327+
return f"{major}.{minor}.{patch}"
1328+
except Exception:
1329+
return version_str
1330+
1331+
1332+
def _resolve_version_from_specifier(specifier_str: str) -> str:
1333+
"""Resolve the version to check based on upper bounds.
1334+
1335+
Upper bounds take priority. If upper bound is <3.0.0, it's safe (V2 only).
1336+
If no upper bound exists, it's safe (unbounded).
1337+
If the decremented upper bound is less than a lower bound, use the lower bound.
1338+
1339+
Args:
1340+
specifier_str: Version specifier string (e.g., ">=2.256.0", "<2.256.0", "==2.255.0")
1341+
1342+
Returns:
1343+
The resolved version string to check, or None if safe
1344+
"""
1345+
import re
1346+
from packaging import version as pkg_version
1347+
1348+
# Handle exact version pinning (==)
1349+
match = re.search(r'==\s*([\d.]+)', specifier_str)
1350+
if match:
1351+
return match.group(1)
1352+
1353+
# Extract lower bounds for comparison
1354+
lower_bounds = []
1355+
for match in re.finditer(r'>=\s*([\d.]+)', specifier_str):
1356+
lower_bounds.append(match.group(1))
1357+
1358+
# Handle upper bounds - find the most restrictive one
1359+
upper_bounds = []
1360+
1361+
# Find all <= bounds
1362+
for match in re.finditer(r'<=\s*([\d.]+)', specifier_str):
1363+
upper_bounds.append(('<=', match.group(1)))
1364+
1365+
# Find all < bounds
1366+
for match in re.finditer(r'<\s*([\d.]+)', specifier_str):
1367+
upper_bounds.append(('<', match.group(1)))
1368+
1369+
if upper_bounds:
1370+
# Sort by version to find the most restrictive (lowest) upper bound
1371+
upper_bounds.sort(key=lambda x: pkg_version.parse(x[1]))
1372+
operator, version = upper_bounds[0]
1373+
1374+
# Special case: if upper bound is <3.0.0, it's safe (V2 only)
1375+
try:
1376+
parsed_upper = pkg_version.parse(version)
1377+
if operator == '<' and parsed_upper.major == 3 and parsed_upper.minor == 0 and parsed_upper.micro == 0:
1378+
# <3.0.0 means V2 only, which is safe
1379+
return None
1380+
except Exception:
1381+
pass
1382+
1383+
resolved_version = version
1384+
if operator == '<':
1385+
resolved_version = _decrement_version(version)
1386+
1387+
# If we have a lower bound and the resolved version is less than it, use the lower bound
1388+
if lower_bounds:
1389+
try:
1390+
resolved_parsed = pkg_version.parse(resolved_version)
1391+
for lower_bound_str in lower_bounds:
1392+
lower_parsed = pkg_version.parse(lower_bound_str)
1393+
if resolved_parsed < lower_parsed:
1394+
resolved_version = lower_bound_str
1395+
except Exception:
1396+
pass
1397+
1398+
return resolved_version
1399+
1400+
# For lower bounds only (>=, >), we don't check
1401+
return None
1402+
1403+
12991404
def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None:
13001405
"""Check if the sagemaker version requirement uses incompatible hashing.
13011406
@@ -1309,42 +1414,44 @@ def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None:
13091414
ValueError: If the requirement would install a version using HMAC hashing
13101415
"""
13111416
import re
1312-
from packaging.specifiers import SpecifierSet
13131417
from packaging import version as pkg_version
1314-
1418+
13151419
match = re.search(r'sagemaker\s*(.+)$', sagemaker_requirement.strip(), re.IGNORECASE)
13161420
if not match:
13171421
return
13181422

13191423
specifier_str = match.group(1).strip()
13201424

1425+
# Resolve the version that would be installed
1426+
resolved_version_str = _resolve_version_from_specifier(specifier_str)
1427+
if not resolved_version_str:
1428+
# No upper bound or exact version, so we can't determine if it's bad
1429+
return
1430+
13211431
try:
1322-
specifier_set = SpecifierSet(specifier_str)
1432+
resolved_version = pkg_version.parse(resolved_version_str)
13231433
except Exception:
13241434
return
1325-
1326-
# Test if any HMAC version would satisfy the specifier
1327-
# V2 HMAC versions: < 2.256.0
1328-
v2_hmac_test_versions = ["2.0.0", "2.100.0", "2.200.0", "2.255.0", "2.255.1", "2.255.99"]
1329-
for test_version in v2_hmac_test_versions:
1330-
if test_version in specifier_set:
1331-
raise ValueError(
1332-
f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) "
1333-
f"could install a version using HMAC-based integrity checks which are incompatible "
1334-
f"with the current SHA256-based integrity checks. Please update to "
1335-
f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)."
1336-
)
13371435

1338-
# V3 HMAC versions: < 3.2.0
1339-
v3_hmac_test_versions = ["3.0.0", "3.0.1", "3.1.0", "3.1.99"]
1340-
for test_version in v3_hmac_test_versions:
1341-
if test_version in specifier_set:
1342-
raise ValueError(
1343-
f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) "
1344-
f"could install a version using HMAC-based integrity checks which are incompatible "
1345-
f"with the current SHA256-based integrity checks. Please update to "
1346-
f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)."
1347-
)
1436+
# Define HMAC thresholds for each major version
1437+
v2_hmac_threshold = pkg_version.parse("2.256.0")
1438+
v3_hmac_threshold = pkg_version.parse("3.2.0")
1439+
1440+
# Check if the resolved version uses HMAC hashing
1441+
uses_hmac = False
1442+
if resolved_version.major == 2 and resolved_version < v2_hmac_threshold:
1443+
uses_hmac = True
1444+
elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold:
1445+
uses_hmac = True
1446+
1447+
if uses_hmac:
1448+
raise ValueError(
1449+
f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) "
1450+
f"could install a version using HMAC-based integrity checks which are incompatible "
1451+
f"with the current SHA256-based integrity checks. Please update to "
1452+
f"sagemaker>=2.256.0,<3.0.0 (for V2) or sagemaker>=3.2.0,<4.0.0 (for V3)."
1453+
)
1454+
13481455

13491456

13501457
def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:

tests/unit/sagemaker/remote_function/test_ensure_sagemaker_dependency.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,9 +264,9 @@ def test_v2_bad_less_equal_255(self):
264264
_check_sagemaker_version_compatibility("sagemaker<=2.255.0")
265265

266266
def test_v2_bad_greater_than_255_0(self):
267-
"""Test V2 greater than 2.255.0 (bad - could be 2.255.1 with HMAC)."""
268-
with self.assertRaises(ValueError):
269-
_check_sagemaker_version_compatibility("sagemaker>2.255.0")
267+
"""Test V2 greater than 2.255.0 (not checked - treat as lower bound only)."""
268+
# Should not raise - > is treated as a lower bound, we don't check those
269+
_check_sagemaker_version_compatibility("sagemaker>2.255.0")
270270

271271
def test_v2_bad_range_200_to_255(self):
272272
"""Test V2 range 2.200.0 to 2.255.0 (bad - HMAC)."""
@@ -305,11 +305,16 @@ def test_multiple_version_specifiers_good(self):
305305
# Should not raise
306306
_check_sagemaker_version_compatibility("sagemaker>=2.256.0,<3.0.0")
307307

308+
def test_multiple_version_specifiers_good_with_lower_bound(self):
309+
"""Test multiple version specifiers that are good (upper bound resolves to good version)."""
310+
# Should not raise - <2.300.0 decrements to 2.299.0 which is >= 2.256.0
311+
_check_sagemaker_version_compatibility("sagemaker>=2.200.0,<2.300.0")
312+
308313
def test_multiple_version_specifiers_bad(self):
309314
"""Test multiple version specifiers that are bad."""
310-
# Should raise because lower bound is < 2.256.0
315+
# Should raise - <2.256.0 decrements to 2.255.0 which is < 2.256.0 (HMAC)
311316
with self.assertRaises(ValueError):
312-
_check_sagemaker_version_compatibility("sagemaker>=2.200.0,<2.300.0")
317+
_check_sagemaker_version_compatibility("sagemaker>=2.200.0,<2.256.0")
313318

314319

315320
if __name__ == "__main__":

0 commit comments

Comments
 (0)