@@ -1820,43 +1820,52 @@ def test_conditional_functions(df_with_nulls, expr, expected):
18201820 assert result .column (0 ) == expected
18211821
18221822
1823- def test_percentile_cont ():
1823+ @pytest .mark .parametrize (
1824+ ("filter_expr" , "expected" ),
1825+ [
1826+ (None , 3.0 ),
1827+ (column ("a" ) > literal (1.0 ), 3.5 ),
1828+ ],
1829+ ids = ["no_filter" , "with_filter" ],
1830+ )
1831+ def test_percentile_cont (filter_expr , expected ):
18241832 ctx = SessionContext ()
18251833 df = ctx .from_pydict ({"a" : [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]})
18261834 result = df .aggregate (
1827- [], [f .percentile_cont (column ("a" ), 0.5 ).alias ("v" )]
1835+ [], [f .percentile_cont (column ("a" ), 0.5 , filter = filter_expr ).alias ("v" )]
18281836 ).collect ()[0 ]
1829- assert result .column (0 )[0 ].as_py () == 3.0
1830-
1831-
1832- def test_percentile_cont_with_filter ():
1833- ctx = SessionContext ()
1834- df = ctx .from_pydict ({"a" : [1.0 , 2.0 , 3.0 , 4.0 , 5.0 ]})
1835- result = df .aggregate (
1836- [],
1837- [
1838- f .percentile_cont (
1839- column ("a" ), 0.5 , filter = column ("a" ) > literal (1.0 )
1840- ).alias ("v" )
1841- ],
1842- ).collect ()[0 ]
1843- assert result .column (0 )[0 ].as_py () == 3.5
1837+ assert result .column (0 )[0 ].as_py () == expected
18441838
18451839
18461840def test_grouping ():
18471841 ctx = SessionContext ()
18481842 df = ctx .from_pydict ({"a" : [1 , 1 , 2 ], "b" : [10 , 20 , 30 ]})
18491843 # In a simple GROUP BY (no grouping sets), grouping() returns 0 for all rows.
1850- # Note: grouping() must not be aliased directly in the aggregate expression list
1851- # due to an upstream DataFusion analyzer limitation (the ResolveGroupingFunction
1852- # rule doesn't unwrap Alias nodes). Apply aliases via a follow-up select instead.
18531844 result = df .aggregate (
18541845 [column ("a" )], [f .grouping (column ("a" )), f .sum (column ("b" )).alias ("s" )]
18551846 ).collect ()
18561847 grouping_col = pa .concat_arrays ([batch .column (1 ) for batch in result ]).to_pylist ()
18571848 assert all (v == 0 for v in grouping_col )
18581849
18591850
1851+ def test_grouping_multiple_columns ():
1852+ # Verify grouping() works when multiple columns are in the GROUP BY clause.
1853+ ctx = SessionContext ()
1854+ df = ctx .from_pydict ({"a" : [1 , 1 , 2 ], "b" : [10 , 10 , 30 ], "c" : [100 , 200 , 300 ]})
1855+ result = df .aggregate (
1856+ [column ("a" ), column ("b" )],
1857+ [
1858+ f .grouping (column ("a" )),
1859+ f .grouping (column ("b" )),
1860+ f .sum (column ("c" )).alias ("s" ),
1861+ ],
1862+ ).collect ()
1863+ grouping_a = pa .concat_arrays ([batch .column (2 ) for batch in result ]).to_pylist ()
1864+ grouping_b = pa .concat_arrays ([batch .column (3 ) for batch in result ]).to_pylist ()
1865+ assert all (v == 0 for v in grouping_a )
1866+ assert all (v == 0 for v in grouping_b )
1867+
1868+
18601869def test_var_population ():
18611870 ctx = SessionContext ()
18621871 df = ctx .from_pydict ({"a" : [- 1.0 , 0.0 , 2.0 ]})
0 commit comments