Skip to content

Commit b23641e

Browse files
author
Namrata Madan
committed
Revert "Bug fix for hmac key for V3 (aws#5379)"
This reverts commit fb0d789.
1 parent 0976df1 commit b23641e

18 files changed

Lines changed: 1117 additions & 1867 deletions

File tree

sagemaker-core/src/sagemaker/core/remote_function/client.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def wrapper(*args, **kwargs):
369369
s3_uri=s3_path_join(
370370
job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
371371
),
372-
372+
hmac_key=job.hmac_key,
373373
)
374374
except ServiceError as serr:
375375
chained_e = serr.__cause__
@@ -406,7 +406,7 @@ def wrapper(*args, **kwargs):
406406
return serialization.deserialize_obj_from_s3(
407407
sagemaker_session=job_settings.sagemaker_session,
408408
s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
409-
409+
hmac_key=job.hmac_key,
410410
)
411411

412412
if job.describe()["TrainingJobStatus"] == "Stopped":
@@ -1008,7 +1008,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
10081008
job_return = serialization.deserialize_obj_from_s3(
10091009
sagemaker_session=sagemaker_session,
10101010
s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
1011-
1011+
hmac_key=job.hmac_key,
10121012
)
10131013
except DeserializationError as e:
10141014
client_exception = e
@@ -1020,7 +1020,7 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
10201020
job_exception = serialization.deserialize_exception_from_s3(
10211021
sagemaker_session=sagemaker_session,
10221022
s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
1023-
1023+
hmac_key=job.hmac_key,
10241024
)
10251025
except ServiceError as serr:
10261026
chained_e = serr.__cause__
@@ -1110,7 +1110,7 @@ def result(self, timeout: float = None) -> Any:
11101110
self._return = serialization.deserialize_obj_from_s3(
11111111
sagemaker_session=self._job.sagemaker_session,
11121112
s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
1113-
1113+
hmac_key=self._job.hmac_key,
11141114
)
11151115
self._state = _FINISHED
11161116
return self._return
@@ -1119,7 +1119,7 @@ def result(self, timeout: float = None) -> Any:
11191119
self._exception = serialization.deserialize_exception_from_s3(
11201120
sagemaker_session=self._job.sagemaker_session,
11211121
s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
1122-
1122+
hmac_key=self._job.hmac_key,
11231123
)
11241124
except ServiceError as serr:
11251125
chained_e = serr.__cause__

