Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 6c96c84

Browse files
refactor: Define sql nodes and transform
1 parent e6de52d commit 6c96c84

File tree

299 files changed

+2877
-4400
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

299 files changed

+2877
-4400
lines changed

bigframes/core/compile/compiled.py

Lines changed: 5 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16-
import functools
1716
import itertools
1817
import typing
1918
from typing import Literal, Optional, Sequence
@@ -27,7 +26,7 @@
2726
from google.cloud import bigquery
2827
import pyarrow as pa
2928

30-
from bigframes.core import agg_expressions
29+
from bigframes.core import agg_expressions, rewrite
3130
import bigframes.core.agg_expressions as ex_types
3231
import bigframes.core.compile.googlesql
3332
import bigframes.core.compile.ibis_compiler.aggregate_compiler as agg_compiler
@@ -38,8 +37,6 @@
3837
import bigframes.core.sql
3938
from bigframes.core.window_spec import WindowSpec
4039
import bigframes.dtypes
41-
import bigframes.operations as ops
42-
import bigframes.operations.aggregations as agg_ops
4340

4441
op_compiler = op_compilers.scalar_op_compiler
4542

@@ -424,59 +421,11 @@ def project_window_op(
424421
output_name,
425422
)
426423

427-
if expression.op.order_independent and window_spec.is_unbounded:
428-
# notably percentile_cont does not support ordering clause
429-
window_spec = window_spec.without_order()
430-
431-
# TODO: Turn this logic into a true rewriter
432-
result_expr: ex.Expression = agg_expressions.WindowExpression(
433-
expression, window_spec
424+
rewritten_expr = rewrite.simplify_complex_windows(
425+
agg_expressions.WindowExpression(expression, window_spec)
434426
)
435-
clauses: list[tuple[ex.Expression, ex.Expression]] = []
436-
if window_spec.min_periods and len(expression.inputs) > 0:
437-
if not expression.op.nulls_count_for_min_values:
438-
is_observation = ops.notnull_op.as_expr()
439-
440-
# Most operations do not count NULL values towards min_periods
441-
per_col_does_count = (
442-
ops.notnull_op.as_expr(input) for input in expression.inputs
443-
)
444-
# All inputs must be non-null for observation to count
445-
is_observation = functools.reduce(
446-
lambda x, y: ops.and_op.as_expr(x, y), per_col_does_count
447-
)
448-
observation_sentinel = ops.AsTypeOp(bigframes.dtypes.INT_DTYPE).as_expr(
449-
is_observation
450-
)
451-
observation_count_expr = agg_expressions.WindowExpression(
452-
ex_types.UnaryAggregation(agg_ops.sum_op, observation_sentinel),
453-
window_spec,
454-
)
455-
else:
456-
# Operations like count treat even NULLs as valid observations for the sake of min_periods
457-
# notnull is just used to convert null values to non-null (FALSE) values to be counted
458-
is_observation = ops.notnull_op.as_expr(expression.inputs[0])
459-
observation_count_expr = agg_expressions.WindowExpression(
460-
agg_ops.count_op.as_expr(is_observation),
461-
window_spec,
462-
)
463-
clauses.append(
464-
(
465-
ops.lt_op.as_expr(
466-
observation_count_expr, ex.const(window_spec.min_periods)
467-
),
468-
ex.const(None),
469-
)
470-
)
471-
if clauses:
472-
case_inputs = [
473-
*itertools.chain.from_iterable(clauses),
474-
ex.const(True),
475-
result_expr,
476-
]
477-
result_expr = ops.CaseWhenOp().as_expr(*case_inputs)
478-
479-
ibis_expr = op_compiler.compile_expression(result_expr, self._ibis_bindings)
427+
428+
ibis_expr = op_compiler.compile_expression(rewritten_expr, self._ibis_bindings)
480429

481430
return UnorderedIR(self._table, (*self.columns, ibis_expr.name(output_name)))
482431

bigframes/core/compile/sqlglot/compiler.py

Lines changed: 34 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@
2020
import bigframes_vendored.sqlglot.expressions as sge
2121

2222
from bigframes.core import (
23-
agg_expressions,
2423
expression,
2524
guid,
2625
identifiers,
2726
nodes,
2827
pyarrow_utils,
2928
rewrite,
29+
sql_nodes,
3030
)
3131
from bigframes.core.compile import configs
3232
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
@@ -104,30 +104,10 @@ def _compile_result_node(root: nodes.ResultNode) -> str:
104104
root = typing.cast(nodes.ResultNode, rewrite.defer_selection(root))
105105

106106
# Have to bind schema as the final step before compilation.
107+
# Probably, should defer even further
107108
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))
108109

109-
selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
110-
(name, scalar_compiler.scalar_op_compiler.compile_expression(ref))
111-
for ref, name in root.output_cols
112-
)
113-
sqlglot_ir = compile_node(root.child, uid_gen).select(selected_cols)
114-
115-
if root.order_by is not None:
116-
ordering_cols = tuple(
117-
sge.Ordered(
118-
this=scalar_compiler.scalar_op_compiler.compile_expression(
119-
ordering.scalar_expression
120-
),
121-
desc=ordering.direction.is_ascending is False,
122-
nulls_first=ordering.na_last is False,
123-
)
124-
for ordering in root.order_by.all_ordering_columns
125-
)
126-
sqlglot_ir = sqlglot_ir.order_by(ordering_cols)
127-
128-
if root.limit is not None:
129-
sqlglot_ir = sqlglot_ir.limit(root.limit)
130-
110+
sqlglot_ir = compile_node(rewrite.as_sql_nodes(root), uid_gen)
131111
return sqlglot_ir.sql
132112

