Skip to content

Commit 70d5369

Browse files
feat: Perform tolerance-based comparison for lists and arrays (#19)
1 parent 2ae4c11 commit 70d5369

File tree

3 files changed

+303
-11
lines changed

3 files changed

+303
-11
lines changed

diffly/_conditions.py

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def condition_equal_rows(
1919
columns: list[str],
2020
schema_left: pl.Schema,
2121
schema_right: pl.Schema,
22+
max_list_lengths_by_column: Mapping[str, int],
2223
abs_tol_by_column: Mapping[str, float],
2324
rel_tol_by_column: Mapping[str, float],
2425
abs_tol_temporal_by_column: Mapping[str, dt.timedelta],
@@ -34,6 +35,7 @@ def condition_equal_rows(
3435
column=column,
3536
dtype_left=schema_left[column],
3637
dtype_right=schema_right[column],
38+
max_list_length=max_list_lengths_by_column.get(column),
3739
abs_tol=abs_tol_by_column[column],
3840
rel_tol=rel_tol_by_column[column],
3941
abs_tol_temporal=abs_tol_temporal_by_column[column],
@@ -47,6 +49,7 @@ def condition_equal_columns(
4749
column: str,
4850
dtype_left: pl.DataType,
4951
dtype_right: pl.DataType,
52+
max_list_length: int | None,
5053
abs_tol: float = ABS_TOL_DEFAULT,
5154
rel_tol: float = REL_TOL_DEFAULT,
5255
abs_tol_temporal: dt.timedelta = ABS_TOL_TEMPORAL_DEFAULT,
@@ -58,6 +61,7 @@ def condition_equal_columns(
5861
col_right=pl.col(f"{column}_{Side.RIGHT}"),
5962
dtype_left=dtype_left,
6063
dtype_right=dtype_right,
64+
max_list_length=max_list_length,
6165
abs_tol=abs_tol,
6266
rel_tol=rel_tol,
6367
abs_tol_temporal=abs_tol_temporal,
@@ -92,6 +96,7 @@ def _compare_columns(
9296
col_right: pl.Expr,
9397
dtype_left: DataType | DataTypeClass,
9498
dtype_right: DataType | DataTypeClass,
99+
max_list_length: int | None,
95100
abs_tol: float,
96101
rel_tol: float,
97102
abs_tol_temporal: dt.timedelta,
@@ -123,6 +128,7 @@ def _compare_columns(
123128
col_right=col_right.struct[field],
124129
dtype_left=fields_left[field],
125130
dtype_right=fields_right[field],
131+
max_list_length=max_list_length,
126132
abs_tol=abs_tol,
127133
rel_tol=rel_tol,
128134
abs_tol_temporal=abs_tol_temporal,
@@ -133,10 +139,16 @@ def _compare_columns(
133139
elif isinstance(dtype_left, pl.List | pl.Array) and isinstance(
134140
dtype_right, pl.List | pl.Array
135141
):
136-
# As of polars 1.28, there is no way to access another column within
137-
# `list.eval`. Hence, we necessarily need to resort to a primitive
138-
# comparison in this case.
139-
pass
142+
return _compare_sequence_columns(
143+
col_left=col_left,
144+
col_right=col_right,
145+
dtype_left=dtype_left,
146+
dtype_right=dtype_right,
147+
max_list_length=max_list_length,
148+
abs_tol=abs_tol,
149+
rel_tol=rel_tol,
150+
abs_tol_temporal=abs_tol_temporal,
151+
)
140152

141153
if (
142154
isinstance(dtype_left, pl.Enum)
@@ -154,6 +166,7 @@ def _compare_columns(
154166
abs_tol=abs_tol,
155167
rel_tol=rel_tol,
156168
abs_tol_temporal=abs_tol_temporal,
169+
max_list_length=max_list_length,
157170
)
158171

159172
return _compare_primitive_columns(
@@ -167,6 +180,67 @@ def _compare_columns(
167180
)
168181

169182

183+
def _compare_sequence_columns(
184+
col_left: pl.Expr,
185+
col_right: pl.Expr,
186+
dtype_left: pl.List | pl.Array,
187+
dtype_right: pl.List | pl.Array,
188+
max_list_length: int | None,
189+
abs_tol: float,
190+
rel_tol: float,
191+
abs_tol_temporal: dt.timedelta,
192+
) -> pl.Expr:
193+
"""Compare Array/List columns element-wise with tolerance."""
194+
n_elements: int
195+
has_same_length: pl.Expr
196+
197+
if isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.Array):
198+
if dtype_left.shape != dtype_right.shape:
199+
return pl.repeat(pl.lit(False), pl.len())
200+
n_elements = dtype_left.shape[0]
201+
has_same_length = pl.repeat(pl.lit(True), pl.len())
202+
elif isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.List):
203+
n_elements = dtype_left.shape[0]
204+
has_same_length = col_right.list.len().eq(pl.lit(n_elements))
205+
elif isinstance(dtype_left, pl.List) and isinstance(dtype_right, pl.Array):
206+
n_elements = dtype_right.shape[0]
207+
has_same_length = col_left.list.len().eq(pl.lit(n_elements))
208+
else: # pl.List vs pl.List
209+
if not isinstance(max_list_length, int):
210+
# Fallback for nested list comparisons where no max_list_length is
211+
# available: perform a direct equality comparison without element-wise
212+
# unrolling.
213+
return _eq_missing(col_left.eq_missing(col_right), col_left, col_right)
214+
n_elements = max_list_length
215+
has_same_length = col_left.list.len().eq_missing(col_right.list.len())
216+
217+
if n_elements == 0:
218+
return _eq_missing(pl.lit(True), col_left, col_right)
219+
220+
def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Expr:
221+
if isinstance(dtype, pl.Array):
222+
return col.arr.get(i)
223+
return col.list.get(i, null_on_oob=True)
224+
225+
elements_match = pl.all_horizontal(
226+
[
227+
_compare_columns(
228+
col_left=_get_element(col_left, dtype_left, i),
229+
col_right=_get_element(col_right, dtype_right, i),
230+
dtype_left=dtype_left.inner,
231+
dtype_right=dtype_right.inner,
232+
abs_tol=abs_tol,
233+
rel_tol=rel_tol,
234+
abs_tol_temporal=abs_tol_temporal,
235+
max_list_length=None,
236+
)
237+
for i in range(n_elements)
238+
]
239+
)
240+
241+
return _eq_missing(has_same_length & elements_match, col_left, col_right)
242+
243+
170244
def _compare_primitive_columns(
171245
col_left: pl.Expr,
172246
col_right: pl.Expr,

diffly/comparison.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ def equal(self, *, check_dtypes: bool = True) -> bool:
508508
columns=common_columns,
509509
schema_left=self.left_schema,
510510
schema_right=self.right_schema,
511+
max_list_lengths_by_column=self._max_list_lengths_by_column,
511512
abs_tol_by_column=self.abs_tol_by_column,
512513
rel_tol_by_column=self.rel_tol_by_column,
513514
abs_tol_temporal_by_column=self.abs_tol_temporal_by_column,
@@ -708,11 +709,32 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str]
708709
raise ValueError(f"{difference} are not common columns.")
709710
return list(subset)
710711

712+
@cached_property
713+
def _max_list_lengths_by_column(self) -> dict[str, int]:
714+
list_columns = [
715+
col
716+
for col in self._other_common_columns
717+
if isinstance(self.left_schema[col], pl.List)
718+
and isinstance(self.right_schema[col], pl.List)
719+
]
720+
if not list_columns:
721+
return {}
722+
723+
exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns]
724+
[left_max, right_max] = pl.collect_all(
725+
[self.left.select(exprs), self.right.select(exprs)]
726+
)
727+
return {
728+
col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0))
729+
for col in list_columns
730+
}
731+
711732
def _condition_equal_rows(self, columns: list[str]) -> pl.Expr:
712733
return condition_equal_rows(
713734
columns=columns,
714735
schema_left=self.left_schema,
715736
schema_right=self.right_schema,
737+
max_list_lengths_by_column=self._max_list_lengths_by_column,
716738
abs_tol_by_column=self.abs_tol_by_column,
717739
rel_tol_by_column=self.rel_tol_by_column,
718740
abs_tol_temporal_by_column=self.abs_tol_temporal_by_column,
@@ -726,6 +748,7 @@ def _condition_equal_columns(self, column: str) -> pl.Expr:
726748
abs_tol=self.abs_tol_by_column[column],
727749
rel_tol=self.rel_tol_by_column[column],
728750
abs_tol_temporal=self.abs_tol_temporal_by_column[column],
751+
max_list_length=self._max_list_lengths_by_column.get(column),
729752
)
730753

731754
def _equal_rows(self) -> bool:

0 commit comments

Comments
 (0)