Skip to content

Commit c72a77d

Browse files
authored
Bug Fixes Model Customization (#5558)
* Bug Fixes Model Customization * fixes * Fix * Fix * Fix * Test FIxes
1 parent 88963f8 commit c72a77d

15 files changed

+6004
-444
lines changed

sagemaker-serve/debug.ipynb

Lines changed: 428 additions & 0 deletions
Large diffs are not rendered by default.

sagemaker-serve/src/sagemaker/serve/async_inference/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@
1616

1717
from sagemaker.core.inference_config import AsyncInferenceConfig # noqa: F401
1818
from sagemaker.serve.async_inference.waiter_config import WaiterConfig # noqa: F401
19-
from sagemaker.serve.async_inference.async_inference_response import AsyncInferenceResponse # noqa: F401
19+
from sagemaker.serve.async_inference.async_inference_response import (
20+
AsyncInferenceResponse,
21+
) # noqa: F401

sagemaker-serve/src/sagemaker/serve/async_inference/async_inference_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
" from sagemaker.core.inference_config import AsyncInferenceConfig\n"
3434
"This compatibility shim will be removed in a future version.",
3535
DeprecationWarning,
36-
stacklevel=2
36+
stacklevel=2,
3737
)
3838

39-
__all__ = ['AsyncInferenceConfig']
39+
__all__ = ["AsyncInferenceConfig"]

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 1055 additions & 437 deletions
Large diffs are not rendered by default.

sagemaker-serve/test_script.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
print("Hello from processing script!")

sagemaker-serve/tests/unit/test_artifact_path_propagation.py

Lines changed: 408 additions & 0 deletions
Large diffs are not rendered by default.

sagemaker-serve/tests/unit/test_artifact_path_resolution.py

Lines changed: 485 additions & 0 deletions
Large diffs are not rendered by default.

sagemaker-serve/tests/unit/test_compute_requirements_resolution.py

