Skip to content

Commit 8d55ad5

Browse files
committed
Bug fix for integrity check using hmac key in serve
1 parent 4ed9d91 commit 8d55ad5

13 files changed

Lines changed: 54 additions & 109 deletions

File tree

sagemaker-serve/src/sagemaker/serve/model_builder_servers.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -705,12 +705,6 @@ def _build_for_transformers(self) -> Model:
705705
self.s3_model_data_url, _ = self._prepare_for_mode()
706706

707707

708-
# Clean up empty secret key
709-
if (
710-
"SAGEMAKER_SERVE_SECRET_KEY" in self.env_vars
711-
and not self.env_vars["SAGEMAKER_SERVE_SECRET_KEY"]
712-
):
713-
del self.env_vars["SAGEMAKER_SERVE_SECRET_KEY"]
714708

715709
# Instance type validation for SAGEMAKER_ENDPOINT mode
716710
if self.mode == Mode.SAGEMAKER_ENDPOINT and not self.instance_type:

sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,7 @@ def build(self):
130130
from sagemaker.core.deserializers import JSONDeserializer
131131
from sagemaker.serve.detector.pickler import save_pkl
132132
from sagemaker.serve.builder.requirements_manager import RequirementsManager
133-
from sagemaker.serve.validations.check_integrity import (
134-
generate_secret_key,
135-
compute_hash,
136-
)
133+
from sagemaker.serve.validations.check_integrity import compute_hash
137134
from sagemaker.core.remote_function.core.serialization import _MetaData
138135
from sagemaker.serve.model_server.triton.config_template import CONFIG_TEMPLATE
139136

@@ -2882,20 +2879,6 @@ def _save_inference_spec(self) -> None:
28822879
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")
28832880
save_pkl(pkl_path, (self.inference_spec, self.schema_builder))
28842881

2885-
def _hmac_signing(self):
2886-
"""Perform HMAC signing on picke file for integrity check"""
2887-
secret_key = generate_secret_key()
2888-
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")
2889-
2890-
with open(str(pkl_path.joinpath("serve.pkl")), "rb") as f:
2891-
buffer = f.read()
2892-
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
2893-
2894-
with open(str(pkl_path.joinpath("metadata.json")), "wb") as metadata:
2895-
metadata.write(_MetaData(hash_value).to_json())
2896-
2897-
self.secret_key = secret_key
2898-
28992882
def _generate_config_pbtxt(self, pkl_path: Path):
29002883
"""Generate Triton config.pbtxt file."""
29012884
config_path = pkl_path.joinpath("config.pbtxt")
@@ -3097,7 +3080,12 @@ def _prepare_for_triton(self):
30973080

30983081
self._pack_conda_env(pkl_path=pkl_path)
30993082

3100-
self._hmac_signing()
3083+
# Compute SHA256 hash for integrity check
3084+
with open(str(pkl_path.joinpath("serve.pkl")), "rb") as f:
3085+
buffer = f.read()
3086+
hash_value = compute_hash(buffer=buffer)
3087+
with open(str(pkl_path.joinpath("metadata.json")), "wb") as metadata:
3088+
metadata.write(_MetaData(hash_value).to_json())
31013089

31023090
return
31033091

sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/prepare.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,7 @@
2525
from sagemaker.core.helper.session_helper import Session
2626
from sagemaker.serve.spec.inference_spec import InferenceSpec
2727
from sagemaker.serve.detector.dependency_manager import capture_dependencies
28-
from sagemaker.serve.validations.check_integrity import (
29-
generate_secret_key,
30-
compute_hash,
31-
)
28+
from sagemaker.serve.validations.check_integrity import compute_hash
3229
from sagemaker.core.remote_function.core.serialization import _MetaData
3330

