Skip to content

Commit d7da250

Browse files
clean up
1 parent d528ecd commit d7da250

2 files changed

Lines changed: 58 additions & 4 deletions

File tree

diffly/comparison.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ def equal(self, *, check_dtypes: bool = True) -> bool:
508508
columns=common_columns,
509509
schema_left=self.left_schema,
510510
schema_right=self.right_schema,
511-
max_list_lengths_by_column=self._max_list_lengths,
511+
max_list_lengths_by_column=self._max_list_lengths_by_column,
512512
abs_tol_by_column=self.abs_tol_by_column,
513513
rel_tol_by_column=self.rel_tol_by_column,
514514
abs_tol_temporal_by_column=self.abs_tol_temporal_by_column,
@@ -710,7 +710,7 @@ def _validate_subset_of_common_columns(self, subset: Iterable[str]) -> list[str]
710710
return list(subset)
711711

712712
@cached_property
713-
def _max_list_lengths(self) -> dict[str, int]:
713+
def _max_list_lengths_by_column(self) -> dict[str, int]:
714714
list_columns = [
715715
col
716716
for col in self._other_common_columns
@@ -734,7 +734,7 @@ def _condition_equal_rows(self, columns: list[str]) -> pl.Expr:
734734
columns=columns,
735735
schema_left=self.left_schema,
736736
schema_right=self.right_schema,
737-
max_list_lengths_by_column=self._max_list_lengths,
737+
max_list_lengths_by_column=self._max_list_lengths_by_column,
738738
abs_tol_by_column=self.abs_tol_by_column,
739739
rel_tol_by_column=self.rel_tol_by_column,
740740
abs_tol_temporal_by_column=self.abs_tol_temporal_by_column,
@@ -748,7 +748,7 @@ def _condition_equal_columns(self, column: str) -> pl.Expr:
748748
abs_tol=self.abs_tol_by_column[column],
749749
rel_tol=self.rel_tol_by_column[column],
750750
abs_tol_temporal=self.abs_tol_temporal_by_column[column],
751-
max_list_length=self._max_list_lengths.get(column, 0),
751+
max_list_length=self._max_list_lengths_by_column.get(column),
752752
)
753753

754754
def _equal_rows(self) -> bool:

tests/test_conditions.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,60 @@ def test_condition_equal_columns_list_array_with_tolerance(
121121
assert actual.to_list() == [True, True, False]
122122

123123

124+
@pytest.mark.parametrize(
125+
"lhs_type",
126+
[pl.Array(pl.Float64, shape=(2, 2)), pl.List(pl.List(pl.Float64))],
127+
)
128+
@pytest.mark.parametrize(
129+
"rhs_type",
130+
[pl.Array(pl.Float64, shape=(2, 2)), pl.List(pl.List(pl.Float64))],
131+
)
132+
def test_condition_equal_columns_nested_list_array_with_tolerance(
133+
lhs_type: pl.DataType, rhs_type: pl.DataType
134+
) -> None:
135+
# Arrange
136+
lhs = pl.DataFrame(
137+
{
138+
"pk": [1, 2, 3],
139+
"a_left": [
140+
[[1.0, 1.1], [2.0, 2.1]],
141+
[[3.0, 3.0], [4.0, 4.0]],
142+
[[5.0, 5.0], [6.0, 6.0]],
143+
],
144+
},
145+
schema={"pk": pl.Int64, "a_left": lhs_type},
146+
)
147+
rhs = pl.DataFrame(
148+
{
149+
"pk": [1, 2, 3],
150+
"a_right": [
151+
[[1.0, 1.1], [2.0, 2.1]],
152+
[[3.0, 3.0], [4.0, 4.4]],
153+
[[5.0, 5.0], [6.0, 6.8]],
154+
],
155+
},
156+
schema={"pk": pl.Int64, "a_right": rhs_type},
157+
)
158+
159+
# Act
160+
actual = (
161+
lhs.join(rhs, on="pk", maintain_order="left")
162+
.select(
163+
condition_equal_columns(
164+
"a",
165+
dtype_left=lhs.schema["a_left"],
166+
dtype_right=rhs.schema["a_right"],
167+
abs_tol=0.5,
168+
rel_tol=0,
169+
max_list_length=2,
170+
)
171+
)
172+
.to_series()
173+
)
174+
175+
assert actual.to_list() == [True, True, False]
176+
177+
124178
def test_condition_equal_columns_nested_dtype_mismatch() -> None:
125179
# Arrange
126180
lhs = pl.DataFrame(

0 commit comments

Comments
 (0)