Skip to content

Commit 68dc630

Browse files
clean up
1 parent 1c0050f commit 68dc630

3 files changed

Lines changed: 35 additions & 36 deletions

File tree

diffly/_conditions.py

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,26 @@ def condition_equal_rows(
1919
columns: list[str],
2020
schema_left: pl.Schema,
2121
schema_right: pl.Schema,
22+
max_list_lengths_by_column: Mapping[str, int],
2223
abs_tol_by_column: Mapping[str, float],
2324
rel_tol_by_column: Mapping[str, float],
2425
abs_tol_temporal_by_column: Mapping[str, dt.timedelta],
25-
max_list_lengths_by_column: Mapping[str, int] | None = None,
2626
) -> pl.Expr:
2727
"""Build an expression whether two rows are equal, based on all columns' data
2828
types."""
2929
if not columns:
3030
return pl.lit(True)
3131

32-
_max_list_lengths = max_list_lengths_by_column or {}
3332
return pl.all_horizontal(
3433
[
3534
condition_equal_columns(
3635
column=column,
3736
dtype_left=schema_left[column],
3837
dtype_right=schema_right[column],
38+
max_list_length=max_list_lengths_by_column.get(column, 0),
3939
abs_tol=abs_tol_by_column[column],
4040
rel_tol=rel_tol_by_column[column],
4141
abs_tol_temporal=abs_tol_temporal_by_column[column],
42-
max_list_length=_max_list_lengths.get(column, 0),
4342
)
4443
for column in columns
4544
]
@@ -50,10 +49,10 @@ def condition_equal_columns(
5049
column: str,
5150
dtype_left: pl.DataType,
5251
dtype_right: pl.DataType,
52+
max_list_length: int,
5353
abs_tol: float = ABS_TOL_DEFAULT,
5454
rel_tol: float = REL_TOL_DEFAULT,
5555
abs_tol_temporal: dt.timedelta = ABS_TOL_TEMPORAL_DEFAULT,
56-
max_list_length: int = 0,
5756
) -> pl.Expr:
5857
"""Build an expression whether two columns are equal, depending on the columns' data
5958
types."""
@@ -62,10 +61,10 @@ def condition_equal_columns(
6261
col_right=pl.col(f"{column}_{Side.RIGHT}"),
6362
dtype_left=dtype_left,
6463
dtype_right=dtype_right,
64+
max_list_length=max_list_length,
6565
abs_tol=abs_tol,
6666
rel_tol=rel_tol,
6767
abs_tol_temporal=abs_tol_temporal,
68-
max_list_length=max_list_length,
6968
)
7069

7170

@@ -97,10 +96,10 @@ def _compare_columns(
9796
col_right: pl.Expr,
9897
dtype_left: DataType | DataTypeClass,
9998
dtype_right: DataType | DataTypeClass,
99+
max_list_length: int,
100100
abs_tol: float,
101101
rel_tol: float,
102102
abs_tol_temporal: dt.timedelta,
103-
max_list_length: int = 0,
104103
) -> pl.Expr:
105104
"""Build an expression whether two expressions yield the same value.
106105
@@ -129,6 +128,7 @@ def _compare_columns(
129128
col_right=col_right.struct[field],
130129
dtype_left=fields_left[field],
131130
dtype_right=fields_right[field],
131+
max_list_length=max_list_length,
132132
abs_tol=abs_tol,
133133
rel_tol=rel_tol,
134134
abs_tol_temporal=abs_tol_temporal,
@@ -139,7 +139,7 @@ def _compare_columns(
139139
elif isinstance(dtype_left, pl.List | pl.Array) and isinstance(
140140
dtype_right, pl.List | pl.Array
141141
):
142-
result = _compare_sequence_columns(
142+
return _compare_sequence_columns(
143143
col_left=col_left,
144144
col_right=col_right,
145145
dtype_left=dtype_left,
@@ -149,8 +149,6 @@ def _compare_columns(
149149
rel_tol=rel_tol,
150150
abs_tol_temporal=abs_tol_temporal,
151151
)
152-
if result is not None:
153-
return result
154152

155153
if (
156154
isinstance(dtype_left, pl.Enum)
@@ -168,6 +166,7 @@ def _compare_columns(
168166
abs_tol=abs_tol,
169167
rel_tol=rel_tol,
170168
abs_tol_temporal=abs_tol_temporal,
169+
max_list_length=max_list_length,
171170
)
172171

173172
return _compare_primitive_columns(
@@ -190,13 +189,8 @@ def _compare_sequence_columns(
190189
abs_tol: float,
191190
rel_tol: float,
192191
abs_tol_temporal: dt.timedelta,
193-
) -> pl.Expr | None:
194-
"""Compare Array/List columns element-wise with tolerance.
195-
196-
Returns ``None`` if the comparison cannot be performed element-wise (e.g. List vs
197-
List without a known ``max_list_length``), signalling to the caller that it should
198-
fall back to primitive comparison.
199-
"""
192+
) -> pl.Expr:
193+
"""Compare Array/List columns element-wise with tolerance."""
200194
assert isinstance(dtype_left, pl.List | pl.Array)
201195
assert isinstance(dtype_right, pl.List | pl.Array)
202196
inner_left = dtype_left.inner
@@ -207,29 +201,27 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
207201
return col.arr.get(i)
208202
return col.list.get(i, null_on_oob=True)
209203

210-
n: int | None = None
211-
length_check: pl.Expr | None = None
204+
n_elements: int | None = None
205+
has_same_length: pl.Expr | None = None
212206

213207
if isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.Array):
214208
if dtype_left.shape != dtype_right.shape:
215209
return pl.repeat(pl.lit(False), pl.len())
216-
n = dtype_left.shape[0]
210+
n_elements = dtype_left.shape[0]
217211
elif isinstance(dtype_left, pl.Array) and isinstance(dtype_right, pl.List):
218-
n = dtype_left.shape[0]
219-
length_check = col_right.list.len().eq(pl.lit(n))
212+
n_elements = dtype_left.shape[0]
213+
has_same_length = col_right.list.len().eq(pl.lit(n_elements))
220214
elif isinstance(dtype_left, pl.List) and isinstance(dtype_right, pl.Array):
221-
n = dtype_right.shape[0]
222-
length_check = col_left.list.len().eq(pl.lit(n))
215+
n_elements = dtype_right.shape[0]
216+
has_same_length = col_left.list.len().eq(pl.lit(n_elements))
223217
else:
224218
# List vs List
225-
if max_list_length == 0:
226-
return None
227-
n = max_list_length
228-
length_check = col_left.list.len().eq_missing(col_right.list.len())
229-
230-
if n == 0:
231-
if length_check is not None:
232-
return _eq_missing(length_check, col_left, col_right)
219+
n_elements = max_list_length
220+
has_same_length = col_left.list.len().eq_missing(col_right.list.len())
221+
222+
if n_elements == 0:
223+
if has_same_length is not None:
224+
return _eq_missing(has_same_length, col_left, col_right)
233225
return _eq_missing(pl.lit(True), col_left, col_right)
234226

235227
elements_match = pl.all_horizontal(
@@ -242,13 +234,14 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
242234
abs_tol=abs_tol,
243235
rel_tol=rel_tol,
244236
abs_tol_temporal=abs_tol_temporal,
237+
max_list_length=max_list_length,
245238
)
246-
for i in range(n)
239+
for i in range(n_elements)
247240
]
248241
)
249242

250-
if length_check is not None:
251-
return _eq_missing(length_check & elements_match, col_left, col_right)
243+
if has_same_length is not None:
244+
return _eq_missing(has_same_length & elements_match, col_left, col_right)
252245
return elements_match
253246

254247

diffly/comparison.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -508,10 +508,10 @@ def equal(self, *, check_dtypes: bool = True) -> bool:
508508
columns=common_columns,
509509
schema_left=self.left_schema,
510510
schema_right=self.right_schema,
511+
max_list_lengths_by_column=self._max_list_lengths,
511512
abs_tol_by_column=self.abs_tol_by_column,
512513
rel_tol_by_column=self.rel_tol_by_column,
513514
abs_tol_temporal_by_column=self.abs_tol_temporal_by_column,
514-
max_list_lengths_by_column=self._max_list_lengths,
515515
).all()
516516
)
517517
.item()
@@ -734,10 +734,10 @@ def _condition_equal_rows(self, columns: list[str]) -> pl.Expr:
734734
columns=columns,
735735
schema_left=self.left_schema,
736736
schema_right=self.right_schema,
737+
max_list_lengths_by_column=self._max_list_lengths,
737738
abs_tol_by_column=self.abs_tol_by_column,
738739
rel_tol_by_column=self.rel_tol_by_column,
739740
abs_tol_temporal_by_column=self.abs_tol_temporal_by_column,
740-
max_list_lengths_by_column=self._max_list_lengths,
741741
)
742742

743743
def _condition_equal_columns(self, column: str) -> pl.Expr:

tests/test_conditions.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def test_condition_equal_columns_struct() -> None:
3232
"a",
3333
dtype_left=lhs.schema["a_left"],
3434
dtype_right=rhs.schema["a_right"],
35+
max_list_length=0,
3536
abs_tol=0.5,
3637
rel_tol=0,
3738
)
@@ -66,6 +67,7 @@ def test_condition_equal_columns_different_struct_fields() -> None:
6667
"a",
6768
dtype_left=lhs.schema["a_left"],
6869
dtype_right=rhs.schema["a_right"],
70+
max_list_length=0,
6971
)
7072
)
7173
.to_series()
@@ -188,6 +190,7 @@ def test_condition_equal_columns_nested_dtype_mismatch() -> None:
188190
"a",
189191
dtype_left=lhs.schema["a_left"],
190192
dtype_right=rhs.schema["a_right"],
193+
max_list_length=0,
191194
)
192195
)
193196
.to_series()
@@ -220,6 +223,7 @@ def test_condition_equal_columns_exactly_one_nested() -> None:
220223
"a",
221224
dtype_left=lhs.schema["a_left"],
222225
dtype_right=rhs.schema["a_right"],
226+
max_list_length=0,
223227
)
224228
)
225229
.to_series()
@@ -262,6 +266,7 @@ def test_condition_equal_columns_temporal_tolerance() -> None:
262266
"a",
263267
dtype_left=lhs.schema["a_left"],
264268
dtype_right=rhs.schema["a_right"],
269+
max_list_length=0,
265270
abs_tol_temporal=dt.timedelta(seconds=2),
266271
)
267272
)
@@ -354,6 +359,7 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None:
354359
"a",
355360
dtype_left=lhs.schema["a_left"],
356361
dtype_right=rhs.schema["a_right"],
362+
max_list_length=0,
357363
abs_tol=0.5,
358364
rel_tol=0,
359365
)

0 commit comments

Comments
 (0)