3431
logger = logging.getLogger(__name__)
@@ -119,11 +116,8 @@ def prepare_for_mms(
119116

120117
capture_dependencies(dependencies=dependencies, work_dir=code_dir)
121118

122-
secret_key = generate_secret_key()
123119
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
124120
buffer = f.read()
125-
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
121+
hash_value = compute_hash(buffer=buffer)
126122
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
127123
metadata.write(_MetaData(hash_value).to_json())
128-
129-
return secret_key

sagemaker-serve/src/sagemaker/serve/model_server/multi_model_server/server.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,12 @@ def _start_serving(
2828
client: object,
2929
image: str,
3030
model_path: str,
31-
secret_key: str,
3231
env_vars: dict,
3332
):
3433
"""Initializes the start of the server"""
3534
env = {
3635
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
3736
"SAGEMAKER_PROGRAM": "inference.py",
38-
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
3937
"LOCAL_PYTHON": platform.python_version(),
4038
}
4139
if env_vars:
@@ -80,7 +78,6 @@ class SageMakerMultiModelServer:
8078
def _upload_server_artifacts(
8179
self,
8280
model_path: str,
83-
secret_key: str,
8481
sagemaker_session: Session,
8582
s3_model_data_url: str = None,
8683
image: str = None,
@@ -127,15 +124,16 @@ def _upload_server_artifacts(
127124
else None
128125
)
129126

130-
if secret_key:
131-
env_vars = {
132-
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
133-
"SAGEMAKER_PROGRAM": "inference.py",
134-
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
135-
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
136-
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
137-
"LOCAL_PYTHON": platform.python_version(),
138-
}
127+
if env_vars is None:
128+
env_vars = {}
129+
130+
env_vars.update({
131+
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
132+
"SAGEMAKER_PROGRAM": "inference.py",
133+
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
134+
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
135+
"LOCAL_PYTHON": platform.python_version(),
136+
})
139137

140138
return model_data, _update_env_vars(env_vars)
141139

sagemaker-serve/src/sagemaker/serve/model_server/smd/prepare.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,7 @@
1111

1212
from sagemaker.serve.spec.inference_spec import InferenceSpec
1313
from sagemaker.serve.detector.dependency_manager import capture_dependencies
14-
from sagemaker.serve.validations.check_integrity import (
15-
generate_secret_key,
16-
compute_hash,
17-
)
14+
from sagemaker.serve.validations.check_integrity import compute_hash
1815
from sagemaker.core.remote_function.core.serialization import _MetaData
1916
from sagemaker.serve.spec.inference_base import CustomOrchestrator, AsyncCustomOrchestrator
2017

@@ -64,11 +61,8 @@ def prepare_for_smd(
6461

6562
capture_dependencies(dependencies=dependencies, work_dir=code_dir)
6663

67-
secret_key = generate_secret_key()
6864
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
6965
buffer = f.read()
70-
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
66+
hash_value = compute_hash(buffer=buffer)
7167
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
7268
metadata.write(_MetaData(hash_value).to_json())
73-
74-
return secret_key

sagemaker-serve/src/sagemaker/serve/model_server/smd/server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ def _upload_smd_artifacts(
2020
self,
2121
model_path: str,
2222
sagemaker_session: Session,
23-
secret_key: str,
2423
s3_model_data_url: str = None,
2524
image: str = None,
2625
should_upload_artifacts: bool = False,
@@ -53,7 +52,6 @@ def _upload_smd_artifacts(
5352
"SAGEMAKER_INFERENCE_CODE_DIRECTORY": "/opt/ml/model/code",
5453
"SAGEMAKER_INFERENCE_CODE": "inference.handler",
5554
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
56-
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
5755
"LOCAL_PYTHON": platform.python_version(),
5856
}
5957
return s3_upload_path, env_vars

sagemaker-serve/src/sagemaker/serve/model_server/tei/server.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,20 +27,16 @@ class LocalTeiServing:
2727
"""LocalTeiServing class"""
2828

2929
def _start_tei_serving(
30-
self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict
30+
self, client: object, image: str, model_path: str, env_vars: dict
3131
):
3232
"""Starts a local tei serving container.
3333
3434
Args:
3535
client: Docker client
3636
image: Image to use
3737
model_path: Path to the model
38-
secret_key: Secret key to use for authentication
3938
env_vars: Environment variables to set
4039
"""
41-
if env_vars and secret_key:
42-
env_vars["SAGEMAKER_SERVE_SECRET_KEY"] = secret_key
43-
4440
self.container = client.containers.run(
4541
image,
4642
shm_size=_SHM_SIZE,

sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/prepare.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,7 @@
1010
_move_contents,
1111
)
1212
from sagemaker.serve.detector.dependency_manager import capture_dependencies
13-
from sagemaker.serve.validations.check_integrity import (
14-
generate_secret_key,
15-
compute_hash,
16-
)
13+
from sagemaker.serve.validations.check_integrity import compute_hash
1714
from sagemaker.core.remote_function.core.serialization import _MetaData
1815

1916

@@ -57,11 +54,8 @@ def prepare_for_tf_serving(
5754
raise ValueError("SavedModel is not found for Tensorflow or Keras flavor.")
5855
_move_contents(src_dir=mlflow_saved_model_dir, dest_dir=saved_model_bundle_dir)
5956

60-
secret_key = generate_secret_key()
6157
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
6258
buffer = f.read()
63-
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
59+
hash_value = compute_hash(buffer=buffer)
6460
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
6561
metadata.write(_MetaData(hash_value).to_json())
66-
67-
return secret_key

sagemaker-serve/src/sagemaker/serve/model_server/tensorflow_serving/server.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@ class LocalTensorflowServing:
2020
"""LocalTensorflowServing class."""
2121

2222
def _start_tensorflow_serving(
23-
self, client: object, image: str, model_path: str, secret_key: str, env_vars: dict
23+
self, client: object, image: str, model_path: str, env_vars: dict
2424
):
2525
"""Starts a local tensorflow serving container.
2626
2727
Args:
2828
client: Docker client
2929
image: Image to use
3030
model_path: Path to the model
31-
secret_key: Secret key to use for authentication
3231
env_vars: Environment variables to set
3332
"""
3433
self.container = client.containers.run(
@@ -47,7 +46,6 @@ def _start_tensorflow_serving(
4746
environment={
4847
"SAGEMAKER_SUBMIT_DIRECTORY": "/opt/ml/model/code",
4948
"SAGEMAKER_PROGRAM": "inference.py",
50-
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
5149
"LOCAL_PYTHON": platform.python_version(),
5250
**env_vars,
5351
},
@@ -81,7 +79,6 @@ def _upload_tensorflow_serving_artifacts(
8179
self,
8280
model_path: str,
8381
sagemaker_session: Session,
84-
secret_key: str,
8582
s3_model_data_url: str = None,
8683
image: str = None,
8784
should_upload_artifacts: bool = False,
@@ -91,7 +88,6 @@ def _upload_tensorflow_serving_artifacts(
9188
Args:
9289
model_path: Path to the model
9390
sagemaker_session: SageMaker session
94-
secret_key: Secret key to use for authentication
9591
s3_model_data_url: S3 model data URL
9692
image: Image to use
9793
model_data_s3_path: S3 model data URI
@@ -124,7 +120,6 @@ def _upload_tensorflow_serving_artifacts(
124120
"SAGEMAKER_PROGRAM": "inference.py",
125121
"SAGEMAKER_REGION": sagemaker_session.boto_region_name,
126122
"SAGEMAKER_CONTAINER_LOG_LEVEL": "10",
127-
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
128123
"LOCAL_PYTHON": platform.python_version(),
129124
}
130125
return s3_upload_path, env_vars

sagemaker-serve/src/sagemaker/serve/model_server/torchserve/prepare.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,7 @@
1212
from sagemaker.core.helper.session_helper import Session
1313
from sagemaker.serve.spec.inference_spec import InferenceSpec
1414
from sagemaker.serve.detector.dependency_manager import capture_dependencies
15-
from sagemaker.serve.validations.check_integrity import (
16-
generate_secret_key,
17-
compute_hash,
18-
)
15+
from sagemaker.serve.validations.check_integrity import compute_hash
1916
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
2017
from sagemaker.core.remote_function.core.serialization import _MetaData
2118

@@ -67,11 +64,8 @@ def prepare_for_torchserve(
6764

6865
capture_dependencies(dependencies=dependencies, work_dir=code_dir)
6966

70-
secret_key = generate_secret_key()
7167
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
7268
buffer = f.read()
73-
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
69+
hash_value = compute_hash(buffer=buffer)
7470
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
75-
metadata.write(_MetaData(hash_value).to_json())
76-
77-
return secret_key
71+
metadata.write(_MetaData(hash_value).to_json())

0 commit comments

Comments
 (0)