Skip to content

Commit 216c055

Browse files
cleop-googlecopybara-github
authored andcommitted
fix: GenAI SDK client(multimodal) - preserve existing metadata when creating from bigframes
PiperOrigin-RevId: 901268958
1 parent f2d73fd commit 216c055

4 files changed

Lines changed: 293 additions & 57 deletions

File tree

tests/unit/vertexai/genai/replays/test_create_multimodal_datasets.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,46 @@ def test_create_dataset_from_bigframes(client, is_replay_mode):
210210
)
211211

212212

213+
@pytest.mark.skipif(
214+
sys.version_info < (3, 10), reason="bigframes requires python 3.10 or higher"
215+
)
216+
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
217+
def test_create_dataset_from_bigframes_preserves_other_metadata(client, is_replay_mode):
218+
import bigframes.pandas
219+
220+
dataframe = pd.DataFrame(
221+
{
222+
"col1": ["col1"],
223+
"col2": ["col2"],
224+
}
225+
)
226+
if is_replay_mode:
227+
bf_dataframe = mock.MagicMock()
228+
bf_dataframe.to_gbq.return_value = "temp_table_id"
229+
else:
230+
bf_dataframe = bigframes.pandas.DataFrame(dataframe)
231+
232+
dataset = client.datasets.create_from_bigframes(
233+
dataframe=bf_dataframe,
234+
target_table_id=BIGQUERY_TABLE_NAME,
235+
multimodal_dataset={
236+
"display_name": "test-from-bigframes",
237+
"metadata": {
238+
"gemini_request_read_config": {
239+
"assembled_request_column_name": "test_column"
240+
}
241+
}
242+
},
243+
)
244+
245+
assert isinstance(dataset, types.MultimodalDataset)
246+
assert dataset.display_name == "test-from-bigframes"
247+
assert dataset.metadata.gemini_request_read_config.assembled_request_column_name == "test_column"
248+
assert dataset.metadata.input_config.bigquery_source.uri == (
249+
f"bq://{BIGQUERY_TABLE_NAME}"
250+
)
251+
252+
213253
pytestmark = pytest_helper.setup(
214254
file=__file__,
215255
globals_for_file=globals(),
@@ -371,3 +411,44 @@ async def test_create_dataset_from_bigframes_async(client, is_replay_mode):
371411
pd.testing.assert_frame_equal(
372412
rows.to_dataframe(), dataframe, check_index_type=False
373413
)
414+
415+
416+
@pytest.mark.skipif(
417+
sys.version_info < (3, 10), reason="bigframes requires python 3.10 or higher"
418+
)
419+
@pytest.mark.asyncio
420+
@pytest.mark.usefixtures("mock_bigquery_client", "mock_import_bigframes")
421+
async def test_create_dataset_from_bigframes_preserves_other_metadata_async(client, is_replay_mode):
422+
import bigframes.pandas
423+
424+
dataframe = pd.DataFrame(
425+
{
426+
"col1": ["col1"],
427+
"col2": ["col2"],
428+
}
429+
)
430+
if is_replay_mode:
431+
bf_dataframe = mock.MagicMock()
432+
bf_dataframe.to_gbq.return_value = "temp_table_id"
433+
else:
434+
bf_dataframe = bigframes.pandas.DataFrame(dataframe)
435+
436+
dataset = await client.aio.datasets.create_from_bigframes(
437+
dataframe=bf_dataframe,
438+
target_table_id=BIGQUERY_TABLE_NAME,
439+
multimodal_dataset={
440+
"display_name": "test-from-bigframes",
441+
"metadata": {
442+
"gemini_request_read_config": {
443+
"assembled_request_column_name": "test_column"
444+
}
445+
}
446+
},
447+
)
448+
449+
assert isinstance(dataset, types.MultimodalDataset)
450+
assert dataset.display_name == "test-from-bigframes"
451+
assert dataset.metadata.gemini_request_read_config.assembled_request_column_name == "test_column"
452+
assert dataset.metadata.input_config.bigquery_source.uri == (
453+
f"bq://{BIGQUERY_TABLE_NAME}"
454+
)

vertexai/_genai/_datasets_utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,6 @@
3434
T = TypeVar("T", bound=BaseModel)
3535

3636

37-
def create_from_response(model_type: Type[T], response: dict[str, Any]) -> T:
38-
"""Creates a model from a response."""
39-
model_field_names = model_type.model_fields.keys()
40-
filtered_response = {}
41-
for key, value in response.items():
42-
snake_key = common.camel_to_snake(key)
43-
if snake_key in model_field_names:
44-
filtered_response[snake_key] = value
45-
return model_type(**filtered_response)
46-
47-
4837
def validate_multimodal_dataset_bigquery_uri(
4938
multimodal_dataset: common.MultimodalDataset,
5039
) -> None:

0 commit comments

Comments
 (0)