Skip to content

Commit e164b19

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Add metadata helpers to MultimodalDataset.
BREAKING CHANGE: `create_from_bigquery` and `update_multimodal_dataset` no longer automatically prepend a missing `bq://` prefix for BigQuery URIs. When using the new function `MultimodalDataset.set_bigquery_uri` the prefix will still be added if needed. PiperOrigin-RevId: 890413106
1 parent 2767273 commit e164b19

File tree

5 files changed

+208
-93
lines changed

5 files changed

+208
-93
lines changed

tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -106,25 +106,6 @@ def test_create_dataset_from_bigquery(client):
106106
)
107107

108108

109-
def test_create_dataset_from_bigquery_without_bq_prefix(client):
110-
dataset = client.datasets.create_from_bigquery(
111-
multimodal_dataset={
112-
"display_name": "test-from-bigquery",
113-
"description": "test-description-from-bigquery",
114-
"metadata": {
115-
"inputConfig": {
116-
"bigquerySource": {"uri": BIGQUERY_TABLE_NAME},
117-
},
118-
},
119-
},
120-
)
121-
assert isinstance(dataset, types.MultimodalDataset)
122-
assert dataset.display_name == "test-from-bigquery"
123-
assert dataset.metadata.input_config.bigquery_source.uri == (
124-
f"bq://{BIGQUERY_TABLE_NAME}"
125-
)
126-
127-
128109
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
129110
def test_create_dataset_from_pandas(client, is_replay_mode):
130111
dataframe = pd.DataFrame(
@@ -270,26 +251,6 @@ async def test_create_dataset_from_bigquery_async_with_timeout(client):
270251
)
271252

272253

273-
@pytest.mark.asyncio
274-
async def test_create_dataset_from_bigquery_async_without_bq_prefix(client):
275-
dataset = await client.aio.datasets.create_from_bigquery(
276-
multimodal_dataset={
277-
"display_name": "test-from-bigquery",
278-
"description": "test-description-from-bigquery",
279-
"metadata": {
280-
"inputConfig": {
281-
"bigquerySource": {"uri": BIGQUERY_TABLE_NAME},
282-
},
283-
},
284-
},
285-
)
286-
assert isinstance(dataset, types.MultimodalDataset)
287-
assert dataset.display_name == "test-from-bigquery"
288-
assert dataset.metadata.input_config.bigquery_source.uri == (
289-
f"bq://{BIGQUERY_TABLE_NAME}"
290-
)
291-
292-
293254
@pytest.mark.asyncio
294255
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
295256
async def test_create_dataset_from_pandas_async(client, is_replay_mode):
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
"""Tests for multimodal datasets."""
16+
17+
from vertexai._genai import types
18+
19+
20+
class TestMultimodalDataset:
21+
22+
def test_read_config(self):
23+
dataset = types.MultimodalDataset(
24+
metadata={
25+
"gemini_request_read_config": {
26+
"assembled_request_column_name": "test_column",
27+
},
28+
},
29+
)
30+
31+
assert isinstance(dataset.read_config, types.GeminiRequestReadConfig)
32+
assert dataset.read_config.assembled_request_column_name == "test_column"
33+
34+
def test_read_config_empty(self):
35+
dataset = types.MultimodalDataset()
36+
assert dataset.read_config is None
37+
38+
def test_set_read_config(self):
39+
dataset = types.MultimodalDataset()
40+
41+
dataset.set_read_config(
42+
read_config={
43+
"assembled_request_column_name": "test_column",
44+
},
45+
)
46+
47+
assert isinstance(dataset, types.MultimodalDataset)
48+
assert (
49+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
50+
== "test_column"
51+
)
52+
53+
def test_set_read_config_preserves_other_fields(self):
54+
dataset = types.MultimodalDataset(
55+
metadata={
56+
"inputConfig": {
57+
"bigquerySource": {"uri": "bq://test_table"},
58+
},
59+
},
60+
)
61+
62+
dataset.set_read_config(
63+
read_config={
64+
"assembled_request_column_name": "test_column",
65+
},
66+
)
67+
68+
assert isinstance(dataset, types.MultimodalDataset)
69+
assert (
70+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
71+
== "test_column"
72+
)
73+
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
74+
75+
def test_bigquery_uri(self):
76+
dataset = types.MultimodalDataset(
77+
metadata={
78+
"inputConfig": {
79+
"bigquerySource": {"uri": "bq://project.dataset.table"},
80+
},
81+
},
82+
)
83+
84+
assert dataset.bigquery_uri == "bq://project.dataset.table"
85+
86+
def test_bigquery_uri_empty(self):
87+
dataset = types.MultimodalDataset()
88+
assert dataset.bigquery_uri is None
89+
90+
def test_set_bigquery_uri(self):
91+
dataset = types.MultimodalDataset()
92+
93+
dataset.set_bigquery_uri("bq://project.dataset.table")
94+
95+
assert isinstance(dataset, types.MultimodalDataset)
96+
assert (
97+
dataset.metadata.input_config.bigquery_source.uri
98+
== "bq://project.dataset.table"
99+
)
100+
101+
def test_set_bigquery_uri_without_prefix(self):
102+
dataset = types.MultimodalDataset()
103+
104+
dataset.set_bigquery_uri("project.dataset.table")
105+
106+
assert isinstance(dataset, types.MultimodalDataset)
107+
assert (
108+
dataset.metadata.input_config.bigquery_source.uri
109+
== "bq://project.dataset.table"
110+
)
111+
112+
def test_set_bigquery_uri_preserves_other_fields(self):
113+
dataset = types.MultimodalDataset(
114+
metadata={
115+
"gemini_request_read_config": {
116+
"assembled_request_column_name": "test_column",
117+
},
118+
},
119+
)
120+
121+
dataset.set_bigquery_uri("bq://test_table")
122+
123+
assert isinstance(dataset, types.MultimodalDataset)
124+
assert dataset.metadata.input_config.bigquery_source.uri == "bq://test_table"
125+
assert (
126+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
127+
== "test_column"
128+
)

