Skip to content

Commit b1a6d54

Browse files
committed
Adding unit and integ tests with minor bug fix to rf script
1 parent ea789a4 commit b1a6d54

3 files changed

Lines changed: 468 additions & 13 deletions

File tree

sagemaker-core/src/sagemaker/core/remote_function/job.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,12 @@
175175
fi
176176
177177
printf "INFO: Invoking remote function inside conda environment: $conda_env.\\n"
178-
printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function \\n"
179-
$conda_exe run -n $conda_env python -m sagemaker.train.remote_function.invoke_function "$@"
178+
printf "INFO: $conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function \\n"
179+
$conda_exe run -n $conda_env python -m sagemaker.core.remote_function.invoke_function "$@"
180180
else
181181
printf "INFO: No conda env provided. Invoking remote function\\n"
182-
printf "INFO: python -m sagemaker.train.remote_function.invoke_function \\n"
183-
python -m sagemaker.train.remote_function.invoke_function "$@"
182+
printf "INFO: python -m sagemaker.core.remote_function.invoke_function \\n"
183+
python -m sagemaker.core.remote_function.invoke_function "$@"
184184
fi
185185
"""
186186

@@ -234,14 +234,14 @@
234234
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
235235
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
236236
237-
python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n"
237+
python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n"
238238
$conda_exe run -n $conda_env mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
239239
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
240240
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
241241
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
242242
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
243243
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
244-
python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@"
244+
python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@"
245245
246246
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1
247247
else
@@ -259,15 +259,15 @@
259259
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
260260
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
261261
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
262-
python -m mpi4py -m sagemaker.train.remote_function.invoke_function \\n"
262+
python -m mpi4py -m sagemaker.core.remote_function.invoke_function \\n"
263263
264264
mpirun --host $SM_HOSTS_LIST -np $SM_NPROC_PER_NODE \
265265
--allow-run-as-root --display-map --tag-output -mca btl_tcp_if_include $SM_NETWORK_INTERFACE_NAME \
266266
-mca plm_rsh_no_tree_spawn 1 -mca pml ob1 -mca btl ^openib -mca orte_abort_on_non_zero_status 1 \
267267
-mca btl_vader_single_copy_mechanism none -mca plm_rsh_num_concurrent $SM_HOST_COUNT \
268268
-x NCCL_SOCKET_IFNAME=$SM_NETWORK_INTERFACE_NAME -x LD_LIBRARY_PATH -x PATH \
269269
$SM_FI_PROVIDER $SM_NCCL_PROTO $SM_FI_EFA_USE_DEVICE_RDMA \
270-
python -m mpi4py -m sagemaker.train.remote_function.invoke_function "$@"
270+
python -m mpi4py -m sagemaker.core.remote_function.invoke_function "$@"
271271
272272
python /opt/ml/input/data/{RUNTIME_SCRIPTS_CHANNEL_NAME}/{MPI_UTILS_SCRIPT_NAME} --job_ended 1
273273
else
@@ -320,18 +320,18 @@
320320
printf "INFO: Invoking remote function with torchrun inside conda environment: $conda_env.\\n"
321321
printf "INFO: $conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
322322
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
323-
-m sagemaker.train.remote_function.invoke_function \\n"
323+
-m sagemaker.core.remote_function.invoke_function \\n"
324324
325325
$conda_exe run -n $conda_env torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE \
326326
--master_addr $SM_MASTER_ADDR --master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK \
327-
-m sagemaker.train.remote_function.invoke_function "$@"
327+
-m sagemaker.core.remote_function.invoke_function "$@"
328328
else
329329
printf "INFO: No conda env provided. Invoking remote function with torchrun\\n"
330330
printf "INFO: torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
331-
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function \\n"
331+
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function \\n"
332332
333333
torchrun --nnodes $SM_HOST_COUNT --nproc_per_node $SM_NPROC_PER_NODE --master_addr $SM_MASTER_ADDR \
334-
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.train.remote_function.invoke_function "$@"
334+
--master_port $SM_MASTER_PORT --node_rank $SM_CURRENT_HOST_RANK -m sagemaker.core.remote_function.invoke_function "$@"
335335
fi
336336
"""
337337

@@ -1467,7 +1467,7 @@ def _ensure_sagemaker_dependency(local_dependencies_path: str) -> str:
14671467
return local_dependencies_path
14681468

14691469

