Skip to content

Commit 644fc16

Browse files
author
Pravali Uppugunduri
committed
fix: Add HMAC integrity verification for Triton inference handler
Addresses P400136088 Bug 1 and V2146375387 (Triton path). Three changes: 1. check_integrity.py: Switch from HMAC-SHA256 to plain SHA-256. - Remove generate_secret_key() — no longer needed - compute_hash() now uses hashlib.sha256() instead of hmac.new() - perform_integrity_check() no longer reads SAGEMAKER_SERVE_SECRET_KEY from environment 2. triton/model.py: Add integrity check in initialize() BEFORE cloudpickle deserialization. Previously the handler called cloudpickle.load() with no verification (acknowledged by a TODO comment). Now reads the file into a buffer, runs perform_integrity_check(), then deserializes with cloudpickle.loads(). 3. triton/server.py: Remove SAGEMAKER_SERVE_SECRET_KEY from container environment variables in both local and SageMaker deployment modes. The key is no longer needed since integrity checking uses plain SHA-256. 4. model_builder_utils.py: Update _hmac_signing() to use plain SHA-256 and stop generating/storing a secret key. Remove generate_secret_key import. The integrity check still detects accidental corruption of model artifacts in S3. The HMAC was providing a false sense of security since the key was exposed via DescribeModel/DescribeEndpointConfig APIs.
1 parent d5eed80 commit 644fc16

File tree

5 files changed

+21
-31
lines changed

5 files changed

+21
-31
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ def build(self):
131131
from sagemaker.serve.detector.pickler import save_pkl
132132
from sagemaker.serve.builder.requirements_manager import RequirementsManager
133133
from sagemaker.serve.validations.check_integrity import (
134-
generate_secret_key,
135134
compute_hash,
136135
)
137136
from sagemaker.core.remote_function.core.serialization import _MetaData
@@ -2884,20 +2883,17 @@ def _save_inference_spec(self) -> None:
28842883
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")
28852884
save_pkl(pkl_path, (self.inference_spec, self.schema_builder))
28862885

2887-
def _hmac_signing(self):
2888-
"""Perform HMAC signing on picke file for integrity check"""
2889-
secret_key = generate_secret_key()
2886+
def _compute_integrity_hash(self):
2887+
"""Compute SHA-256 hash of serve.pkl and store in metadata.json for integrity check."""
28902888
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")
28912889

28922890
with open(str(pkl_path.joinpath("serve.pkl")), "rb") as f:
28932891
buffer = f.read()
2894-
hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
2892+
hash_value = compute_hash(buffer=buffer)
28952893

28962894
with open(str(pkl_path.joinpath("metadata.json")), "wb") as metadata:
28972895
metadata.write(_MetaData(hash_value).to_json())
28982896

2899-
self.secret_key = secret_key
2900-
29012897
def _generate_config_pbtxt(self, pkl_path: Path):
29022898
"""Generate Triton config.pbtxt file."""
29032899
config_path = pkl_path.joinpath("config.pbtxt")
@@ -3100,7 +3096,7 @@ def _prepare_for_triton(self):
31003096

31013097
self._pack_conda_env(pkl_path=pkl_path)
31023098

3103-
self._hmac_signing()
3099+
self._compute_integrity_hash()
31043100

31053101
return
31063102

sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,14 @@ def auto_complete_config(auto_complete_model_config):
2626
def initialize(self, args: dict) -> None:
2727
"""Placeholder docstring"""
2828
serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl")
29-
with open(str(serve_path), mode="rb") as f:
30-
inference_spec, schema_builder = cloudpickle.load(f)
29+
metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json")
3130

32-
# TODO: HMAC signing for integrity check
31+
# Integrity check BEFORE deserialization to prevent RCE via malicious pickle
32+
with open(str(serve_path), "rb") as f:
33+
buffer = f.read()
34+
perform_integrity_check(buffer=buffer, metadata_path=metadata_path)
35+
36+
inference_spec, schema_builder = cloudpickle.loads(buffer)
3337

