@@ -2851,62 +2851,6 @@ def _deploy_core_endpoint(self, **kwargs):
28512851 if self .role_arn is None :
28522852 raise ValueError ("Role can not be null for deploying a model" )
28532853
2854- routing_config = _resolve_routing_config (routing_config )
2855-
2856- if (
2857- inference_recommendation_id is not None
2858- or self .inference_recommender_job_results is not None
2859- ):
2860- instance_type , initial_instance_count = self ._update_params (
2861- instance_type = instance_type ,
2862- initial_instance_count = initial_instance_count ,
2863- accelerator_type = accelerator_type ,
2864- async_inference_config = async_inference_config ,
2865- serverless_inference_config = serverless_inference_config ,
2866- explainer_config = explainer_config ,
2867- inference_recommendation_id = inference_recommendation_id ,
2868- inference_recommender_job_results = self .inference_recommender_job_results ,
2869- )
2870-
2871- is_async = async_inference_config is not None
2872- if is_async and not isinstance (async_inference_config , AsyncInferenceConfig ):
2873- raise ValueError ("async_inference_config needs to be a AsyncInferenceConfig object" )
2874-
2875- is_explainer_enabled = explainer_config is not None
2876- if is_explainer_enabled and not isinstance (explainer_config , ExplainerConfig ):
2877- raise ValueError ("explainer_config needs to be a ExplainerConfig object" )
2878-
2879- is_serverless = serverless_inference_config is not None
2880- if not is_serverless and not (instance_type and initial_instance_count ):
2881- raise ValueError (
2882- "Must specify instance type and instance count unless using serverless inference"
2883- )
2884-
2885- if is_serverless and not isinstance (serverless_inference_config , ServerlessInferenceConfig ):
2886- raise ValueError (
2887- "serverless_inference_config needs to be a ServerlessInferenceConfig object"
2888- )
2889-
2890- if self ._is_sharded_model :
2891- if endpoint_type != EndpointType .INFERENCE_COMPONENT_BASED :
2892- logger .warning (
2893- "Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
2894- "Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
2895- )
2896- endpoint_type = EndpointType .INFERENCE_COMPONENT_BASED
2897-
2898- if self ._enable_network_isolation :
2899- raise ValueError (
2900- "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
2901- "Loading of model requires network access."
2902- )
2903-
2904- if resources and resources .num_cpus and resources .num_cpus > 0 :
2905- logger .warning (
2906- "NumberOfCpuCoresRequired should be 0 for the best experience with SageMaker "
2907- "Fast Model Loading. Configure by setting `num_cpus` to 0 in `resources`."
2908- )
2909-
29102854 if endpoint_type == EndpointType .INFERENCE_COMPONENT_BASED :
29112855 if update_endpoint :
29122856 raise ValueError (
@@ -2933,10 +2877,14 @@ def _deploy_core_endpoint(self, **kwargs):
29332877 else :
29342878 managed_instance_scaling_config ["MinInstanceCount" ] = initial_instance_count
29352879
2880+ # Use user-provided variant_name or default to "AllTraffic"
2881+ ic_variant_name = kwargs .get ("variant_name" , "AllTraffic" )
2882+
29362883 if not self .sagemaker_session .endpoint_in_service_or_not (self .endpoint_name ):
29372884 production_variant = session_helper .production_variant (
29382885 instance_type = instance_type ,
29392886 initial_instance_count = initial_instance_count ,
2887+ variant_name = ic_variant_name ,
29402888 volume_size = volume_size ,
29412889 model_data_download_timeout = model_data_download_timeout ,
29422890 container_startup_health_check_timeout = container_startup_health_check_timeout ,
@@ -2986,9 +2934,9 @@ def _deploy_core_endpoint(self, **kwargs):
29862934 if ic_data_cache_config is not None :
29872935 resolved_cache_config = self ._resolve_data_cache_config (ic_data_cache_config )
29882936 if resolved_cache_config is not None :
2989- cache_dict = {"EnableCaching" : resolved_cache_config . enable_caching }
2990- # Forward any additional fields from the shape as they become available
2991- inference_component_spec [ "DataCacheConfig" ] = cache_dict
2937+ inference_component_spec [ "DataCacheConfig" ] = {
2938+ "EnableCaching" : resolved_cache_config . enable_caching
2939+ }
29922940
29932941 ic_base_component_name = kwargs .get ("base_inference_component_name" )
29942942 if ic_base_component_name is not None :
@@ -3015,9 +2963,6 @@ def _deploy_core_endpoint(self, **kwargs):
30152963 or unique_name_from_base (self .model_name )
30162964 )
30172965
3018- # Use user-provided variant_name or default to "AllTraffic"
3019- ic_variant_name = kwargs .get ("variant_name" , "AllTraffic" )
3020-
30212966 # [TODO]: Add endpoint_logging support
30222967 self .sagemaker_session .create_inference_component (
30232968 inference_component_name = self .inference_component_name ,
@@ -3201,6 +3146,34 @@ def _update_inference_component(
32013146 "StartupParameters" : startup_parameters ,
32023147 "ComputeResourceRequirements" : compute_rr ,
32033148 }
3149+
3150+ # Wire optional IC-level parameters into the update specification
3151+ ic_data_cache_config = kwargs .get ("data_cache_config" )
3152+ if ic_data_cache_config is not None :
3153+ resolved_cache_config = self ._resolve_data_cache_config (ic_data_cache_config )
3154+ if resolved_cache_config is not None :
3155+ inference_component_spec ["DataCacheConfig" ] = {
3156+ "EnableCaching" : resolved_cache_config .enable_caching
3157+ }
3158+
3159+ ic_base_component_name = kwargs .get ("base_inference_component_name" )
3160+ if ic_base_component_name is not None :
3161+ inference_component_spec ["BaseInferenceComponentName" ] = ic_base_component_name
3162+
3163+ ic_container = kwargs .get ("container" )
3164+ if ic_container is not None :
3165+ resolved_container = self ._resolve_container_spec (ic_container )
3166+ if resolved_container is not None :
3167+ container_dict = {}
3168+ if resolved_container .image :
3169+ container_dict ["Image" ] = resolved_container .image
3170+ if resolved_container .artifact_url :
3171+ container_dict ["ArtifactUrl" ] = resolved_container .artifact_url
3172+ if resolved_container .environment :
3173+ container_dict ["Environment" ] = resolved_container .environment
3174+ if container_dict :
3175+ inference_component_spec ["Container" ] = container_dict
3176+
32043177 runtime_config = {"CopyCount" : resource_requirements .copy_count }
32053178
32063179 return self .sagemaker_session .update_inference_component (
@@ -4160,6 +4133,7 @@ def deploy(
41604133 ] = None ,
41614134 custom_orchestrator_instance_type : str = None ,
41624135 custom_orchestrator_initial_instance_count : int = None ,
4136+ inference_component_name : Optional [str ] = None ,
41634137 data_cache_config : Optional [Union ["InferenceComponentDataCacheConfig" , Dict [str , Any ]]] = None ,
41644138 base_inference_component_name : Optional [str ] = None ,
41654139 container : Optional [Union ["InferenceComponentContainerSpecification" , Dict [str , Any ]]] = None ,
@@ -4197,6 +4171,9 @@ def deploy(
41974171 orchestrator deployment. (Default: None).
41984172 custom_orchestrator_initial_instance_count (int, optional): Initial instance count
41994173 for custom orchestrator deployment. (Default: None).
4174+ inference_component_name (str, optional): The name of the inference component
4175+ to create. Only used for inference-component-based endpoints. If not specified,
4176+ a unique name is generated from the model name. (Default: None).
42004177 data_cache_config (Union[InferenceComponentDataCacheConfig, dict], optional):
42014178 Data cache configuration for the inference component. Enables caching of model
42024179 artifacts and container images on instances for faster auto-scaling cold starts.
@@ -4213,6 +4190,7 @@ def deploy(
42134190 variant_name (str, optional): The name of the production variant to deploy to.
42144191 If not provided (or explicitly ``None``), defaults to ``'AllTraffic'``.
42154192 (Default: None).
4193+
42164194 Returns:
42174195 Union[Endpoint, LocalEndpoint, Transformer]: A ``sagemaker.core.resources.Endpoint``
42184196 resource representing the deployed endpoint, a ``LocalEndpoint`` for local mode,
@@ -4235,15 +4213,16 @@ def deploy(
42354213 if not hasattr (self , "built_model" ) and not hasattr (self , "_deployables" ):
42364214 raise ValueError ("Model needs to be built before deploying" )
42374215
4238- # Store IC-level parameters for use in _deploy_core_endpoint
4216+ # Centralize variant_name defaulting and always forward IC-level params
4217+ kwargs ["variant_name" ] = variant_name or "AllTraffic"
4218+ if inference_component_name is not None :
4219+ kwargs ["inference_component_name" ] = inference_component_name
42394220 if data_cache_config is not None :
42404221 kwargs ["data_cache_config" ] = data_cache_config
42414222 if base_inference_component_name is not None :
42424223 kwargs ["base_inference_component_name" ] = base_inference_component_name
42434224 if container is not None :
42444225 kwargs ["container" ] = container
4245- if variant_name is not None :
4246- kwargs ["variant_name" ] = variant_name
42474226
42484227 # Handle model customization deployment
42494228 if self ._is_model_customization ():
@@ -4401,6 +4380,8 @@ def _deploy_model_customization(
44014380 initial_instance_count : int = 1 ,
44024381 inference_component_name : Optional [str ] = None ,
44034382 inference_config : Optional [ResourceRequirements ] = None ,
4383+ variant_name : Optional [str ] = None ,
4384+ data_cache_config : Optional [Union ["InferenceComponentDataCacheConfig" , Dict [str , Any ]]] = None ,
44044385 ** kwargs ,
44054386 ) -> Endpoint :
44064387 """Deploy a model customization (fine-tuned) model to an endpoint with inference components.
@@ -4442,6 +4423,14 @@ def _deploy_model_customization(
44424423 # Fetch model package
44434424 model_package = self ._fetch_model_package ()
44444425
4426+ # Resolve variant_name: use provided value or default to "AllTraffic"
4427+ effective_variant_name = variant_name or "AllTraffic"
4428+
4429+ # Resolve data_cache_config if provided
4430+ resolved_data_cache_config = None
4431+ if data_cache_config is not None :
4432+ resolved_data_cache_config = self ._resolve_data_cache_config (data_cache_config )
4433+
44454434 # Check if endpoint exists
44464435 is_existing_endpoint = self ._does_endpoint_exist (endpoint_name )
44474436
@@ -4450,7 +4439,7 @@ def _deploy_model_customization(
44504439 endpoint_config_name = endpoint_name ,
44514440 production_variants = [
44524441 ProductionVariant (
4453- variant_name = endpoint_name ,
4442+ variant_name = effective_variant_name ,
44544443 instance_type = self .instance_type ,
44554444 initial_instance_count = initial_instance_count or 1 ,
44564445 )
@@ -4491,6 +4480,7 @@ def _deploy_model_customization(
44914480
44924481 base_ic_spec = InferenceComponentSpecification (
44934482 model_name = self .built_model .model_name ,
4483+ data_cache_config = resolved_data_cache_config ,
44944484 )
44954485 if inference_config is not None :
44964486 base_ic_spec .compute_resource_requirements = (
@@ -4507,7 +4497,7 @@ def _deploy_model_customization(
45074497 InferenceComponent .create (
45084498 inference_component_name = base_ic_name ,
45094499 endpoint_name = endpoint_name ,
4510- variant_name = endpoint_name ,
4500+ variant_name = effective_variant_name ,
45114501 specification = base_ic_spec ,
45124502 runtime_config = InferenceComponentRuntimeConfig (copy_count = 1 ),
45134503 tags = [{"key" : "Base" , "value" : base_model_recipe_name }],
@@ -4549,7 +4539,8 @@ def _deploy_model_customization(
45494539 ic_spec = InferenceComponentSpecification (
45504540 container = InferenceComponentContainerSpecification (
45514541 image = self .image_uri , artifact_url = artifact_url , environment = self .env_vars
4552- )
4542+ ),
4543+ data_cache_config = resolved_data_cache_config ,
45534544 )
45544545
45554546 if inference_config is not None :
@@ -4567,7 +4558,7 @@ def _deploy_model_customization(
45674558 InferenceComponent .create (
45684559 inference_component_name = inference_component_name ,
45694560 endpoint_name = endpoint_name ,
4570- variant_name = endpoint_name ,
4561+ variant_name = effective_variant_name ,
45714562 specification = ic_spec ,
45724563 runtime_config = InferenceComponentRuntimeConfig (copy_count = 1 ),
45734564 )
0 commit comments