Lines changed: 984 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
"""
2+
Test to verify that deploy() method passes inference_config to _deploy_model_customization.
3+
This test validates task 4.4 requirements.
4+
"""
5+
6+
import unittest
7+
from unittest.mock import Mock, patch, MagicMock
8+
import pytest
9+
10+
from sagemaker.serve.model_builder import ModelBuilder
11+
from sagemaker.serve.mode.function_pointers import Mode
12+
from sagemaker.core.inference_config import ResourceRequirements # Correct import!
13+
14+
15+
class TestDeployPassesInferenceConfig(unittest.TestCase):
16+
"""Test that deploy() correctly passes inference_config to _deploy_model_customization."""
17+
18+
def setUp(self):
19+
"""Set up test fixtures."""
20+
self.mock_session = Mock()
21+
self.mock_session.boto_region_name = "us-west-2"
22+
self.mock_session.default_bucket.return_value = "test-bucket"
23+
self.mock_session.default_bucket_prefix = "test-prefix"
24+
self.mock_session.config = {}
25+
self.mock_session.sagemaker_config = {}
26+
self.mock_session.settings = Mock()
27+
self.mock_session.settings.include_jumpstart_tags = False
28+
29+
mock_credentials = Mock()
30+
mock_credentials.access_key = "test-key"
31+
mock_credentials.secret_key = "test-secret"
32+
mock_credentials.token = None
33+
self.mock_session.boto_session = Mock()
34+
self.mock_session.boto_session.get_credentials.return_value = mock_credentials
35+
self.mock_session.boto_session.region_name = "us-west-2"
36+
37+
@patch("sagemaker.serve.model_builder.ModelBuilder._deploy_model_customization")
38+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
39+
@patch(
40+
"sagemaker.serve.model_builder.ModelBuilder._fetch_default_instance_type_for_custom_model"
41+
)
42+
def test_deploy_passes_inference_config_to_deploy_model_customization(
43+
self,
44+
mock_fetch_default_instance,
45+
mock_is_model_customization,
46+
mock_deploy_model_customization,
47+
):
48+
"""Test that deploy() passes inference_config parameter to _deploy_model_customization."""
49+
# Setup: Mock model customization check
50+
mock_is_model_customization.return_value = True
51+
mock_fetch_default_instance.return_value = "ml.g5.12xlarge"
52+
53+
# Setup: Mock _deploy_model_customization to return a mock endpoint
54+
mock_endpoint = Mock()
55+
mock_deploy_model_customization.return_value = mock_endpoint
56+
57+
# Create ModelBuilder
58+
builder = ModelBuilder(
59+
model="huggingface-llm-mistral-7b",
60+
model_metadata={
61+
"CUSTOM_MODEL_ID": "huggingface-llm-mistral-7b",
62+
"CUSTOM_MODEL_VERSION": "1.0.0",
63+
},
64+
instance_type="ml.g5.12xlarge",
65+
mode=Mode.SAGEMAKER_ENDPOINT,
66+
role_arn="arn:aws:iam::123456789012:role/TestRole",
67+
sagemaker_session=self.mock_session,
68+
image_uri="123456789012.dkr.ecr.us-west-2.amazonaws.com/test:latest",
69+
)
70+
71+
# Mark as built
72+
builder.built_model = Mock()
73+
74+
# Create inference_config
75+
inference_config = ResourceRequirements(
76+
requests={"num_cpus": 8, "memory": 16384, "num_accelerators": 4}
77+
)
78+
79+
# Execute: Call deploy() with inference_config
80+
result = builder.deploy(
81+
endpoint_name="test-endpoint",
82+
inference_config=inference_config,
83+
initial_instance_count=1,
84+
wait=True,
85+
)
86+
87+
# Verify: _deploy_model_customization was called with inference_config
88+
assert mock_deploy_model_customization.called
89+
call_kwargs = mock_deploy_model_customization.call_args[1]
90+
91+
# Verify inference_config was passed through
92+
assert "inference_config" in call_kwargs
93+
assert call_kwargs["inference_config"] == inference_config
94+
95+
# Verify other parameters were also passed
96+
assert call_kwargs["endpoint_name"] == "test-endpoint"
97+
assert call_kwargs["initial_instance_count"] == 1
98+
assert call_kwargs["wait"] == True
99+
100+
# Verify the result is the mock endpoint
101+
assert result == mock_endpoint
102+
103+
@patch("sagemaker.serve.model_builder.ModelBuilder._deploy_model_customization")
104+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
105+
@patch(
106+
"sagemaker.serve.model_builder.ModelBuilder._fetch_default_instance_type_for_custom_model"
107+
)
108+
def test_deploy_passes_none_when_inference_config_not_provided(
109+
self,
110+
mock_fetch_default_instance,
111+
mock_is_model_customization,
112+
mock_deploy_model_customization,
113+
):
114+
"""Test backward compatibility: deploy() passes None when inference_config not provided."""
115+
# Setup
116+
mock_is_model_customization.return_value = True
117+
mock_fetch_default_instance.return_value = "ml.g5.12xlarge"
118+
mock_endpoint = Mock()
119+
mock_deploy_model_customization.return_value = mock_endpoint
120+
121+
builder = ModelBuilder(
122+
model="huggingface-llm-mistral-7b",
123+
model_metadata={
124+
"CUSTOM_MODEL_ID": "huggingface-llm-mistral-7b",
125+
"CUSTOM_MODEL_VERSION": "1.0.0",
126+
},
127+
instance_type="ml.g5.12xlarge",
128+
mode=Mode.SAGEMAKER_ENDPOINT,
129+
role_arn="arn:aws:iam::123456789012:role/TestRole",
130+
sagemaker_session=self.mock_session,
131+
image_uri="123456789012.dkr.ecr.us-west-2.amazonaws.com/test:latest",
132+
)
133+
134+
builder.built_model = Mock()
135+
136+
# Execute: Call deploy() WITHOUT inference_config
137+
result = builder.deploy(endpoint_name="test-endpoint", initial_instance_count=1)
138+
139+
# Verify: _deploy_model_customization was called with inference_config=None
140+
assert mock_deploy_model_customization.called
141+
call_kwargs = mock_deploy_model_customization.call_args[1]
142+
143+
# Verify inference_config is None (backward compatibility)
144+
assert "inference_config" in call_kwargs
145+
assert call_kwargs["inference_config"] is None
146+
147+
@patch("sagemaker.serve.model_builder.ModelBuilder._deploy_model_customization")
148+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
149+
@patch(
150+
"sagemaker.serve.model_builder.ModelBuilder._fetch_default_instance_type_for_custom_model"
151+
)
152+
def test_deploy_only_passes_resource_requirements_type(
153+
self,
154+
mock_fetch_default_instance,
155+
mock_is_model_customization,
156+
mock_deploy_model_customization,
157+
):
158+
"""Test that deploy() only passes inference_config if it's ResourceRequirements type."""
159+
# Setup
160+
mock_is_model_customization.return_value = True
161+
mock_fetch_default_instance.return_value = "ml.g5.12xlarge"
162+
mock_endpoint = Mock()
163+
mock_deploy_model_customization.return_value = mock_endpoint
164+
165+
builder = ModelBuilder(
166+
model="huggingface-llm-mistral-7b",
167+
model_metadata={
168+
"CUSTOM_MODEL_ID": "huggingface-llm-mistral-7b",
169+
"CUSTOM_MODEL_VERSION": "1.0.0",
170+
},
171+
instance_type="ml.g5.12xlarge",
172+
mode=Mode.SAGEMAKER_ENDPOINT,
173+
role_arn="arn:aws:iam::123456789012:role/TestRole",
174+
sagemaker_session=self.mock_session,
175+
image_uri="123456789012.dkr.ecr.us-west-2.amazonaws.com/test:latest",
176+
)
177+
178+
builder.built_model = Mock()
179+
180+
# Create a non-ResourceRequirements inference_config (e.g., ServerlessInferenceConfig)
181+
from sagemaker.core.inference_config import ServerlessInferenceConfig
182+
183+
serverless_config = ServerlessInferenceConfig(memory_size_in_mb=4096, max_concurrency=10)
184+
185+
# Execute: Call deploy() with ServerlessInferenceConfig
186+
# This should NOT pass it to _deploy_model_customization
187+
result = builder.deploy(endpoint_name="test-endpoint", inference_config=serverless_config)
188+
189+
# Verify: _deploy_model_customization was called with inference_config=None
190+
# because ServerlessInferenceConfig is not ResourceRequirements
191+
assert mock_deploy_model_customization.called
192+
call_kwargs = mock_deploy_model_customization.call_args[1]
193+
194+
# Verify inference_config is None (not ServerlessInferenceConfig)
195+
assert "inference_config" in call_kwargs
196+
assert call_kwargs["inference_config"] is None
197+
198+
199+
if __name__ == "__main__":
200+
unittest.main()

0 commit comments

Comments
 (0)