@@ -3675,57 +3675,78 @@ def test_explain_with_format(capsys, fmt, verbose, analyze, expected_substring):
36753675 assert expected_substring in captured .out
36763676
36773677
3678- def test_window ():
3678+ @pytest .mark .parametrize (
3679+ ("window_exprs" , "expected_columns" ),
3680+ [
3681+ pytest .param (
3682+ lambda : [
3683+ f .row_number (partition_by = [column ("b" )], order_by = [column ("a" )]).alias (
3684+ "rn"
3685+ ),
3686+ ],
3687+ {"rn" : [1 , 2 , 1 ]},
3688+ id = "single window expression" ,
3689+ ),
3690+ pytest .param (
3691+ lambda : [
3692+ f .row_number (partition_by = [column ("b" )], order_by = [column ("a" )]).alias (
3693+ "rn"
3694+ ),
3695+ f .rank (partition_by = [column ("b" )], order_by = [column ("a" )]).alias ("rnk" ),
3696+ ],
3697+ {"rn" : [1 , 2 , 1 ], "rnk" : [1 , 2 , 1 ]},
3698+ id = "multiple window expressions" ,
3699+ ),
3700+ ],
3701+ )
3702+ def test_window (window_exprs , expected_columns ):
36793703 ctx = SessionContext ()
36803704 df = ctx .from_pydict ({"a" : [1 , 2 , 3 ], "b" : ["x" , "x" , "y" ]})
36813705 result = (
3682- df .window (
3683- f .row_number (partition_by = [column ("b" )], order_by = [column ("a" )]).alias ("rn" )
3684- )
3685- .sort (column ("a" ).sort (ascending = True ))
3686- .collect ()[0 ]
3706+ df .window (* window_exprs ()).sort (column ("a" ).sort (ascending = True )).collect ()[0 ]
36873707 )
3688- assert "rn" in result .schema .names
3689- assert result .column (result .schema .get_field_index ("rn" )).to_pylist () == [1 , 2 , 1 ]
3690-
3691-
3692- def test_unnest_columns_with_recursions ():
3693- ctx = SessionContext ()
3694- df = ctx .from_pydict ({"a" : [[1 , 2 ], [3 ]], "b" : ["x" , "y" ]})
3695- # Basic unnest still works
3696- result = df .unnest_columns ("a" ).collect ()[0 ]
3697- assert result .column (0 ).to_pylist () == [1 , 2 , 3 ]
3698- # With explicit recursion options
3699- result = df .unnest_columns ("a" , recursions = [("a" , "a" , 1 )]).collect ()[0 ]
3700- assert result .column (0 ).to_pylist () == [1 , 2 , 3 ]
3701-
3702-
3703- def test_unnest_columns_with_deep_recursion ():
3704- ctx = SessionContext ()
3705- # Nested list of lists — requires depth > 1 to fully flatten
3706- df = ctx .from_pydict ({"a" : [[[1 , 2 ], [3 ]], [[4 ]]], "b" : ["x" , "y" ]})
3707- # Depth 1 unnests the outer list, leaving inner lists intact
3708- result = df .unnest_columns ("a" , recursions = [("a" , "a" , 1 )]).collect ()[0 ]
3709- assert result .column (0 ).to_pylist () == [[1 , 2 ], [3 ], [4 ]]
3710- # Depth 2 fully flattens
3711- result = df .unnest_columns ("a" , recursions = [("a" , "a" , 2 )]).collect ()[0 ]
3712- assert result .column (0 ).to_pylist () == [1 , 2 , 3 , 4 ]
3708+ for col_name , expected_values in expected_columns .items ():
3709+ assert col_name in result .schema .names
3710+ assert (
3711+ result .column (result .schema .get_field_index (col_name )).to_pylist ()
3712+ == expected_values
3713+ )
37133714
37143715
3715- def test_window_multiple_expressions ():
3716+ @pytest .mark .parametrize (
3717+ ("input_data" , "recursions" , "expected_a" ),
3718+ [
3719+ pytest .param (
3720+ {"a" : [[1 , 2 ], [3 ]], "b" : ["x" , "y" ]},
3721+ None ,
3722+ [1 , 2 , 3 ],
3723+ id = "basic unnest without recursions" ,
3724+ ),
3725+ pytest .param (
3726+ {"a" : [[1 , 2 ], [3 ]], "b" : ["x" , "y" ]},
3727+ [("a" , "a" , 1 )],
3728+ [1 , 2 , 3 ],
3729+ id = "explicit depth 1 matches basic unnest" ,
3730+ ),
3731+ pytest .param (
3732+ {"a" : [[[1 , 2 ], [3 ]], [[4 ]]], "b" : ["x" , "y" ]},
3733+ [("a" , "a" , 1 )],
3734+ [[1 , 2 ], [3 ], [4 ]],
3735+ id = "depth 1 on nested lists keeps inner lists" ,
3736+ ),
3737+ pytest .param (
3738+ {"a" : [[[1 , 2 ], [3 ]], [[4 ]]], "b" : ["x" , "y" ]},
3739+ [("a" , "a" , 2 )],
3740+ [1 , 2 , 3 , 4 ],
3741+ id = "depth 2 fully flattens nested lists" ,
3742+ ),
3743+ ],
3744+ )
3745+ def test_unnest_columns_with_recursions (input_data , recursions , expected_a ):
37163746 ctx = SessionContext ()
3717- df = ctx .from_pydict ({"a" : [1 , 2 , 3 ], "b" : ["x" , "x" , "y" ]})
3718- result = (
3719- df .window (
3720- f .row_number (partition_by = [column ("b" )], order_by = [column ("a" )]).alias (
3721- "rn"
3722- ),
3723- f .rank (partition_by = [column ("b" )], order_by = [column ("a" )]).alias ("rnk" ),
3724- )
3725- .sort (column ("a" ).sort (ascending = True ))
3726- .collect ()[0 ]
3727- )
3728- assert "rn" in result .schema .names
3729- assert "rnk" in result .schema .names
3730- assert result .column (result .schema .get_field_index ("rn" )).to_pylist () == [1 , 2 , 1 ]
3731- assert result .column (result .schema .get_field_index ("rnk" )).to_pylist () == [1 , 2 , 1 ]
3747+ df = ctx .from_pydict (input_data )
3748+ kwargs = {}
3749+ if recursions is not None :
3750+ kwargs ["recursions" ] = recursions
3751+ result = df .unnest_columns ("a" , ** kwargs ).collect ()[0 ]
3752+ assert result .column (0 ).to_pylist () == expected_a
0 commit comments