Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 1696ddd

Browse files
committed
fix unit tests
1 parent a3ee429 commit 1696ddd

File tree

11 files changed

+46
-33
lines changed

11 files changed

+46
-33
lines changed

bigframes/dtypes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -724,10 +724,6 @@ def infer_literal_type(literal) -> typing.Optional[Dtype]:
724724
# Maybe also normalize literal to canonical python representation to remove this burden from compilers?
725725
if isinstance(literal, pa.Scalar):
726726
return arrow_dtype_to_bigframes_dtype(literal.type)
727-
if pd.api.types.is_list_like(literal):
728-
element_types = [infer_literal_type(i) for i in literal]
729-
common_type = lcd_type(*element_types)
730-
return list_type(common_type)
731727
if pd.api.types.is_dict_like(literal):
732728
fields = []
733729
for key in literal.keys():
@@ -738,6 +734,10 @@ def infer_literal_type(literal) -> typing.Optional[Dtype]:
738734
pa.field(key, field_type, nullable=(not pa.types.is_list(field_type)))
739735
)
740736
return pd.ArrowDtype(pa.struct(fields))
737+
if pd.api.types.is_list_like(literal):
738+
element_types = [infer_literal_type(i) for i in literal]
739+
common_type = lcd_type(*element_types)
740+
return list_type(common_type)
741741
if pd.isna(literal):
742742
return None # Null value without a definite type
743743
# Make sure to check datetime before date as datetimes are also dates

tests/unit/bigquery/test_ai.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_generate_embedding_with_dataframe(mock_dataframe, mock_session):
9191
expected_part_1 = "SELECT * FROM AI.GENERATE_EMBEDDING("
9292
expected_part_2 = f"MODEL `{model_name}`,"
9393
expected_part_3 = "(SELECT * FROM my_table),"
94-
expected_part_4 = "STRUCT(256 AS OUTPUT_DIMENSIONALITY)"
94+
expected_part_4 = "STRUCT(256 AS `OUTPUT_DIMENSIONALITY`)"
9595

9696
assert expected_part_1 in query
9797
assert expected_part_2 in query
@@ -117,7 +117,7 @@ def test_generate_embedding_with_series(mock_embedding_series, mock_session):
117117
assert f"MODEL `{model_name}`" in query
118118
assert "(SELECT my_col AS content FROM my_table)" in query
119119
assert (
120-
"STRUCT(0.0 AS START_SECOND, 10.0 AS END_SECOND, 5.0 AS INTERVAL_SECONDS)"
120+
"STRUCT(0.0 AS `START_SECOND`, 10.0 AS `END_SECOND`, 5.0 AS `INTERVAL_SECONDS`)"
121121
in query
122122
)
123123

@@ -180,7 +180,7 @@ def test_generate_text_with_dataframe(mock_dataframe, mock_session):
180180
expected_part_1 = "SELECT * FROM AI.GENERATE_TEXT("
181181
expected_part_2 = f"MODEL `{model_name}`,"
182182
expected_part_3 = "(SELECT * FROM my_table),"
183-
expected_part_4 = "STRUCT(256 AS MAX_OUTPUT_TOKENS)"
183+
expected_part_4 = "STRUCT(256 AS `MAX_OUTPUT_TOKENS`)"
184184

185185
assert expected_part_1 in query
186186
assert expected_part_2 in query
@@ -238,7 +238,7 @@ def test_generate_table_with_dataframe(mock_dataframe, mock_session):
238238
expected_part_1 = "SELECT * FROM AI.GENERATE_TABLE("
239239
expected_part_2 = f"MODEL `{model_name}`,"
240240
expected_part_3 = "(SELECT * FROM my_table),"
241-
expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS output_schema)"
241+
expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS `output_schema`)"
242242

243243
assert expected_part_1 in query
244244
assert expected_part_2 in query
@@ -264,7 +264,7 @@ def test_generate_table_with_options(mock_dataframe, mock_session):
264264
assert f"MODEL `{model_name}`" in query
265265
assert "(SELECT * FROM my_table)" in query
266266
assert (
267-
"STRUCT('col1 STRING' AS output_schema, 0.5 AS temperature, 100 AS max_output_tokens)"
267+
"STRUCT('col1 STRING' AS `output_schema`, 0.5 AS `temperature`, 100 AS `max_output_tokens`)"
268268
in query
269269
)
270270

@@ -287,7 +287,7 @@ def test_generate_table_with_mapping_schema(mock_dataframe, mock_session):
287287
expected_part_1 = "SELECT * FROM AI.GENERATE_TABLE("
288288
expected_part_2 = f"MODEL `{model_name}`,"
289289
expected_part_3 = "(SELECT * FROM my_table),"
290-
expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS output_schema)"
290+
expected_part_4 = "STRUCT('col1 STRING, col2 INT64' AS `output_schema`)"
291291

292292
assert expected_part_1 in query
293293
assert expected_part_2 in query

tests/unit/bigquery/test_ml.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -167,14 +167,14 @@ def test_generate_text_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mo
167167
assert "ML.GENERATE_TEXT" in generated_sql
168168
assert f"MODEL `{MODEL_NAME}`" in generated_sql
169169
assert "(SELECT * FROM `pandas_df`)" in generated_sql
170-
assert "STRUCT(0.5 AS temperature" in generated_sql
171-
assert "128 AS max_output_tokens" in generated_sql
172-
assert "20 AS top_k" in generated_sql
173-
assert "0.9 AS top_p" in generated_sql
174-
assert "true AS flatten_json_output" in generated_sql
175-
assert "['a', 'b'] AS stop_sequences" in generated_sql
176-
assert "true AS ground_with_google_search" in generated_sql
177-
assert "'TYPE' AS request_type" in generated_sql
170+
assert "STRUCT(\n 0.5 AS `temperature`" in generated_sql
171+
assert "128 AS `max_output_tokens`" in generated_sql
172+
assert "20 AS `top_k`" in generated_sql
173+
assert "0.9 AS `top_p`" in generated_sql
174+
assert "TRUE AS `flatten_json_output`" in generated_sql
175+
assert "['a', 'b'] AS `stop_sequences`" in generated_sql
176+
assert "TRUE AS `ground_with_google_search`" in generated_sql
177+
assert "'TYPE' AS `request_type`" in generated_sql
178178

