Skip to content

Commit cccfad4

Browse files
feat: Tolerances for inner lists and arrays (#21)
1 parent 2a3010b commit cccfad4

File tree

3 files changed

+229
-116
lines changed

3 files changed

+229
-116
lines changed

diffly/_conditions.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import datetime as dt
55
from collections.abc import Mapping
6+
from typing import cast
67

78
import polars as pl
89
from polars.datatypes import DataType, DataTypeClass
@@ -206,12 +207,7 @@ def _compare_sequence_columns(
206207
n_elements = dtype_right.shape[0]
207208
has_same_length = col_left.list.len().eq(pl.lit(n_elements))
208209
else: # pl.List vs pl.List
209-
if not isinstance(max_list_length, int):
210-
# Fallback for nested list comparisons where no max_list_length is
211-
# available: perform a direct equality comparison without element-wise
212-
# unrolling.
213-
return _eq_missing(col_left.eq_missing(col_right), col_left, col_right)
214-
n_elements = max_list_length
210+
n_elements = cast(int, max_list_length)
215211
has_same_length = col_left.list.len().eq_missing(col_right.list.len())
216212

217213
if n_elements == 0:
@@ -232,7 +228,7 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
232228
abs_tol=abs_tol,
233229
rel_tol=rel_tol,
234230
abs_tol_temporal=abs_tol_temporal,
235-
max_list_length=None,
231+
max_list_length=max_list_length,
236232
)
237233
for i in range(n_elements)
238234
]

diffly/comparison.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -711,22 +711,30 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str]
711711

712712
@cached_property
713713
def _max_list_lengths_by_column(self) -> dict[str, int]:
714-
list_columns = [
715-
col
716-
for col in self._other_common_columns
717-
if isinstance(self.left_schema[col], pl.List)
718-
and isinstance(self.right_schema[col], pl.List)
719-
]
720-
if not list_columns:
714+
"""Max list length across all nesting levels, for columns where both sides
715+
contain a List anywhere in their type tree."""
716+
left_exprs: list[pl.Expr] = []
717+
right_exprs: list[pl.Expr] = []
718+
columns: list[str] = []
719+
720+
for col in self._other_common_columns:
721+
col_left = _list_length_exprs(pl.col(col), self.left_schema[col])
722+
col_right = _list_length_exprs(pl.col(col), self.right_schema[col])
723+
if not (col_left and col_right):
724+
continue
725+
columns.append(col)
726+
left_exprs.append(pl.max_horizontal(col_left).alias(col))
727+
right_exprs.append(pl.max_horizontal(col_right).alias(col))
728+
729+
if not columns:
721730
return {}
722731

723-
exprs = [pl.col(col).list.len().max().alias(col) for col in list_columns]
724732
[left_max, right_max] = pl.collect_all(
725-
[self.left.select(exprs), self.right.select(exprs)]
733+
[self.left.select(left_exprs), self.right.select(right_exprs)]
726734
)
727735
return {
728736
col: max(int(left_max[col].item() or 0), int(right_max[col].item() or 0))
729-
for col in list_columns
737+
for col in columns
730738
}
731739

732740
def _condition_equal_rows(self, columns: list[str]) -> pl.Expr:
@@ -833,3 +841,21 @@ def right_only(self) -> Schema:
833841
"""Columns that are only present in the right data frame, mapped to their data
834842
types."""
835843
return self.right() - self.left()
844+
845+
846+
def _list_length_exprs(
847+
expr: pl.Expr, dtype: pl.DataType | pl.datatypes.DataTypeClass
848+
) -> list[pl.Expr]:
849+
"""Collect max-list-length scalar expressions for every List level in the type
850+
tree."""
851+
if isinstance(dtype, pl.List):
852+
return [expr.list.len().max(), *_list_length_exprs(expr.explode(), dtype.inner)]
853+
if isinstance(dtype, pl.Array):
854+
return _list_length_exprs(expr.explode(), dtype.inner)
855+
if isinstance(dtype, pl.Struct):
856+
return [
857+
e
858+
for field in dtype.fields
859+
for e in _list_length_exprs(expr.struct[field.name], field.dtype)
860+
]
861+
return []

0 commit comments

Comments
 (0)