Skip to content

Commit a41c883

Browse files
cleop-googlecopybara-github
authored andcommitted
feat: GenAI SDK client(multimodal) - Accept an explicit bigquery_uri parameter in create_from_bigquery
PiperOrigin-RevId: 900174983
1 parent f5c4f8f commit a41c883

3 files changed

Lines changed: 133 additions & 22 deletions

File tree

tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,40 @@ def test_create_dataset_from_bigquery(client):
115115
)
116116

117117

118+
@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name")
119+
def test_create_dataset_from_bigquery_with_uri(client):
120+
dataset = client.datasets.create_from_bigquery(
121+
bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}",
122+
)
123+
assert isinstance(dataset, types.MultimodalDataset)
124+
assert dataset.metadata.input_config.bigquery_source.uri == (
125+
f"bq://{BIGQUERY_TABLE_NAME}"
126+
)
127+
128+
129+
def test_create_dataset_from_bigquery_preserves_other_metadata(client):
130+
dataset = client.datasets.create_from_bigquery(
131+
bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}",
132+
multimodal_dataset={
133+
"display_name": "test-from-bigquery-uri",
134+
"metadata": {
135+
"gemini_request_read_config": {
136+
"assembled_request_column_name": "test_column"
137+
}
138+
},
139+
},
140+
)
141+
assert isinstance(dataset, types.MultimodalDataset)
142+
assert dataset.display_name == "test-from-bigquery-uri"
143+
assert (
144+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
145+
== "test_column"
146+
)
147+
assert dataset.metadata.input_config.bigquery_source.uri == (
148+
f"bq://{BIGQUERY_TABLE_NAME}"
149+
)
150+
151+
118152
@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name")
119153
def test_create_dataset_from_bigquery_no_display_name(client):
120154
dataset = client.datasets.create_from_bigquery(
@@ -254,6 +288,44 @@ async def test_create_dataset_from_bigquery_async(client):
254288
)
255289

256290

291+
@pytest.mark.asyncio
292+
@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name")
293+
async def test_create_dataset_from_bigquery_with_uri_async(client):
294+
dataset = await client.aio.datasets.create_from_bigquery(
295+
bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}",
296+
)
297+
assert isinstance(dataset, types.MultimodalDataset)
298+
assert dataset.metadata.input_config.bigquery_source.uri == (
299+
f"bq://{BIGQUERY_TABLE_NAME}"
300+
)
301+
302+
303+
@pytest.mark.asyncio
304+
async def test_create_dataset_from_bigquery_preserves_other_metadata_async(
305+
client,
306+
):
307+
dataset = await client.aio.datasets.create_from_bigquery(
308+
bigquery_uri=f"bq://{BIGQUERY_TABLE_NAME}",
309+
multimodal_dataset={
310+
"display_name": "test-from-bigquery-uri",
311+
"metadata": {
312+
"gemini_request_read_config": {
313+
"assembled_request_column_name": "test_column"
314+
}
315+
},
316+
},
317+
)
318+
assert isinstance(dataset, types.MultimodalDataset)
319+
assert dataset.display_name == "test-from-bigquery-uri"
320+
assert (
321+
dataset.metadata.gemini_request_read_config.assembled_request_column_name
322+
== "test_column"
323+
)
324+
assert dataset.metadata.input_config.bigquery_source.uri == (
325+
f"bq://{BIGQUERY_TABLE_NAME}"
326+
)
327+
328+
257329
@pytest.mark.asyncio
258330
@pytest.mark.usefixtures("mock_generate_multimodal_dataset_display_name")
259331
async def test_create_dataset_from_bigquery_no_display_name_async(client):

vertexai/_genai/_datasets_utils.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
import google.auth.credentials
2323
from vertexai._genai.types import common
24-
from pydantic import BaseModel
24+
from google.genai import _common
2525

2626