179179

180180
@mock.patch("bigframes.pandas.read_gbq_query")
@@ -210,6 +210,6 @@ def test_generate_embedding_with_pandas_dataframe(
210210
assert "ML.GENERATE_EMBEDDING" in generated_sql
211211
assert f"MODEL `{MODEL_NAME}`" in generated_sql
212212
assert "(SELECT * FROM `pandas_df`)" in generated_sql
213-
assert "true AS flatten_json_output" in generated_sql
214-
assert "'RETRIEVAL_DOCUMENT' AS task_type" in generated_sql
215-
assert "256 AS output_dimensionality" in generated_sql
213+
assert "STRUCT(\n TRUE AS `flatten_json_output`" in generated_sql
214+
assert "'RETRIEVAL_DOCUMENT' AS `task_type`" in generated_sql
215+
assert "256 AS `output_dimensionality`" in generated_sql
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(false AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level))
1+
SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(FALSE AS `perform_aggregation`, 10 AS `horizon`, 0.95 AS `confidence_level`))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(5 AS top_k_features))
1+
SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(5 AS `top_k_features`))
Original file line numberDiff line numberDiff line change
@@ -1 +1,5 @@
1-
SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(true AS flatten_json_output, 'RETRIEVAL_DOCUMENT' AS task_type, 256 AS output_dimensionality))
1+
SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(
2+
TRUE AS `flatten_json_output`,
3+
'RETRIEVAL_DOCUMENT' AS `task_type`,
4+
256 AS `output_dimensionality`
5+
))
Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,10 @@
1-
SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(0.5 AS temperature, 128 AS max_output_tokens, 20 AS top_k, 0.9 AS top_p, true AS flatten_json_output, ['a', 'b'] AS stop_sequences, true AS ground_with_google_search, 'TYPE' AS request_type))
1+
SELECT * FROM ML.GENERATE_TEXT(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(
2+
0.5 AS `temperature`,
3+
128 AS `max_output_tokens`,
4+
20 AS `top_k`,
5+
0.9 AS `top_p`,
6+
TRUE AS `flatten_json_output`,
7+
['a', 'b'] AS `stop_sequences`,
8+
TRUE AS `ground_with_google_search`,
9+
'TYPE' AS `request_type`
10+
))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(true AS class_level_explain))
1+
SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(TRUE AS `class_level_explain`))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(true AS keep_original_columns))
1+
SELECT * FROM ML.PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(TRUE AS `keep_original_columns`))

tests/unit/ml/test_golden_sql.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def test_linear_regression_default_fit(
124124
model.fit(mock_X, mock_y)
125125

126126
mock_session._start_query_ml_ddl.assert_called_once_with(
127-
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=True,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql"
127+
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=TRUE,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=FALSE,\n enable_global_explain=FALSE,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql"
128128
)
129129

130130

@@ -134,7 +134,7 @@ def test_linear_regression_params_fit(bqml_model_factory, mock_session, mock_X,
134134
model.fit(mock_X, mock_y)
135135

136136
mock_session._start_query_ml_ddl.assert_called_once_with(
137-
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=False,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql"
137+
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LINEAR_REG',\n data_split_method='NO_SPLIT',\n optimize_strategy='auto_strategy',\n fit_intercept=FALSE,\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=FALSE,\n enable_global_explain=FALSE,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql"
138138
)
139139

140140

@@ -169,7 +169,7 @@ def test_logistic_regression_default_fit(
169169
model.fit(mock_X, mock_y)
170170

171171
mock_session._start_query_ml_ddl.assert_called_once_with(
172-
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=True,\n auto_class_weights=False,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=False,\n enable_global_explain=False,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql",
172+
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=TRUE,\n auto_class_weights=FALSE,\n optimize_strategy='auto_strategy',\n l2_reg=0.0,\n max_iterations=20,\n learn_rate_strategy='line_search',\n min_rel_progress=0.01,\n calculate_p_values=FALSE,\n enable_global_explain=FALSE,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql",
173173
)
174174

175175

@@ -191,7 +191,7 @@ def test_logistic_regression_params_fit(
191191
model.fit(mock_X, mock_y)
192192

193193
mock_session._start_query_ml_ddl.assert_called_once_with(
194-
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=False,\n auto_class_weights=True,\n optimize_strategy='batch_gradient_descent',\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy='constant',\n min_rel_progress=0.02,\n calculate_p_values=False,\n enable_global_explain=False,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql"
194+
"CREATE OR REPLACE MODEL `test-project`.`_anon123`.`temp_model_id`\nOPTIONS(\n model_type='LOGISTIC_REG',\n data_split_method='NO_SPLIT',\n fit_intercept=FALSE,\n auto_class_weights=TRUE,\n optimize_strategy='batch_gradient_descent',\n l2_reg=0.2,\n max_iterations=30,\n learn_rate_strategy='constant',\n min_rel_progress=0.02,\n calculate_p_values=FALSE,\n enable_global_explain=FALSE,\n l1_reg=0.2,\n learn_rate=0.2,\n INPUT_LABEL_COLS=['input_column_label'])\nAS input_X_y_no_index_sql"
195195
)
196196

197197

0 commit comments

Comments
 (0)