Skip to content

Commit 869474a

Browse files
committed
fix: address review comments (iteration #3)
1 parent ccc3425 commit 869474a

File tree

3 files changed

+113
-83
lines changed

3 files changed

+113
-83
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 60 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2851,62 +2851,6 @@ def _deploy_core_endpoint(self, **kwargs):
28512851
if self.role_arn is None:
28522852
raise ValueError("Role can not be null for deploying a model")
28532853

2854-
routing_config = _resolve_routing_config(routing_config)
2855-
2856-
if (
2857-
inference_recommendation_id is not None
2858-
or self.inference_recommender_job_results is not None
2859-
):
2860-
instance_type, initial_instance_count = self._update_params(
2861-
instance_type=instance_type,
2862-
initial_instance_count=initial_instance_count,
2863-
accelerator_type=accelerator_type,
2864-
async_inference_config=async_inference_config,
2865-
serverless_inference_config=serverless_inference_config,
2866-
explainer_config=explainer_config,
2867-
inference_recommendation_id=inference_recommendation_id,
2868-
inference_recommender_job_results=self.inference_recommender_job_results,
2869-
)
2870-
2871-
is_async = async_inference_config is not None
2872-
if is_async and not isinstance(async_inference_config, AsyncInferenceConfig):
2873-
raise ValueError("async_inference_config needs to be a AsyncInferenceConfig object")
2874-
2875-
is_explainer_enabled = explainer_config is not None
2876-
if is_explainer_enabled and not isinstance(explainer_config, ExplainerConfig):
2877-
raise ValueError("explainer_config needs to be a ExplainerConfig object")
2878-
2879-
is_serverless = serverless_inference_config is not None
2880-
if not is_serverless and not (instance_type and initial_instance_count):
2881-
raise ValueError(
2882-
"Must specify instance type and instance count unless using serverless inference"
2883-
)
2884-
2885-
if is_serverless and not isinstance(serverless_inference_config, ServerlessInferenceConfig):
2886-
raise ValueError(
2887-
"serverless_inference_config needs to be a ServerlessInferenceConfig object"
2888-
)
2889-
2890-
if self._is_sharded_model:
2891-
if endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
2892-
logger.warning(
2893-
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
2894-
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
2895-
)
2896-
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
2897-
2898-
if self._enable_network_isolation:
2899-
raise ValueError(
2900-
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
2901-
"Loading of model requires network access."
2902-
)
2903-
2904-
if resources and resources.num_cpus and resources.num_cpus > 0:
2905-
logger.warning(
2906-
"NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker "
2907-
"Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`."
2908-
)
2909-
29102854
if endpoint_type == EndpointType.INFERENCE_COMPONENT_BASED:
29112855
if update_endpoint:
29122856
raise ValueError(
@@ -2933,10 +2877,14 @@ def _deploy_core_endpoint(self, **kwargs):
29332877
else:
29342878
managed_instance_scaling_config["MinInstanceCount"] = initial_instance_count
29352879

2880+
# Use user-provided variant_name or default to "AllTraffic"
2881+
ic_variant_name = kwargs.get("variant_name", "AllTraffic")
2882+
29362883
if not self.sagemaker_session.endpoint_in_service_or_not(self.endpoint_name):
29372884
production_variant = session_helper.production_variant(
29382885
instance_type=instance_type,
29392886
initial_instance_count=initial_instance_count,
2887+
variant_name=ic_variant_name,
29402888
volume_size=volume_size,
29412889
model_data_download_timeout=model_data_download_timeout,
29422890
container_startup_health_check_timeout=container_startup_health_check_timeout,
@@ -2986,9 +2934,9 @@ def _deploy_core_endpoint(self, **kwargs):
29862934
if ic_data_cache_config is not None:
29872935
resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config)
29882936
if resolved_cache_config is not None:
2989-
cache_dict = {"EnableCaching": resolved_cache_config.enable_caching}
2990-
# Forward any additional fields from the shape as they become available
2991-
inference_component_spec["DataCacheConfig"] = cache_dict
2937+
inference_component_spec["DataCacheConfig"] = {
2938+
"EnableCaching": resolved_cache_config.enable_caching
2939+
}
29922940

29932941
ic_base_component_name = kwargs.get("base_inference_component_name")
29942942
if ic_base_component_name is not None:
@@ -3015,9 +2963,6 @@ def _deploy_core_endpoint(self, **kwargs):
30152963
or unique_name_from_base(self.model_name)
30162964
)
30172965

3018-
# Use user-provided variant_name or default to "AllTraffic"
3019-
ic_variant_name = kwargs.get("variant_name", "AllTraffic")
3020-
30212966
# [TODO]: Add endpoint_logging support
30222967
self.sagemaker_session.create_inference_component(
30232968
inference_component_name=self.inference_component_name,
@@ -3201,6 +3146,34 @@ def _update_inference_component(
32013146
"StartupParameters": startup_parameters,
32023147
"ComputeResourceRequirements": compute_rr,
32033148
}
3149+
3150+
# Wire optional IC-level parameters into the update specification
3151+
ic_data_cache_config = kwargs.get("data_cache_config")
3152+
if ic_data_cache_config is not None:
3153+
resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config)
3154+
if resolved_cache_config is not None:
3155+
inference_component_spec["DataCacheConfig"] = {
3156+
"EnableCaching": resolved_cache_config.enable_caching
3157+
}
3158+
3159+
ic_base_component_name = kwargs.get("base_inference_component_name")
3160+
if ic_base_component_name is not None:
3161+
inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name
3162+
3163+
ic_container = kwargs.get("container")
3164+
if ic_container is not None:
3165+
resolved_container = self._resolve_container_spec(ic_container)
3166+
if resolved_container is not None:
3167+
container_dict = {}
3168+
if resolved_container.image:
3169+
container_dict["Image"] = resolved_container.image
3170+
if resolved_container.artifact_url:
3171+
container_dict["ArtifactUrl"] = resolved_container.artifact_url
3172+
if resolved_container.environment:
3173+
container_dict["Environment"] = resolved_container.environment
3174+
if container_dict:
3175+
inference_component_spec["Container"] = container_dict
3176+
32043177
runtime_config = {"CopyCount": resource_requirements.copy_count}
32053178

32063179
return self.sagemaker_session.update_inference_component(
@@ -4160,6 +4133,7 @@ def deploy(
41604133
] = None,
41614134
custom_orchestrator_instance_type: str = None,
41624135
custom_orchestrator_initial_instance_count: int = None,
4136+
inference_component_name: Optional[str] = None,
41634137
data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None,
41644138
base_inference_component_name: Optional[str] = None,
41654139
container: Optional[Union["InferenceComponentContainerSpecification", Dict[str, Any]]] = None,
@@ -4197,6 +4171,9 @@ def deploy(
41974171
orchestrator deployment. (Default: None).
41984172
custom_orchestrator_initial_instance_count (int, optional): Initial instance count
41994173
for custom orchestrator deployment. (Default: None).
4174+
inference_component_name (str, optional): The name of the inference component
4175+
to create. Only used for inference-component-based endpoints. If not specified,
4176+
a unique name is generated from the model name. (Default: None).
42004177
data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional):
42014178
Data cache configuration for the inference component. Enables caching of model
42024179
artifacts and container images on instances for faster auto-scaling cold starts.
@@ -4213,6 +4190,7 @@ def deploy(
42134190
variant_name (str, optional): The name of the production variant to deploy to.
42144191
If not provided (or explicitly ``None``), defaults to ``'AllTraffic'``.
42154192
(Default: None).
4193+
42164194
Returns:
42174195
Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint``
42184196
resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode,
@@ -4235,15 +4213,16 @@ def deploy(
42354213
if not hasattr(self, "built_model") and not hasattr(self, "_deployables"):
42364214
raise ValueError("Model needs to be built before deploying")
42374215

4238-
# Store IC-level parameters for use in _deploy_core_endpoint
4216+
# Centralize variant_name defaulting and always forward IC-level params
4217+
kwargs["variant_name"] = variant_name or "AllTraffic"
4218+
if inference_component_name is not None:
4219+
kwargs["inference_component_name"] = inference_component_name
42394220
if data_cache_config is not None:
42404221
kwargs["data_cache_config"] = data_cache_config
42414222
if base_inference_component_name is not None:
42424223
kwargs["base_inference_component_name"] = base_inference_component_name
42434224
if container is not None:
42444225
kwargs["container"] = container
4245-
if variant_name is not None:
4246-
kwargs["variant_name"] = variant_name
42474226

42484227
# Handle model customization deployment
42494228
if self._is_model_customization():
@@ -4401,6 +4380,8 @@ def _deploy_model_customization(
44014380
initial_instance_count: int = 1,
44024381
inference_component_name: Optional[str] = None,
44034382
inference_config: Optional[ResourceRequirements] = None,
4383+
variant_name: Optional[str] = None,
4384+
data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None,
44044385
**kwargs,
44054386
) -> Endpoint:
44064387
"""Deploy a model customization (fine-tuned) model to an endpoint with inference components.
@@ -4442,6 +4423,14 @@ def _deploy_model_customization(
44424423
# Fetch model package
44434424
model_package = self._fetch_model_package()
44444425

4426+
# Resolve variant_name: use provided value or default to "AllTraffic"
4427+
effective_variant_name = variant_name or "AllTraffic"
4428+
4429+
# Resolve data_cache_config if provided
4430+
resolved_data_cache_config = None
4431+
if data_cache_config is not None:
4432+
resolved_data_cache_config = self._resolve_data_cache_config(data_cache_config)
4433+
44454434
# Check if endpoint exists
44464435
is_existing_endpoint = self._does_endpoint_exist(endpoint_name)
44474436

@@ -4450,7 +4439,7 @@ def _deploy_model_customization(
44504439
endpoint_config_name=endpoint_name,
44514440
production_variants=[
44524441
ProductionVariant(
4453-
variant_name=endpoint_name,
4442+
variant_name=effective_variant_name,
44544443
instance_type=self.instance_type,
44554444
initial_instance_count=initial_instance_count or 1,
44564445
)
@@ -4491,6 +4480,7 @@ def _deploy_model_customization(
44914480

44924481
base_ic_spec = InferenceComponentSpecification(
44934482
model_name=self.built_model.model_name,
4483+
data_cache_config=resolved_data_cache_config,
44944484
)
44954485
if inference_config is not None:
44964486
base_ic_spec.compute_resource_requirements = (
@@ -4507,7 +4497,7 @@ def _deploy_model_customization(
45074497
InferenceComponent.create(
45084498
inference_component_name=base_ic_name,
45094499
endpoint_name=endpoint_name,
4510-
variant_name=endpoint_name,
4500+
variant_name=effective_variant_name,
45114501
specification=base_ic_spec,
45124502
runtime_config=InferenceComponentRuntimeConfig(copy_count=1),
45134503
tags=[{"key": "Base", "value": base_model_recipe_name}],
@@ -4549,7 +4539,8 @@ def _deploy_model_customization(
45494539
ic_spec = InferenceComponentSpecification(
45504540
container=InferenceComponentContainerSpecification(
45514541
image=self.image_uri, artifact_url=artifact_url, environment=self.env_vars
4552-
)
4542+
),
4543+
data_cache_config=resolved_data_cache_config,
45534544
)
45544545

45554546
if inference_config is not None:
@@ -4567,7 +4558,7 @@ def _deploy_model_customization(
45674558
InferenceComponent.create(
45684559
inference_component_name=inference_component_name,
45694560
endpoint_name=endpoint_name,
4570-
variant_name=endpoint_name,
4561+
variant_name=effective_variant_name,
45714562
specification=ic_spec,
45724563
runtime_config=InferenceComponentRuntimeConfig(copy_count=1),
45734564
)

sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ def build(self):
7878
from sagemaker.serve.utils.hardware_detector import _total_inference_model_size_mib
7979
from sagemaker.serve.utils.types import ModelServer
8080
from sagemaker.core.resources import Model
81-
82-
# MLflow imports
8381
from sagemaker.core.shapes import (
8482
InferenceComponentDataCacheConfig,
8583
InferenceComponentContainerSpecification,
8684
)
85+
86+
# MLflow imports
8787
from sagemaker.serve.model_format.mlflow.constants import (
8888
MLFLOW_METADATA_FILE,
8989
MLFLOW_MODEL_PATH,
@@ -3380,7 +3380,8 @@ def _resolve_data_cache_config(
33803380
"""Resolve data_cache_config to InferenceComponentDataCacheConfig.
33813381
33823382
Args:
3383-
data_cache_config: Either a dict with 'enable_caching' key,
3383+
data_cache_config: Either a dict with 'enable_caching' key (and any future
3384+
fields supported by InferenceComponentDataCacheConfig),
33843385
an InferenceComponentDataCacheConfig instance, or None.
33853386
33863387
Returns:
@@ -3401,6 +3402,9 @@ def _resolve_data_cache_config(
34013402
"data_cache_config dict must contain the required 'enable_caching' key. "
34023403
"Example: {'enable_caching': True}"
34033404
)
3405+
# Pass only 'enable_caching' to avoid Pydantic validation errors
3406+
# if the model has extra='forbid'. As new fields are added to
3407+
# InferenceComponentDataCacheConfig, add them here.
34043408
return InferenceComponentDataCacheConfig(
34053409
enable_caching=data_cache_config["enable_caching"]
34063410
)

0 commit comments

Comments
 (0)