-
Notifications
You must be signed in to change notification settings - Fork 1.2k
fix: ModelBuilder.deploy() should expose DataCacheConfig and other CreateInferenceCom (5750) #5753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
05e59de
8bc8db3
ccc3425
869474a
f865a27
afcad51
7272444
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
aviruthen marked this conversation as resolved.
aviruthen marked this conversation as resolved.
aviruthen marked this conversation as resolved.
aviruthen marked this conversation as resolved.
aviruthen marked this conversation as resolved.
aviruthen marked this conversation as resolved.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Two fixes needed:
# Replace:
kwargs["variant_name"] = variant_name or "AllTraffic"
# With:
if variant_name is not None:
kwargs["variant_name"] = variant_nameEach downstream path already has its own default — _deploy_core_endpoint defaults to "AllTraffic" via kwargs.get("variant_name", "AllTraffic"), and _deploy_model_customization defaults to endpoint_name via variant_name or endpoint_name or "AllTraffic".
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not worry about CI failures! Removing the second integ test will fix one failure and the other failures are due to flakiness |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,6 +45,8 @@ | |
| ModelLifeCycle, | ||
| DriftCheckBaselines, | ||
| InferenceComponentComputeResourceRequirements, | ||
| InferenceComponentDataCacheConfig, | ||
|
aviruthen marked this conversation as resolved.
|
||
| InferenceComponentContainerSpecification, | ||
| ) | ||
| from sagemaker.core.resources import ( | ||
| ModelPackage, | ||
|
|
@@ -2978,18 +2980,49 @@ def _deploy_core_endpoint(self, **kwargs): | |
| "StartupParameters": startup_parameters, | ||
|
aviruthen marked this conversation as resolved.
|
||
| "ComputeResourceRequirements": resources.get_compute_resource_requirements(), | ||
| } | ||
|
|
||
| # Wire optional IC-level parameters into the specification | ||
| ic_data_cache_config = kwargs.get("data_cache_config") | ||
| if ic_data_cache_config is not None: | ||
| resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config) | ||
|
aviruthen marked this conversation as resolved.
Outdated
|
||
| if resolved_cache_config is not None: | ||
|
aviruthen marked this conversation as resolved.
Outdated
|
||
| cache_dict = {"EnableCaching": resolved_cache_config.enable_caching} | ||
| # Forward any additional fields from the shape as they become available | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
| inference_component_spec["DataCacheConfig"] = cache_dict | ||
|
|
||
| ic_base_component_name = kwargs.get("base_inference_component_name") | ||
|
aviruthen marked this conversation as resolved.
Outdated
|
||
| if ic_base_component_name is not None: | ||
| inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name | ||
|
|
||
| ic_container = kwargs.get("container") | ||
| if ic_container is not None: | ||
| resolved_container = self._resolve_container_spec(ic_container) | ||
| if resolved_container is not None: | ||
| container_dict = {} | ||
| if resolved_container.image: | ||
| container_dict["Image"] = resolved_container.image | ||
| if resolved_container.artifact_url: | ||
| container_dict["ArtifactUrl"] = resolved_container.artifact_url | ||
|
aviruthen marked this conversation as resolved.
Outdated
|
||
| if resolved_container.environment: | ||
| container_dict["Environment"] = resolved_container.environment | ||
| if container_dict: | ||
|
aviruthen marked this conversation as resolved.
Outdated
|
||
| inference_component_spec["Container"] = container_dict | ||
|
|
||
| runtime_config = {"CopyCount": resources.copy_count} | ||
| self.inference_component_name = ( | ||
| inference_component_name | ||
| or self.inference_component_name | ||
| or unique_name_from_base(self.model_name) | ||
| ) | ||
|
|
||
| # Use user-provided variant_name or default to "AllTraffic" | ||
| ic_variant_name = kwargs.get("variant_name", "AllTraffic") | ||
|
|
||
| # [TODO]: Add endpoint_logging support | ||
| self.sagemaker_session.create_inference_component( | ||
| inference_component_name=self.inference_component_name, | ||
| endpoint_name=self.endpoint_name, | ||
| variant_name="AllTraffic", # default variant name | ||
| variant_name=ic_variant_name, | ||
| specification=inference_component_spec, | ||
| runtime_config=runtime_config, | ||
| tags=tags, | ||
|
|
@@ -4127,6 +4160,10 @@ def deploy( | |
| ] = None, | ||
| custom_orchestrator_instance_type: str = None, | ||
| custom_orchestrator_initial_instance_count: int = None, | ||
| data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None, | ||
|
aviruthen marked this conversation as resolved.
aviruthen marked this conversation as resolved.
aviruthen marked this conversation as resolved.
|
||
| base_inference_component_name: Optional[str] = None, | ||
| container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None, | ||
|
aviruthen marked this conversation as resolved.
aviruthen marked this conversation as resolved.
|
||
| variant_name: Optional[str] = None, | ||
| **kwargs, | ||
| ) -> Union[Endpoint, LocalEndpoint, Transformer]: | ||
| """Deploy the built model to an ``Endpoint``. | ||
|
|
@@ -4160,6 +4197,22 @@ def deploy( | |
| orchestrator deployment. (Default: None). | ||
|
aviruthen marked this conversation as resolved.
|
||
| custom_orchestrator_initial_instance_count (int, optional): Initial instance count | ||
| for custom orchestrator deployment. (Default: None). | ||
| data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional): | ||
| Data cache configuration for the inference component. Enables caching of model | ||
| artifacts and container images on instances for faster auto-scaling cold starts. | ||
| Can be a dict with 'enable_caching' key (e.g., {'enable_caching': True}) or an | ||
| InferenceComponentDataCacheConfig instance. (Default: None). | ||
| base_inference_component_name (str, optional): Name of the base inference component | ||
| for adapter deployments (e.g., LoRA adapters attached to a base model). | ||
| (Default: None). | ||
| container (Union[InferenceComponentContainerSpecification, dict], optional): | ||
| Custom container specification for the inference component, including image URI, | ||
| artifact URL, and environment variables. Can be a dict with keys 'image', | ||
| 'artifact_url', 'environment' or an InferenceComponentContainerSpecification | ||
| instance. (Default: None). | ||
| variant_name (str, optional): The name of the production variant to deploy to. | ||
| If not provided (or explicitly ``None``), defaults to ``'AllTraffic'``. | ||
|
aviruthen marked this conversation as resolved.
|
||
| (Default: None). | ||
| Returns: | ||
| Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint`` | ||
| resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode, | ||
|
|
@@ -4182,6 +4235,16 @@ def deploy( | |
| if not hasattr(self, "built_model") and not hasattr(self, "_deployables"): | ||
|
aviruthen marked this conversation as resolved.
aviruthen marked this conversation as resolved.
|
||
| raise ValueError("Model needs to be built before deploying") | ||
|
|
||
| # Store IC-level parameters for use in _deploy_core_endpoint | ||
| if data_cache_config is not None: | ||
| kwargs["data_cache_config"] = data_cache_config | ||
| if base_inference_component_name is not None: | ||
| kwargs["base_inference_component_name"] = base_inference_component_name | ||
| if container is not None: | ||
| kwargs["container"] = container | ||
|
aviruthen marked this conversation as resolved.
|
||
| if variant_name is not None: | ||
| kwargs["variant_name"] = variant_name | ||
|
|
||
| # Handle model customization deployment | ||
| if self._is_model_customization(): | ||
| logger.info("Deploying Model Customization model") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.