Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion sagemaker-serve/src/sagemaker/serve/model_builder.py
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two fixes needed:

  1. Bug: variant_name always overrides model customization default. In deploy(), kwargs["variant_name"] = variant_name or "AllTraffic" always sets the key, so _deploy_model_customization never sees None and its backward-compat default of endpoint_name is dead code. Fix: only forward variant_name when explicitly provided:
# Replace:
kwargs["variant_name"] = variant_name or "AllTraffic"
# With:
if variant_name is not None:
    kwargs["variant_name"] = variant_name

Each downstream path already has its own default — _deploy_core_endpoint defaults to "AllTraffic" via kwargs.get("variant_name", "AllTraffic"), and _deploy_model_customization defaults to endpoint_name via variant_name or endpoint_name or "AllTraffic".

  1. Drop the second integ test (test_deploy_with_data_cache_config_via_model_customization_path). The model customization path requires ml.g5.4xlarge which has a non-adjustable account quota of 2 instances. When CI runs tests in parallel, this test competes with the existing test_model_customization_deployment.py for the same quota, causing flaky InsufficientInstanceCapacity failures. The model customization path's data_cache_config and variant_name wiring is already covered by unit tests. Keep only the first integ test (test_deploy_with_data_cache_config_and_variant_name_via_ic_path) which uses ml.g5.2xlarge.
    Also remove the TRAINING_JOB_NAME constant since it's no longer needed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not worry about CI failures! Removing the second integ test will fix one failure and the other failures are due to flakiness

Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
ModelLifeCycle,
DriftCheckBaselines,
InferenceComponentComputeResourceRequirements,
InferenceComponentDataCacheConfig,
Comment thread
aviruthen marked this conversation as resolved.
InferenceComponentContainerSpecification,
)
from sagemaker.core.resources import (
ModelPackage,
Expand Down Expand Up @@ -2978,18 +2980,49 @@ def _deploy_core_endpoint(self, **kwargs):
"StartupParameters": startup_parameters,
Comment thread
aviruthen marked this conversation as resolved.
"ComputeResourceRequirements": resources.get_compute_resource_requirements(),
}

# Wire optional IC-level parameters into the specification
ic_data_cache_config = kwargs.get("data_cache_config")
if ic_data_cache_config is not None:
resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config)
Comment thread
aviruthen marked this conversation as resolved.
Outdated
if resolved_cache_config is not None:
Comment thread
aviruthen marked this conversation as resolved.
Outdated
inference_component_spec["DataCacheConfig"] = {
"EnableCaching": resolved_cache_config.enable_caching
}

ic_base_component_name = kwargs.get("base_inference_component_name")
Comment thread
aviruthen marked this conversation as resolved.
Outdated
if ic_base_component_name is not None:
inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name

ic_container = kwargs.get("container")
if ic_container is not None:
resolved_container = self._resolve_container_spec(ic_container)
if resolved_container is not None:
container_dict = {}
if hasattr(resolved_container, "image") and resolved_container.image:
container_dict["Image"] = resolved_container.image
if hasattr(resolved_container, "artifact_url") and resolved_container.artifact_url:

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using hasattr checks on a Pydantic BaseModel (which InferenceComponentContainerSpecification likely is) is unnecessary — the attributes are always present (possibly None). Simplify to:

container_dict = {}
if resolved_container.image:
    container_dict["Image"] = resolved_container.image
if resolved_container.artifact_url:
    container_dict["ArtifactUrl"] = resolved_container.artifact_url
if resolved_container.environment:
    container_dict["Environment"] = resolved_container.environment

This is cleaner and more idiomatic for Pydantic models.

container_dict["ArtifactUrl"] = resolved_container.artifact_url
Comment thread
aviruthen marked this conversation as resolved.
Outdated
if hasattr(resolved_container, "environment") and resolved_container.environment:
container_dict["Environment"] = resolved_container.environment
if container_dict:
Comment thread
aviruthen marked this conversation as resolved.
Outdated
inference_component_spec["Container"] = container_dict