2727
METADATA_SCHEMA_URI = (
@@ -31,18 +31,27 @@
3131
_DEFAULT_BQ_DATASET_PREFIX = "vertex_datasets"
3232
_DEFAULT_BQ_TABLE_PREFIX = "multimodal_dataset"
3333

34-
T = TypeVar("T", bound=BaseModel)
34+
T = TypeVar("T", bound=_common.BaseModel)
3535

3636

37-
def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
37+
def create_from_response(
38+
model_type: Type[T],
39+
response: dict[str, Any],
40+
config: Any | None = None,
41+
) -> T:
3842
"""Creates a model from a response."""
39-
model_field_names = model_type.model_fields.keys()
40-
filtered_response = {}
41-
for key, value in response.items():
42-
snake_key = common.camel_to_snake(key)
43-
if snake_key in model_field_names:
44-
filtered_response[snake_key] = value
45-
return model_type(**filtered_response)
43+
kwargs = (
44+
{
45+
"config": {
46+
"response_schema": getattr(config, "response_schema", None),
47+
"response_json_schema": getattr(config, "response_json_schema", None),
48+
"include_all_fields": getattr(config, "include_all_fields", None),
49+
}
50+
}
51+
if config
52+
else {}
53+
)
54+
return model_type._from_response(response=response, kwargs=kwargs)
4655

4756

4857
def validate_multimodal_dataset_bigquery_uri(

vertexai/_genai/datasets.py

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -924,23 +924,34 @@ def _wait_for_operation(
924924
def create_from_bigquery(
925925
self,
926926
*,
927-
multimodal_dataset: types.MultimodalDatasetOrDict,
927+
bigquery_uri: Optional[str] = None,
928+
multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None,
928929
config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None,
929930
) -> types.MultimodalDataset:
930931
"""Creates a multimodal dataset from a BigQuery table.
931932
932933
Args:
934+
bigquery_uri:
935+
Optional. The BigQuery URI of the table to create the dataset from.
936+
e.g. "bq://project.dataset.table".
933937
multimodal_dataset:
934-
Required. A representation of a multimodal dataset.
938+
Optional. A representation of a multimodal dataset.
935939
config:
936940
Optional. A configuration for creating the multimodal dataset. If not
937941
provided, the default configuration will be used.
938942
939943
Returns:
940944
A types.MultimodalDataset object representing a multimodal dataset.
941945
"""
942-
if isinstance(multimodal_dataset, dict):
946+
if multimodal_dataset is None:
947+
multimodal_dataset = types.MultimodalDataset()
948+
elif isinstance(multimodal_dataset, dict):
943949
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
950+
951+
if bigquery_uri:
952+
multimodal_dataset = multimodal_dataset.model_copy(deep=True)
953+
multimodal_dataset.set_bigquery_uri(bigquery_uri)
954+
944955
_datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset)
945956

946957
if isinstance(config, dict):
@@ -963,7 +974,9 @@ def create_from_bigquery(
963974
operation=multimodal_dataset_operation,
964975
timeout_seconds=config.timeout,
965976
)
966-
return _datasets_utils.create_from_response(types.MultimodalDataset, response)
977+
return _datasets_utils.create_from_response(
978+
types.MultimodalDataset, response, config
979+
)
967980

968981
def create_from_pandas(
969982
self,
@@ -1302,6 +1315,7 @@ def assess_tuning_resources(
13021315
return _datasets_utils.create_from_response(
13031316
types.TuningResourceUsageAssessmentResult,
13041317
response["tuningResourceUsageAssessmentResult"],
1318+
config,
13051319
)
13061320

13071321
def assess_tuning_validity(
@@ -1368,6 +1382,7 @@ def assess_tuning_validity(
13681382
return _datasets_utils.create_from_response(
13691383
types.TuningValidationAssessmentResult,
13701384
response["tuningValidationAssessmentResult"],
1385+
config,
13711386
)
13721387

13731388
def assess_batch_prediction_resources(
@@ -1430,7 +1445,7 @@ def assess_batch_prediction_resources(
14301445
)
14311446
result = response["batchPredictionResourceUsageAssessmentResult"]
14321447
return _datasets_utils.create_from_response(
1433-
types.BatchPredictionResourceUsageAssessmentResult, result
1448+
types.BatchPredictionResourceUsageAssessmentResult, result, config
14341449
)
14351450

14361451
def assess_batch_prediction_validity(
@@ -1493,7 +1508,7 @@ def assess_batch_prediction_validity(
14931508
)
14941509
result = response["batchPredictionValidationAssessmentResult"]
14951510
return _datasets_utils.create_from_response(
1496-
types.BatchPredictionValidationAssessmentResult, result
1511+
types.BatchPredictionValidationAssessmentResult, result, config
14971512
)
14981513

14991514

@@ -2192,23 +2207,34 @@ async def _wait_for_operation(
21922207
async def create_from_bigquery(
21932208
self,
21942209
*,
2195-
multimodal_dataset: types.MultimodalDatasetOrDict,
2210+
bigquery_uri: Optional[str] = None,
2211+
multimodal_dataset: Optional[types.MultimodalDatasetOrDict] = None,
21962212
config: Optional[types.CreateMultimodalDatasetConfigOrDict] = None,
21972213
) -> types.MultimodalDataset:
21982214
"""Creates a multimodal dataset from a BigQuery table.
21992215
22002216
Args:
2217+
bigquery_uri:
2218+
Optional. The BigQuery URI of the table to create the dataset from.
2219+
e.g. "bq://project.dataset.table".
22012220
multimodal_dataset:
2202-
Required. A representation of a multimodal dataset.
2221+
Optional. A representation of a multimodal dataset.
22032222
config:
22042223
Optional. A configuration for creating the multimodal dataset. If not
22052224
provided, the default configuration will be used.
22062225
22072226
Returns:
22082227
A types.MultimodalDataset object representing a multimodal dataset.
22092228
"""
2210-
if isinstance(multimodal_dataset, dict):
2229+
if multimodal_dataset is None:
2230+
multimodal_dataset = types.MultimodalDataset()
2231+
elif isinstance(multimodal_dataset, dict):
22112232
multimodal_dataset = types.MultimodalDataset(**multimodal_dataset)
2233+
2234+
if bigquery_uri:
2235+
multimodal_dataset = multimodal_dataset.model_copy(deep=True)
2236+
multimodal_dataset.set_bigquery_uri(bigquery_uri)
2237+
22122238
_datasets_utils.validate_multimodal_dataset_bigquery_uri(multimodal_dataset)
22132239

22142240
if isinstance(config, dict):
@@ -2231,7 +2257,9 @@ async def create_from_bigquery(
22312257
operation=multimodal_dataset_operation,
22322258
timeout_seconds=config.timeout,
22332259
)
2234-
return _datasets_utils.create_from_response(types.MultimodalDataset, response)
2260+
return _datasets_utils.create_from_response(
2261+
types.MultimodalDataset, response, config
2262+
)
22352263

22362264
async def create_from_pandas(
22372265
self,
@@ -2568,6 +2596,7 @@ async def assess_tuning_resources(
25682596
return _datasets_utils.create_from_response(
25692597
types.TuningResourceUsageAssessmentResult,
25702598
response["tuningResourceUsageAssessmentResult"],
2599+
config,
25712600
)
25722601

25732602
async def assess_tuning_validity(
@@ -2634,6 +2663,7 @@ async def assess_tuning_validity(
26342663
return _datasets_utils.create_from_response(
26352664
types.TuningValidationAssessmentResult,
26362665
response["tuningValidationAssessmentResult"],
2666+
config,
26372667
)
26382668

26392669
async def assess_batch_prediction_resources(
@@ -2696,7 +2726,7 @@ async def assess_batch_prediction_resources(
26962726
)
26972727
result = response["batchPredictionResourceUsageAssessmentResult"]
26982728
return _datasets_utils.create_from_response(
2699-
types.BatchPredictionResourceUsageAssessmentResult, result
2729+
types.BatchPredictionResourceUsageAssessmentResult, result, config
27002730
)
27012731

27022732
async def assess_batch_prediction_validity(
@@ -2759,5 +2789,5 @@ async def assess_batch_prediction_validity(
27592789
)
27602790
result = response["batchPredictionValidationAssessmentResult"]
27612791
return _datasets_utils.create_from_response(
2762-
types.BatchPredictionValidationAssessmentResult, result
2792+
types.BatchPredictionValidationAssessmentResult, result, config
27632793
)

0 commit comments

Comments
 (0)