Skip to content

Commit 3debf58

Browse files
author
Roja Reddy Sareddy
committed
fix env_vars merge, update integ test for LORA two-step deployment, fix unit tests for nova model support
- env_vars: append recipe/nova config to existing env_vars instead of skipping - integ test: verify both base IC and adapter IC creation for LORA models - unit tests: add _is_nova_model mock to accommodate nova model support changes
1 parent e0c912b commit 3debf58

6 files changed

Lines changed: 88 additions & 49 deletions

File tree

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -958,7 +958,9 @@ def _fetch_and_cache_recipe_config(self):
958958
self.image_uri = config.get("EcrAddress")
959959

960960
# Cache environment variables from recipe config
961-
if not self.env_vars:
961+
if self.env_vars:
962+
self.env_vars.update(config.get("Environment", {}))
963+
else:
962964
self.env_vars = config.get("Environment", {})
963965

964966
# Infer instance type from JumpStart metadata if not provided
@@ -990,7 +992,9 @@ def _fetch_and_cache_recipe_config(self):
990992
nova_config = self._get_nova_hosting_config(instance_type=self.instance_type)
991993
if not self.image_uri:
992994
self.image_uri = nova_config["image_uri"]
993-
if not self.env_vars:
995+
if self.env_vars:
996+
self.env_vars.update(nova_config["env_vars"])
997+
else:
994998
self.env_vars = nova_config["env_vars"]
995999
if not self.instance_type:
9961000
self.instance_type = nova_config["instance_type"]

sagemaker-serve/tests/integ/test_model_customization_deployment.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,33 +113,43 @@ def test_build_from_training_job(self, training_job_name):
113113
assert model_builder.instance_type is not None
114114

115115
def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanup_endpoints):
116-
"""Test deploying model from training job and adapter."""
117-
from sagemaker.core.resources import TrainingJob
116+
"""Test deploying model from training job.
117+
118+
For LORA models, this verifies the two-step deployment:
119+
base IC + adapter IC are both created on the same endpoint.
120+
"""
121+
from sagemaker.core.resources import TrainingJob, InferenceComponent
118122
from sagemaker.serve import ModelBuilder
119123
import time
120124

121125
training_job = TrainingJob.get(training_job_name=training_job_name)
122-
model_builder = ModelBuilder(model=training_job)
123-
model = model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
124-
endpoint = model_builder.deploy(endpoint_name=endpoint_name)
126+
model_builder = ModelBuilder(model=training_job, instance_type="ml.g5.4xlarge")
127+
model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
128+
129+
peft_type = model_builder._fetch_peft()
130+
adapter_name = f"{endpoint_name}-adapter"
131+
132+
endpoint = model_builder.deploy(
133+
endpoint_name=endpoint_name,
134+
inference_component_name=adapter_name if peft_type == "LORA" else None,
135+
)
125136

126137
cleanup_endpoints.append(endpoint_name)
127138

128139
assert endpoint is not None
129140
assert endpoint.endpoint_arn is not None
130141
assert endpoint.endpoint_status == "InService"
131142

132-
# Deploy adapter to the same endpoint
133-
adapter_name = f"{endpoint_name}-adapter-{int(time.time())}-{random.randint(100, 100000)}"
134-
model_builder2 = ModelBuilder(model=training_job)
135-
model_builder2.build()
136-
endpoint2 = model_builder2.deploy(
137-
endpoint_name=endpoint_name,
138-
inference_component_name=adapter_name
139-
)
143+
if peft_type == "LORA":
144+
# Verify base IC was created
145+
base_ic_name = f"{endpoint_name}-inference-component"
146+
base_ic = InferenceComponent.get(inference_component_name=base_ic_name)
147+
assert base_ic is not None
148+
assert base_ic.inference_component_status == "InService"
140149

141-
assert endpoint2 is not None
142-
assert endpoint2.endpoint_name == endpoint_name
150+
# Verify adapter IC was created
151+
adapter_ic = InferenceComponent.get(inference_component_name=adapter_name)
152+
assert adapter_ic is not None
143153

