Skip to content

Commit 0747156

Browse files
authored
fix: s3 bucket operations in V2 (#5803)
* fix: s3 bucket operations in V2 Add ExpectedBucketOwner to S3 data-plane calls * Follow up test fixes * fix unit tests * fix unit test
1 parent f416fa7 commit 0747156

35 files changed

Lines changed: 1155 additions & 52 deletions

src/sagemaker/amazon/amazon_estimator.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,8 @@ def record_set(
317317
key_prefix = key_prefix.lstrip("/")
318318
logger.debug("Uploading to bucket %s and key_prefix %s", bucket, key_prefix)
319319
manifest_s3_file = upload_numpy_to_s3_shards(
320-
self.instance_count, s3, bucket, key_prefix, train, labels, encrypt
320+
self.instance_count, s3, bucket, key_prefix, train, labels, encrypt,
321+
sagemaker_session=self.sagemaker_session
321322
)
322323
logger.debug("Created manifest file %s", manifest_s3_file)
323324
return RecordSet(
@@ -455,7 +456,7 @@ def _build_shards(num_shards, array):
455456

456457

457458
def upload_numpy_to_s3_shards(
458-
num_shards, s3, bucket, key_prefix, array, labels=None, encrypt=False
459+
num_shards, s3, bucket, key_prefix, array, labels=None, encrypt=False, sagemaker_session=None
459460
):
460461
"""Upload the training ``array`` and ``labels`` arrays to ``num_shards``.
461462
@@ -470,6 +471,8 @@ def upload_numpy_to_s3_shards(
470471
array:
471472
labels:
472473
encrypt:
474+
sagemaker_session: Optional. SageMaker session used to resolve the
475+
ExpectedBucketOwner spot check for the default bucket.
473476
"""
474477
shards = _build_shards(num_shards, array)
475478
if labels is not None:
@@ -478,6 +481,12 @@ def upload_numpy_to_s3_shards(
478481
if key_prefix[-1] != "/":
479482
key_prefix = key_prefix + "/"
480483
extra_put_kwargs = {"ServerSideEncryption": "AES256"} if encrypt else {}
484+
# Spot check: enforce ownership only when uploading to the session's default
485+
# bucket. Cross-account destinations are left untouched.
486+
if sagemaker_session is not None:
487+
expected_owner = sagemaker_session._get_account_id_if_default_bucket(bucket)
488+
if expected_owner:
489+
extra_put_kwargs["ExpectedBucketOwner"] = expected_owner
481490
try:
482491
for shard_index, shard in enumerate(shards):
483492
with tempfile.TemporaryFile() as file:

src/sagemaker/async_inference/async_inference_response.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ def _get_result_from_s3_output_path(self, output_path):
9797
"""Get inference result from the output Amazon S3 path"""
9898
bucket, key = parse_s3_url(output_path)
9999
try:
100-
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
100+
get_kwargs = {"Bucket": bucket, "Key": key}
101+
expected_owner = self.predictor_async.sagemaker_session._get_account_id_if_default_bucket(bucket)
102+
if expected_owner:
103+
get_kwargs["ExpectedBucketOwner"] = expected_owner
104+
response = self.predictor_async.s3_client.get_object(**get_kwargs)
101105
return self.predictor_async.predictor._handle_response(response)
102106
except ClientError as ex:
103107
if ex.response["Error"]["Code"] == "NoSuchKey":
@@ -113,14 +117,22 @@ def _get_result_from_s3_output_failure_paths(self, output_path, failure_path):
113117
"""Get inference result from the output & failure Amazon S3 path"""
114118
bucket, key = parse_s3_url(output_path)
115119
try:
116-
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
120+
get_kwargs = {"Bucket": bucket, "Key": key}
121+
expected_owner = self.predictor_async.sagemaker_session._get_account_id_if_default_bucket(bucket)
122+
if expected_owner:
123+
get_kwargs["ExpectedBucketOwner"] = expected_owner
124+
response = self.predictor_async.s3_client.get_object(**get_kwargs)
117125
return self.predictor_async.predictor._handle_response(response)
118126
except ClientError as e:
119127
if e.response["Error"]["Code"] == "NoSuchKey":
120128
try:
121129
failure_bucket, failure_key = parse_s3_url(failure_path)
130+
fail_kwargs = {"Bucket": failure_bucket, "Key": failure_key}
131+
fail_owner = self.predictor_async.sagemaker_session._get_account_id_if_default_bucket(failure_bucket)
132+
if fail_owner:
133+
fail_kwargs["ExpectedBucketOwner"] = fail_owner
122134
failure_response = self.predictor_async.s3_client.get_object(
123-
Bucket=failure_bucket, Key=failure_key
135+
**fail_kwargs
124136
)
125137
failure_response = self.predictor_async.predictor._handle_response(
126138
failure_response

src/sagemaker/estimator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,9 @@ def _stage_user_code_in_s3(self) -> UploadedCode:
11051105
kms_key=kms_key,
11061106
s3_resource=self.sagemaker_session.s3_resource,
11071107
settings=self.sagemaker_session.settings,
1108+
expected_bucket_owner=self.sagemaker_session._get_account_id_if_default_bucket(
1109+
code_bucket
1110+
),
11081111
)
11091112

11101113
def _assign_s3_prefix(self, key_prefix=""):

src/sagemaker/experiments/_helper.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,16 @@ def upload_artifact(self, file_path, extra_args=None):
9595
artifact_s3_key = "{}/{}/{}".format(
9696
self.artifact_prefix, self.trial_component_name, artifact_name
9797
)
98+
99+
# Spot check: enforce ownership only when uploading to the session's default
100+
# bucket. Cross-account destinations are left untouched.
101+
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
102+
self.artifact_bucket
103+
)
104+
if expected_owner:
105+
extra_args = dict(extra_args) if extra_args else {}
106+
extra_args["ExpectedBucketOwner"] = expected_owner
107+
98108
self._s3_client.upload_file(
99109
file_path,
100110
self.artifact_bucket,
@@ -133,9 +143,21 @@ def upload_object_artifact(self, artifact_name, artifact_object, file_extension=
133143
artifact_s3_key = "{}/{}/{}".format(
134144
self.artifact_prefix, self.trial_component_name, artifact_name
135145
)
136-
self._s3_client.put_object(
137-
Body=json.dumps(artifact_object), Bucket=self.artifact_bucket, Key=artifact_s3_key
146+
147+
# Spot check: enforce ownership only when uploading to the session's default
148+
# bucket. Cross-account destinations are left untouched.
149+
put_kwargs = {
150+
"Body": json.dumps(artifact_object),
151+
"Bucket": self.artifact_bucket,
152+
"Key": artifact_s3_key,
153+
}
154+
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
155+
self.artifact_bucket
138156
)
157+
if expected_owner:
158+
put_kwargs["ExpectedBucketOwner"] = expected_owner
159+
160+
self._s3_client.put_object(**put_kwargs)
139161
etag = self._try_get_etag(artifact_s3_key)
140162
return "s3://{}/{}".format(self.artifact_bucket, artifact_s3_key), etag
141163

@@ -149,7 +171,14 @@ def _try_get_etag(self, key):
149171
str: The S3 object ETag if it allows, otherwise return None.
150172
"""
151173
try:
152-
response = self._s3_client.head_object(Bucket=self.artifact_bucket, Key=key)
174+
head_kwargs = {"Bucket": self.artifact_bucket, "Key": key}
175+
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
176+
self.artifact_bucket
177+
)
178+
if expected_owner:
179+
head_kwargs["ExpectedBucketOwner"] = expected_owner
180+
181+
response = self._s3_client.head_object(**head_kwargs)
153182
return response["ETag"]
154183
except botocore.exceptions.ClientError as error:
155184
# requires read permissions

src/sagemaker/fw_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def tar_and_upload_dir(
401401
kms_key=None,
402402
s3_resource=None,
403403
settings: Optional[SessionSettings] = None,
404+
expected_bucket_owner: Optional[str] = None,
404405
) -> UploadedCode:
405406
"""Package source files and upload a compress tar file to S3.
406407
@@ -431,6 +432,12 @@ def tar_and_upload_dir(
431432
settings (sagemaker.session_settings.SessionSettings): Optional. The settings
432433
of the SageMaker ``Session``, can be used to override the default encryption
433434
behavior (default: None).
435+
expected_bucket_owner (str): Optional. AWS account id passed as
436+
``ExpectedBucketOwner`` on the upload. Callers should supply this when
437+
``bucket`` is the session's default bucket (via
438+
``Session._get_account_id_if_default_bucket``) to defend against
439+
bucket-squatting on the predictable default name. Leave as ``None`` for
440+
cross-account destination buckets.
434441
Returns:
435442
sagemaker.fw_utils.UploadedCode: An object with the S3 bucket and key (S3 prefix) and
436443
script name.
@@ -472,6 +479,10 @@ def tar_and_upload_dir(
472479
else:
473480
extra_args = None
474481

482+
if expected_bucket_owner:
483+
extra_args = dict(extra_args) if extra_args else {}
484+
extra_args["ExpectedBucketOwner"] = expected_bucket_owner
485+
475486
if s3_resource is None:
476487
s3_resource = session.resource("s3", region_name=session.region_name)
477488
else:

src/sagemaker/lambda_helper.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,17 @@ def create(self):
123123
bucket, key_prefix = s3.determine_bucket_and_prefix(
124124
bucket=self.s3_bucket, key_prefix=None, sagemaker_session=self.session
125125
)
126+
# Spot check: if the resolved bucket is the session's default bucket,
127+
# enforce ownership on the upload so an attacker cannot squat on the
128+
# predictable default name.
129+
expected_owner = self.session._get_account_id_if_default_bucket(bucket)
126130
key = _upload_to_s3(
127131
s3_client=_get_s3_client(self.session),
128132
function_name=self.function_name,
129133
zipped_code_dir=self.zipped_code_dir,
130134
s3_bucket=bucket,
131135
s3_key_prefix=key_prefix,
136+
expected_bucket_owner=expected_owner,
132137
)
133138
code = {"S3Bucket": bucket, "S3Key": key}
134139

@@ -179,6 +184,13 @@ def update(self):
179184
else:
180185
function_name_for_s3 = self.function_name
181186

187+
# Spot check: enforce ownership only when the resolved bucket is
188+
# the session's default bucket (defends against squatting on the
189+
# predictable default name). Other buckets are left untouched.
190+
expected_owner = self.session._get_account_id_if_default_bucket(
191+
bucket
192+
)
193+
182194
response = lambda_client.update_function_code(
183195
FunctionName=(self.function_name or self.function_arn),
184196
S3Bucket=bucket,
@@ -188,6 +200,7 @@ def update(self):
188200
zipped_code_dir=self.zipped_code_dir,
189201
s3_bucket=bucket,
190202
s3_key_prefix=key_prefix,
203+
expected_bucket_owner=expected_owner,
191204
),
192205
)
193206
return response
@@ -276,13 +289,29 @@ def _get_lambda_client(session):
276289
return lambda_client
277290

278291

279-
def _upload_to_s3(s3_client, function_name, zipped_code_dir, s3_bucket, s3_key_prefix=None):
292+
def _upload_to_s3(
293+
s3_client,
294+
function_name,
295+
zipped_code_dir,
296+
s3_bucket,
297+
s3_key_prefix=None,
298+
expected_bucket_owner=None,
299+
):
280300
"""Upload the zipped code to S3 bucket provided in the Lambda instance.
281301
282302
Lambda instance must have a path to the zipped code folder and a S3 bucket to upload
283303
the code. The key will lambda/function_name/code and the S3 URI where the code is
284304
uploaded is in this format: s3://bucket_name/lambda/function_name/code.
285305
306+
Args:
307+
s3_client: boto3 S3 client used for the upload.
308+
function_name (str): Lambda function name used to build the S3 key.
309+
zipped_code_dir (str): Local path to the zipped Lambda code.
310+
s3_bucket (str): Destination S3 bucket.
311+
s3_key_prefix (str): Optional S3 key prefix.
312+
expected_bucket_owner (str): Optional account id passed as ``ExpectedBucketOwner``
313+
on the upload when the destination bucket should belong to that account.
314+
286315
Returns: the S3 key where the code is uploaded.
287316
"""
288317

@@ -292,7 +321,10 @@ def _upload_to_s3(s3_client, function_name, zipped_code_dir, s3_bucket, s3_key_p
292321
function_name,
293322
"code",
294323
)
295-
s3_client.upload_file(zipped_code_dir, s3_bucket, key)
324+
extra_args = None
325+
if expected_bucket_owner:
326+
extra_args = {"ExpectedBucketOwner": expected_bucket_owner}
327+
s3_client.upload_file(zipped_code_dir, s3_bucket, key, ExtraArgs=extra_args)
296328
return key
297329

298330

src/sagemaker/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,9 @@ def _upload_code(self, key_prefix: str, repack: bool = False) -> None:
788788
dependencies=self.dependencies,
789789
kms_key=self.model_kms_key,
790790
settings=self.sagemaker_session.settings,
791+
expected_bucket_owner=self.sagemaker_session._get_account_id_if_default_bucket(
792+
bucket
793+
),
791794
)
792795

793796
if repack and self.model_data is not None and self.entry_point is not None:

src/sagemaker/multidatamodel.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,17 @@ def add_model(self, model_data_source, model_data_path=None):
330330
dst_s3_uri = s3.s3_path_join(dst_prefix, model_data_path)
331331
else:
332332
dst_s3_uri = s3.s3_path_join(dst_prefix, os.path.basename(model_data_source))
333-
self.s3_client.upload_file(model_data_source, destination_bucket, dst_s3_uri)
333+
# Spot check: enforce ownership only when uploading to the session's default
334+
# bucket. Cross-account destinations are left untouched.
335+
extra_args = None
336+
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
337+
destination_bucket
338+
)
339+
if expected_owner:
340+
extra_args = {"ExpectedBucketOwner": expected_owner}
341+
self.s3_client.upload_file(
342+
model_data_source, destination_bucket, dst_s3_uri, ExtraArgs=extra_args
343+
)
334344
# return upload_path
335345
return s3.s3_path_join("s3://", destination_bucket, dst_s3_uri)
336346

src/sagemaker/predictor_async.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,18 @@ def _upload_data_to_s3(
176176
)
177177

178178
data = self.serializer.serialize(data)
179-
self.s3_client.put_object(
180-
Body=data, Bucket=bucket, Key=key, ContentType=self.serializer.CONTENT_TYPE
181-
)
179+
# Spot check: enforce ownership only when uploading to the session's default
180+
# bucket. Cross-account destinations are left untouched.
181+
put_kwargs = {
182+
"Body": data,
183+
"Bucket": bucket,
184+
"Key": key,
185+
"ContentType": self.serializer.CONTENT_TYPE,
186+
}
187+
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(bucket)
188+
if expected_owner:
189+
put_kwargs["ExpectedBucketOwner"] = expected_owner
190+
self.s3_client.put_object(**put_kwargs)
182191
input_path = input_path or "s3://{}/{}".format(bucket, key)
183192

184193
return input_path
@@ -241,7 +250,13 @@ def _check_output_path(self, output_path, waiter_config):
241250
output_path=output_path,
242251
seconds=waiter_config.delay * waiter_config.max_attempts,
243252
)
244-
s3_object = self.s3_client.get_object(Bucket=bucket, Key=key)
253+
# Spot check: enforce ownership only when reading from the session's default
254+
# bucket. Cross-account reads are left untouched.
255+
get_kwargs = {"Bucket": bucket, "Key": key}
256+
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(bucket)
257+
if expected_owner:
258+
get_kwargs["ExpectedBucketOwner"] = expected_owner
259+
s3_object = self.s3_client.get_object(**get_kwargs)
245260
result = self.predictor._handle_response(response=s3_object)
246261
return result
247262

@@ -311,12 +326,22 @@ def check_failure_file():
311326
time.sleep(1)
312327

313328
if output_file_found.is_set():
314-
s3_object = self.s3_client.get_object(Bucket=output_bucket, Key=output_key)
329+
get_kwargs = {"Bucket": output_bucket, "Key": output_key}
330+
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(
331+
output_bucket
332+
)
333+
if expected_owner:
334+
get_kwargs["ExpectedBucketOwner"] = expected_owner
335+
s3_object = self.s3_client.get_object(**get_kwargs)
315336
result = self.predictor._handle_response(response=s3_object)
316337
return result
317338

318339
if failure_file_found.is_set():
319-
failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key)
340+
fail_kwargs = {"Bucket": failure_bucket, "Key": failure_key}
341+
fail_owner = self.sagemaker_session._get_account_id_if_default_bucket(failure_bucket)
342+
if fail_owner:
343+
fail_kwargs["ExpectedBucketOwner"] = fail_owner
344+
failure_object = self.s3_client.get_object(**fail_kwargs)
320345
failure_response = self.predictor._handle_response(response=failure_object)
321346
raise AsyncInferenceModelError(message=failure_response)
322347

src/sagemaker/pytorch/estimator.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -744,7 +744,14 @@ def _create_recipe_copy(self, original_s3_uri):
744744
# Copy the object with the new name
745745
copy_source = {"Bucket": bucket, "Key": original_key}
746746

747-
s3_client.copy_object(CopySource=copy_source, Bucket=bucket, Key=new_key)
747+
# Spot check: enforce ownership only when copying within the session's
748+
# default bucket. Cross-account buckets are left untouched.
749+
copy_kwargs = {"CopySource": copy_source, "Bucket": bucket, "Key": new_key}
750+
expected_owner = self.sagemaker_session._get_account_id_if_default_bucket(bucket)
751+
if expected_owner:
752+
copy_kwargs["ExpectedBucketOwner"] = expected_owner
753+
754+
s3_client.copy_object(**copy_kwargs)
748755

749756
return f"s3://{bucket}/{new_key}"
750757

0 commit comments

Comments
 (0)