Skip to content

Commit 8bc8db3

Browse files
committed
fix: address review comments (iteration #1)
1 parent 05e59de commit 8bc8db3

File tree

3 files changed

+232
-9
lines changed

3 files changed

+232
-9
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2980,18 +2980,49 @@ def _deploy_core_endpoint(self, **kwargs):
29802980
"StartupParameters": startup_parameters,
29812981
"ComputeResourceRequirements": resources.get_compute_resource_requirements(),
29822982
}
2983+
2984+
# Wire optional IC-level parameters into the specification
2985+
ic_data_cache_config = kwargs.get("data_cache_config")
2986+
if ic_data_cache_config is not None:
2987+
resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config)
2988+
if resolved_cache_config is not None:
2989+
inference_component_spec["DataCacheConfig"] = {
2990+
"EnableCaching": resolved_cache_config.enable_caching
2991+
}
2992+
2993+
ic_base_component_name = kwargs.get("base_inference_component_name")
2994+
if ic_base_component_name is not None:
2995+
inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name
2996+
2997+
ic_container = kwargs.get("container")
2998+
if ic_container is not None:
2999+
resolved_container = self._resolve_container_spec(ic_container)
3000+
if resolved_container is not None:
3001+
container_dict = {}
3002+
if hasattr(resolved_container, "image") and resolved_container.image:
3003+
container_dict["Image"] = resolved_container.image
3004+
if hasattr(resolved_container, "artifact_url") and resolved_container.artifact_url:
3005+
container_dict["ArtifactUrl"] = resolved_container.artifact_url
3006+
if hasattr(resolved_container, "environment") and resolved_container.environment:
3007+
container_dict["Environment"] = resolved_container.environment
3008+
if container_dict:
3009+
inference_component_spec["Container"] = container_dict
3010+
29833011
runtime_config = {"CopyCount": resources.copy_count}
29843012
self.inference_component_name = (
29853013
inference_component_name
29863014
or self.inference_component_name
29873015
or unique_name_from_base(self.model_name)
29883016
)
29893017

