Skip to content

Commit 1d4df7b

Browse files
improve
1 parent 68dc630 commit 1d4df7b

3 files changed

Lines changed: 110 additions & 62 deletions

File tree

diffly/_conditions.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def condition_equal_rows(
3535
column=column,
3636
dtype_left=schema_left[column],
3737
dtype_right=schema_right[column],
38-
max_list_length=max_list_lengths_by_column.get(column, 0),
38+
max_list_length=max_list_lengths_by_column.get(column),
3939
abs_tol=abs_tol_by_column[column],
4040
rel_tol=rel_tol_by_column[column],
4141
abs_tol_temporal=abs_tol_temporal_by_column[column],
@@ -49,7 +49,7 @@ def condition_equal_columns(
4949
column: str,
5050
dtype_left: pl.DataType,
5151
dtype_right: pl.DataType,
52-
max_list_length: int,
52+
max_list_length: int | None = None,
5353
abs_tol: float = ABS_TOL_DEFAULT,
5454
rel_tol: float = REL_TOL_DEFAULT,
5555
abs_tol_temporal: dt.timedelta = ABS_TOL_TEMPORAL_DEFAULT,
@@ -96,7 +96,7 @@ def _compare_columns(
9696
col_right: pl.Expr,
9797
dtype_left: DataType | DataTypeClass,
9898
dtype_right: DataType | DataTypeClass,
99-
max_list_length: int,
99+
max_list_length: int | None,
100100
abs_tol: float,
101101
rel_tol: float,
102102
abs_tol_temporal: dt.timedelta,
@@ -185,7 +185,7 @@ def _compare_sequence_columns(
185185
col_right: pl.Expr,
186186
dtype_left: DataType | DataTypeClass,
187187
dtype_right: DataType | DataTypeClass,
188-
max_list_length: int,
188+
max_list_length: int | None,
189189
abs_tol: float,
190190
rel_tol: float,
191191
abs_tol_temporal: dt.timedelta,
@@ -216,6 +216,7 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
216216
has_same_length = col_left.list.len().eq(pl.lit(n_elements))
217217
else:
218218
# List vs List
219+
assert max_list_length is not None
219220
n_elements = max_list_length
220221
has_same_length = col_left.list.len().eq_missing(col_right.list.len())
221222

diffly/comparison.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -748,7 +748,7 @@ def _condition_equal_columns(self, column: str) -> pl.Expr:
748748
abs_tol=self.abs_tol_by_column[column],
749749
rel_tol=self.rel_tol_by_column[column],
750750
abs_tol_temporal=self.abs_tol_temporal_by_column[column],
751-
max_list_length=self._max_list_lengths.get(column, 0),
751+
max_list_length=self._max_list_lengths.get(column),
752752
)
753753

754754
def _equal_rows(self) -> bool:

tests/test_conditions.py

Lines changed: 104 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def test_condition_equal_columns_struct() -> None:
3232
"a",
3333
dtype_left=lhs.schema["a_left"],
3434
dtype_right=rhs.schema["a_right"],
35-
max_list_length=0,
35+
max_list_length=None,
3636
abs_tol=0.5,
3737
rel_tol=0,
3838
)
@@ -67,7 +67,7 @@ def test_condition_equal_columns_different_struct_fields() -> None:
6767
"a",
6868
dtype_left=lhs.schema["a_left"],
6969
dtype_right=rhs.schema["a_right"],
70-
max_list_length=0,
70+
max_list_length=None,
7171
)
7272
)
7373
.to_series()
@@ -89,60 +89,15 @@ def test_condition_equal_columns_list_array_with_tolerance(
8989
# Arrange
9090
lhs = pl.DataFrame(
9191
{
92-
"pk": [1, 2],
93-
"a_left": [[1.0, 1.1], [2.0, 2.1]],
94-
},
95-
schema={"pk": pl.Int64, "a_left": lhs_type},
96-
)
97-
rhs = pl.DataFrame(
98-
{
99-
"pk": [1, 2],
100-
"a_right": [[1.0, 1.1], [2.0, 2.2]],
101-
},
102-
schema={"pk": pl.Int64, "a_right": rhs_type},
103-
)
104-
105-
# Act
106-
actual = (
107-
lhs.join(rhs, on="pk", maintain_order="left")
108-
.select(
109-
condition_equal_columns(
110-
"a",
111-
dtype_left=lhs.schema["a_left"],
112-
dtype_right=rhs.schema["a_right"],
113-
abs_tol=0.5,
114-
rel_tol=0,
115-
max_list_length=2,
116-
)
117-
)
118-
.to_series()
119-
)
120-
121-
# Assert: diff is 0.1, within abs_tol=0.5
122-
assert actual.to_list() == [True, True]
123-
124-
125-
@pytest.mark.parametrize(
126-
"lhs_type", [pl.Array(pl.Float64, shape=2), pl.List(pl.Float64)]
127-
)
128-
@pytest.mark.parametrize(
129-
"rhs_type", [pl.Array(pl.Float64, shape=2), pl.List(pl.Float64)]
130-
)
131-
def test_condition_equal_columns_list_array_exceeds_tolerance(
132-
lhs_type: pl.DataType, rhs_type: pl.DataType
133-
) -> None:
134-
# Arrange
135-
lhs = pl.DataFrame(
136-
{
137-
"pk": [1, 2],
138-
"a_left": [[1.0, 1.1], [2.0, 2.1]],
92+
"pk": [1, 2, 3],
93+
"a_left": [[1.0, 1.1], [2.0, 2.1], [3.0, 3.0]],
13994
},
14095
schema={"pk": pl.Int64, "a_left": lhs_type},
14196
)
14297
rhs = pl.DataFrame(
14398
{
144-
"pk": [1, 2],
145-
"a_right": [[1.0, 1.1], [2.0, 2.8]],
99+
"pk": [1, 2, 3],
100+
"a_right": [[1.0, 1.1], [2.0, 2.2], [3.0, 3.7]],
146101
},
147102
schema={"pk": pl.Int64, "a_right": rhs_type},
148103
)
@@ -163,8 +118,7 @@ def test_condition_equal_columns_list_array_exceeds_tolerance(
163118
.to_series()
164119
)
165120

166-
# Assert: diff is 0.7, exceeds abs_tol=0.5
167-
assert actual.to_list() == [True, False]
121+
assert actual.to_list() == [True, True, False]
168122

169123

170124
def test_condition_equal_columns_nested_dtype_mismatch() -> None:
@@ -190,7 +144,7 @@ def test_condition_equal_columns_nested_dtype_mismatch() -> None:
190144
"a",
191145
dtype_left=lhs.schema["a_left"],
192146
dtype_right=rhs.schema["a_right"],
193-
max_list_length=0,
147+
max_list_length=None,
194148
)
195149
)
196150
.to_series()
@@ -223,7 +177,7 @@ def test_condition_equal_columns_exactly_one_nested() -> None:
223177
"a",
224178
dtype_left=lhs.schema["a_left"],
225179
dtype_right=rhs.schema["a_right"],
226-
max_list_length=0,
180+
max_list_length=None,
227181
)
228182
)
229183
.to_series()
@@ -266,7 +220,7 @@ def test_condition_equal_columns_temporal_tolerance() -> None:
266220
"a",
267221
dtype_left=lhs.schema["a_left"],
268222
dtype_right=rhs.schema["a_right"],
269-
max_list_length=0,
223+
max_list_length=None,
270224
abs_tol_temporal=dt.timedelta(seconds=2),
271225
)
272226
)
@@ -359,7 +313,7 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None:
359313
"a",
360314
dtype_left=lhs.schema["a_left"],
361315
dtype_right=rhs.schema["a_right"],
362-
max_list_length=0,
316+
max_list_length=None,
363317
abs_tol=0.5,
364318
rel_tol=0,
365319
)
@@ -369,6 +323,99 @@ def test_condition_equal_columns_array_vs_list_length_mismatch() -> None:
369323
assert actual.to_list() == [True, False]
370324