runtime_config = {"CopyCount": resources.copy_count}
self.inference_component_name = (
inference_component_name
or self.inference_component_name
or unique_name_from_base(self.model_name)
)

# Use user-provided variant_name or default to "AllTraffic"
ic_variant_name = kwargs.get("variant_name", "AllTraffic")

# [TODO]: Add endpoint_logging support
self.sagemaker_session.create_inference_component(
inference_component_name=self.inference_component_name,
endpoint_name=self.endpoint_name,
variant_name="AllTraffic", # default variant name
variant_name=ic_variant_name,
specification=inference_component_spec,
runtime_config=runtime_config,
tags=tags,
Expand Down Expand Up @@ -4127,6 +4160,10 @@ def deploy(
] = None,
custom_orchestrator_instance_type: str = None,
custom_orchestrator_initial_instance_count: int = None,
data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None,
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
base_inference_component_name: Optional[str] = None,
container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None,
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
variant_name: Optional[str] = None,
**kwargs,
) -> Union[Endpoint, LocalEndpoint, Transformer]:
"""Deploy the built model to an ``Endpoint``.
Expand Down Expand Up @@ -4160,6 +4197,21 @@ def deploy(
orchestrator deployment. (Default: None).
Comment thread
aviruthen marked this conversation as resolved.
custom_orchestrator_initial_instance_count (int, optional): Initial instance count
for custom orchestrator deployment. (Default: None).
data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional):
Data cache configuration for the inference component. Enables caching of model
artifacts and container images on instances for faster auto-scaling cold starts.
Can be a dict with 'enable_caching' key (e.g., {'enable_caching': True}) or an
InferenceComponentDataCacheConfig instance. (Default: None).
base_inference_component_name (str, optional): Name of the base inference component
for adapter deployments (e.g., LoRA adapters attached to a base model).
(Default: None).
container (Union[InferenceComponentContainerSpecification, dict], optional):
Custom container specification for the inference component, including image URI,
artifact URL, and environment variables. Can be a dict with keys 'image',
'artifact_url', 'environment' or an InferenceComponentContainerSpecification
instance. (Default: None).
variant_name (str, optional): The name of the production variant to deploy to.
If not specified, defaults to 'AllTraffic'. (Default: None).
Returns:
Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint``
resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode,
Expand All @@ -4182,6 +4234,16 @@ def deploy(
if not hasattr(self, "built_model") and not hasattr(self, "_deployables"):
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
raise ValueError("Model needs to be built before deploying")

# Store IC-level parameters for use in _deploy_core_endpoint
if data_cache_config is not None:
kwargs["data_cache_config"] = data_cache_config
if base_inference_component_name is not None:
kwargs["base_inference_component_name"] = base_inference_component_name
if container is not None:
kwargs["container"] = container
Comment thread
aviruthen marked this conversation as resolved.
if variant_name is not None:
kwargs["variant_name"] = variant_name

# Handle model customization deployment
if self._is_model_customization():
logger.info("Deploying Model Customization model")
Expand Down
77 changes: 77 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/model_builder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def build(self):
from sagemaker.core.resources import Model

# MLflow imports
from sagemaker.core.shapes import (
InferenceComponentDataCacheConfig,
InferenceComponentContainerSpecification,
)
from sagemaker.serve.model_format.mlflow.constants import (
MLFLOW_METADATA_FILE,
MLFLOW_MODEL_PATH,
Expand Down Expand Up @@ -3369,6 +3373,79 @@ def _extract_speculative_draft_model_provider(

Comment thread
aviruthen marked this conversation as resolved.
return "auto"
Comment thread
aviruthen marked this conversation as resolved.

Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
def _resolve_data_cache_config(
Comment thread
aviruthen marked this conversation as resolved.
self,
data_cache_config: Union[InferenceComponentDataCacheConfig, Dict[str, Any], None],
) -> Optional[InferenceComponentDataCacheConfig]:
"""Resolve data_cache_config to InferenceComponentDataCacheConfig.

