@@ -1298,32 +1298,32 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
12981298
12991299def _decrement_version (version_str : str ) -> str :
13001300 """Decrement a version string by one minor or patch version.
1301-
1301+
13021302 Rules:
13031303 - If patch version is 0 (e.g., 2.256.0), decrement minor: 2.256.0 -> 2.255.0
13041304 - If patch version is not 0 (e.g., 2.254.2), decrement patch: 2.254.2 -> 2.254.1
1305-
1305+
13061306 Args:
13071307 version_str: Version string (e.g., "2.256.0")
1308-
1308+
13091309 Returns:
13101310 Decremented version string
13111311 """
13121312 from packaging import version as pkg_version
1313-
1313+
13141314 try :
13151315 parsed = pkg_version .parse (version_str )
13161316 major = parsed .major
13171317 minor = parsed .minor
13181318 patch = parsed .micro
1319-
1319+
13201320 if patch == 0 :
13211321 # Decrement minor version
13221322 minor = max (0 , minor - 1 )
13231323 else :
13241324 # Decrement patch version
13251325 patch = max (0 , patch - 1 )
1326-
1326+
13271327 return f"{ major } .{ minor } .{ patch } "
13281328 except Exception :
13291329 return version_str
@@ -1335,55 +1335,60 @@ def _resolve_version_from_specifier(specifier_str: str) -> str:
13351335 Upper bounds take priority. If upper bound is <3.0.0, it's safe (V2 only).
13361336 If no upper bound exists, it's safe (unbounded).
13371337 If the decremented upper bound is less than a lower bound, use the lower bound.
1338-
1338+
13391339 Args:
13401340 specifier_str: Version specifier string (e.g., ">=2.256.0", "<2.256.0", "==2.255.0")
1341-
1341+
13421342 Returns:
13431343 The resolved version string to check, or None if safe
13441344 """
13451345 import re
13461346 from packaging import version as pkg_version
1347-
1347+
13481348 # Handle exact version pinning (==)
1349- match = re .search (r' ==\s*([\d.]+)' , specifier_str )
1349+ match = re .search (r" ==\s*([\d.]+)" , specifier_str )
13501350 if match :
13511351 return match .group (1 )
1352-
1352+
13531353 # Extract lower bounds for comparison
13541354 lower_bounds = []
1355- for match in re .finditer (r' >=\s*([\d.]+)' , specifier_str ):
1355+ for match in re .finditer (r" >=\s*([\d.]+)" , specifier_str ):
13561356 lower_bounds .append (match .group (1 ))
1357-
1357+
13581358 # Handle upper bounds - find the most restrictive one
13591359 upper_bounds = []
1360-
1360+
13611361 # Find all <= bounds
1362- for match in re .finditer (r' <=\s*([\d.]+)' , specifier_str ):
1363- upper_bounds .append (('<=' , match .group (1 )))
1364-
1362+ for match in re .finditer (r" <=\s*([\d.]+)" , specifier_str ):
1363+ upper_bounds .append (("<=" , match .group (1 )))
1364+
13651365 # Find all < bounds
1366- for match in re .finditer (r' <\s*([\d.]+)' , specifier_str ):
1367- upper_bounds .append (('<' , match .group (1 )))
1368-
1366+ for match in re .finditer (r" <\s*([\d.]+)" , specifier_str ):
1367+ upper_bounds .append (("<" , match .group (1 )))
1368+
13691369 if upper_bounds :
13701370 # Sort by version to find the most restrictive (lowest) upper bound
13711371 upper_bounds .sort (key = lambda x : pkg_version .parse (x [1 ]))
13721372 operator , version = upper_bounds [0 ]
1373-
1373+
13741374 # Special case: if upper bound is <3.0.0, it's safe (V2 only)
13751375 try :
13761376 parsed_upper = pkg_version .parse (version )
1377- if operator == '<' and parsed_upper .major == 3 and parsed_upper .minor == 0 and parsed_upper .micro == 0 :
1377+ if (
1378+ operator == "<"
1379+ and parsed_upper .major == 3
1380+ and parsed_upper .minor == 0
1381+ and parsed_upper .micro == 0
1382+ ):
13781383 # <3.0.0 means V2 only, which is safe
13791384 return None
13801385 except Exception :
13811386 pass
1382-
1387+
13831388 resolved_version = version
1384- if operator == '<' :
1389+ if operator == "<" :
13851390 resolved_version = _decrement_version (version )
1386-
1391+
13871392 # If we have a lower bound and the resolved version is less than it, use the lower bound
13881393 if lower_bounds :
13891394 try :
@@ -1394,9 +1399,9 @@ def _resolve_version_from_specifier(specifier_str: str) -> str:
13941399 resolved_version = lower_bound_str
13951400 except Exception :
13961401 pass
1397-
1402+
13981403 return resolved_version
1399-
1404+
14001405 # For lower bounds only (>=, >), we don't check
14011406 return None
14021407
@@ -1415,35 +1420,35 @@ def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None:
14151420 """
14161421 import re
14171422 from packaging import version as pkg_version
1418-
1419- match = re .search (r' sagemaker\s*(.+)$' , sagemaker_requirement .strip (), re .IGNORECASE )
1423+
1424+ match = re .search (r" sagemaker\s*(.+)$" , sagemaker_requirement .strip (), re .IGNORECASE )
14201425 if not match :
14211426 return
14221427
14231428 specifier_str = match .group (1 ).strip ()
1424-
1429+
14251430 # Resolve the version that would be installed
14261431 resolved_version_str = _resolve_version_from_specifier (specifier_str )
14271432 if not resolved_version_str :
14281433 # No upper bound or exact version, so we can't determine if it's bad
14291434 return
1430-
1435+
14311436 try :
14321437 resolved_version = pkg_version .parse (resolved_version_str )
14331438 except Exception :
14341439 return
1435-
1440+
14361441 # Define HMAC thresholds for each major version
14371442 v2_hmac_threshold = pkg_version .parse ("2.256.0" )
14381443 v3_hmac_threshold = pkg_version .parse ("3.2.0" )
1439-
1444+
14401445 # Check if the resolved version uses HMAC hashing
14411446 uses_hmac = False
14421447 if resolved_version .major == 2 and resolved_version < v2_hmac_threshold :
14431448 uses_hmac = True
14441449 elif resolved_version .major == 3 and resolved_version < v3_hmac_threshold :
14451450 uses_hmac = True
1446-
1451+
14471452 if uses_hmac :
14481453 raise ValueError (
14491454 f"The sagemaker version specified in requirements.txt ({ sagemaker_requirement } ) "
@@ -1453,7 +1458,6 @@ def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None:
14531458 )
14541459
14551460
1456-
14571461def _ensure_sagemaker_dependency (local_dependencies_path : str ) -> str :
14581462 """Ensure sagemaker>=2.256.0 is in the dependencies.
14591463
@@ -1481,13 +1485,14 @@ def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:
14811485
14821486 if local_dependencies_path is None :
14831487 # Create a temporary requirements.txt in the system temp directory
1484- # This avoids overwriting any user files in their working directory
14851488 fd , req_file = tempfile .mkstemp (suffix = ".txt" , prefix = "sagemaker_requirements_" )
1486- os .close (fd ) # Close the file descriptor, we'll write to it ourselves
1489+ os .close (fd )
14871490
14881491 with open (req_file , "w" ) as f :
14891492 f .write (f"{ SAGEMAKER_MIN_VERSION } \n " )
1490- logger .info ("Created temporary requirements.txt at %s with %s" , req_file , SAGEMAKER_MIN_VERSION )
1493+ logger .info (
1494+ "Created temporary requirements.txt at %s with %s" , req_file , SAGEMAKER_MIN_VERSION
1495+ )
14911496 return req_file
14921497
14931498 # If dependencies provided, ensure sagemaker is included
@@ -1498,8 +1503,8 @@ def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:
14981503 # Check if sagemaker is already specified
14991504 if "sagemaker" in content .lower ():
15001505 # Extract the sagemaker requirement line for compatibility check
1501- for line in content .split (' \n ' ):
1502- if ' sagemaker' in line .lower ():
1506+ for line in content .split (" \n " ):
1507+ if " sagemaker" in line .lower ():
15031508 _check_sagemaker_version_compatibility (line .strip ())
15041509 break
15051510 else :
0 commit comments