Skip to content

Commit 0303716

Browse files
timsaucerclaude
andcommitted
Improve aggregate function tests and docstrings per review feedback
Add docstring example to grouping(), parametrize percentile_cont tests, and add multi-column grouping test case. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 271e86b commit 0303716

File tree

2 files changed

+46
-20
lines changed

2 files changed

+46
-20
lines changed

python/datafusion/functions.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4427,6 +4427,23 @@ def grouping(
44274427
expression: The column to check grouping status for
44284428
distinct: If True, compute on distinct values only
44294429
filter: If provided, only compute against rows for which the filter is True
4430+
4431+
Examples:
4432+
In a simple ``GROUP BY`` (no grouping sets), ``grouping()`` always
4433+
returns 0, indicating the column is part of the grouping key:
4434+
4435+
>>> import pyarrow as pa
4436+
>>> ctx = dfn.SessionContext()
4437+
>>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]})
4438+
>>> result = df.aggregate(
4439+
... [dfn.col("a")],
4440+
... [dfn.functions.grouping(dfn.col("a")),
4441+
... dfn.functions.sum(dfn.col("b")).alias("s")])
4442+
>>> batches = result.collect()
4443+
>>> grouping_vals = pa.concat_arrays(
4444+
... [batch.column(1) for batch in batches]).to_pylist()
4445+
>>> all(v == 0 for v in grouping_vals)
4446+
True
44304447
"""
44314448
filter_raw = filter.expr if filter is not None else None
44324449
return Expr(f.grouping(expression.expr, distinct=distinct, filter=filter_raw))

python/tests/test_functions.py

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1820,43 +1820,52 @@ def test_conditional_functions(df_with_nulls, expr, expected):
18201820
assert result.column(0) == expected
18211821

18221822

1823-
def test_percentile_cont():
1823+
@pytest.mark.parametrize(
1824+
("filter_expr", "expected"),
1825+
[
1826+
(None, 3.0),
1827+
(column("a") > literal(1.0), 3.5),
1828+
],
1829+
ids=["no_filter", "with_filter"],
1830+
)
1831+
def test_percentile_cont(filter_expr, expected):
18241832
ctx = SessionContext()
18251833
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
18261834
result = df.aggregate(
1827-
[], [f.percentile_cont(column("a"), 0.5).alias("v")]
1835+
[], [f.percentile_cont(column("a"), 0.5, filter=filter_expr).alias("v")]
18281836
).collect()[0]
1829-
assert result.column(0)[0].as_py() == 3.0
1830-
1831-
1832-
def test_percentile_cont_with_filter():
1833-
ctx = SessionContext()
1834-
df = ctx.from_pydict({"a": [1.0, 2.0, 3.0, 4.0, 5.0]})
1835-
result = df.aggregate(
1836-
[],
1837-
[
1838-
f.percentile_cont(
1839-
column("a"), 0.5, filter=column("a") > literal(1.0)
1840-
).alias("v")
1841-
],
1842-
).collect()[0]
1843-
assert result.column(0)[0].as_py() == 3.5
1837+
assert result.column(0)[0].as_py() == expected
18441838

18451839

18461840
def test_grouping():
18471841
ctx = SessionContext()
18481842
df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]})
18491843
# In a simple GROUP BY (no grouping sets), grouping() returns 0 for all rows.
1850-
# Note: grouping() must not be aliased directly in the aggregate expression list
1851-
# due to an upstream DataFusion analyzer limitation (the ResolveGroupingFunction
1852-
# rule doesn't unwrap Alias nodes). Apply aliases via a follow-up select instead.
18531844
result = df.aggregate(
18541845
[column("a")], [f.grouping(column("a")), f.sum(column("b")).alias("s")]
18551846
).collect()
18561847
grouping_col = pa.concat_arrays([batch.column(1) for batch in result]).to_pylist()
18571848
assert all(v == 0 for v in grouping_col)
18581849

18591850

1851+
def test_grouping_multiple_columns():
1852+
# Verify grouping() works when multiple columns are in the GROUP BY clause.
1853+
ctx = SessionContext()
1854+
df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 10, 30], "c": [100, 200, 300]})
1855+
result = df.aggregate(
1856+
[column("a"), column("b")],
1857+
[
1858+
f.grouping(column("a")),
1859+
f.grouping(column("b")),
1860+
f.sum(column("c")).alias("s"),
1861+
],
1862+
).collect()
1863+
grouping_a = pa.concat_arrays([batch.column(2) for batch in result]).to_pylist()
1864+
grouping_b = pa.concat_arrays([batch.column(3) for batch in result]).to_pylist()
1865+
assert all(v == 0 for v in grouping_a)
1866+
assert all(v == 0 for v in grouping_b)
1867+
1868+
18601869
def test_var_population():
18611870
ctx = SessionContext()
18621871
df = ctx.from_pydict({"a": [-1.0, 0.0, 2.0]})

0 commit comments

Comments
 (0)