133113

@@ -160,6 +140,37 @@ def _compile_node(
160140
raise ValueError(f"Can't compile unrecognized node: {node}")
161141

162142

143+
@_compile_node.register
144+
def compile_sql_select(node: sql_nodes.SelectNode, child: ir.SQLGlotIR):
145+
sqlglot_ir = child
146+
if node.sorting is not None:
147+
ordering_cols = tuple(
148+
sge.Ordered(
149+
this=scalar_compiler.scalar_op_compiler.compile_expression(
150+
ordering.scalar_expression
151+
),
152+
desc=ordering.direction.is_ascending is False,
153+
nulls_first=ordering.na_last is False,
154+
)
155+
for ordering in node.sorting
156+
)
157+
sqlglot_ir = sqlglot_ir.order_by(ordering_cols)
158+
159+
projected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
160+
(
161+
cdef.id.sql,
162+
scalar_compiler.scalar_op_compiler.compile_expression(cdef.expression),
163+
)
164+
for cdef in node.selections
165+
)
166+
sqlglot_ir = sqlglot_ir.project(projected_cols)
167+
168+
if node.limit is not None:
169+
sqlglot_ir = sqlglot_ir.limit(node.limit)
170+
171+
return sqlglot_ir
172+
173+
163174
@_compile_node.register
164175
def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
165176
pa_table = node.local_data_source.data
@@ -188,30 +199,6 @@ def compile_readtable(node: nodes.ReadTableNode, child: ir.SQLGlotIR):
188199
)
189200

190201

191-
@_compile_node.register
192-
def compile_selection(node: nodes.SelectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
193-
selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
194-
(id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr))
195-
for expr, id in node.input_output_pairs
196-
)
197-
return child.select(selected_cols)
198-
199-
200-
@_compile_node.register
201-
def compile_projection(node: nodes.ProjectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
202-
projected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
203-
(id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr))
204-
for expr, id in node.assignments
205-
)
206-
return child.project(projected_cols)
207-
208-
209-
@_compile_node.register
210-
def compile_filter(node: nodes.FilterNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
211-
condition = scalar_compiler.scalar_op_compiler.compile_expression(node.predicate)
212-
return child.filter(tuple([condition]))
213-
214-
215202
@_compile_node.register
216203
def compile_join(
217204
node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
@@ -325,79 +312,6 @@ def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLG
325312
return child.aggregate(aggregations, by_cols, tuple(dropna_cols))
326313

327314

328-
@_compile_node.register
329-
def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
330-
window_spec = node.window_spec
331-
result = child
332-
for cdef in node.agg_exprs:
333-
assert isinstance(cdef.expression, agg_expressions.Aggregation)
334-
if cdef.expression.op.order_independent and window_spec.is_unbounded:
335-
# notably percentile_cont does not support ordering clause
336-
window_spec = window_spec.without_order()
337-
338-
window_op = aggregate_compiler.compile_analytic(cdef.expression, window_spec)
339-
340-
inputs: tuple[sge.Expression, ...] = tuple(
341-
scalar_compiler.scalar_op_compiler.compile_expression(
342-
expression.DerefOp(column)
343-
)
344-
for column in cdef.expression.column_references
345-
)
346-
347-
clauses: list[tuple[sge.Expression, sge.Expression]] = []
348-
if window_spec.min_periods and len(inputs) > 0:
349-
if not cdef.expression.op.nulls_count_for_min_values:
350-
# Most operations do not count NULL values towards min_periods
351-
not_null_columns = [
352-
sge.Not(this=sge.Is(this=column, expression=sge.Null()))
353-
for column in inputs
354-
]
355-
# All inputs must be non-null for observation to count
356-
if not not_null_columns:
357-
is_observation_expr: sge.Expression = sge.convert(True)
358-
else:
359-
is_observation_expr = not_null_columns[0]
360-
for expr in not_null_columns[1:]:
361-
is_observation_expr = sge.And(
362-
this=is_observation_expr, expression=expr
363-
)
364-
is_observation = ir._cast(is_observation_expr, "INT64")
365-
observation_count = windows.apply_window_if_present(
366-
sge.func("SUM", is_observation), window_spec
367-
)
368-
observation_count = sge.func(
369-
"COALESCE", observation_count, sge.convert(0)
370-
)
371-
else:
372-
# Operations like count treat even NULLs as valid observations
373-
# for the sake of min_periods notnull is just used to convert
374-
# null values to non-null (FALSE) values to be counted.
375-
is_observation = ir._cast(
376-
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
377-
"INT64",
378-
)
379-
observation_count = windows.apply_window_if_present(
380-
sge.func("COUNT", is_observation), window_spec
381-
)
382-
383-
clauses.append(
384-
(
385-
observation_count < sge.convert(window_spec.min_periods),
386-
sge.Null(),
387-
)
388-
)
389-
if clauses:
390-
when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses]
391-
window_op = sge.Case(ifs=when_expressions, default=window_op)
392-
393-
# TODO: check if we can directly window the expression.
394-
result = result.window(
395-
window_op=window_op,
396-
output_column_id=cdef.id.sql,
397-
)
398-
return result
399-
400-
401315
def _replace_unsupported_ops(node: nodes.BigFrameNode):
402316
node = nodes.bottom_up(node, rewrite.rewrite_slice)
403317
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)

0 commit comments

Comments
 (0)