Skip to content

Commit 016dd16

Browse files
committed
Fix CI failures on new unit tests
1 parent 2420985 commit 016dd16

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

sagemaker-serve/tests/unit/test_resolve_ic_params.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
InferenceComponentDataCacheConfig,
2121
InferenceComponentContainerSpecification,
2222
)
23+
from sagemaker.core.enums import EndpointType
2324
from sagemaker.serve.model_builder_utils import _ModelBuilderUtils
2425

2526

@@ -303,7 +304,7 @@ def test_variant_name_defaults_to_all_traffic(self, mock_endpoint_cls):
303304
)
304305

305306
mb._deploy_core_endpoint(
306-
endpoint_type="INFERENCE_COMPONENT_BASED",
307+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
307308
resources=resources,
308309
instance_type="ml.g5.2xlarge",
309310
initial_instance_count=1,
@@ -327,7 +328,7 @@ def test_variant_name_custom(self, mock_endpoint_cls):
327328
)
328329

329330
mb._deploy_core_endpoint(
330-
endpoint_type="INFERENCE_COMPONENT_BASED",
331+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
331332
resources=resources,
332333
instance_type="ml.g5.2xlarge",
333334
initial_instance_count=1,
@@ -350,7 +351,7 @@ def test_data_cache_config_wired_into_spec(self, mock_endpoint_cls):
350351
)
351352

352353
mb._deploy_core_endpoint(
353-
endpoint_type="INFERENCE_COMPONENT_BASED",
354+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
354355
resources=resources,
355356
instance_type="ml.g5.2xlarge",
356357
initial_instance_count=1,
@@ -375,7 +376,7 @@ def test_base_inference_component_name_wired_into_spec(self, mock_endpoint_cls):
375376
)
376377

377378
mb._deploy_core_endpoint(
378-
endpoint_type="INFERENCE_COMPONENT_BASED",
379+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
379380
resources=resources,
380381
instance_type="ml.g5.2xlarge",
381382
initial_instance_count=1,
@@ -399,7 +400,7 @@ def test_container_wired_into_spec(self, mock_endpoint_cls):
399400
)
400401

401402
mb._deploy_core_endpoint(
402-
endpoint_type="INFERENCE_COMPONENT_BASED",
403+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
403404
resources=resources,
404405
instance_type="ml.g5.2xlarge",
405406
initial_instance_count=1,
@@ -430,7 +431,7 @@ def test_no_optional_params_no_extra_keys_in_spec(self, mock_endpoint_cls):
430431
)
431432

432433
mb._deploy_core_endpoint(
433-
endpoint_type="INFERENCE_COMPONENT_BASED",
434+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
434435
resources=resources,
435436
instance_type="ml.g5.2xlarge",
436437
initial_instance_count=1,
@@ -456,7 +457,7 @@ def test_data_cache_config_typed_object_wired(self, mock_endpoint_cls):
456457

457458
config = InferenceComponentDataCacheConfig(enable_caching=True)
458459
mb._deploy_core_endpoint(
459-
endpoint_type="INFERENCE_COMPONENT_BASED",
460+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
460461
resources=resources,
461462
instance_type="ml.g5.2xlarge",
462463
initial_instance_count=1,
@@ -484,7 +485,7 @@ def test_variant_name_passed_to_production_variant_on_new_endpoint(self, mock_en
484485
with patch("sagemaker.serve.model_builder.session_helper.production_variant") as mock_pv:
485486
mock_pv.return_value = {"VariantName": "CustomVariant"}
486487
mb._deploy_core_endpoint(
487-
endpoint_type="INFERENCE_COMPONENT_BASED",
488+
endpoint_type=EndpointType.INFERENCE_COMPONENT_BASED,
488489
resources=resources,
489490
instance_type="ml.g5.2xlarge",
490491
initial_instance_count=1,
@@ -494,8 +495,12 @@ def test_variant_name_passed_to_production_variant_on_new_endpoint(self, mock_en
494495

495496
# Verify production_variant was called with variant_name="CustomVariant"
496497
mock_pv.assert_called_once()
497-
pv_kwargs = mock_pv.call_args
498-
assert pv_kwargs.kwargs.get("variant_name") == "CustomVariant"
498+
pv_call = mock_pv.call_args
499+
# variant_name may be in kwargs or positional args
500+
variant = pv_call.kwargs.get("variant_name") or (
501+
pv_call.args[3] if len(pv_call.args) > 3 else None
502+
)
503+
assert variant == "CustomVariant"
499504

500505

501506
# ============================================================

0 commit comments

Comments
 (0)