Skip to content

Commit 72a3982

Browse files
committed
Update after testing
1 parent d469f05 commit 72a3982

4 files changed

Lines changed: 51 additions & 101 deletions

File tree

sagemaker-core/src/sagemaker/core/remote_function/client.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,6 @@ def wrapper(*args, **kwargs):
369369
s3_uri=s3_path_join(
370370
job_settings.s3_root_uri, job.job_name, EXCEPTION_FOLDER
371371
),
372-
373372
)
374373
except ServiceError as serr:
375374
chained_e = serr.__cause__
@@ -406,7 +405,6 @@ def wrapper(*args, **kwargs):
406405
return serialization.deserialize_obj_from_s3(
407406
sagemaker_session=job_settings.sagemaker_session,
408407
s3_uri=s3_path_join(job_settings.s3_root_uri, job.job_name, RESULTS_FOLDER),
409-
410408
)
411409

412410
if job.describe()["TrainingJobStatus"] == "Stopped":
@@ -1008,7 +1006,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
10081006
job_return = serialization.deserialize_obj_from_s3(
10091007
sagemaker_session=sagemaker_session,
10101008
s3_uri=s3_path_join(job.s3_uri, RESULTS_FOLDER),
1011-
10121009
)
10131010
except DeserializationError as e:
10141011
client_exception = e
@@ -1020,7 +1017,6 @@ def from_describe_response(describe_training_job_response, sagemaker_session):
10201017
job_exception = serialization.deserialize_exception_from_s3(
10211018
sagemaker_session=sagemaker_session,
10221019
s3_uri=s3_path_join(job.s3_uri, EXCEPTION_FOLDER),
1023-
10241020
)
10251021
except ServiceError as serr:
10261022
chained_e = serr.__cause__
@@ -1110,7 +1106,6 @@ def result(self, timeout: float = None) -> Any:
11101106
self._return = serialization.deserialize_obj_from_s3(
11111107
sagemaker_session=self._job.sagemaker_session,
11121108
s3_uri=s3_path_join(self._job.s3_uri, RESULTS_FOLDER),
1113-
11141109
)
11151110
self._state = _FINISHED
11161111
return self._return
@@ -1119,7 +1114,6 @@ def result(self, timeout: float = None) -> Any:
11191114
self._exception = serialization.deserialize_exception_from_s3(
11201115
sagemaker_session=self._job.sagemaker_session,
11211116
s3_uri=s3_path_join(self._job.s3_uri, EXCEPTION_FOLDER),
1122-
11231117
)
11241118
except ServiceError as serr:
11251119
chained_e = serr.__cause__

sagemaker-core/src/sagemaker/core/remote_function/core/serialization.py

Lines changed: 40 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,6 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str) -> Callabl
216216
buffer=bytes_to_deserialize,
217217
sagemaker_session=sagemaker_session,
218218
secret_arn=metadata.secret_arn,
219-
s3_uri=s3_uri
220219
)
221220

222221
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
@@ -315,7 +314,6 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str) -> Any:
315314
buffer=bytes_to_deserialize,
316315
sagemaker_session=sagemaker_session,
317316
secret_arn=metadata.secret_arn,
318-
s3_uri=s3_uri
319317
)
320318

321319
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
@@ -411,7 +409,6 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str) -> An
411409
buffer=bytes_to_deserialize,
412410
sagemaker_session=sagemaker_session,
413411
secret_arn=metadata.secret_arn,
414-
s3_uri=s3_uri
415412
)
416413

