This repository was archived by the owner on Apr 1, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 67
Expand file tree
/
Copy path_utils.py
More file actions
312 lines (257 loc) · 11.6 KB
/
_utils.py
File metadata and controls
312 lines (257 loc) · 11.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import inspect
import json
import sys
import typing
from typing import Any, cast, Optional, Sequence, Set
import warnings
import cloudpickle
import google.api_core.exceptions
from google.cloud import bigquery, functions_v2
import numpy
from packaging.requirements import Requirement
import pandas
import pyarrow
import bigframes.exceptions as bfe
import bigframes.formatting_helpers as bf_formatting
from bigframes.functions import function_typing
# Naming convention for the function artifacts
_BIGFRAMES_FUNCTION_PREFIX = "bigframes"
_BQ_FUNCTION_NAME_SEPERATOR = "_"
_GCF_FUNCTION_NAME_SEPERATOR = "-"
# Protocol version 4 is available in python version 3.4 and above
# https://docs.python.org/3/library/pickle.html#data-stream-format
_pickle_protocol_version = 4
def get_remote_function_locations(bq_location):
"""Get BQ location and cloud functions region given a BQ client."""
# TODO(shobs, b/274647164): Find the best way to determine default location.
# For now let's assume that if no BQ location is set in the client then it
# defaults to US multi region
bq_location = bq_location.lower() if bq_location else "us"
# Cloud function should be in the same region as the bigquery remote function
cloud_function_region = bq_location
# BigQuery has multi region but cloud functions does not.
# Any region in the multi region that supports cloud functions should work
# https://cloud.google.com/functions/docs/locations
if bq_location == "us":
cloud_function_region = "us-central1"
elif bq_location == "eu":
cloud_function_region = "europe-west1"
return bq_location, cloud_function_region
def _package_existed(package_requirements: list[str], package: str) -> bool:
"""Checks if a package (regardless of version) exists in a given list."""
if not package_requirements:
return False
return Requirement(package).name in {
Requirement(req).name for req in package_requirements
}
def get_updated_package_requirements(
package_requirements: Sequence[str] = (),
is_row_processor: bool = False,
capture_references: bool = True,
ignore_package_version: bool = False,
) -> Sequence[str]:
requirements: list[str] = []
if capture_references:
requirements.append(f"cloudpickle=={cloudpickle.__version__}")
if is_row_processor:
if ignore_package_version:
# TODO(jialuo): Add back the version after b/410924784 is resolved.
# Due to current limitations on the packages version in Python UDFs,
# we use `ignore_package_version` to optionally omit the version for
# managed functions only.
msg = bfe.format_message(
"numpy, pandas, and pyarrow versions in the function execution"
" environment may not precisely match your local environment."
)
warnings.warn(msg, category=bfe.FunctionPackageVersionWarning)
requirements.append("pandas")
requirements.append("pyarrow")
requirements.append("numpy")
else:
# bigframes function will send an entire row of data as json, which
# would be converted to a pandas series and processed Ensure numpy
# versions match to avoid unpickling problems. See internal issue
# b/347934471.
requirements.append(f"pandas=={pandas.__version__}")
requirements.append(f"pyarrow=={pyarrow.__version__}")
requirements.append(f"numpy=={numpy.__version__}")
if not requirements:
return package_requirements
result = list(package_requirements)
for package in requirements:
if not _package_existed(result, package):
result.append(package)
return sorted(result)
def clean_up_by_session_id(
bqclient: bigquery.Client,
gcfclient: functions_v2.FunctionServiceClient,
dataset: bigquery.DatasetReference,
session_id: str,
):
"""Delete remote function artifacts for a session id, where the session id
was not necessarily created in the current runtime. This is useful if the
user worked with a BigQuery DataFrames session previously and remembered the
session id, and now wants to clean up its temporary resources at a later
point in time.
"""
# First clean up the BQ remote functions and then the underlying cloud
# functions, so that at no point we are left with a remote function that is
# pointing to a cloud function that does not exist
endpoints_to_be_deleted: Set[str] = set()
match_prefix = "".join(
[
_BIGFRAMES_FUNCTION_PREFIX,
_BQ_FUNCTION_NAME_SEPERATOR,
session_id,
_BQ_FUNCTION_NAME_SEPERATOR,
]
)
for routine in bqclient.list_routines(dataset):
routine = cast(bigquery.Routine, routine)
# skip past the routines not belonging to the given session id, or
# non-remote-function routines
if (
routine.type_ != bigquery.RoutineType.SCALAR_FUNCTION
or not cast(str, routine.routine_id).startswith(match_prefix)
or not routine.remote_function_options
or not routine.remote_function_options.endpoint
):
continue
# Let's forgive the edge case possibility that the BQ remote function
# may have been deleted at the same time directly by the user
bqclient.delete_routine(routine, not_found_ok=True)
endpoints_to_be_deleted.add(routine.remote_function_options.endpoint)
# Now clean up the cloud functions
bq_location = bqclient.get_dataset(dataset).location
bq_location, gcf_location = get_remote_function_locations(bq_location)
parent_path = gcfclient.common_location_path(
project=dataset.project, location=gcf_location
)
for gcf in gcfclient.list_functions(parent=parent_path):
# skip past the cloud functions not attached to any BQ remote function
# belonging to the given session id
if gcf.service_config.uri not in endpoints_to_be_deleted:
continue
# Let's forgive the edge case possibility that the cloud function
# may have been deleted at the same time directly by the user
try:
gcfclient.delete_function(name=gcf.name)
except google.api_core.exceptions.NotFound:
pass
def routine_ref_to_string_for_query(routine_ref: bigquery.RoutineReference) -> str:
return f"`{routine_ref.project}.{routine_ref.dataset_id}`.{routine_ref.routine_id}"
# Deprecated: Use CodeDef.stable_hash() instead.
def get_hash(def_, package_requirements=None):
"Get hash (32 digits alphanumeric) of a function."
# There is a known cell-id sensitivity of the cloudpickle serialization in
# notebooks https://github.com/cloudpipe/cloudpickle/issues/538. Because of
# this, if a cell contains a udf decorated with @remote_function, a unique
# cloudpickle code is generated every time the cell is run, creating new
# cloud artifacts every time. This is slow and wasteful.
# A workaround of the same can be achieved by replacing the filename in the
# code object to a static value
# https://github.com/cloudpipe/cloudpickle/issues/120#issuecomment-338510661.
#
# To respect the user code/environment let's make this modification on a
# copy of the udf, not on the original udf itself.
def_copy = cloudpickle.loads(cloudpickle.dumps(def_))
def_copy.__code__ = def_copy.__code__.replace(
co_filename="bigframes_place_holder_filename"
)
def_repr = cloudpickle.dumps(def_copy, protocol=_pickle_protocol_version)
if package_requirements:
for p in sorted(package_requirements):
def_repr += p.encode()
return hashlib.md5(def_repr).hexdigest()
def get_python_output_type_str_from_bigframes_metadata(
metadata_text: str,
) -> Optional[str]:
try:
metadata_dict = json.loads(metadata_text)
except (TypeError, json.decoder.JSONDecodeError):
return None
try:
return metadata_dict["value"]["python_array_output_type"]
except KeyError:
return None
def get_python_output_type_from_bigframes_metadata(
metadata_text: str,
) -> Optional[type]:
output_type_str = get_python_output_type_str_from_bigframes_metadata(metadata_text)
for (
python_output_array_type
) in function_typing.RF_SUPPORTED_ARRAY_OUTPUT_PYTHON_TYPES:
if python_output_array_type.__name__ == output_type_str:
return list[python_output_array_type] # type: ignore
return None
def get_bigframes_metadata(*, python_output_type: Optional[type] = None) -> str:
# Let's keep the actual metadata inside one level of nesting so that in
# future we can use a top level key "version" (parallel to "value"), based
# on which "value" can be interpreted according to the "version". The
# absence of "version" should be interpreted as default version.
inner_metadata = {}
if typing.get_origin(python_output_type) is list:
python_output_array_type = typing.get_args(python_output_type)[0]
if (
python_output_array_type
in function_typing.RF_SUPPORTED_ARRAY_OUTPUT_PYTHON_TYPES
):
inner_metadata[
"python_array_output_type"
] = python_output_array_type.__name__
metadata = {"value": inner_metadata}
metadata_ser = json.dumps(metadata)
# let's make sure the serialized value is deserializable
if (
get_python_output_type_from_bigframes_metadata(metadata_ser)
!= python_output_type
):
raise bf_formatting.create_exception_with_feedback_link(
ValueError, f"python_output_type {python_output_type} is not serializable."
)
return metadata_ser
def get_python_version(is_compat: bool = False) -> str:
# Cloud Run functions use the 'compat' format (e.g., python311, see more
# from https://cloud.google.com/functions/docs/runtime-support#python),
# while managed functions use the standard format (e.g., python-3.11).
major = sys.version_info.major
minor = sys.version_info.minor
return f"python{major}{minor}" if is_compat else f"python-{major}.{minor}"
def has_conflict_input_type(
signature: inspect.Signature,
input_types: Sequence[Any],
) -> bool:
"""Checks if the parameters have any conflict with the input_types."""
params = list(signature.parameters.values())
if len(params) != len(input_types):
return True
# Check for conflicts type hints.
for i, param in enumerate(params):
if param.annotation is not inspect.Parameter.empty:
if param.annotation != input_types[i]:
return True
# No conflicts were found after checking all parameters.
return False
def has_conflict_output_type(
signature: inspect.Signature,
output_type: Any,
) -> bool:
"""Checks if the return type annotation conflicts with the output_type."""
return_annotation = signature.return_annotation
if return_annotation is inspect.Parameter.empty:
return False
return return_annotation != output_type