Skip to content

Commit a10a813

Browse files
refactor(bigframes): Introduce GoogleSqlScalarOp (#17037)
Thank you for opening a Pull Request! Before submitting your PR, there are a few things you can do to make sure it goes smoothly: - [ ] Make sure to open an issue as a [bug/issue](https://github.com/googleapis/google-cloud-python/issues) before writing your code! That way we can discuss the change, evaluate designs, and agree on the general idea - [ ] Ensure the tests and linter pass - [ ] Code coverage does not decrease (if any source code was changed) - [ ] Appropriate docs were updated (if necessary) Fixes #<issue_number_goes_here> 🦕
1 parent 7f17ff1 commit a10a813

24 files changed

Lines changed: 351 additions & 94 deletions

File tree

packages/bigframes/bigframes/bigquery/_operations/geo.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def st_area(
9999
bigframes.pandas.Series:
100100
Series of float representing the areas.
101101
"""
102-
series = series._apply_unary_op(ops.geo_area_op)
102+
series = series._apply_nary_op(ops.googlesql.ST_AREA, [])
103103
series.name = None
104104
return series
105105

@@ -223,7 +223,7 @@ def st_centroid(
223223
bigframes.pandas.Series:
224224
A series of geography objects representing the centroids.
225225
"""
226-
series = series._apply_unary_op(ops.geo_st_centroid_op)
226+
series = series._apply_nary_op(ops.googlesql.ST_CENTROID, [])
227227
series.name = None
228228
return series
229229

@@ -753,6 +753,4 @@ def st_simplify(
753753
Returns:
754754
a Series containing the simplified GEOGRAPHY data.
755755
"""
756-
return geography._apply_unary_op(
757-
ops.GeoStSimplifyOp(tolerance_meters=tolerance_meters)
758-
)
756+
return geography._apply_nary_op(ops.googlesql.ST_SIMPLIFY, [tolerance_meters])

packages/bigframes/bigframes/bigquery/_operations/mathematical.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import bigframes.core.expression
2121
from bigframes import dtypes
2222
from bigframes import operations as ops
23+
from bigframes.operations import googlesql
2324

2425

2526
def rand() -> bigframes.core.col.Expression:
@@ -47,12 +48,9 @@ def rand() -> bigframes.core.col.Expression:
4748
:func:`~bigframes.pandas.DataFrame.assign` and other methods. See
4849
:func:`bigframes.pandas.col`.
4950
"""
50-
op = ops.SqlScalarOp(
51-
_output_type=dtypes.FLOAT_DTYPE,
52-
sql_template="RAND()",
53-
is_deterministic=False,
51+
return bigframes.core.col.Expression(
52+
bigframes.core.expression.OpExpression(googlesql.RAND, ())
5453
)
55-
return bigframes.core.col.Expression(bigframes.core.expression.OpExpression(op, ()))
5654

5755

5856
def hparam_range(min: float, max: float) -> bigframes.core.col.Expression:

packages/bigframes/bigframes/core/compile/ibis_compiler/ibis_compiler.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import bigframes.core.nodes as nodes
3131
import bigframes.core.ordering as bf_ordering
3232
import bigframes.core.rewrite as rewrites
33+
import bigframes.core.rewrite.schema_binding as schema_binding
3334
from bigframes import dtypes, operations
3435
from bigframes.core import bq_data, expression, pyarrow_utils
3536
from bigframes.core.logging import data_types as data_type_logger
@@ -59,6 +60,11 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
5960
if request.sort_rows:
6061
result_node = cast(nodes.ResultNode, rewrites.column_pruning(result_node))
6162
encoded_type_refs = data_type_logger.encode_type_refs(result_node)
63+
# Have to bind schema as the final step before compilation.
64+
# Probably, should defer even further
65+
result_node = typing.cast(
66+
nodes.ResultNode, schema_binding.bind_schema_to_tree(result_node)
67+
)
6268
sql = compile_result_node(result_node)
6369
return configs.CompileResult(
6470
sql,
@@ -72,6 +78,11 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
7278
result_node = cast(nodes.ResultNode, rewrites.column_pruning(result_node))
7379
result_node = cast(nodes.ResultNode, rewrites.defer_selection(result_node))
7480
encoded_type_refs = data_type_logger.encode_type_refs(result_node)
81+
# Have to bind schema as the final step before compilation.
82+
# Probably, should defer even further
83+
result_node = typing.cast(
84+
nodes.ResultNode, schema_binding.bind_schema_to_tree(result_node)
85+
)
7586
sql = compile_result_node(result_node)
7687
# Return the ordering iff no extra columns are needed to define the row order
7788
if ordering is not None:

packages/bigframes/bigframes/core/compile/ibis_compiler/operations/geo_ops.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,6 @@
3030

3131

3232
# Geo Ops
33-
@register_unary_op(ops.geo_area_op)
34-
def geo_area_op_impl(x: ibis_types.Value):
35-
return cast(ibis_types.GeoSpatialValue, x).area()
36-
37-
3833
@register_unary_op(ops.geo_st_astext_op)
3934
def geo_st_astext_op_impl(x: ibis_types.Value):
4035
return cast(ibis_types.GeoSpatialValue, x).as_text()
@@ -55,11 +50,6 @@ def geo_st_buffer_op_impl(x: ibis_types.Value, op: ops.GeoStBufferOp):
5550
)
5651

5752

58-
@register_unary_op(ops.geo_st_centroid_op, pass_op=False)
59-
def geo_st_centroid_op_impl(x: ibis_types.Value):
60-
return cast(ibis_types.GeoSpatialValue, x).centroid()
61-
62-
6353
@register_unary_op(ops.geo_st_convexhull_op, pass_op=False)
6454
def geo_st_convexhull_op_impl(x: ibis_types.Value):
6555
return st_convexhull(x)
@@ -132,12 +122,6 @@ def geo_st_regionstats_op_impl(
132122
).to_expr()
133123

134124

135-
@register_unary_op(ops.GeoStSimplifyOp, pass_op=True)
136-
def st_simplify_op_impl(x: ibis_types.Value, op: ops.GeoStSimplifyOp):
137-
x = cast(ibis_types.GeoSpatialValue, x)
138-
return st_simplify(x, op.tolerance_meters)
139-
140-
141125
@register_unary_op(ops.geo_x_op)
142126
def geo_x_op_impl(x: ibis_types.Value):
143127
return cast(ibis_types.GeoSpatialValue, x).x()

packages/bigframes/bigframes/core/compile/ibis_compiler/scalar_op_compiler.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,13 @@
2121
from typing import TYPE_CHECKING
2222

2323
import bigframes_vendored.ibis
24+
import bigframes_vendored.ibis.expr.operations.generic as ibis_generic
2425
import bigframes_vendored.ibis.expr.types as ibis_types
2526

2627
import bigframes.core.compile.ibis_types
2728
import bigframes.core.expression as ex
2829
from bigframes.core import agg_expressions, ordering
30+
from bigframes.operations import googlesql as gsql_ops
2931
from bigframes.operations import numeric_ops
3032

3133
if TYPE_CHECKING:
@@ -92,8 +94,20 @@ def _(
9294
self.compile_expression(sub_expr, bindings)
9395
for sub_expr in expression.inputs
9496
]
97+
if isinstance(expression.op, gsql_ops.GoogleSqlScalarOp):
98+
return googlesql_scalar_op_impl(
99+
*inputs, op=expression.op, output_type=expression.output_type
100+
)
95101
return self.compile_row_op(expression.op, inputs)
96102

103+
@compile_expression.register
104+
def _(
105+
self,
106+
expression: ex.OmittedArg,
107+
bindings: typing.Dict[str, ibis_types.Value],
108+
) -> ibis_types.Value:
109+
return bigframes_vendored.ibis.omitted()
110+
97111
def compile_row_op(
98112
self, op: ops.RowOp, inputs: typing.Sequence[ibis_types.Value]
99113
) -> ibis_types.Value:
@@ -278,3 +292,39 @@ def isnanornull(arg):
278292
@scalar_op_compiler.register_unary_op(numeric_ops.isfinite_op)
279293
def isfinite(arg):
280294
return arg.isinf().negate() & arg.isnan().negate()
295+
296+
297+
def googlesql_scalar_op_impl(
298+
*operands: ibis_types.Value, op: ops.GoogleSqlScalarOp, output_type
299+
):
300+
final_operands: list[ibis_types.Value] = []
301+
arg_templates = []
302+
for i, operand in enumerate(operands):
303+
if i < len(op.args):
304+
arg_spec = op.args[i]
305+
else:
306+
assert op.args[-1].is_vararg, (
307+
f"Too many arguments, for {op.sql_name}, expected {len(op.args)}"
308+
)
309+
arg_spec = op.args[-1]
310+
if isinstance(operand.op(), ibis_generic.OmittedArg):
311+
assert arg_spec.optional, f"Argument omitted, but not optional"
312+
continue
313+
314+
target_idx = len(final_operands)
315+
final_operands.append(operand)
316+
if arg_spec.arg_name:
317+
arg_templates.append(f"{arg_spec.arg_name} => {{{target_idx}}}")
318+
else:
319+
arg_templates.append(f"{{{target_idx}}}")
320+
args_template = ", ".join(arg_templates)
321+
sql_template = f"{op.sql_name}({args_template})"
322+
return ibis_generic.SqlScalar(
323+
sql_template,
324+
values=tuple(
325+
typing.cast(ibis_generic.Value, expr.op()) for expr in final_operands
326+
),
327+
output_type=bigframes.core.compile.ibis_types.bigframes_dtype_to_ibis_dtype(
328+
output_type
329+
),
330+
).to_expr()

packages/bigframes/bigframes/core/compile/sqlglot/expression_compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,10 @@ def _(self, expr: agg_exprs.WindowExpression) -> sge.Expression:
9090

9191
@compile_expression.register
9292
def _(self, expr: ex.OpExpression) -> sge.Expression:
93-
# Non-recursively compiles the children scalar expressions.
9493
inputs = tuple(
9594
TypedExpr(self.compile_expression(sub_expr), sub_expr.output_type)
95+
if not isinstance(sub_expr, ex.OmittedArg)
96+
else TypedExpr(sge.Null(), None, is_omitted=True)
9697
for sub_expr in expr.inputs
9798
)
9899
return self.compile_row_op(expr.op, inputs)

packages/bigframes/bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,27 @@ def _(expr: TypedExpr) -> sge.Expression:
8282
return sge.BitwiseNot(this=sge.paren(expr.expr))
8383

8484

85+
@register_nary_op(ops.GoogleSqlScalarOp, pass_op=True)
86+
def _(*operands: TypedExpr, op: ops.GoogleSqlScalarOp) -> sge.Expression:
87+
args: list[sge.Expression] = []
88+
for i, operand in enumerate(operands):
89+
if i < len(op.args):
90+
arg_spec = op.args[i]
91+
else:
92+
assert op.args[-1].is_vararg, (
93+
f"Too many arguments, for {op.sql_name}, expected {len(op.args)}"
94+
)
95+
arg_spec = op.args[-1]
96+
if operand.is_omitted:
97+
assert arg_spec.optional, f"Argument omitted, but not optional"
98+
continue
99+
elif arg_spec.arg_name:
100+
args.append(sge.Kwarg(this=arg_spec.arg_name, expression=operand.expr))
101+
else:
102+
args.append(operand.expr)
103+
return sg.func(op.sql_name, *args)
104+
105+
85106
@register_nary_op(ops.SqlScalarOp, pass_op=True)
86107
def _(*operands: TypedExpr, op: ops.SqlScalarOp) -> sge.Expression:
87108
return sg.parse_one(

packages/bigframes/bigframes/core/compile/sqlglot/expressions/geo_ops.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,6 @@
2424
register_binary_op = expression_compiler.expression_compiler.register_binary_op
2525

2626

27-
@register_unary_op(ops.geo_area_op)
28-
def _(expr: TypedExpr) -> sge.Expression:
29-
return sge.func("ST_AREA", expr.expr)
30-
31-
3227
@register_unary_op(ops.geo_st_astext_op)
3328
def _(expr: TypedExpr) -> sge.Expression:
3429
return sge.func("ST_ASTEXT", expr.expr)
@@ -50,11 +45,6 @@ def _(expr: TypedExpr, op: ops.GeoStBufferOp) -> sge.Expression:
5045
)
5146

5247

53-
@register_unary_op(ops.geo_st_centroid_op)
54-
def _(expr: TypedExpr) -> sge.Expression:
55-
return sge.func("ST_CENTROID", expr.expr)
56-
57-
5848
@register_unary_op(ops.geo_st_convexhull_op)
5949
def _(expr: TypedExpr) -> sge.Expression:
6050
return sge.func("ST_CONVEXHULL", expr.expr)
@@ -97,15 +87,6 @@ def _(
9787
return sge.func("ST_REGIONSTATS", *args)
9888

9989

100-
@register_unary_op(ops.GeoStSimplifyOp, pass_op=True)
101-
def _(expr: TypedExpr, op: ops.GeoStSimplifyOp) -> sge.Expression:
102-
return sge.func(
103-
"ST_SIMPLIFY",
104-
expr.expr,
105-
sge.convert(op.tolerance_meters),
106-
)
107-
108-
10990
@register_unary_op(ops.geo_x_op)
11091
def _(expr: TypedExpr) -> sge.Expression:
11192
return sge.func("ST_X", expr.expr)

packages/bigframes/bigframes/core/compile/sqlglot/expressions/typed_expr.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ class TypedExpr:
2525

2626
expr: sge.Expression
2727
dtype: dtypes.ExpressionType
28+
29+
# kludge to support optional args in argument lists
30+
is_omitted: bool = False

packages/bigframes/bigframes/core/expression.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,56 @@ def output_type(self) -> dtypes.ExpressionType:
364364
return self.dtype
365365

366366

367+
@dataclasses.dataclass(frozen=True)
368+
class OmittedArg(Expression):
369+
"""Represents an omitted optional arg used calling a function."""
370+
371+
@property
372+
def free_variables(self) -> typing.Tuple[Hashable, ...]:
373+
return ()
374+
375+
@property
376+
def is_const(self) -> bool:
377+
return True
378+
379+
@property
380+
def column_references(self) -> typing.Tuple[ids.ColumnId, ...]:
381+
return ()
382+
383+
@property
384+
def is_resolved(self):
385+
return True # vacuously
386+
387+
@property
388+
def output_type(self) -> dtypes.ExpressionType:
389+
return None
390+
391+
def bind_refs(
392+
self,
393+
bindings: Mapping[ids.ColumnId, Expression],
394+
allow_partial_bindings: bool = False,
395+
) -> OmittedArg:
396+
return self
397+
398+
def bind_variables(
399+
self,
400+
bindings: Mapping[Hashable, Expression],
401+
allow_partial_bindings: bool = False,
402+
) -> Expression:
403+
return self
404+
405+
@property
406+
def is_bijective(self) -> bool:
407+
return True
408+
409+
@property
410+
def is_identity(self) -> bool:
411+
return True
412+
413+
def transform_children(self, t: Callable[[Expression], Expression]) -> Expression:
414+
return self
415+
416+
367417
@dataclasses.dataclass(frozen=True)
368418
class OpExpression(Expression):
369419
"""An expression representing a scalar operation applied to 1 or more argument sub-expressions."""

0 commit comments

Comments
 (0)