vertexai/_genai/_datasets_utils.py

Lines changed: 9 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
4444
return model_type(**filtered_response)
4545

4646

47-
def multimodal_dataset_get_bigquery_uri(
47+
def validate_multimodal_dataset_bigquery_uri(
4848
multimodal_dataset: common.MultimodalDataset,
49-
) -> str:
50-
"""Gets the bigquery uri from a multimodal dataset or raises ValueError."""
49+
) -> None:
50+
"""Validates that a multimodal dataset has a bigquery uri or raises ValueError."""
5151
if (
5252
not hasattr(multimodal_dataset, "metadata")
5353
or multimodal_dataset.metadata is None
@@ -70,33 +70,12 @@ def multimodal_dataset_get_bigquery_uri(
7070
raise ValueError(
7171
"Multimodal dataset input config bigquery source uri is required."
7272
)
73-
return str(multimodal_dataset.metadata.input_config.bigquery_source.uri)
74-
75-
76-
def multimodal_dataset_set_bigquery_uri(
77-
multimodal_dataset: common.MultimodalDataset,
78-
bigquery_uri: str,
79-
) -> None:
80-
"""Sets the bigquery uri from a multimodal dataset or raises ValueError."""
81-
metadata = (
82-
common.SchemaTablesDatasetMetadata()
83-
if multimodal_dataset.metadata is None
84-
else multimodal_dataset.metadata
85-
)
86-
input_config = (
87-
common.SchemaTablesDatasetMetadataInputConfig()
88-
if metadata.input_config is None
89-
else metadata.input_config
90-
)
91-
bigquery_source = (
92-
common.SchemaTablesDatasetMetadataBigQuerySource()
93-
if input_config.bigquery_source is None
94-
else input_config.bigquery_source
95-
)
96-
bigquery_source.uri = bigquery_uri
97-
input_config.bigquery_source = bigquery_source
98-
metadata.input_config = input_config
99-
multimodal_dataset.metadata = metadata
73+
if not str(multimodal_dataset.metadata.input_config.bigquery_source.uri).startswith(
74+
"bq://"
75+
):
76+
raise ValueError(
77+
"Multimodal dataset bigquery source uri must start with 'bq://'."
78+
)
10079

10180

10281
def _try_import_bigframes() -> Any:

vertexai/_genai/datasets.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -790,12 +790,8 @@ def create_from_bigquery(
790790
"""
791791
if isinstance(multimodal_dataset, dict):
792792
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
793+
_datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset)
793794

794-
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
795-
if not uri.startswith("bq://"):
796-
_datasets_utils.multimodal_dataset_set_bigquery_uri(
797-
multimodal_dataset, f"bq://{uri}"
798-
)
799795
if isinstance(config, dict):
800796
config = types.CreateMultimodalDatasetConfig(**config)
801797
elif not config:
@@ -998,8 +994,11 @@ def to_bigframes(
998994
elif not multimodal_dataset:
999995
multimodal_dataset = types.MultimodalDataset()
1000996

1001-
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
1002-
return bigframes.pandas.read_gbq_table(uri.removeprefix("bq://"))
997+
if multimodal_dataset.bigquery_uri is None:
998+
raise ValueError("Multimodal dataset bigquery source uri is not set.")
999+
return bigframes.pandas.read_gbq_table(
1000+
multimodal_dataset.bigquery_uri.removeprefix("bq://")
1001+
)
10031002

10041003
def update_multimodal_dataset(
10051004
self,
@@ -1026,12 +1025,8 @@ def update_multimodal_dataset(
10261025
"""
10271026
if isinstance(multimodal_dataset, dict):
10281027
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
1028+
_datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset)
10291029

1030-
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
1031-
if not uri.startswith("bq://"):
1032-
_datasets_utils.multimodal_dataset_set_bigquery_uri(
1033-
multimodal_dataset, f"bq://{uri}"
1034-
)
10351030
if isinstance(config, dict):
10361031
config = types.CreateMultimodalDatasetConfig(**config)
10371032
elif not config:
@@ -1936,12 +1931,8 @@ async def create_from_bigquery(
19361931
"""
19371932
if isinstance(multimodal_dataset, dict):
19381933
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
1934+
_datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset)
19391935

1940-
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
1941-
if not uri.startswith("bq://"):
1942-
_datasets_utils.multimodal_dataset_set_bigquery_uri(
1943-
multimodal_dataset, f"bq://{uri}"
1944-
)
19451936
if isinstance(config, dict):
19461937
config = types.CreateMultimodalDatasetConfig(**config)
19471938
elif not config:
@@ -2148,9 +2139,11 @@ async def to_bigframes(
21482139
elif not multimodal_dataset:
21492140
multimodal_dataset = types.MultimodalDataset()
21502141

2151-
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
2142+
if multimodal_dataset.bigquery_uri is None:
2143+
raise ValueError("Multimodal dataset bigquery source uri is missing.")
21522144
return await asyncio.to_thread(
2153-
bigframes.pandas.read_gbq_table, uri.removeprefix("bq://")
2145+
bigframes.pandas.read_gbq_table,
2146+
multimodal_dataset.bigquery_uri.removeprefix("bq://"),
21542147
)
21552148

21562149
async def update_multimodal_dataset(
@@ -2174,12 +2167,8 @@ async def update_multimodal_dataset(
21742167
"""
21752168
if isinstance(multimodal_dataset, dict):
21762169
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
2170+
_datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset)
21772171

2178-
uri = _datasets_utils.multimodal_dataset_get_bigquery_uri(multimodal_dataset)
2179-
if not uri.startswith("bq://"):
2180-
_datasets_utils.multimodal_dataset_set_bigquery_uri(
2181-
multimodal_dataset, f"bq://{uri}"
2182-
)
21832172
if isinstance(config, dict):
21842173
config = types.CreateMultimodalDatasetConfig(**config)
21852174
elif not config:

vertexai/_genai/types/common.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12366,6 +12366,64 @@ class MultimodalDataset(_common.BaseModel):
1236612366
default=None, description="""The description of the multimodal dataset."""
1236712367
)
1236812368

12369+
@property
12370+
def read_config(self) -> Optional[GeminiRequestReadConfig]:
12371+
"""Gets the read config from the dataset metadata. Returns None if it's not set."""
12372+
if self.metadata is None or self.metadata.gemini_request_read_config is None:
12373+
return None
12374+
return self.metadata.gemini_request_read_config
12375+
12376+
def set_read_config(
12377+
self,
12378+
*,
12379+
read_config: GeminiRequestReadConfigOrDict,
12380+
) -> None:
12381+
"""Sets the read config in the dataset metadata."""
12382+
if isinstance(read_config, dict):
12383+
read_config = GeminiRequestReadConfig(**read_config)
12384+
12385+
if self.metadata is None:
12386+
self.metadata = SchemaTablesDatasetMetadata()
12387+
self.metadata.gemini_request_read_config = read_config
12388+
12389+
@property
12390+
def bigquery_uri(
12391+
self,
12392+
) -> Optional[str]:
12393+
"""Gets the bigquery uri from the dataset metadata. Returns None if it's not set."""
12394+
if (
12395+
self.metadata is None
12396+
or self.metadata.input_config is None
12397+
or self.metadata.input_config.bigquery_source is None
12398+
):
12399+
return None
12400+
return str(self.metadata.input_config.bigquery_source.uri)
12401+
12402+
def set_bigquery_uri(
12403+
self,
12404+
bigquery_uri: str,
12405+
) -> None:
12406+
"""Sets the bigquery uri in the dataset metadata. Prepends 'bq://' if it's not already present."""
12407+
if not bigquery_uri.startswith("bq://"):
12408+
bigquery_uri = f"bq://{bigquery_uri}"
12409+
metadata = (
12410+
SchemaTablesDatasetMetadata() if self.metadata is None else self.metadata
12411+
)
12412+
input_config = (
12413+
SchemaTablesDatasetMetadataInputConfig()
12414+
if metadata.input_config is None
12415+
else metadata.input_config
12416+
)
12417+
bigquery_source = (
12418+
SchemaTablesDatasetMetadataBigQuerySource()
12419+
if input_config.bigquery_source is None
12420+
else input_config.bigquery_source
12421+
)
12422+
bigquery_source.uri = bigquery_uri
12423+
input_config.bigquery_source = bigquery_source
12424+
metadata.input_config = input_config
12425+
self.metadata = metadata
12426+
1236912427

1237012428
class MultimodalDatasetDict(TypedDict, total=False):
1237112429
"""Represents a multimodal dataset."""

0 commit comments

Comments
 (0)