Skip to content

Commit 9475603

Browse files
committed
Fix codestyle issues
1 parent 2b870a3 commit 9475603

3 files changed

Lines changed: 113 additions & 101 deletions

File tree

src/sagemaker/remote_function/job.py

Lines changed: 45 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,32 +1298,32 @@ def _generate_input_data_config(job_settings: _JobSettings, s3_base_uri: str):
12981298

12991299
def _decrement_version(version_str: str) -> str:
13001300
"""Decrement a version string by one minor or patch version.
1301-
1301+
13021302
Rules:
13031303
- If patch version is 0 (e.g., 2.256.0), decrement minor: 2.256.0 -> 2.255.0
13041304
- If patch version is not 0 (e.g., 2.254.2), decrement patch: 2.254.2 -> 2.254.1
1305-
1305+
13061306
Args:
13071307
version_str: Version string (e.g., "2.256.0")
1308-
1308+
13091309
Returns:
13101310
Decremented version string
13111311
"""
13121312
from packaging import version as pkg_version
1313-
1313+
13141314
try:
13151315
parsed = pkg_version.parse(version_str)
13161316
major = parsed.major
13171317
minor = parsed.minor
13181318
patch = parsed.micro
1319-
1319+
13201320
if patch == 0:
13211321
# Decrement minor version
13221322
minor = max(0, minor - 1)
13231323
else:
13241324
# Decrement patch version
13251325
patch = max(0, patch - 1)
1326-
1326+
13271327
return f"{major}.{minor}.{patch}"
13281328
except Exception:
13291329
return version_str
@@ -1335,55 +1335,60 @@ def _resolve_version_from_specifier(specifier_str: str) -> str:
13351335
Upper bounds take priority. If upper bound is <3.0.0, it's safe (V2 only).
13361336
If no upper bound exists, it's safe (unbounded).
13371337
If the decremented upper bound is less than a lower bound, use the lower bound.
1338-
1338+
13391339
Args:
13401340
specifier_str: Version specifier string (e.g., ">=2.256.0", "<2.256.0", "==2.255.0")
1341-
1341+
13421342
Returns:
13431343
The resolved version string to check, or None if safe
13441344
"""
13451345
import re
13461346
from packaging import version as pkg_version
1347-
1347+
13481348
# Handle exact version pinning (==)
1349-
match = re.search(r'==\s*([\d.]+)', specifier_str)
1349+
match = re.search(r"==\s*([\d.]+)", specifier_str)
13501350
if match:
13511351
return match.group(1)
1352-
1352+
13531353
# Extract lower bounds for comparison
13541354
lower_bounds = []
1355-
for match in re.finditer(r'>=\s*([\d.]+)', specifier_str):
1355+
for match in re.finditer(r">=\s*([\d.]+)", specifier_str):
13561356
lower_bounds.append(match.group(1))
1357-
1357+
13581358
# Handle upper bounds - find the most restrictive one
13591359
upper_bounds = []
1360-
1360+
13611361
# Find all <= bounds
1362-
for match in re.finditer(r'<=\s*([\d.]+)', specifier_str):
1363-
upper_bounds.append(('<=', match.group(1)))
1364-
1362+
for match in re.finditer(r"<=\s*([\d.]+)", specifier_str):
1363+
upper_bounds.append(("<=", match.group(1)))
1364+
13651365
# Find all < bounds
1366-
for match in re.finditer(r'<\s*([\d.]+)', specifier_str):
1367-
upper_bounds.append(('<', match.group(1)))
1368-
1366+
for match in re.finditer(r"<\s*([\d.]+)", specifier_str):
1367+
upper_bounds.append(("<", match.group(1)))
1368+
13691369
if upper_bounds:
13701370
# Sort by version to find the most restrictive (lowest) upper bound
13711371
upper_bounds.sort(key=lambda x: pkg_version.parse(x[1]))
13721372
operator, version = upper_bounds[0]
1373-
1373+
13741374
# Special case: if upper bound is <3.0.0, it's safe (V2 only)
13751375
try:
13761376
parsed_upper = pkg_version.parse(version)
1377-
if operator == '<' and parsed_upper.major == 3 and parsed_upper.minor == 0 and parsed_upper.micro == 0:
1377+
if (
1378+
operator == "<"
1379+
and parsed_upper.major == 3
1380+
and parsed_upper.minor == 0
1381+
and parsed_upper.micro == 0
1382+
):
13781383
# <3.0.0 means V2 only, which is safe
13791384
return None
13801385
except Exception:
13811386
pass
1382-
1387+
13831388
resolved_version = version
1384-
if operator == '<':
1389+
if operator == "<":
13851390
resolved_version = _decrement_version(version)
1386-
1391+
13871392
# If we have a lower bound and the resolved version is less than it, use the lower bound
13881393
if lower_bounds:
13891394
try:
@@ -1394,9 +1399,9 @@ def _resolve_version_from_specifier(specifier_str: str) -> str:
13941399
resolved_version = lower_bound_str
13951400
except Exception:
13961401
pass
1397-
1402+
13981403
return resolved_version
1399-
1404+
14001405
# For lower bounds only (>=, >), we don't check
14011406
return None
14021407

@@ -1415,35 +1420,35 @@ def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None:
14151420
"""
14161421
import re
14171422
from packaging import version as pkg_version
1418-
1419-
match = re.search(r'sagemaker\s*(.+)$', sagemaker_requirement.strip(), re.IGNORECASE)
1423+
1424+
match = re.search(r"sagemaker\s*(.+)$", sagemaker_requirement.strip(), re.IGNORECASE)
14201425
if not match:
14211426
return
14221427

