Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 176c861

Browse files
refactor: Simplify @udf wrapper object
1 parent 677d6cc commit 176c861

File tree

7 files changed

+792
-1043
lines changed

7 files changed

+792
-1043
lines changed

bigframes/dataframe.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4662,7 +4662,13 @@ def _prepare_export(
46624662
return array_value, id_overrides
46634663

46644664
def map(self, func, na_action: Optional[str] = None) -> DataFrame:
4665-
if not isinstance(func, bigframes.functions.BigqueryCallableRoutine):
4665+
if not isinstance(
4666+
func,
4667+
(
4668+
bigframes.functions.BigqueryCallableRoutine,
4669+
bigframes.functions.UdfRoutine,
4670+
),
4671+
):
46664672
raise TypeError("the first argument must be callable")
46674673

46684674
if na_action not in {None, "ignore"}:
@@ -4690,14 +4696,14 @@ def apply(self, func, *, axis=0, args: typing.Tuple = (), **kwargs):
46904696
func,
46914697
(
46924698
bigframes.functions.BigqueryCallableRoutine,
4693-
bigframes.functions.BigqueryCallableRowRoutine,
4699+
bigframes.functions.UdfRoutine,
46944700
),
46954701
):
46964702
raise ValueError(
46974703
"For axis=1 a BigFrames BigQuery function must be used."
46984704
)
46994705

4700-
if func.is_row_processor:
4706+
if func.udf_def.signature.is_row_processor:
47014707
# Early check whether the dataframe dtypes are currently supported
47024708
# in the bigquery function
47034709
# NOTE: Keep in sync with the value converters used in the gcf code

bigframes/functions/__init__.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from bigframes.functions.function import (
15-
BigqueryCallableRoutine,
16-
BigqueryCallableRowRoutine,
17-
)
14+
from bigframes.functions.function import BigqueryCallableRoutine, UdfRoutine
1815

1916
__all__ = [
2017
"BigqueryCallableRoutine",
21-
"BigqueryCallableRowRoutine",
18+
"UdfRoutine",
2219
]

