Skip to content

Commit 1243498

Browse files
timsaucerclaude
andcommitted
Add GroupingSet.rollup, .cube, and .grouping_sets factory methods
Expose ROLLUP, CUBE, and GROUPING SETS via the DataFrame API by adding static methods on GroupingSet that construct the corresponding Expr variants. Update grouping() docstring and tests to use the new API. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 0303716 commit 1243498

File tree

4 files changed

+251
-34
lines changed

4 files changed

+251
-34
lines changed

crates/core/src/expr/grouping_set.rs

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use datafusion::logical_expr::GroupingSet;
18+
use datafusion::logical_expr::{Expr, GroupingSet};
1919
use pyo3::prelude::*;
2020

21+
use crate::expr::PyExpr;
22+
2123
#[pyclass(
2224
from_py_object,
2325
frozen,
@@ -30,6 +32,39 @@ pub struct PyGroupingSet {
3032
grouping_set: GroupingSet,
3133
}
3234

35+
#[pymethods]
36+
impl PyGroupingSet {
37+
#[staticmethod]
38+
#[pyo3(signature = (*exprs))]
39+
fn rollup(exprs: Vec<PyExpr>) -> PyExpr {
40+
Expr::GroupingSet(GroupingSet::Rollup(
41+
exprs.into_iter().map(|e| e.expr).collect(),
42+
))
43+
.into()
44+
}
45+
46+
#[staticmethod]
47+
#[pyo3(signature = (*exprs))]
48+
fn cube(exprs: Vec<PyExpr>) -> PyExpr {
49+
Expr::GroupingSet(GroupingSet::Cube(
50+
exprs.into_iter().map(|e| e.expr).collect(),
51+
))
52+
.into()
53+
}
54+
55+
#[staticmethod]
56+
#[pyo3(signature = (*expr_lists))]
57+
fn grouping_sets(expr_lists: Vec<Vec<PyExpr>>) -> PyExpr {
58+
Expr::GroupingSet(GroupingSet::GroupingSets(
59+
expr_lists
60+
.into_iter()
61+
.map(|list| list.into_iter().map(|e| e.expr).collect())
62+
.collect(),
63+
))
64+
.into()
65+
}
66+
}
67+
3368
impl From<PyGroupingSet> for GroupingSet {
3469
fn from(grouping_set: PyGroupingSet) -> Self {
3570
grouping_set.grouping_set

python/datafusion/expr.py

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@
9191
Extension = expr_internal.Extension
9292
FileType = expr_internal.FileType
9393
Filter = expr_internal.Filter
94-
GroupingSet = expr_internal.GroupingSet
94+
_GroupingSetInternal = expr_internal.GroupingSet
9595
Join = expr_internal.Join
9696
ILike = expr_internal.ILike
9797
InList = expr_internal.InList
@@ -1430,3 +1430,135 @@ def __repr__(self) -> str:
14301430

14311431

14321432
SortKey = Expr | SortExpr | str
1433+
1434+
1435+
class GroupingSet:
1436+
"""Factory for creating grouping set expressions.
1437+
1438+
Grouping sets control how
1439+
:py:meth:`~datafusion.dataframe.DataFrame.aggregate` groups rows.
1440+
Instead of a single ``GROUP BY``, they produce multiple grouping
1441+
levels in one pass — subtotals, cross-tabulations, or arbitrary
1442+
column subsets.
1443+
1444+
Use :py:func:`~datafusion.functions.grouping` in the aggregate list
1445+
to tell which columns are aggregated across in each result row.
1446+
"""
1447+
1448+
@staticmethod
1449+
def rollup(*exprs: Expr) -> Expr:
1450+
"""Create a ``ROLLUP`` grouping set for use with ``aggregate()``.
1451+
1452+
``ROLLUP`` generates all prefixes of the given column list as
1453+
grouping sets. For example, ``rollup(a, b)`` produces grouping
1454+
sets ``(a, b)``, ``(a)``, and ``()`` (grand total).
1455+
1456+
This is equivalent to ``GROUP BY ROLLUP(a, b)`` in SQL.
1457+
1458+
Args:
1459+
*exprs: Column expressions to include in the rollup.
1460+
1461+
Examples:
1462+
>>> import pyarrow as pa
1463+
>>> import datafusion as dfn
1464+
>>> from datafusion.expr import GroupingSet
1465+
>>> ctx = dfn.SessionContext()
1466+
>>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]})
1467+
>>> result = df.aggregate(
1468+
... [GroupingSet.rollup(dfn.col("a"))],
1469+
... [dfn.functions.sum(dfn.col("b")).alias("s"),
1470+
... dfn.functions.grouping(dfn.col("a"))],
1471+
... ).sort(dfn.col("a").sort(nulls_first=False))
1472+
>>> batches = result.collect()
1473+
>>> pa.concat_arrays([b.column("s") for b in batches]).to_pylist()
1474+
[30, 30, 60]
1475+
1476+
See Also:
1477+
:py:meth:`cube`, :py:meth:`grouping_sets`,
1478+
:py:func:`~datafusion.functions.grouping`
1479+
"""
1480+
args = [e.expr for e in exprs]
1481+
return Expr(_GroupingSetInternal.rollup(*args))
1482+
1483+
@staticmethod
1484+
def cube(*exprs: Expr) -> Expr:
1485+
"""Create a ``CUBE`` grouping set for use with ``aggregate()``.
1486+
1487+
``CUBE`` generates all possible subsets of the given column list
1488+
as grouping sets. For example, ``cube(a, b)`` produces grouping
1489+
sets ``(a, b)``, ``(a)``, ``(b)``, and ``()`` (grand total).
1490+
1491+
This is equivalent to ``GROUP BY CUBE(a, b)`` in SQL.
1492+
1493+
Args:
1494+
*exprs: Column expressions to include in the cube.
1495+
1496+
Examples:
1497+
With a single column, ``cube`` behaves identically to
1498+
:py:meth:`rollup`:
1499+
1500+
>>> import pyarrow as pa
1501+
>>> import datafusion as dfn
1502+
>>> from datafusion.expr import GroupingSet
1503+
>>> ctx = dfn.SessionContext()
1504+
>>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]})
1505+
>>> result = df.aggregate(
1506+
... [GroupingSet.cube(dfn.col("a"))],
1507+
... [dfn.functions.sum(dfn.col("b")).alias("s"),
1508+
... dfn.functions.grouping(dfn.col("a"))],
1509+
... ).sort(dfn.col("a").sort(nulls_first=False))
1510+
>>> batches = result.collect()
1511+
>>> pa.concat_arrays([b.column(2) for b in batches]).to_pylist()
1512+
[0, 0, 1]
1513+
1514+
See Also:
1515+
:py:meth:`rollup`, :py:meth:`grouping_sets`,
1516+
:py:func:`~datafusion.functions.grouping`
1517+
"""
1518+
args = [e.expr for e in exprs]
1519+
return Expr(_GroupingSetInternal.cube(*args))
1520+
1521+
@staticmethod
1522+
def grouping_sets(*expr_lists: list[Expr]) -> Expr:
1523+
"""Create explicit grouping sets for use with ``aggregate()``.
1524+
1525+
Each argument is a list of column expressions representing one
1526+
grouping set. For example, ``grouping_sets([a], [b])`` groups
1527+
by ``a`` alone and by ``b`` alone in a single query.
1528+
1529+
This is equivalent to ``GROUP BY GROUPING SETS ((a), (b))`` in
1530+
SQL.
1531+
1532+
Args:
1533+
*expr_lists: Each positional argument is a list of
1534+
expressions forming one grouping set.
1535+
1536+
Examples:
1537+
>>> import pyarrow as pa
1538+
>>> import datafusion as dfn
1539+
>>> from datafusion.expr import GroupingSet
1540+
>>> ctx = dfn.SessionContext()
1541+
>>> df = ctx.from_pydict(
1542+
... {"a": ["x", "x", "y"], "b": ["m", "n", "m"],
1543+
... "c": [1, 2, 3]})
1544+
>>> result = df.aggregate(
1545+
... [GroupingSet.grouping_sets(
1546+
... [dfn.col("a")], [dfn.col("b")])],
1547+
... [dfn.functions.sum(dfn.col("c")).alias("s"),
1548+
... dfn.functions.grouping(dfn.col("a")),
1549+
... dfn.functions.grouping(dfn.col("b"))],
1550+
... ).sort(
1551+
... dfn.col("a").sort(nulls_first=False),
1552+
... dfn.col("b").sort(nulls_first=False),
1553+
... )
1554+
>>> batches = result.collect()
1555+
>>> pa.concat_arrays(
1556+
... [b.column("s") for b in batches]).to_pylist()
1557+
[3, 3, 4, 2]
1558+
1559+
See Also:
1560+
:py:meth:`rollup`, :py:meth:`cube`,
1561+
:py:func:`~datafusion.functions.grouping`
1562+
"""
1563+
raw_lists = [[e.expr for e in lst] for lst in expr_lists]
1564+
return Expr(_GroupingSetInternal.grouping_sets(*raw_lists))

python/datafusion/functions.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4413,37 +4413,47 @@ def grouping(
44134413
distinct: bool = False,
44144414
filter: Expr | None = None,
44154415
) -> Expr:
4416-
"""Returns 1 if the data is aggregated across the specified column, or 0 otherwise.
4416+
"""Indicates whether a column is aggregated across in the current row.
44174417
4418-
This function is used with ``GROUPING SETS``, ``CUBE``, or ``ROLLUP`` to
4419-
distinguish between aggregated and non-aggregated rows. In a regular
4420-
``GROUP BY`` without grouping sets, it always returns 0.
4418+
Returns 0 when the column is part of the grouping key for that row
4419+
(i.e., the row contains per-group results for that column). Returns 1
4420+
when the column is *not* part of the grouping key (i.e., the row's
4421+
aggregate spans all values of that column).
44214422
4422-
Note: The ``grouping`` aggregate function is rewritten by the query
4423-
optimizer before execution, so it works correctly even though its
4424-
physical plan is not directly implemented.
4423+
This function is meaningful with
4424+
:py:meth:`GroupingSet.rollup <datafusion.expr.GroupingSet.rollup>`,
4425+
:py:meth:`GroupingSet.cube <datafusion.expr.GroupingSet.cube>`, or
4426+
:py:meth:`GroupingSet.grouping_sets <datafusion.expr.GroupingSet.grouping_sets>`,
4427+
where different rows are grouped by different subsets of columns. In a
4428+
regular ``GROUP BY`` without grouping sets every column is always part
4429+
of the key, so ``grouping()`` always returns 0.
44254430
44264431
Args:
44274432
expression: The column to check grouping status for
44284433
distinct: If True, compute on distinct values only
44294434
filter: If provided, only compute against rows for which the filter is True
44304435
44314436
Examples:
4432-
In a simple ``GROUP BY`` (no grouping sets), ``grouping()`` always
4433-
returns 0, indicating the column is part of the grouping key:
4437+
With :py:meth:`~datafusion.expr.GroupingSet.rollup`, the result
4438+
includes both per-group rows (``grouping(a) = 0``) and a
4439+
grand-total row where ``a`` is aggregated across
4440+
(``grouping(a) = 1``):
44344441
44354442
>>> import pyarrow as pa
4443+
>>> from datafusion.expr import GroupingSet
44364444
>>> ctx = dfn.SessionContext()
44374445
>>> df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]})
44384446
>>> result = df.aggregate(
4439-
... [dfn.col("a")],
4440-
... [dfn.functions.grouping(dfn.col("a")),
4441-
... dfn.functions.sum(dfn.col("b")).alias("s")])
4447+
... [GroupingSet.rollup(dfn.col("a"))],
4448+
... [dfn.functions.sum(dfn.col("b")).alias("s"),
4449+
... dfn.functions.grouping(dfn.col("a"))],
4450+
... ).sort(dfn.col("a").sort(nulls_first=False))
44424451
>>> batches = result.collect()
4443-
>>> grouping_vals = pa.concat_arrays(
4444-
... [batch.column(1) for batch in batches]).to_pylist()
4445-
>>> all(v == 0 for v in grouping_vals)
4446-
True
4452+
>>> pa.concat_arrays([b.column(2) for b in batches]).to_pylist()
4453+
[0, 0, 1]
4454+
4455+
See Also:
4456+
:py:class:`~datafusion.expr.GroupingSet`
44474457
"""
44484458
filter_raw = filter.expr if filter is not None else None
44494459
return Expr(f.grouping(expression.expr, distinct=distinct, filter=filter_raw))

python/tests/test_functions.py

Lines changed: 56 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pytest
2323
from datafusion import SessionContext, column, literal
2424
from datafusion import functions as f
25+
from datafusion.expr import GroupingSet
2526

2627
np.seterr(invalid="ignore")
2728

@@ -1837,33 +1838,72 @@ def test_percentile_cont(filter_expr, expected):
18371838
assert result.column(0)[0].as_py() == expected
18381839

18391840

1840-
def test_grouping():
1841+
def test_rollup():
1842+
# With ROLLUP, per-group rows have grouping()=0 and the grand-total row
1843+
# (where the column is aggregated across) has grouping()=1.
18411844
ctx = SessionContext()
18421845
df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]})
1843-
# In a simple GROUP BY (no grouping sets), grouping() returns 0 for all rows.
18441846
result = df.aggregate(
1845-
[column("a")], [f.grouping(column("a")), f.sum(column("b")).alias("s")]
1846-
).collect()
1847-
grouping_col = pa.concat_arrays([batch.column(1) for batch in result]).to_pylist()
1848-
assert all(v == 0 for v in grouping_col)
1847+
[GroupingSet.rollup(column("a"))],
1848+
[f.sum(column("b")).alias("s"), f.grouping(column("a"))],
1849+
).sort(column("a").sort(ascending=True, nulls_first=False))
1850+
batches = result.collect()
1851+
g = pa.concat_arrays([b.column(2) for b in batches]).to_pylist()
1852+
s = pa.concat_arrays([b.column("s") for b in batches]).to_pylist()
1853+
# Two per-group rows (g=0) plus one grand-total row (g=1)
1854+
assert g == [0, 0, 1]
1855+
assert s == [30, 30, 60]
1856+
1857+
1858+
def test_rollup_multi_column():
1859+
# rollup(a, b) produces grouping sets (a, b), (a), ().
1860+
ctx = SessionContext()
1861+
df = ctx.from_pydict({"a": [1, 1, 2], "b": ["x", "y", "x"], "c": [10, 20, 30]})
1862+
result = df.aggregate(
1863+
[GroupingSet.rollup(column("a"), column("b"))],
1864+
[f.sum(column("c")).alias("s")],
1865+
)
1866+
total_rows = sum(b.num_rows for b in result.collect())
1867+
# 3 detail (a,b) + 2 subtotal (a) + 1 grand total = 6
1868+
assert total_rows == 6
18491869

18501870

1851-
def test_grouping_multiple_columns():
1852-
# Verify grouping() works when multiple columns are in the GROUP BY clause.
1871+
def test_cube():
1872+
# cube(a, b) produces all subsets: (a,b), (a), (b), ().
18531873
ctx = SessionContext()
1854-
df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 10, 30], "c": [100, 200, 300]})
1874+
df = ctx.from_pydict({"a": [1, 1, 2], "b": ["x", "y", "x"], "c": [10, 20, 30]})
18551875
result = df.aggregate(
1856-
[column("a"), column("b")],
1876+
[GroupingSet.cube(column("a"), column("b"))],
1877+
[f.sum(column("c")).alias("s")],
1878+
)
1879+
total_rows = sum(b.num_rows for b in result.collect())
1880+
# 3 (a,b) + 2 (a) + 2 (b) + 1 () = 8
1881+
assert total_rows == 8
1882+
1883+
1884+
def test_grouping_sets():
1885+
# GROUPING SETS lets you choose exactly which column subsets to group by.
1886+
# Each row's grouping() value tells you which columns are aggregated across.
1887+
ctx = SessionContext()
1888+
df = ctx.from_pydict({"a": ["x", "x", "y"], "b": ["m", "n", "m"], "c": [1, 2, 3]})
1889+
result = df.aggregate(
1890+
[GroupingSet.grouping_sets([column("a")], [column("b")])],
18571891
[
1892+
f.sum(column("c")).alias("s"),
18581893
f.grouping(column("a")),
18591894
f.grouping(column("b")),
1860-
f.sum(column("c")).alias("s"),
18611895
],
1862-
).collect()
1863-
grouping_a = pa.concat_arrays([batch.column(2) for batch in result]).to_pylist()
1864-
grouping_b = pa.concat_arrays([batch.column(3) for batch in result]).to_pylist()
1865-
assert all(v == 0 for v in grouping_a)
1866-
assert all(v == 0 for v in grouping_b)
1896+
).sort(
1897+
column("a").sort(ascending=True, nulls_first=False),
1898+
column("b").sort(ascending=True, nulls_first=False),
1899+
)
1900+
batches = result.collect()
1901+
ga = pa.concat_arrays([b.column(3) for b in batches]).to_pylist()
1902+
gb = pa.concat_arrays([b.column(4) for b in batches]).to_pylist()
1903+
# Rows grouped by (a): ga=0 (a is a key), gb=1 (b is aggregated across)
1904+
# Rows grouped by (b): ga=1 (a is aggregated across), gb=0 (b is a key)
1905+
assert ga == [0, 0, 1, 1]
1906+
assert gb == [1, 1, 0, 0]
18671907

18681908

18691909
def test_var_population():

0 commit comments

Comments
 (0)