Args:
data_cache_config: Either a dict with 'enable_caching' key,
an InferenceComponentDataCacheConfig instance, or None.

Returns:
InferenceComponentDataCacheConfig or None.

Raises:
ValueError: If data_cache_config is an unsupported type or dict
is missing the required 'enable_caching' key.
"""
if data_cache_config is None:
return None

Comment thread
aviruthen marked this conversation as resolved.
if isinstance(data_cache_config, InferenceComponentDataCacheConfig):
return data_cache_config
elif isinstance(data_cache_config, dict):
if "enable_caching" not in data_cache_config:
raise ValueError(
"data_cache_config dict must contain the required 'enable_caching' key. "
"Example: {'enable_caching': True}"
)
return InferenceComponentDataCacheConfig(
enable_caching=data_cache_config["enable_caching"]
)
else:
raise ValueError(
f"data_cache_config must be a dict with 'enable_caching' key or an "
f"InferenceComponentDataCacheConfig instance, got {type(data_cache_config)}"
)

Comment thread
aviruthen marked this conversation as resolved.
def _resolve_container_spec(
self,
container: Union[InferenceComponentContainerSpecification, Dict[str, Any], None],
) -> Optional[InferenceComponentContainerSpecification]:
"""Resolve container to InferenceComponentContainerSpecification.
Comment thread
aviruthen marked this conversation as resolved.

Args:
container: Either a dict with container config keys (image, artifact_url,
environment), an InferenceComponentContainerSpecification instance, or None.

Returns:
InferenceComponentContainerSpecification or None.

Raises:
ValueError: If container is an unsupported type.
"""
if container is None:
return None

if isinstance(container, InferenceComponentContainerSpecification):
return container
elif isinstance(container, dict):
kwargs = {}
if "image" in container:
kwargs["image"] = container["image"]
if "artifact_url" in container:
kwargs["artifact_url"] = container["artifact_url"]
if "environment" in container:
kwargs["environment"] = container["environment"]
return InferenceComponentContainerSpecification(**kwargs)
else:
raise ValueError(
f"container must be a dict or an InferenceComponentContainerSpecification "
f"instance, got {type(container)}"
)

def get_huggingface_model_metadata(
self, model_id: str, hf_hub_token: Optional[str] = None
) -> dict:
Expand Down
151 changes: 151 additions & 0 deletions tests/unit/sagemaker/serve/test_resolve_ic_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Comment thread
aviruthen marked this conversation as resolved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Unit tests for _resolve_data_cache_config and _resolve_container_spec."""
from __future__ import absolute_import

Comment thread
aviruthen marked this conversation as resolved.
import pytest

from sagemaker.core.shapes import (
InferenceComponentDataCacheConfig,
InferenceComponentContainerSpecification,
)
from sagemaker.serve.model_builder_utils import _ModelBuilderUtils


class ConcreteUtils(_ModelBuilderUtils):
"""Concrete class to test mixin methods."""
pass
Comment thread
aviruthen marked this conversation as resolved.


@pytest.fixture
def utils():
return ConcreteUtils()


# ============================================================
# Tests for _resolve_data_cache_config
# ============================================================

class TestResolveDataCacheConfig:
def test_none_returns_none(self, utils):
assert utils._resolve_data_cache_config(None) is None

def test_already_typed_passthrough(self, utils):
config = InferenceComponentDataCacheConfig(enable_caching=True)
result = utils._resolve_data_cache_config(config)
assert result is config
assert result.enable_caching is True

def test_dict_with_enable_caching_true(self, utils):
result = utils._resolve_data_cache_config({"enable_caching": True})
assert isinstance(result, InferenceComponentDataCacheConfig)
assert result.enable_caching is True

