@@ -19,6 +19,7 @@ def condition_equal_rows(
1919 columns : list [str ],
2020 schema_left : pl .Schema ,
2121 schema_right : pl .Schema ,
22+ max_list_lengths_by_column : Mapping [str , int ],
2223 abs_tol_by_column : Mapping [str , float ],
2324 rel_tol_by_column : Mapping [str , float ],
2425 abs_tol_temporal_by_column : Mapping [str , dt .timedelta ],
@@ -34,6 +35,7 @@ def condition_equal_rows(
3435 column = column ,
3536 dtype_left = schema_left [column ],
3637 dtype_right = schema_right [column ],
38+ max_list_length = max_list_lengths_by_column .get (column ),
3739 abs_tol = abs_tol_by_column [column ],
3840 rel_tol = rel_tol_by_column [column ],
3941 abs_tol_temporal = abs_tol_temporal_by_column [column ],
@@ -47,6 +49,7 @@ def condition_equal_columns(
4749 column : str ,
4850 dtype_left : pl .DataType ,
4951 dtype_right : pl .DataType ,
52+ max_list_length : int | None ,
5053 abs_tol : float = ABS_TOL_DEFAULT ,
5154 rel_tol : float = REL_TOL_DEFAULT ,
5255 abs_tol_temporal : dt .timedelta = ABS_TOL_TEMPORAL_DEFAULT ,
@@ -58,6 +61,7 @@ def condition_equal_columns(
5861 col_right = pl .col (f"{ column } _{ Side .RIGHT } " ),
5962 dtype_left = dtype_left ,
6063 dtype_right = dtype_right ,
64+ max_list_length = max_list_length ,
6165 abs_tol = abs_tol ,
6266 rel_tol = rel_tol ,
6367 abs_tol_temporal = abs_tol_temporal ,
@@ -92,6 +96,7 @@ def _compare_columns(
9296 col_right : pl .Expr ,
9397 dtype_left : DataType | DataTypeClass ,
9498 dtype_right : DataType | DataTypeClass ,
99+ max_list_length : int | None ,
95100 abs_tol : float ,
96101 rel_tol : float ,
97102 abs_tol_temporal : dt .timedelta ,
@@ -123,6 +128,7 @@ def _compare_columns(
123128 col_right = col_right .struct [field ],
124129 dtype_left = fields_left [field ],
125130 dtype_right = fields_right [field ],
131+ max_list_length = max_list_length ,
126132 abs_tol = abs_tol ,
127133 rel_tol = rel_tol ,
128134 abs_tol_temporal = abs_tol_temporal ,
@@ -133,10 +139,16 @@ def _compare_columns(
133139 elif isinstance (dtype_left , pl .List | pl .Array ) and isinstance (
134140 dtype_right , pl .List | pl .Array
135141 ):
136- # As of polars 1.28, there is no way to access another column within
137- # `list.eval`. Hence, we necessarily need to resort to a primitive
138- # comparison in this case.
139- pass
142+ return _compare_sequence_columns (
143+ col_left = col_left ,
144+ col_right = col_right ,
145+ dtype_left = dtype_left ,
146+ dtype_right = dtype_right ,
147+ max_list_length = max_list_length ,
148+ abs_tol = abs_tol ,
149+ rel_tol = rel_tol ,
150+ abs_tol_temporal = abs_tol_temporal ,
151+ )
140152
141153 if (
142154 isinstance (dtype_left , pl .Enum )
@@ -154,6 +166,7 @@ def _compare_columns(
154166 abs_tol = abs_tol ,
155167 rel_tol = rel_tol ,
156168 abs_tol_temporal = abs_tol_temporal ,
169+ max_list_length = max_list_length ,
157170 )
158171
159172 return _compare_primitive_columns (
@@ -167,6 +180,67 @@ def _compare_columns(
167180 )
168181
169182
183+ def _compare_sequence_columns (
184+ col_left : pl .Expr ,
185+ col_right : pl .Expr ,
186+ dtype_left : pl .List | pl .Array ,
187+ dtype_right : pl .List | pl .Array ,
188+ max_list_length : int | None ,
189+ abs_tol : float ,
190+ rel_tol : float ,
191+ abs_tol_temporal : dt .timedelta ,
192+ ) -> pl .Expr :
193+ """Compare Array/List columns element-wise with tolerance."""
194+ n_elements : int
195+ has_same_length : pl .Expr
196+
197+ if isinstance (dtype_left , pl .Array ) and isinstance (dtype_right , pl .Array ):
198+ if dtype_left .shape != dtype_right .shape :
199+ return pl .repeat (pl .lit (False ), pl .len ())
200+ n_elements = dtype_left .shape [0 ]
201+ has_same_length = pl .repeat (pl .lit (True ), pl .len ())
202+ elif isinstance (dtype_left , pl .Array ) and isinstance (dtype_right , pl .List ):
203+ n_elements = dtype_left .shape [0 ]
204+ has_same_length = col_right .list .len ().eq (pl .lit (n_elements ))
205+ elif isinstance (dtype_left , pl .List ) and isinstance (dtype_right , pl .Array ):
206+ n_elements = dtype_right .shape [0 ]
207+ has_same_length = col_left .list .len ().eq (pl .lit (n_elements ))
208+ else : # pl.List vs pl.List
209+ if not isinstance (max_list_length , int ):
210+ # Fallback for nested list comparisons where no max_list_length is
211+ # available: perform a direct equality comparison without element-wise
212+ # unrolling.
213+ return _eq_missing (col_left .eq_missing (col_right ), col_left , col_right )
214+ n_elements = max_list_length
215+ has_same_length = col_left .list .len ().eq_missing (col_right .list .len ())
216+
217+ if n_elements == 0 :
218+ return _eq_missing (pl .lit (True ), col_left , col_right )
219+
220+ def _get_element (col : pl .Expr , dtype : DataType | DataTypeClass , i : int ) -> pl .Expr :
221+ if isinstance (dtype , pl .Array ):
222+ return col .arr .get (i )
223+ return col .list .get (i , null_on_oob = True )
224+
225+ elements_match = pl .all_horizontal (
226+ [
227+ _compare_columns (
228+ col_left = _get_element (col_left , dtype_left , i ),
229+ col_right = _get_element (col_right , dtype_right , i ),
230+ dtype_left = dtype_left .inner ,
231+ dtype_right = dtype_right .inner ,
232+ abs_tol = abs_tol ,
233+ rel_tol = rel_tol ,
234+ abs_tol_temporal = abs_tol_temporal ,
235+ max_list_length = None ,
236+ )
237+ for i in range (n_elements )
238+ ]
239+ )
240+
241+ return _eq_missing (has_same_length & elements_match , col_left , col_right )
242+
243+
170244def _compare_primitive_columns (
171245 col_left : pl .Expr ,
172246 col_right : pl .Expr ,
0 commit comments