Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 943f6ac

Browse files
committed
fix tests
1 parent e2e21d4 commit 943f6ac

File tree

3 files changed

+141
-11
lines changed

3 files changed

+141
-11
lines changed

bigframes/bigquery/_operations/ai.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ def generate_embedding(
454454
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-ai-generate-embedding#output>`_
455455
for details.
456456
"""
457+
data = _to_dataframe(data, series_rename="content")
457458
model_name, session = ml_utils.get_model_name_and_session(model, data)
458459
table_sql = ml_utils.to_sql(data)
459460

@@ -562,6 +563,7 @@ def generate_text(
562563
bigframes.pandas.DataFrame:
563564
The generated text.
564565
"""
566+
data = _to_dataframe(data, series_rename="prompt")
565567
model_name, session = ml_utils.get_model_name_and_session(model, data)
566568
table_sql = ml_utils.to_sql(data)
567569

@@ -914,3 +916,20 @@ def _resolve_connection_id(series: series.Series, connection_id: str | None):
914916
series._session._project,
915917
series._session._location,
916918
)
919+
920+
921+
def _to_dataframe(
922+
data: Union[dataframe.DataFrame, series.Series, pd.DataFrame, pd.Series],
923+
series_rename: str,
924+
) -> dataframe.DataFrame:
925+
if isinstance(data, (pd.DataFrame, pd.Series)):
926+
data = bpd.read_pandas(data)
927+
928+
if isinstance(data, series.Series):
929+
data = data.copy()
930+
data.name = series_rename
931+
return data.to_frame()
932+
elif isinstance(data, dataframe.DataFrame):
933+
return data
934+
935+
raise ValueError(f"Unsupported data type: {type(data)}")

bigframes/bigquery/_operations/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
import pandas as pd
1818

1919
import bigframes
20-
from bigframes import dataframe
20+
from bigframes import dataframe, series
21+
import bigframes.pandas as bpd
2122

2223

2324
def get_model_name_and_session(

tests/unit/bigquery/test_ai.py

Lines changed: 120 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,41 @@ def mock_dataframe(mock_session):
3333
df = mock.create_autospec(spec=bigframes.dataframe.DataFrame)
3434
df._session = mock_session
3535
df.sql = "SELECT * FROM my_table"
36+
df._to_sql_query.return_value = ("SELECT * FROM my_table", None, None)
3637
return df
3738

3839

3940
@pytest.fixture
40-
def mock_series(mock_session):
41+
def mock_embedding_series(mock_session):
4142
series = mock.create_autospec(spec=bigframes.series.Series)
4243
series._session = mock_session
4344
# Mock to_frame to return a mock dataframe
4445
df = mock.create_autospec(spec=bigframes.dataframe.DataFrame)
4546
df._session = mock_session
4647
df.sql = "SELECT my_col AS content FROM my_table"
48+
df._to_sql_query.return_value = (
49+
"SELECT my_col AS content FROM my_table",
50+
None,
51+
None,
52+
)
53+
series.copy.return_value = series
54+
series.to_frame.return_value = df
55+
return series
56+
57+
58+
@pytest.fixture
59+
def mock_text_series(mock_session):
60+
series = mock.create_autospec(spec=bigframes.series.Series)
61+
series._session = mock_session
62+
# Mock to_frame to return a mock dataframe
63+
df = mock.create_autospec(spec=bigframes.dataframe.DataFrame)
64+
df._session = mock_session
65+
df.sql = "SELECT my_col AS prompt FROM my_table"
66+
df._to_sql_query.return_value = (
67+
"SELECT my_col AS prompt FROM my_table",
68+
None,
69+
None,
70+
)
4771
series.copy.return_value = series
4872
series.to_frame.return_value = df
4973
return series
@@ -58,8 +82,8 @@ def test_generate_embedding_with_dataframe(mock_dataframe, mock_session):
5882
output_dimensionality=256,
5983
)
6084

61-
mock_session.read_gbq.assert_called_once()
62-
query = mock_session.read_gbq.call_args[0][0]
85+
mock_session.read_gbq_query.assert_called_once()
86+
query = mock_session.read_gbq_query.call_args[0][0]
6387

6488
# Normalize whitespace for comparison
6589
query = " ".join(query.split())
@@ -75,15 +99,19 @@ def test_generate_embedding_with_dataframe(mock_dataframe, mock_session):
7599
assert expected_part_4 in query
76100

77101

78-
def test_generate_embedding_with_series(mock_series, mock_session):
102+
def test_generate_embedding_with_series(mock_embedding_series, mock_session):
79103
model_name = "project.dataset.model"
80104

