This repository was archived by the owner on Apr 1, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 67
Expand file tree
/
Copy pathtest_ml.py
More file actions
86 lines (77 loc) · 3.19 KB
/
test_ml.py
File metadata and controls
86 lines (77 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
import bigframes.core.sql.ml
def test_create_model_basic(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_project.my_dataset.my_model",
options={"model_type": "LINEAR_REG", "input_label_cols": ["label"]},
query_statement="SELECT * FROM my_table",
)
snapshot.assert_match(sql, "create_model_basic.sql")
def test_create_model_replace(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_model",
replace=True,
options={"model_type": "LOGISTIC_REG"},
query_statement="SELECT * FROM t",
)
snapshot.assert_match(sql, "create_model_replace.sql")
def test_create_model_if_not_exists(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_model",
if_not_exists=True,
options={"model_type": "KMEANS"},
query_statement="SELECT * FROM t",
)
snapshot.assert_match(sql, "create_model_if_not_exists.sql")
def test_create_model_transform(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_model",
transform=["ML.STANDARD_SCALER(c1) OVER() AS c1_scaled", "c2"],
options={"model_type": "LINEAR_REG"},
query_statement="SELECT c1, c2, label FROM t",
)
snapshot.assert_match(sql, "create_model_transform.sql")
def test_create_model_remote(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_remote_model",
connection_name="my_project.us.my_connection",
options={"endpoint": "gemini-pro"},
input_schema={"prompt": "STRING"},
output_schema={"content": "STRING"},
)
snapshot.assert_match(sql, "create_model_remote.sql")
def test_create_model_remote_default(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_remote_model",
connection_name="DEFAULT",
options={"endpoint": "gemini-pro"},
)
snapshot.assert_match(sql, "create_model_remote_default.sql")
def test_create_model_training_data_and_holiday(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_arima_model",
options={"model_type": "ARIMA_PLUS"},
training_data="SELECT * FROM sales",
custom_holiday="SELECT * FROM holidays",
)
snapshot.assert_match(sql, "create_model_training_data_and_holiday.sql")
def test_create_model_list_option(snapshot):
sql = bigframes.core.sql.ml.create_model_ddl(
model_name="my_model",
options={"hidden_units": [32, 16], "dropout": 0.2},
query_statement="SELECT * FROM t",
)
snapshot.assert_match(sql, "create_model_list_option.sql")