Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 5 additions & 56 deletions bigframes/core/compile/compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
from __future__ import annotations

import functools
import itertools
import typing
from typing import Literal, Optional, Sequence
Expand All @@ -27,7 +26,7 @@
from google.cloud import bigquery
import pyarrow as pa

from bigframes.core import agg_expressions
from bigframes.core import agg_expressions, rewrite
import bigframes.core.agg_expressions as ex_types
import bigframes.core.compile.googlesql
import bigframes.core.compile.ibis_compiler.aggregate_compiler as agg_compiler
Expand All @@ -38,8 +37,6 @@
import bigframes.core.sql
from bigframes.core.window_spec import WindowSpec
import bigframes.dtypes
import bigframes.operations as ops
import bigframes.operations.aggregations as agg_ops

op_compiler = op_compilers.scalar_op_compiler

Expand Down Expand Up @@ -424,59 +421,11 @@ def project_window_op(
output_name,
)

if expression.op.order_independent and window_spec.is_unbounded:
# notably percentile_cont does not support ordering clause
window_spec = window_spec.without_order()

# TODO: Turn this logic into a true rewriter
result_expr: ex.Expression = agg_expressions.WindowExpression(
expression, window_spec
rewritten_expr = rewrite.simplify_complex_windows(
agg_expressions.WindowExpression(expression, window_spec)
)
clauses: list[tuple[ex.Expression, ex.Expression]] = []
if window_spec.min_periods and len(expression.inputs) > 0:
if not expression.op.nulls_count_for_min_values:
is_observation = ops.notnull_op.as_expr()

# Most operations do not count NULL values towards min_periods
per_col_does_count = (
ops.notnull_op.as_expr(input) for input in expression.inputs
)
# All inputs must be non-null for observation to count
is_observation = functools.reduce(
lambda x, y: ops.and_op.as_expr(x, y), per_col_does_count
)
observation_sentinel = ops.AsTypeOp(bigframes.dtypes.INT_DTYPE).as_expr(
is_observation
)
observation_count_expr = agg_expressions.WindowExpression(
ex_types.UnaryAggregation(agg_ops.sum_op, observation_sentinel),
window_spec,
)
else:
# Operations like count treat even NULLs as valid observations for the sake of min_periods
# notnull is just used to convert null values to non-null (FALSE) values to be counted
is_observation = ops.notnull_op.as_expr(expression.inputs[0])
observation_count_expr = agg_expressions.WindowExpression(
agg_ops.count_op.as_expr(is_observation),
window_spec,
)
clauses.append(
(
ops.lt_op.as_expr(
observation_count_expr, ex.const(window_spec.min_periods)
),
ex.const(None),
)
)
if clauses:
case_inputs = [
*itertools.chain.from_iterable(clauses),
ex.const(True),
result_expr,
]
result_expr = ops.CaseWhenOp().as_expr(*case_inputs)

ibis_expr = op_compiler.compile_expression(result_expr, self._ibis_bindings)

ibis_expr = op_compiler.compile_expression(rewritten_expr, self._ibis_bindings)

return UnorderedIR(self._table, (*self.columns, ibis_expr.name(output_name)))

Expand Down
154 changes: 34 additions & 120 deletions bigframes/core/compile/sqlglot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@
import bigframes_vendored.sqlglot.expressions as sge

from bigframes.core import (
agg_expressions,
expression,
guid,
identifiers,
nodes,
pyarrow_utils,
rewrite,
sql_nodes,
)
from bigframes.core.compile import configs
import bigframes.core.compile.sqlglot.aggregate_compiler as aggregate_compiler
Expand Down Expand Up @@ -104,30 +104,10 @@ def _compile_result_node(root: nodes.ResultNode) -> str:
root = typing.cast(nodes.ResultNode, rewrite.defer_selection(root))

# Have to bind schema as the final step before compilation.
# Probably, should defer even further
root = typing.cast(nodes.ResultNode, schema_binding.bind_schema_to_tree(root))

selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
(name, scalar_compiler.scalar_op_compiler.compile_expression(ref))
for ref, name in root.output_cols
)
sqlglot_ir = compile_node(root.child, uid_gen).select(selected_cols)

if root.order_by is not None:
ordering_cols = tuple(
sge.Ordered(
this=scalar_compiler.scalar_op_compiler.compile_expression(
ordering.scalar_expression
),
desc=ordering.direction.is_ascending is False,
nulls_first=ordering.na_last is False,
)
for ordering in root.order_by.all_ordering_columns
)
sqlglot_ir = sqlglot_ir.order_by(ordering_cols)

if root.limit is not None:
sqlglot_ir = sqlglot_ir.limit(root.limit)

sqlglot_ir = compile_node(rewrite.as_sql_nodes(root), uid_gen)
return sqlglot_ir.sql


