Skip to content

Commit 9b7dc29

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Support creating multimodal dataset from bigframe DataFrame
PiperOrigin-RevId: 888132573
1 parent 368a8f8 commit 9b7dc29

File tree

4 files changed

+352
-17
lines changed

4 files changed

+352
-17
lines changed

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@
6969
"pyyaml>=5.3.1,<7",
7070
]
7171
datasets_extra_require = [
72-
"pyarrow >= 3.0.0, < 8.0.0; python_version<'3.11'",
72+
"pyarrow >= 3.0.0, < 8.0.0; python_version<'3.10'",
73+
"pyarrow >= 10.0.1; python_version=='3.10'",
7374
"pyarrow >= 10.0.1; python_version=='3.11'",
7475
"pyarrow >= 14.0.0; python_version>='3.12'",
7576
]

tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414
#
1515
# pylint: disable=protected-access,bad-continuation,missing-function-docstring
1616

17+
import sys
1718
from unittest import mock
19+
1820
from google.cloud import bigquery
1921
from tests.unit.vertexai.genai.replays import pytest_helper
2022
from vertexai._genai import _datasets_utils
2123
from vertexai._genai import types
2224
import pandas as pd
2325
import pytest
2426

27+
2528
METADATA_SCHEMA_URI = (
2629
"gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml"
2730
)
@@ -156,6 +159,52 @@ def test_create_dataset_from_pandas(client, is_replay_mode):
156159
pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe)
157160

158161

162+
@pytest.mark.skipif(
163+
sys.version_info < (3, 10), reason="bigframes requires python 3.10 or higher"
164+
)
165+
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
166+
def test_create_dataset_from_bigframes(client, is_replay_mode):
167+
import bigframes.pandas
168+
169+
dataframe = pd.DataFrame(
170+
{
171+
"col1": ["col1"],
172+
"col2": ["col2"],
173+
}
174+
)
175+
if is_replay_mode:
176+
bf_dataframe = mock.MagicMock()
177+
bf_dataframe.to_gbq.return_value = "temp_table_id"
178+
else:
179+
bf_dataframe = bigframes.pandas.DataFrame(dataframe)
180+
181+
dataset = client.datasets.create_from_bigframes(
182+
dataframe=bf_dataframe,
183+
target_table_id=BIGQUERY_TABLE_NAME,
184+
multimodal_dataset={
185+
"display_name": "test-from-bigframes",
186+
},
187+
)
188+
189+
assert isinstance(dataset, types.MultimodalDataset)
190+
assert dataset.display_name == "test-from-bigframes"
191+
assert dataset.metadata.input_config.bigquery_source.uri == (
192+
f"bq://{BIGQUERY_TABLE_NAME}"
193+
)
194+
if not is_replay_mode:
195+
bigquery_client = bigquery.Client(
196+
project=client._api_client.project,
197+
location=client._api_client.location,
198+
credentials=client._api_client._credentials,
199+
)
200+
rows = bigquery_client.list_rows(
201+
dataset.metadata.input_config.bigquery_source.uri[5:]
202+
)
203+
pd.testing.assert_frame_equal(
204+
rows.to_dataframe(), dataframe, check_index_type=False
205+
)
206+
207+
159208
pytestmark = pytest_helper.setup(
160209
file=__file__,
161210
globals_for_file=globals(),
@@ -274,3 +323,50 @@ async def test_create_dataset_from_pandas_async(client, is_replay_mode):
274323
dataset.metadata.input_config.bigquery_source.uri[5:]
275324
)
276325
pd.testing.assert_frame_equal(rows.to_dataframe(), dataframe)
326+
327+
328+
@pytest.mark.skipif(
329+
sys.version_info < (3, 10), reason="bigframes requires python 3.10 or higher"
330+
)
331+
@pytest.mark.asyncio
332+
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
333+
async def test_create_dataset_from_bigframes_async(client, is_replay_mode):
334+
import bigframes.pandas
335+
336+
dataframe = pd.DataFrame(
337+
{
338+
"col1": ["col1"],
339+
"col2": ["col2"],
340+
}
341+
)
342+
if is_replay_mode:
343+
bf_dataframe = mock.MagicMock()
344+
bf_dataframe.to_gbq.return_value = "temp_table_id"
345+
else:
346+
bf_dataframe = bigframes.pandas.DataFrame(dataframe)
347+
348+
dataset = await client.aio.datasets.create_from_bigframes(
349+
dataframe=bf_dataframe,
350+
target_table_id=BIGQUERY_TABLE_NAME,
351+
multimodal_dataset={
352+
"display_name": "test-from-bigframes",
353+
},
354+
)
355+
356+
assert isinstance(dataset, types.MultimodalDataset)
357+
assert dataset.display_name == "test-from-bigframes"
358+
assert dataset.metadata.input_config.bigquery_source.uri == (
359+
f"bq://{BIGQUERY_TABLE_NAME}"
360+
)
361+
if not is_replay_mode:
362+
bigquery_client = bigquery.Client(
363+
project=client._api_client.project,
364+
location=client._api_client.location,
365+
credentials=client._api_client._credentials,
366+
)
367+
rows = bigquery_client.list_rows(
368+
dataset.metadata.input_config.bigquery_source.uri[5:]
369+
)
370+
pd.testing.assert_frame_equal(
371+
rows.to_dataframe(), dataframe, check_index_type=False
372+
)

