This repository was archived by the owner on Apr 1, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 68
Expand file tree
/
Copy pathml.py
More file actions
226 lines (190 loc) · 7.66 KB
/
ml.py
File metadata and controls
226 lines (190 loc) · 7.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from typing import Dict, Mapping, Optional, Union
import bigframes.core.compile.googlesql as googlesql
import bigframes.core.sql
def create_model_ddl(
model_name: str,
*,
replace: bool = False,
if_not_exists: bool = False,
transform: Optional[list[str]] = None,
input_schema: Optional[Mapping[str, str]] = None,
output_schema: Optional[Mapping[str, str]] = None,
connection_name: Optional[str] = None,
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
training_data: Optional[str] = None,
custom_holiday: Optional[str] = None,
) -> str:
"""Encode the CREATE MODEL statement.
See https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create for reference.
"""
if replace:
create = "CREATE OR REPLACE MODEL "
elif if_not_exists:
create = "CREATE MODEL IF NOT EXISTS "
else:
create = "CREATE MODEL "
ddl = f"{create}{googlesql.identifier(model_name)}\n"
# [TRANSFORM (select_list)]
if transform:
ddl += f"TRANSFORM ({', '.join(transform)})\n"
# [INPUT (field_name field_type) OUTPUT (field_name field_type)]
if input_schema:
inputs = [f"{k} {v}" for k, v in input_schema.items()]
ddl += f"INPUT ({', '.join(inputs)})\n"
if output_schema:
outputs = [f"{k} {v}" for k, v in output_schema.items()]
ddl += f"OUTPUT ({', '.join(outputs)})\n"
# [REMOTE WITH CONNECTION {connection_name | DEFAULT}]
if connection_name:
if connection_name.upper() == "DEFAULT":
ddl += "REMOTE WITH CONNECTION DEFAULT\n"
else:
ddl += f"REMOTE WITH CONNECTION {googlesql.identifier(connection_name)}\n"
# [OPTIONS(model_option_list)]
if options:
rendered_options = []
for option_name, option_value in options.items():
if isinstance(option_value, (list, tuple)):
# Handle list options like model_registry="vertex_ai"
# wait, usually options are key=value.
# if value is list, it is [val1, val2]
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
else:
rendered_val = bigframes.core.sql.simple_literal(option_value)
rendered_options.append(f"{option_name} = {rendered_val}")
ddl += f"OPTIONS({', '.join(rendered_options)})\n"
# [AS {query_statement | ( training_data AS (query_statement), custom_holiday AS (holiday_statement) )}]
if training_data:
if custom_holiday:
# When custom_holiday is present, we need named clauses
parts = []
parts.append(f"training_data AS ({training_data})")
parts.append(f"custom_holiday AS ({custom_holiday})")
ddl += f"AS (\n {', '.join(parts)}\n)"
else:
# Just training_data is treated as the query_statement
ddl += f"AS {training_data}\n"
return ddl
def _build_struct_sql(
struct_options: Mapping[str, Union[str, int, float, bool]]
) -> str:
if not struct_options:
return ""
rendered_options = []
for option_name, option_value in struct_options.items():
rendered_val = bigframes.core.sql.simple_literal(option_value)
rendered_options.append(f"{rendered_val} AS {option_name}")
return f", STRUCT({', '.join(rendered_options)})"
def evaluate(
model_name: str,
*,
table: Optional[str] = None,
perform_aggregation: Optional[bool] = None,
horizon: Optional[int] = None,
confidence_level: Optional[float] = None,
) -> str:
"""Encode the ML.EVAluate statement.
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-evaluate for reference.
"""
struct_options: Dict[str, Union[str, int, float, bool]] = {}
if perform_aggregation is not None:
struct_options["perform_aggregation"] = perform_aggregation
if horizon is not None:
struct_options["horizon"] = horizon
if confidence_level is not None:
struct_options["confidence_level"] = confidence_level
sql = f"SELECT * FROM ML.EVALUATE(MODEL {googlesql.identifier(model_name)}"
if table:
sql += f", ({table})"
sql += _build_struct_sql(struct_options)
sql += ")\n"
return sql
def predict(
model_name: str,
table: str,
*,
threshold: Optional[float] = None,
keep_original_columns: Optional[bool] = None,
trial_id: Optional[int] = None,
) -> str:
"""Encode the ML.PREDICT statement.
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict for reference.
"""
struct_options = {}
if threshold is not None:
struct_options["threshold"] = threshold
if keep_original_columns is not None:
struct_options["keep_original_columns"] = keep_original_columns
if trial_id is not None:
struct_options["trial_id"] = trial_id
sql = (
f"SELECT * FROM ML.PREDICT(MODEL {googlesql.identifier(model_name)}, ({table})"
)
sql += _build_struct_sql(struct_options)
sql += ")\n"
return sql
def explain_predict(
model_name: str,
table: str,
*,
top_k_features: Optional[int] = None,
threshold: Optional[float] = None,
integrated_gradients_num_steps: Optional[int] = None,
approx_feature_contrib: Optional[bool] = None,
) -> str:
"""Encode the ML.EXPLAIN_PREDICT statement.
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict for reference.
"""
struct_options: Dict[str, Union[str, int, float, bool]] = {}
if top_k_features is not None:
struct_options["top_k_features"] = top_k_features
if threshold is not None:
struct_options["threshold"] = threshold
if integrated_gradients_num_steps is not None:
struct_options[
"integrated_gradients_num_steps"
] = integrated_gradients_num_steps
if approx_feature_contrib is not None:
struct_options["approx_feature_contrib"] = approx_feature_contrib
sql = f"SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {googlesql.identifier(model_name)}, ({table})"
sql += _build_struct_sql(struct_options)
sql += ")\n"
return sql
def global_explain(
model_name: str,
*,
class_level_explain: Optional[bool] = None,
) -> str:
"""Encode the ML.GLOBAL_EXPLAIN statement.
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain for reference.
"""
struct_options = {}
if class_level_explain is not None:
struct_options["class_level_explain"] = class_level_explain
sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {googlesql.identifier(model_name)}"
sql += _build_struct_sql(struct_options)
sql += ")\n"
return sql
def transform(
model_name: str,
table: str,
) -> str:
"""Encode the ML.TRANSFORM statement.
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-transform for reference.
"""
sql = f"SELECT * FROM ML.TRANSFORM(MODEL {googlesql.identifier(model_name)}, ({table}))\n"
return sql