Skip to content

Commit f865a27

Browse files
committed
fix: address review comments (iteration #4)
1 parent 869474a commit f865a27

File tree

2 files changed

+290
-54
lines changed

2 files changed

+290
-54
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 54 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,6 +2702,52 @@ def _wait_for_endpoint(
27022702

27032703
return desc
27042704

2705+
@staticmethod
2706+
def _apply_optional_ic_params(inference_component_spec, **kwargs):
2707+
"""Apply optional IC-level parameters to an inference component spec dict.
2708+
2709+
Wires data_cache_config, base_inference_component_name, and container
2710+
into the given inference_component_spec dict. Shared by
2711+
_deploy_core_endpoint and _update_inference_component to avoid
2712+
code duplication.
2713+
2714+
Args:
2715+
inference_component_spec (dict): The spec dict to mutate in-place.
2716+
**kwargs: May contain data_cache_config, base_inference_component_name,
2717+
and container.
2718+
"""
2719+
from sagemaker.serve.model_builder_utils import _ModelBuilderUtils
2720+
2721+
ic_data_cache_config = kwargs.get("data_cache_config")
2722+
if ic_data_cache_config is not None:
2723+
resolved = _ModelBuilderUtils._resolve_data_cache_config(
2724+
None, ic_data_cache_config
2725+
)
2726+
if resolved is not None:
2727+
inference_component_spec["DataCacheConfig"] = {
2728+
"EnableCaching": resolved.enable_caching
2729+
}
2730+
2731+
ic_base_component_name = kwargs.get("base_inference_component_name")
2732+
if ic_base_component_name is not None:
2733+
inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name
2734+
2735+
ic_container = kwargs.get("container")
2736+
if ic_container is not None:
2737+
resolved_container = _ModelBuilderUtils._resolve_container_spec(
2738+
None, ic_container
2739+
)
2740+
if resolved_container is not None:
2741+
container_dict = {}
2742+
if resolved_container.image:
2743+
container_dict["Image"] = resolved_container.image
2744+
if resolved_container.artifact_url:
2745+
container_dict["ArtifactUrl"] = resolved_container.artifact_url
2746+
if resolved_container.environment:
2747+
container_dict["Environment"] = resolved_container.environment
2748+
if container_dict:
2749+
inference_component_spec["Container"] = container_dict
2750+
27052751
def _deploy_core_endpoint(self, **kwargs):
27062752
# Extract and update self parameters
27072753
initial_instance_count = kwargs.get(
@@ -2930,31 +2976,7 @@ def _deploy_core_endpoint(self, **kwargs):
29302976
}
29312977

29322978
# Wire optional IC-level parameters into the specification
2933-
ic_data_cache_config = kwargs.get("data_cache_config")
2934-
if ic_data_cache_config is not None:
2935-
resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config)
2936-
if resolved_cache_config is not None:
2937-
inference_component_spec["DataCacheConfig"] = {
2938-
"EnableCaching": resolved_cache_config.enable_caching
2939-
}
2940-
2941-
ic_base_component_name = kwargs.get("base_inference_component_name")
2942-
if ic_base_component_name is not None:
2943-
inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name
2944-
2945-
ic_container = kwargs.get("container")
2946-
if ic_container is not None:
2947-
resolved_container = self._resolve_container_spec(ic_container)
2948-
if resolved_container is not None:
2949-
container_dict = {}
2950-
if resolved_container.image:
2951-
container_dict["Image"] = resolved_container.image
2952-
if resolved_container.artifact_url:
2953-
container_dict["ArtifactUrl"] = resolved_container.artifact_url
2954-
if resolved_container.environment:
2955-
container_dict["Environment"] = resolved_container.environment
2956-
if container_dict:
2957-
inference_component_spec["Container"] = container_dict
2979+
self._apply_optional_ic_params(inference_component_spec, **kwargs)
29582980

29592981
runtime_config = {"CopyCount": resources.copy_count}
29602982
self.inference_component_name = (
@@ -3148,31 +3170,7 @@ def _update_inference_component(
31483170
}
31493171

31503172
# Wire optional IC-level parameters into the update specification
3151-
ic_data_cache_config = kwargs.get("data_cache_config")
3152-
if ic_data_cache_config is not None:
3153-
resolved_cache_config = self._resolve_data_cache_config(ic_data_cache_config)
3154-
if resolved_cache_config is not None:
3155-
inference_component_spec["DataCacheConfig"] = {
3156-
"EnableCaching": resolved_cache_config.enable_caching
3157-
}
3158-
3159-
ic_base_component_name = kwargs.get("base_inference_component_name")
3160-
if ic_base_component_name is not None:
3161-
inference_component_spec["BaseInferenceComponentName"] = ic_base_component_name
3162-
3163-
ic_container = kwargs.get("container")
3164-
if ic_container is not None:
3165-
resolved_container = self._resolve_container_spec(ic_container)
3166-
if resolved_container is not None:
3167-
container_dict = {}
3168-
if resolved_container.image:
3169-
container_dict["Image"] = resolved_container.image
3170-
if resolved_container.artifact_url:
3171-
container_dict["ArtifactUrl"] = resolved_container.artifact_url
3172-
if resolved_container.environment:
3173-
container_dict["Environment"] = resolved_container.environment
3174-
if container_dict:
3175-
inference_component_spec["Container"] = container_dict
3173+
self._apply_optional_ic_params(inference_component_spec, **kwargs)
31763174

31773175
runtime_config = {"CopyCount": resource_requirements.copy_count}
31783176

@@ -4384,6 +4382,9 @@ def _deploy_model_customization(
43844382
data_cache_config: Optional[Union["InferenceComponentDataCacheConfig", Dict[str, Any]]] = None,
43854383
**kwargs,
43864384
) -> Endpoint:
4385+
# NOTE: For backward compatibility, model customization deployments
4386+
# default variant_name to endpoint_name (not "AllTraffic") when the
4387+
# caller does not provide an explicit value.
43874388
"""Deploy a model customization (fine-tuned) model to an endpoint with inference components.
43884389
43894390
This method handles the special deployment flow for fine-tuned models, creating:
@@ -4423,8 +4424,9 @@ def _deploy_model_customization(
44234424
# Fetch model package
44244425
model_package = self._fetch_model_package()
44254426

4426-
# Resolve variant_name: use provided value or default to "AllTraffic"
4427-
effective_variant_name = variant_name or "AllTraffic"
4427+
# Resolve variant_name: preserve backward-compatible default of
4428+
# endpoint_name for model customization deployments.
4429+
effective_variant_name = variant_name or endpoint_name or "AllTraffic"
44284430

44294431
# Resolve data_cache_config if provided
44304432
resolved_data_cache_config = None

tests/unit/sagemaker/serve/test_resolve_ic_params.py

Lines changed: 236 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,82 @@ def test_invalid_type_list_raises(self, utils):
167167
utils._resolve_container_spec([{"image": "img"}])
168168

169169

170+
# ============================================================
171+
# Tests for _apply_optional_ic_params helper
172+
# ============================================================
173+
174+
class TestApplyOptionalIcParams:
175+
"""Tests for the static helper that wires optional IC params into a spec dict."""
176+
177+
def test_no_params_no_mutation(self):
178+
from sagemaker.serve.model_builder import ModelBuilder
179+
spec = {"ModelName": "m"}
180+
ModelBuilder._apply_optional_ic_params(spec)
181+
assert "DataCacheConfig" not in spec
182+
assert "BaseInferenceComponentName" not in spec
183+
assert "Container" not in spec
184+
185+
def test_data_cache_config_dict(self):
186+
from sagemaker.serve.model_builder import ModelBuilder
187+
spec = {"ModelName": "m"}
188+
ModelBuilder._apply_optional_ic_params(
189+
spec, data_cache_config={"enable_caching": True}
190+
)
191+
assert spec["DataCacheConfig"] == {"EnableCaching": True}
192+
193+
def test_data_cache_config_typed(self):
194+
from sagemaker.serve.model_builder import ModelBuilder
195+
spec = {"ModelName": "m"}
196+
cfg = InferenceComponentDataCacheConfig(enable_caching=False)
197+
ModelBuilder._apply_optional_ic_params(spec, data_cache_config=cfg)
198+
assert spec["DataCacheConfig"] == {"EnableCaching": False}
199+
200+
def test_base_inference_component_name(self):
201+
from sagemaker.serve.model_builder import ModelBuilder
202+
spec = {"ModelName": "m"}
203+
ModelBuilder._apply_optional_ic_params(
204+
spec, base_inference_component_name="base-ic"
205+
)
206+
assert spec["BaseInferenceComponentName"] == "base-ic"
207+
208+
def test_container_dict(self):
209+
from sagemaker.serve.model_builder import ModelBuilder
210+
spec = {"ModelName": "m"}
211+
ModelBuilder._apply_optional_ic_params(
212+
spec,
213+
container={
214+
"image": "img:latest",
215+
"artifact_url": "s3://b/a",
216+
"environment": {"K": "V"},
217+
},
218+
)
219+
assert spec["Container"] == {
220+
"Image": "img:latest",
221+
"ArtifactUrl": "s3://b/a",
222+
"Environment": {"K": "V"},
223+
}
224+
225+
def test_container_typed(self):
226+
from sagemaker.serve.model_builder import ModelBuilder
227+
spec = {"ModelName": "m"}
228+
c = InferenceComponentContainerSpecification(image="img")
229+
ModelBuilder._apply_optional_ic_params(spec, container=c)
230+
assert spec["Container"] == {"Image": "img"}
231+
232+
def test_all_params_together(self):
233+
from sagemaker.serve.model_builder import ModelBuilder
234+
spec = {"ModelName": "m"}
235+
ModelBuilder._apply_optional_ic_params(
236+
spec,
237+
data_cache_config={"enable_caching": True},
238+
base_inference_component_name="base",
239+
container={"image": "img"},
240+
)
241+
assert spec["DataCacheConfig"] == {"EnableCaching": True}
242+
assert spec["BaseInferenceComponentName"] == "base"
243+
assert spec["Container"] == {"Image": "img"}
244+
245+
170246
# ============================================================
171247
# Tests for core wiring logic in _deploy_core_endpoint
172248
# ============================================================
@@ -419,5 +495,163 @@ def test_variant_name_passed_to_production_variant_on_new_endpoint(self, mock_en
419495
# Verify production_variant was called with variant_name="CustomVariant"
420496
mock_pv.assert_called_once()
421497
pv_kwargs = mock_pv.call_args
422-
assert pv_kwargs.kwargs.get("variant_name") == "CustomVariant" or \
423-
(len(pv_kwargs.args) > 0 and False) # variant_name is always a kwarg
498+
assert pv_kwargs.kwargs.get("variant_name") == "CustomVariant"
499+
500+
501+
# ============================================================
502+
# Tests for _update_inference_component wiring
503+
# ============================================================
504+
505+
class TestUpdateInferenceComponentWiring:
506+
"""Tests that _update_inference_component correctly wires optional IC params."""
507+
508+
def _make_model_builder(self):
509+
from sagemaker.serve.model_builder import ModelBuilder
510+
511+
mb = object.__new__(ModelBuilder)
512+
mb.model_name = "test-model"
513+
mb.sagemaker_session = MagicMock()
514+
return mb
515+
516+
def test_update_ic_with_data_cache_config(self):
517+
mb = self._make_model_builder()
518+
from sagemaker.core.inference_config import ResourceRequirements
519+
resources = ResourceRequirements(
520+
requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1}
521+
)
522+
523+
mb._update_inference_component(
524+
"my-ic", resources, data_cache_config={"enable_caching": True}
525+
)
526+
527+
call_kwargs = mb.sagemaker_session.update_inference_component.call_args
528+
spec = call_kwargs.kwargs["specification"]
529+
assert spec["DataCacheConfig"] == {"EnableCaching": True}
530+
531+
def test_update_ic_with_container(self):
532+
mb = self._make_model_builder()
533+
from sagemaker.core.inference_config import ResourceRequirements
534+
resources = ResourceRequirements(
535+
requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1}
536+
)
537+
538+
mb._update_inference_component(
539+
"my-ic", resources, container={"image": "img:v1"}
540+
)
541+
542+
call_kwargs = mb.sagemaker_session.update_inference_component.call_args
543+
spec = call_kwargs.kwargs["specification"]
544+
assert spec["Container"] == {"Image": "img:v1"}
545+
546+
def test_update_ic_with_base_inference_component_name(self):
547+
mb = self._make_model_builder()
548+
from sagemaker.core.inference_config import ResourceRequirements
549+
resources = ResourceRequirements(
550+
requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1}
551+
)
552+
553+
mb._update_inference_component(
554+
"my-ic", resources, base_inference_component_name="base-ic"
555+
)
556+
557+
call_kwargs = mb.sagemaker_session.update_inference_component.call_args
558+
spec = call_kwargs.kwargs["specification"]
559+
assert spec["BaseInferenceComponentName"] == "base-ic"
560+
561+
def test_update_ic_no_optional_params(self):
562+
mb = self._make_model_builder()
563+
from sagemaker.core.inference_config import ResourceRequirements
564+
resources = ResourceRequirements(
565+
requests={"memory": 8192, "num_accelerators": 1, "num_cpus": 2, "copies": 1}
566+
)
567+
568+
mb._update_inference_component("my-ic", resources)
569+
570+
call_kwargs = mb.sagemaker_session.update_inference_component.call_args
571+
spec = call_kwargs.kwargs["specification"]
572+
assert "DataCacheConfig" not in spec
573+
assert "BaseInferenceComponentName" not in spec
574+
assert "Container" not in spec
575+
576+
577+
# ============================================================
578+
# Tests for deploy() parameter forwarding
579+
# ============================================================
580+
581+
class TestDeployParameterForwarding:
582+
"""Tests that deploy() correctly forwards new IC params into kwargs."""
583+
584+
def test_deploy_forwards_variant_name_to_kwargs(self):
585+
"""deploy() should set kwargs['variant_name'] to the provided value."""
586+
from sagemaker.serve.model_builder import ModelBuilder
587+
588+
mb = object.__new__(ModelBuilder)
589+
mb.built_model = MagicMock()
590+
mb._deployed = False
591+
mb._is_sharded_model = False
592+
mb.model_name = "test"
593+
mb.instance_type = "ml.m5.large"
594+
mb.endpoint_name = None
595+
mb.mode = None
596+
mb.model_server = None
597+
598+
# Mock _is_model_customization to return False
599+
mb._is_model_customization = MagicMock(return_value=False)
600+
# Mock _deploy to capture kwargs
601+
captured = {}
602+
603+
def fake_deploy(**kw):
604+
captured.update(kw)
605+
return MagicMock()
606+
607+
mb._deploy = fake_deploy
608+
609+
mb.deploy(
610+
endpoint_name="ep",
611+
instance_type="ml.m5.large",
612+
initial_instance_count=1,
613+
variant_name="MyVariant",
614+
data_cache_config={"enable_caching": True},
615+
base_inference_component_name="base-ic",
616+
container={"image": "img"},
617+
)
618+
619+
assert captured["variant_name"] == "MyVariant"
620+
assert captured["data_cache_config"] == {"enable_caching": True}
621+
assert captured["base_inference_component_name"] == "base-ic"
622+
assert captured["container"] == {"image": "img"}
623+
624+
def test_deploy_defaults_variant_name_to_all_traffic(self):
625+
"""deploy() should default variant_name to 'AllTraffic' when not provided."""
626+
from sagemaker.serve.model_builder import ModelBuilder
627+
628+
mb = object.__new__(ModelBuilder)
629+
mb.built_model = MagicMock()
630+
mb._deployed = False
631+
mb._is_sharded_model = False
632+
mb.model_name = "test"
633+
mb.instance_type = "ml.m5.large"
634+
mb.endpoint_name = None
635+
mb.mode = None
636+
mb.model_server = None
637+
mb._is_model_customization = MagicMock(return_value=False)
638+
639+
captured = {}
640+
641+
def fake_deploy(**kw):
642+
captured.update(kw)
643+
return MagicMock()
644+
645+
mb._deploy = fake_deploy
646+
647+
mb.deploy(
648+
endpoint_name="ep",
649+
instance_type="ml.m5.large",
650+
initial_instance_count=1,
651+
)
652+
653+
assert captured["variant_name"] == "AllTraffic"
654+
# Optional params should not be in kwargs when not provided
655+
assert "data_cache_config" not in captured
656+
assert "base_inference_component_name" not in captured
657+
assert "container" not in captured

0 commit comments

Comments
 (0)