@@ -19,27 +19,26 @@ def condition_equal_rows(
1919 columns : list [str ],
2020 schema_left : pl .Schema ,
2121 schema_right : pl .Schema ,
22+ max_list_lengths_by_column : Mapping [str , int ],
2223 abs_tol_by_column : Mapping [str , float ],
2324 rel_tol_by_column : Mapping [str , float ],
2425 abs_tol_temporal_by_column : Mapping [str , dt .timedelta ],
25- max_list_lengths_by_column : Mapping [str , int ] | None = None ,
2626) -> pl .Expr :
2727 """Build an expression whether two rows are equal, based on all columns' data
2828 types."""
2929 if not columns :
3030 return pl .lit (True )
3131
32- _max_list_lengths = max_list_lengths_by_column or {}
3332 return pl .all_horizontal (
3433 [
3534 condition_equal_columns (
3635 column = column ,
3736 dtype_left = schema_left [column ],
3837 dtype_right = schema_right [column ],
38+ max_list_length = max_list_lengths_by_column .get (column , 0 ),
3939 abs_tol = abs_tol_by_column [column ],
4040 rel_tol = rel_tol_by_column [column ],
4141 abs_tol_temporal = abs_tol_temporal_by_column [column ],
42- max_list_length = _max_list_lengths .get (column , 0 ),
4342 )
4443 for column in columns
4544 ]
@@ -50,10 +49,10 @@ def condition_equal_columns(
5049 column : str ,
5150 dtype_left : pl .DataType ,
5251 dtype_right : pl .DataType ,
52+ max_list_length : int ,
5353 abs_tol : float = ABS_TOL_DEFAULT ,
5454 rel_tol : float = REL_TOL_DEFAULT ,
5555 abs_tol_temporal : dt .timedelta = ABS_TOL_TEMPORAL_DEFAULT ,
56- max_list_length : int = 0 ,
5756) -> pl .Expr :
5857 """Build an expression whether two columns are equal, depending on the columns' data
5958 types."""
@@ -62,10 +61,10 @@ def condition_equal_columns(
6261 col_right = pl .col (f"{ column } _{ Side .RIGHT } " ),
6362 dtype_left = dtype_left ,
6463 dtype_right = dtype_right ,
64+ max_list_length = max_list_length ,
6565 abs_tol = abs_tol ,
6666 rel_tol = rel_tol ,
6767 abs_tol_temporal = abs_tol_temporal ,
68- max_list_length = max_list_length ,
6968 )
7069
7170
@@ -97,10 +96,10 @@ def _compare_columns(
9796 col_right : pl .Expr ,
9897 dtype_left : DataType | DataTypeClass ,
9998 dtype_right : DataType | DataTypeClass ,
99+ max_list_length : int ,
100100 abs_tol : float ,
101101 rel_tol : float ,
102102 abs_tol_temporal : dt .timedelta ,
103- max_list_length : int = 0 ,
104103) -> pl .Expr :
105104 """Build an expression whether two expressions yield the same value.
106105
@@ -129,6 +128,7 @@ def _compare_columns(
129128 col_right = col_right .struct [field ],
130129 dtype_left = fields_left [field ],
131130 dtype_right = fields_right [field ],
131+ max_list_length = max_list_length ,
132132 abs_tol = abs_tol ,
133133 rel_tol = rel_tol ,
134134 abs_tol_temporal = abs_tol_temporal ,
@@ -139,7 +139,7 @@ def _compare_columns(
139139 elif isinstance (dtype_left , pl .List | pl .Array ) and isinstance (
140140 dtype_right , pl .List | pl .Array
141141 ):
142- result = _compare_sequence_columns (
142+ return _compare_sequence_columns (
143143 col_left = col_left ,
144144 col_right = col_right ,
145145 dtype_left = dtype_left ,
@@ -149,8 +149,6 @@ def _compare_columns(
149149 rel_tol = rel_tol ,
150150 abs_tol_temporal = abs_tol_temporal ,
151151 )
152- if result is not None :
153- return result
154152
155153 if (
156154 isinstance (dtype_left , pl .Enum )
@@ -168,6 +166,7 @@ def _compare_columns(
168166 abs_tol = abs_tol ,
169167 rel_tol = rel_tol ,
170168 abs_tol_temporal = abs_tol_temporal ,
169+ max_list_length = max_list_length ,
171170 )
172171
173172 return _compare_primitive_columns (
@@ -190,13 +189,8 @@ def _compare_sequence_columns(
190189 abs_tol : float ,
191190 rel_tol : float ,
192191 abs_tol_temporal : dt .timedelta ,
193- ) -> pl .Expr | None :
194- """Compare Array/List columns element-wise with tolerance.
195-
196- Returns ``None`` if the comparison cannot be performed element-wise (e.g. List vs
197- List without a known ``max_list_length``), signalling to the caller that it should
198- fall back to primitive comparison.
199- """
192+ ) -> pl .Expr :
193+ """Compare Array/List columns element-wise with tolerance."""
200194 assert isinstance (dtype_left , pl .List | pl .Array )
201195 assert isinstance (dtype_right , pl .List | pl .Array )
202196 inner_left = dtype_left .inner
@@ -207,29 +201,27 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
207201 return col .arr .get (i )
208202 return col .list .get (i , null_on_oob = True )
209203
210- n : int | None = None
211- length_check : pl .Expr | None = None
204+ n_elements : int | None = None
205+ has_same_length : pl .Expr | None = None
212206
213207 if isinstance (dtype_left , pl .Array ) and isinstance (dtype_right , pl .Array ):
214208 if dtype_left .shape != dtype_right .shape :
215209 return pl .repeat (pl .lit (False ), pl .len ())
216- n = dtype_left .shape [0 ]
210+ n_elements = dtype_left .shape [0 ]
217211 elif isinstance (dtype_left , pl .Array ) and isinstance (dtype_right , pl .List ):
218- n = dtype_left .shape [0 ]
219- length_check = col_right .list .len ().eq (pl .lit (n ))
212+ n_elements = dtype_left .shape [0 ]
213+ has_same_length = col_right .list .len ().eq (pl .lit (n_elements ))
220214 elif isinstance (dtype_left , pl .List ) and isinstance (dtype_right , pl .Array ):
221- n = dtype_right .shape [0 ]
222- length_check = col_left .list .len ().eq (pl .lit (n ))
215+ n_elements = dtype_right .shape [0 ]
216+ has_same_length = col_left .list .len ().eq (pl .lit (n_elements ))
223217 else :
224218 # List vs List
225- if max_list_length == 0 :
226- return None
227- n = max_list_length
228- length_check = col_left .list .len ().eq_missing (col_right .list .len ())
229-
230- if n == 0 :
231- if length_check is not None :
232- return _eq_missing (length_check , col_left , col_right )
219+ n_elements = max_list_length
220+ has_same_length = col_left .list .len ().eq_missing (col_right .list .len ())
221+
222+ if n_elements == 0 :
223+ if has_same_length is not None :
224+ return _eq_missing (has_same_length , col_left , col_right )
233225 return _eq_missing (pl .lit (True ), col_left , col_right )
234226
235227 elements_match = pl .all_horizontal (
@@ -242,13 +234,14 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
242234 abs_tol = abs_tol ,
243235 rel_tol = rel_tol ,
244236 abs_tol_temporal = abs_tol_temporal ,
237+ max_list_length = max_list_length ,
245238 )
246- for i in range (n )
239+ for i in range (n_elements )
247240 ]
248241 )
249242
250- if length_check is not None :
251- return _eq_missing (length_check & elements_match , col_left , col_right )
243+ if has_same_length is not None :
244+ return _eq_missing (has_same_length & elements_match , col_left , col_right )
252245 return elements_match
253246
254247
0 commit comments