Skip to content

Commit 90f1f66

Browse files
committed
Add wheel install for remote function integ test to use local code
1 parent 08c6146 commit 90f1f66

File tree

2 files changed

+119
-63
lines changed

2 files changed

+119
-63
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Shared fixtures for remote function integration tests."""
2+
3+
import glob
4+
import os
5+
import subprocess
6+
import tempfile
7+
8+
import cloudpickle
9+
import pytest
10+
11+
from sagemaker.core.helper.session_helper import Session
12+
from sagemaker.core.s3 import S3Uploader
13+
14+
15+
def _get_repo_root():
16+
return os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
17+
18+
19+
def _build_and_upload_core_wheel(sagemaker_session):
20+
"""Build sagemaker-core wheel and upload to S3. Returns (s3_prefix, wheel_basename)."""
21+
repo_root = _get_repo_root()
22+
dist_dir = tempfile.mkdtemp(prefix="sagemaker_core_wheel_")
23+
24+
subprocess.run(
25+
f"python -m build --wheel --outdir {dist_dir}",
26+
shell=True,
27+
cwd=os.path.join(repo_root, "sagemaker-core"),
28+
check=True,
29+
)
30+
31+
matches = glob.glob(os.path.join(dist_dir, "sagemaker_core-*.whl"))
32+
if not matches:
33+
raise FileNotFoundError(f"No sagemaker-core wheel found in {dist_dir}")
34+
wheel_path = matches[0]
35+
36+
s3_prefix = f"s3://{sagemaker_session.default_bucket()}/remote-function-test/wheels"
37+
S3Uploader.upload(wheel_path, s3_prefix, sagemaker_session=sagemaker_session)
38+
39+
return s3_prefix, os.path.basename(wheel_path)
40+
41+
42+
@pytest.fixture(scope="module")
43+
def sagemaker_session():
44+
import boto3
45+
return Session(boto3.Session())
46+
47+
48+
@pytest.fixture(scope="module")
49+
def role(sagemaker_session):
50+
import boto3
51+
account_id = boto3.client("sts").get_caller_identity()["Account"]
52+
return f"arn:aws:iam::{account_id}:role/Admin"
53+
54+
55+
@pytest.fixture(scope="module")
56+
def image_uri(sagemaker_session):
57+
region = sagemaker_session.boto_region_name
58+
return f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
59+
60+
61+
@pytest.fixture(scope="module")
62+
def dev_sdk_pre_execution_commands(sagemaker_session):
63+
"""Build dev sagemaker-core wheel, upload to S3, and return pre_execution_commands."""
64+
s3_prefix, wheel_name = _build_and_upload_core_wheel(sagemaker_session)
65+
cp_version = cloudpickle.__version__
66+
return [
67+
f"pip install cloudpickle=={cp_version}",
68+
f"aws s3 cp {s3_prefix}/{wheel_name} /tmp/{wheel_name}",
69+
f"pip install /tmp/{wheel_name}",
70+
]

sagemaker-core/tests/integ/remote_function/test_sagemaker_dependency_injection.py

Lines changed: 49 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,6 @@
99
import tempfile
1010
import pytest
1111

12-
# Skip decorator for AWS configuration
13-
# skip_if_no_aws_region = pytest.mark.skipif(
14-
# not os.environ.get('AWS_DEFAULT_REGION'),
15-
# reason="AWS credentials not configured"
16-
# )
17-
1812
# Add src to path
1913
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src'))
2014

@@ -25,58 +19,47 @@ class TestRemoteFunctionDependencyInjection:
2519
"""Integration tests for dependency injection in remote functions."""
2620

2721
@pytest.mark.integ
28-
# @skip_if_no_aws_region
29-
def test_remote_function_without_dependencies(self):
30-
"""Test remote function execution without explicit dependencies.
31-
32-
This test verifies that when no dependencies are provided, the remote
33-
function still executes successfully because sagemaker>=3.2.0 is
34-
automatically injected.
35-
"""
22+
def test_remote_function_without_dependencies(
23+
self, dev_sdk_pre_execution_commands, role, image_uri, sagemaker_session
24+
):
25+
"""Test remote function execution without explicit dependencies."""
3626
@remote(
3727
instance_type="ml.m5.large",
38-
# No dependencies specified - sagemaker should be injected automatically
28+
role=role,
29+
image_uri=image_uri,
30+
sagemaker_session=sagemaker_session,
31+
pre_execution_commands=dev_sdk_pre_execution_commands,
3932
)
4033
def simple_add(x, y):
41-
"""Simple function that adds two numbers."""
4234
return x + y
43-
44-
# Execute the function
35+
4536
result = simple_add(5, 3)
46-
47-
# Verify result
4837
assert result == 8, f"Expected 8, got {result}"
49-
print("✓ Remote function without dependencies executed successfully")
5038

5139
@pytest.mark.integ
52-
# @skip_if_no_aws_region
53-
def test_remote_function_with_user_dependencies_no_sagemaker(self):
54-
"""Test remote function with user dependencies but no sagemaker.
55-
56-
This test verifies that when user provides dependencies without sagemaker,
57-
sagemaker>=3.2.0 is automatically appended.
58-
"""
59-
# Create a temporary requirements.txt without sagemaker
40+
def test_remote_function_with_user_dependencies_no_sagemaker(
41+
self, dev_sdk_pre_execution_commands, role, image_uri, sagemaker_session
42+
):
43+
"""Test remote function with user dependencies but no sagemaker."""
6044
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
6145
f.write("numpy>=1.20.0\npandas>=1.3.0\n")
6246
req_file = f.name
63-
47+
6448
try:
6549
@remote(
6650
instance_type="ml.m5.large",
51+
role=role,
52+
image_uri=image_uri,
53+
sagemaker_session=sagemaker_session,
6754
dependencies=req_file,
55+
pre_execution_commands=dev_sdk_pre_execution_commands,
6856
)
6957
def compute_with_numpy(x):
70-
"""Function that uses numpy."""
7158
import numpy as np
7259
return np.array([x, x*2, x*3]).sum()
73-
74-
# Execute the function
60+
7561
result = compute_with_numpy(5)
76-
77-
# Verify result (5 + 10 + 15 = 30)
7862
assert result == 30, f"Expected 30, got {result}"
79-
print("✓ Remote function with user dependencies executed successfully")
8063
finally:
8164
os.remove(req_file)
8265

@@ -85,52 +68,55 @@ class TestRemoteFunctionVersionCompatibility:
8568
"""Tests for version compatibility between local and remote environments."""
8669

8770
@pytest.mark.integ
88-
# @skip_if_no_aws_region
89-
def test_deserialization_with_injected_sagemaker(self):
90-
"""Test that deserialization works with injected sagemaker dependency.
91-
92-
This test verifies that the remote environment can properly deserialize
93-
functions when sagemaker>=3.2.0 is available.
94-
"""
71+
def test_deserialization_with_injected_sagemaker(
72+
self, dev_sdk_pre_execution_commands, role, image_uri, sagemaker_session
73+
):
74+
"""Test that deserialization works with injected sagemaker dependency."""
9575
@remote(
9676
instance_type="ml.m5.large",
77+
role=role,
78+
image_uri=image_uri,
79+
sagemaker_session=sagemaker_session,
80+
pre_execution_commands=dev_sdk_pre_execution_commands,
9781
)
9882
def complex_computation(data):
99-
"""Function that performs complex computation."""
10083
result = sum(data) * len(data)
10184
return result
102-
103-
# Execute with various data types
85+
10486
test_data = [1, 2, 3, 4, 5]
10587
result = complex_computation(test_data)
106-
107-
# Verify result (sum=15, len=5, 15*5=75)
10888
assert result == 75, f"Expected 75, got {result}"
109-
print("✓ Deserialization with injected sagemaker works correctly")
11089

11190
@pytest.mark.integ
112-
# @skip_if_no_aws_region
113-
def test_multiple_remote_functions_with_dependencies(self):
114-
"""Test multiple remote functions with different dependency configurations.
115-
116-
This test verifies that the dependency injection works correctly
117-
when multiple remote functions are defined and executed.
118-
"""
119-
@remote(instance_type="ml.m5.large")
91+
def test_multiple_remote_functions_with_dependencies(
92+
self, dev_sdk_pre_execution_commands, role, image_uri, sagemaker_session
93+
):
94+
"""Test multiple remote functions with different dependency configurations."""
95+
@remote(
96+
instance_type="ml.m5.large",
97+
role=role,
98+
image_uri=image_uri,
99+
sagemaker_session=sagemaker_session,
100+
pre_execution_commands=dev_sdk_pre_execution_commands,
101+
)
120102
def func1(x):
121103
return x + 1
122-
123-
@remote(instance_type="ml.m5.large")
104+
105+
@remote(
106+
instance_type="ml.m5.large",
107+
role=role,
108+
image_uri=image_uri,
109+
sagemaker_session=sagemaker_session,
110+
pre_execution_commands=dev_sdk_pre_execution_commands,
111+
)
124112
def func2(x):
125113
return x * 2
126-
127-
# Execute both functions
114+
128115
result1 = func1(5)
129116
result2 = func2(5)
130-
117+
131118
assert result1 == 6, f"func1: Expected 6, got {result1}"
132119
assert result2 == 10, f"func2: Expected 10, got {result2}"
133-
print("✓ Multiple remote functions with dependencies executed successfully")
134120

135121

136122
if __name__ == "__main__":

0 commit comments

Comments
 (0)