@@ -167,6 +167,82 @@ def test_invalid_type_list_raises(self, utils):
167167 utils ._resolve_container_spec ([{"image" : "img" }])
168168
169169
170+ # ============================================================
171+ # Tests for _apply_optional_ic_params helper
172+ # ============================================================
173+
174+ class TestApplyOptionalIcParams :
175+ """Tests for the static helper that wires optional IC params into a spec dict."""
176+
177+ def test_no_params_no_mutation (self ):
178+ from sagemaker .serve .model_builder import ModelBuilder
179+ spec = {"ModelName" : "m" }
180+ ModelBuilder ._apply_optional_ic_params (spec )
181+ assert "DataCacheConfig" not in spec
182+ assert "BaseInferenceComponentName" not in spec
183+ assert "Container" not in spec
184+
185+ def test_data_cache_config_dict (self ):
186+ from sagemaker .serve .model_builder import ModelBuilder
187+ spec = {"ModelName" : "m" }
188+ ModelBuilder ._apply_optional_ic_params (
189+ spec , data_cache_config = {"enable_caching" : True }
190+ )
191+ assert spec ["DataCacheConfig" ] == {"EnableCaching" : True }
192+
193+ def test_data_cache_config_typed (self ):
194+ from sagemaker .serve .model_builder import ModelBuilder
195+ spec = {"ModelName" : "m" }
196+ cfg = InferenceComponentDataCacheConfig (enable_caching = False )
197+ ModelBuilder ._apply_optional_ic_params (spec , data_cache_config = cfg )
198+ assert spec ["DataCacheConfig" ] == {"EnableCaching" : False }
199+
200+ def test_base_inference_component_name (self ):
201+ from sagemaker .serve .model_builder import ModelBuilder
202+ spec = {"ModelName" : "m" }
203+ ModelBuilder ._apply_optional_ic_params (
204+ spec , base_inference_component_name = "base-ic"
205+ )
206+ assert spec ["BaseInferenceComponentName" ] == "base-ic"
207+
208+ def test_container_dict (self ):
209+ from sagemaker .serve .model_builder import ModelBuilder
210+ spec = {"ModelName" : "m" }
211+ ModelBuilder ._apply_optional_ic_params (
212+ spec ,
213+ container = {
214+ "image" : "img:latest" ,
215+ "artifact_url" : "s3://b/a" ,
216+ "environment" : {"K" : "V" },
217+ },
218+ )
219+ assert spec ["Container" ] == {
220+ "Image" : "img:latest" ,
221+ "ArtifactUrl" : "s3://b/a" ,
222+ "Environment" : {"K" : "V" },
223+ }
224+
225+ def test_container_typed (self ):
226+ from sagemaker .serve .model_builder import ModelBuilder
227+ spec = {"ModelName" : "m" }
228+ c = InferenceComponentContainerSpecification (image = "img" )
229+ ModelBuilder ._apply_optional_ic_params (spec , container = c )
230+ assert spec ["Container" ] == {"Image" : "img" }
231+
232+ def test_all_params_together (self ):
233+ from sagemaker .serve .model_builder import ModelBuilder
234+ spec = {"ModelName" : "m" }
235+ ModelBuilder ._apply_optional_ic_params (
236+ spec ,
237+ data_cache_config = {"enable_caching" : True },
238+ base_inference_component_name = "base" ,
239+ container = {"image" : "img" },
240+ )
241+ assert spec ["DataCacheConfig" ] == {"EnableCaching" : True }
242+ assert spec ["BaseInferenceComponentName" ] == "base"
243+ assert spec ["Container" ] == {"Image" : "img" }
244+
245+
170246# ============================================================
171247# Tests for core wiring logic in _deploy_core_endpoint
172248# ============================================================
@@ -419,5 +495,163 @@ def test_variant_name_passed_to_production_variant_on_new_endpoint(self, mock_en
419495 # Verify production_variant was called with variant_name="CustomVariant"
420496 mock_pv .assert_called_once ()
421497 pv_kwargs = mock_pv .call_args
422- assert pv_kwargs .kwargs .get ("variant_name" ) == "CustomVariant" or \
423- (len (pv_kwargs .args ) > 0 and False ) # variant_name is always a kwarg
498+ assert pv_kwargs .kwargs .get ("variant_name" ) == "CustomVariant"
499+
500+
501+ # ============================================================
502+ # Tests for _update_inference_component wiring
503+ # ============================================================
504+
505+ class TestUpdateInferenceComponentWiring :
506+ """Tests that _update_inference_component correctly wires optional IC params."""
507+
508+ def _make_model_builder (self ):
509+ from sagemaker .serve .model_builder import ModelBuilder
510+
511+ mb = object .__new__ (ModelBuilder )
512+ mb .model_name = "test-model"
513+ mb .sagemaker_session = MagicMock ()
514+ return mb
515+
516+ def test_update_ic_with_data_cache_config (self ):
517+ mb = self ._make_model_builder ()
518+ from sagemaker .core .inference_config import ResourceRequirements
519+ resources = ResourceRequirements (
520+ requests = {"memory" : 8192 , "num_accelerators" : 1 , "num_cpus" : 2 , "copies" : 1 }
521+ )
522+
523+ mb ._update_inference_component (
524+ "my-ic" , resources , data_cache_config = {"enable_caching" : True }
525+ )
526+
527+ call_kwargs = mb .sagemaker_session .update_inference_component .call_args
528+ spec = call_kwargs .kwargs ["specification" ]
529+ assert spec ["DataCacheConfig" ] == {"EnableCaching" : True }
530+
531+ def test_update_ic_with_container (self ):
532+ mb = self ._make_model_builder ()
533+ from sagemaker .core .inference_config import ResourceRequirements
534+ resources = ResourceRequirements (
535+ requests = {"memory" : 8192 , "num_accelerators" : 1 , "num_cpus" : 2 , "copies" : 1 }
536+ )
537+
538+ mb ._update_inference_component (
539+ "my-ic" , resources , container = {"image" : "img:v1" }
540+ )
541+
542+ call_kwargs = mb .sagemaker_session .update_inference_component .call_args
543+ spec = call_kwargs .kwargs ["specification" ]
544+ assert spec ["Container" ] == {"Image" : "img:v1" }
545+
546+ def test_update_ic_with_base_inference_component_name (self ):
547+ mb = self ._make_model_builder ()
548+ from sagemaker .core .inference_config import ResourceRequirements
549+ resources = ResourceRequirements (
550+ requests = {"memory" : 8192 , "num_accelerators" : 1 , "num_cpus" : 2 , "copies" : 1 }
551+ )
552+
553+ mb ._update_inference_component (
554+ "my-ic" , resources , base_inference_component_name = "base-ic"
555+ )
556+
557+ call_kwargs = mb .sagemaker_session .update_inference_component .call_args
558+ spec = call_kwargs .kwargs ["specification" ]
559+ assert spec ["BaseInferenceComponentName" ] == "base-ic"
560+
561+ def test_update_ic_no_optional_params (self ):
562+ mb = self ._make_model_builder ()
563+ from sagemaker .core .inference_config import ResourceRequirements
564+ resources = ResourceRequirements (
565+ requests = {"memory" : 8192 , "num_accelerators" : 1 , "num_cpus" : 2 , "copies" : 1 }
566+ )
567+
568+ mb ._update_inference_component ("my-ic" , resources )
569+
570+ call_kwargs = mb .sagemaker_session .update_inference_component .call_args
571+ spec = call_kwargs .kwargs ["specification" ]
572+ assert "DataCacheConfig" not in spec
573+ assert "BaseInferenceComponentName" not in spec
574+ assert "Container" not in spec
575+
576+
577+ # ============================================================
578+ # Tests for deploy() parameter forwarding
579+ # ============================================================
580+
581+ class TestDeployParameterForwarding :
582+ """Tests that deploy() correctly forwards new IC params into kwargs."""
583+
584+ def test_deploy_forwards_variant_name_to_kwargs (self ):
585+ """deploy() should set kwargs['variant_name'] to the provided value."""
586+ from sagemaker .serve .model_builder import ModelBuilder
587+
588+ mb = object .__new__ (ModelBuilder )
589+ mb .built_model = MagicMock ()
590+ mb ._deployed = False
591+ mb ._is_sharded_model = False
592+ mb .model_name = "test"
593+ mb .instance_type = "ml.m5.large"
594+ mb .endpoint_name = None
595+ mb .mode = None
596+ mb .model_server = None
597+
598+ # Mock _is_model_customization to return False
599+ mb ._is_model_customization = MagicMock (return_value = False )
600+ # Mock _deploy to capture kwargs
601+ captured = {}
602+
603+ def fake_deploy (** kw ):
604+ captured .update (kw )
605+ return MagicMock ()
606+
607+ mb ._deploy = fake_deploy
608+
609+ mb .deploy (
610+ endpoint_name = "ep" ,
611+ instance_type = "ml.m5.large" ,
612+ initial_instance_count = 1 ,
613+ variant_name = "MyVariant" ,
614+ data_cache_config = {"enable_caching" : True },
615+ base_inference_component_name = "base-ic" ,
616+ container = {"image" : "img" },
617+ )
618+
619+ assert captured ["variant_name" ] == "MyVariant"
620+ assert captured ["data_cache_config" ] == {"enable_caching" : True }
621+ assert captured ["base_inference_component_name" ] == "base-ic"
622+ assert captured ["container" ] == {"image" : "img" }
623+
624+ def test_deploy_defaults_variant_name_to_all_traffic (self ):
625+ """deploy() should default variant_name to 'AllTraffic' when not provided."""
626+ from sagemaker .serve .model_builder import ModelBuilder
627+
628+ mb = object .__new__ (ModelBuilder )
629+ mb .built_model = MagicMock ()
630+ mb ._deployed = False
631+ mb ._is_sharded_model = False
632+ mb .model_name = "test"
633+ mb .instance_type = "ml.m5.large"
634+ mb .endpoint_name = None
635+ mb .mode = None
636+ mb .model_server = None
637+ mb ._is_model_customization = MagicMock (return_value = False )
638+
639+ captured = {}
640+
641+ def fake_deploy (** kw ):
642+ captured .update (kw )
643+ return MagicMock ()
644+
645+ mb ._deploy = fake_deploy
646+
647+ mb .deploy (
648+ endpoint_name = "ep" ,
649+ instance_type = "ml.m5.large" ,
650+ initial_instance_count = 1 ,
651+ )
652+
653+ assert captured ["variant_name" ] == "AllTraffic"
654+ # Optional params should not be in kwargs when not provided
655+ assert "data_cache_config" not in captured
656+ assert "base_inference_component_name" not in captured
657+ assert "container" not in captured
0 commit comments