Skip to content

Commit 410f4c4

Browse files
timsaucerclaude
andcommitted
Accept string column names in GroupingSet factory methods
GroupingSet.rollup(), .cube(), and .grouping_sets() now accept both Expr objects and string column names, consistent with DataFrame.aggregate(). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 74c1485 commit 410f4c4

File tree

2 files changed

+32
-16
lines changed

2 files changed

+32
-16
lines changed

python/datafusion/expr.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1445,7 +1445,7 @@ class GroupingSet:
14451445
"""
14461446

14471447
@staticmethod
1448-
def rollup(*exprs: Expr) -> Expr:
1448+
def rollup(*exprs: Expr | str) -> Expr:
14491449
"""Create a ``ROLLUP`` grouping set for use with ``aggregate()``.
14501450
14511451
``ROLLUP`` generates all prefixes of the given column list as
@@ -1455,7 +1455,8 @@ def rollup(*exprs: Expr) -> Expr:
14551455
This is equivalent to ``GROUP BY ROLLUP(a, b)`` in SQL.
14561456
14571457
Args:
1458-
*exprs: Column expressions to include in the rollup.
1458+
*exprs: Column expressions or column name strings to
1459+
include in the rollup.
14591460
14601461
Examples:
14611462
>>> import datafusion as dfn
@@ -1474,11 +1475,11 @@ def rollup(*exprs: Expr) -> Expr:
14741475
:py:meth:`cube`, :py:meth:`grouping_sets`,
14751476
:py:func:`~datafusion.functions.grouping`
14761477
"""
1477-
args = [e.expr for e in exprs]
1478+
args = [_to_raw_expr(e) for e in exprs]
14781479
return Expr(expr_internal.GroupingSet.rollup(*args))
14791480

14801481
@staticmethod
1481-
def cube(*exprs: Expr) -> Expr:
1482+
def cube(*exprs: Expr | str) -> Expr:
14821483
"""Create a ``CUBE`` grouping set for use with ``aggregate()``.
14831484
14841485
``CUBE`` generates all possible subsets of the given column list
@@ -1488,7 +1489,8 @@ def cube(*exprs: Expr) -> Expr:
14881489
This is equivalent to ``GROUP BY CUBE(a, b)`` in SQL.
14891490
14901491
Args:
1491-
*exprs: Column expressions to include in the cube.
1492+
*exprs: Column expressions or column name strings to
1493+
include in the cube.
14921494
14931495
Examples:
14941496
With a single column, ``cube`` behaves identically to
@@ -1510,23 +1512,25 @@ def cube(*exprs: Expr) -> Expr:
15101512
:py:meth:`rollup`, :py:meth:`grouping_sets`,
15111513
:py:func:`~datafusion.functions.grouping`
15121514
"""
1513-
args = [e.expr for e in exprs]
1515+
args = [_to_raw_expr(e) for e in exprs]
15141516
return Expr(expr_internal.GroupingSet.cube(*args))
15151517

15161518
@staticmethod
1517-
def grouping_sets(*expr_lists: list[Expr]) -> Expr:
1519+
def grouping_sets(*expr_lists: list[Expr | str]) -> Expr:
15181520
"""Create explicit grouping sets for use with ``aggregate()``.
15191521
1520-
Each argument is a list of column expressions representing one
1521-
grouping set. For example, ``grouping_sets([a], [b])`` groups
1522-
by ``a`` alone and by ``b`` alone in a single query.
1522+
Each argument is a list of column expressions or column name
1523+
strings representing one grouping set. For example,
1524+
``grouping_sets([a], [b])`` groups by ``a`` alone and by ``b``
1525+
alone in a single query.
15231526
15241527
This is equivalent to ``GROUP BY GROUPING SETS ((a), (b))`` in
15251528
SQL.
15261529
15271530
Args:
15281531
*expr_lists: Each positional argument is a list of
1529-
expressions forming one grouping set.
1532+
expressions or column name strings forming one
1533+
grouping set.
15301534
15311535
Examples:
15321536
>>> import datafusion as dfn
@@ -1552,5 +1556,5 @@ def grouping_sets(*expr_lists: list[Expr]) -> Expr:
15521556
:py:meth:`rollup`, :py:meth:`cube`,
15531557
:py:func:`~datafusion.functions.grouping`
15541558
"""
1555-
raw_lists = [[e.expr for e in lst] for lst in expr_lists]
1559+
raw_lists = [[_to_raw_expr(e) for e in lst] for lst in expr_lists]
15561560
return Expr(expr_internal.GroupingSet.grouping_sets(*raw_lists))

python/tests/test_functions.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1844,8 +1844,10 @@ def test_percentile_cont(func, filter_expr, expected):
18441844
[
18451845
(GroupingSet.rollup(column("a")), [0, 0, 1], [30, 30, 60]),
18461846
(GroupingSet.cube(column("a")), [0, 0, 1], [30, 30, 60]),
1847+
(GroupingSet.rollup("a"), [0, 0, 1], [30, 30, 60]),
1848+
(GroupingSet.cube("a"), [0, 0, 1], [30, 30, 60]),
18471849
],
1848-
ids=["rollup", "cube"],
1850+
ids=["rollup", "cube", "rollup_str", "cube_str"],
18491851
)
18501852
def test_grouping_set_single_column(
18511853
grouping_set_expr, expected_grouping, expected_sums
@@ -1870,8 +1872,10 @@ def test_grouping_set_single_column(
18701872
(GroupingSet.rollup(column("a"), column("b")), 6),
18711873
# cube(a, b) => (a,b), (a), (b), () => 3 + 2 + 2 + 1 = 8
18721874
(GroupingSet.cube(column("a"), column("b")), 8),
1875+
(GroupingSet.rollup("a", "b"), 6),
1876+
(GroupingSet.cube("a", "b"), 8),
18731877
],
1874-
ids=["rollup", "cube"],
1878+
ids=["rollup", "cube", "rollup_str", "cube_str"],
18751879
)
18761880
def test_grouping_set_multi_column(grouping_set_expr, expected_rows):
18771881
ctx = SessionContext()
@@ -1884,12 +1888,20 @@ def test_grouping_set_multi_column(grouping_set_expr, expected_rows):
18841888
assert total_rows == expected_rows
18851889

18861890

1887-
def test_grouping_sets_explicit():
1891+
@pytest.mark.parametrize(
1892+
"grouping_set_expr",
1893+
[
1894+
GroupingSet.grouping_sets([column("a")], [column("b")]),
1895+
GroupingSet.grouping_sets(["a"], ["b"]),
1896+
],
1897+
ids=["expr", "str"],
1898+
)
1899+
def test_grouping_sets_explicit(grouping_set_expr):
18881900
# Each row's grouping() value tells you which columns are aggregated across.
18891901
ctx = SessionContext()
18901902
df = ctx.from_pydict({"a": ["x", "x", "y"], "b": ["m", "n", "m"], "c": [1, 2, 3]})
18911903
result = df.aggregate(
1892-
[GroupingSet.grouping_sets([column("a")], [column("b")])],
1904+
[grouping_set_expr],
18931905
[
18941906
f.sum(column("c")).alias("s"),
18951907
f.grouping(column("a")),

0 commit comments

Comments
 (0)