Expand Down Expand Up @@ -160,6 +140,37 @@ def _compile_node(
raise ValueError(f"Can't compile unrecognized node: {node}")


@_compile_node.register
def compile_sql_select(node: sql_nodes.SelectNode, child: ir.SQLGlotIR):
sqlglot_ir = child
if node.sorting is not None:
ordering_cols = tuple(
sge.Ordered(
this=scalar_compiler.scalar_op_compiler.compile_expression(
ordering.scalar_expression
),
desc=ordering.direction.is_ascending is False,
nulls_first=ordering.na_last is False,
)
for ordering in node.sorting
)
sqlglot_ir = sqlglot_ir.order_by(ordering_cols)

projected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
(
cdef.id.sql,
scalar_compiler.scalar_op_compiler.compile_expression(cdef.expression),
)
for cdef in node.selections
)
sqlglot_ir = sqlglot_ir.select(projected_cols)

if node.limit is not None:
sqlglot_ir = sqlglot_ir.limit(node.limit)

return sqlglot_ir


@_compile_node.register
def compile_readlocal(node: nodes.ReadLocalNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
pa_table = node.local_data_source.data
Expand Down Expand Up @@ -188,30 +199,6 @@ def compile_readtable(node: nodes.ReadTableNode, child: ir.SQLGlotIR):
)


@_compile_node.register
def compile_selection(node: nodes.SelectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
selected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
(id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr))
for expr, id in node.input_output_pairs
)
return child.select(selected_cols)


@_compile_node.register
def compile_projection(node: nodes.ProjectionNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
projected_cols: tuple[tuple[str, sge.Expression], ...] = tuple(
(id.sql, scalar_compiler.scalar_op_compiler.compile_expression(expr))
for expr, id in node.assignments
)
return child.project(projected_cols)


@_compile_node.register
def compile_filter(node: nodes.FilterNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
condition = scalar_compiler.scalar_op_compiler.compile_expression(node.predicate)
return child.filter(tuple([condition]))


@_compile_node.register
def compile_join(
node: nodes.JoinNode, left: ir.SQLGlotIR, right: ir.SQLGlotIR
Expand Down Expand Up @@ -325,79 +312,6 @@ def compile_aggregate(node: nodes.AggregateNode, child: ir.SQLGlotIR) -> ir.SQLG
return child.aggregate(aggregations, by_cols, tuple(dropna_cols))


@_compile_node.register
def compile_window(node: nodes.WindowOpNode, child: ir.SQLGlotIR) -> ir.SQLGlotIR:
window_spec = node.window_spec
result = child
for cdef in node.agg_exprs:
assert isinstance(cdef.expression, agg_expressions.Aggregation)
if cdef.expression.op.order_independent and window_spec.is_unbounded:
# notably percentile_cont does not support ordering clause
window_spec = window_spec.without_order()

window_op = aggregate_compiler.compile_analytic(cdef.expression, window_spec)

inputs: tuple[sge.Expression, ...] = tuple(
scalar_compiler.scalar_op_compiler.compile_expression(
expression.DerefOp(column)
)
for column in cdef.expression.column_references
)

clauses: list[tuple[sge.Expression, sge.Expression]] = []
if window_spec.min_periods and len(inputs) > 0:
if not cdef.expression.op.nulls_count_for_min_values:
# Most operations do not count NULL values towards min_periods
not_null_columns = [
sge.Not(this=sge.Is(this=column, expression=sge.Null()))
for column in inputs
]
# All inputs must be non-null for observation to count
if not not_null_columns:
is_observation_expr: sge.Expression = sge.convert(True)
else:
is_observation_expr = not_null_columns[0]
for expr in not_null_columns[1:]:
is_observation_expr = sge.And(
this=is_observation_expr, expression=expr
)
is_observation = ir._cast(is_observation_expr, "INT64")
observation_count = windows.apply_window_if_present(
sge.func("SUM", is_observation), window_spec
)
observation_count = sge.func(
"COALESCE", observation_count, sge.convert(0)
)
else:
# Operations like count treat even NULLs as valid observations
# for the sake of min_periods notnull is just used to convert
# null values to non-null (FALSE) values to be counted.
is_observation = ir._cast(
sge.Not(this=sge.Is(this=inputs[0], expression=sge.Null())),
"INT64",
)
observation_count = windows.apply_window_if_present(
sge.func("COUNT", is_observation), window_spec
)

clauses.append(
(
observation_count < sge.convert(window_spec.min_periods),
sge.Null(),
)
)
if clauses:
when_expressions = [sge.When(this=cond, true=res) for cond, res in clauses]
window_op = sge.Case(ifs=when_expressions, default=window_op)

# TODO: check if we can directly window the expression.
result = result.window(
window_op=window_op,
output_column_id=cdef.id.sql,
)
return result


def _replace_unsupported_ops(node: nodes.BigFrameNode):
node = nodes.bottom_up(node, rewrite.rewrite_slice)
node = nodes.bottom_up(node, rewrite.rewrite_range_rolling)
Expand Down
Loading
Loading