bigframes/functions/_function_session.py

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -631,25 +631,15 @@ def wrapper(func):
631631
if udf_sig.is_row_processor:
632632
msg = bfe.format_message("input_types=Series is in preview.")
633633
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
634-
return decorator(
635-
bq_functions.BigqueryCallableRowRoutine(
636-
udf_definition,
637-
session,
638-
cloud_function_ref=bigframes_cloud_function,
639-
local_func=func,
640-
is_managed=False,
641-
)
642-
)
643-
else:
644-
return decorator(
645-
bq_functions.BigqueryCallableRoutine(
646-
udf_definition,
647-
session,
648-
cloud_function_ref=bigframes_cloud_function,
649-
local_func=func,
650-
is_managed=False,
651-
)
634+
return decorator(
635+
bq_functions.BigqueryCallableRoutine(
636+
udf_definition,
637+
session,
638+
cloud_function_ref=bigframes_cloud_function,
639+
local_func=func,
640+
is_managed=False,
652641
)
642+
)
653643

654644
return wrapper
655645

@@ -835,8 +825,9 @@ def wrapper(func):
835825
bq_connection_manager,
836826
session=session, # type: ignore
837827
)
828+
code_def = udf_def.CodeDef.from_func(func)
838829
config = udf_def.ManagedFunctionConfig(
839-
code=udf_def.CodeDef.from_func(func),
830+
code=code_def,
840831
signature=udf_sig,
841832
max_batching_rows=max_batching_rows,
842833
container_cpu=container_cpu,
@@ -863,26 +854,11 @@ def wrapper(func):
863854
if not name:
864855
self._update_temp_artifacts(full_rf_name, "")
865856

866-
decorator = functools.wraps(func)
867857
if udf_sig.is_row_processor:
868858
msg = bfe.format_message("input_types=Series is in preview.")
869859
warnings.warn(msg, stacklevel=1, category=bfe.PreviewWarning)
870-
assert session is not None # appease mypy
871-
return decorator(
872-
bq_functions.BigqueryCallableRowRoutine(
873-
udf_definition, session, local_func=func, is_managed=True
874-
)
875-
)
876-
else:
877-
assert session is not None # appease mypy
878-
return decorator(
879-
bq_functions.BigqueryCallableRoutine(
880-
udf_definition,
881-
session,
882-
local_func=func,
883-
is_managed=True,
884-
)
885-
)
860+
861+
return bq_functions.UdfRoutine(func=func, _udf_def=udf_definition)
886862

887863
return wrapper
888864

bigframes/functions/function.py

Lines changed: 14 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
from bigframes.session import Session
2222
import bigframes.series
2323

24+
import dataclasses
25+
import functools
26+
2427
import google.api_core.exceptions
2528
from google.cloud import bigquery
2629

@@ -90,13 +93,13 @@ def _try_import_routine(
9093

9194
def _try_import_row_routine(
9295
routine: bigquery.Routine, session: bigframes.Session
93-
) -> BigqueryCallableRowRoutine:
96+
) -> BigqueryCallableRoutine:
9497
udf_def = _routine_as_udf_def(routine, is_row_processor=True)
9598

9699
is_remote = (
97100
hasattr(routine, "remote_function_options") and routine.remote_function_options
98101
)
99-
return BigqueryCallableRowRoutine(udf_def, session, is_managed=not is_remote)
102+
return BigqueryCallableRoutine(udf_def, session, is_managed=not is_remote)
100103

101104

102105
def _routine_as_udf_def(
@@ -117,7 +120,6 @@ def _routine_as_udf_def(
117120
)
118121

119122

120-
# TODO(b/399894805): Support managed function.
121123
def read_gbq_function(
122124
function_name: str,
123125
*,
@@ -202,7 +204,7 @@ def bigframes_remote_function(self):
202204

203205
@property
204206
def is_row_processor(self) -> bool:
205-
return False
207+
return self.udf_def.signature.is_row_processor
206208

207209
@property
208210
def udf_def(self) -> udf_def.BigqueryUdf:
@@ -225,75 +227,17 @@ def bigframes_bigquery_function_output_dtype(self):
225227
return self.udf_def.signature.output.emulating_type.bf_type
226228

227229

228-
class BigqueryCallableRowRoutine:
229-
"""
230-
A reference to a routine in the context of a session.
231-
232-
Can be used both directly as a callable, or as an input to dataframe ops that take a callable.
233-
"""
234-
235-
def __init__(
236-
self,
237-
udf_def: udf_def.BigqueryUdf,
238-
session: bigframes.Session,
239-
*,
240-
local_func: Optional[Callable] = None,
241-
cloud_function_ref: Optional[str] = None,
242-
is_managed: bool = False,
243-
):
244-
assert udf_def.signature.is_row_processor
245-
self._udf_def = udf_def
246-
self._session = session
247-
self._local_fun = local_func
248-
self._cloud_function = cloud_function_ref
249-
self._is_managed = is_managed
230+
@dataclasses.dataclass(frozen=True)
231+
class UdfRoutine:
232+
func: Callable
233+
# Try not to depend on this, bq managed function creation will be deferred later
234+
# And this ref will be replaced with requirements rather to support lazy creation
235+
_udf_def: udf_def.BigqueryUdf
250236

237+
@functools.partial
251238
def __call__(self, *args, **kwargs):
252-
if self._local_fun:
253-
return self._local_fun(*args, **kwargs)
254-
# avoid circular imports
255-
from bigframes.core.compile.sqlglot import sql as sg_sql
256-
import bigframes.session._io.bigquery as bf_io_bigquery
257-
258-
args_string = ", ".join([sg_sql.to_sql(sg_sql.literal(v)) for v in args])
259-
sql = f"SELECT `{str(self._udf_def.routine_ref)}`({args_string})"
260-
iter, job = bf_io_bigquery.start_query_with_client(
261-
self._session.bqclient,
262-
sql=sql,
263-
query_with_job=True,
264-
job_config=bigquery.QueryJobConfig(),
265-
publisher=self._session._publisher,
266-
) # type: ignore
267-
return list(iter.to_arrow().to_pydict().values())[0][0]
268-
269-
@property
270-
def bigframes_bigquery_function(self) -> str:
271-
return str(self._udf_def.routine_ref)
272-
273-
@property
274-
def bigframes_remote_function(self):
275-
return None if self._is_managed else str(self._udf_def.routine_ref)
276-
277-
@property
278-
def is_row_processor(self) -> bool:
279-
return True
239+
return self.func(*args, **kwargs)
280240

281241
@property
282242
def udf_def(self) -> udf_def.BigqueryUdf:
283243
return self._udf_def
284-
285-
@property
286-
def bigframes_cloud_function(self) -> Optional[str]:
287-
return self._cloud_function
288-
289-
@property
290-
def input_dtypes(self):
291-
return tuple(arg.bf_type for arg in self.udf_def.signature.inputs)
292-
293-
@property
294-
def output_dtype(self):
295-
return self.udf_def.signature.output.bf_type
296-
297-
@property
298-
def bigframes_bigquery_function_output_dtype(self):
299-
return self.udf_def.signature.output.emulating_type.bf_type

bigframes/functions/udf_def.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,14 @@ def stable_hash(self) -> bytes:
455455

456456
return hash_val.digest()
457457

458+
def to_callable(self):
459+
"""
460+
Reconstructs the python callable from the pickled code.
461+
462+
Assumption: package_requirements match local environment
463+
"""
464+
return cloudpickle.loads(self.pickled_code)
465+
458466

459467
@dataclasses.dataclass(frozen=True)
460468
class ManagedFunctionConfig:

bigframes/series.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2029,7 +2029,13 @@ def apply(
20292029
" are supported."
20302030
)
20312031

2032-
if isinstance(func, bigframes.functions.BigqueryCallableRoutine):
2032+
if isinstance(
2033+
func,
2034+
(
2035+
bigframes.functions.BigqueryCallableRoutine,
2036+
bigframes.functions.UdfRoutine,
2037+
),
2038+
):
20332039
# We are working with bigquery function at this point
20342040
if args:
20352041
result_series = self._apply_nary_op(
@@ -2090,7 +2096,13 @@ def combine(
20902096
" are supported."
20912097
)
20922098

2093-
if isinstance(func, bigframes.functions.BigqueryCallableRoutine):
2099+
if isinstance(
2100+
func,
2101+
(
2102+
bigframes.functions.BigqueryCallableRoutine,
2103+
bigframes.functions.UdfRoutine,
2104+
),
2105+
):
20942106
result_series = self._apply_binary_op(
20952107
other, ops.BinaryRemoteFunctionOp(function_def=func.udf_def)
20962108
)

0 commit comments

Comments
 (0)