sagemaker-core/src/sagemaker/core/remote_function/core/pipeline_variables.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class _DelayedReturnResolver:
164164
def __init__(
165165
self,
166166
delayed_returns: List[_DelayedReturn],
167+
hmac_key: str,
167168
properties_resolver: _PropertiesResolver,
168169
parameter_resolver: _ParameterResolver,
169170
execution_variable_resolver: _ExecutionVariableResolver,
@@ -174,6 +175,7 @@ def __init__(
174175
175176
Args:
176177
delayed_returns: list of delayed returns to resolve.
178+
hmac_key: key used to encrypt serialized and deserialized function and arguments.
177179
properties_resolver: resolver used to resolve step properties.
178180
parameter_resolver: resolver used to pipeline parameters.
179181
execution_variable_resolver: resolver used to resolve execution variables.
@@ -195,6 +197,7 @@ def deserialization_task(uri):
195197
return uri, deserialize_obj_from_s3(
196198
sagemaker_session=settings["sagemaker_session"],
197199
s3_uri=uri,
200+
hmac_key=hmac_key,
198201
)
199202

200203
with ThreadPoolExecutor() as executor:
@@ -244,6 +247,7 @@ def resolve_pipeline_variables(
244247
context: Context,
245248
func_args: Tuple,
246249
func_kwargs: Dict,
250+
hmac_key: str,
247251
s3_base_uri: str,
248252
**settings,
249253
):
@@ -253,6 +257,7 @@ def resolve_pipeline_variables(
253257
context: context for the execution.
254258
func_args: function args.
255259
func_kwargs: function kwargs.
260+
hmac_key: key used to encrypt serialized and deserialized function and arguments.
256261
s3_base_uri: the s3 base uri of the function step that the serialized artifacts
257262
will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name.
258263
**settings: settings to pass to the deserialization function.
@@ -275,6 +280,7 @@ def resolve_pipeline_variables(
275280
properties_resolver = _PropertiesResolver(context)
276281
delayed_return_resolver = _DelayedReturnResolver(
277282
delayed_returns=delayed_returns,
283+
hmac_key=hmac_key,
278284
properties_resolver=properties_resolver,
279285
parameter_resolver=parameter_resolver,
280286
execution_variable_resolver=execution_variable_resolver,

sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import io
2020

2121
import sys
22+
import hmac
2223
import hashlib
2324
import pickle
2425

@@ -155,14 +156,15 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any:
155156

156157
# TODO: use dask serializer in case dask distributed is installed in users' environment.
157158
def serialize_func_to_s3(
158-
func: Callable, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
159+
func: Callable, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
159160
):
160161
"""Serializes function and uploads it to S3.
161162
162163
Args:
163164
sagemaker_session (sagemaker.core.helper.session.Session):
164165
The underlying Boto3 session which AWS service calls are delegated to.
165166
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
167+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
166168
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
167169
func: function to be serialized and persisted
168170
Raises:
@@ -171,13 +173,14 @@ def serialize_func_to_s3(
171173

172174
_upload_payload_and_metadata_to_s3(
173175
bytes_to_upload=CloudpickleSerializer.serialize(func),
176+
hmac_key=hmac_key,
174177
s3_uri=s3_uri,
175178
sagemaker_session=sagemaker_session,
176179
s3_kms_key=s3_kms_key,
177180
)
178181

179182

180-
def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callable:
183+
def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Callable:
181184
"""Downloads from S3 and then deserializes data objects.
182185
183186
This method downloads the serialized training job outputs to a temporary directory and
@@ -187,6 +190,7 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callabl
187190
sagemaker_session (sagemaker.core.helper.session.Session):
188191
The underlying sagemaker session which AWS service calls are delegated to.
189192
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
193+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized func.
190194
Returns :
191195
The deserialized function.
192196
Raises:
@@ -199,14 +203,14 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callabl
199203
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
200204

201205
_perform_integrity_check(
202-
expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
206+
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
203207
)
204208

205209
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
206210

207211

208212
def serialize_obj_to_s3(
209-
obj: Any, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
213+
obj: Any, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
210214
):
211215
"""Serializes data object and uploads it to S3.
212216
@@ -215,13 +219,15 @@ def serialize_obj_to_s3(
215219
The underlying Boto3 session which AWS service calls are delegated to.
216220
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
217221
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
222+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
218223
obj: object to be serialized and persisted
219224
Raises:
220225
SerializationError: when fail to serialize object to bytes.
221226
"""
222227

223228
_upload_payload_and_metadata_to_s3(
224229
bytes_to_upload=CloudpickleSerializer.serialize(obj),
230+
hmac_key=hmac_key,
225231
s3_uri=s3_uri,
226232
sagemaker_session=sagemaker_session,
227233
s3_kms_key=s3_kms_key,
@@ -268,13 +274,14 @@ def json_serialize_obj_to_s3(
268274
)
269275

270276

271-
def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
277+
def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
272278
"""Downloads from S3 and then deserializes data objects.
273279
274280
Args:
275281
sagemaker_session (sagemaker.core.helper.session.Session):
276282
The underlying sagemaker session which AWS service calls are delegated to.
277283
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
284+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
278285
Returns :
279286
Deserialized python objects.
280287
Raises:
@@ -288,14 +295,14 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
288295
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
289296

290297
_perform_integrity_check(
291-
expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
298+
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
292299
)
293300

294301
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
295302

296303

297304
def serialize_exception_to_s3(
298-
exc: Exception, sagemaker_session: Session, s3_uri: str, s3_kms_key: str = None
305+
exc: Exception, sagemaker_session: Session, s3_uri: str, hmac_key: str, s3_kms_key: str = None
299306
):
300307
"""Serializes exception with traceback and uploads it to S3.
301308
@@ -304,6 +311,7 @@ def serialize_exception_to_s3(
304311
The underlying Boto3 session which AWS service calls are delegated to.
305312
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
306313
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
314+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
307315
exc: Exception to be serialized and persisted
308316
Raises:
309317
SerializationError: when fail to serialize object to bytes.
@@ -312,6 +320,7 @@ def serialize_exception_to_s3(
312320

313321
_upload_payload_and_metadata_to_s3(
314322
bytes_to_upload=CloudpickleSerializer.serialize(exc),
323+
hmac_key=hmac_key,
315324
s3_uri=s3_uri,
316325
sagemaker_session=sagemaker_session,
317326
s3_kms_key=s3_kms_key,
@@ -320,6 +329,7 @@ def serialize_exception_to_s3(
320329

321330
def _upload_payload_and_metadata_to_s3(
322331
bytes_to_upload: Union[bytes, io.BytesIO],
332+
hmac_key: str,
323333
s3_uri: str,
324334
sagemaker_session: Session,
325335
s3_kms_key,
@@ -328,14 +338,15 @@ def _upload_payload_and_metadata_to_s3(
328338
329339
Args:
330340
bytes_to_upload (bytes): Serialized bytes to upload.
341+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj.
331342
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
332343
sagemaker_session (sagemaker.core.helper.session.Session):
333344
The underlying Boto3 session which AWS service calls are delegated to.
334345
s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3.
335346
"""
336347
_upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session)
337348

338-
sha256_hash = _compute_hash(bytes_to_upload)
349+
sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key)
339350

340351
_upload_bytes_to_s3(
341352
_MetaData(sha256_hash).to_json(),
@@ -345,13 +356,14 @@ def _upload_payload_and_metadata_to_s3(
345356
)
346357

347358

348-
def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
359+
def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: str) -> Any:
349360
"""Downloads from S3 and then deserializes exception.
350361
351362
Args:
352363
sagemaker_session (sagemaker.core.helper.session.Session):
353364
The underlying sagemaker session which AWS service calls are delegated to.
354365
s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded.
366+
hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized exception.
355367
Returns :
356368
Deserialized exception with traceback.
357369
Raises:
@@ -365,7 +377,7 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> An
365377
bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session)
366378

367379
_perform_integrity_check(
368-
expected_hash_value=metadata.sha256_hash, buffer=bytes_to_deserialize
380+
expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize
369381
)
370382

371383
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
@@ -391,19 +403,19 @@ def _read_bytes_from_s3(s3_uri, sagemaker_session):
391403
) from e
392404

393405

394-
def _compute_hash(buffer: bytes) -> str:
395-
"""Compute the sha256 hash"""
396-
return hashlib.sha256(buffer).hexdigest()
406+
def _compute_hash(buffer: bytes, secret_key: str) -> str:
407+
"""Compute the hmac-sha256 hash"""
408+
return hmac.new(secret_key.encode(), msg=buffer, digestmod=hashlib.sha256).hexdigest()
397409

398410

399-
def _perform_integrity_check(expected_hash_value: str, buffer: bytes):
411+
def _perform_integrity_check(expected_hash_value: str, secret_key: str, buffer: bytes):
400412
"""Performs integrity checks for serialized code/arguments uploaded to s3.
401413
402414
Verifies whether the hash read from s3 matches the hash calculated
403415
during remote function execution.
404416
"""
405-
actual_hash_value = _compute_hash(buffer=buffer)
406-
if expected_hash_value != actual_hash_value:
417+
actual_hash_value = _compute_hash(buffer=buffer, secret_key=secret_key)
418+
if not hmac.compare_digest(expected_hash_value, actual_hash_value):
407419
raise DeserializationError(
408420
"Integrity check for the serialized function or data failed. "
409421
"Please restrict access to your S3 bucket"

0 commit comments

Comments
 (0)