def test_dict_with_enable_caching_false(self, utils):
result = utils._resolve_data_cache_config({"enable_caching": False})
assert isinstance(result, InferenceComponentDataCacheConfig)
assert result.enable_caching is False

def test_dict_missing_enable_caching_raises(self, utils):
with pytest.raises(ValueError, match="must contain the required 'enable_caching' key"):
utils._resolve_data_cache_config({})

def test_dict_with_extra_keys_still_works(self, utils):
"""Extra keys are ignored; only enable_caching is required."""
result = utils._resolve_data_cache_config(
{"enable_caching": True, "extra_key": "ignored"}
)
assert isinstance(result, InferenceComponentDataCacheConfig)
assert result.enable_caching is True

def test_invalid_type_raises(self, utils):
with pytest.raises(ValueError, match="data_cache_config must be a dict"):
utils._resolve_data_cache_config("invalid")
Comment thread
aviruthen marked this conversation as resolved.

Comment thread
aviruthen marked this conversation as resolved.
def test_invalid_type_int_raises(self, utils):
with pytest.raises(ValueError, match="data_cache_config must be a dict"):
utils._resolve_data_cache_config(42)

def test_invalid_type_list_raises(self, utils):
with pytest.raises(ValueError, match="data_cache_config must be a dict"):
utils._resolve_data_cache_config([True])


# ============================================================
# Tests for _resolve_container_spec
# ============================================================

class TestResolveContainerSpec:
def test_none_returns_none(self, utils):
assert utils._resolve_container_spec(None) is None

def test_already_typed_passthrough(self, utils):
spec = InferenceComponentContainerSpecification(
image="my-image:latest",
artifact_url="s3://bucket/artifact",
environment={"KEY": "VALUE"},
)
result = utils._resolve_container_spec(spec)
assert result is spec

def test_dict_full(self, utils):
result = utils._resolve_container_spec({
"image": "my-image:latest",
"artifact_url": "s3://bucket/artifact",
"environment": {"KEY": "VALUE"},
})
assert isinstance(result, InferenceComponentContainerSpecification)
assert result.image == "my-image:latest"
assert result.artifact_url == "s3://bucket/artifact"
assert result.environment == {"KEY": "VALUE"}

def test_dict_image_only(self, utils):
result = utils._resolve_container_spec({"image": "my-image:latest"})
assert isinstance(result, InferenceComponentContainerSpecification)
assert result.image == "my-image:latest"

def test_dict_artifact_url_only(self, utils):
result = utils._resolve_container_spec({"artifact_url": "s3://bucket/model.tar.gz"})
assert isinstance(result, InferenceComponentContainerSpecification)
assert result.artifact_url == "s3://bucket/model.tar.gz"

def test_dict_environment_only(self, utils):
result = utils._resolve_container_spec({"environment": {"A": "B"}})
assert isinstance(result, InferenceComponentContainerSpecification)
assert result.environment == {"A": "B"}

def test_dict_empty(self, utils):
"""Empty dict creates a spec with no fields set."""
result = utils._resolve_container_spec({})
assert isinstance(result, InferenceComponentContainerSpecification)

Comment thread
aviruthen marked this conversation as resolved.
def test_dict_with_extra_keys(self, utils):
"""Extra keys are ignored."""
result = utils._resolve_container_spec({
"image": "img",
"unknown_key": "ignored",
})
assert isinstance(result, InferenceComponentContainerSpecification)
assert result.image == "img"

def test_invalid_type_raises(self, utils):
with pytest.raises(ValueError, match="container must be a dict"):
utils._resolve_container_spec("invalid")

def test_invalid_type_int_raises(self, utils):
with pytest.raises(ValueError, match="container must be a dict"):
utils._resolve_container_spec(123)

def test_invalid_type_list_raises(self, utils):
with pytest.raises(ValueError, match="container must be a dict"):
utils._resolve_container_spec([{"image": "img"}])
Loading