1010# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
13- """Unit tests for _resolve_data_cache_config and _resolve_container_spec ."""
13+ """Unit tests for IC parameter resolvers and wiring logic ."""
1414from __future__ import absolute_import
1515
1616import pytest
17+ from unittest .mock import MagicMock , patch , ANY
1718
1819from sagemaker .core .shapes import (
1920 InferenceComponentDataCacheConfig ,
@@ -61,7 +62,11 @@ def test_dict_missing_enable_caching_raises(self, utils):
6162 utils ._resolve_data_cache_config ({})
6263
6364 def test_dict_with_extra_keys_still_works (self , utils ):
64- """Extra keys are ignored; only enable_caching is required."""
65+ """Extra keys in the input dict are ignored (not forwarded to the Pydantic constructor).
66+
67+ The resolver only extracts 'enable_caching' from the dict, so extra keys
68+ do not cause Pydantic validation errors even if the model forbids extras.
69+ """
6570 result = utils ._resolve_data_cache_config (
6671 {"enable_caching" : True , "extra_key" : "ignored" }
6772 )
@@ -130,7 +135,11 @@ def test_dict_empty(self, utils):
130135 assert isinstance (result , InferenceComponentContainerSpecification )
131136
132137 def test_dict_with_extra_keys (self , utils ):
133- """Extra keys are ignored."""
138+ """Extra keys are filtered out before passing to the Pydantic constructor.
139+
140+ This ensures compatibility even if InferenceComponentContainerSpecification
141+ has extra='forbid' in its Pydantic model config.
142+ """
134143 result = utils ._resolve_container_spec ({
135144 "image" : "img" ,
136145 "unknown_key" : "ignored" ,
@@ -149,3 +158,231 @@ def test_invalid_type_int_raises(self, utils):
149158 def test_invalid_type_list_raises (self , utils ):
150159 with pytest .raises (ValueError , match = "container must be a dict" ):
151160 utils ._resolve_container_spec ([{"image" : "img" }])
161+
162+
163+ # ============================================================
164+ # Tests for core wiring logic in _deploy_core_endpoint
165+ # ============================================================
166+
167+ class TestDeployCoreEndpointWiring :
168+ """Tests that new IC parameters are correctly wired through _deploy_core_endpoint."""
169+
170+ def _make_model_builder (self ):
171+ """Create a minimally-configured ModelBuilder for testing _deploy_core_endpoint."""
172+ from sagemaker .serve .model_builder import ModelBuilder
173+
174+ mb = object .__new__ (ModelBuilder )
175+ # Set minimum required attributes
176+ mb .model_name = "test-model"
177+ mb .endpoint_name = None
178+ mb .inference_component_name = None
179+ mb .instance_type = "ml.g5.2xlarge"
180+ mb .instance_count = 1
181+ mb .accelerator_type = None
182+ mb ._tags = None
183+ mb .kms_key = None
184+ mb .async_inference_config = None
185+ mb .serverless_inference_config = None
186+ mb .model_data_download_timeout = None
187+ mb .resource_requirements = None
188+ mb .container_startup_health_check_timeout = None
189+ mb .inference_ami_version = None
190+ mb ._is_sharded_model = False
191+ mb ._enable_network_isolation = False
192+ mb .role_arn = "arn:aws:iam::123456789012:role/SageMakerRole"
193+ mb .vpc_config = None
194+ mb .inference_recommender_job_results = None
195+ mb .model_server = None
196+ mb .mode = None
197+ mb .region = "us-east-1"
198+
199+ # Mock built_model
200+ mb .built_model = MagicMock ()
201+ mb .built_model .model_name = "test-model"
202+
203+ # Mock sagemaker_session
204+ mb .sagemaker_session = MagicMock ()
205+ mb .sagemaker_session .endpoint_in_service_or_not .return_value = True
206+ mb .sagemaker_session .boto_session = MagicMock ()
207+ mb .sagemaker_session .boto_region_name = "us-east-1"
208+
209+ return mb
210+
211+ @patch ("sagemaker.serve.model_builder.Endpoint" )
212+ def test_variant_name_defaults_to_all_traffic (self , mock_endpoint_cls ):
213+ """When variant_name is not provided, it defaults to 'AllTraffic'."""
214+ mb = self ._make_model_builder ()
215+ mock_endpoint_cls .get .return_value = MagicMock ()
216+
217+ from sagemaker .core .inference_config import ResourceRequirements
218+ resources = ResourceRequirements (
219+ requests = {"memory" : 8192 , "num_accelerators" : 1 , "num_cpus" : 2 , "copies" : 1 }
220+ )
221+
222+ mb ._deploy_core_endpoint (
223+ endpoint_type = "INFERENCE_COMPONENT_BASED" ,
224+ resources = resources ,
225+ instance_type = "ml.g5.2xlarge" ,
226+ initial_instance_count = 1 ,
227+ wait = False ,
228+ )
229+
230+ # Verify create_inference_component was called with variant_name="AllTraffic"
231+ mb .sagemaker_session .create_inference_component .assert_called_once ()
232+ call_kwargs = mb .sagemaker_session .create_inference_component .call_args
233+ assert call_kwargs [1 ]["variant_name" ] == "AllTraffic" or \
234+ (len (call_kwargs [0 ]) > 2 and call_kwargs [0 ][2 ] == "AllTraffic" )
235+
236+ @patch ("sagemaker.serve.model_builder.Endpoint" )
237+ def test_variant_name_custom (self , mock_endpoint_cls ):
238+ """When variant_name is provided, it is used instead of 'AllTraffic'."""
239+ mb = self ._make_model_builder ()
240+ mock_endpoint_cls .get .return_value = MagicMock ()
241+
242+ from sagemaker .core .inference_config import ResourceRequirements
243+ resources = ResourceRequirements (
244+ requests = {"memory" : 8192 , "num_accelerators" : 1 , "num_cpus" : 2 , "copies" : 1 }
245+ )
246+
247+ mb ._deploy_core_endpoint (
248+ endpoint_type = "INFERENCE_COMPONENT_BASED" ,
249+ resources = resources ,
250+ instance_type = "ml.g5.2xlarge" ,
251+ initial_instance_count = 1 ,
252+ variant_name = "MyVariant" ,
253+ wait = False ,
254+ )
255+
256+ call_kwargs = mb .sagemaker_session .create_inference_component .call_args
257+ assert call_kwargs [1 ]["variant_name" ] == "MyVariant" or \
258+ (len (call_kwargs [0 ]) > 2 and call_kwargs [0 ][2 ] == "MyVariant" )
259+
260+ @patch ("sagemaker.serve.model_builder.Endpoint" )
261+ def test_data_cache_config_wired_into_spec (self , mock_endpoint_cls ):
262+ """data_cache_config dict is resolved and added to inference_component_spec."""
263+ mb = self ._make_model_builder ()
264+ mock_endpoint_cls .get .return_value = MagicMock ()
265+
266+ from sagemaker .core .inference_config import ResourceRequirements
267+ resources = ResourceRequirements (
268+ requests = {"memory" : 8192 , "num_accelerators" : 1 , "num_cpus" : 2 , "copies" : 1 }
269+ )
270+
271+ mb ._deploy_core_endpoint (
272+ endpoint_type = "INFERENCE_COMPONENT_BASED" ,
273+ resources = resources ,
274+ instance_type = "ml.g5.2xlarge" ,
275+ initial_instance_count = 1 ,
276+ data_cache_config = {"enable_caching" : True },
277+ wait = False ,
278+ )
279+
280+ call_kwargs = mb .sagemaker_session .create_inference_component .call_args
281+ spec = call_kwargs [1 ]["specification" ]
282+ assert "DataCacheConfig" in spec
283+ assert spec ["DataCacheConfig" ]["EnableCaching" ] is True
284+
285+ @patch ("sagemaker.serve.model_builder.Endpoint" )
286+ def test_base_inference_component_name_wired_into_spec (self , mock_endpoint_cls ):
287+ """base_inference_component_name is added to inference_component_spec."""
288+ mb = self ._make_model_builder ()
289+ mock_endpoint_cls .get .return_value = MagicMock ()
290+
291+ from sagemaker .core .inference_config import ResourceRequirements
292+ resources = ResourceRequirements (
293+ requests = {"memory" : 8192 , "num_accelerators" : 1 , "num_cpus" : 2 , "copies" : 1 }
294+ )
295+
296+ mb ._deploy_core_endpoint (
297+ endpoint_type = "INFERENCE_COMPONENT_BASED" ,
298+ resources = resources ,
299+ instance_type = "ml.g5.2xlarge" ,
300+ initial_instance_count = 1 ,
301+ base_inference_component_name = "base-ic-name" ,
302+ wait = False ,
303+ )
304+
305+ call_kwargs = mb .sagemaker_session .create_inference_component .call_args
306+ spec = call_kwargs [1 ]["specification" ]
307+ assert spec ["BaseInferenceComponentName" ] == "base-ic-name"
308+
309+ @patch ("sagemaker.serve.model_builder.Endpoint" )
310+ def test_container_wired_into_spec (self , mock_endpoint_cls ):
311+ """container dict is resolved and added to inference_component_spec."""
312+ mb = self ._make_model_builder ()
313+ mock_endpoint_cls .get .return_value = MagicMock ()
314+
315+ from sagemaker .core .inference_config import ResourceRequirements
316+ resources = ResourceRequirements (
317+ requests = {"memory" : 8192 , "num_accelerators" : 1 , "num_cpus" : 2 , "copies" : 1 }
318+ )
319+
320+ mb ._deploy_core_endpoint (
321+ endpoint_type = "INFERENCE_COMPONENT_BASED" ,
322+ resources = resources ,
323+ instance_type = "ml.g5.2xlarge" ,
324+ initial_instance_count = 1 ,
325+ container = {
326+ "image" : "my-image:latest" ,
327+ "artifact_url" : "s3://bucket/artifact" ,
328+ "environment" : {"KEY" : "VALUE" },
329+ },
330+ wait = False ,
331+ )
332+
333+ call_kwargs = mb .sagemaker_session .create_inference_component .call_args
334+ spec = call_kwargs [1 ]["specification" ]
335+ assert "Container" in spec
336+ assert spec ["Container" ]["Image" ] == "my-image:latest"
337+ assert spec ["Container" ]["ArtifactUrl" ] == "s3://bucket/artifact"
338+ assert spec ["Container" ]["Environment" ] == {"KEY" : "VALUE" }
339+
340+ @patch ("sagemaker.serve.model_builder.Endpoint" )
341+ def test_no_optional_params_no_extra_keys_in_spec (self , mock_endpoint_cls ):
342+ """When no optional IC params are provided, spec has no extra keys."""
343+ mb = self ._make_model_builder ()
344+ mock_endpoint_cls .get .return_value = MagicMock ()
345+
346+ from sagemaker .core .inference_config import ResourceRequirements
347+ resources = ResourceRequirements (
348+ requests = {"memory" : 8192 , "num_accelerators" : 1 , "num_cpus" : 2 , "copies" : 1 }
349+ )
350+
351+ mb ._deploy_core_endpoint (
352+ endpoint_type = "INFERENCE_COMPONENT_BASED" ,
353+ resources = resources ,
354+ instance_type = "ml.g5.2xlarge" ,
355+ initial_instance_count = 1 ,
356+ wait = False ,
357+ )
358+
359+ call_kwargs = mb .sagemaker_session .create_inference_component .call_args
360+ spec = call_kwargs [1 ]["specification" ]
361+ assert "DataCacheConfig" not in spec
362+ assert "BaseInferenceComponentName" not in spec
363+ assert "Container" not in spec
364+
365+ @patch ("sagemaker.serve.model_builder.Endpoint" )
366+ def test_data_cache_config_typed_object_wired (self , mock_endpoint_cls ):
367+ """InferenceComponentDataCacheConfig object is correctly wired."""
368+ mb = self ._make_model_builder ()
369+ mock_endpoint_cls .get .return_value = MagicMock ()
370+
371+ from sagemaker .core .inference_config import ResourceRequirements
372+ resources = ResourceRequirements (
373+ requests = {"memory" : 8192 , "num_accelerators" : 1 , "num_cpus" : 2 , "copies" : 1 }
374+ )
375+
376+ config = InferenceComponentDataCacheConfig (enable_caching = True )
377+ mb ._deploy_core_endpoint (
378+ endpoint_type = "INFERENCE_COMPONENT_BASED" ,
379+ resources = resources ,
380+ instance_type = "ml.g5.2xlarge" ,
381+ initial_instance_count = 1 ,
382+ data_cache_config = config ,
383+ wait = False ,
384+ )
385+
386+ call_kwargs = mb .sagemaker_session .create_inference_component .call_args
387+ spec = call_kwargs [1 ]["specification" ]
388+ assert spec ["DataCacheConfig" ]["EnableCaching" ] is True
0 commit comments