Skip to content

Commit f9a45db

Browse files
author
Pravali Uppugunduri
committed
fix: Security fixes for Triton HMAC key exposure and missing integrity check (v2)
Backport of v3 security fixes for P400136088 and V2146375387. 1. check_integrity.py: Switch from HMAC-SHA256 to plain SHA-256. Remove generate_secret_key, remove env var dependency. 2. triton/model.py: Add integrity check in initialize() BEFORE cloudpickle deserialization. 3. triton/server.py: Remove SAGEMAKER_SERVE_SECRET_KEY from container environment variables. 4. triton/triton_builder.py: Remove hardcoded dummy secret key for ONNX path. Rename _hmac_signing to _compute_integrity_hash. Use plain SHA-256. 5. All prepare.py files (torchserve, mms, tf_serving, smd): Remove generate_secret_key usage, switch to plain SHA-256.
1 parent e5f349c commit f9a45db

File tree

8 files changed

+23
-43
lines changed

8 files changed

+23
-43
lines changed

src/sagemaker/serve/model_server/multi_model_server/prepare.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from sagemaker.serve.spec.inference_spec import InferenceSpec
2727
from sagemaker.serve.detector.dependency_manager import capture_dependencies
2828
from sagemaker.serve.validations.check_integrity import (
29-
generate_secret_key,
3029
compute_hash,
3130
)
3231
from sagemaker.remote_function.core.serialization import _MetaData
@@ -120,11 +119,10 @@ def prepare_for_mms(
120119

121120
capture_dependencies(dependencies=dependencies, work_dir=code_dir)
122121

123-
secret_key = generate_secret_key()
124122
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
125123
buffer = f.read()
126-
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
124+
hash_value = compute_hash(buffer=buffer)
127125
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
128126
metadata.write(_MetaData(hash_value).to_json())
129127

130-
return secret_key
128+
return ""

src/sagemaker/serve/model_server/smd/prepare.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from sagemaker.serve.spec.inference_spec import InferenceSpec
1313
from sagemaker.serve.detector.dependency_manager import capture_dependencies
1414
from sagemaker.serve.validations.check_integrity import (
15-
generate_secret_key,
1615
compute_hash,
1716
)
1817
from sagemaker.remote_function.core.serialization import _MetaData
@@ -64,11 +63,10 @@ def prepare_for_smd(
6463

6564
capture_dependencies(dependencies=dependencies, work_dir=code_dir)
6665

67-
secret_key = generate_secret_key()
6866
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
6967
buffer = f.read()
70-
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
68+
hash_value = compute_hash(buffer=buffer)
7169
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
7270
metadata.write(_MetaData(hash_value).to_json())
7371

74-
return secret_key
72+
return ""

src/sagemaker/serve/model_server/tensorflow_serving/prepare.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
)
1212
from sagemaker.serve.detector.dependency_manager import capture_dependencies
1313
from sagemaker.serve.validations.check_integrity import (
14-
generate_secret_key,
1514
compute_hash,
1615
)
1716
from sagemaker.remote_function.core.serialization import _MetaData
@@ -57,11 +56,10 @@ def prepare_for_tf_serving(
5756
raise ValueError("SavedModel is not found for Tensorflow or Keras flavor.")
5857
_move_contents(src_dir=mlflow_saved_model_dir, dest_dir=saved_model_bundle_dir)
5958

60-
secret_key = generate_secret_key()
6159
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
6260
buffer = f.read()
63-
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
61+
hash_value = compute_hash(buffer=buffer)
6462
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
6563
metadata.write(_MetaData(hash_value).to_json())
6664

67-
return secret_key
65+
return ""

src/sagemaker/serve/model_server/torchserve/prepare.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from sagemaker.serve.spec.inference_spec import InferenceSpec
1414
from sagemaker.serve.detector.dependency_manager import capture_dependencies
1515
from sagemaker.serve.validations.check_integrity import (
16-
generate_secret_key,
1716
compute_hash,
1817
)
1918
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
@@ -69,11 +68,10 @@ def prepare_for_torchserve(
6968

7069
capture_dependencies(dependencies=dependencies, work_dir=code_dir)
7170

72-
secret_key = generate_secret_key()
7371
with open(str(code_dir.joinpath("serve.pkl")), "rb") as f:
7472
buffer = f.read()
75-
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
73+
hash_value = compute_hash(buffer=buffer)
7674
with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata:
7775
metadata.write(_MetaData(hash_value).to_json())
7876

79-
return secret_key
77+
return ""

src/sagemaker/serve/model_server/triton/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,12 @@ def auto_complete_config(auto_complete_model_config):
2626
def initialize(self, args: dict) -> None:
2727
"""Placeholder docstring"""
2828
serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl")
29+
metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json")
2930
with open(str(serve_path), mode="rb") as f:
30-
inference_spec, schema_builder = cloudpickle.load(f)
31+
buffer = f.read()
3132

32-
# TODO: HMAC signing for integrity check
33+
perform_integrity_check(buffer=buffer, metadata_path=metadata_path)
34+
inference_spec, schema_builder = cloudpickle.loads(buffer)
3335

3436
self.inference_spec = inference_spec
3537
self.schema_builder = schema_builder

src/sagemaker/serve/model_server/triton/server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ def _start_triton_server(
4343
env_vars.update(
4444
{
4545
"TRITON_MODEL_DIR": "/models/model",
46-
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
4746
"LOCAL_PYTHON": platform.python_version(),
4847
}
4948
)
@@ -146,7 +145,6 @@ def _upload_triton_artifacts(
146145
env_vars = {
147146
"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model",
148147
"TRITON_MODEL_DIR": "/opt/ml/model/model",
149-
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
150148
"LOCAL_PYTHON": platform.python_version(),
151149
}
152150
return s3_upload_path, env_vars

src/sagemaker/serve/model_server/triton/triton_builder.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
from sagemaker.serve.detector.pickler import save_pkl
2424
from sagemaker.serve.model_server.triton.config_template import CONFIG_TEMPLATE
2525
from sagemaker.serve.validations.check_integrity import (
26-
generate_secret_key,
2726
compute_hash,
2827
)
2928

@@ -213,7 +212,7 @@ def _prepare_for_triton(self):
213212
export_path.mkdir(parents=True)
214213

215214
if self.model:
216-
self.secret_key = "dummy secret key for onnx backend"
215+
# ONNX path: no pickle serialization, no serve.pkl, no integrity check needed.
217216

218217
if self._framework == "pytorch":
219218
self._export_pytorch_to_onnx(
@@ -237,26 +236,23 @@ def _prepare_for_triton(self):
237236

238237
self._pack_conda_env(pkl_path=pkl_path)
239238

240-
self._hmac_signing()
239+
self._compute_integrity_hash()
241240

242241
return
243242

244243
raise ValueError("Either model or inference_spec should be provided to ModelBuilder.")
245244

246-
def _hmac_signing(self):
247-
"""Perform HMAC signing on picke file for integrity check"""
248-
secret_key = generate_secret_key()
245+
def _compute_integrity_hash(self):
246+
"""Compute SHA-256 integrity hash on pickle file for integrity check"""
249247
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")
250248

251249
with open(str(pkl_path.joinpath("serve.pkl")), "rb") as f:
252250
buffer = f.read()
253-
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
251+
hash_value = compute_hash(buffer=buffer)
254252

255253
with open(str(pkl_path.joinpath("metadata.json")), "wb") as metadata:
256254
metadata.write(_MetaData(hash_value).to_json())
257255

258-
self.secret_key = secret_key
259-
260256
def _generate_config_pbtxt(self, pkl_path: Path):
261257
config_path = pkl_path.joinpath("config.pbtxt")
262258

src/sagemaker/serve/validations/check_integrity.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,21 @@
1-
"""Validates the integrity of pickled file with HMAC signing."""
1+
"""Validates the integrity of pickled file with SHA-256 hash."""
22

33
from __future__ import absolute_import
4-
import secrets
54
import hmac
65
import hashlib
7-
import os
86
from pathlib import Path
97

108
from sagemaker.remote_function.core.serialization import _MetaData
119

1210

13-
def generate_secret_key(nbytes: int = 32) -> str:
14-
"""Generates secret key"""
15-
return secrets.token_hex(nbytes)
16-
17-
18-
def compute_hash(buffer: bytes, secret_key: str) -> str:
19-
"""Compute hash value using HMAC"""
20-
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
11+
def compute_hash(buffer: bytes) -> str:
12+
"""Compute SHA-256 hash of the given buffer."""
13+
return hashlib.sha256(buffer).hexdigest()
2114

2215

2316
def perform_integrity_check(buffer: bytes, metadata_path: Path):
24-
"""Validates the integrity of bytes by comparing the hash value"""
25-
secret_key = os.environ.get("SAGEMAKER_SERVE_SECRET_KEY")
26-
actual_hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
17+
"""Validates the integrity of bytes by comparing the hash value."""
18+
actual_hash_value = compute_hash(buffer=buffer)
2719

2820
if not Path.exists(metadata_path):
2921
raise ValueError("Path to metadata.json does not exist")

0 commit comments

Comments
 (0)