Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion sagemaker-serve/src/sagemaker/serve/model_builder.py
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add 1-2 integration tests that verify the new IC-level parameters (data_cache_config, variant_name) work end-to-end through ModelBuilder.deploy(). Place the new test file at sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py.

Follow the exact patterns from the existing integ tests for structure, cleanup, and assertions.

Test 1: Deploy with data_cache_config via the standard IC path (_deploy_core_endpoint). Use a JumpStart model (like the pattern in test_jumpstart_integration.py, ideally use the same JumpStart model) but deploy with inference_config=ResourceRequirements to trigger the IC-based path. Pass data_cache_config={"enable_caching": True} and a custom variant_name. After deployment, use boto3 sagemaker_client.describe_inference_component() to verify:

  • The IC was created with DataCacheConfig.EnableCaching == True
  • The variant name matches what was passed (not "AllTraffic")

Test 2: Deploy with data_cache_config via the model customization path (_deploy_model_customization). Use a TrainingJob-based ModelBuilder (like test_model_customization_deployment.py lines 131-170) and pass data_cache_config={"enable_caching": True}. After deployment, describe the inference component and verify DataCacheConfig.EnableCaching is True. Also verify the variant_name defaults to endpoint_name (backward compat) when variant_name is not explicitly provided.

Both tests should include proper cleanup (delete endpoint, endpoint config, model, inference components) in a finally block. Use unique names with uuid to avoid collisions. Mark both with @pytest.mark.slow_test.

$context sagemaker-serve/tests/integ/test_jumpstart_integration.py
$context sagemaker-serve/tests/integ/test_model_customization_deployment.py

Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
ModelLifeCycle,
DriftCheckBaselines,
InferenceComponentComputeResourceRequirements,
InferenceComponentDataCacheConfig,
InferenceComponentContainerSpecification,
)
from sagemaker.core.resources import (
ModelPackage,
Expand Down Expand Up @@ -2978,18 +2980,49 @@ def _deploy_core_endpoint(self, **kwargs):
"StartupParameters": startup_parameters,
"ComputeResourceRequirements": resources.get_compute_resource_requirements(),
}

# Wire optional IC-level parameters into the specification
ic_data_cache_config = kwargs.get("data_cache_config")
if ic_data_cache_config is not None:
resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config)
if resolved_cache_config is not None:
inference_component_spec["DataCacheConfig"] = {
"EnableCaching": resolved_cache_config.enable_caching
}

ic_base_component_name = kwargs.get("base_inference_component_name")
if ic_base_component_name is not None:
inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name

ic_container = kwargs.get("container")
if ic_container is not None:
resolved_container = self._resolve_container_spec(ic_container)
if resolved_container is not None:
container_dict = {}
if hasattr(resolved_container, "image") and resolved_container.image:
container_dict["Image"] = resolved_container.image
if hasattr(resolved_container, "artifact_url") and resolved_container.artifact_url:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using hasattr checks on a Pydantic BaseModel (which InferenceComponentContainerSpecification likely is) is unnecessary — the attributes are always present (possibly None). Simplify to:

container_dict = {}
if resolved_container.image:
    container_dict["Image"] = resolved_container.image
if resolved_container.artifact_url:
    container_dict["ArtifactUrl"] = resolved_container.artifact_url
if resolved_container.environment:
    container_dict["Environment"] = resolved_container.environment

This is cleaner and more idiomatic for Pydantic models.

container_dict["ArtifactUrl"] = resolved_container.artifact_url
if hasattr(resolved_container, "environment") and resolved_container.environment:
container_dict["Environment"] = resolved_container.environment
if container_dict:
inference_component_spec["Container"] = container_dict

runtime_config = {"CopyCount": resources.copy_count}
self.inference_component_name = (
inference_component_name
or self.inference_component_name
or unique_name_from_base(self.model_name)
)

# Use user-provided variant_name or default to "AllTraffic"
ic_variant_name = kwargs.get("variant_name", "AllTraffic")