3018+
# Use user-provided variant_name or default to "AllTraffic"
3019+
ic_variant_name = kwargs.get("variant_name", "AllTraffic")
3020+
29903021
# [TODO]: Add endpoint_logging support
29913022
self.sagemaker_session.create_inference_component(
29923023
inference_component_name=self.inference_component_name,
29933024
endpoint_name=self.endpoint_name,
2994-
variant_name="AllTraffic", # default variant name
3025+
variant_name=ic_variant_name,
29953026
specification=inference_component_spec,
29963027
runtime_config=runtime_config,
29973028
tags=tags,
@@ -4129,6 +4160,10 @@ def deploy(
41294160
] = None,
41304161
custom_orchestrator_instance_type: str = None,
41314162
custom_orchestrator_initial_instance_count: int = None,
4163+
data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None,
4164+
base_inference_component_name: Optional[str] = None,
4165+
container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None,
4166+
variant_name: Optional[str] = None,
41324167
**kwargs,
41334168
) -> Union[Endpoint, LocalEndpoint, Transformer]:
41344169
"""Deploy the built model to an ``Endpoint``.
@@ -4162,6 +4197,21 @@ def deploy(
41624197
orchestrator deployment. (Default: None).
41634198
custom_orchestrator_initial_instance_count (int, optional): Initial instance count
41644199
for custom orchestrator deployment. (Default: None).
4200+
data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional):
4201+
Data cache configuration for the inference component. Enables caching of model
4202+
artifacts and container images on instances for faster auto-scaling cold starts.
4203+
Can be a dict with 'enable_caching' key (e.g., {'enable_caching': True}) or an
4204+
InferenceComponentDataCacheConfig instance. (Default: None).
4205+
base_inference_component_name (str, optional): Name of the base inference component
4206+
for adapter deployments (e.g., LoRA adapters attached to a base model).
4207+
(Default: None).
4208+
container (Union[InferenceComponentContainerSpecification, dict], optional):
4209+
Custom container specification for the inference component, including image URI,
4210+
artifact URL, and environment variables. Can be a dict with keys 'image',
4211+
'artifact_url', 'environment' or an InferenceComponentContainerSpecification
4212+
instance. (Default: None).
4213+
variant_name (str, optional): The name of the production variant to deploy to.
4214+
If not specified, defaults to 'AllTraffic'. (Default: None).
41654215
Returns:
41664216
Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint``
41674217
resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode,
@@ -4184,6 +4234,16 @@ def deploy(
41844234
if not hasattr(self, "built_model") and not hasattr(self, "_deployables"):
41854235
raise ValueError("Model needs to be built before deploying")
41864236

4237+
# Store IC-level parameters for use in _deploy_core_endpoint
4238+
if data_cache_config is not None:
4239+
kwargs["data_cache_config"] = data_cache_config
4240+
if base_inference_component_name is not None:
4241+
kwargs["base_inference_component_name"] = base_inference_component_name
4242+
if container is not None:
4243+
kwargs["container"] = container
4244+
if variant_name is not None:
4245+
kwargs["variant_name"] = variant_name
4246+
41874247
# Handle model customization deployment
41884248
if self._is_model_customization():
41894249
logger.info("Deploying Model Customization model")

sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ def build(self):
8080
from sagemaker.core.resources import Model
8181

8282
# MLflow imports
83+
from sagemaker.core.shapes import (
84+
InferenceComponentDataCacheConfig,
85+
InferenceComponentContainerSpecification,
86+
)
8387
from sagemaker.serve.model_format.mlflow.constants import (
8488
MLFLOW_METADATA_FILE,
8589
MLFLOW_MODEL_PATH,
@@ -3369,7 +3373,10 @@ def _extract_speculative_draft_model_provider(
33693373

33703374
return "auto"
33713375

3372-
def _resolve_data_cache_config(self, data_cache_config):
3376+
def _resolve_data_cache_config(
3377+
self,
3378+
data_cache_config: Union[InferenceComponentDataCacheConfig, Dict[str, Any], None],
3379+
) -> Optional[InferenceComponentDataCacheConfig]:
33733380
"""Resolve data_cache_config to InferenceComponentDataCacheConfig.
33743381
33753382
Args:
@@ -3380,26 +3387,33 @@ def _resolve_data_cache_config(self, data_cache_config):
33803387
InferenceComponentDataCacheConfig or None.
33813388
33823389
Raises:
3383-
ValueError: If data_cache_config is an unsupported type.
3390+
ValueError: If data_cache_config is an unsupported type or dict
3391+
is missing the required 'enable_caching' key.
33843392
"""
33853393
if data_cache_config is None:
33863394
return None
33873395

3388-
from sagemaker.core.shapes import InferenceComponentDataCacheConfig
3389-
33903396
if isinstance(data_cache_config, InferenceComponentDataCacheConfig):
33913397
return data_cache_config
33923398
elif isinstance(data_cache_config, dict):
3399+
if "enable_caching" not in data_cache_config:
3400+
raise ValueError(
3401+
"data_cache_config dict must contain the required 'enable_caching' key. "
3402+
"Example: {'enable_caching': True}"
3403+
)
33933404
return InferenceComponentDataCacheConfig(
3394-
enable_caching=data_cache_config.get("enable_caching", False)
3405+
enable_caching=data_cache_config["enable_caching"]
33953406
)
33963407
else:
33973408
raise ValueError(
33983409
f"data_cache_config must be a dict with 'enable_caching' key or an "
33993410
f"InferenceComponentDataCacheConfig instance, got {type(data_cache_config)}"
34003411
)
34013412

3402-
def _resolve_container_spec(self, container):
3413+
def _resolve_container_spec(
3414+
self,
3415+
container: Union[InferenceComponentContainerSpecification, Dict[str, Any], None],
3416+
) -> Optional[InferenceComponentContainerSpecification]:
34033417
"""Resolve container to InferenceComponentContainerSpecification.
34043418
34053419
Args:
@@ -3415,8 +3429,6 @@ def _resolve_container_spec(self, container):
34153429
if container is None:
34163430
return None
34173431