417414
return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize)
@@ -518,17 +515,24 @@ def _store_secret_arn_in_parameter_store(
518515
ssm_client = sagemaker_session.boto_session.client('ssm')
519516
parameter_name = f"/sagemaker/remote-function/{job_name}/secret-arn"
520517

521-
ssm_client.put_parameter(
522-
Name=parameter_name,
523-
Value=secret_arn,
524-
Type="String",
525-
Overwrite=True,
526-
Description=f"Secret ARN for SageMaker remote function job {job_name}",
527-
Tags=[
528-
{'Key': 'SageMaker:JobName', 'Value': job_name},
529-
{'Key': 'SageMaker:Purpose', 'Value': 'RemoteFunctionIntegrity'}
530-
]
531-
)
518+
try:
519+
ssm_client.put_parameter(
520+
Name=parameter_name,
521+
Value=secret_arn,
522+
Type="String",
523+
Description=f"Secret ARN for SageMaker remote function job {job_name}",
524+
Tags=[
525+
{'Key': 'SageMaker:JobName', 'Value': job_name},
526+
{'Key': 'SageMaker:Purpose', 'Value': 'RemoteFunctionIntegrity'}
527+
]
528+
)
529+
except ssm_client.exceptions.ParameterAlreadyExists:
530+
ssm_client.put_parameter(
531+
Name=parameter_name,
532+
Value=secret_arn,
533+
Type="String",
534+
Overwrite=True,
535+
)
532536

533537

534538
def _get_secret_arn_from_parameter_store(
@@ -560,46 +564,47 @@ def _get_secret_arn_from_parameter_store(
560564
)
561565

562566

563-
def _extract_job_name_from_s3_uri(s3_uri: str) -> str:
564-
"""Extract job name from S3 URI.
567+
def _extract_job_name_from_secret_arn(secret_arn: str) -> str:
568+
"""Extract job name from a Secrets Manager ARN.
565569
566-
S3 URI format: s3://bucket/path/to/job-name/results
567-
or: s3://bucket/job-name/function
570+
Secret name convention: sagemaker/remote-function/{job_name}/hmac-key
571+
ARN format: arn:aws:secretsmanager:region:account:secret:sagemaker/remote-function/{job_name}/hmac-key-XXXXXX
568572
569573
Args:
570-
s3_uri: S3 URI containing job name
574+
secret_arn: Full ARN of the secret
571575
572576
Returns:
573-
Job name extracted from URI
577+
Extracted job name
578+
579+
Raises:
580+
DeserializationError: If ARN doesn't match expected format
574581
"""
575-
# Remove s3:// prefix and split by /
576-
parts = s3_uri.replace("s3://", "").split("/")
577-
578-
# Try to find a part that looks like a job name
579-
# Job names typically contain execution IDs or timestamps
580-
for part in reversed(parts):
581-
if part and part not in ['function', 'arguments', 'results', 'exception', 'payload.pkl', 'metadata.json']:
582-
return part
583-
584-
# Fallback: use the last meaningful part
585-
return parts[-2] if len(parts) > 1 else parts[0]
582+
import re
583+
match = re.search(r":secret:sagemaker/remote-function/(.+)/hmac-key", secret_arn)
584+
if not match:
585+
raise DeserializationError(
586+
f"Secret ARN does not match expected format "
587+
f"'sagemaker/remote-function/{{job_name}}/hmac-key': {secret_arn}"
588+
)
589+
return match.group(1)
586590

587591

588592
def _validate_secret_arn(
589593
sagemaker_session: Session,
590594
metadata_secret_arn: str,
591-
job_name: str
592595
):
593596
"""Validate secret ARN from metadata against trusted sources.
594597
595598
Implements two mitigations:
596599
1. Validate secret is in same AWS account
597600
2. Validate secret ARN matches Parameter Store (trust anchor)
598601
602+
The job_name is derived from the secret ARN's naming convention, then
603+
independently validated against the SSM trust anchor.
604+
599605
Args:
600606
sagemaker_session: SageMaker session
601607
metadata_secret_arn: Secret ARN from S3 metadata (untrusted)
602-
job_name: Remote function job name
603608
604609
Raises:
605610
DeserializationError: If validation fails
@@ -623,6 +628,7 @@ def _validate_secret_arn(
623628
)
624629

625630
# Mitigation #3: Validate against Parameter Store (trust anchor)
631+
job_name = _extract_job_name_from_secret_arn(metadata_secret_arn)
626632
expected_secret_arn = _get_secret_arn_from_parameter_store(sagemaker_session, job_name)
627633

628634
if metadata_secret_arn != expected_secret_arn:
@@ -638,7 +644,6 @@ def _perform_integrity_check(
638644
buffer: bytes,
639645
sagemaker_session: Optional[Session] = None,
640646
secret_arn: Optional[str] = None,
641-
s3_uri: Optional[str] = None
642647
):
643648
"""Performs integrity checks for serialized code/arguments uploaded to s3.
644649
@@ -650,7 +655,6 @@ def _perform_integrity_check(
650655
buffer: Serialized data buffer
651656
sagemaker_session: SageMaker session (required if secret_arn is provided)
652657
secret_arn: ARN of secret containing HMAC key (None for legacy plain SHA-256)
653-
s3_uri: S3 URI for extracting job name (required if secret_arn is provided)
654658
"""
655659
if secret_arn:
656660
# New secure method: HMAC with key from Secrets Manager
@@ -659,16 +663,8 @@ def _perform_integrity_check(
659663
"sagemaker_session is required for HMAC integrity check"
660664
)
661665

662-
if not s3_uri:
663-
raise DeserializationError(
664-
"s3_uri is required for HMAC integrity check to extract job name"
665-
)
666-
667-
# Extract job name from S3 URI
668-
job_name = _extract_job_name_from_s3_uri(s3_uri)
669-
670666
# Validate secret ARN (Mitigations #1 and #3)
671-
_validate_secret_arn(sagemaker_session, secret_arn, job_name)
667+
_validate_secret_arn(sagemaker_session, secret_arn)
672668

673669
# Now safe to retrieve HMAC key
674670
hmac_key = _get_hmac_key_from_secret(sagemaker_session, secret_arn)

sagemaker-core/src/sagemaker/core/remote_function/core/stored_function.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,9 @@ def save_pipeline_step_function(self, serialized_data):
145145
)
146146
serialization._upload_payload_and_metadata_to_s3(
147147
bytes_to_upload=serialized_data.func,
148-
148+
job_name=self.job_name,
149149
s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER),
150150
sagemaker_session=self.sagemaker_session,
151-
job_name=self.job_name,
152151
s3_kms_key=self.s3_kms_key,
153152
)
154153

