Skip to content

Commit 8a2f9d2

Browse files
fix: Remove hardcoded lmi config name in ModelBuilder optimize (#5749)
* fix: Remove hardcoded lmi config name in ModelBuilder optimize Replace hardcoded config_name='lmi' in ModelBuilder.optimize() with the model's dynamically resolved default config name. The JumpStart metadata for models like llama-3-1-8b-instruct no longer includes an 'lmi' config, causing ValueError during optimization. Also fix test_jumpstart_session_with_config_name to use the model's resolved config_name instead of hardcoding 'tgi', and update unit test assertions in test_js_builder.py accordingly. * fix: Use flexible Environment assertions in serve deep unit tests The JumpStart metadata update changed the default config from 'lmi' to 'max_context_best_price_performance', which adds extra environment variables to container_defs. Updated all three test assertions to use ANY for the Environment dict and separately verify only the specific env vars each test cares about (OPTION_QUANTIZE, OPTION_TENSOR_PARALLEL_DEGREE, OPTION_SPECULATIVE_DRAFT_MODEL). * Fix codestyle --------- Co-authored-by: Molly He <mollyhe@amazon.com>
1 parent cebc27a commit 8a2f9d2

File tree

4 files changed

+29
-36
lines changed

4 files changed

+29
-36
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,9 +1728,11 @@ def _model_builder_optimize_wrapper(
17281728
if self._is_jumpstart_model_id():
17291729
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
17301730
if self.pysdk_model:
1731-
self.pysdk_model.set_deployment_config(
1732-
instance_type=instance_type, config_name="lmi"
1733-
)
1731+
config_name = self.pysdk_model.config_name
1732+
if config_name:
1733+
self.pysdk_model.set_deployment_config(
1734+
instance_type=instance_type, config_name=config_name
1735+
)
17341736
input_args = self._optimize_for_jumpstart(
17351737
output_path=output_path,
17361738
instance_type=instance_type,

tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def test_jumpstart_session_with_config_name():
414414
pass
415415

416416
assert (
417-
"md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#tgi"
417+
f"md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#{model.config_name}"
418418
in mock_make_request.call_args[0][1]["headers"]["User-Agent"]
419419
)
420420

tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,7 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
6262
role=ANY,
6363
container_defs={
6464
"Image": ANY,
65-
"Environment": {
66-
"SAGEMAKER_PROGRAM": "inference.py",
67-
"SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600",
68-
"ENDPOINT_SERVER_TIMEOUT": "3600",
69-
"MODEL_CACHE_ROOT": "/opt/ml/model",
70-
"SAGEMAKER_ENV": "1",
71-
"HF_MODEL_ID": "/opt/ml/model",
72-
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
73-
"OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model/",
74-
},
65+
"Environment": ANY,
7566
"AdditionalModelDataSources": [
7667
{
7768
"ChannelName": "draft_model",
@@ -96,6 +87,14 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
9687
enable_network_isolation=True,
9788
tags=ANY,
9889
)
90+
# Verify the specific environment variables we care about
91+
actual_env = mock_create_model.call_args[1]["container_defs"]["Environment"]
92+
assert (
93+
actual_env["OPTION_SPECULATIVE_DRAFT_MODEL"]
94+
== "/opt/ml/additional-model-data-sources/draft_model/"
95+
)
96+
assert actual_env["SAGEMAKER_PROGRAM"] == "inference.py"
97+
assert actual_env["HF_MODEL_ID"] == "/opt/ml/model"
9998
mock_endpoint_from_production_variants.assert_called_once()
10099

101100

@@ -149,16 +148,7 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_
149148
role=ANY,
150149
container_defs={
151150
"Image": ANY,
152-
"Environment": {
153-
"SAGEMAKER_PROGRAM": "inference.py",
154-
"SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600",
155-
"ENDPOINT_SERVER_TIMEOUT": "3600",
156-
"MODEL_CACHE_ROOT": "/opt/ml/model",
157-
"SAGEMAKER_ENV": "1",
158-
"HF_MODEL_ID": "/opt/ml/model",
159-
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
160-
"OPTION_TENSOR_PARALLEL_DEGREE": "8",
161-
},
151+
"Environment": ANY,
162152
"ModelDataSource": {
163153
"S3DataSource": {
164154
"S3Uri": ANY,
@@ -172,6 +162,11 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_
172162
enable_network_isolation=False, # should be set to false
173163
tags=ANY,
174164
)
165+
# Verify the specific environment variables we care about
166+
actual_env = mock_create_model.call_args[1]["container_defs"]["Environment"]
167+
assert actual_env["OPTION_TENSOR_PARALLEL_DEGREE"] == "8"
168+
assert actual_env["SAGEMAKER_PROGRAM"] == "inference.py"
169+
assert actual_env["HF_MODEL_ID"] == "/opt/ml/model"
175170
mock_endpoint_from_production_variants.assert_called_once_with(
176171
name=ANY,
177172
production_variants=ANY,
@@ -237,16 +232,7 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are
237232
role=ANY,
238233
container_defs={
239234
"Image": ANY,
240-
"Environment": {
241-
"SAGEMAKER_PROGRAM": "inference.py",
242-
"SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600",
243-
"ENDPOINT_SERVER_TIMEOUT": "3600",
244-
"MODEL_CACHE_ROOT": "/opt/ml/model",
245-
"SAGEMAKER_ENV": "1",
246-
"HF_MODEL_ID": "/opt/ml/model",
247-
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
248-
"OPTION_QUANTIZE": "fp8",
249-
},
235+
"Environment": ANY,
250236
"ModelDataSource": {
251237
"S3DataSource": {
252238
"S3Uri": ANY,
@@ -260,4 +246,9 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are
260246
enable_network_isolation=True, # should be set to false
261247
tags=ANY,
262248
)
249+
# Verify the specific environment variables we care about
250+
actual_env = mock_create_model.call_args[1]["container_defs"]["Environment"]
251+
assert actual_env["OPTION_QUANTIZE"] == "fp8"
252+
assert actual_env["SAGEMAKER_PROGRAM"] == "inference.py"
253+
assert actual_env["HF_MODEL_ID"] == "/opt/ml/model"
263254
mock_endpoint_from_production_variants.assert_called_once()

tests/unit/sagemaker/serve/builder/test_js_builder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1696,7 +1696,7 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations(
16961696

16971697
assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == {
16981698
"instance_type": "ml.g5.24xlarge",
1699-
"config_name": "lmi",
1699+
"config_name": mock_lmi_js_model.config_name,
17001700
}
17011701
assert optimized_model.env == {
17021702
"SAGEMAKER_PROGRAM": "inference.py",
@@ -1784,7 +1784,7 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over
17841784

17851785
assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == {
17861786
"instance_type": "ml.g5.24xlarge",
1787-
"config_name": "lmi",
1787+
"config_name": mock_lmi_js_model.config_name,
17881788
}
17891789
assert optimized_model.env == {
17901790
"SAGEMAKER_PROGRAM": "inference.py",

0 commit comments

Comments
 (0)