371325

326+
def test_condition_equal_columns_array_different_shapes() -> None:
327+
lhs = pl.DataFrame(
328+
{
329+
"pk": [1],
330+
"a_left": [[1.0, 2.0]],
331+
},
332+
schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=2)},
333+
)
334+
rhs = pl.DataFrame(
335+
{
336+
"pk": [1],
337+
"a_right": [[1.0, 2.0, 3.0]],
338+
},
339+
schema={"pk": pl.Int64, "a_right": pl.Array(pl.Float64, shape=3)},
340+
)
341+
342+
actual = (
343+
lhs.join(rhs, on="pk", maintain_order="left")
344+
.select(
345+
condition_equal_columns(
346+
"a",
347+
dtype_left=lhs.schema["a_left"],
348+
dtype_right=rhs.schema["a_right"],
349+
max_list_length=None,
350+
)
351+
)
352+
.to_series()
353+
)
354+
assert actual.to_list() == [False]
355+
356+
357+
def test_condition_equal_columns_empty_arrays() -> None:
358+
lhs = pl.DataFrame(
359+
{
360+
"pk": [1, 2],
361+
"a_left": [[], None],
362+
},
363+
schema={"pk": pl.Int64, "a_left": pl.Array(pl.Float64, shape=0)},
364+
)
365+
rhs = pl.DataFrame(
366+
{
367+
"pk": [1, 2],
368+
"a_right": [[], None],
369+
},
370+
schema={"pk": pl.Int64, "a_right": pl.Array(pl.Float64, shape=0)},
371+
)
372+
373+
actual = (
374+
lhs.join(rhs, on="pk", maintain_order="left")
375+
.select(
376+
condition_equal_columns(
377+
"a",
378+
dtype_left=lhs.schema["a_left"],
379+
dtype_right=rhs.schema["a_right"],
380+
max_list_length=None,
381+
)
382+
)
383+
.to_series()
384+
)
385+
assert actual.to_list() == [True, True]
386+
387+
388+
def test_condition_equal_columns_empty_lists() -> None:
389+
lhs = pl.DataFrame(
390+
{
391+
"pk": [1, 2, 3],
392+
"a_left": [[], None, []],
393+
},
394+
schema={"pk": pl.Int64, "a_left": pl.List(pl.Float64)},
395+
)
396+
rhs = pl.DataFrame(
397+
{
398+
"pk": [1, 2, 3],
399+
"a_right": [[], None, None],
400+
},
401+
schema={"pk": pl.Int64, "a_right": pl.List(pl.Float64)},
402+
)
403+
404+
actual = (
405+
lhs.join(rhs, on="pk", maintain_order="left")
406+
.select(
407+
condition_equal_columns(
408+
"a",
409+
dtype_left=lhs.schema["a_left"],
410+
dtype_right=rhs.schema["a_right"],
411+
max_list_length=0,
412+
)
413+
)
414+
.to_series()
415+
)
416+
assert actual.to_list() == [True, True, False]
417+
418+
372419
@pytest.mark.parametrize(
373420
("dtype_left", "dtype_right", "can_compare_dtypes"),
374421
[

0 commit comments

Comments
 (0)