# [TODO]: Add endpoint_logging support
self.sagemaker_session.create_inference_component(
inference_component_name=self.inference_component_name,
endpoint_name=self.endpoint_name,
variant_name="AllTraffic", # default variant name
variant_name=ic_variant_name,
specification=inference_component_spec,
runtime_config=runtime_config,
tags=tags,
Expand Down Expand Up @@ -4127,6 +4160,10 @@ def deploy(
] = None,
custom_orchestrator_instance_type: str = None,
custom_orchestrator_initial_instance_count: int = None,
data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None,
base_inference_component_name: Optional[str] = None,
container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None,
variant_name: Optional[str] = None,
**kwargs,
) -> Union[Endpoint, LocalEndpoint, Transformer]:
"""Deploy the built model to an ``Endpoint``.
Expand Down Expand Up @@ -4160,6 +4197,21 @@ def deploy(
orchestrator deployment. (Default: None).
custom_orchestrator_initial_instance_count (int, optional): Initial instance count
for custom orchestrator deployment. (Default: None).
data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional):
Data cache configuration for the inference component. Enables caching of model
artifacts and container images on instances for faster auto-scaling cold starts.
Can be a dict with 'enable_caching' key (e.g., {'enable_caching': True}) or an
InferenceComponentDataCacheConfig instance. (Default: None).
base_inference_component_name (str, optional): Name of the base inference component
for adapter deployments (e.g., LoRA adapters attached to a base model).
(Default: None).
container (Union[InferenceComponentContainerSpecification, dict], optional):
Custom container specification for the inference component, including image URI,
artifact URL, and environment variables. Can be a dict with keys 'image',
'artifact_url', 'environment' or an InferenceComponentContainerSpecification
instance. (Default: None).
variant_name (str, optional): The name of the production variant to deploy to.
If not specified, defaults to 'AllTraffic'. (Default: None).
Returns:
Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint``
resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode,
Expand All @@ -4182,6 +4234,16 @@ def deploy(
if not hasattr(self, "built_model") and not hasattr(self, "_deployables"):
raise ValueError("Model needs to be built before deploying")

# Store IC-level parameters for use in _deploy_core_endpoint
if data_cache_config is not None:
kwargs["data_cache_config"] = data_cache_config
if base_inference_component_name is not None:
kwargs["base_inference_component_name"] = base_inference_component_name
if container is not None:
kwargs["container"] = container
if variant_name is not None:
kwargs["variant_name"] = variant_name

# Handle model customization deployment
if self._is_model_customization():
logger.info("Deploying Model Customization model")
Expand Down
77 changes: 77 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/model_builder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def build(self):
from sagemaker.core.resources import Model

# MLflow imports
from sagemaker.core.shapes import (
InferenceComponentDataCacheConfig,
InferenceComponentContainerSpecification,
)
from sagemaker.serve.model_format.mlflow.constants import (
MLFLOW_METADATA_FILE,
MLFLOW_MODEL_PATH,
Expand Down Expand Up @@ -3369,6 +3373,79 @@ def _extract_speculative_draft_model_provider(

return "auto"

def _resolve_data_cache_config(
self,
data_cache_config: Union[InferenceComponentDataCacheConfig, Dict[str, Any], None],
) -> Optional[InferenceComponentDataCacheConfig]:
"""Resolve data_cache_config to InferenceComponentDataCacheConfig.

Args:
data_cache_config: Either a dict with 'enable_caching' key,
an InferenceComponentDataCacheConfig instance, or None.

Returns:
InferenceComponentDataCacheConfig or None.

Raises:
ValueError: If data_cache_config is an unsupported type or dict
is missing the required 'enable_caching' key.
"""
if data_cache_config is None:
return None

if isinstance(data_cache_config, InferenceComponentDataCacheConfig):
return data_cache_config
elif isinstance(data_cache_config, dict):
if "enable_caching" not in data_cache_config:
raise ValueError(
"data_cache_config dict must contain the required 'enable_caching' key. "
"Example: {'enable_caching': True}"
)
return InferenceComponentDataCacheConfig(
enable_caching=data_cache_config["enable_caching"]
)
else:
raise ValueError(
f"data_cache_config must be a dict with 'enable_caching' key or an "
f"InferenceComponentDataCacheConfig instance, got {type(data_cache_config)}"
)

def _resolve_container_spec(
self,
container: Union[InferenceComponentContainerSpecification, Dict[str, Any], None],
) -> Optional[InferenceComponentContainerSpecification]:
"""Resolve container to InferenceComponentContainerSpecification.

Args:
container: Either a dict with container config keys (image, artifact_url,
environment), an InferenceComponentContainerSpecification instance, or None.

Returns:
InferenceComponentContainerSpecification or None.

Raises:
ValueError: If container is an unsupported type.
"""
if container is None:
return None

if isinstance(container, InferenceComponentContainerSpecification):
return container
elif isinstance(container, dict):
kwargs = {}
if "image" in container:
kwargs["image"] = container["image"]
if "artifact_url" in container:
kwargs["artifact_url"] = container["artifact_url"]
if "environment" in container:
kwargs["environment"] = container["environment"]
return InferenceComponentContainerSpecification(**kwargs)
else:
raise ValueError(
f"container must be a dict or an InferenceComponentContainerSpecification "
f"instance, got {type(container)}"
)

def get_huggingface_model_metadata(
self, model_id: str, hf_hub_token: Optional[str] = None
) -> dict:
Expand Down
151 changes: 151 additions & 0 deletions tests/unit/sagemaker/serve/test_resolve_ic_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Unit tests for _resolve_data_cache_config and _resolve_container_spec."""
from __future__ import absolute_import

