-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: ModelBuilder.deploy() should expose DataCacheConfig and other CreateInferenceCom (5750) #5753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
05e59de
8bc8db3
ccc3425
869474a
f865a27
afcad51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add 1-2 integration tests that verify the new IC-level parameters (data_cache_config, variant_name) work end-to-end through ModelBuilder.deploy(). Place the new test file at sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py. Follow the exact patterns from the existing integ tests for structure, cleanup, and assertions. Test 1: Deploy with data_cache_config via the standard IC path (_deploy_core_endpoint). Use a JumpStart model (like the pattern in test_jumpstart_integration.py, ideally use the same JumpStart model) but deploy with inference_config=ResourceRequirements to trigger the IC-based path. Pass data_cache_config={"enable_caching": True} and a custom variant_name. After deployment, use boto3 sagemaker_client.describe_inference_component() to verify:
Test 2: Deploy with data_cache_config via the model customization path (_deploy_model_customization). Use a TrainingJob-based ModelBuilder (like test_model_customization_deployment.py lines 131-170) and pass data_cache_config={"enable_caching": True}. After deployment, describe the inference component and verify DataCacheConfig.EnableCaching is True. Also verify the variant_name defaults to endpoint_name (backward compat) when variant_name is not explicitly provided. Both tests should include proper cleanup (delete endpoint, endpoint config, model, inference components) in a finally block. Use unique names with uuid to avoid collisions. Mark both with @pytest.mark.slow_test. $context sagemaker-serve/tests/integ/test_jumpstart_integration.py |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,8 @@ | |
| ModelLifeCycle, | ||
| DriftCheckBaselines, | ||
| InferenceComponentComputeResourceRequirements, | ||
| InferenceComponentDataCacheConfig, | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| InferenceComponentContainerSpecification, | ||
| ) | ||
| from sagemaker.core.resources import ( | ||
| ModelPackage, | ||
|
|
@@ -2978,18 +2980,49 @@ def _deploy_core_endpoint(self, **kwargs): | |
| "StartupParameters": startup_parameters, | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "ComputeResourceRequirements": resources.get_compute_resource_requirements(), | ||
| } | ||
|
|
||
| # Wire optional IC-level parameters into the specification | ||
| ic_data_cache_config = kwargs.get("data_cache_config") | ||
| if ic_data_cache_config is not None: | ||
| resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config) | ||
aviruthen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if resolved_cache_config is not None: | ||
aviruthen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| inference_component_spec["DataCacheConfig"] = { | ||
| "EnableCaching": resolved_cache_config.enable_caching | ||
| } | ||
|
|
||
| ic_base_component_name = kwargs.get("base_inference_component_name") | ||
aviruthen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if ic_base_component_name is not None: | ||
| inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name | ||
|
|
||
| ic_container = kwargs.get("container") | ||
| if ic_container is not None: | ||
| resolved_container = self._resolve_container_spec(ic_container) | ||
| if resolved_container is not None: | ||
| container_dict = {} | ||
| if hasattr(resolved_container, "image") and resolved_container.image: | ||
| container_dict["Image"] = resolved_container.image | ||
| if hasattr(resolved_container, "artifact_url") and resolved_container.artifact_url: | ||
|
||
| container_dict["ArtifactUrl"] = resolved_container.artifact_url | ||
aviruthen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if hasattr(resolved_container, "environment") and resolved_container.environment: | ||
| container_dict["Environment"] = resolved_container.environment | ||
| if container_dict: | ||
aviruthen marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| inference_component_spec["Container"] = container_dict | ||
|
|
||
| runtime_config = {"CopyCount": resources.copy_count} | ||
| self.inference_component_name = ( | ||
| inference_component_name | ||
| or self.inference_component_name | ||
| or unique_name_from_base(self.model_name) | ||
| ) | ||
|
|
||
| # Use user-provided variant_name or default to "AllTraffic" | ||
| ic_variant_name = kwargs.get("variant_name", "AllTraffic") | ||
|
|
||
| # [TODO]: Add endpoint_logging support | ||
| self.sagemaker_session.create_inference_component( | ||
| inference_component_name=self.inference_component_name, | ||
| endpoint_name=self.endpoint_name, | ||
| variant_name="AllTraffic", # default variant name | ||
| variant_name=ic_variant_name, | ||
| specification=inference_component_spec, | ||
| runtime_config=runtime_config, | ||
| tags=tags, | ||
|
|
@@ -4127,6 +4160,10 @@ def deploy( | |
| ] = None, | ||
| custom_orchestrator_instance_type: str = None, | ||
| custom_orchestrator_initial_instance_count: int = None, | ||
| data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None, | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| base_inference_component_name: Optional[str] = None, | ||
| container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None, | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| variant_name: Optional[str] = None, | ||
| **kwargs, | ||
| ) -> Union[Endpoint, LocalEndpoint, Transformer]: | ||
| """Deploy the built model to an ``Endpoint``. | ||
|
|
@@ -4160,6 +4197,21 @@ def deploy( | |
| orchestrator deployment. (Default: None). | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| custom_orchestrator_initial_instance_count (int, optional): Initial instance count | ||
| for custom orchestrator deployment. (Default: None). | ||
| data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional): | ||
| Data cache configuration for the inference component. Enables caching of model | ||
| artifacts and container images on instances for faster auto-scaling cold starts. | ||
| Can be a dict with 'enable_caching' key (e.g., {'enable_caching': True}) or an | ||
| InferenceComponentDataCacheConfig instance. (Default: None). | ||
| base_inference_component_name (str, optional): Name of the base inference component | ||
| for adapter deployments (e.g., LoRA adapters attached to a base model). | ||
| (Default: None). | ||
| container (Union[InferenceComponentContainerSpecification, dict], optional): | ||
| Custom container specification for the inference component, including image URI, | ||
| artifact URL, and environment variables. Can be a dict with keys 'image', | ||
| 'artifact_url', 'environment' or an InferenceComponentContainerSpecification | ||
| instance. (Default: None). | ||
| variant_name (str, optional): The name of the production variant to deploy to. | ||
| If not specified, defaults to 'AllTraffic'. (Default: None). | ||
| Returns: | ||
| Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint`` | ||
| resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode, | ||
|
|
@@ -4182,6 +4234,16 @@ def deploy( | |
| if not hasattr(self, "built_model") and not hasattr(self, "_deployables"): | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| raise ValueError("Model needs to be built before deploying") | ||
|
|
||
| # Store IC-level parameters for use in _deploy_core_endpoint | ||
| if data_cache_config is not None: | ||
| kwargs["data_cache_config"] = data_cache_config | ||
| if base_inference_component_name is not None: | ||
| kwargs["base_inference_component_name"] = base_inference_component_name | ||
| if container is not None: | ||
| kwargs["container"] = container | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if variant_name is not None: | ||
| kwargs["variant_name"] = variant_name | ||
|
|
||
| # Handle model customization deployment | ||
| if self._is_model_customization(): | ||
| logger.info("Deploying Model Customization model") | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """Unit tests for _resolve_data_cache_config and _resolve_container_spec.""" | ||
| from __future__ import absolute_import | ||
|
|
||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| import pytest | ||
|
|
||
| from sagemaker.core.shapes import ( | ||
| InferenceComponentDataCacheConfig, | ||
| InferenceComponentContainerSpecification, | ||
| ) | ||
| from sagemaker.serve.model_builder_utils import _ModelBuilderUtils | ||
|
|
||
|
|
||
| class ConcreteUtils(_ModelBuilderUtils): | ||
| """Concrete class to test mixin methods.""" | ||
| pass | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| @pytest.fixture | ||
| def utils(): | ||
| return ConcreteUtils() | ||
|
|
||
|
|
||
| # ============================================================ | ||
| # Tests for _resolve_data_cache_config | ||
| # ============================================================ | ||
|
|
||
| class TestResolveDataCacheConfig: | ||
| def test_none_returns_none(self, utils): | ||
| assert utils._resolve_data_cache_config(None) is None | ||
|
|
||
| def test_already_typed_passthrough(self, utils): | ||
| config = InferenceComponentDataCacheConfig(enable_caching=True) | ||
| result = utils._resolve_data_cache_config(config) | ||
| assert result is config | ||
| assert result.enable_caching is True | ||
|
|
||
| def test_dict_with_enable_caching_true(self, utils): | ||
| result = utils._resolve_data_cache_config({"enable_caching": True}) | ||
| assert isinstance(result, InferenceComponentDataCacheConfig) | ||
| assert result.enable_caching is True | ||
|
|
||
| def test_dict_with_enable_caching_false(self, utils): | ||
| result = utils._resolve_data_cache_config({"enable_caching": False}) | ||
| assert isinstance(result, InferenceComponentDataCacheConfig) | ||
| assert result.enable_caching is False | ||
|
|
||
| def test_dict_missing_enable_caching_raises(self, utils): | ||
| with pytest.raises(ValueError, match="must contain the required 'enable_caching' key"): | ||
| utils._resolve_data_cache_config({}) | ||
|
|
||
| def test_dict_with_extra_keys_still_works(self, utils): | ||
| """Extra keys are ignored; only enable_caching is required.""" | ||
| result = utils._resolve_data_cache_config( | ||
| {"enable_caching": True, "extra_key": "ignored"} | ||
| ) | ||
| assert isinstance(result, InferenceComponentDataCacheConfig) | ||
| assert result.enable_caching is True | ||
|
|
||
| def test_invalid_type_raises(self, utils): | ||
| with pytest.raises(ValueError, match="data_cache_config must be a dict"): | ||
| utils._resolve_data_cache_config("invalid") | ||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def test_invalid_type_int_raises(self, utils): | ||
| with pytest.raises(ValueError, match="data_cache_config must be a dict"): | ||
| utils._resolve_data_cache_config(42) | ||
|
|
||
| def test_invalid_type_list_raises(self, utils): | ||
| with pytest.raises(ValueError, match="data_cache_config must be a dict"): | ||
| utils._resolve_data_cache_config([True]) | ||
|
|
||
|
|
||
| # ============================================================ | ||
| # Tests for _resolve_container_spec | ||
| # ============================================================ | ||
|
|
||
| class TestResolveContainerSpec: | ||
| def test_none_returns_none(self, utils): | ||
| assert utils._resolve_container_spec(None) is None | ||
|
|
||
| def test_already_typed_passthrough(self, utils): | ||
| spec = InferenceComponentContainerSpecification( | ||
| image="my-image:latest", | ||
| artifact_url="s3://bucket/artifact", | ||
| environment={"KEY": "VALUE"}, | ||
| ) | ||
| result = utils._resolve_container_spec(spec) | ||
| assert result is spec | ||
|
|
||
| def test_dict_full(self, utils): | ||
| result = utils._resolve_container_spec({ | ||
| "image": "my-image:latest", | ||
| "artifact_url": "s3://bucket/artifact", | ||
| "environment": {"KEY": "VALUE"}, | ||
| }) | ||
| assert isinstance(result, InferenceComponentContainerSpecification) | ||
| assert result.image == "my-image:latest" | ||
| assert result.artifact_url == "s3://bucket/artifact" | ||
| assert result.environment == {"KEY": "VALUE"} | ||
|
|
||
| def test_dict_image_only(self, utils): | ||
| result = utils._resolve_container_spec({"image": "my-image:latest"}) | ||
| assert isinstance(result, InferenceComponentContainerSpecification) | ||
| assert result.image == "my-image:latest" | ||
|
|
||
| def test_dict_artifact_url_only(self, utils): | ||
| result = utils._resolve_container_spec({"artifact_url": "s3://bucket/model.tar.gz"}) | ||
| assert isinstance(result, InferenceComponentContainerSpecification) | ||
| assert result.artifact_url == "s3://bucket/model.tar.gz" | ||
|
|
||
| def test_dict_environment_only(self, utils): | ||
| result = utils._resolve_container_spec({"environment": {"A": "B"}}) | ||
| assert isinstance(result, InferenceComponentContainerSpecification) | ||
| assert result.environment == {"A": "B"} | ||
|
|
||
| def test_dict_empty(self, utils): | ||
| """Empty dict creates a spec with no fields set.""" | ||
| result = utils._resolve_container_spec({}) | ||
| assert isinstance(result, InferenceComponentContainerSpecification) | ||
|
|
||
aviruthen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def test_dict_with_extra_keys(self, utils): | ||
| """Extra keys are ignored.""" | ||
| result = utils._resolve_container_spec({ | ||
| "image": "img", | ||
| "unknown_key": "ignored", | ||
| }) | ||
| assert isinstance(result, InferenceComponentContainerSpecification) | ||
| assert result.image == "img" | ||
|
|
||
| def test_invalid_type_raises(self, utils): | ||
| with pytest.raises(ValueError, match="container must be a dict"): | ||
| utils._resolve_container_spec("invalid") | ||
|
|
||
| def test_invalid_type_int_raises(self, utils): | ||
| with pytest.raises(ValueError, match="container must be a dict"): | ||
| utils._resolve_container_spec(123) | ||
|
|
||
| def test_invalid_type_list_raises(self, utils): | ||
| with pytest.raises(ValueError, match="container must be a dict"): | ||
| utils._resolve_container_spec([{"image": "img"}]) | ||
Uh oh!
There was an error while loading. Please reload this page.