2020 InferenceComponentDataCacheConfig ,
2121 InferenceComponentContainerSpecification ,
2222)
23+ from sagemaker .core .enums import EndpointType
2324from sagemaker .serve .model_builder_utils import _ModelBuilderUtils
2425
2526
@@ -303,7 +304,7 @@ def test_variant_name_defaults_to_all_traffic(self, mock_endpoint_cls):
303304 )
304305
305306 mb ._deploy_core_endpoint (
306- endpoint_type = " INFERENCE_COMPONENT_BASED" ,
307+ endpoint_type = EndpointType . INFERENCE_COMPONENT_BASED ,
307308 resources = resources ,
308309 instance_type = "ml.g5.2xlarge" ,
309310 initial_instance_count = 1 ,
@@ -327,7 +328,7 @@ def test_variant_name_custom(self, mock_endpoint_cls):
327328 )
328329
329330 mb ._deploy_core_endpoint (
330- endpoint_type = " INFERENCE_COMPONENT_BASED" ,
331+ endpoint_type = EndpointType . INFERENCE_COMPONENT_BASED ,
331332 resources = resources ,
332333 instance_type = "ml.g5.2xlarge" ,
333334 initial_instance_count = 1 ,
@@ -350,7 +351,7 @@ def test_data_cache_config_wired_into_spec(self, mock_endpoint_cls):
350351 )
351352
352353 mb ._deploy_core_endpoint (
353- endpoint_type = " INFERENCE_COMPONENT_BASED" ,
354+ endpoint_type = EndpointType . INFERENCE_COMPONENT_BASED ,
354355 resources = resources ,
355356 instance_type = "ml.g5.2xlarge" ,
356357 initial_instance_count = 1 ,
@@ -375,7 +376,7 @@ def test_base_inference_component_name_wired_into_spec(self, mock_endpoint_cls):
375376 )
376377
377378 mb ._deploy_core_endpoint (
378- endpoint_type = " INFERENCE_COMPONENT_BASED" ,
379+ endpoint_type = EndpointType . INFERENCE_COMPONENT_BASED ,
379380 resources = resources ,
380381 instance_type = "ml.g5.2xlarge" ,
381382 initial_instance_count = 1 ,
@@ -399,7 +400,7 @@ def test_container_wired_into_spec(self, mock_endpoint_cls):
399400 )
400401
401402 mb ._deploy_core_endpoint (
402- endpoint_type = " INFERENCE_COMPONENT_BASED" ,
403+ endpoint_type = EndpointType . INFERENCE_COMPONENT_BASED ,
403404 resources = resources ,
404405 instance_type = "ml.g5.2xlarge" ,
405406 initial_instance_count = 1 ,
@@ -430,7 +431,7 @@ def test_no_optional_params_no_extra_keys_in_spec(self, mock_endpoint_cls):
430431 )
431432
432433 mb ._deploy_core_endpoint (
433- endpoint_type = " INFERENCE_COMPONENT_BASED" ,
434+ endpoint_type = EndpointType . INFERENCE_COMPONENT_BASED ,
434435 resources = resources ,
435436 instance_type = "ml.g5.2xlarge" ,
436437 initial_instance_count = 1 ,
@@ -456,7 +457,7 @@ def test_data_cache_config_typed_object_wired(self, mock_endpoint_cls):
456457
457458 config = InferenceComponentDataCacheConfig (enable_caching = True )
458459 mb ._deploy_core_endpoint (
459- endpoint_type = " INFERENCE_COMPONENT_BASED" ,
460+ endpoint_type = EndpointType . INFERENCE_COMPONENT_BASED ,
460461 resources = resources ,
461462 instance_type = "ml.g5.2xlarge" ,
462463 initial_instance_count = 1 ,
@@ -484,7 +485,7 @@ def test_variant_name_passed_to_production_variant_on_new_endpoint(self, mock_en
484485 with patch ("sagemaker.serve.model_builder.session_helper.production_variant" ) as mock_pv :
485486 mock_pv .return_value = {"VariantName" : "CustomVariant" }
486487 mb ._deploy_core_endpoint (
487- endpoint_type = " INFERENCE_COMPONENT_BASED" ,
488+ endpoint_type = EndpointType . INFERENCE_COMPONENT_BASED ,
488489 resources = resources ,
489490 instance_type = "ml.g5.2xlarge" ,
490491 initial_instance_count = 1 ,
@@ -494,8 +495,12 @@ def test_variant_name_passed_to_production_variant_on_new_endpoint(self, mock_en
494495
495496 # Verify production_variant was called with variant_name="CustomVariant"
496497 mock_pv .assert_called_once ()
497- pv_kwargs = mock_pv .call_args
498- assert pv_kwargs .kwargs .get ("variant_name" ) == "CustomVariant"
498+ pv_call = mock_pv .call_args
499+ # variant_name may be in kwargs or positional args
500+ variant = pv_call .kwargs .get ("variant_name" ) or (
501+ pv_call .args [3 ] if len (pv_call .args ) > 3 else None
502+ )
503+ assert variant == "CustomVariant"
499504
500505
501506# ============================================================
0 commit comments