Skip to content

Commit c9183dd

Browse files
timsaucerclaude
andcommitted
Parametrize grouping set tests for rollup and cube
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent c52b49c commit c9183dd

File tree

1 file changed

+28
-27
lines changed

1 file changed

+28
-27
lines changed

python/tests/test_functions.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1838,51 +1838,52 @@ def test_percentile_cont(filter_expr, expected):
18381838
assert result.column(0)[0].as_py() == expected
18391839

18401840

1841-
def test_rollup():
1842-
# With ROLLUP, per-group rows have grouping()=0 and the grand-total row
1843-
# (where the column is aggregated across) has grouping()=1.
1841+
@pytest.mark.parametrize(
1842+
("grouping_set_expr", "expected_grouping", "expected_sums"),
1843+
[
1844+
(GroupingSet.rollup(column("a")), [0, 0, 1], [30, 30, 60]),
1845+
(GroupingSet.cube(column("a")), [0, 0, 1], [30, 30, 60]),
1846+
],
1847+
ids=["rollup", "cube"],
1848+
)
1849+
def test_grouping_set_single_column(
1850+
grouping_set_expr, expected_grouping, expected_sums
1851+
):
18441852
ctx = SessionContext()
18451853
df = ctx.from_pydict({"a": [1, 1, 2], "b": [10, 20, 30]})
18461854
result = df.aggregate(
1847-
[GroupingSet.rollup(column("a"))],
1855+
[grouping_set_expr],
18481856
[f.sum(column("b")).alias("s"), f.grouping(column("a"))],
18491857
).sort(column("a").sort(ascending=True, nulls_first=False))
18501858
batches = result.collect()
18511859
g = pa.concat_arrays([b.column(2) for b in batches]).to_pylist()
18521860
s = pa.concat_arrays([b.column("s") for b in batches]).to_pylist()
1853-
# Two per-group rows (g=0) plus one grand-total row (g=1)
1854-
assert g == [0, 0, 1]
1855-
assert s == [30, 30, 60]
1856-
1857-
1858-
def test_rollup_multi_column():
1859-
# rollup(a, b) produces grouping sets (a, b), (a), ().
1860-
ctx = SessionContext()
1861-
df = ctx.from_pydict({"a": [1, 1, 2], "b": ["x", "y", "x"], "c": [10, 20, 30]})
1862-
result = df.aggregate(
1863-
[GroupingSet.rollup(column("a"), column("b"))],
1864-
[f.sum(column("c")).alias("s")],
1865-
)
1866-
total_rows = sum(b.num_rows for b in result.collect())
1867-
# 3 detail (a,b) + 2 subtotal (a) + 1 grand total = 6
1868-
assert total_rows == 6
1861+
assert g == expected_grouping
1862+
assert s == expected_sums
18691863

18701864

1871-
def test_cube():
1872-
# cube(a, b) produces all subsets: (a,b), (a), (b), ().
1865+
@pytest.mark.parametrize(
1866+
("grouping_set_expr", "expected_rows"),
1867+
[
1868+
# rollup(a, b) => (a,b), (a), () => 3 + 2 + 1 = 6
1869+
(GroupingSet.rollup(column("a"), column("b")), 6),
1870+
# cube(a, b) => (a,b), (a), (b), () => 3 + 2 + 2 + 1 = 8
1871+
(GroupingSet.cube(column("a"), column("b")), 8),
1872+
],
1873+
ids=["rollup", "cube"],
1874+
)
1875+
def test_grouping_set_multi_column(grouping_set_expr, expected_rows):
18731876
ctx = SessionContext()
18741877
df = ctx.from_pydict({"a": [1, 1, 2], "b": ["x", "y", "x"], "c": [10, 20, 30]})
18751878
result = df.aggregate(
1876-
[GroupingSet.cube(column("a"), column("b"))],
1879+
[grouping_set_expr],
18771880
[f.sum(column("c")).alias("s")],
18781881
)
18791882
total_rows = sum(b.num_rows for b in result.collect())
1880-
# 3 (a,b) + 2 (a) + 2 (b) + 1 () = 8
1881-
assert total_rows == 8
1883+
assert total_rows == expected_rows
18821884

18831885

1884-
def test_grouping_sets():
1885-
# GROUPING SETS lets you choose exactly which column subsets to group by.
1886+
def test_grouping_sets_explicit():
18861887
# Each row's grouping() value tells you which columns are aggregated across.
18871888
ctx = SessionContext()
18881889
df = ctx.from_pydict({"a": ["x", "x", "y"], "b": ["m", "n", "m"], "c": [1, 2, 3]})

0 commit comments

Comments
 (0)