Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 45b4a80

Browse files
committed
chore: add a function to traverse BFET and encode type usage
1 parent 328d048 commit 45b4a80

File tree

4 files changed

+197
-19
lines changed

4 files changed

+197
-19
lines changed

bigframes/core/logging/data_types.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,94 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from __future__ import annotations
16+
17+
import functools
1518

1619
from bigframes import dtypes
20+
from bigframes.core import agg_expressions, bigframe_node, expression, nodes
21+
from bigframes.core.rewrite import schema_binding
22+
23+
24+
def encode_type_refs(root: bigframe_node.BigFrameNode) -> str:
25+
return f"{root.reduce_up(_encode_type_refs_from_node):x}"
26+
27+
28+
def _encode_type_refs_from_node(
29+
node: bigframe_node.BigFrameNode, child_results: tuple[int, ...]
30+
) -> int:
31+
child_result = functools.reduce(lambda x, y: x | y, child_results, 0)
32+
33+
curr_result = 0
34+
if isinstance(node, nodes.FilterNode):
35+
curr_result = _encode_type_refs_from_expr(node.predicate, node.child)
36+
elif isinstance(node, nodes.ProjectionNode):
37+
for assignment in node.assignments:
38+
expr = assignment[0]
39+
if isinstance(expr, (expression.DerefOp)):
40+
# Ignore direct assignments in projection nodes.
41+
continue
42+
curr_result = curr_result | _encode_type_refs_from_expr(
43+
assignment[0], node.child
44+
)
45+
elif isinstance(node, nodes.SelectionNode):
46+
# Do nothing
47+
pass
48+
elif isinstance(node, nodes.OrderByNode):
49+
for by in node.by:
50+
curr_result = curr_result | _encode_type_refs_from_expr(
51+
by.scalar_expression, node.child
52+
)
53+
elif isinstance(node, nodes.JoinNode):
54+
for left, right in node.conditions:
55+
curr_result = (
56+
curr_result
57+
| _encode_type_refs_from_expr(left, node.left_child)
58+
| _encode_type_refs_from_expr(right, node.right_child)
59+
)
60+
elif isinstance(node, nodes.InNode):
61+
curr_result = _encode_type_refs_from_expr(node.left_col, node.left_child)
62+
elif isinstance(node, nodes.AggregateNode):
63+
for agg, _ in node.aggregations:
64+
curr_result = curr_result | _encode_type_refs_from_expr(agg, node.child)
65+
elif isinstance(node, nodes.WindowOpNode):
66+
for grouping_key in node.window_spec.grouping_keys:
67+
curr_result = curr_result | _encode_type_refs_from_expr(
68+
grouping_key, node.child
69+
)
70+
for ordering_expr in node.window_spec.ordering:
71+
curr_result = curr_result | _encode_type_refs_from_expr(
72+
ordering_expr.scalar_expression, node.child
73+
)
74+
for col_def in node.agg_exprs:
75+
curr_result = curr_result | _encode_type_refs_from_expr(
76+
col_def.expression, node.child
77+
)
78+
79+
return child_result | curr_result
80+
81+
82+
def _encode_type_refs_from_expr(
83+
expr: expression.Expression, child_node: bigframe_node.BigFrameNode
84+
) -> int:
85+
# TODO(b/409387790): Remove this branch once SQLGlot compiler fully replaces Ibis compiler
86+
if not expr.is_resolved:
87+
if isinstance(expr, agg_expressions.Aggregation):
88+
expr = schema_binding._bind_schema_to_aggregation_expr(expr, child_node)
89+
else:
90+
expr = expression.bind_schema_fields(expr, child_node.field_by_id)
1791

92+
result = _get_dtype_mask(expr.output_type)
93+
for child_expr in expr.children:
94+
result = result | _encode_type_refs_from_expr(child_expr, child_node)
1895

19-
def _add_data_type(existing_types: int, curr_type: dtypes.Dtype) -> int:
20-
return existing_types | _get_dtype_mask(curr_type)
96+
return result
2197

2298

