Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 35f3f5e

Browse files
authored
feat: add bigquery.ml.generate_embedding function (#2422)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/python-bigquery-dataframes/issues/new/choose) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 248c8ea commit 35f3f5e

File tree

8 files changed

+204
-0
lines changed

8 files changed

+204
-0
lines changed

bigframes/bigquery/_operations/ml.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -520,3 +520,63 @@ def generate_text(
520520
return bpd.read_gbq_query(sql)
521521
else:
522522
return session.read_gbq_query(sql)
523+
524+
525+
@log_adapter.method_logger(custom_base_name="bigquery_ml")
526+
def generate_embedding(
527+
model: Union[bigframes.ml.base.BaseEstimator, str, pd.Series],
528+
input_: Union[pd.DataFrame, dataframe.DataFrame, str],
529+
*,
530+
flatten_json_output: Optional[bool] = None,
531+
task_type: Optional[str] = None,
532+
output_dimensionality: Optional[int] = None,
533+
) -> dataframe.DataFrame:
534+
"""
535+
Generates text embedding using a BigQuery ML model.
536+
537+
See the `BigQuery ML GENERATE_EMBEDDING function syntax
538+
<https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-embedding>`_
539+
for additional reference.
540+
541+
Args:
542+
model (bigframes.ml.base.BaseEstimator or str):
543+
The model to use for text embedding.
544+
input_ (Union[bigframes.pandas.DataFrame, str]):
545+
The DataFrame or query to use for text embedding.
546+
flatten_json_output (bool, optional):
547+
A BOOL value that determines the content of the generated JSON column.
548+
task_type (str, optional):
549+
A STRING value that specifies the intended downstream application task.
550+
Supported values are:
551+
- `RETRIEVAL_QUERY`
552+
- `RETRIEVAL_DOCUMENT`
553+
- `SEMANTIC_SIMILARITY`
554+
- `CLASSIFICATION`
555+
- `CLUSTERING`
556+
- `QUESTION_ANSWERING`
557+
- `FACT_VERIFICATION`
558+
- `CODE_RETRIEVAL_QUERY`
559+
output_dimensionality (int, optional):
560+
An INT64 value that specifies the size of the output embedding.
561+
562+
Returns:
563+
bigframes.pandas.DataFrame:
564+
The generated text embedding.
565+
"""
566+
import bigframes.pandas as bpd
567+
568+
model_name, session = _get_model_name_and_session(model, input_)
569+
table_sql = _to_sql(input_)
570+
571+
sql = bigframes.core.sql.ml.generate_embedding(
572+
model_name=model_name,
573+
table=table_sql,
574+
flatten_json_output=flatten_json_output,
575+
task_type=task_type,
576+
output_dimensionality=output_dimensionality,
577+
)
578+
579+
if session is None:
580+
return bpd.read_gbq_query(sql)
581+
else:
582+
return session.read_gbq_query(sql)

bigframes/bigquery/ml.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
create_model,
2424
evaluate,
2525
explain_predict,
26+
generate_embedding,
2627
generate_text,
2728
global_explain,
2829
predict,
@@ -37,4 +38,5 @@
3738
"global_explain",
3839
"transform",
3940
"generate_text",
41+
"generate_embedding",
4042
]

bigframes/core/sql/ml.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,3 +296,31 @@ def generate_text(
296296
sql += _build_struct_sql(struct_options)
297297
sql += ")\n"
298298
return sql
299+
300+
301+
def generate_embedding(
302+
model_name: str,
303+
table: str,
304+
*,
305+
flatten_json_output: Optional[bool] = None,
306+
task_type: Optional[str] = None,
307+
output_dimensionality: Optional[int] = None,
308+
) -> str:
309+
"""Encode the ML.GENERATE_EMBEDDING statement.
310+
See https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-generate-embedding for reference.
311+
"""
312+
struct_options: Dict[
313+
str,
314+
Union[str, int, float, bool, Mapping[str, str], List[str], Mapping[str, Any]],
315+
] = {}
316+
if flatten_json_output is not None:
317+
struct_options["flatten_json_output"] = flatten_json_output
318+
if task_type is not None:
319+
struct_options["task_type"] = task_type
320+
if output_dimensionality is not None:
321+
struct_options["output_dimensionality"] = output_dimensionality
322+
323+
sql = f"SELECT * FROM ML.GENERATE_EMBEDDING(MODEL {googlesql.identifier(model_name)}, ({table})"
324+
sql += _build_struct_sql(struct_options)
325+
sql += ")\n"
326+
return sql
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import pytest
16+
17+
import bigframes.bigquery.ml as ml
18+
import bigframes.pandas as bpd
19+
20+
21+
@pytest.fixture(scope="session")
22+
def embedding_model(bq_connection, dataset_id):
23+
model_name = f"{dataset_id}.embedding_model"
24+
return ml.create_model(
25+
model_name=model_name,
26+
options={"endpoint": "gemini-embedding-001"},
27+
connection_name=bq_connection,
28+
)
29+
30+
31+
def test_generate_embedding(embedding_model):
32+
df = bpd.DataFrame(
33+
{
34+
"content": [
35+
"What is BigQuery?",
36+
"What is BQML?",
37+
]
38+
}
39+
)
40+
41+
result = ml.generate_embedding(embedding_model, df)
42+
assert len(result) == 2
43+
assert "ml_generate_embedding_result" in result.columns
44+
assert "ml_generate_embedding_status" in result.columns
45+
46+
47+
def test_generate_embedding_with_options(embedding_model):
48+
df = bpd.DataFrame(
49+
{
50+
"content": [
51+
"What is BigQuery?",
52+
"What is BQML?",
53+
]
54+
}
55+
)
56+
57+
result = ml.generate_embedding(
58+
embedding_model, df, task_type="RETRIEVAL_DOCUMENT", output_dimensionality=256
59+
)
60+
assert len(result) == 2
61+
assert "ml_generate_embedding_result" in result.columns
62+
assert "ml_generate_embedding_status" in result.columns
63+
embedding = result["ml_generate_embedding_result"].to_pandas()
64+
assert len(embedding[0]) == 256

tests/unit/bigquery/test_ml.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,32 @@ def test_generate_text_with_pandas_dataframe(read_pandas_mock, read_gbq_query_mo
200200
assert "['a', 'b'] AS stop_sequences" in generated_sql
201201
assert "true AS ground_with_google_search" in generated_sql
202202
assert "'TYPE' AS request_type" in generated_sql
203+
204+
205+
@mock.patch("bigframes.pandas.read_gbq_query")
206+
@mock.patch("bigframes.pandas.read_pandas")
207+
def test_generate_embedding_with_pandas_dataframe(
208+
read_pandas_mock, read_gbq_query_mock
209+
):
210+
df = pd.DataFrame({"col1": [1, 2, 3]})
211+
read_pandas_mock.return_value._to_sql_query.return_value = (
212+
"SELECT * FROM `pandas_df`",
213+
[],
214+
[],
215+
)
216+
ml_ops.generate_embedding(
217+
MODEL_SERIES,
218+
input_=df,
219+
flatten_json_output=True,
220+
task_type="RETRIEVAL_DOCUMENT",
221+
output_dimensionality=256,
222+
)
223+
read_pandas_mock.assert_called_once()
224+
read_gbq_query_mock.assert_called_once()
225+
generated_sql = read_gbq_query_mock.call_args[0][0]
226+
assert "ML.GENERATE_EMBEDDING" in generated_sql
227+
assert f"MODEL `{MODEL_NAME}`" in generated_sql
228+
assert "(SELECT * FROM `pandas_df`)" in generated_sql
229+
assert "true AS flatten_json_output" in generated_sql
230+
assert "'RETRIEVAL_DOCUMENT' AS task_type" in generated_sql
231+
assert "256 AS output_dimensionality" in generated_sql
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
SELECT * FROM ML.GENERATE_EMBEDDING(MODEL `my_project.my_dataset.my_model`, (SELECT * FROM new_data), STRUCT(true AS flatten_json_output, 'RETRIEVAL_DOCUMENT' AS task_type, 256 AS output_dimensionality))

tests/unit/core/sql/test_ml.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,22 @@ def test_generate_text_model_with_options(snapshot):
201201
request_type="TYPE",
202202
)
203203
snapshot.assert_match(sql, "generate_text_model_with_options.sql")
204+
205+
206+
def test_generate_embedding_model_basic(snapshot):
207+
sql = bigframes.core.sql.ml.generate_embedding(
208+
model_name="my_project.my_dataset.my_model",
209+
table="SELECT * FROM new_data",
210+
)
211+
snapshot.assert_match(sql, "generate_embedding_model_basic.sql")
212+
213+
214+
def test_generate_embedding_model_with_options(snapshot):
215+
sql = bigframes.core.sql.ml.generate_embedding(
216+
model_name="my_project.my_dataset.my_model",
217+
table="SELECT * FROM new_data",
218+
flatten_json_output=True,
219+
task_type="RETRIEVAL_DOCUMENT",
220+
output_dimensionality=256,
221+
)
222+
snapshot.assert_match(sql, "generate_embedding_model_with_options.sql")

0 commit comments

Comments
 (0)