3438
self.inference_spec = inference_spec
3539
self.schema_builder = schema_builder

sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def _start_triton_server(
4141
env_vars.update(
4242
{
4343
"TRITON_MODEL_DIR": "/models/model",
44-
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
4544
"LOCAL_PYTHON": platform.python_version(),
4645
}
4746
)
@@ -133,7 +132,6 @@ def _upload_triton_artifacts(
133132
env_vars = {
134133
"SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "model",
135134
"TRITON_MODEL_DIR": "/opt/ml/model/model",
136-
"SAGEMAKER_SERVE_SECRET_KEY": secret_key,
137135
"LOCAL_PYTHON": platform.python_version(),
138136
}
139137
return s3_upload_path, env_vars

sagemaker-serve/src/sagemaker/serve/validations/check_integrity.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,21 @@
1-
"""Validates the integrity of pickled file with HMAC signing."""
1+
"""Validates the integrity of pickled file with SHA-256 hash."""
22

33
from __future__ import absolute_import
4-
import secrets
54
import hmac
65
import hashlib
7-
import os
86
from pathlib import Path
97

108
from sagemaker.core.remote_function.core.serialization import _MetaData
119

1210

13-
def generate_secret_key(nbytes: int = 32) -> str:
14-
"""Generates secret key"""
15-
return secrets.token_hex(nbytes)
16-
17-
18-
def compute_hash(buffer: bytes, secret_key: str) -> str:
19-
"""Compute hash value using HMAC"""
20-
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
11+
def compute_hash(buffer: bytes) -> str:
12+
"""Compute SHA-256 hash of the given buffer."""
13+
return hashlib.sha256(buffer).hexdigest()
2114

2215

2316
def perform_integrity_check(buffer: bytes, metadata_path: Path):
24-
"""Validates the integrity of bytes by comparing the hash value"""
25-
secret_key = os.environ.get("SAGEMAKER_SERVE_SECRET_KEY")
26-
actual_hash_value = compute_hash(buffer=buffer, secret_key=secret_key)
17+
"""Validates the integrity of bytes by comparing the hash value."""
18+
actual_hash_value = compute_hash(buffer=buffer)
2719

2820
if not Path.exists(metadata_path):
2921
raise ValueError("Path to metadata.json does not exist")

sagemaker-serve/tests/unit/test_model_builder_utils_triton.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_prepare_for_triton_tensorflow(self, mock_export, mock_copy):
113113
@patch('shutil.copy2')
114114
@patch.object(_ModelBuilderUtils, '_generate_config_pbtxt')
115115
@patch.object(_ModelBuilderUtils, '_pack_conda_env')
116-
@patch.object(_ModelBuilderUtils, '_hmac_signing')
116+
@patch.object(_ModelBuilderUtils, '_compute_integrity_hash')
117117
def test_prepare_for_triton_inference_spec(self, mock_hmac, mock_pack, mock_config, mock_copy):
118118
"""Test preparing inference spec for Triton."""
119119
utils = _ModelBuilderUtils()
@@ -262,9 +262,9 @@ def test_save_inference_spec(self):
262262

263263

264264
class TestHMACSignin(unittest.TestCase):
265-
"""Test _hmac_signing method."""
265+
"""Test _compute_integrity_hash method."""
266266

267-
def test_hmac_signing(self):
267+
def test_compute_integrity_hash(self):
268268
"""Test HMAC signing."""
269269
utils = _ModelBuilderUtils()
270270

@@ -276,7 +276,7 @@ def test_hmac_signing(self):
276276
# Create dummy serve.pkl
277277
(pkl_path / "serve.pkl").write_bytes(b"dummy content")
278278

279-
utils._hmac_signing()
279+
utils._compute_integrity_hash()
280280

281281
# Secret key is generated, not mocked
282282
self.assertIsNotNone(utils.secret_key)

0 commit comments

Comments
 (0)