@@ -32,7 +32,7 @@ def test_condition_equal_columns_struct() -> None:
3232 "a" ,
3333 dtype_left = lhs .schema ["a_left" ],
3434 dtype_right = rhs .schema ["a_right" ],
35- max_list_length = 0 ,
35+ max_list_length = None ,
3636 abs_tol = 0.5 ,
3737 rel_tol = 0 ,
3838 )
@@ -67,7 +67,7 @@ def test_condition_equal_columns_different_struct_fields() -> None:
6767 "a" ,
6868 dtype_left = lhs .schema ["a_left" ],
6969 dtype_right = rhs .schema ["a_right" ],
70- max_list_length = 0 ,
70+ max_list_length = None ,
7171 )
7272 )
7373 .to_series ()
@@ -89,60 +89,15 @@ def test_condition_equal_columns_list_array_with_tolerance(
8989 # Arrange
9090 lhs = pl .DataFrame (
9191 {
92- "pk" : [1 , 2 ],
93- "a_left" : [[1.0 , 1.1 ], [2.0 , 2.1 ]],
94- },
95- schema = {"pk" : pl .Int64 , "a_left" : lhs_type },
96- )
97- rhs = pl .DataFrame (
98- {
99- "pk" : [1 , 2 ],
100- "a_right" : [[1.0 , 1.1 ], [2.0 , 2.2 ]],
101- },
102- schema = {"pk" : pl .Int64 , "a_right" : rhs_type },
103- )
104-
105- # Act
106- actual = (
107- lhs .join (rhs , on = "pk" , maintain_order = "left" )
108- .select (
109- condition_equal_columns (
110- "a" ,
111- dtype_left = lhs .schema ["a_left" ],
112- dtype_right = rhs .schema ["a_right" ],
113- abs_tol = 0.5 ,
114- rel_tol = 0 ,
115- max_list_length = 2 ,
116- )
117- )
118- .to_series ()
119- )
120-
121- # Assert: diff is 0.1, within abs_tol=0.5
122- assert actual .to_list () == [True , True ]
123-
124-
125- @pytest .mark .parametrize (
126- "lhs_type" , [pl .Array (pl .Float64 , shape = 2 ), pl .List (pl .Float64 )]
127- )
128- @pytest .mark .parametrize (
129- "rhs_type" , [pl .Array (pl .Float64 , shape = 2 ), pl .List (pl .Float64 )]
130- )
131- def test_condition_equal_columns_list_array_exceeds_tolerance (
132- lhs_type : pl .DataType , rhs_type : pl .DataType
133- ) -> None :
134- # Arrange
135- lhs = pl .DataFrame (
136- {
137- "pk" : [1 , 2 ],
138- "a_left" : [[1.0 , 1.1 ], [2.0 , 2.1 ]],
92+ "pk" : [1 , 2 , 3 ],
93+ "a_left" : [[1.0 , 1.1 ], [2.0 , 2.1 ], [3.0 , 3.0 ]],
13994 },
14095 schema = {"pk" : pl .Int64 , "a_left" : lhs_type },
14196 )
14297 rhs = pl .DataFrame (
14398 {
144- "pk" : [1 , 2 ],
145- "a_right" : [[1.0 , 1.1 ], [2.0 , 2.8 ]],
99+ "pk" : [1 , 2 , 3 ],
100+ "a_right" : [[1.0 , 1.1 ], [2.0 , 2.2 ], [ 3.0 , 3.7 ]],
146101 },
147102 schema = {"pk" : pl .Int64 , "a_right" : rhs_type },
148103 )
@@ -163,8 +118,7 @@ def test_condition_equal_columns_list_array_exceeds_tolerance(
163118 .to_series ()
164119 )
165120
166- # Assert: diff is 0.7, exceeds abs_tol=0.5
167- assert actual .to_list () == [True , False ]
121+ assert actual .to_list () == [True , True , False ]
168122
169123
170124def test_condition_equal_columns_nested_dtype_mismatch () -> None :
@@ -190,7 +144,7 @@ def test_condition_equal_columns_nested_dtype_mismatch() -> None:
190144 "a" ,
191145 dtype_left = lhs .schema ["a_left" ],
192146 dtype_right = rhs .schema ["a_right" ],
193- max_list_length = 0 ,
147+ max_list_length = None ,
194148 )
195149 )
196150 .to_series ()
@@ -223,7 +177,7 @@ def test_condition_equal_columns_exactly_one_nested() -> None:
223177 "a" ,
224178 dtype_left = lhs .schema ["a_left" ],
225179 dtype_right = rhs .schema ["a_right" ],
226- max_list_length = 0 ,
180+ max_list_length = None ,
227181 )
228182 )
229183 .to_series ()
@@ -266,7 +220,7 @@ def test_condition_equal_columns_temporal_tolerance() -> None:
266220 "a" ,
267221 dtype_left = lhs .schema ["a_left" ],
268222 dtype_right = rhs .schema ["a_right" ],
269- max_list_length = 0 ,
223+ max_list_length = None ,
270224 abs_tol_temporal = dt .timedelta (seconds = 2 ),
271225 )
272226 )
@@ -359,7 +313,7 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None:
359313 "a" ,
360314 dtype_left = lhs .schema ["a_left" ],
361315 dtype_right = rhs .schema ["a_right" ],
362- max_list_length = 0 ,
316+ max_list_length = None ,
363317 abs_tol = 0.5 ,
364318 rel_tol = 0 ,
365319 )
@@ -369,6 +323,99 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None:
369323 assert actual .to_list () == [True , False ]
370324
371325
326+ def test_condition_equal_columns_array_different_shapes () -> None :
327+ lhs = pl .DataFrame (
328+ {
329+ "pk" : [1 ],
330+ "a_left" : [[1.0 , 2.0 ]],
331+ },
332+ schema = {"pk" : pl .Int64 , "a_left" : pl .Array (pl .Float64 , shape = 2 )},
333+ )
334+ rhs = pl .DataFrame (
335+ {
336+ "pk" : [1 ],
337+ "a_right" : [[1.0 , 2.0 , 3.0 ]],
338+ },
339+ schema = {"pk" : pl .Int64 , "a_right" : pl .Array (pl .Float64 , shape = 3 )},
340+ )
341+
342+ actual = (
343+ lhs .join (rhs , on = "pk" , maintain_order = "left" )
344+ .select (
345+ condition_equal_columns (
346+ "a" ,
347+ dtype_left = lhs .schema ["a_left" ],
348+ dtype_right = rhs .schema ["a_right" ],
349+ max_list_length = None ,
350+ )
351+ )
352+ .to_series ()
353+ )
354+ assert actual .to_list () == [False ]
355+
356+
357+ def test_condition_equal_columns_empty_arrays () -> None :
358+ lhs = pl .DataFrame (
359+ {
360+ "pk" : [1 , 2 ],
361+ "a_left" : [[], None ],
362+ },
363+ schema = {"pk" : pl .Int64 , "a_left" : pl .Array (pl .Float64 , shape = 0 )},
364+ )
365+ rhs = pl .DataFrame (
366+ {
367+ "pk" : [1 , 2 ],
368+ "a_right" : [[], None ],
369+ },
370+ schema = {"pk" : pl .Int64 , "a_right" : pl .Array (pl .Float64 , shape = 0 )},
371+ )
372+
373+ actual = (
374+ lhs .join (rhs , on = "pk" , maintain_order = "left" )
375+ .select (
376+ condition_equal_columns (
377+ "a" ,
378+ dtype_left = lhs .schema ["a_left" ],
379+ dtype_right = rhs .schema ["a_right" ],
380+ max_list_length = None ,
381+ )
382+ )
383+ .to_series ()
384+ )
385+ assert actual .to_list () == [True , True ]
386+
387+
388+ def test_condition_equal_columns_empty_lists () -> None :
389+ lhs = pl .DataFrame (
390+ {
391+ "pk" : [1 , 2 , 3 ],
392+ "a_left" : [[], None , []],
393+ },
394+ schema = {"pk" : pl .Int64 , "a_left" : pl .List (pl .Float64 )},
395+ )
396+ rhs = pl .DataFrame (
397+ {
398+ "pk" : [1 , 2 , 3 ],
399+ "a_right" : [[], None , None ],
400+ },
401+ schema = {"pk" : pl .Int64 , "a_right" : pl .List (pl .Float64 )},
402+ )
403+
404+ actual = (
405+ lhs .join (rhs , on = "pk" , maintain_order = "left" )
406+ .select (
407+ condition_equal_columns (
408+ "a" ,
409+ dtype_left = lhs .schema ["a_left" ],
410+ dtype_right = rhs .schema ["a_right" ],
411+ max_list_length = 0 ,
412+ )
413+ )
414+ .to_series ()
415+ )
416+ assert actual .to_list () == [True , True , False ]
417+
418+
372419@pytest .mark .parametrize (
373420 ("dtype_left" , "dtype_right" , "can_compare_dtypes" ),
374421 [
0 commit comments