1470-
1470+
def _generate_input_data_config(job_settings, s3_base_uri):
14711471
"""Generates input data config"""
14721472
from sagemaker.core.workflow.utilities import load_step_compilation_context
14731473

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
"""Integration tests for sagemaker dependency injection in remote functions.
2+
3+
These tests verify that the sagemaker>=3.2.0 dependency is properly injected
4+
into remote function jobs, preventing version mismatch issues.
5+
"""
6+
7+
import os
8+
import sys
9+
import tempfile
10+
import pytest
11+
12+
# Skip decorator for AWS configuration
13+
# skip_if_no_aws_region = pytest.mark.skipif(
14+
# not os.environ.get('AWS_DEFAULT_REGION'),
15+
# reason="AWS credentials not configured"
16+
# )
17+
18+
# Add src to path
19+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../../../src'))
20+
21+
from sagemaker.core.remote_function import remote
22+
23+
24+
class TestRemoteFunctionDependencyInjection:
25+
"""Integration tests for dependency injection in remote functions."""
26+
27+
@pytest.mark.integ
28+
# @skip_if_no_aws_region
29+
def test_remote_function_without_dependencies(self):
30+
"""Test remote function execution without explicit dependencies.
31+
32+
This test verifies that when no dependencies are provided, the remote
33+
function still executes successfully because sagemaker>=3.2.0 is
34+
automatically injected.
35+
"""
36+
@remote(
37+
instance_type="ml.m5.large",
38+
# No dependencies specified - sagemaker should be injected automatically
39+
)
40+
def simple_add(x, y):
41+
"""Simple function that adds two numbers."""
42+
return x + y
43+
44+
# Execute the function
45+
result = simple_add(5, 3)
46+
47+
# Verify result
48+
assert result == 8, f"Expected 8, got {result}"
49+
print("✓ Remote function without dependencies executed successfully")
50+
51+
@pytest.mark.integ
52+
# @skip_if_no_aws_region
53+
def test_remote_function_with_user_dependencies_no_sagemaker(self):
54+
"""Test remote function with user dependencies but no sagemaker.
55+
56+
This test verifies that when user provides dependencies without sagemaker,
57+
sagemaker>=3.2.0 is automatically appended.
58+
"""
59+
# Create a temporary requirements.txt without sagemaker
60+
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f:
61+
f.write("numpy>=1.20.0\npandas>=1.3.0\n")
62+
req_file = f.name
63+
64+
try:
65+
@remote(
66+
instance_type="ml.m5.large",
67+
dependencies=req_file,
68+
)
69+
def compute_with_numpy(x):
70+
"""Function that uses numpy."""
71+
import numpy as np
72+
return np.array([x, x*2, x*3]).sum()
73+
74+
# Execute the function
75+
result = compute_with_numpy(5)
76+
77+
# Verify result (5 + 10 + 15 = 30)
78+
assert result == 30, f"Expected 30, got {result}"
79+
print("✓ Remote function with user dependencies executed successfully")
80+
finally:
81+
os.remove(req_file)
82+
83+
84+
class TestRemoteFunctionVersionCompatibility:
85+
"""Tests for version compatibility between local and remote environments."""
86+
87+
@pytest.mark.integ
88+
# @skip_if_no_aws_region
89+
def test_deserialization_with_injected_sagemaker(self):
90+
"""Test that deserialization works with injected sagemaker dependency.
91+
92+
This test verifies that the remote environment can properly deserialize
93+
functions when sagemaker>=3.2.0 is available.
94+
"""
95+
@remote(
96+
instance_type="ml.m5.large",
97+
)
98+
def complex_computation(data):
99+
"""Function that performs complex computation."""
100+
result = sum(data) * len(data)
101+
return result
102+
103+
# Execute with various data types
104+
test_data = [1, 2, 3, 4, 5]
105+
result = complex_computation(test_data)
106+
107+
# Verify result (sum=15, len=5, 15*5=75)
108+
assert result == 75, f"Expected 75, got {result}"
109+
print("✓ Deserialization with injected sagemaker works correctly")
110+
111+
@pytest.mark.integ
112+
# @skip_if_no_aws_region
113+
def test_multiple_remote_functions_with_dependencies(self):
114+
"""Test multiple remote functions with different dependency configurations.
115+
116+
This test verifies that the dependency injection works correctly
117+
when multiple remote functions are defined and executed.
118+
"""
119+
@remote(instance_type="ml.m5.large")
120+
def func1(x):
121+
return x + 1
122+
123+
@remote(instance_type="ml.m5.large")
124+
def func2(x):
125+
return x * 2
126+
127+
# Execute both functions
128+
result1 = func1(5)
129+
result2 = func2(5)
130+
131+
assert result1 == 6, f"func1: Expected 6, got {result1}"
132+
assert result2 == 10, f"func2: Expected 10, got {result2}"
133+
print("✓ Multiple remote functions with dependencies executed successfully")
134+
135+
136+
if __name__ == "__main__":
137+
pytest.main([__file__, "-v", "-m", "integ"])

0 commit comments

Comments
 (0)