@@ -1838,51 +1838,52 @@ def test_percentile_cont(filter_expr, expected):
18381838 assert result .column (0 )[0 ].as_py () == expected
18391839
18401840
1841- def test_rollup ():
1842- # With ROLLUP, per-group rows have grouping()=0 and the grand-total row
1843- # (where the column is aggregated across) has grouping()=1.
1841+ @pytest .mark .parametrize (
1842+ ("grouping_set_expr" , "expected_grouping" , "expected_sums" ),
1843+ [
1844+ (GroupingSet .rollup (column ("a" )), [0 , 0 , 1 ], [30 , 30 , 60 ]),
1845+ (GroupingSet .cube (column ("a" )), [0 , 0 , 1 ], [30 , 30 , 60 ]),
1846+ ],
1847+ ids = ["rollup" , "cube" ],
1848+ )
1849+ def test_grouping_set_single_column (
1850+ grouping_set_expr , expected_grouping , expected_sums
1851+ ):
18441852 ctx = SessionContext ()
18451853 df = ctx .from_pydict ({"a" : [1 , 1 , 2 ], "b" : [10 , 20 , 30 ]})
18461854 result = df .aggregate (
1847- [GroupingSet . rollup ( column ( "a" )) ],
1855+ [grouping_set_expr ],
18481856 [f .sum (column ("b" )).alias ("s" ), f .grouping (column ("a" ))],
18491857 ).sort (column ("a" ).sort (ascending = True , nulls_first = False ))
18501858 batches = result .collect ()
18511859 g = pa .concat_arrays ([b .column (2 ) for b in batches ]).to_pylist ()
18521860 s = pa .concat_arrays ([b .column ("s" ) for b in batches ]).to_pylist ()
1853- # Two per-group rows (g=0) plus one grand-total row (g=1)
1854- assert g == [0 , 0 , 1 ]
1855- assert s == [30 , 30 , 60 ]
1856-
1857-
1858- def test_rollup_multi_column ():
1859- # rollup(a, b) produces grouping sets (a, b), (a), ().
1860- ctx = SessionContext ()
1861- df = ctx .from_pydict ({"a" : [1 , 1 , 2 ], "b" : ["x" , "y" , "x" ], "c" : [10 , 20 , 30 ]})
1862- result = df .aggregate (
1863- [GroupingSet .rollup (column ("a" ), column ("b" ))],
1864- [f .sum (column ("c" )).alias ("s" )],
1865- )
1866- total_rows = sum (b .num_rows for b in result .collect ())
1867- # 3 detail (a,b) + 2 subtotal (a) + 1 grand total = 6
1868- assert total_rows == 6
1861+ assert g == expected_grouping
1862+ assert s == expected_sums
18691863
18701864
1871- def test_cube ():
1872- # cube(a, b) produces all subsets: (a,b), (a), (b), ().
1865+ @pytest .mark .parametrize (
1866+ ("grouping_set_expr" , "expected_rows" ),
1867+ [
1868+ # rollup(a, b) => (a,b), (a), () => 3 + 2 + 1 = 6
1869+ (GroupingSet .rollup (column ("a" ), column ("b" )), 6 ),
1870+ # cube(a, b) => (a,b), (a), (b), () => 3 + 2 + 2 + 1 = 8
1871+ (GroupingSet .cube (column ("a" ), column ("b" )), 8 ),
1872+ ],
1873+ ids = ["rollup" , "cube" ],
1874+ )
1875+ def test_grouping_set_multi_column (grouping_set_expr , expected_rows ):
18731876 ctx = SessionContext ()
18741877 df = ctx .from_pydict ({"a" : [1 , 1 , 2 ], "b" : ["x" , "y" , "x" ], "c" : [10 , 20 , 30 ]})
18751878 result = df .aggregate (
1876- [GroupingSet . cube ( column ( "a" ), column ( "b" )) ],
1879+ [grouping_set_expr ],
18771880 [f .sum (column ("c" )).alias ("s" )],
18781881 )
18791882 total_rows = sum (b .num_rows for b in result .collect ())
1880- # 3 (a,b) + 2 (a) + 2 (b) + 1 () = 8
1881- assert total_rows == 8
1883+ assert total_rows == expected_rows
18821884
18831885
1884- def test_grouping_sets ():
1885- # GROUPING SETS lets you choose exactly which column subsets to group by.
1886+ def test_grouping_sets_explicit ():
18861887 # Each row's grouping() value tells you which columns are aggregated across.
18871888 ctx = SessionContext ()
18881889 df = ctx .from_pydict ({"a" : ["x" , "x" , "y" ], "b" : ["m" , "n" , "m" ], "c" : [1 , 2 , 3 ]})
0 commit comments