@@ -158,10 +157,9 @@ def save_pipeline_step_function(self, serialized_data):
158157
)
159158
serialization._upload_payload_and_metadata_to_s3(
160159
bytes_to_upload=serialized_data.args,
161-
160+
job_name=self.job_name,
162161
s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER),
163162
sagemaker_session=self.sagemaker_session,
164-
job_name=self.job_name,
165163
s3_kms_key=self.s3_kms_key,
166164
)
167165

sagemaker-core/tests/unit/remote_function/test_serialization_security.py

Lines changed: 9 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
_MetaData,
2525
_compute_hash,
2626
_compute_hmac,
27+
_extract_job_name_from_secret_arn,
2728
_get_or_create_hmac_secret,
2829
_get_hmac_key_from_secret,
2930
_store_secret_arn_in_parameter_store,
3031
_get_secret_arn_from_parameter_store,
31-
_extract_job_name_from_s3_uri,
3232
_validate_secret_arn,
3333
_perform_integrity_check,
3434
_upload_payload_and_metadata_to_s3,
@@ -186,7 +186,7 @@ def test_store_secret_arn(self):
186186
call_kwargs = ssm_client.put_parameter.call_args[1]
187187
assert call_kwargs["Name"] == f"/sagemaker/remote-function/{MOCK_JOB_NAME}/secret-arn"
188188
assert call_kwargs["Value"] == MOCK_SECRET_ARN
189-
assert call_kwargs["Overwrite"] is True
189+
assert "Tags" in call_kwargs
190190

191191
def test_get_secret_arn(self):
192192
session, _, ssm_client, _ = _mock_sagemaker_session()
@@ -208,28 +208,6 @@ def test_get_secret_arn_not_found_raises(self):
208208
_get_secret_arn_from_parameter_store(session, MOCK_JOB_NAME)
209209

210210

211-
class TestExtractJobName:
212-
"""Tests for S3 URI job name extraction."""
213-
214-
def test_extract_from_results_uri(self):
215-
result = _extract_job_name_from_s3_uri(
216-
"s3://bucket/remote-function/my-job-123/results"
217-
)
218-
assert result == "my-job-123"
219-
220-
def test_extract_from_function_uri(self):
221-
result = _extract_job_name_from_s3_uri(
222-
"s3://bucket/remote-function/my-job-123/function"
223-
)
224-
assert result == "my-job-123"
225-
226-
def test_extract_from_exception_uri(self):
227-
result = _extract_job_name_from_s3_uri(
228-
"s3://bucket/remote-function/my-job-123/exception"
229-
)
230-
assert result == "my-job-123"
231-
232-
233211
class TestValidateSecretArn:
234212
"""Tests for secret ARN validation (Mitigations #1 and #3)."""
235213

@@ -238,7 +216,7 @@ def test_valid_secret_arn_passes(self):
238216
session, _, _, _ = _mock_sagemaker_session()
239217

240218
# Should not raise
241-
_validate_secret_arn(session, MOCK_SECRET_ARN, MOCK_JOB_NAME)
219+
_validate_secret_arn(session, MOCK_SECRET_ARN)
242220

243221
def test_cross_account_arn_rejected(self):
244222
"""Mitigation #1: Secret ARN from different account should be rejected."""
@@ -247,7 +225,7 @@ def test_cross_account_arn_rejected(self):
247225
attacker_arn = "arn:aws:secretsmanager:us-west-2:999999999999:secret:evil-secret"
248226

