@@ -33,17 +33,41 @@ def mock_dataframe(mock_session):
3333 df = mock .create_autospec (spec = bigframes .dataframe .DataFrame )
3434 df ._session = mock_session
3535 df .sql = "SELECT * FROM my_table"
36+ df ._to_sql_query .return_value = ("SELECT * FROM my_table" , None , None )
3637 return df
3738
3839
3940@pytest .fixture
40- def mock_series (mock_session ):
41+ def mock_embedding_series (mock_session ):
4142 series = mock .create_autospec (spec = bigframes .series .Series )
4243 series ._session = mock_session
4344 # Mock to_frame to return a mock dataframe
4445 df = mock .create_autospec (spec = bigframes .dataframe .DataFrame )
4546 df ._session = mock_session
4647 df .sql = "SELECT my_col AS content FROM my_table"
48+ df ._to_sql_query .return_value = (
49+ "SELECT my_col AS content FROM my_table" ,
50+ None ,
51+ None ,
52+ )
53+ series .copy .return_value = series
54+ series .to_frame .return_value = df
55+ return series
56+
57+
58+ @pytest .fixture
59+ def mock_text_series (mock_session ):
60+ series = mock .create_autospec (spec = bigframes .series .Series )
61+ series ._session = mock_session
62+ # Mock to_frame to return a mock dataframe
63+ df = mock .create_autospec (spec = bigframes .dataframe .DataFrame )
64+ df ._session = mock_session
65+ df .sql = "SELECT my_col AS prompt FROM my_table"
66+ df ._to_sql_query .return_value = (
67+ "SELECT my_col AS prompt FROM my_table" ,
68+ None ,
69+ None ,
70+ )
4771 series .copy .return_value = series
4872 series .to_frame .return_value = df
4973 return series
@@ -58,8 +82,8 @@ def test_generate_embedding_with_dataframe(mock_dataframe, mock_session):
5882 output_dimensionality = 256 ,
5983 )
6084
61- mock_session .read_gbq .assert_called_once ()
62- query = mock_session .read_gbq .call_args [0 ][0 ]
85+ mock_session .read_gbq_query .assert_called_once ()
86+ query = mock_session .read_gbq_query .call_args [0 ][0 ]
6387
6488 # Normalize whitespace for comparison
6589 query = " " .join (query .split ())
@@ -75,15 +99,19 @@ def test_generate_embedding_with_dataframe(mock_dataframe, mock_session):
7599 assert expected_part_4 in query
76100
77101
78- def test_generate_embedding_with_series (mock_series , mock_session ):
102+ def test_generate_embedding_with_series (mock_embedding_series , mock_session ):
79103 model_name = "project.dataset.model"
80104
81105 bbq .ai .generate_embedding (
82- model_name , mock_series , start_second = 0.0 , end_second = 10.0 , interval_seconds = 5.0
106+ model_name ,
107+ mock_embedding_series ,
108+ start_second = 0.0 ,
109+ end_second = 10.0 ,
110+ interval_seconds = 5.0 ,
83111 )
84112
85- mock_session .read_gbq .assert_called_once ()
86- query = mock_session .read_gbq .call_args [0 ][0 ]
113+ mock_session .read_gbq_query .assert_called_once ()
114+ query = mock_session .read_gbq_query .call_args [0 ][0 ]
87115 query = " " .join (query .split ())
88116
89117 assert f"MODEL `{ model_name } `" in query
@@ -102,8 +130,8 @@ def test_generate_embedding_defaults(mock_dataframe, mock_session):
102130 mock_dataframe ,
103131 )
104132
105- mock_session .read_gbq .assert_called_once ()
106- query = mock_session .read_gbq .call_args [0 ][0 ]
133+ mock_session .read_gbq_query .assert_called_once ()
134+ query = mock_session .read_gbq_query .call_args [0 ][0 ]
107135 query = " " .join (query .split ())
108136
109137 assert f"MODEL `{ model_name } `" in query
@@ -131,4 +159,86 @@ def test_generate_embedding_with_pandas_dataframe(
131159 # Check that read_pandas was called with something (the pandas df)
132160 assert read_pandas_mock .call_args [0 ][0 ] is pandas_df
133161
134- mock_session .read_gbq .assert_called_once ()
162+ mock_session .read_gbq_query .assert_called_once ()
163+
164+
165+ def test_generate_text_with_dataframe (mock_dataframe , mock_session ):
166+ model_name = "project.dataset.model"
167+
168+ bbq .ai .generate_text (
169+ model_name ,
170+ mock_dataframe ,
171+ max_output_tokens = 256 ,
172+ )
173+
174+ mock_session .read_gbq_query .assert_called_once ()
175+ query = mock_session .read_gbq_query .call_args [0 ][0 ]
176+
177+ # Normalize whitespace for comparison
178+ query = " " .join (query .split ())
179+
180+ expected_part_1 = "SELECT * FROM AI.GENERATE_TEXT("
181+ expected_part_2 = f"MODEL `{ model_name } `,"
182+ expected_part_3 = "(SELECT * FROM my_table),"
183+ expected_part_4 = "STRUCT(256 AS MAX_OUTPUT_TOKENS)"
184+
185+ assert expected_part_1 in query
186+ assert expected_part_2 in query
187+ assert expected_part_3 in query
188+ assert expected_part_4 in query
189+
190+
191+ def test_generate_text_with_series (mock_text_series , mock_session ):
192+ model_name = "project.dataset.model"
193+
194+ bbq .ai .generate_text (
195+ model_name ,
196+ mock_text_series ,
197+ )
198+
199+ mock_session .read_gbq_query .assert_called_once ()
200+ query = mock_session .read_gbq_query .call_args [0 ][0 ]
201+ query = " " .join (query .split ())
202+
203+ assert f"MODEL `{ model_name } `" in query
204+ assert "(SELECT my_col AS prompt FROM my_table)" in query
205+
206+
207+ def test_generate_text_defaults (mock_dataframe , mock_session ):
208+ model_name = "project.dataset.model"
209+
210+ bbq .ai .generate_text (
211+ model_name ,
212+ mock_dataframe ,
213+ )
214+
215+ mock_session .read_gbq_query .assert_called_once ()
216+ query = mock_session .read_gbq_query .call_args [0 ][0 ]
217+ query = " " .join (query .split ())
218+
219+ assert f"MODEL `{ model_name } `" in query
220+ assert "STRUCT()" in query
221+
222+
223+ @mock .patch ("bigframes.pandas.read_pandas" )
224+ def test_generate_text_with_pandas_dataframe (
225+ read_pandas_mock , mock_dataframe , mock_session
226+ ):
227+ # This tests that pandas input path works and calls read_pandas
228+ model_name = "project.dataset.model"
229+
230+ # Mock return value of read_pandas to be a BigFrames DataFrame
231+ read_pandas_mock .return_value = mock_dataframe
232+
233+ pandas_df = pd .DataFrame ({"content" : ["test" ]})
234+
235+ bbq .ai .generate_text (
236+ model_name ,
237+ pandas_df ,
238+ )
239+
240+ read_pandas_mock .assert_called_once ()
241+ # Check that read_pandas was called with something (the pandas df)
242+ assert read_pandas_mock .call_args [0 ][0 ] is pandas_df
243+
244+ mock_session .read_gbq_query .assert_called_once ()
0 commit comments