import pytest

from sagemaker.core.shapes import (
InferenceComponentDataCacheConfig,
InferenceComponentContainerSpecification,
)
from sagemaker.serve.model_builder_utils import _ModelBuilderUtils


class ConcreteUtils(_ModelBuilderUtils):
"""Concrete class to test mixin methods."""
pass


@pytest.fixture
def utils():
return ConcreteUtils()


# ============================================================
# Tests for _resolve_data_cache_config
# ============================================================

class TestResolveDataCacheConfig:
def test_none_returns_none(self, utils):
assert utils._resolve_data_cache_config(None) is None

def test_already_typed_passthrough(self, utils):
config = InferenceComponentDataCacheConfig(enable_caching=True)
result = utils._resolve_data_cache_config(config)
assert result is config
assert result.enable_caching is True

def test_dict_with_enable_caching_true(self, utils):
result = utils._resolve_data_cache_config({"enable_caching": True})
assert isinstance(result, InferenceComponentDataCacheConfig)
assert result.enable_caching is True

def test_dict_with_enable_caching_false(self, utils):
result = utils._resolve_data_cache_config({"enable_caching": False})
assert isinstance(result, InferenceComponentDataCacheConfig)
assert result.enable_caching is False

def test_dict_missing_enable_caching_raises(self, utils):
with pytest.raises(ValueError, match="must contain the required 'enable_caching' key"):
utils._resolve_data_cache_config({})

def test_dict_with_extra_keys_still_works(self, utils):
"""Extra keys are ignored; only enable_caching is required."""
result = utils._resolve_data_cache_config(
{"enable_caching": True, "extra_key": "ignored"}
)
assert isinstance(result, InferenceComponentDataCacheConfig)
assert result.enable_caching is True

def test_invalid_type_raises(self, utils):
with pytest.raises(ValueError, match="data_cache_config must be a dict"):
utils._resolve_data_cache_config("invalid")

def test_invalid_type_int_raises(self, utils):
with pytest.raises(ValueError, match="data_cache_config must be a dict"):
utils._resolve_data_cache_config(42)

def test_invalid_type_list_raises(self, utils):
with pytest.raises(ValueError, match="data_cache_config must be a dict"):
utils._resolve_data_cache_config([True])


# ============================================================
# Tests for _resolve_container_spec
# ============================================================

class TestResolveContainerSpec:
def test_none_returns_none(self, utils):
assert utils._resolve_container_spec(None) is None

def test_already_typed_passthrough(self, utils):
spec = InferenceComponentContainerSpecification(
image="my-image:latest",
artifact_url="s3://bucket/artifact",
environment={"KEY": "VALUE"},
)
result = utils._resolve_container_spec(spec)
assert result is spec

def test_dict_full(self, utils):
result = utils._resolve_container_spec({
"image": "my-image:latest",
"artifact_url": "s3://bucket/artifact",
"environment": {"KEY": "VALUE"},
})
assert isinstance(result, InferenceComponentContainerSpecification)
assert result.image == "my-image:latest"
assert result.artifact_url == "s3://bucket/artifact"
assert result.environment == {"KEY": "VALUE"}

def test_dict_image_only(self, utils):
result = utils._resolve_container_spec({"image": "my-image:latest"})
assert isinstance(result, InferenceComponentContainerSpecification)
assert result.image == "my-image:latest"

def test_dict_artifact_url_only(self, utils):
result = utils._resolve_container_spec({"artifact_url": "s3://bucket/model.tar.gz"})
assert isinstance(result, InferenceComponentContainerSpecification)
assert result.artifact_url == "s3://bucket/model.tar.gz"

def test_dict_environment_only(self, utils):
result = utils._resolve_container_spec({"environment": {"A": "B"}})
assert isinstance(result, InferenceComponentContainerSpecification)
assert result.environment == {"A": "B"}

def test_dict_empty(self, utils):
"""Empty dict creates a spec with no fields set."""
result = utils._resolve_container_spec({})
assert isinstance(result, InferenceComponentContainerSpecification)

def test_dict_with_extra_keys(self, utils):
"""Extra keys are ignored."""
result = utils._resolve_container_spec({
"image": "img",
"unknown_key": "ignored",
})
assert isinstance(result, InferenceComponentContainerSpecification)
assert result.image == "img"

def test_invalid_type_raises(self, utils):
with pytest.raises(ValueError, match="container must be a dict"):
utils._resolve_container_spec("invalid")

def test_invalid_type_int_raises(self, utils):
with pytest.raises(ValueError, match="container must be a dict"):
utils._resolve_container_spec(123)

def test_invalid_type_list_raises(self, utils):
with pytest.raises(ValueError, match="container must be a dict"):
utils._resolve_container_spec([{"image": "img"}])
Loading