14231428
specifier_str = match.group(1).strip()
1424-
1429+
14251430
# Resolve the version that would be installed
14261431
resolved_version_str = _resolve_version_from_specifier(specifier_str)
14271432
if not resolved_version_str:
14281433
# No upper bound or exact version, so we can't determine if it's bad
14291434
return
1430-
1435+
14311436
try:
14321437
resolved_version = pkg_version.parse(resolved_version_str)
14331438
except Exception:
14341439
return
1435-
1440+
14361441
# Define HMAC thresholds for each major version
14371442
v2_hmac_threshold = pkg_version.parse("2.256.0")
14381443
v3_hmac_threshold = pkg_version.parse("3.2.0")
1439-
1444+
14401445
# Check if the resolved version uses HMAC hashing
14411446
uses_hmac = False
14421447
if resolved_version.major == 2 and resolved_version < v2_hmac_threshold:
14431448
uses_hmac = True
14441449
elif resolved_version.major == 3 and resolved_version < v3_hmac_threshold:
14451450
uses_hmac = True
1446-
1451+
14471452
if uses_hmac:
14481453
raise ValueError(
14491454
f"The sagemaker version specified in requirements.txt ({sagemaker_requirement}) "
@@ -1453,7 +1458,6 @@ def _check_sagemaker_version_compatibility(sagemaker_requirement: str) -> None:
14531458
)
14541459

14551460

