99import tempfile
1010import pytest
1111
12- # Skip decorator for AWS configuration
13- # skip_if_no_aws_region = pytest.mark.skipif(
14- # not os.environ.get('AWS_DEFAULT_REGION'),
15- # reason="AWS credentials not configured"
16- # )
17-
1812# Add src to path
1913sys .path .insert (0 , os .path .join (os .path .dirname (__file__ ), '../../../../src' ))
2014
@@ -25,58 +19,47 @@ class TestRemoteFunctionDependencyInjection:
2519 """Integration tests for dependency injection in remote functions."""
2620
2721 @pytest .mark .integ
28- # @skip_if_no_aws_region
29- def test_remote_function_without_dependencies (self ):
30- """Test remote function execution without explicit dependencies.
31-
32- This test verifies that when no dependencies are provided, the remote
33- function still executes successfully because sagemaker>=3.2.0 is
34- automatically injected.
35- """
22+ def test_remote_function_without_dependencies (
23+ self , dev_sdk_pre_execution_commands , role , image_uri , sagemaker_session
24+ ):
25+ """Test remote function execution without explicit dependencies."""
3626 @remote (
3727 instance_type = "ml.m5.large" ,
38- # No dependencies specified - sagemaker should be injected automatically
28+ role = role ,
29+ image_uri = image_uri ,
30+ sagemaker_session = sagemaker_session ,
31+ pre_execution_commands = dev_sdk_pre_execution_commands ,
3932 )
4033 def simple_add (x , y ):
41- """Simple function that adds two numbers."""
4234 return x + y
43-
44- # Execute the function
35+
4536 result = simple_add (5 , 3 )
46-
47- # Verify result
4837 assert result == 8 , f"Expected 8, got { result } "
49- print ("✓ Remote function without dependencies executed successfully" )
5038
5139 @pytest .mark .integ
52- # @skip_if_no_aws_region
53- def test_remote_function_with_user_dependencies_no_sagemaker (self ):
54- """Test remote function with user dependencies but no sagemaker.
55-
56- This test verifies that when user provides dependencies without sagemaker,
57- sagemaker>=3.2.0 is automatically appended.
58- """
59- # Create a temporary requirements.txt without sagemaker
40+ def test_remote_function_with_user_dependencies_no_sagemaker (
41+ self , dev_sdk_pre_execution_commands , role , image_uri , sagemaker_session
42+ ):
43+ """Test remote function with user dependencies but no sagemaker."""
6044 with tempfile .NamedTemporaryFile (mode = 'w' , suffix = '.txt' , delete = False ) as f :
6145 f .write ("numpy>=1.20.0\n pandas>=1.3.0\n " )
6246 req_file = f .name
63-
47+
6448 try :
6549 @remote (
6650 instance_type = "ml.m5.large" ,
51+ role = role ,
52+ image_uri = image_uri ,
53+ sagemaker_session = sagemaker_session ,
6754 dependencies = req_file ,
55+ pre_execution_commands = dev_sdk_pre_execution_commands ,
6856 )
6957 def compute_with_numpy (x ):
70- """Function that uses numpy."""
7158 import numpy as np
7259 return np .array ([x , x * 2 , x * 3 ]).sum ()
73-
74- # Execute the function
60+
7561 result = compute_with_numpy (5 )
76-
77- # Verify result (5 + 10 + 15 = 30)
7862 assert result == 30 , f"Expected 30, got { result } "
79- print ("✓ Remote function with user dependencies executed successfully" )
8063 finally :
8164 os .remove (req_file )
8265
@@ -85,52 +68,55 @@ class TestRemoteFunctionVersionCompatibility:
8568 """Tests for version compatibility between local and remote environments."""
8669
8770 @pytest .mark .integ
88- # @skip_if_no_aws_region
89- def test_deserialization_with_injected_sagemaker (self ):
90- """Test that deserialization works with injected sagemaker dependency.
91-
92- This test verifies that the remote environment can properly deserialize
93- functions when sagemaker>=3.2.0 is available.
94- """
71+ def test_deserialization_with_injected_sagemaker (
72+ self , dev_sdk_pre_execution_commands , role , image_uri , sagemaker_session
73+ ):
74+ """Test that deserialization works with injected sagemaker dependency."""
9575 @remote (
9676 instance_type = "ml.m5.large" ,
77+ role = role ,
78+ image_uri = image_uri ,
79+ sagemaker_session = sagemaker_session ,
80+ pre_execution_commands = dev_sdk_pre_execution_commands ,
9781 )
9882 def complex_computation (data ):
99- """Function that performs complex computation."""
10083 result = sum (data ) * len (data )
10184 return result
102-
103- # Execute with various data types
85+
10486 test_data = [1 , 2 , 3 , 4 , 5 ]
10587 result = complex_computation (test_data )
106-
107- # Verify result (sum=15, len=5, 15*5=75)
10888 assert result == 75 , f"Expected 75, got { result } "
109- print ("✓ Deserialization with injected sagemaker works correctly" )
11089
11190 @pytest .mark .integ
112- # @skip_if_no_aws_region
113- def test_multiple_remote_functions_with_dependencies (self ):
114- """Test multiple remote functions with different dependency configurations.
115-
116- This test verifies that the dependency injection works correctly
117- when multiple remote functions are defined and executed.
118- """
119- @remote (instance_type = "ml.m5.large" )
91+ def test_multiple_remote_functions_with_dependencies (
92+ self , dev_sdk_pre_execution_commands , role , image_uri , sagemaker_session
93+ ):
94+ """Test multiple remote functions with different dependency configurations."""
95+ @remote (
96+ instance_type = "ml.m5.large" ,
97+ role = role ,
98+ image_uri = image_uri ,
99+ sagemaker_session = sagemaker_session ,
100+ pre_execution_commands = dev_sdk_pre_execution_commands ,
101+ )
120102 def func1 (x ):
121103 return x + 1
122-
123- @remote (instance_type = "ml.m5.large" )
104+
105+ @remote (
106+ instance_type = "ml.m5.large" ,
107+ role = role ,
108+ image_uri = image_uri ,
109+ sagemaker_session = sagemaker_session ,
110+ pre_execution_commands = dev_sdk_pre_execution_commands ,
111+ )
124112 def func2 (x ):
125113 return x * 2
126-
127- # Execute both functions
114+
128115 result1 = func1 (5 )
129116 result2 = func2 (5 )
130-
117+
131118 assert result1 == 6 , f"func1: Expected 6, got { result1 } "
132119 assert result2 == 10 , f"func2: Expected 10, got { result2 } "
133- print ("✓ Multiple remote functions with dependencies executed successfully" )
134120
135121
136122if __name__ == "__main__" :
0 commit comments