249227
with pytest.raises(DeserializationError, match="same AWS account"):
250-
_validate_secret_arn(session, attacker_arn, MOCK_JOB_NAME)
228+
_validate_secret_arn(session, attacker_arn)
251229

252230
def test_tampered_arn_rejected(self):
253231
"""Mitigation #3: ARN not matching Parameter Store should be rejected."""
@@ -261,15 +239,15 @@ def test_tampered_arn_rejected(self):
261239
# Attacker's ARN (same account but different secret)
262240
tampered_arn = f"arn:aws:secretsmanager:us-west-2:{MOCK_ACCOUNT_ID}:secret:attacker-created-secret"
263241

264-
with pytest.raises(DeserializationError, match="Secret ARN mismatch"):
265-
_validate_secret_arn(session, tampered_arn, MOCK_JOB_NAME)
242+
with pytest.raises(DeserializationError, match="does not match expected format"):
243+
_validate_secret_arn(session, tampered_arn)
266244

267245
def test_invalid_arn_format_rejected(self):
268246
"""Malformed ARN should be rejected."""
269247
session, _, _, _ = _mock_sagemaker_session()
270248

271249
with pytest.raises(DeserializationError, match="Invalid secret ARN format"):
272-
_validate_secret_arn(session, "not-an-arn", MOCK_JOB_NAME)
250+
_validate_secret_arn(session, "not-an-arn")
273251

274252

275253
class TestPerformIntegrityCheck:
@@ -288,7 +266,6 @@ def test_hmac_integrity_check_passes(self):
288266
buffer=payload,
289267
sagemaker_session=session,
290268
secret_arn=MOCK_SECRET_ARN,
291-
s3_uri=MOCK_S3_URI,
292269
)
293270

294271
def test_hmac_integrity_check_fails_on_tampered_payload(self):
@@ -305,7 +282,6 @@ def test_hmac_integrity_check_fails_on_tampered_payload(self):
305282
buffer=tampered_payload,
306283
sagemaker_session=session,
307284
secret_arn=MOCK_SECRET_ARN,
308-
s3_uri=MOCK_S3_URI,
309285
)
310286

311287
def test_legacy_sha256_check_passes_with_warning(self):
@@ -340,19 +316,6 @@ def test_hmac_check_requires_session(self):
340316
secret_arn=MOCK_SECRET_ARN,
341317
)
342318

343-
def test_hmac_check_requires_s3_uri(self):
344-
"""HMAC check should require s3_uri."""
345-
session, _, _, _ = _mock_sagemaker_session()
346-
347-
with pytest.raises(DeserializationError, match="s3_uri is required"):
348-
_perform_integrity_check(
349-
expected_hash_value="hash",
350-
buffer=b"data",
351-
sagemaker_session=session,
352-
secret_arn=MOCK_SECRET_ARN,
353-
)
354-
355-
356319
class TestAttackScenarios:
357320
"""Tests simulating actual attack scenarios."""
358321

@@ -373,7 +336,6 @@ def test_attacker_replaces_payload_and_metadata_plain_hash(self):
373336
buffer=malicious_payload,
374337
sagemaker_session=session,
375338
secret_arn=MOCK_SECRET_ARN,
376-
s3_uri=MOCK_S3_URI,
377339
)
378340

379341
def test_attacker_points_to_cross_account_secret(self):
@@ -383,7 +345,7 @@ def test_attacker_points_to_cross_account_secret(self):
383345
attacker_secret_arn = "arn:aws:secretsmanager:us-west-2:999999999999:secret:attacker-secret"
384346

385347
with pytest.raises(DeserializationError, match="same AWS account"):
386-
_validate_secret_arn(session, attacker_secret_arn, MOCK_JOB_NAME)
348+
_validate_secret_arn(session, attacker_secret_arn)
387349

388350
def test_attacker_creates_secret_in_same_account(self):
389351
"""Attacker creates secret in same account but ARN doesn't match Parameter Store."""
@@ -398,4 +360,4 @@ def test_attacker_creates_secret_in_same_account(self):
398360
attacker_arn = f"arn:aws:secretsmanager:us-west-2:{MOCK_ACCOUNT_ID}:secret:sagemaker/remote-function/evil-job/hmac-key"
399361

400362
with pytest.raises(DeserializationError, match="Secret ARN mismatch"):
401-
_validate_secret_arn(session, attacker_arn, MOCK_JOB_NAME)
363+
_validate_secret_arn(session, attacker_arn)

0 commit comments

Comments
 (0)