Skip to content
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion sagemaker-serve/src/sagemaker/serve/model_builder.py
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two fixes needed:

  1. Bug: variant_name always overrides model customization default. In deploy(), kwargs["variant_name"] = variant_name or "AllTraffic" always sets the key, so _deploy_model_customization never sees None and its backward-compat default of endpoint_name is dead code. Fix: only forward variant_name when explicitly provided:
# Replace:
kwargs["variant_name"] = variant_name or "AllTraffic"
# With:
if variant_name is not None:
    kwargs["variant_name"] = variant_name

Each downstream path already has its own default — _deploy_core_endpoint defaults to "AllTraffic" via kwargs.get("variant_name", "AllTraffic"), and _deploy_model_customization defaults to endpoint_name via variant_name or endpoint_name or "AllTraffic".

  1. Drop the second integ test (test_deploy_with_data_cache_config_via_model_customization_path). The model customization path requires ml.g5.4xlarge which has a non-adjustable account quota of 2 instances. When CI runs tests in parallel, this test competes with the existing test_model_customization_deployment.py for the same quota, causing flaky InsufficientInstanceCapacity failures. The model customization path's data_cache_config and variant_name wiring is already covered by unit tests. Keep only the first integ test (test_deploy_with_data_cache_config_and_variant_name_via_ic_path) which uses ml.g5.2xlarge.
    Also remove the TRAINING_JOB_NAME constant since it's no longer needed.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do not worry about CI failures! Removing the second integ test will fix one failure and the other failures are due to flakiness

Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
ModelLifeCycle,
DriftCheckBaselines,
InferenceComponentComputeResourceRequirements,
InferenceComponentDataCacheConfig,
Comment thread
aviruthen marked this conversation as resolved.
InferenceComponentContainerSpecification,
)
from sagemaker.core.resources import (
ModelPackage,
Expand Down Expand Up @@ -2978,18 +2980,49 @@ def _deploy_core_endpoint(self, **kwargs):
"StartupParameters": startup_parameters,
Comment thread
aviruthen marked this conversation as resolved.
"ComputeResourceRequirements": resources.get_compute_resource_requirements(),
}

# Wire optional IC-level parameters into the specification
ic_data_cache_config = kwargs.get("data_cache_config")
if ic_data_cache_config is not None:
resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config)
Comment thread
aviruthen marked this conversation as resolved.
Outdated
if resolved_cache_config is not None:
Comment thread
aviruthen marked this conversation as resolved.
Outdated
cache_dict = {"EnableCaching": resolved_cache_config.enable_caching}
# Forward any additional fields from the shape as they become available
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The DataCacheConfig is being manually serialized to a dict ({"EnableCaching": ...}), but the spec dict already uses PascalCase API keys. Consider whether create_inference_component expects the Pydantic shape object directly (as sagemaker-core typically handles serialization) rather than a manually constructed dict. If the session method handles serialization, passing the resolved InferenceComponentDataCacheConfig object directly would be more robust and future-proof as new fields are added to the shape.

inference_component_spec["DataCacheConfig"] = cache_dict

ic_base_component_name = kwargs.get("base_inference_component_name")
Comment thread
aviruthen marked this conversation as resolved.
Outdated
if ic_base_component_name is not None:
inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name

ic_container = kwargs.get("container")
if ic_container is not None:
resolved_container = self._resolve_container_spec(ic_container)
if resolved_container is not None:
container_dict = {}
if resolved_container.image:
container_dict["Image"] = resolved_container.image
if resolved_container.artifact_url:
container_dict["ArtifactUrl"] = resolved_container.artifact_url
Comment thread
aviruthen marked this conversation as resolved.
Outdated
if resolved_container.environment:
container_dict["Environment"] = resolved_container.environment
if container_dict:
Comment thread
aviruthen marked this conversation as resolved.
Outdated
inference_component_spec["Container"] = container_dict

runtime_config = {"CopyCount": resources.copy_count}
self.inference_component_name = (
inference_component_name
or self.inference_component_name
or unique_name_from_base(self.model_name)
)

# Use user-provided variant_name or default to "AllTraffic"
ic_variant_name = kwargs.get("variant_name", "AllTraffic")

