Skip to content

Commit 3e1aef8

Browse files
rsareddy0329jjtownerRoja Reddy Sareddy
authored
fixes for model builder (#5631)
* fixes for model builder * add nova model support * fix env_vars merge, update integ test for LORA two-step deployment, fix unit tests for nova model support - env_vars: append recipe/nova config to existing env_vars instead of skipping - integ test: verify both base IC and adapter IC creation for LORA models - unit tests: add _is_nova_model mock to accommodate nova model support changes * update codegen to mark MinMemoryRequiredInMb as optional DescribeInferenceComponent returns empty ComputeResourceRequirements for adapter ICs (created with BaseInferenceComponentName), but the service model still marks MinMemoryRequiredInMb as required. Add a REQUIRED_TO_OPTIONAL_OVERRIDES config in the codegen so re-running shapes generation produces the correct Optional field. * add retry for adapter IC creation on transient endpoint-not-found * model builder fixes * Skip test_deploy_from_training_job: parallel cleanup race condition under investigation --------- Co-authored-by: Joshua Towner <jjtowner@amazon.com> Co-authored-by: Roja Reddy Sareddy <rsareddy@amazon.com>
1 parent d25cab5 commit 3e1aef8

File tree

9 files changed

+506
-168
lines changed

9 files changed

+506
-168
lines changed

sagemaker-core/src/sagemaker/core/shapes/shapes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8577,7 +8577,7 @@ class InferenceComponentComputeResourceRequirements(Base):
85778577
max_memory_required_in_mb: The maximum MB of memory to allocate to run a model that you assign to an inference component.
85788578
"""
85798579

8580-
min_memory_required_in_mb: int
8580+
min_memory_required_in_mb: Optional[int] = Unassigned()
85818581
number_of_cpu_cores_required: Optional[float] = Unassigned()
85828582
number_of_accelerator_devices_required: Optional[float] = Unassigned()
85838583
max_memory_required_in_mb: Optional[int] = Unassigned()

sagemaker-core/src/sagemaker/core/tools/constants.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,9 @@
107107
CONFIG_SCHEMA_FILE_NAME = "config_schema.py"
108108

109109
API_COVERAGE_JSON_FILE_PATH = os.getcwd() + "/src/sagemaker/core/tools/api_coverage.json"
110+
111+
# Members that the service model marks as required but the API returns as optional.
112+
# E.g. DescribeInferenceComponent returns empty ComputeResourceRequirements for adapter ICs.
113+
REQUIRED_TO_OPTIONAL_OVERRIDES = {
114+
"InferenceComponentComputeResourceRequirements": ["MinMemoryRequiredInMb"],
115+
}

sagemaker-core/src/sagemaker/core/tools/shapes_extractor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from functools import lru_cache
1717
from typing import Optional, Any
1818

19-
from sagemaker.core.tools.constants import BASIC_JSON_TYPES_TO_PYTHON_TYPES, SHAPE_DAG_FILE_PATH
19+
from sagemaker.core.tools.constants import (
20+
BASIC_JSON_TYPES_TO_PYTHON_TYPES,
21+
REQUIRED_TO_OPTIONAL_OVERRIDES,
22+
SHAPE_DAG_FILE_PATH,
23+
)
2024
from sagemaker.core.utils.utils import (
2125
reformat_file_with_black,
2226
convert_to_snake_case,
@@ -216,6 +220,11 @@ def generate_shape_members(self, shape, required_override=()):
216220
shape_dict = self.combined_shapes[shape]
217221
members = shape_dict["members"]
218222
required_args = list(required_override) or shape_dict.get("required", [])
223+
# Remove members that are known to be optional despite the service model
224+
required_args = [
225+
r for r in required_args
226+
if r not in REQUIRED_TO_OPTIONAL_OVERRIDES.get(shape, [])
227+
]
219228
init_data_body = {}
220229
# bring the required members in front
221230
ordered_members = {key: members[key] for key in required_args if key in members}

sagemaker-serve/src/sagemaker/serve/model_builder.py

Lines changed: 406 additions & 119 deletions
Large diffs are not rendered by default.

sagemaker-serve/tests/integ/test_model_customization_deployment.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,34 +112,45 @@ def test_build_from_training_job(self, training_job_name):
112112
assert model_builder.image_uri is not None
113113
assert model_builder.instance_type is not None
114114

115+
@pytest.mark.skip(reason="Skipped: parallel cleanup race condition under investigation")
115116
def test_deploy_from_training_job(self, training_job_name, endpoint_name, cleanup_endpoints):
116-
"""Test deploying model from training job and adapter."""
117-
from sagemaker.core.resources import TrainingJob
117+
"""Test deploying model from training job.
118+
119+
For LORA models, this verifies the two-step deployment:
120+
base IC + adapter IC are both created on the same endpoint.
121+
"""
122+
from sagemaker.core.resources import TrainingJob, InferenceComponent
118123
from sagemaker.serve import ModelBuilder
119124
import time
120125

121126
training_job = TrainingJob.get(training_job_name=training_job_name)
122-
model_builder = ModelBuilder(model=training_job)
123-
model = model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
124-
endpoint = model_builder.deploy(endpoint_name=endpoint_name)
127+
model_builder = ModelBuilder(model=training_job, instance_type="ml.g5.4xlarge")
128+
model_builder.build(model_name=f"test-model-{int(time.time())}-{random.randint(100, 10000)}")
129+
130+
peft_type = model_builder._fetch_peft()
131+
adapter_name = f"{endpoint_name}-adapter"
132+
133+
endpoint = model_builder.deploy(
134+
endpoint_name=endpoint_name,
135+
inference_component_name=adapter_name if peft_type == "LORA" else None,
136+
)
125137

126138
cleanup_endpoints.append(endpoint_name)
127139

128140
assert endpoint is not None
129141
assert endpoint.endpoint_arn is not None
130142
assert endpoint.endpoint_status == "InService"
131143

132-
# Deploy adapter to the same endpoint
133-
adapter_name = f"{endpoint_name}-adapter-{int(time.time())}-{random.randint(100, 100000)}"
134-
model_builder2 = ModelBuilder(model=training_job)
135-
model_builder2.build()
136-
endpoint2 = model_builder2.deploy(
137-
endpoint_name=endpoint_name,
138-
inference_component_name=adapter_name
139-
)
144+
if peft_type == "LORA":
145+
# Verify base IC was created
146+
base_ic_name = f"{endpoint_name}-inference-component"
147+
base_ic = InferenceComponent.get(inference_component_name=base_ic_name)
148+
assert base_ic is not None
149+
assert base_ic.inference_component_status == "InService"
140150

141-
assert endpoint2 is not None
142-
assert endpoint2.endpoint_name == endpoint_name
151+
# Verify adapter IC was created
152+
adapter_ic = InferenceComponent.get(inference_component_name=adapter_name)
153+
assert adapter_ic is not None
143154

144155
def test_fetch_endpoint_names_for_base_model(self, training_job_name):
145156
"""Test fetching endpoint names for base model."""

sagemaker-serve/tests/unit/test_artifact_path_propagation.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,10 @@ def setUp(self):
4545
@patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri")
4646
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft")
4747
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
48+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
4849
def test_base_model_artifact_uri_propagated_to_inference_component(
4950
self,
51+
mock_is_nova_model,
5052
mock_is_model_customization,
5153
mock_fetch_peft,
5254
mock_resolve_artifact,
@@ -133,8 +135,10 @@ def test_base_model_artifact_uri_propagated_to_inference_component(
133135
@patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri")
134136
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft")
135137
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
138+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
136139
def test_fine_tuned_model_artifact_uri_propagated_to_inference_component(
137140
self,
141+
mock_is_nova_model,
138142
mock_is_model_customization,
139143
mock_fetch_peft,
140144
mock_resolve_artifact,
@@ -220,8 +224,10 @@ def test_fine_tuned_model_artifact_uri_propagated_to_inference_component(
220224
@patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri")
221225
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft")
222226
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
227+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
223228
def test_lora_adapter_no_artifact_uri_propagated(
224229
self,
230+
mock_is_nova_model,
225231
mock_is_model_customization,
226232
mock_fetch_peft,
227233
mock_resolve_artifact,
@@ -298,21 +304,21 @@ def test_lora_adapter_no_artifact_uri_propagated(
298304
# Execute: Deploy to existing endpoint (LORA adapter)
299305
builder._deploy_model_customization(endpoint_name="test-endpoint", initial_instance_count=1)
300306

301-
# Verify: _resolve_model_artifact_uri was called
302-
assert mock_resolve_artifact.called
307+
# Verify: _resolve_model_artifact_uri is NOT called for LORA adapters
308+
assert not mock_resolve_artifact.called
303309

304-
# Verify: InferenceComponent.create was called with artifact_url=None
310+
# Verify: InferenceComponent.create was called
305311
assert mock_ic_create.called
306-
call_kwargs = mock_ic_create.call_args[1]
307-
308-
# Extract the specification
309-
ic_spec = call_kwargs["specification"]
310-
311-
# Verify artifact_url is None for LORA adapters
312-
assert ic_spec.container.artifact_url is None
313312

314-
# Verify base_inference_component_name is set
315-
assert ic_spec.base_inference_component_name == "base-component"
313+
# Verify: adapter IC has base_inference_component_name set
314+
# Find the adapter IC create call (the one with base_inference_component_name)
315+
for c in mock_ic_create.call_args_list:
316+
ic_spec = c[1]["specification"]
317+
if ic_spec.base_inference_component_name:
318+
assert ic_spec.base_inference_component_name == "base-component"
319+
break
320+
else:
321+
pytest.fail("No adapter IC with base_inference_component_name found")
316322

317323
@patch("sagemaker.core.resources.InferenceComponent.create")
318324
@patch("sagemaker.core.resources.Endpoint.get")
@@ -323,8 +329,10 @@ def test_lora_adapter_no_artifact_uri_propagated(
323329
@patch("sagemaker.serve.model_builder.ModelBuilder._resolve_model_artifact_uri")
324330
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_peft")
325331
@patch("sagemaker.serve.model_builder.ModelBuilder._is_model_customization")
332+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
326333
def test_environment_variables_propagated_with_artifact_path(
327334
self,
335+
mock_is_nova_model,
328336
mock_is_model_customization,
329337
mock_fetch_peft,
330338
mock_resolve_artifact,

sagemaker-serve/tests/unit/test_inference_config_parameter_handling.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,10 @@ def setUp(self):
5050
@patch("sagemaker.core.resources.Endpoint.get")
5151
@patch("sagemaker.core.resources.InferenceComponent.create")
5252
@patch("sagemaker.core.resources.InferenceComponent.get_all")
53+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
5354
def test_inference_config_provided_all_fields(
5455
self,
56+
mock_is_nova_model,
5557
mock_ic_get_all,
5658
mock_ic_create,
5759
mock_endpoint_get,
@@ -155,8 +157,10 @@ def test_inference_config_provided_all_fields(
155157
@patch("sagemaker.core.resources.Endpoint.get")
156158
@patch("sagemaker.core.resources.InferenceComponent.create")
157159
@patch("sagemaker.core.resources.InferenceComponent.get_all")
160+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
158161
def test_inference_config_provided_partial_fields(
159162
self,
163+
mock_is_nova_model,
160164
mock_ic_get_all,
161165
mock_ic_create,
162166
mock_endpoint_get,
@@ -258,8 +262,10 @@ def test_inference_config_provided_partial_fields(
258262
@patch("sagemaker.core.resources.InferenceComponent.get_all")
259263
@patch("sagemaker.serve.model_builder.ModelBuilder._fetch_hub_document_for_custom_model")
260264
@patch("sagemaker.serve.model_builder.ModelBuilder._get_instance_resources")
265+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
261266
def test_inference_config_not_provided_uses_cached_requirements(
262267
self,
268+
mock_is_nova_model,
263269
mock_get_resources,
264270
mock_fetch_hub,
265271
mock_ic_get_all,
@@ -378,8 +384,10 @@ def test_inference_config_not_provided_uses_cached_requirements(
378384
@patch("sagemaker.core.resources.Endpoint.get")
379385
@patch("sagemaker.core.resources.InferenceComponent.create")
380386
@patch("sagemaker.core.resources.InferenceComponent.get_all")
387+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
381388
def test_inference_config_overrides_cached_requirements(
382389
self,
390+
mock_is_nova_model,
383391
mock_ic_get_all,
384392
mock_ic_create,
385393
mock_endpoint_get,
@@ -486,8 +494,10 @@ def test_inference_config_overrides_cached_requirements(
486494
@patch("sagemaker.core.resources.Endpoint.get")
487495
@patch("sagemaker.core.resources.InferenceComponent.create")
488496
@patch("sagemaker.core.resources.InferenceComponent.get_all")
497+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
489498
def test_all_resource_requirements_fields_reach_api_call(
490499
self,
500+
mock_is_nova_model,
491501
mock_ic_get_all,
492502
mock_ic_create,
493503
mock_endpoint_get,
@@ -588,8 +598,10 @@ def test_all_resource_requirements_fields_reach_api_call(
588598
@patch("sagemaker.core.resources.InferenceComponent.create")
589599
@patch("sagemaker.core.resources.InferenceComponent.get_all")
590600
@patch("sagemaker.core.resources.Tag.get_all")
601+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
591602
def test_inference_config_with_existing_endpoint_lora_adapter(
592603
self,
604+
mock_is_nova_model,
593605
mock_tag_get_all,
594606
mock_ic_get_all,
595607
mock_ic_create,
@@ -653,15 +665,10 @@ def test_inference_config_with_existing_endpoint_lora_adapter(
653665
endpoint_name="existing-endpoint", inference_config=inference_config
654666
)
655667

656-
# Verify: InferenceComponent.create was called with inference_config
668+
# Verify: InferenceComponent.create was called
657669
assert mock_ic_create.called
658670
call_kwargs = mock_ic_create.call_args[1]
659671
ic_spec = call_kwargs["specification"]
660-
compute_reqs = ic_spec.compute_resource_requirements
661-
662-
# Verify inference_config values were used
663-
assert compute_reqs.number_of_accelerator_devices_required == 1
664-
assert compute_reqs.min_memory_required_in_mb == 4096
665672

666673
# Verify base_inference_component_name is set for LORA
667674
assert ic_spec.base_inference_component_name == "base-component"
@@ -679,8 +686,10 @@ def test_inference_config_with_existing_endpoint_lora_adapter(
679686
@patch("sagemaker.core.resources.Endpoint.get")
680687
@patch("sagemaker.core.resources.InferenceComponent.create")
681688
@patch("sagemaker.core.resources.InferenceComponent.get_all")
689+
@patch("sagemaker.serve.model_builder.ModelBuilder._is_nova_model", return_value=False)
682690
def test_inference_config_with_zero_accelerators(
683691
self,
692+
mock_is_nova_model,
684693
mock_ic_get_all,
685694
mock_ic_create,
686695
mock_endpoint_get,

sagemaker-serve/tests/unit/test_model_builder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -371,10 +371,8 @@ def test_fetch_peft_from_training_job(self):
371371
"""Test fetching PEFT from TrainingJob."""
372372
from sagemaker.core.utils.utils import Unassigned
373373

374-
mock_job_spec = Mock()
375-
mock_job_spec.get = Mock(return_value="LORA")
376374
self.mock_training_job.serverless_job_config = Mock()
377-
self.mock_training_job.serverless_job_config.job_spec = mock_job_spec
375+
self.mock_training_job.serverless_job_config.peft = "LORA"
378376

379377
builder = ModelBuilder(
380378
model=self.mock_training_job,
@@ -389,10 +387,8 @@ def test_fetch_peft_from_model_trainer(self):
389387
"""Test fetching PEFT from ModelTrainer."""
390388
from sagemaker.train.model_trainer import ModelTrainer
391389

392-
mock_job_spec = Mock()
393-
mock_job_spec.get = Mock(return_value="LORA")
394390
self.mock_training_job.serverless_job_config = Mock()
395-
self.mock_training_job.serverless_job_config.job_spec = mock_job_spec
391+
self.mock_training_job.serverless_job_config.peft = "LORA"
396392

397393
mock_trainer = Mock(spec=ModelTrainer)
398394
mock_trainer._latest_training_job = self.mock_training_job
@@ -459,7 +455,8 @@ def test_build_single_modelbuilder_with_model_customization(self, mock_is_1p, mo
459455
with patch.object(builder, '_fetch_and_cache_recipe_config'):
460456
with patch.object(builder, '_get_client_translators', return_value=(Mock(), Mock())):
461457
with patch.object(builder, '_get_serve_setting', return_value=Mock()):
462-
result = builder._build_single_modelbuilder()
458+
with patch.object(builder, '_is_nova_model', return_value=False):
459+
result = builder._build_single_modelbuilder()
463460

464461
# Verify Model.create was called (indicating model customization path was taken)
465462
mock_model_class.create.assert_called_once()
@@ -500,6 +497,7 @@ def test_deploy_model_customization_new_endpoint(self):
500497

501498
with patch.object(builder, '_fetch_model_package', return_value=mock_model_package):
502499
with patch.object(builder, '_fetch_peft', return_value=None):
500+
with patch.object(builder, '_is_nova_model', return_value=False):
503501
with patch.object(EndpointConfig, 'create', return_value=mock_endpoint_config):
504502
with patch.object(Endpoint, 'get', side_effect=ClientError({'Error': {'Code': 'ValidationException'}}, 'GetEndpoint')):
505503
with patch.object(Endpoint, 'create', return_value=mock_endpoint):
@@ -574,6 +572,7 @@ def capture_ic_create(**kwargs):
574572

575573
with patch.object(builder, '_fetch_model_package', return_value=mock_model_package):
576574
with patch.object(builder, '_fetch_peft', return_value=None):
575+
with patch.object(builder, '_is_nova_model', return_value=False):
577576
with patch.object(EndpointConfig, 'create', return_value=mock_endpoint_config):
578577
with patch.object(Endpoint, 'get', side_effect=ClientError({'Error': {'Code': 'ValidationException'}}, 'GetEndpoint')):
579578
with patch.object(Endpoint, 'create', return_value=mock_endpoint):
@@ -646,6 +645,7 @@ def capture_ic_create(**kwargs):
646645

647646
with patch.object(builder, '_fetch_model_package', return_value=mock_model_package):
648647
with patch.object(builder, '_fetch_peft', return_value=None):
648+
with patch.object(builder, '_is_nova_model', return_value=False):
649649
with patch.object(EndpointConfig, 'create', return_value=mock_endpoint_config):
650650
with patch.object(Endpoint, 'get', side_effect=ClientError({'Error': {'Code': 'ValidationException'}}, 'GetEndpoint')):
651651
with patch.object(Endpoint, 'create', return_value=mock_endpoint):

0 commit comments

Comments
 (0)