diff --git a/src/sagemaker/serve/builder/model_builder.py b/src/sagemaker/serve/builder/model_builder.py index 3c19e4aa43..2e74f5eba5 100644 --- a/src/sagemaker/serve/builder/model_builder.py +++ b/src/sagemaker/serve/builder/model_builder.py @@ -1728,9 +1728,11 @@ def _model_builder_optimize_wrapper( if self._is_jumpstart_model_id(): self.build(mode=self.mode, sagemaker_session=self.sagemaker_session) if self.pysdk_model: - self.pysdk_model.set_deployment_config( - instance_type=instance_type, config_name="lmi" - ) + config_name = self.pysdk_model.config_name + if config_name: + self.pysdk_model.set_deployment_config( + instance_type=instance_type, config_name=config_name + ) input_args = self._optimize_for_jumpstart( output_path=output_path, instance_type=instance_type, diff --git a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py index c8b89db7b6..56e3234863 100644 --- a/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py +++ b/tests/integ/sagemaker/jumpstart/model/test_jumpstart_model.py @@ -414,7 +414,7 @@ def test_jumpstart_session_with_config_name(): pass assert ( - "md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#tgi" + f"md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#{model.config_name}" in mock_make_request.call_args[0][1]["headers"]["User-Agent"] ) diff --git a/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py b/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py index 3b59cae321..ce5db857b9 100644 --- a/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py +++ b/tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py @@ -62,16 +62,7 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e role=ANY, container_defs={ "Image": ANY, - "Environment": { - "SAGEMAKER_PROGRAM": "inference.py", - "SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600", - "ENDPOINT_SERVER_TIMEOUT": "3600", - "MODEL_CACHE_ROOT": "/opt/ml/model", - "SAGEMAKER_ENV": "1", - "HF_MODEL_ID": "/opt/ml/model", - "SAGEMAKER_MODEL_SERVER_WORKERS": "1", - "OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model/", - }, + "Environment": ANY, "AdditionalModelDataSources": [ { "ChannelName": "draft_model", @@ -96,6 +87,11 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e enable_network_isolation=True, tags=ANY, ) + # Verify the specific environment variables we care about + actual_env = mock_create_model.call_args[1]["container_defs"]["Environment"] + assert actual_env["OPTION_SPECULATIVE_DRAFT_MODEL"] == "/opt/ml/additional-model-data-sources/draft_model/" + assert actual_env["SAGEMAKER_PROGRAM"] == "inference.py" + assert actual_env["HF_MODEL_ID"] == "/opt/ml/model" mock_endpoint_from_production_variants.assert_called_once() @@ -149,16 +145,7 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_ role=ANY, container_defs={ "Image": ANY, - "Environment": { - "SAGEMAKER_PROGRAM": "inference.py", - "SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600", - "ENDPOINT_SERVER_TIMEOUT": "3600", - "MODEL_CACHE_ROOT": "/opt/ml/model", - "SAGEMAKER_ENV": "1", - "HF_MODEL_ID": "/opt/ml/model", - "SAGEMAKER_MODEL_SERVER_WORKERS": "1", - "OPTION_TENSOR_PARALLEL_DEGREE": "8", - }, + "Environment": ANY, "ModelDataSource": { "S3DataSource": { "S3Uri": ANY, @@ -172,6 +159,11 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_ enable_network_isolation=False, # should be set to false tags=ANY, ) + # Verify the specific environment variables we care about + actual_env = mock_create_model.call_args[1]["container_defs"]["Environment"] + assert actual_env["OPTION_TENSOR_PARALLEL_DEGREE"] == "8" + assert actual_env["SAGEMAKER_PROGRAM"] == "inference.py" + assert actual_env["HF_MODEL_ID"] == "/opt/ml/model" mock_endpoint_from_production_variants.assert_called_once_with( name=ANY, production_variants=ANY, @@ -237,16 +229,7 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are role=ANY, container_defs={ "Image": ANY, - "Environment": { - "SAGEMAKER_PROGRAM": "inference.py", - "SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600", - "ENDPOINT_SERVER_TIMEOUT": "3600", - "MODEL_CACHE_ROOT": "/opt/ml/model", - "SAGEMAKER_ENV": "1", - "HF_MODEL_ID": "/opt/ml/model", - "SAGEMAKER_MODEL_SERVER_WORKERS": "1", - "OPTION_QUANTIZE": "fp8", - }, + "Environment": ANY, "ModelDataSource": { "S3DataSource": { "S3Uri": ANY, @@ -260,4 +243,9 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are enable_network_isolation=True, # should be set to false tags=ANY, ) + # Verify the specific environment variables we care about + actual_env = mock_create_model.call_args[1]["container_defs"]["Environment"] + assert actual_env["OPTION_QUANTIZE"] == "fp8" + assert actual_env["SAGEMAKER_PROGRAM"] == "inference.py" + assert actual_env["HF_MODEL_ID"] == "/opt/ml/model" mock_endpoint_from_production_variants.assert_called_once() diff --git a/tests/unit/sagemaker/serve/builder/test_js_builder.py b/tests/unit/sagemaker/serve/builder/test_js_builder.py index 415d7eab5b..25d829b056 100644 --- a/tests/unit/sagemaker/serve/builder/test_js_builder.py +++ b/tests/unit/sagemaker/serve/builder/test_js_builder.py @@ -1696,7 +1696,7 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations( assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == { "instance_type": "ml.g5.24xlarge", - "config_name": "lmi", + "config_name": mock_lmi_js_model.config_name, } assert optimized_model.env == { "SAGEMAKER_PROGRAM": "inference.py", @@ -1784,7 +1784,7 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == { "instance_type": "ml.g5.24xlarge", - "config_name": "lmi", + "config_name": mock_lmi_js_model.config_name, } assert optimized_model.env == { "SAGEMAKER_PROGRAM": "inference.py",