Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/sagemaker/serve/builder/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,9 +1728,11 @@ def _model_builder_optimize_wrapper(
if self._is_jumpstart_model_id():
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
if self.pysdk_model:
self.pysdk_model.set_deployment_config(
instance_type=instance_type, config_name="lmi"
)
config_name = self.pysdk_model.config_name
if config_name:
self.pysdk_model.set_deployment_config(
instance_type=instance_type, config_name=config_name
)
input_args = self._optimize_for_jumpstart(
output_path=output_path,
instance_type=instance_type,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def test_jumpstart_session_with_config_name():
pass

assert (
"md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#tgi"
f"md/js_model_id#meta-textgeneration-llama-2-7b md/js_model_ver#* md/js_config#{model.config_name}"
in mock_make_request.call_args[0][1]["headers"]["User-Agent"]
)

Expand Down
48 changes: 18 additions & 30 deletions tests/integ/sagemaker/serve/test_serve_js_deep_unit_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,7 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
role=ANY,
container_defs={
"Image": ANY,
"Environment": {
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600",
"ENDPOINT_SERVER_TIMEOUT": "3600",
"MODEL_CACHE_ROOT": "/opt/ml/model",
"SAGEMAKER_ENV": "1",
"HF_MODEL_ID": "/opt/ml/model",
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
"OPTION_SPECULATIVE_DRAFT_MODEL": "/opt/ml/additional-model-data-sources/draft_model/",
},
"Environment": ANY,
"AdditionalModelDataSources": [
{
"ChannelName": "draft_model",
Expand All @@ -96,6 +87,11 @@ def test_js_model_with_optimize_speculative_decoding_config_gated_requests_are_e
enable_network_isolation=True,
tags=ANY,
)
# Verify the specific environment variables we care about
actual_env = mock_create_model.call_args[1]["container_defs"]["Environment"]
assert actual_env["OPTION_SPECULATIVE_DRAFT_MODEL"] == "/opt/ml/additional-model-data-sources/draft_model/"
assert actual_env["SAGEMAKER_PROGRAM"] == "inference.py"
assert actual_env["HF_MODEL_ID"] == "/opt/ml/model"
mock_endpoint_from_production_variants.assert_called_once()


Expand Down Expand Up @@ -149,16 +145,7 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_
role=ANY,
container_defs={
"Image": ANY,
"Environment": {
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600",
"ENDPOINT_SERVER_TIMEOUT": "3600",
"MODEL_CACHE_ROOT": "/opt/ml/model",
"SAGEMAKER_ENV": "1",
"HF_MODEL_ID": "/opt/ml/model",
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
"OPTION_TENSOR_PARALLEL_DEGREE": "8",
},
"Environment": ANY,
"ModelDataSource": {
"S3DataSource": {
"S3Uri": ANY,
Expand All @@ -172,6 +159,11 @@ def test_js_model_with_optimize_sharding_and_resource_requirements_requests_are_
enable_network_isolation=False, # should be set to false
tags=ANY,
)
# Verify the specific environment variables we care about
actual_env = mock_create_model.call_args[1]["container_defs"]["Environment"]
assert actual_env["OPTION_TENSOR_PARALLEL_DEGREE"] == "8"
assert actual_env["SAGEMAKER_PROGRAM"] == "inference.py"
assert actual_env["HF_MODEL_ID"] == "/opt/ml/model"
mock_endpoint_from_production_variants.assert_called_once_with(
name=ANY,
production_variants=ANY,
Expand Down Expand Up @@ -237,16 +229,7 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are
role=ANY,
container_defs={
"Image": ANY,
"Environment": {
"SAGEMAKER_PROGRAM": "inference.py",
"SAGEMAKER_MODEL_SERVER_TIMEOUT": "3600",
"ENDPOINT_SERVER_TIMEOUT": "3600",
"MODEL_CACHE_ROOT": "/opt/ml/model",
"SAGEMAKER_ENV": "1",
"HF_MODEL_ID": "/opt/ml/model",
"SAGEMAKER_MODEL_SERVER_WORKERS": "1",
"OPTION_QUANTIZE": "fp8",
},
"Environment": ANY,
"ModelDataSource": {
"S3DataSource": {
"S3Uri": ANY,
Expand All @@ -260,4 +243,9 @@ def test_js_model_with_optimize_quantization_on_pre_optimized_model_requests_are
enable_network_isolation=True, # should be set to false
tags=ANY,
)
# Verify the specific environment variables we care about
actual_env = mock_create_model.call_args[1]["container_defs"]["Environment"]
assert actual_env["OPTION_QUANTIZE"] == "fp8"
assert actual_env["SAGEMAKER_PROGRAM"] == "inference.py"
assert actual_env["HF_MODEL_ID"] == "/opt/ml/model"
mock_endpoint_from_production_variants.assert_called_once()
4 changes: 2 additions & 2 deletions tests/unit/sagemaker/serve/builder/test_js_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1696,7 +1696,7 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations(

assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == {
"instance_type": "ml.g5.24xlarge",
"config_name": "lmi",
"config_name": mock_lmi_js_model.config_name,
}
assert optimized_model.env == {
"SAGEMAKER_PROGRAM": "inference.py",
Expand Down Expand Up @@ -1784,7 +1784,7 @@ def test_optimize_on_js_model_should_ignore_pre_optimized_configurations_no_over

assert mock_lmi_js_model.set_deployment_config.call_args_list[0].kwargs == {
"instance_type": "ml.g5.24xlarge",
"config_name": "lmi",
"config_name": mock_lmi_js_model.config_name,
}
assert optimized_model.env == {
"SAGEMAKER_PROGRAM": "inference.py",
Expand Down
Loading