Skip to content
This repository was archived by the owner on Apr 1, 2026. It is now read-only.

Commit 1ff716d

Browse files
committed
refactor: add agg_ops.MeanOp for sqlglot compiler
1 parent 17b5d3e commit 1ff716d

File tree

3 files changed

+42
-1
lines changed

3 files changed

+42
-1
lines changed

bigframes/core/compile/sqlglot/aggregations/unary_compiler.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,26 @@ def _(
5555
return apply_window_if_present(sge.func("MAX", column.expr), window)
5656

5757

58+
@UNARY_OP_REGISTRATION.register(agg_ops.MeanOp)
59+
def _(
60+
op: agg_ops.MeanOp,
61+
column: typed_expr.TypedExpr,
62+
window: typing.Optional[window_spec.WindowSpec] = None,
63+
) -> sge.Expression:
64+
expr = column.expr
65+
if column.dtype == dtypes.BOOL_DTYPE:
66+
expr = sge.Cast(this=expr, to="INT64")
67+
68+
expr = sge.func("AVG", expr)
69+
70+
should_floor_result = (
71+
op.should_floor_result or column.dtype == dtypes.TIMEDELTA_DTYPE
72+
)
73+
if should_floor_result:
74+
expr = sge.func("FLOOR", expr)
75+
return apply_window_if_present(expr, window)
76+
77+
5878
@UNARY_OP_REGISTRATION.register(agg_ops.MinOp)
5979
def _(
6080
op: agg_ops.MinOp,

tests/system/small/engines/test_aggregation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_engines_aggregate_size(
7070
assert_equivalence_execution(node, REFERENCE_ENGINE, engine)
7171

7272

73-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
73+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
7474
@pytest.mark.parametrize(
7575
"op",
7676
[agg_ops.min_op, agg_ops.max_op, agg_ops.mean_op, agg_ops.sum_op, agg_ops.count_op],

tests/unit/core/compile/sqlglot/aggregations/test_unary_compiler.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,27 @@ def test_max(scalar_types_df: bpd.DataFrame, snapshot):
5656
snapshot.assert_match(sql, "out.sql")
5757

5858

59+
def test_mean(scalar_types_df: bpd.DataFrame, snapshot):
60+
bf_df = scalar_types_df[["duration_col"]]
61+
bf_df["duration_col"] = bpd.to_timedelta(bf_df["duration_col"], unit="us")
62+
63+
# sql = bf_df.sql
64+
# snapshot.assert_match(sql, "out.sql")
65+
agg_ops_map = {
66+
# "int64_col": agg_ops.MeanOp().as_expr("int64_col"),
67+
# "bool_col": agg_ops.MeanOp().as_expr("bool_col"),
68+
"duration_col": agg_ops.MeanOp().as_expr("duration_col"),
69+
# "int64_col_w_floor": agg_ops.MeanOp(should_floor_result=True).as_expr(
70+
# "int64_col"
71+
# ),
72+
}
73+
sql = _apply_unary_agg_ops(
74+
bf_df, list(agg_ops_map.values()), list(agg_ops_map.keys())
75+
)
76+
77+
snapshot.assert_match(sql, "out.sql")
78+
79+
5980
def test_min(scalar_types_df: bpd.DataFrame, snapshot):
6081
col_name = "int64_col"
6182
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)