vertexai/_genai/_datasets_utils.py

Lines changed: 88 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
#
1515
"""Utility functions for multimodal dataset."""
1616

17+
import asyncio
1718
from typing import Any, Type, TypeVar
1819
import uuid
1920

2021
import google.auth.credentials
2122
from vertexai._genai.types import common
2223
from pydantic import BaseModel
2324

25+
2426
METADATA_SCHEMA_URI = (
2527
"gs://google-cloud-aiplatform/schema/dataset/metadata/multimodal_1.0.0.yaml"
2628
)
@@ -169,14 +171,48 @@ def _normalize_and_validate_table_id(
169171
return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}"
170172

171173

174+
async def _normalize_and_validate_table_id_async(
175+
*,
176+
table_id: str,
177+
project: str,
178+
location: str,
179+
credentials: google.auth.credentials.Credentials,
180+
) -> str:
181+
bigquery = _try_import_bigquery()
182+
183+
table_ref = bigquery.TableReference.from_string(table_id, default_project=project)
184+
if table_ref.project != project:
185+
raise ValueError(
186+
"The BigQuery table "
187+
f"`{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}`"
188+
" must be in the same project as the multimodal dataset."
189+
f" The multimodal dataset is in `{project}`, but the BigQuery table"
190+
f" is in `{table_ref.project}`."
191+
)
192+
193+
dataset_ref = bigquery.DatasetReference(
194+
project=table_ref.project, dataset_id=table_ref.dataset_id
195+
)
196+
client = bigquery.Client(project=project, credentials=credentials)
197+
bq_dataset = await asyncio.to_thread(client.get_dataset, dataset_ref=dataset_ref)
198+
if not _bq_dataset_location_allowed(location, bq_dataset.location):
199+
raise ValueError(
200+
"The BigQuery dataset"
201+
f" `{dataset_ref.project}.{dataset_ref.dataset_id}` must be in the"
202+
" same location as the multimodal dataset. The multimodal dataset"
203+
f" is in `{location}`, but the BigQuery dataset is in"
204+
f" `{bq_dataset.location}`."
205+
)
206+
return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}"
207+
208+
172209
def _create_default_bigquery_dataset_if_not_exists(
173210
*,
174211
project: str,
175212
location: str,
176213
credentials: google.auth.credentials.Credentials,
177214
) -> str:
178-
# Loading bigquery lazily to avoid auto-loading it when importing vertexai
179-
from google.cloud import bigquery # pylint: disable=g-import-not-at-top
215+
bigquery = _try_import_bigquery()
180216

181217
bigquery_client = bigquery.Client(project=project, credentials=credentials)
182218
location_str = location.lower().replace("-", "_")
@@ -189,5 +225,55 @@ def _create_default_bigquery_dataset_if_not_exists(
189225
return f"{dataset_id.project}.{dataset_id.dataset_id}"
190226

191227

228+
async def _create_default_bigquery_dataset_if_not_exists_async(
229+
*,
230+
project: str,
231+
location: str,
232+
credentials: google.auth.credentials.Credentials,
233+
) -> str:
234+
bigquery = _try_import_bigquery()
235+
236+
bigquery_client = bigquery.Client(project=project, credentials=credentials)
237+
location_str = location.lower().replace("-", "_")
238+
dataset_id = bigquery.DatasetReference(
239+
project, f"{_DEFAULT_BQ_DATASET_PREFIX}_{location_str}"
240+
)
241+
dataset = bigquery.Dataset(dataset_ref=dataset_id)
242+
dataset.location = location
243+
await asyncio.to_thread(bigquery_client.create_dataset, dataset, exists_ok=True)
244+
return f"{dataset_id.project}.{dataset_id.dataset_id}"
245+
246+
192247
def _generate_target_table_id(dataset_id: str) -> str:
193248
return f"{dataset_id}.{_DEFAULT_BQ_TABLE_PREFIX}_{str(uuid.uuid4())}"
249+
250+
251+
def save_dataframe_to_bigquery(
252+
dataframe: "bigframes.pandas.DataFrame", # type: ignore # noqa: F821
253+
target_table_id: str,
254+
bq_client: "bigquery.Client", # type: ignore # noqa: F821
255+
) -> None:
256+
# `to_gbq` does not support cross-region use cases. We use `copy_table` as a workaround.
257+
temp_table_id = dataframe.to_gbq()
258+
copy_job = bq_client.copy_table(
259+
sources=temp_table_id,
260+
destination=target_table_id,
261+
)
262+
copy_job.result()
263+
bq_client.delete_table(temp_table_id)
264+
265+
266+
async def save_dataframe_to_bigquery_async(
267+
dataframe: "bigframes.pandas.DataFrame", # type: ignore # noqa: F821
268+
target_table_id: str,
269+
bq_client: "bigquery.Client", # type: ignore # noqa: F821
270+
) -> None:
271+
# `to_gbq` does not support cross-region use cases. We use `copy_table` as a workaround.
272+
temp_table_id = await asyncio.to_thread(dataframe.to_gbq)
273+
copy_job = await asyncio.to_thread(
274+
bq_client.copy_table,
275+
sources=temp_table_id,
276+
destination=target_table_id,
277+
)
278+
await asyncio.to_thread(copy_job.result)
279+
await asyncio.to_thread(bq_client.delete_table, temp_table_id)

0 commit comments

Comments
 (0)