This repository was archived by the owner on Apr 1, 2026. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 68
Expand file tree
/
Copy pathtest_aggregation.py
More file actions
153 lines (136 loc) · 4.97 KB
/
test_aggregation.py
File metadata and controls
153 lines (136 loc) · 4.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from google.cloud import bigquery
import pytest
from bigframes.core import agg_expressions, array_value, expression, identifiers, nodes
import bigframes.operations.aggregations as agg_ops
from bigframes.session import direct_gbq_execution, polars_executor
from bigframes.testing.engine_utils import assert_equivalence_execution
pytest.importorskip("polars")
# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree.
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
def apply_agg_to_all_valid(
array: array_value.ArrayValue, op: agg_ops.UnaryAggregateOp, excluded_cols=[]
) -> array_value.ArrayValue:
"""
Apply the aggregation to every column in the array that has a compatible datatype.
"""
exprs_by_name = []
for arg in array.column_ids:
if arg in excluded_cols:
continue
try:
_ = op.output_type(array.get_column_type(arg))
expr = agg_expressions.UnaryAggregation(op, expression.deref(arg))
name = f"{arg}-{op.name}"
exprs_by_name.append((expr, name))
except TypeError:
continue
assert len(exprs_by_name) > 0
new_arr = array.aggregate(exprs_by_name)
return new_arr
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_aggregate_post_filter_size(
scalars_array_value: array_value.ArrayValue,
engine,
):
w_offsets, offsets_id = (
scalars_array_value.select_columns(("bool_col", "string_col"))
.filter(expression.deref("bool_col"))
.promote_offsets()
)
plan = (
w_offsets.select_columns((offsets_id, "bool_col", "string_col"))
.row_count()
.node
)
assert_equivalence_execution(plan, REFERENCE_ENGINE, engine)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
def test_engines_aggregate_size(
scalars_array_value: array_value.ArrayValue,
engine,
):
node = nodes.AggregateNode(
scalars_array_value.node,
aggregations=(
(
agg_expressions.NullaryAggregation(agg_ops.SizeOp()),
identifiers.ColumnId("size_op"),
),
(
agg_expressions.UnaryAggregation(
agg_ops.SizeUnaryOp(), expression.deref("string_col")
),
identifiers.ColumnId("unary_size_op"),
),
),
)
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
@pytest.mark.parametrize(
"op",
[agg_ops.min_op, agg_ops.max_op, agg_ops.mean_op, agg_ops.sum_op, agg_ops.count_op],
)
def test_engines_unary_aggregates(
scalars_array_value: array_value.ArrayValue,
engine,
op,
):
node = apply_agg_to_all_valid(scalars_array_value, op).node
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
def test_sql_engines_median_op_aggregates(
scalars_array_value: array_value.ArrayValue,
bigquery_client: bigquery.Client,
):
node = apply_agg_to_all_valid(
scalars_array_value,
agg_ops.MedianOp(),
).node
left_engine = direct_gbq_execution.DirectGbqExecutor(bigquery_client)
right_engine = direct_gbq_execution.DirectGbqExecutor(
bigquery_client, compiler="sqlglot"
)
assert_equivalence_execution(node, left_engine, right_engine)
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
@pytest.mark.parametrize(
"grouping_cols",
[
["bool_col"],
["string_col", "int64_col"],
["date_col"],
["datetime_col"],
["timestamp_col"],
["bytes_col"],
],
)
def test_engines_grouped_aggregate(
scalars_array_value: array_value.ArrayValue, engine, grouping_cols
):
node = nodes.AggregateNode(
scalars_array_value.node,
aggregations=(
(
agg_expressions.NullaryAggregation(agg_ops.SizeOp()),
identifiers.ColumnId("size_op"),
),
(
agg_expressions.UnaryAggregation(
agg_ops.SizeUnaryOp(), expression.deref("string_col")
),
identifiers.ColumnId("unary_size_op"),
),
),
by_column_ids=tuple(expression.deref(id) for id in grouping_cols),
)
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)