# [TODO]: Add endpoint_logging support
self.sagemaker_session.create_inference_component(
inference_component_name=self.inference_component_name,
endpoint_name=self.endpoint_name,
variant_name="AllTraffic", # default variant name
variant_name=ic_variant_name,
specification=inference_component_spec,
runtime_config=runtime_config,
tags=tags,
Expand Down Expand Up @@ -4127,6 +4160,10 @@ def deploy(
] = None,
custom_orchestrator_instance_type: str = None,
custom_orchestrator_initial_instance_count: int = None,
data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None,
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
base_inference_component_name: Optional[str] = None,
container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None,
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
variant_name: Optional[str] = None,
**kwargs,
) -> Union[Endpoint, LocalEndpoint, Transformer]:
"""Deploy the built model to an ``Endpoint``.
Expand Down Expand Up @@ -4160,6 +4197,22 @@ def deploy(
orchestrator deployment. (Default: None).
Comment thread
aviruthen marked this conversation as resolved.
custom_orchestrator_initial_instance_count (int, optional): Initial instance count
for custom orchestrator deployment. (Default: None).
data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional):
Data cache configuration for the inference component. Enables caching of model
artifacts and container images on instances for faster auto-scaling cold starts.
Can be a dict with 'enable_caching' key (e.g., {'enable_caching': True}) or an
InferenceComponentDataCacheConfig instance. (Default: None).
base_inference_component_name (str, optional): Name of the base inference component
for adapter deployments (e.g., LoRA adapters attached to a base model).
(Default: None).
container (Union[InferenceComponentContainerSpecification, dict], optional):
Custom container specification for the inference component, including image URI,
artifact URL, and environment variables. Can be a dict with keys 'image',
'artifact_url', 'environment' or an InferenceComponentContainerSpecification
instance. (Default: None).
variant_name (str, optional): The name of the production variant to deploy to.
If not provided (or explicitly ``None``), defaults to ``'AllTraffic'``.
Comment thread
aviruthen marked this conversation as resolved.
(Default: None).
Returns:
Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint``
resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode,
Expand All @@ -4182,6 +4235,16 @@ def deploy(
if not hasattr(self, "built_model") and not hasattr(self, "_deployables"):
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
raise ValueError("Model needs to be built before deploying")

# Store IC-level parameters for use in _deploy_core_endpoint
if data_cache_config is not None:
kwargs["data_cache_config"] = data_cache_config
if base_inference_component_name is not None:
kwargs["base_inference_component_name"] = base_inference_component_name
if container is not None:
kwargs["container"] = container
Comment thread
aviruthen marked this conversation as resolved.
if variant_name is not None:
kwargs["variant_name"] = variant_name

# Handle model customization deployment
if self._is_model_customization():
logger.info("Deploying Model Customization model")
Expand Down
74 changes: 74 additions & 0 deletions sagemaker-serve/src/sagemaker/serve/model_builder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ def build(self):
from sagemaker.core.resources import Model

# MLflow imports
from sagemaker.core.shapes import (
InferenceComponentDataCacheConfig,
InferenceComponentContainerSpecification,
)
from sagemaker.serve.model_format.mlflow.constants import (
MLFLOW_METADATA_FILE,
MLFLOW_MODEL_PATH,
Expand Down Expand Up @@ -3369,6 +3373,76 @@ def _extract_speculative_draft_model_provider(

Comment thread
aviruthen marked this conversation as resolved.
return "auto"
Comment thread
aviruthen marked this conversation as resolved.

Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
Comment thread
aviruthen marked this conversation as resolved.
def _resolve_data_cache_config(
Comment thread
aviruthen marked this conversation as resolved.
self,
data_cache_config: Union[InferenceComponentDataCacheConfig, Dict[str, Any], None],
) -> Optional[InferenceComponentDataCacheConfig]:
"""Resolve data_cache_config to InferenceComponentDataCacheConfig.

Args:
data_cache_config: Either a dict with 'enable_caching' key,
an InferenceComponentDataCacheConfig instance, or None.

Returns:
InferenceComponentDataCacheConfig or None.

Raises:
ValueError: If data_cache_config is an unsupported type or dict
is missing the required 'enable_caching' key.
"""
if data_cache_config is None:
return None

Comment thread
aviruthen marked this conversation as resolved.
if isinstance(data_cache_config, InferenceComponentDataCacheConfig):
return data_cache_config
elif isinstance(data_cache_config, dict):
if "enable_caching" not in data_cache_config:
raise ValueError(
"data_cache_config dict must contain the required 'enable_caching' key. "
"Example: {'enable_caching': True}"
)
return InferenceComponentDataCacheConfig(
enable_caching=data_cache_config["enable_caching"]
)
else:
raise ValueError(
f"data_cache_config must be a dict with 'enable_caching' key or an "
f"InferenceComponentDataCacheConfig instance, got {type(data_cache_config)}"
)

Comment thread
aviruthen marked this conversation as resolved.
def _resolve_container_spec(
self,
container: Union[InferenceComponentContainerSpecification, Dict[str, Any], None],
) -> Optional[InferenceComponentContainerSpecification]:
"""Resolve container to InferenceComponentContainerSpecification.
Comment thread
aviruthen marked this conversation as resolved.

Args:
container: Either a dict with container config keys (image, artifact_url,
environment), an InferenceComponentContainerSpecification instance, or None.

Returns:
InferenceComponentContainerSpecification or None.

Raises:
ValueError: If container is an unsupported type.
"""
if container is None:
return None

if isinstance(container, InferenceComponentContainerSpecification):
return container
elif isinstance(container, dict):
# Only pass known keys to avoid Pydantic validation errors
# if the model has extra='forbid' configured
known_keys = {"image", "artifact_url", "environment"}
filtered = {k: v for k, v in container.items() if k in known_keys}
return InferenceComponentContainerSpecification(**filtered)
else:
raise ValueError(
f"container must be a dict or an InferenceComponentContainerSpecification "
f"instance, got {type(container)}"
)

def get_huggingface_model_metadata(
self, model_id: str, hf_hub_token: Optional[str] = None
) -> dict:
Expand Down
Loading
Loading