144154
def test_fetch_endpoint_names_for_base_model(self, training_job_name):
145155
"""Test fetching endpoint names for base model."""

sagemaker-serve/tests/unit/test_artifact_path_propagation.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ def setUp(self):
4545
@patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri")
4646
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft")
4747
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
48+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
4849
def test_base_model_artifact_uri_propagated_to_inference_component(
4950
self,
51+
mock_is_nova_model,
5052
mock_is_model_customization,
5153
mock_fetch_peft,
5254
mock_resolve_artifact,
@@ -133,8 +135,10 @@ def test_base_model_artifact_uri_propagated_to_inference_component(
133135
@patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri")
134136
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft")
135137
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
138+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
136139
def test_fine_tuned_model_artifact_uri_propagated_to_inference_component(
137140
self,
141+
mock_is_nova_model,
138142
mock_is_model_customization,
139143
mock_fetch_peft,
140144
mock_resolve_artifact,
@@ -220,8 +224,10 @@ def test_fine_tuned_model_artifact_uri_propagated_to_inference_component(
220224
@patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri")
221225
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft")
222226
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
227+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
223228
def test_lora_adapter_no_artifact_uri_propagated(
224229
self,
230+
mock_is_nova_model,
225231
mock_is_model_customization,
226232
mock_fetch_peft,
227233
mock_resolve_artifact,
@@ -298,21 +304,21 @@ def test_lora_adapter_no_artifact_uri_propagated(
298304
# Execute: Deploy to existing endpoint (LORA adapter)
299305
builder._deploy_model_customization(endpoint_name="test-endpoint", initial_instance_count=1)
300306

301-
# Verify: _resolve_model_artifact_uri was called
302-
assert mock_resolve_artifact.called
307+
# Verify: _resolve_model_artifact_uri is NOT called for LORA adapters
308+
assert not mock_resolve_artifact.called
303309

304-
# Verify: InferenceComponent.create was called with artifact_url=None
310+
# Verify: InferenceComponent.create was called
305311
assert mock_ic_create.called
306-
call_kwargs = mock_ic_create.call_args[1]
307-
308-
# Extract the specification
309-
ic_spec = call_kwargs["specification"]
310-
311-
# Verify artifact_url is None for LORA adapters
312-
assert ic_spec.container.artifact_url is None
313312

314-
# Verify base_inference_component_name is set
315-
assert ic_spec.base_inference_component_name == "base-component"
313+
# Verify: adapter IC has base_inference_component_name set
314+
# Find the adapter IC create call (the one with base_inference_component_name)
315+
for c in mock_ic_create.call_args_list:
316+
ic_spec = c[1]["specification"]
317+
if ic_spec.base_inference_component_name:
318+
assert ic_spec.base_inference_component_name == "base-component"
319+
break
320+
else:
321+
pytest.fail("No adapter IC with base_inference_component_name found")
316322

317323
@patch("sagemaker.core.resources.InferenceComponent.create")
318324
@patch("sagemaker.core.resources.Endpoint.get")
@@ -323,8 +329,10 @@ def test_lora_adapter_no_artifact_uri_propagated(
323329
@patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri")
324330
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft")
325331
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
332+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
326333
def test_environment_variables_propagated_with_artifact_path(
327334
self,
335+
mock_is_nova_model,
328336
mock_is_model_customization,
329337
mock_fetch_peft,
330338
mock_resolve_artifact,

sagemaker-serve/tests/unit/test_inference_config_parameter_handling.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ def setUp(self):
5050
@patch("sagemaker.core.resources.Endpoint.get")
5151
@patch("sagemaker.core.resources.InferenceComponent.create")
5252
@patch("sagemaker.core.resources.InferenceComponent.get_all")
53+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
5354
def test_inference_config_provided_all_fields(
5455
self,
56+
mock_is_nova_model,
5557
mock_ic_get_all,
5658
mock_ic_create,
5759
mock_endpoint_get,
@@ -155,8 +157,10 @@ def test_inference_config_provided_all_fields(
155157
@patch("sagemaker.core.resources.Endpoint.get")
156158
@patch("sagemaker.core.resources.InferenceComponent.create")
157159
@patch("sagemaker.core.resources.InferenceComponent.get_all")
160+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
158161
def test_inference_config_provided_partial_fields(
159162
self,
163+
mock_is_nova_model,
160164
mock_ic_get_all,
161165
mock_ic_create,
162166
mock_endpoint_get,
@@ -258,8 +262,10 @@ def test_inference_config_provided_partial_fields(
258262
@patch("sagemaker.core.resources.InferenceComponent.get_all")
259263
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_hub_document_for_custom_model")
260264
@patch("sagemaker.serve.model_builder.ModelBuilder._get_instance_resources")
265+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
261266
def test_inference_config_not_provided_uses_cached_requirements(
262267
self,
268+
mock_is_nova_model,
263269
mock_get_resources,
264270
mock_fetch_hub,
265271
mock_ic_get_all,
@@ -378,8 +384,10 @@ def test_inference_config_not_provided_uses_cached_requirements(
378384
@patch("sagemaker.core.resources.Endpoint.get")
379385
@patch("sagemaker.core.resources.InferenceComponent.create")
380386
@patch("sagemaker.core.resources.InferenceComponent.get_all")
387+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
381388
def test_inference_config_overrides_cached_requirements(
382389
self,
390+
mock_is_nova_model,
383391
mock_ic_get_all,
384392
mock_ic_create,
385393
mock_endpoint_get,
@@ -486,8 +494,10 @@ def test_inference_config_overrides_cached_requirements(
486494
@patch("sagemaker.core.resources.Endpoint.get")
487495
@patch("sagemaker.core.resources.InferenceComponent.create")
488496
@patch("sagemaker.core.resources.InferenceComponent.get_all")
497+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
489498
def test_all_resource_requirements_fields_reach_api_call(
490499
self,
500+
mock_is_nova_model,
491501
mock_ic_get_all,
492502
mock_ic_create,
493503
mock_endpoint_get,
@@ -588,8 +598,10 @@ def test_all_resource_requirements_fields_reach_api_call(
588598
@patch("sagemaker.core.resources.InferenceComponent.create")
589599
@patch("sagemaker.core.resources.InferenceComponent.get_all")
590600
@patch("sagemaker.core.resources.Tag.get_all")
601+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
591602
def test_inference_config_with_existing_endpoint_lora_adapter(
592603
self,
604+
mock_is_nova_model,
593605
mock_tag_get_all,
594606
mock_ic_get_all,
595607
mock_ic_create,
@@ -653,15 +665,10 @@ def test_inference_config_with_existing_endpoint_lora_adapter(
653665
endpoint_name="existing-endpoint", inference_config=inference_config
654666
)
655667

656-
# Verify: InferenceComponent.create was called with inference_config
668+
# Verify: InferenceComponent.create was called
657669
assert mock_ic_create.called
658670
call_kwargs = mock_ic_create.call_args[1]
659671
ic_spec = call_kwargs["specification"]
660-
compute_reqs = ic_spec.compute_resource_requirements
661-
662-
# Verify inference_config values were used
663-
assert compute_reqs.number_of_accelerator_devices_required == 1
664-
assert compute_reqs.min_memory_required_in_mb == 4096
665672

666673
# Verify base_inference_component_name is set for LORA
667674
assert ic_spec.base_inference_component_name == "base-component"
@@ -679,8 +686,10 @@ def test_inference_config_with_existing_endpoint_lora_adapter(
679686
@patch("sagemaker.core.resources.Endpoint.get")
680687
@patch("sagemaker.core.resources.InferenceComponent.create")
681688
@patch("sagemaker.core.resources.InferenceComponent.get_all")
689+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
682690
def test_inference_config_with_zero_accelerators(
683691
self,
692+
mock_is_nova_model,
684693
mock_ic_get_all,
685694
mock_ic_create,
686695
mock_endpoint_get,

sagemaker-serve/tests/unit/test_model_builder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,8 @@ def test_fetch_peft_from_training_job(self):
371371
"""Test fetching PEFT from TrainingJob."""
372372
from sagemaker.core.utils.utils import Unassigned
373373

374-
mock_job_spec = Mock()
375-
mock_job_spec.get = Mock(return_value="LORA")
376374
self.mock_training_job.serverless_job_config = Mock()
377-
self.mock_training_job.serverless_job_config.job_spec = mock_job_spec
375+
self.mock_training_job.serverless_job_config.peft = "LORA"
378376

379377
builder = ModelBuilder(
380378
model=self.mock_training_job,
@@ -389,10 +387,8 @@ def test_fetch_peft_from_model_trainer(self):
389387
"""Test fetching PEFT from ModelTrainer."""
390388
from sagemaker.train.model_trainer import ModelTrainer
391389

392-
mock_job_spec = Mock()
393-
mock_job_spec.get = Mock(return_value="LORA")
394390
self.mock_training_job.serverless_job_config = Mock()
395-
self.mock_training_job.serverless_job_config.job_spec = mock_job_spec
391+
self.mock_training_job.serverless_job_config.peft = "LORA"
396392

397393
mock_trainer = Mock(spec=ModelTrainer)
398394
mock_trainer._latest_training_job = self.mock_training_job
@@ -459,7 +455,8 @@ def test_build_single_modelbuilder_with_model_customization(self, mock_is_1p, mo
459455
with patch.object(builder, '_fetch_and_cache_recipe_config'):
460456
with patch.object(builder, '_get_client_translators', return_value=(Mock(), Mock())):
461457
with patch.object(builder, '_get_serve_setting', return_value=Mock()):
462-
result = builder._build_single_modelbuilder()
458+
with patch.object(builder, '_is_nova_model', return_value=False):
459+
result = builder._build_single_modelbuilder()
463460

464461
# Verify Model.create was called (indicating model customization path was taken)
465462
mock_model_class.create.assert_called_once()
@@ -500,6 +497,7 @@ def test_deploy_model_customization_new_endpoint(self):
500497

501498
with patch.object(builder, '_fetch_model_package', return_value=mock_model_package):
502499
with patch.object(builder, '_fetch_peft', return_value=None):
500+
with patch.object(builder, '_is_nova_model', return_value=False):
503501
with patch.object(EndpointConfig, 'create', return_value=mock_endpoint_config):
504502
with patch.object(Endpoint, 'get', side_effect=ClientError({'Error': {'Code': 'ValidationException'}}, 'GetEndpoint')):
505503
with patch.object(Endpoint, 'create', return_value=mock_endpoint):
@@ -574,6 +572,7 @@ def capture_ic_create(**kwargs):
574572

575573
with patch.object(builder, '_fetch_model_package', return_value=mock_model_package):
576574
with patch.object(builder, '_fetch_peft', return_value=None):
575+
with patch.object(builder, '_is_nova_model', return_value=False):
577576
with patch.object(EndpointConfig, 'create', return_value=mock_endpoint_config):
578577
with patch.object(Endpoint, 'get', side_effect=ClientError({'Error': {'Code': 'ValidationException'}}, 'GetEndpoint')):
579578
with patch.object(Endpoint, 'create', return_value=mock_endpoint):
@@ -646,6 +645,7 @@ def capture_ic_create(**kwargs):
646645

647646
with patch.object(builder, '_fetch_model_package', return_value=mock_model_package):
648647
with patch.object(builder, '_fetch_peft', return_value=None):
648+
with patch.object(builder, '_is_nova_model', return_value=False):
649649
with patch.object(EndpointConfig, 'create', return_value=mock_endpoint_config):
650650
with patch.object(Endpoint, 'get', side_effect=ClientError({'Error': {'Code': 'ValidationException'}}, 'GetEndpoint')):
651651
with patch.object(Endpoint, 'create', return_value=mock_endpoint):

0 commit comments

Comments
 (0)