23-
def _get_dtype_mask(dtype: dtypes.Dtype) -> int:
99+
def _get_dtype_mask(dtype: dtypes.Dtype | None) -> int:
100+
if dtype is None:
101+
# If the dtype is not given, ignore
102+
return 0
24103
if dtype == dtypes.INT_DTYPE:
25104
return 1 << 1
26105
if dtype == dtypes.FLOAT_DTYPE:
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Sequence
16+
17+
from bigframes import dtypes
18+
from bigframes.core.logging import data_types
19+
20+
21+
def encode_types(inputs: Sequence[dtypes.Dtype]) -> str:
22+
encoded_val = 0
23+
for t in inputs:
24+
encoded_val = encoded_val | data_types._get_dtype_mask(t)
25+
26+
return f"{encoded_val:x}"
27+
28+
29+
def test_get_type_refs_no_op(scalars_df_index):
30+
node = scalars_df_index._block._expr.node
31+
expected_types: list[dtypes.Dtype] = []
32+
33+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
34+
35+
36+
def test_get_type_refs_projection(scalars_df_index):
37+
node = (
38+
scalars_df_index["datetime_col"] - scalars_df_index["datetime_col"]
39+
)._block._expr.node
40+
expected_types = [dtypes.DATETIME_DTYPE, dtypes.TIMEDELTA_DTYPE]
41+
42+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
43+
44+
45+
def test_get_type_refs_filter(scalars_df_index):
46+
node = scalars_df_index[scalars_df_index["int64_col"] > 0]._block._expr.node
47+
expected_types = [dtypes.INT_DTYPE, dtypes.BOOL_DTYPE]
48+
49+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
50+
51+
52+
def test_get_type_refs_order_by(scalars_df_index):
53+
node = scalars_df_index.sort_index()._block._expr.node
54+
expected_types = [dtypes.INT_DTYPE]
55+
56+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
57+
58+
59+
def test_get_type_refs_join(scalars_df_index):
60+
node = (
61+
scalars_df_index[["int64_col"]].merge(
62+
scalars_df_index[["float64_col"]],
63+
left_on="int64_col",
64+
right_on="float64_col",
65+
)
66+
)._block._expr.node
67+
expected_types = [dtypes.INT_DTYPE, dtypes.FLOAT_DTYPE]
68+
69+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
70+
71+
72+
def test_get_type_refs_isin(scalars_df_index):
73+
node = scalars_df_index["string_col"].isin(["a"])._block._expr.node
74+
expected_types = [dtypes.STRING_DTYPE, dtypes.BOOL_DTYPE]
75+
76+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
77+
78+
79+
def test_get_type_refs_agg(scalars_df_index):
80+
node = scalars_df_index[["bool_col", "string_col"]].count()._block._expr.node
81+
expected_types = [
82+
dtypes.INT_DTYPE,
83+
dtypes.BOOL_DTYPE,
84+
dtypes.STRING_DTYPE,
85+
dtypes.FLOAT_DTYPE,
86+
]
87+
88+
assert data_types.encode_type_refs(node) == encode_types(expected_types)
89+
90+
91+
def test_get_type_refs_window(scalars_df_index):
92+
node = (
93+
scalars_df_index[["string_col", "bool_col"]]
94+
.groupby("string_col")
95+
.rolling(window=3)
96+
.count()
97+
._block._expr.node
98+
)
99+
expected_types = [dtypes.STRING_DTYPE, dtypes.BOOL_DTYPE, dtypes.INT_DTYPE]
100+
101+
assert data_types.encode_type_refs(node) == encode_types(expected_types)

tests/unit/core/logging/test_data_types.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
@pytest.mark.parametrize(
3030
("dtype", "expected_mask"),
3131
[
32+
(None, 0),
3233
(UNKNOWN_TYPE, 1 << 0),
3334
(dtypes.INT_DTYPE, 1 << 1),
3435
(dtypes.FLOAT_DTYPE, 1 << 2),
@@ -51,19 +52,3 @@
5152
)
5253
def test_get_dtype_mask(dtype, expected_mask):
5354
assert data_types._get_dtype_mask(dtype) == expected_mask
54-
55-
56-
def test_add_data_type__type_overlap_no_op():
57-
curr_type = dtypes.STRING_DTYPE
58-
existing_types = data_types._get_dtype_mask(curr_type)
59-
60-
assert data_types._add_data_type(existing_types, curr_type) == existing_types
61-
62-
63-
def test_add_data_type__new_type_updated():
64-
curr_type = dtypes.STRING_DTYPE
65-
existing_types = data_types._get_dtype_mask(dtypes.INT_DTYPE)
66-
67-
assert data_types._add_data_type(
68-
existing_types, curr_type
69-
) == existing_types | data_types._get_dtype_mask(curr_type)

0 commit comments

Comments
 (0)