Skip to content

Commit 187eecb

Browse files
author
Pravali Uppugunduri
committed
fix: Add HMAC integrity verification for Triton inference handler
- Add HMAC integrity check before pickle deserialization in TritonPythonModel.initialize() - Replace hardcoded secret key with generate_secret_key() in _prepare_for_triton() ONNX path - Add _hmac_signing() after ONNX export for both PyTorch and TensorFlow frameworks - Add secret key validation in _start_triton_server() to reject None/empty keys Fixes RCE vulnerabilities in Triton handler by aligning with HMAC verification patterns used by TorchServe, MMS, TF Serving, and SMD handlers.
1 parent 6a174f4 commit 187eecb

File tree

4 files changed

+40
-7
lines changed

4 files changed

+40
-7
lines changed

sagemaker-serve/src/sagemaker/serve/model_builder_utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2879,10 +2879,12 @@ def _is_gpu_instance(self, instance_type: str) -> bool:
28792879
return instance_family in GPU_INSTANCE_FAMILIES
28802880

28812881
def _save_inference_spec(self) -> None:
2882-
"""Save inference specification to pickle file."""
2882+
"""Save inference specification or model to pickle file."""
2883+
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")
28832884
if self.inference_spec:
2884-
pkl_path = Path(self.model_path).joinpath("model_repository").joinpath("model")
28852885
save_pkl(pkl_path, (self.inference_spec, self.schema_builder))
2886+
elif self.model:
2887+
save_pkl(pkl_path, (self.model, self.schema_builder))
28862888

28872889
def _hmac_signing(self):
28882890
"""Perform HMAC signing on picke file for integrity check"""
@@ -3075,18 +3077,20 @@ def _prepare_for_triton(self):
30753077
export_path.mkdir(parents=True)
30763078

30773079
if self.model:
3078-
self.secret_key = "dummy secret key for onnx backend"
3080+
self.secret_key = generate_secret_key()
30793081

30803082
if self.framework == Framework.PYTORCH:
30813083
self._export_pytorch_to_onnx(
30823084
export_path=export_path, model=self.model, schema_builder=self.schema_builder
30833085
)
3086+
self._hmac_signing()
30843087
return
30853088

30863089
if self.framework == Framework.TENSORFLOW:
30873090
self._export_tf_to_onnx(
30883091
export_path=export_path, model=self.model, schema_builder=self.schema_builder
30893092
)
3093+
self._hmac_signing()
30903094
return
30913095

30923096
raise ValueError("%s is not supported" % self.framework)

sagemaker-serve/src/sagemaker/serve/model_server/triton/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,13 @@ def auto_complete_config(auto_complete_model_config):
2626
def initialize(self, args: dict) -> None:
2727
"""Placeholder docstring"""
2828
serve_path = Path(TRITON_MODEL_DIR).joinpath("serve.pkl")
29+
metadata_path = Path(TRITON_MODEL_DIR).joinpath("metadata.json")
2930
with open(str(serve_path), mode="rb") as f:
30-
inference_spec, schema_builder = cloudpickle.load(f)
31+
buffer = f.read()
32+
perform_integrity_check(buffer=buffer, metadata_path=str(metadata_path))
3133

32-
# TODO: HMAC signing for integrity check
34+
with open(str(serve_path), mode="rb") as f:
35+
inference_spec, schema_builder = cloudpickle.load(f)
3336

3437
self.inference_spec = inference_spec
3538
self.schema_builder = schema_builder

sagemaker-serve/src/sagemaker/serve/model_server/triton/server.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,12 @@ def _start_triton_server(
3636
env_vars: dict,
3737
):
3838
"""Placeholder docstring"""
39+
if not isinstance(secret_key, str) or not secret_key.strip():
40+
raise ValueError(
41+
"A valid secret key is required for Triton deployments. "
42+
"The secret key must be a non-empty string generated by generate_secret_key(). "
43+
f"Received: {type(secret_key).__name__}"
44+
)
3945
self.container_name = "triton" + uuid.uuid1().hex
4046
model_repository = model_path + "/model_repository"
4147
env_vars.update(

sagemaker-serve/tests/unit/test_model_builder_utils_triton.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,9 @@ class TestPrepareForTriton(unittest.TestCase):
8181
"""Test _prepare_for_triton method."""
8282

8383
@patch('shutil.copy2')
84+
@patch.object(_ModelBuilderUtils, '_hmac_signing')
8485
@patch.object(_ModelBuilderUtils, '_export_pytorch_to_onnx')
85-
def test_prepare_for_triton_pytorch(self, mock_export, mock_copy):
86+
def test_prepare_for_triton_pytorch(self, mock_export, mock_hmac, mock_copy):
8687
"""Test preparing PyTorch model for Triton."""
8788
utils = _ModelBuilderUtils()
8889
utils.framework = Framework.PYTORCH
@@ -94,10 +95,12 @@ def test_prepare_for_triton_pytorch(self, mock_export, mock_copy):
9495
utils._prepare_for_triton()
9596

9697
mock_export.assert_called_once()
98+
mock_hmac.assert_called_once()
9799

98100
@patch('shutil.copy2')
101+
@patch.object(_ModelBuilderUtils, '_hmac_signing')
99102
@patch.object(_ModelBuilderUtils, '_export_tf_to_onnx')
100-
def test_prepare_for_triton_tensorflow(self, mock_export, mock_copy):
103+
def test_prepare_for_triton_tensorflow(self, mock_export, mock_hmac, mock_copy):
101104
"""Test preparing TensorFlow model for Triton."""
102105
utils = _ModelBuilderUtils()
103106
utils.framework = Framework.TENSORFLOW
@@ -109,6 +112,7 @@ def test_prepare_for_triton_tensorflow(self, mock_export, mock_copy):
109112
utils._prepare_for_triton()
110113

111114
mock_export.assert_called_once()
115+
mock_hmac.assert_called_once()
112116

113117
@patch('shutil.copy2')
114118
@patch.object(_ModelBuilderUtils, '_generate_config_pbtxt')
@@ -259,6 +263,22 @@ def test_save_inference_spec(self):
259263

260264
# Check that serve.pkl was created
261265
self.assertTrue(os.path.exists(os.path.join(pkl_path, "serve.pkl")))
266+
def test_save_inference_spec_with_model(self):
267+
"""Test saving model when inference_spec is None."""
268+
utils = _ModelBuilderUtils()
269+
utils.inference_spec = None
270+
utils.model = Mock()
271+
utils.schema_builder = Mock()
272+
273+
with tempfile.TemporaryDirectory() as tmpdir:
274+
utils.model_path = tmpdir
275+
pkl_path = os.path.join(tmpdir, "model_repository", "model")
276+
os.makedirs(pkl_path, exist_ok=True)
277+
278+
utils._save_inference_spec()
279+
280+
# Check that serve.pkl was created for the model path
281+
self.assertTrue(os.path.exists(os.path.join(pkl_path, "serve.pkl")))
262282

263283

264284
class TestHMACSignin(unittest.TestCase):

0 commit comments

Comments
 (0)