Skip to content

Commit ccc3425

Browse files
committed
fix: address review comments (iteration #2)
1 parent 8bc8db3 commit ccc3425

File tree

3 files changed

+253
-18
lines changed

3 files changed

+253
-18
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2986,9 +2986,9 @@ def _deploy_core_endpoint(self, **kwargs):
29862986
if ic_data_cache_config is not None:
29872987
resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config)
29882988
if resolved_cache_config is not None:
2989-
inference_component_spec["DataCacheConfig"] = {
2990-
"EnableCaching": resolved_cache_config.enable_caching
2991-
}
2989+
cache_dict = {"EnableCaching": resolved_cache_config.enable_caching}
2990+
# Forward any additional fields from the shape as they become available
2991+
inference_component_spec["DataCacheConfig"] = cache_dict
29922992

29932993
ic_base_component_name = kwargs.get("base_inference_component_name")
29942994
if ic_base_component_name is not None:
@@ -2999,11 +2999,11 @@ def _deploy_core_endpoint(self, **kwargs):
29992999
resolved_container = self._resolve_container_spec(ic_container)
30003000
if resolved_container is not None:
30013001
container_dict = {}
3002-
if hasattr(resolved_container, "image") and resolved_container.image:
3002+
if resolved_container.image:
30033003
container_dict["Image"] = resolved_container.image
3004-
if hasattr(resolved_container, "artifact_url") and resolved_container.artifact_url:
3004+
if resolved_container.artifact_url:
30053005
container_dict["ArtifactUrl"] = resolved_container.artifact_url
3006-
if hasattr(resolved_container, "environment") and resolved_container.environment:
3006+
if resolved_container.environment:
30073007
container_dict["Environment"] = resolved_container.environment
30083008
if container_dict:
30093009
inference_component_spec["Container"] = container_dict
@@ -4211,7 +4211,8 @@ def deploy(
42114211
'artifact_url', 'environment' or an InferenceComponentContainerSpecification
42124212
instance. (Default: None).
42134213
variant_name (str, optional): The name of the production variant to deploy to.
4214-
If not specified, defaults to 'AllTraffic'. (Default: None).
4214+
If not provided (or explicitly ``None``), defaults to ``'AllTraffic'``.
4215+
(Default: None).
42154216
Returns:
42164217
Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint``
42174218
resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode,

sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3432,14 +3432,11 @@ def _resolve_container_spec(
34323432
if isinstance(container, InferenceComponentContainerSpecification):
34333433
return container
34343434
elif isinstance(container, dict):
3435-
kwargs = {}
3436-
if "image" in container:
3437-
kwargs["image"] = container["image"]
3438-
if "artifact_url" in container:
3439-
kwargs["artifact_url"] = container["artifact_url"]
3440-
if "environment" in container:
3441-
kwargs["environment"] = container["environment"]
3442-
return InferenceComponentContainerSpecification(**kwargs)
3435+
# Only pass known keys to avoid Pydantic validation errors
3436+
# if the model has extra='forbid' configured
3437+
known_keys = {"image", "artifact_url", "environment"}
3438+
filtered = {k: v for k, v in container.items() if k in known_keys}
3439+
return InferenceComponentContainerSpecification(**filtered)
34433440
else:
34443441
raise ValueError(
34453442
f"container must be a dict or an InferenceComponentContainerSpecification "

tests/unit/sagemaker/serve/test_resolve_ic_params.py

Lines changed: 240 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Unit tests for _resolve_data_cache_config and _resolve_container_spec."""
13+
"""Unit tests for IC parameter resolvers and wiring logic."""
1414
from __future__ import absolute_import
1515

1616
import pytest
17+
from unittest.mock import MagicMock, patch, ANY
1718

1819
from sagemaker.core.shapes import (
1920
InferenceComponentDataCacheConfig,
@@ -61,7 +62,11 @@ def test_dict_missing_enable_caching_raises(self, utils):
6162
utils._resolve_data_cache_config({})
6263

6364
def test_dict_with_extra_keys_still_works(self, utils):
64-
"""Extra keys are ignored; only enable_caching is required."""
65+
"""Extra keys in the input dict are ignored (not forwarded to the Pydantic constructor).
66+
67+
The resolver only extracts 'enable_caching' from the dict, so extra keys
68+
do not cause Pydantic validation errors even if the model forbids extras.
69+
"""
6570
result = utils._resolve_data_cache_config(
6671
{"enable_caching": True, "extra_key": "ignored"}
6772
)
@@ -130,7 +135,11 @@ def test_dict_empty(self, utils):
130135
assert isinstance(result, InferenceComponentContainerSpecification)
131136

132137
def test_dict_with_extra_keys(self, utils):
133-
"""Extra keys are ignored."""
138+
"""Extra keys are filtered out before passing to the Pydantic constructor.
139+
140+
This ensures compatibility even if InferenceComponentContainerSpecification
141+
has extra='forbid' in its Pydantic model config.
142+
"""
134143
result = utils._resolve_container_spec({
135144
"image": "img",
136145
"unknown_key": "ignored",
@@ -149,3 +158,231 @@ def test_invalid_type_int_raises(self, utils):
149158
def test_invalid_type_list_raises(self, utils):
150159
with pytest.raises(ValueError, match="container must be a dict"):
151160
utils._resolve_container_spec([{"image": "img"}])
161+
162+
163+
# ============================================================
164+
# Tests for core wiring logic in _deploy_core_endpoint
165+
# ============================================================
166+
167+
class TestDeployCoreEndpointWiring:
168+
"""Tests that new IC parameters are correctly wired through _deploy_core_endpoint."""
169+
170+
def _make_model_builder(self):
171+
"""Create a minimally-configured ModelBuilder for testing _deploy_core_endpoint."""
172+
from sagemaker.serve.model_builder import ModelBuilder
173+
174+
mb = object.__new__(ModelBuilder)
175+
# Set minimum required attributes
176+
mb.model_name = "test-model"
177+
mb.endpoint_name = None
178+
mb.inference_component_name = None
179+
mb.instance_type = "ml.g5.2xlarge"
180+
mb.instance_count = 1
181+
mb.accelerator_type = None
182+
mb._tags = None
183+
mb.kms_key = None
184+
mb.async_inference_config = None
185+
mb.serverless_inference_config = None
186+
mb.model_data_download_timeout = None
187+
mb.resource_requirements = None
188+
mb.container_startup_health_check_timeout = None
189+
mb.inference_ami_version = None
190+
mb._is_sharded_model = False
191+
mb._enable_network_isolation = False
192+
mb.role_arn = "arn:aws:iam::123456789012:role/SageMakerRole"
193+
mb.vpc_config = None
194+
mb.inference_recommender_job_results = None
195+
mb.model_server = None
196+
mb.mode = None
197+
mb.region = "us-east-1"
198+
199+
# Mock built_model
200+
mb.built_model = MagicMock()
201+
mb.built_model.model_name = "test-model"
202+
203+
# Mock sagemaker_session
204+
mb.sagemaker_session = MagicMock()
205+
mb.sagemaker_session.endpoint_in_service_or_not.return_value = True
206+
mb.sagemaker_session.boto_session = MagicMock()
207+
mb.sagemaker_session.boto_region_name = "us-east-1"
208+
209+
return mb
210+
211+
@patch("sagemaker.serve.model_builder.Endpoint")
212+
def test_variant_name_defaults_to_all_traffic(self, mock_endpoint_cls):
213+
"""When variant_name is not provided, it defaults to 'AllTraffic'."""
214+
mb = self._make_model_builder()
215+
mock_endpoint_cls.get.return_value = MagicMock()
216+
217+
from sagemaker.core.inference_config import ResourceRequirements
218+
resources = ResourceRequirements(
219+
requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1}
220+
)
221+
222+
mb._deploy_core_endpoint(
223+
endpoint_type="INFERENCE_COMPONENT_BASED",
224+
resources=resources,
225+
instance_type="ml.g5.2xlarge",
226+
initial_instance_count=1,
227+
wait=False,
228+
)
229+
230+
# Verify create_inference_component was called with variant_name="AllTraffic"
231+
mb.sagemaker_session.create_inference_component.assert_called_once()
232+
call_kwargs = mb.sagemaker_session.create_inference_component.call_args
233+
assert call_kwargs[1]["variant_name"] == "AllTraffic" or \
234+
(len(call_kwargs[0]) > 2 and call_kwargs[0][2] == "AllTraffic")
235+
236+
@patch("sagemaker.serve.model_builder.Endpoint")
237+
def test_variant_name_custom(self, mock_endpoint_cls):
238+
"""When variant_name is provided, it is used instead of 'AllTraffic'."""
239+
mb = self._make_model_builder()
240+
mock_endpoint_cls.get.return_value = MagicMock()
241+
242+
from sagemaker.core.inference_config import ResourceRequirements
243+
resources = ResourceRequirements(
244+
requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1}
245+
)
246+
247+
mb._deploy_core_endpoint(
248+
endpoint_type="INFERENCE_COMPONENT_BASED",
249+
resources=resources,
250+
instance_type="ml.g5.2xlarge",
251+
initial_instance_count=1,
252+
variant_name="MyVariant",
253+
wait=False,
254+
)
255+
256+
call_kwargs = mb.sagemaker_session.create_inference_component.call_args
257+
assert call_kwargs[1]["variant_name"] == "MyVariant" or \
258+
(len(call_kwargs[0]) > 2 and call_kwargs[0][2] == "MyVariant")
259+
260+
@patch("sagemaker.serve.model_builder.Endpoint")
261+
def test_data_cache_config_wired_into_spec(self, mock_endpoint_cls):
262+
"""data_cache_config dict is resolved and added to inference_component_spec."""
263+
mb = self._make_model_builder()
264+
mock_endpoint_cls.get.return_value = MagicMock()
265+
266+
from sagemaker.core.inference_config import ResourceRequirements
267+
resources = ResourceRequirements(
268+
requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1}
269+
)
270+
271+
mb._deploy_core_endpoint(
272+
endpoint_type="INFERENCE_COMPONENT_BASED",
273+
resources=resources,
274+
instance_type="ml.g5.2xlarge",
275+
initial_instance_count=1,
276+
data_cache_config={"enable_caching": True},
277+
wait=False,
278+
)
279+
280+
call_kwargs = mb.sagemaker_session.create_inference_component.call_args
281+
spec = call_kwargs[1]["specification"]
282+
assert "DataCacheConfig" in spec
283+
assert spec["DataCacheConfig"]["EnableCaching"] is True
284+
285+
@patch("sagemaker.serve.model_builder.Endpoint")
286+
def test_base_inference_component_name_wired_into_spec(self, mock_endpoint_cls):
287+
"""base_inference_component_name is added to inference_component_spec."""
288+
mb = self._make_model_builder()
289+
mock_endpoint_cls.get.return_value = MagicMock()
290+
291+
from sagemaker.core.inference_config import ResourceRequirements
292+
resources = ResourceRequirements(
293+
requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1}
294+
)
295+
296+
mb._deploy_core_endpoint(
297+
endpoint_type="INFERENCE_COMPONENT_BASED",
298+
resources=resources,
299+
instance_type="ml.g5.2xlarge",
300+
initial_instance_count=1,
301+
base_inference_component_name="base-ic-name",
302+
wait=False,
303+
)
304+
305+
call_kwargs = mb.sagemaker_session.create_inference_component.call_args
306+
spec = call_kwargs[1]["specification"]
307+
assert spec["BaseInferenceComponentName"] == "base-ic-name"
308+
309+
@patch("sagemaker.serve.model_builder.Endpoint")
310+
def test_container_wired_into_spec(self, mock_endpoint_cls):
311+
"""container dict is resolved and added to inference_component_spec."""
312+
mb = self._make_model_builder()
313+
mock_endpoint_cls.get.return_value = MagicMock()
314+
315+
from sagemaker.core.inference_config import ResourceRequirements
316+
resources = ResourceRequirements(
317+
requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1}
318+
)
319+
320+
mb._deploy_core_endpoint(
321+
endpoint_type="INFERENCE_COMPONENT_BASED",
322+
resources=resources,
323+
instance_type="ml.g5.2xlarge",
324+
initial_instance_count=1,
325+
container={
326+
"image": "my-image:latest",
327+
"artifact_url": "s3://bucket/artifact",
328+
"environment": {"KEY": "VALUE"},
329+
},
330+
wait=False,
331+
)
332+
333+
call_kwargs = mb.sagemaker_session.create_inference_component.call_args
334+
spec = call_kwargs[1]["specification"]
335+
assert "Container" in spec
336+
assert spec["Container"]["Image"] == "my-image:latest"
337+
assert spec["Container"]["ArtifactUrl"] == "s3://bucket/artifact"
338+
assert spec["Container"]["Environment"] == {"KEY": "VALUE"}
339+
340+
@patch("sagemaker.serve.model_builder.Endpoint")
341+
def test_no_optional_params_no_extra_keys_in_spec(self, mock_endpoint_cls):
342+
"""When no optional IC params are provided, spec has no extra keys."""
343+
mb = self._make_model_builder()
344+
mock_endpoint_cls.get.return_value = MagicMock()
345+
346+
from sagemaker.core.inference_config import ResourceRequirements
347+
resources = ResourceRequirements(
348+
requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1}
349+
)
350+
351+
mb._deploy_core_endpoint(
352+
endpoint_type="INFERENCE_COMPONENT_BASED",
353+
resources=resources,
354+
instance_type="ml.g5.2xlarge",
355+
initial_instance_count=1,
356+
wait=False,
357+
)
358+
359+
call_kwargs = mb.sagemaker_session.create_inference_component.call_args
360+
spec = call_kwargs[1]["specification"]
361+
assert "DataCacheConfig" not in spec
362+
assert "BaseInferenceComponentName" not in spec
363+
assert "Container" not in spec
364+
365+
@patch("sagemaker.serve.model_builder.Endpoint")
366+
def test_data_cache_config_typed_object_wired(self, mock_endpoint_cls):
367+
"""InferenceComponentDataCacheConfig object is correctly wired."""
368+
mb = self._make_model_builder()
369+
mock_endpoint_cls.get.return_value = MagicMock()
370+
371+
from sagemaker.core.inference_config import ResourceRequirements
372+
resources = ResourceRequirements(
373+
requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1}
374+
)
375+
376+
config = InferenceComponentDataCacheConfig(enable_caching=True)
377+
mb._deploy_core_endpoint(
378+
endpoint_type="INFERENCE_COMPONENT_BASED",
379+
resources=resources,
380+
instance_type="ml.g5.2xlarge",
381+
initial_instance_count=1,
382+
data_cache_config=config,
383+
wait=False,
384+
)
385+
386+
call_kwargs = mb.sagemaker_session.create_inference_component.call_args
387+
spec = call_kwargs[1]["specification"]
388+
assert spec["DataCacheConfig"]["EnableCaching"] is True

0 commit comments

Comments
 (0)