Skip to content

Commit 9feda02

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Add bucket ownership verification to prevent bucket squatting in Model.upload()
PiperOrigin-RevId: 899642927
1 parent dc89de2 commit 9feda02

3 files changed

Lines changed: 116 additions & 0 deletions

File tree

google/cloud/aiplatform/utils/gcs_utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,35 @@
6363
_DEFAULT_STAGING_BUCKET_SALT = str(uuid.uuid4())
6464

6565

66+
def _verify_bucket_ownership(
67+
bucket: storage.Bucket,
68+
expected_project: str,
69+
client: storage.Client,
70+
) -> bool:
71+
"""Verifies that a GCS bucket belongs to the expected project.
72+
73+
This check mitigates bucket squatting attacks where an attacker creates a
74+
bucket with a predictable name in their own project before the victim does.
75+
76+
Args:
77+
bucket: The GCS bucket to verify.
78+
expected_project: The project ID that should own the bucket.
79+
client: Storage client instance.
80+
81+
Returns:
82+
True if the bucket belongs to the expected project, False otherwise.
83+
"""
84+
try:
85+
bucket.reload(client=client)
86+
bucket_project_number = str(bucket.project_number)
87+
expected_project_number = str(
88+
resource_manager_utils.get_project_number(expected_project)
89+
)
90+
return bucket_project_number == expected_project_number
91+
except Exception:
92+
return False
93+
94+
6695
def blob_from_uri(uri: str, client: storage.Client) -> storage.Blob:
6796
"""Create a Blob from a GCS URI, compatible with v2 and v3.
6897
@@ -221,6 +250,17 @@ def stage_local_data_in_gcs(
221250
project=project,
222251
location=location,
223252
)
253+
else:
254+
# Verify bucket ownership to prevent bucket squatting attacks.
255+
# See b/469987320 for details.
256+
if not _verify_bucket_ownership(staging_bucket, project, client):
257+
raise ValueError(
258+
f'Staging bucket "{staging_bucket_name}" exists but does '
259+
f'not belong to project "{project}". This may indicate a '
260+
f"bucket squatting attack. Please provide an explicit "
261+
f"staging_bucket parameter or configure one via "
262+
f"aiplatform.init(staging_bucket='gs://your-bucket')."
263+
)
224264
staging_gcs_dir = "gs://" + staging_bucket_name
225265

226266
timestamp = datetime.datetime.now().isoformat(sep="-", timespec="milliseconds")

tests/unit/aiplatform/test_models.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,9 @@ def mock_storage_blob_upload_from_filename():
899899
"google.cloud.storage.Blob.upload_from_filename"
900900
) as mock_blob_upload_from_filename, patch(
901901
"google.cloud.storage.Bucket.exists", return_value=True
902+
), patch(
903+
"google.cloud.aiplatform.utils.gcs_utils._verify_bucket_ownership",
904+
return_value=True,
902905
):
903906
yield mock_blob_upload_from_filename
904907

tests/unit/aiplatform/test_utils.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,79 @@ def test_validate_gcs_path(self):
668668
with pytest.raises(ValueError, match=err_msg):
669669
gcs_utils.validate_gcs_path(test_invalid_path)
670670

671+
@patch.object(
672+
gcs_utils.resource_manager_utils,
673+
"get_project_number",
674+
return_value=12345,
675+
)
676+
@patch.object(storage.Bucket, "reload")
677+
def test_verify_bucket_ownership_matching_project(
678+
self, mock_reload, mock_get_project_number
679+
):
680+
mock_client = mock.MagicMock(spec=storage.Client)
681+
mock_bucket = mock.MagicMock(spec=storage.Bucket)
682+
mock_bucket.project_number = 12345
683+
assert gcs_utils._verify_bucket_ownership(
684+
mock_bucket, "test-project", mock_client
685+
)
686+
687+
@patch.object(
688+
gcs_utils.resource_manager_utils,
689+
"get_project_number",
690+
return_value=12345,
691+
)
692+
@patch.object(storage.Bucket, "reload")
693+
def test_verify_bucket_ownership_different_project(
694+
self, mock_reload, mock_get_project_number
695+
):
696+
mock_client = mock.MagicMock(spec=storage.Client)
697+
mock_bucket = mock.MagicMock(spec=storage.Bucket)
698+
mock_bucket.project_number = 99999
699+
assert not gcs_utils._verify_bucket_ownership(
700+
mock_bucket, "test-project", mock_client
701+
)
702+
703+
@patch.object(storage.Bucket, "exists", return_value=True)
704+
@patch.object(storage, "Client")
705+
@patch.object(gcs_utils, "_verify_bucket_ownership", return_value=False)
706+
def test_stage_local_data_in_gcs_rejects_squatted_bucket(
707+
self, mock_verify, mock_storage_client, mock_bucket_exists, json_file
708+
):
709+
mock_config = mock.MagicMock()
710+
mock_config.project = "victim-project"
711+
mock_config.location = "us-central1"
712+
mock_config.staging_bucket = None
713+
mock_config.credentials = None
714+
with patch.object(gcs_utils.initializer, "global_config", mock_config):
715+
with pytest.raises(
716+
ValueError,
717+
match="bucket squatting",
718+
):
719+
gcs_utils.stage_local_data_in_gcs(json_file)
720+
721+
@patch.object(storage.Bucket, "exists", return_value=True)
722+
@patch.object(storage, "Client")
723+
@patch.object(gcs_utils, "_verify_bucket_ownership", return_value=True)
724+
@patch("google.cloud.storage.Blob.upload_from_filename")
725+
def test_stage_local_data_in_gcs_accepts_owned_bucket(
726+
self,
727+
mock_upload,
728+
mock_verify,
729+
mock_storage_client,
730+
mock_bucket_exists,
731+
json_file,
732+
mock_datetime,
733+
):
734+
mock_config = mock.MagicMock()
735+
mock_config.project = "my-project"
736+
mock_config.location = "us-central1"
737+
mock_config.staging_bucket = None
738+
mock_config.credentials = None
739+
with patch.object(gcs_utils.initializer, "global_config", mock_config):
740+
result = gcs_utils.stage_local_data_in_gcs(json_file)
741+
assert result.startswith("gs://")
742+
mock_verify.assert_called_once()
743+
671744

672745
class TestPipelineUtils:
673746
SAMPLE_JOB_SPEC = {

0 commit comments

Comments
 (0)