3418-
from sagemaker.core.shapes import InferenceComponentContainerSpecification
3419-
34203432
if isinstance(container, InferenceComponentContainerSpecification):
34213433
return container
34223434
elif isinstance(container, dict):
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Unit tests for _resolve_data_cache_config and _resolve_container_spec."""
14+
from __future__ import absolute_import
15+
16+
import pytest
17+
18+
from sagemaker.core.shapes import (
19+
InferenceComponentDataCacheConfig,
20+
InferenceComponentContainerSpecification,
21+
)
22+
from sagemaker.serve.model_builder_utils import _ModelBuilderUtils
23+
24+
25+
class ConcreteUtils(_ModelBuilderUtils):
26+
"""Concrete class to test mixin methods."""
27+
pass
28+
29+
30+
@pytest.fixture
31+
def utils():
32+
return ConcreteUtils()
33+
34+
35+
# ============================================================
36+
# Tests for _resolve_data_cache_config
37+
# ============================================================
38+
39+
class TestResolveDataCacheConfig:
40+
def test_none_returns_none(self, utils):
41+
assert utils._resolve_data_cache_config(None) is None
42+
43+
def test_already_typed_passthrough(self, utils):
44+
config = InferenceComponentDataCacheConfig(enable_caching=True)
45+
result = utils._resolve_data_cache_config(config)
46+
assert result is config
47+
assert result.enable_caching is True
48+
49+
def test_dict_with_enable_caching_true(self, utils):
50+
result = utils._resolve_data_cache_config({"enable_caching": True})
51+
assert isinstance(result, InferenceComponentDataCacheConfig)
52+
assert result.enable_caching is True
53+
54+
def test_dict_with_enable_caching_false(self, utils):
55+
result = utils._resolve_data_cache_config({"enable_caching": False})
56+
assert isinstance(result, InferenceComponentDataCacheConfig)
57+
assert result.enable_caching is False
58+
59+
def test_dict_missing_enable_caching_raises(self, utils):
60+
with pytest.raises(ValueError, match="must contain the required 'enable_caching' key"):
61+
utils._resolve_data_cache_config({})
62+
63+
def test_dict_with_extra_keys_still_works(self, utils):
64+
"""Extra keys are ignored; only enable_caching is required."""
65+
result = utils._resolve_data_cache_config(
66+
{"enable_caching": True, "extra_key": "ignored"}
67+
)
68+
assert isinstance(result, InferenceComponentDataCacheConfig)
69+
assert result.enable_caching is True
70+
71+
def test_invalid_type_raises(self, utils):
72+
with pytest.raises(ValueError, match="data_cache_config must be a dict"):
73+
utils._resolve_data_cache_config("invalid")
74+
75+
def test_invalid_type_int_raises(self, utils):
76+
with pytest.raises(ValueError, match="data_cache_config must be a dict"):
77+
utils._resolve_data_cache_config(42)
78+
79+
def test_invalid_type_list_raises(self, utils):
80+
with pytest.raises(ValueError, match="data_cache_config must be a dict"):
81+
utils._resolve_data_cache_config([True])
82+
83+
84+
# ============================================================
85+
# Tests for _resolve_container_spec
86+
# ============================================================
87+
88+
class TestResolveContainerSpec:
89+
def test_none_returns_none(self, utils):
90+
assert utils._resolve_container_spec(None) is None
91+
92+
def test_already_typed_passthrough(self, utils):
93+
spec = InferenceComponentContainerSpecification(
94+
image="my-image:latest",
95+
artifact_url="s3://bucket/artifact",
96+
environment={"KEY": "VALUE"},
97+
)
98+
result = utils._resolve_container_spec(spec)
99+
assert result is spec
100+
101+
def test_dict_full(self, utils):
102+
result = utils._resolve_container_spec({
103+
"image": "my-image:latest",
104+
"artifact_url": "s3://bucket/artifact",
105+
"environment": {"KEY": "VALUE"},
106+
})
107+
assert isinstance(result, InferenceComponentContainerSpecification)
108+
assert result.image == "my-image:latest"
109+
assert result.artifact_url == "s3://bucket/artifact"
110+
assert result.environment == {"KEY": "VALUE"}
111+
112+
def test_dict_image_only(self, utils):
113+
result = utils._resolve_container_spec({"image": "my-image:latest"})
114+
assert isinstance(result, InferenceComponentContainerSpecification)
115+
assert result.image == "my-image:latest"
116+
117+
def test_dict_artifact_url_only(self, utils):
118+
result = utils._resolve_container_spec({"artifact_url": "s3://bucket/model.tar.gz"})
119+
assert isinstance(result, InferenceComponentContainerSpecification)
120+
assert result.artifact_url == "s3://bucket/model.tar.gz"
121+
122+
def test_dict_environment_only(self, utils):
123+
result = utils._resolve_container_spec({"environment": {"A": "B"}})
124+
assert isinstance(result, InferenceComponentContainerSpecification)
125+
assert result.environment == {"A": "B"}
126+
127+
def test_dict_empty(self, utils):
128+
"""Empty dict creates a spec with no fields set."""
129+
result = utils._resolve_container_spec({})
130+
assert isinstance(result, InferenceComponentContainerSpecification)
131+
132+
def test_dict_with_extra_keys(self, utils):
133+
"""Extra keys are ignored."""
134+
result = utils._resolve_container_spec({
135+
"image": "img",
136+
"unknown_key": "ignored",
137+
})
138+
assert isinstance(result, InferenceComponentContainerSpecification)
139+
assert result.image == "img"
140+
141+
def test_invalid_type_raises(self, utils):
142+
with pytest.raises(ValueError, match="container must be a dict"):
143+
utils._resolve_container_spec("invalid")
144+
145+
def test_invalid_type_int_raises(self, utils):
146+
with pytest.raises(ValueError, match="container must be a dict"):
147+
utils._resolve_container_spec(123)
148+
149+
def test_invalid_type_list_raises(self, utils):
150+
with pytest.raises(ValueError, match="container must be a dict"):
151+
utils._resolve_container_spec([{"image": "img"}])

0 commit comments

Comments
 (0)