Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/bigframes/bigframes/core/array_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,7 @@ def relational_join(
for l_col, r_col in conditions
),
type=type,
nulls_equal=True, # pandas semantics
propogate_order=propogate_order or self.session._strictly_ordered,
)
return ArrayValue(join_node), (l_mapping, r_mapping)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
# Need to do this before replacing unsupported ops, as that will rewrite slice ops
result_node = rewrites.pull_up_limits(result_node)
result_node = _replace_unsupported_ops(result_node)
result_node = result_node.bottom_up(rewrites.simplify_join)
# prune before pulling up order to avoid unnnecessary row_number() ops
result_node = cast(nodes.ResultNode, rewrites.column_pruning(result_node))
result_node = rewrites.defer_order(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def compile_sql(request: configs.CompileRequest) -> configs.CompileResult:
# Need to do this before replacing unsupported ops, as that will rewrite slice ops
result_node = rewrite.pull_up_limits(result_node)
result_node = _replace_unsupported_ops(result_node)
result_node = result_node.bottom_up(rewrite.simplify_join)
# prune before pulling up order to avoid unnnecessary row_number() ops
result_node = typing.cast(nodes.ResultNode, rewrite.column_pruning(result_node))
result_node = rewrite.defer_order(
Expand Down
3 changes: 3 additions & 0 deletions packages/bigframes/bigframes/core/local_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def to_arrow(
else:
return schema, batches

def is_nullable(self, column_id: identifiers.ColumnId) -> bool:
return self.data.column(column_id).null_count > 0

def to_pyarrow_table(
self,
*,
Expand Down
22 changes: 10 additions & 12 deletions packages/bigframes/bigframes/core/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,7 @@ class InNode(BigFrameNode, AdditiveNode):
right_child: BigFrameNode
left_col: ex.DerefOp
indicator_col: identifiers.ColumnId
nulls_equal: bool = True
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment on what this field means? It's not very obvious to me

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment


def _validate(self):
assert len(self.right_child.fields) == 1
Expand Down Expand Up @@ -271,10 +272,7 @@ def additive_base(self) -> BigFrameNode:

@property
def joins_nulls(self) -> bool:
left_nullable = self.left_child.field_by_id[self.left_col.id].nullable
# assumption: right side has one column
right_nullable = self.right_child.fields[0].nullable
return left_nullable or right_nullable
return self.nulls_equal

@property
def _node_expressions(self):
Expand Down Expand Up @@ -316,6 +314,8 @@ class JoinNode(BigFrameNode):
right_child: BigFrameNode
conditions: typing.Tuple[typing.Tuple[ex.DerefOp, ex.DerefOp], ...]
type: typing.Literal["inner", "outer", "left", "right", "cross"]
# pandas treats nulls as equal, sql does not
nulls_equal: bool
propogate_order: bool

def _validate(self):
Expand Down Expand Up @@ -355,13 +355,7 @@ def fields(self) -> Sequence[Field]:

@property
def joins_nulls(self) -> bool:
for left_ref, right_ref in self.conditions:
if (
self.left_child.field_by_id[left_ref.id].nullable
and self.right_child.field_by_id[right_ref.id].nullable
):
return True
return False
return self.nulls_equal

@functools.cached_property
def variables_introduced(self) -> int:
Expand Down Expand Up @@ -675,7 +669,11 @@ class ReadLocalNode(LeafNode):
@property
def fields(self) -> Sequence[Field]:
fields = tuple(
Field(col_id, self.local_data_source.schema.get_type(source_id))
Field(
col_id,
self.local_data_source.schema.get_type(source_id),
nullable=self.local_data_source.is_nullable(source_id),
)
for col_id, source_id in self.scan_list.items
)

Expand Down
2 changes: 2 additions & 0 deletions packages/bigframes/bigframes/core/rewrite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
rewrite_range_rolling,
simplify_complex_windows,
)
from bigframes.core.rewrite.nullity import simplify_join

__all__ = [
"as_sql_nodes",
Expand All @@ -55,4 +56,5 @@
"defer_selection",
"simplify_complex_windows",
"lower_udfs",
"simplify_join",
]
42 changes: 42 additions & 0 deletions packages/bigframes/bigframes/core/rewrite/nullity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from bigframes.core import nodes
import dataclasses


def simplify_join(node: nodes.BigFrameNode) -> nodes.BigFrameNode:
"""Simplify a join node by removing nullity checks."""
# if join conditions are provably non-null, we can set nulls_equal=False
if isinstance(node, nodes.JoinNode):
# even better, we can always make nulls_equal false, but wrap the join keys in coalesce
# to handle nulls correctly, this is more granular than the current implementation
for left_ref, right_ref in node.conditions:
if (
node.left_child.field_by_id[left_ref.id].nullable
and node.right_child.field_by_id[right_ref.id].nullable
):
return node
return dataclasses.replace(node, nulls_equal=False)
elif isinstance(node, nodes.InNode):
if (
node.left_child.field_by_id[node.left_col.id].nullable
and node.right_child.fields[0].nullable
):
return node
return dataclasses.replace(node, nulls_equal=False)
else:
return node
106 changes: 53 additions & 53 deletions packages/bigframes/tests/unit/core/compile/sqlglot/tpch/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,80 +24,80 @@

TPCH_SCHEMAS = {
"LINEITEM": [
bigquery.SchemaField("L_ORDERKEY", "INTEGER"),
bigquery.SchemaField("L_PARTKEY", "INTEGER"),
bigquery.SchemaField("L_SUPPKEY", "INTEGER"),
bigquery.SchemaField("L_LINENUMBER", "INTEGER"),
bigquery.SchemaField("L_QUANTITY", "FLOAT"),
bigquery.SchemaField("L_EXTENDEDPRICE", "FLOAT"),
bigquery.SchemaField("L_DISCOUNT", "FLOAT"),
bigquery.SchemaField("L_TAX", "FLOAT"),
bigquery.SchemaField("L_RETURNFLAG", "STRING"),
bigquery.SchemaField("L_LINESTATUS", "STRING"),
bigquery.SchemaField("L_SHIPDATE", "DATE"),
bigquery.SchemaField("L_COMMITDATE", "DATE"),
bigquery.SchemaField("L_RECEIPTDATE", "DATE"),
bigquery.SchemaField("L_SHIPINSTRUCT", "STRING"),
bigquery.SchemaField("L_SHIPMODE", "STRING"),
bigquery.SchemaField("L_ORDERKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("L_PARTKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("L_SUPPKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("L_LINENUMBER", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("L_QUANTITY", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("L_EXTENDEDPRICE", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("L_DISCOUNT", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("L_TAX", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("L_RETURNFLAG", "STRING", mode="REQUIRED"),
bigquery.SchemaField("L_LINESTATUS", "STRING", mode="REQUIRED"),
bigquery.SchemaField("L_SHIPDATE", "DATE", mode="REQUIRED"),
bigquery.SchemaField("L_COMMITDATE", "DATE", mode="REQUIRED"),
bigquery.SchemaField("L_RECEIPTDATE", "DATE", mode="REQUIRED"),
bigquery.SchemaField("L_SHIPINSTRUCT", "STRING", mode="REQUIRED"),
bigquery.SchemaField("L_SHIPMODE", "STRING", mode="REQUIRED"),
bigquery.SchemaField("L_COMMENT", "STRING"),
],
"ORDERS": [
bigquery.SchemaField("O_ORDERKEY", "INTEGER"),
bigquery.SchemaField("O_CUSTKEY", "INTEGER"),
bigquery.SchemaField("O_ORDERSTATUS", "STRING"),
bigquery.SchemaField("O_TOTALPRICE", "FLOAT"),
bigquery.SchemaField("O_ORDERDATE", "DATE"),
bigquery.SchemaField("O_ORDERPRIORITY", "STRING"),
bigquery.SchemaField("O_CLERK", "STRING"),
bigquery.SchemaField("O_SHIPPRIORITY", "INTEGER"),
bigquery.SchemaField("O_ORDERKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("O_CUSTKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("O_ORDERSTATUS", "STRING", mode="REQUIRED"),
bigquery.SchemaField("O_TOTALPRICE", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("O_ORDERDATE", "DATE", mode="REQUIRED"),
bigquery.SchemaField("O_ORDERPRIORITY", "STRING", mode="REQUIRED"),
bigquery.SchemaField("O_CLERK", "STRING", mode="REQUIRED"),
bigquery.SchemaField("O_SHIPPRIORITY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("O_COMMENT", "STRING"),
],
"PART": [
bigquery.SchemaField("P_PARTKEY", "INTEGER"),
bigquery.SchemaField("P_NAME", "STRING"),
bigquery.SchemaField("P_MFGR", "STRING"),
bigquery.SchemaField("P_BRAND", "STRING"),
bigquery.SchemaField("P_TYPE", "STRING"),
bigquery.SchemaField("P_SIZE", "INTEGER"),
bigquery.SchemaField("P_CONTAINER", "STRING"),
bigquery.SchemaField("P_RETAILPRICE", "FLOAT"),
bigquery.SchemaField("P_PARTKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("P_NAME", "STRING", mode="REQUIRED"),
bigquery.SchemaField("P_MFGR", "STRING", mode="REQUIRED"),
bigquery.SchemaField("P_BRAND", "STRING", mode="REQUIRED"),
bigquery.SchemaField("P_TYPE", "STRING", mode="REQUIRED"),
bigquery.SchemaField("P_SIZE", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("P_CONTAINER", "STRING", mode="REQUIRED"),
bigquery.SchemaField("P_RETAILPRICE", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("P_COMMENT", "STRING"),
],
"SUPPLIER": [
bigquery.SchemaField("S_SUPPKEY", "INTEGER"),
bigquery.SchemaField("S_NAME", "STRING"),
bigquery.SchemaField("S_ADDRESS", "STRING"),
bigquery.SchemaField("S_NATIONKEY", "INTEGER"),
bigquery.SchemaField("S_PHONE", "STRING"),
bigquery.SchemaField("S_ACCTBAL", "FLOAT"),
bigquery.SchemaField("S_SUPPKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("S_NAME", "STRING", mode="REQUIRED"),
bigquery.SchemaField("S_ADDRESS", "STRING", mode="REQUIRED"),
bigquery.SchemaField("S_NATIONKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("S_PHONE", "STRING", mode="REQUIRED"),
bigquery.SchemaField("S_ACCTBAL", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("S_COMMENT", "STRING"),
],
"PARTSUPP": [
bigquery.SchemaField("PS_PARTKEY", "INTEGER"),
bigquery.SchemaField("PS_SUPPKEY", "INTEGER"),
bigquery.SchemaField("PS_AVAILQTY", "INTEGER"),
bigquery.SchemaField("PS_SUPPLYCOST", "FLOAT"),
bigquery.SchemaField("PS_PARTKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("PS_SUPPKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("PS_AVAILQTY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("PS_SUPPLYCOST", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("PS_COMMENT", "STRING"),
],
"CUSTOMER": [
bigquery.SchemaField("C_CUSTKEY", "INTEGER"),
bigquery.SchemaField("C_NAME", "STRING"),
bigquery.SchemaField("C_ADDRESS", "STRING"),
bigquery.SchemaField("C_NATIONKEY", "INTEGER"),
bigquery.SchemaField("C_PHONE", "STRING"),
bigquery.SchemaField("C_ACCTBAL", "FLOAT"),
bigquery.SchemaField("C_MKTSEGMENT", "STRING"),
bigquery.SchemaField("C_CUSTKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("C_NAME", "STRING", mode="REQUIRED"),
bigquery.SchemaField("C_ADDRESS", "STRING", mode="REQUIRED"),
bigquery.SchemaField("C_NATIONKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("C_PHONE", "STRING", mode="REQUIRED"),
bigquery.SchemaField("C_ACCTBAL", "FLOAT", mode="REQUIRED"),
bigquery.SchemaField("C_MKTSEGMENT", "STRING", mode="REQUIRED"),
bigquery.SchemaField("C_COMMENT", "STRING"),
],
"NATION": [
bigquery.SchemaField("N_NATIONKEY", "INTEGER"),
bigquery.SchemaField("N_NAME", "STRING"),
bigquery.SchemaField("N_REGIONKEY", "INTEGER"),
bigquery.SchemaField("N_NATIONKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("N_NAME", "STRING", mode="REQUIRED"),
bigquery.SchemaField("N_REGIONKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("N_COMMENT", "STRING"),
],
"REGION": [
bigquery.SchemaField("R_REGIONKEY", "INTEGER"),
bigquery.SchemaField("R_NAME", "STRING"),
bigquery.SchemaField("R_REGIONKEY", "INTEGER", mode="REQUIRED"),
bigquery.SchemaField("R_NAME", "STRING", mode="REQUIRED"),
bigquery.SchemaField("R_COMMENT", "STRING"),
],
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ WITH `bfcte_0` AS (
AVG(`bfcol_43`) AS `bfcol_61`,
COUNT(`bfcol_41`) AS `bfcol_62`
FROM `bfcte_0`
WHERE
NOT `bfcol_44` IS NULL AND NOT `bfcol_45` IS NULL
GROUP BY
`bfcol_44`,
`bfcol_45`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ WITH `bfcte_0` AS (
`bfcol_8` AS `bfcol_24`
FROM `bfcte_3`
INNER JOIN `bfcte_2`
ON COALESCE(`bfcol_9`, 0) = COALESCE(`bfcol_7`, 0)
AND COALESCE(`bfcol_9`, 1) = COALESCE(`bfcol_7`, 1)
ON `bfcol_9` = `bfcol_7`
), `bfcte_5` AS (
SELECT
`bfcol_16` AS `bfcol_25`,
Expand All @@ -56,8 +55,7 @@ WITH `bfcte_0` AS (
`bfcol_5` AS `bfcol_35`
FROM `bfcte_4`
INNER JOIN `bfcte_1`
ON COALESCE(`bfcol_23`, 0) = COALESCE(`bfcol_2`, 0)
AND COALESCE(`bfcol_23`, 1) = COALESCE(`bfcol_2`, 1)
ON `bfcol_23` = `bfcol_2`
), `bfcte_6` AS (
SELECT
`bfcol_25`,
Expand Down Expand Up @@ -107,8 +105,7 @@ WITH `bfcte_0` AS (
), 2) AS `bfcol_83`
FROM `bfcte_5`
INNER JOIN `bfcte_0`
ON COALESCE(`bfcol_28`, 0) = COALESCE(`bfcol_0`, 0)
AND COALESCE(`bfcol_28`, 1) = COALESCE(`bfcol_0`, 1)
ON `bfcol_28` = `bfcol_0`
WHERE
(
(
Expand All @@ -133,13 +130,7 @@ WITH `bfcte_0` AS (
COALESCE(SUM(`bfcol_83`), 0) AS `bfcol_92`
FROM `bfcte_6`
WHERE
NOT `bfcol_76` IS NULL
AND NOT `bfcol_77` IS NULL
AND NOT `bfcol_80` IS NULL
AND NOT `bfcol_79` IS NULL
AND NOT `bfcol_82` IS NULL
AND NOT `bfcol_78` IS NULL
AND NOT `bfcol_81` IS NULL
NOT `bfcol_81` IS NULL
GROUP BY
`bfcol_76`,
`bfcol_77`,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@ WITH `bfcte_0` AS (
`bfcol_3` AS `bfcol_19`
FROM `bfcte_4`
INNER JOIN `bfcte_3`
ON COALESCE(`bfcol_18`, 0) = COALESCE(`bfcol_4`, 0)
AND COALESCE(`bfcol_18`, 1) = COALESCE(`bfcol_4`, 1)
ON `bfcol_18` = `bfcol_4`
), `bfcte_6` AS (
SELECT
`bfcol_19`,
Expand All @@ -46,8 +45,7 @@ WITH `bfcte_0` AS (
`bfcol_2` * `bfcol_1` AS `bfcol_40`
FROM `bfcte_5`
INNER JOIN `bfcte_1`
ON COALESCE(`bfcol_19`, 0) = COALESCE(`bfcol_0`, 0)
AND COALESCE(`bfcol_19`, 1) = COALESCE(`bfcol_0`, 1)
ON `bfcol_19` = `bfcol_0`
), `bfcte_7` AS (
SELECT
`bfcol_19`,
Expand All @@ -59,8 +57,7 @@ WITH `bfcte_0` AS (
`bfcol_13` * `bfcol_12` AS `bfcol_28`
FROM `bfcte_5`
INNER JOIN `bfcte_2`
ON COALESCE(`bfcol_19`, 0) = COALESCE(`bfcol_11`, 0)
AND COALESCE(`bfcol_19`, 1) = COALESCE(`bfcol_11`, 1)
ON `bfcol_19` = `bfcol_11`
), `bfcte_8` AS (
SELECT
COALESCE(SUM(`bfcol_40`), 0) AS `bfcol_44`
Expand All @@ -70,8 +67,6 @@ WITH `bfcte_0` AS (
`bfcol_27`,
COALESCE(SUM(`bfcol_28`), 0) AS `bfcol_35`
FROM `bfcte_7`
WHERE
NOT `bfcol_27` IS NULL
GROUP BY
`bfcol_27`
), `bfcte_10` AS (
Expand Down Expand Up @@ -101,8 +96,6 @@ WITH `bfcte_0` AS (
`bfcol_8`,
ANY_VALUE(`bfcol_51`) AS `bfcol_55`
FROM `bfcte_12`
WHERE
NOT `bfcol_7` IS NULL AND NOT `bfcol_8` IS NULL
GROUP BY
`bfcol_7`,
`bfcol_8`
Expand Down
Loading
Loading