Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions google/cloud/aiplatform/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5240,8 +5240,8 @@ def __init__(
# Create ModelRegistry with the unversioned resource name
self._registry = ModelRegistry(
self.resource_name,
location=location,
project=project,
location=location or self.location,
project=project or self.project,
credentials=credentials,
)

Expand Down
27 changes: 27 additions & 0 deletions tests/unit/aiplatform/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4701,6 +4701,33 @@ def test_init_with_version_arg(self, get_model_with_version):
# The Model yielded from upload SHOULD have a version in the versioned resource name
assert model.versioned_resource_name.endswith(f"@{_TEST_VERSION_ID}")

def test_versioning_registry_uses_location_from_resource_name(
self, create_client_mock
):
# Regression test for https://github.com/googleapis/python-aiplatform/issues/2608:
# When a Model is initialized with a fully-qualified resource name that encodes a
# non-default location, the versioning registry client must use that location, not
# the global default from aiplatform.init().
models.Model(_TEST_MODEL_RESOURCE_NAME_CUSTOM_LOCATION)
create_client_mock.assert_any_call(
client_class=utils.ModelClientWithOverride,
credentials=initializer.global_config.credentials,
location_override=_TEST_LOCATION_2,
appended_user_agent=None,
)

def test_versioning_registry_uses_project_from_resource_name(
self, get_model_with_custom_project_mock
):
# Regression test for https://github.com/googleapis/python-aiplatform/issues/2608:
# When a Model is initialized with a fully-qualified resource name that encodes a
# non-default project, the versioning registry must use that project, not the
# global default from aiplatform.init().
model = models.Model(_TEST_MODEL_RESOURCE_NAME_CUSTOM_PROJECT)
assert model._registry.model_resource_name.startswith(
f"projects/{_TEST_PROJECT_2}/"
)

@pytest.mark.parametrize(
"parent,location,project",
[
Expand Down