81105
bbq.ai.generate_embedding(
82-
model_name, mock_series, start_second=0.0, end_second=10.0, interval_seconds=5.0
106+
model_name,
107+
mock_embedding_series,
108+
start_second=0.0,
109+
end_second=10.0,
110+
interval_seconds=5.0,
83111
)
84112

85-
mock_session.read_gbq.assert_called_once()
86-
query = mock_session.read_gbq.call_args[0][0]
113+
mock_session.read_gbq_query.assert_called_once()
114+
query = mock_session.read_gbq_query.call_args[0][0]
87115
query = " ".join(query.split())
88116

89117
assert f"MODEL `{model_name}`" in query
@@ -102,8 +130,8 @@ def test_generate_embedding_defaults(mock_dataframe, mock_session):
102130
mock_dataframe,
103131
)
104132

105-
mock_session.read_gbq.assert_called_once()
106-
query = mock_session.read_gbq.call_args[0][0]
133+
mock_session.read_gbq_query.assert_called_once()
134+
query = mock_session.read_gbq_query.call_args[0][0]
107135
query = " ".join(query.split())
108136

109137
assert f"MODEL `{model_name}`" in query
@@ -131,4 +159,86 @@ def test_generate_embedding_with_pandas_dataframe(
131159
# Check that read_pandas was called with something (the pandas df)
132160
assert read_pandas_mock.call_args[0][0] is pandas_df
133161

134-
mock_session.read_gbq.assert_called_once()
162+
mock_session.read_gbq_query.assert_called_once()
163+
164+
165+
def test_generate_text_with_dataframe(mock_dataframe, mock_session):
166+
model_name = "project.dataset.model"
167+
168+
bbq.ai.generate_text(
169+
model_name,
170+
mock_dataframe,
171+
max_output_tokens=256,
172+
)
173+
174+
mock_session.read_gbq_query.assert_called_once()
175+
query = mock_session.read_gbq_query.call_args[0][0]
176+
177+
# Normalize whitespace for comparison
178+
query = " ".join(query.split())
179+
180+
expected_part_1 = "SELECT * FROM AI.GENERATE_TEXT("
181+
expected_part_2 = f"MODEL `{model_name}`,"
182+
expected_part_3 = "(SELECT * FROM my_table),"
183+
expected_part_4 = "STRUCT(256 AS MAX_OUTPUT_TOKENS)"
184+
185+
assert expected_part_1 in query
186+
assert expected_part_2 in query
187+
assert expected_part_3 in query
188+
assert expected_part_4 in query
189+
190+
191+
def test_generate_text_with_series(mock_text_series, mock_session):
192+
model_name = "project.dataset.model"
193+
194+
bbq.ai.generate_text(
195+
model_name,
196+
mock_text_series,
197+
)
198+
199+
mock_session.read_gbq_query.assert_called_once()
200+
query = mock_session.read_gbq_query.call_args[0][0]
201+
query = " ".join(query.split())
202+
203+
assert f"MODEL `{model_name}`" in query
204+
assert "(SELECT my_col AS prompt FROM my_table)" in query
205+
206+
207+
def test_generate_text_defaults(mock_dataframe, mock_session):
208+
model_name = "project.dataset.model"
209+
210+
bbq.ai.generate_text(
211+
model_name,
212+
mock_dataframe,
213+
)
214+
215+
mock_session.read_gbq_query.assert_called_once()
216+
query = mock_session.read_gbq_query.call_args[0][0]
217+
query = " ".join(query.split())
218+
219+
assert f"MODEL `{model_name}`" in query
220+
assert "STRUCT()" in query
221+
222+
223+
@mock.patch("bigframes.pandas.read_pandas")
224+
def test_generate_text_with_pandas_dataframe(
225+
read_pandas_mock, mock_dataframe, mock_session
226+
):
227+
# This tests that pandas input path works and calls read_pandas
228+
model_name = "project.dataset.model"
229+
230+
# Mock return value of read_pandas to be a BigFrames DataFrame
231+
read_pandas_mock.return_value = mock_dataframe
232+
233+
pandas_df = pd.DataFrame({"content": ["test"]})
234+
235+
bbq.ai.generate_text(
236+
model_name,
237+
pandas_df,
238+
)
239+
240+
read_pandas_mock.assert_called_once()
241+
# Check that read_pandas was called with something (the pandas df)
242+
assert read_pandas_mock.call_args[0][0] is pandas_df
243+
244+
mock_session.read_gbq_query.assert_called_once()

0 commit comments

Comments
 (0)