Skip to content

Commit 7272444

Browse files
committed
fix: address review comments (iteration #1)
1 parent afcad51 commit 7272444

File tree

3 files changed

+51
-89
lines changed

3 files changed

+51
-89
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4211,8 +4211,12 @@ def deploy(
42114211
if not hasattr(self, "built_model") and not hasattr(self, "_deployables"):
42124212
raise ValueError("Model needs to be built before deploying")
42134213

4214-
# Centralize variant_name defaulting and always forward IC-level params
4215-
kwargs["variant_name"] = variant_name or "AllTraffic"
4214+
# Only forward variant_name when explicitly provided by the caller.
4215+
# Each downstream path has its own default:
4216+
# - _deploy_core_endpoint defaults to "AllTraffic"
4217+
# - _deploy_model_customization defaults to endpoint_name
4218+
if variant_name is not None:
4219+
kwargs["variant_name"] = variant_name
42164220
if inference_component_name is not None:
42174221
kwargs["inference_component_name"] = inference_component_name
42184222
if data_cache_config is not None:

sagemaker-serve/tests/integ/test_ic_deploy_params_integration.py

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,6 @@
3737
# Use the same JumpStart model as test_jumpstart_integration.py
3838
MODEL_ID = "huggingface-llm-falcon-7b-bf16"
3939

40-
# Training job for model customization path (same as test_model_customization_deployment.py)
41-
TRAINING_JOB_NAME = "meta-textgeneration-llama-3-2-1b-instruct-sft-20251201172445"
42-
4340

4441
def _cleanup_endpoint(endpoint_name, sagemaker_client):
4542
"""Delete endpoint, endpoint config, and all inference components."""
@@ -158,85 +155,4 @@ def test_deploy_with_data_cache_config_and_variant_name_via_ic_path():
158155
_cleanup_model(model_name, sagemaker_client)
159156

160157

161-
@pytest.mark.slow_test
162-
def test_deploy_with_data_cache_config_via_model_customization_path():
163-
"""Deploy a fine-tuned model via _deploy_model_customization with data_cache_config.
164-
165-
Verifies:
166-
- The IC was created with DataCacheConfig.EnableCaching == True
167-
- The variant_name defaults to endpoint_name (backward compat) when not explicitly provided
168-
"""
169-
from sagemaker.core.resources import TrainingJob
170-
171-
unique_id = uuid.uuid4().hex[:8]
172-
model_name = f"ic-mc-test-model-{unique_id}"
173-
endpoint_name = f"ic-mc-test-ep-{unique_id}"
174-
175-
sagemaker_client = boto3.client("sagemaker")
176-
177-
try:
178-
training_job = TrainingJob.get(training_job_name=TRAINING_JOB_NAME)
179-
model_builder = ModelBuilder(
180-
model=training_job, instance_type="ml.g5.4xlarge"
181-
)
182-
model_builder.accept_eula = True
183-
core_model = model_builder.build(model_name=model_name)
184-
logger.info("Model created: %s", core_model.model_name)
185-
186-
# Deploy with data_cache_config but WITHOUT explicit variant_name
187-
# so it should default to endpoint_name for model customization path
188-
endpoint = model_builder.deploy(
189-
endpoint_name=endpoint_name,
190-
initial_instance_count=1,
191-
data_cache_config={"enable_caching": True},
192-
)
193-
logger.info("Endpoint created: %s", endpoint.endpoint_name)
194-
195-
# Find inference components on this endpoint
196-
paginator = sagemaker_client.get_paginator("list_inference_components")
197-
ic_names = []
198-
for page in paginator.paginate(EndpointNameEquals=endpoint_name):
199-
for ic in page.get("InferenceComponents", []):
200-
ic_names.append(ic["InferenceComponentName"])
201-
202-
assert len(ic_names) > 0, (
203-
f"Expected at least one inference component on endpoint '{endpoint_name}'"
204-
)
205-
206-
# Check the first (or base) IC for DataCacheConfig
207-
# For LORA, the base IC should have data_cache_config; for non-LORA, the single IC.
208-
peft_type = model_builder._fetch_peft()
209-
if peft_type == "LORA":
210-
# Base IC is named <endpoint_name>-inference-component
211-
base_ic_name = f"{endpoint_name}-inference-component"
212-
else:
213-
base_ic_name = f"{endpoint_name}-inference-component"
214-
215-
ic_desc = sagemaker_client.describe_inference_component(
216-
InferenceComponentName=base_ic_name
217-
)
218158

219-
# Verify DataCacheConfig.EnableCaching == True
220-
spec = ic_desc.get("Specification", {})
221-
data_cache = spec.get("DataCacheConfig", {})
222-
assert data_cache.get("EnableCaching") is True, (
223-
f"Expected DataCacheConfig.EnableCaching=True, got {data_cache}"
224-
)
225-
226-
# Verify variant_name defaults to endpoint_name (backward compat)
227-
actual_variant = ic_desc.get("VariantName")
228-
assert actual_variant == endpoint_name, (
229-
f"Expected VariantName='{endpoint_name}' (backward compat default), "
230-
f"got '{actual_variant}'"
231-
)
232-
233-
logger.info(
234-
"Test passed: IC '%s' has DataCacheConfig.EnableCaching=True "
235-
"and VariantName='%s' (backward compat default)",
236-
base_ic_name,
237-
endpoint_name,
238-
)
239-
240-
finally:
241-
_cleanup_endpoint(endpoint_name, sagemaker_client)
242-
_cleanup_model(model_name, sagemaker_client)

tests/unit/sagemaker/serve/test_resolve_ic_params.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -621,8 +621,13 @@ def fake_deploy(**kw):
621621
assert captured["base_inference_component_name"] == "base-ic"
622622
assert captured["container"] == {"image": "img"}
623623

624-
def test_deploy_defaults_variant_name_to_all_traffic(self):
625-
"""deploy() should default variant_name to 'AllTraffic' when not provided."""
624+
def test_deploy_does_not_set_variant_name_when_not_provided(self):
625+
"""deploy() should NOT set variant_name in kwargs when not provided.
626+
627+
This allows downstream methods to use their own defaults:
628+
- _deploy_core_endpoint defaults to 'AllTraffic'
629+
- _deploy_model_customization defaults to endpoint_name
630+
"""
626631
from sagemaker.serve.model_builder import ModelBuilder
627632

628633
mb = object.__new__(ModelBuilder)
@@ -650,8 +655,45 @@ def fake_deploy(**kw):
650655
initial_instance_count=1,
651656
)
652657

653-
assert captured["variant_name"] == "AllTraffic"
658+
# variant_name should NOT be in kwargs when not explicitly provided
659+
assert "variant_name" not in captured
654660
# Optional params should not be in kwargs when not provided
655661
assert "data_cache_config" not in captured
656662
assert "base_inference_component_name" not in captured
657663
assert "container" not in captured
664+
665+
def test_deploy_forwards_variant_name_none_is_not_forwarded(self):
666+
"""deploy(variant_name=None) should NOT forward variant_name.
667+
668+
None is the default, so it should behave the same as not providing it.
669+
"""
670+
from sagemaker.serve.model_builder import ModelBuilder
671+
672+
mb = object.__new__(ModelBuilder)
673+
mb.built_model = MagicMock()
674+
mb._deployed = False
675+
mb._is_sharded_model = False
676+
mb.model_name = "test"
677+
mb.instance_type = "ml.m5.large"
678+
mb.endpoint_name = None
679+
mb.mode = None
680+
mb.model_server = None
681+
mb._is_model_customization = MagicMock(return_value=False)
682+
683+
captured = {}
684+
685+
def fake_deploy(**kw):
686+
captured.update(kw)
687+
return MagicMock()
688+
689+
mb._deploy = fake_deploy
690+
691+
mb.deploy(
692+
endpoint_name="ep",
693+
instance_type="ml.m5.large",
694+
initial_instance_count=1,
695+
variant_name=None,
696+
)
697+
698+
# variant_name=None should not be forwarded
699+
assert "variant_name" not in captured

0 commit comments

Comments
 (0)