1456-
14571461
def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:
14581462
"""Ensure sagemaker>=2.256.0 is in the dependencies.
14591463
@@ -1481,13 +1485,14 @@ def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:
14811485

14821486
if local_dependencies_path is None:
14831487
# Create a temporary requirements.txt in the system temp directory
1484-
# This avoids overwriting any user files in their working directory
14851488
fd, req_file = tempfile.mkstemp(suffix=".txt", prefix="sagemaker_requirements_")
1486-
os.close(fd) # Close the file descriptor, we'll write to it ourselves
1489+
os.close(fd)
14871490

14881491
with open(req_file, "w") as f:
14891492
f.write(f"{SAGEMAKER_MIN_VERSION}\n")
1490-
logger.info("Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION)
1493+
logger.info(
1494+
"Created temporary requirements.txt at %s with %s", req_file, SAGEMAKER_MIN_VERSION
1495+
)
14911496
return req_file
14921497

14931498
# If dependencies provided, ensure sagemaker is included
@@ -1498,8 +1503,8 @@ def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:
14981503
# Check if sagemaker is already specified
14991504
if "sagemaker" in content.lower():
15001505
# Extract the sagemaker requirement line for compatibility check
1501-
for line in content.split('\n'):
1502-
if 'sagemaker' in line.lower():
1506+
for line in content.split("\n"):
1507+
if "sagemaker" in line.lower():
15031508
_check_sagemaker_version_compatibility(line.strip())
15041509
break
15051510
else:

tests/integ/sagemaker/remote_function/test_sagemaker_dependency_injection.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,11 @@
1111

1212
# Skip decorator for AWS configuration
1313
skip_if_no_aws_region = pytest.mark.skipif(
14-
not os.environ.get('AWS_DEFAULT_REGION'),
15-
reason="AWS credentials not configured"
14+
not os.environ.get('AWS_DEFAULT_REGION'), reason="AWS credentials not configured"
1615
)
1716

1817
# Add src to path
19-
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src'))
18+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../../../src"))
2019

2120
from sagemaker.remote_function import remote
2221

@@ -28,22 +27,23 @@ class TestRemoteFunctionDependencyInjection:
2827
@skip_if_no_aws_region
2928
def test_remote_function_without_dependencies(self):
3029
"""Test remote function execution without explicit dependencies.
31-
30+
3231
This test verifies that when no dependencies are provided, the remote
3332
function still executes successfully because sagemaker>=2.256.0 is
3433
automatically injected.
3534
"""
35+
3636
@remote(
3737
instance_type="ml.m5.large",
3838
# No dependencies specified - sagemaker should be injected automatically
3939
)
4040
def simple_add(x, y):
4141
"""Simple function that adds two numbers."""
4242
return x + y
43-
43+
4444
# Execute the function
4545
result = simple_add(5, 3)
46-
46+
4747
# Verify result
4848
assert result == 8, f"Expected 8, got {result}"
4949
print("✓ Remote function without dependencies executed successfully")
@@ -52,15 +52,15 @@ def simple_add(x, y):
5252
@skip_if_no_aws_region
5353
def test_remote_function_with_user_dependencies_no_sagemaker(self):
5454
"""Test remote function with user dependencies but no sagemaker.
55-
55+
5656
This test verifies that when user provides dependencies without sagemaker,
5757
sagemaker>=2.256.0 is automatically appended.
5858
"""
5959
# Create a temporary requirements.txt without sagemaker
60-
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
60+
with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False) as f:
6161
f.write("numpy>=1.20.0\npandas>=1.3.0\n")
6262
req_file = f.name
63-
63+
6464
try:
6565
@remote(
6666
instance_type="ml.m5.large",
@@ -69,11 +69,12 @@ def test_remote_function_with_user_dependencies_no_sagemaker(self):
6969
def compute_with_numpy(x):
7070
"""Function that uses numpy."""
7171
import numpy as np
72+
7273
return np.array([x, x*2, x*3]).sum()
73-
74+
7475
# Execute the function
7576
result = compute_with_numpy(5)
76-
77+
7778
# Verify result (5 + 10 + 15 = 30)
7879
assert result == 30, f"Expected 30, got {result}"
7980
print("✓ Remote function with user dependencies executed successfully")
@@ -88,22 +89,23 @@ class TestRemoteFunctionVersionCompatibility:
8889
@skip_if_no_aws_region
8990
def test_deserialization_with_injected_sagemaker(self):
9091
"""Test that deserialization works with injected sagemaker dependency.
91-
92+
9293
This test verifies that the remote environment can properly deserialize
9394
functions when sagemaker>=2.256.0 is available.
9495
"""
96+
9597
@remote(
9698
instance_type="ml.m5.large",
9799
)
98100
def complex_computation(data):
99101
"""Function that performs complex computation."""
100102
result = sum(data) * len(data)
101103
return result
102-
104+
103105
# Execute with various data types
104106
test_data = [1, 2, 3, 4, 5]
105107
result = complex_computation(test_data)
106-
108+
107109
# Verify result (sum=15, len=5, 15*5=75)
108110
assert result == 75, f"Expected 75, got {result}"
109111
print("✓ Deserialization with injected sagemaker works correctly")
@@ -112,22 +114,23 @@ def complex_computation(data):
112114
@skip_if_no_aws_region
113115
def test_multiple_remote_functions_with_dependencies(self):
114116
"""Test multiple remote functions with different dependency configurations.
115-
117+
116118
This test verifies that the dependency injection works correctly
117119
when multiple remote functions are defined and executed.
118120
"""
121+
119122
@remote(instance_type="ml.m5.large")
120123
def func1(x):
121124
return x + 1
122-
125+
123126
@remote(instance_type="ml.m5.large")
124127
def func2(x):
125128
return x * 2
126-
129+
127130
# Execute both functions
128131
result1 = func1(5)
129132
result2 = func2(5)
130-
133+
131134
assert result1 == 6, f"func1: Expected 6, got {result1}"
132135
assert result2 == 10, f"func2: Expected 10, got {result2}"
133136
print("✓ Multiple remote functions with dependencies executed successfully")

0 commit comments

Comments
 (0)