Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 33f5727

Browse files
fixes
1 parent 2efc89e commit 33f5727

File tree

8 files changed

+94
-63
lines changed

8 files changed

+94
-63
lines changed

bigframes/core/compile/ibis_compiler/scalar_op_registry.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,6 +1037,7 @@ def timedelta_floor_op_impl(x: ibis_types.NumericValue):
10371037
@scalar_op_compiler.register_unary_op(ops.RemoteFunctionOp, pass_op=True)
10381038
def remote_function_op_impl(x: ibis_types.Value, op: ops.RemoteFunctionOp):
10391039
udf_sig = op.function_def.signature
1040+
assert not udf_sig.is_virtual # should have been devirtualized in lowering pass
10401041
ibis_py_sig = (tuple(arg.py_type for arg in udf_sig.inputs), udf_sig.output.py_type)
10411042

10421043
@ibis_udf.scalar.builtin(
@@ -1056,6 +1057,7 @@ def binary_remote_function_op_impl(
10561057
x: ibis_types.Value, y: ibis_types.Value, op: ops.BinaryRemoteFunctionOp
10571058
):
10581059
udf_sig = op.function_def.signature
1060+
assert not udf_sig.is_virtual # should have been devirtualized in lowering pass
10591061
ibis_py_sig = (tuple(arg.py_type for arg in udf_sig.inputs), udf_sig.output.py_type)
10601062

10611063
@ibis_udf.scalar.builtin(
@@ -1073,6 +1075,7 @@ def nary_remote_function_op_impl(
10731075
*operands: ibis_types.Value, op: ops.NaryRemoteFunctionOp
10741076
):
10751077
udf_sig = op.function_def.signature
1078+
assert not udf_sig.is_virtual # should have been devirtualized in lowering pass
10761079
ibis_py_sig = (tuple(arg.py_type for arg in udf_sig.inputs), udf_sig.output.py_type)
10771080
arg_names = tuple(arg.name for arg in udf_sig.inputs)
10781081

bigframes/core/rewrite/udfs.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,14 @@ def op(self) -> type[ops.ScalarOp]:
3030
def lower(self, expr: expression.OpExpression) -> expression.Expression:
3131
assert isinstance(expr.op, ops.RemoteFunctionOp)
3232
func_def = expr.op.function_def
33-
if isinstance(func_def.signature.output, udf_def.DirectScalarType):
34-
return expr.op.as_expr(*expr.children)
35-
assert isinstance(func_def.signature.output, udf_def.VirtualListTypeV1)
3633
devirtualized_expr = ops.RemoteFunctionOp(
3734
func_def.with_devirtualize(),
3835
apply_on_null=expr.op.apply_on_null,
3936
).as_expr(*expr.children)
40-
return func_def.signature.output.out_expr(devirtualized_expr)
37+
if isinstance(func_def.signature.output, udf_def.VirtualListTypeV1):
38+
return func_def.signature.output.out_expr(devirtualized_expr)
39+
else:
40+
return devirtualized_expr
4141

4242

4343
@dataclasses.dataclass
@@ -49,14 +49,13 @@ def op(self) -> type[ops.ScalarOp]:
4949
def lower(self, expr: expression.OpExpression) -> expression.Expression:
5050
assert isinstance(expr.op, ops.BinaryRemoteFunctionOp)
5151
func_def = expr.op.function_def
52-
53-
if isinstance(func_def.signature.output, udf_def.DirectScalarType):
54-
return expr.op.as_expr(*expr.children)
55-
assert isinstance(func_def.signature.output, udf_def.VirtualListTypeV1)
5652
devirtualized_expr = ops.BinaryRemoteFunctionOp(
5753
func_def.with_devirtualize(),
5854
).as_expr(*expr.children)
59-
return func_def.signature.output.out_expr(devirtualized_expr)
55+
if isinstance(func_def.signature.output, udf_def.VirtualListTypeV1):
56+
return func_def.signature.output.out_expr(devirtualized_expr)
57+
else:
58+
return devirtualized_expr
6059

6160

6261
@dataclasses.dataclass
@@ -68,13 +67,13 @@ def op(self) -> type[ops.ScalarOp]:
6867
def lower(self, expr: expression.OpExpression) -> expression.Expression:
6968
assert isinstance(expr.op, ops.NaryRemoteFunctionOp)
7069
func_def = expr.op.function_def
71-
if isinstance(func_def.signature.output, udf_def.DirectScalarType):
72-
return expr.op.as_expr(*expr.children)
73-
assert isinstance(func_def.signature.output, udf_def.VirtualListTypeV1)
7470
devirtualized_expr = ops.NaryRemoteFunctionOp(
7571
func_def.with_devirtualize(),
7672
).as_expr(*expr.children)
77-
return func_def.signature.output.out_expr(devirtualized_expr)
73+
if isinstance(func_def.signature.output, udf_def.VirtualListTypeV1):
74+
return func_def.signature.output.out_expr(devirtualized_expr)
75+
else:
76+
return devirtualized_expr
7877

7978

8079
UDF_LOWERING_RULES = (

bigframes/functions/_function_client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,6 @@ def create_bq_remote_function(
177177

178178
# Create BQ function
179179
# https://cloud.google.com/bigquery/docs/reference/standard-sql/remote-functions#create_a_remote_function_2
180-
bq_function_return_type = udf_def.signature.output.sql_type
181180

182181
remote_function_options = {
183182
"endpoint": udf_def.endpoint,
@@ -203,7 +202,7 @@ def create_bq_remote_function(
203202
bq_function_name_escaped = bigframes.core.sql.identifier(sql_func_legal_name)
204203
create_function_ddl = f"""
205204
CREATE OR REPLACE FUNCTION `{self._gcp_project_id}.{self._bq_dataset}`.{bq_function_name_escaped}({udf_def.signature.to_sql_input_signature()})
206-
RETURNS {bq_function_return_type}
205+
RETURNS {udf_def.signature.with_devirtualize().output.sql_type}
207206
REMOTE WITH CONNECTION `{self._gcp_project_id}.{self._bq_location}.{self._bq_connection_id}`
208207
OPTIONS ({remote_function_options_str})"""
209208

@@ -658,6 +657,7 @@ def provision_bq_remote_function(
658657
connection_id=self._bq_connection_id,
659658
max_batching_rows=max_batching_rows,
660659
signature=func_signature,
660+
bq_metadata=func_signature.protocol_metadata,
661661
)
662662
remote_function_name = name or get_bigframes_function_name(
663663
intended_rf_spec,

bigframes/functions/_function_session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -596,7 +596,9 @@ def wrapper(func):
596596
session=session, # type: ignore
597597
)
598598

599-
udf_sig = udf_def.UdfSignature.from_py_signature(py_sig)
599+
udf_sig = udf_def.UdfSignature.from_py_signature(
600+
py_sig
601+
).to_remote_function_compatible()
600602

601603
(
602604
rf_name,

bigframes/functions/function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def __init__(
235235
cloud_function_ref: Optional[str] = None,
236236
is_managed: bool = False,
237237
):
238-
assert self.udf_def.signature.is_row_processor
238+
assert udf_def.signature.is_row_processor
239239
self._udf_def = udf_def
240240
self._session = session
241241
self._local_fun = local_func

bigframes/functions/function_typing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(self, type_, supported_types):
8181

8282

8383
def sdk_type_from_python_type(
84-
t: type, allow_lists: bool = False
84+
t: type, allow_lists: bool = True
8585
) -> bigquery.StandardSqlDataType:
8686
if (get_origin(t) is list) and allow_lists:
8787
return sdk_array_output_type_from_python_type(t)

bigframes/functions/udf_def.py

Lines changed: 55 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ def stable_hash(self) -> bytes:
9292

9393
@dataclasses.dataclass(frozen=True)
9494
class DirectScalarType:
95+
"""
96+
Represents a scalar value that is passed directly to the remote function.
97+
98+
For these values, BigQuery handles the serialization and deserialization without any additional processing.
99+
"""
100+
95101
_py_type: type
96102

97103
@property
@@ -119,16 +125,26 @@ def stable_hash(self) -> bytes:
119125
def from_sdk_type(cls, sdk_type: bigquery.StandardSqlDataType) -> DirectScalarType:
120126
return cls(function_typing.sdk_type_to_py_type(sdk_type))
121127

128+
@property
129+
def emulating_type(self) -> DirectScalarType:
130+
return self
131+
122132

123133
@dataclasses.dataclass(frozen=True)
124134
class VirtualListTypeV1:
135+
"""
136+
Represents a list of scalar values that is emulated as a JSON array string in the remote function.
137+
138+
Only works as output paramter right now where array -> string in function runtime, and then string -> array in SQL post-processing (defined in out_expr()).
139+
"""
140+
125141
_PROTOCOL_ID = "virtual_list_v1"
126142

127143
inner_dtype: DirectScalarType
128144

129145
@property
130146
def py_type(self) -> Type[list[Any]]:
131-
return list[function_typing.sdk_type_to_py_type(self.inner_dtype)] # type: ignore
147+
return list[self.inner_dtype.py_type] # type: ignore
132148

133149
# TODO: Specify emulating type and mapping expressions between said types
134150
@property
@@ -163,6 +179,8 @@ def stable_hash(self) -> bytes:
163179
class RowSeriesInputFieldV1:
164180
"""
165181
Used to handle functions that logically take a series as an input, but handled via a string protocol in the remote function.
182+
183+
For these, the serialization is dependent on index metadata, which must be provided by the caller.
166184
"""
167185

168186
_PROTOCOL_ID = "row_series_input_v1"
@@ -180,6 +198,11 @@ def bf_type(self) -> bigframes.dtypes.Dtype:
180198
def sql_type(self) -> str:
181199
return "STRING"
182200

201+
@property
202+
def emulating_type(self) -> DirectScalarType:
203+
# Regardless of list inner type, string is used to emulate the list in the remote function.
204+
return DirectScalarType(str)
205+
183206
def stable_hash(self) -> bytes:
184207
hash_val = hashlib.md5()
185208
hash_val.update(self._PROTOCOL_ID.encode())
@@ -196,14 +219,14 @@ class UdfSignature:
196219
output: DirectScalarType | VirtualListTypeV1
197220

198221
def __post_init__(self):
199-
if any(isinstance(arg, RowSeriesInputFieldV1) for arg in self.inputs):
200-
if len(self.inputs) != 1:
201-
raise ValueError("Row processor functions must have exactly one input.")
202222
assert all(isinstance(arg, UdfArg) for arg in self.inputs)
203223
assert isinstance(self.output, (DirectScalarType, VirtualListTypeV1))
204224

205225
def to_sql_input_signature(self) -> str:
206-
return ",".join(f"{field.name} {field.sql_type}" for field in self.inputs)
226+
return ",".join(
227+
f"{field.name} {field.sql_type}"
228+
for field in self.with_devirtualize().inputs
229+
)
207230

208231
@property
209232
def protocol_metadata(self) -> str:
@@ -214,16 +237,20 @@ def protocol_metadata(self) -> str:
214237
python_output_type=self.output.py_type
215238
)
216239

240+
@property
241+
def is_virtual(self) -> bool:
242+
dtypes = (self.output,) + tuple(arg.dtype for arg in self.inputs)
243+
return not all(isinstance(dtype, DirectScalarType) for dtype in dtypes)
244+
217245
@property
218246
def is_row_processor(self) -> bool:
219-
return any(isinstance(arg, RowSeriesInputFieldV1) for arg in self.inputs)
247+
return any(isinstance(arg.dtype, RowSeriesInputFieldV1) for arg in self.inputs)
220248

221249
def with_devirtualize(self) -> UdfSignature:
222-
if isinstance(self.output, DirectScalarType):
223-
return self
224-
assert isinstance(self.output, VirtualListTypeV1)
225250
return UdfSignature(
226-
inputs=self.inputs,
251+
inputs=tuple(
252+
UdfArg(arg.name, arg.dtype.emulating_type) for arg in self.inputs
253+
),
227254
output=self.output.emulating_type,
228255
)
229256

@@ -245,7 +272,7 @@ def from_routine(cls, routine: bigquery.Routine) -> UdfSignature:
245272
if python_output_type := bigframes.functions._utils.get_python_output_type_from_bigframes_metadata(
246273
routine.description
247274
):
248-
if routine.return_type is None or bq_return_type.type_kind != "STRING":
275+
if bq_return_type.type_kind != "STRING":
249276
raise bf_formatting.create_exception_with_feedback_link(
250277
TypeError,
251278
"An explicit output_type should be provided only for a BigQuery function with STRING output.",
@@ -280,7 +307,7 @@ def from_py_signature(cls, signature: inspect.Signature):
280307
ValueError,
281308
"'input_types' was not set and parameter "
282309
f"'{parameter.name}' is missing a type annotation. "
283-
"Types are required to use @remote_function.",
310+
"Types are required to use udfs.",
284311
)
285312

286313
input_types.append(UdfArg.from_py_param(parameter))
@@ -290,18 +317,22 @@ def from_py_signature(cls, signature: inspect.Signature):
290317
ValueError,
291318
"'output_type' was not set and function is missing a "
292319
"return type annotation. Types are required to use "
293-
"@remote_function.",
320+
"udfs.",
294321
)
295322

296-
if get_origin(signature.return_annotation) is list:
297-
inner_py_type = get_args(signature.return_annotation)[0]
298-
virtual_list_output_type = VirtualListTypeV1(
299-
DirectScalarType(inner_py_type)
300-
)
301-
return cls(tuple(input_types), virtual_list_output_type)
302-
else:
303-
direct_output_type = DirectScalarType(signature.return_annotation)
304-
return cls(tuple(input_types), direct_output_type)
323+
output_type = DirectScalarType(signature.return_annotation)
324+
return cls(tuple(input_types), output_type)
325+
326+
def to_remote_function_compatible(self) -> UdfSignature:
327+
# need to virtualize list outputs
328+
if isinstance(self.output, DirectScalarType):
329+
if get_origin(self.output.py_type) is list:
330+
inner_py_type = get_args(self.output.py_type)[0]
331+
return UdfSignature(
332+
inputs=self.inputs,
333+
output=VirtualListTypeV1(DirectScalarType(inner_py_type)),
334+
)
335+
return self
305336

306337
def stable_hash(self) -> bytes:
307338
hash_val = hashlib.md5()
@@ -321,6 +352,8 @@ class BigqueryUdf:
321352
signature: UdfSignature
322353

323354
def with_devirtualize(self) -> BigqueryUdf:
355+
if not self.signature.is_virtual:
356+
return self
324357
return BigqueryUdf(
325358
routine_ref=self.routine_ref,
326359
signature=self.signature.with_devirtualize(),

0 commit comments

Comments
 (0)