1919import io
2020
2121import sys
22+ import hmac
2223import hashlib
2324import pickle
2425
@@ -155,14 +156,15 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any:
155156
156157# TODO: use dask serializer in case dask distributed is installed in users' environment.
157158def serialize_func_to_s3 (
158- func : Callable , sagemaker_session : Session , s3_uri : str , s3_kms_key : str = None
159+ func : Callable , sagemaker_session : Session , s3_uri : str , hmac_key : str , s3_kms_key : str = None
159160):
160161 """Serializes function and uploads it to S3.
161162
162163 Args:
163164 sagemaker_session (sagemaker.core.helper.session.Session):
164165 The underlying Boto3 session which AWS service calls are delegated to.
165166 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
167+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
166168 s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
167169 func: function to be serialized and persisted
168170 Raises:
@@ -171,13 +173,14 @@ def serialize_func_to_s3(
171173
172174 _upload_payload_and_metadata_to_s3 (
173175 bytes_to_upload = CloudpickleSerializer .serialize (func ),
176+ hmac_key = hmac_key ,
174177 s3_uri = s3_uri ,
175178 sagemaker_session = sagemaker_session ,
176179 s3_kms_key = s3_kms_key ,
177180 )
178181
179182
180- def deserialize_func_from_s3 (sagemaker_session : Session , s3_uri : str ) -> Callable :
183+ def deserialize_func_from_s3 (sagemaker_session : Session , s3_uri : str , hmac_key : str ) -> Callable :
181184 """Downloads from S3 and then deserializes data objects.
182185
183186 This method downloads the serialized training job outputs to a temporary directory and
@@ -187,6 +190,7 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callabl
187190 sagemaker_session (sagemaker.core.helper.session.Session):
188191 The underlying sagemaker session which AWS service calls are delegated to.
189192 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
193+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
190194 Returns :
191195 The deserialized function.
192196 Raises:
@@ -199,14 +203,14 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callabl
199203 bytes_to_deserialize = _read_bytes_from_s3 (f"{ s3_uri } /payload.pkl" , sagemaker_session )
200204
201205 _perform_integrity_check (
202- expected_hash_value = metadata .sha256_hash , buffer = bytes_to_deserialize
206+ expected_hash_value = metadata .sha256_hash , secret_key = hmac_key , buffer = bytes_to_deserialize
203207 )
204208
205209 return CloudpickleSerializer .deserialize (f"{ s3_uri } /payload.pkl" , bytes_to_deserialize )
206210
207211
208212def serialize_obj_to_s3 (
209- obj : Any , sagemaker_session : Session , s3_uri : str , s3_kms_key : str = None
213+ obj : Any , sagemaker_session : Session , s3_uri : str , hmac_key : str , s3_kms_key : str = None
210214):
211215 """Serializes data object and uploads it to S3.
212216
@@ -215,13 +219,15 @@ def serialize_obj_to_s3(
215219 The underlying Boto3 session which AWS service calls are delegated to.
216220 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
217221 s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
222+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
218223 obj: object to be serialized and persisted
219224 Raises:
220225 SerializationError: when fail to serialize object to bytes.
221226 """
222227
223228 _upload_payload_and_metadata_to_s3 (
224229 bytes_to_upload = CloudpickleSerializer .serialize (obj ),
230+ hmac_key = hmac_key ,
225231 s3_uri = s3_uri ,
226232 sagemaker_session = sagemaker_session ,
227233 s3_kms_key = s3_kms_key ,
@@ -268,13 +274,14 @@ def json_serialize_obj_to_s3(
268274 )
269275
270276
271- def deserialize_obj_from_s3 (sagemaker_session : Session , s3_uri : str ) -> Any :
277+ def deserialize_obj_from_s3 (sagemaker_session : Session , s3_uri : str , hmac_key : str ) -> Any :
272278 """Downloads from S3 and then deserializes data objects.
273279
274280 Args:
275281 sagemaker_session (sagemaker.core.helper.session.Session):
276282 The underlying sagemaker session which AWS service calls are delegated to.
277283 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
284+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
278285 Returns :
279286 Deserialized python objects.
280287 Raises:
@@ -288,14 +295,14 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
288295 bytes_to_deserialize = _read_bytes_from_s3 (f"{ s3_uri } /payload.pkl" , sagemaker_session )
289296
290297 _perform_integrity_check (
291- expected_hash_value = metadata .sha256_hash , buffer = bytes_to_deserialize
298+ expected_hash_value = metadata .sha256_hash , secret_key = hmac_key , buffer = bytes_to_deserialize
292299 )
293300
294301 return CloudpickleSerializer .deserialize (f"{ s3_uri } /payload.pkl" , bytes_to_deserialize )
295302
296303
297304def serialize_exception_to_s3 (
298- exc : Exception , sagemaker_session : Session , s3_uri : str , s3_kms_key : str = None
305+ exc : Exception , sagemaker_session : Session , s3_uri : str , hmac_key : str , s3_kms_key : str = None
299306):
300307 """Serializes exception with traceback and uploads it to S3.
301308
@@ -304,6 +311,7 @@ def serialize_exception_to_s3(
304311 The underlying Boto3 session which AWS service calls are delegated to.
305312 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
306313 s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
314+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
307315 exc: Exception to be serialized and persisted
308316 Raises:
309317 SerializationError: when fail to serialize object to bytes.
@@ -312,6 +320,7 @@ def serialize_exception_to_s3(
312320
313321 _upload_payload_and_metadata_to_s3 (
314322 bytes_to_upload = CloudpickleSerializer .serialize (exc ),
323+ hmac_key = hmac_key ,
315324 s3_uri = s3_uri ,
316325 sagemaker_session = sagemaker_session ,
317326 s3_kms_key = s3_kms_key ,
@@ -320,6 +329,7 @@ def serialize_exception_to_s3(
320329
321330def _upload_payload_and_metadata_to_s3 (
322331 bytes_to_upload : Union [bytes , io .BytesIO ],
332+ hmac_key : str ,
323333 s3_uri : str ,
324334 sagemaker_session : Session ,
325335 s3_kms_key ,
@@ -328,14 +338,15 @@ def _upload_payload_and_metadata_to_s3(
328338
329339 Args:
330340 bytes_to_upload (bytes): Serialized bytes to upload.
341+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
331342 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
332343 sagemaker_session (sagemaker.core.helper.session.Session):
333344 The underlying Boto3 session which AWS service calls are delegated to.
334345 s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
335346 """
336347 _upload_bytes_to_s3 (bytes_to_upload , f"{ s3_uri } /payload.pkl" , s3_kms_key , sagemaker_session )
337348
338- sha256_hash = _compute_hash (bytes_to_upload )
349+ sha256_hash = _compute_hash (bytes_to_upload , secret_key = hmac_key )
339350
340351 _upload_bytes_to_s3 (
341352 _MetaData (sha256_hash ).to_json (),
@@ -345,13 +356,14 @@ def _upload_payload_and_metadata_to_s3(
345356 )
346357
347358
348- def deserialize_exception_from_s3 (sagemaker_session : Session , s3_uri : str ) -> Any :
359+ def deserialize_exception_from_s3 (sagemaker_session : Session , s3_uri : str , hmac_key : str ) -> Any :
349360 """Downloads from S3 and then deserializes exception.
350361
351362 Args:
352363 sagemaker_session (sagemaker.core.helper.session.Session):
353364 The underlying sagemaker session which AWS service calls are delegated to.
354365 s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
366+ hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
355367 Returns :
356368 Deserialized exception with traceback.
357369 Raises:
@@ -365,7 +377,7 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> An
365377 bytes_to_deserialize = _read_bytes_from_s3 (f"{ s3_uri } /payload.pkl" , sagemaker_session )
366378
367379 _perform_integrity_check (
368- expected_hash_value = metadata .sha256_hash , buffer = bytes_to_deserialize
380+ expected_hash_value = metadata .sha256_hash , secret_key = hmac_key , buffer = bytes_to_deserialize
369381 )
370382
371383 return CloudpickleSerializer .deserialize (f"{ s3_uri } /payload.pkl" , bytes_to_deserialize )
@@ -391,19 +403,19 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session):
391403 ) from e
392404
393405
394- def _compute_hash (buffer : bytes ) -> str :
395- """Compute the sha256 hash"""
396- return hashlib . sha256 ( buffer ).hexdigest ()
406+ def _compute_hash (buffer : bytes , secret_key : str ) -> str :
407+ """Compute the hmac- sha256 hash"""
408+ return hmac . new ( secret_key . encode (), msg = buffer , digestmod = hashlib . sha256 ).hexdigest ()
397409
398410
399- def _perform_integrity_check (expected_hash_value : str , buffer : bytes ):
411+ def _perform_integrity_check (expected_hash_value : str , secret_key : str , buffer : bytes ):
400412 """Performs integrity checks for serialized code/arguments uploaded to s3.
401413
402414 Verifies whether the hash read from s3 matches the hash calculated
403415 during remote function execution.
404416 """
405- actual_hash_value = _compute_hash (buffer = buffer )
406- if expected_hash_value != actual_hash_value :
417+ actual_hash_value = _compute_hash (buffer = buffer , secret_key = secret_key )
418+ if not hmac . compare_digest ( expected_hash_value , actual_hash_value ) :
407419 raise DeserializationError (
408420 "Integrity check for the serialized function or data failed. "
409421 "Please restrict access to your S3 bucket"
0 commit comments