Skip to content

Commit a6992fa

Browse files
perf: Element-wise comparison only for tolerance-requiring data types (#26)
1 parent ff8439c commit a6992fa

File tree

3 files changed

+193
-28
lines changed

3 files changed

+193
-28
lines changed

diffly/_conditions.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -140,22 +140,22 @@ def _compare_columns(
140140
elif isinstance(dtype_left, pl.List | pl.Array) and isinstance(
141141
dtype_right, pl.List | pl.Array
142142
):
143-
return _compare_sequence_columns(
144-
col_left=col_left,
145-
col_right=col_right,
146-
dtype_left=dtype_left,
147-
dtype_right=dtype_right,
148-
max_list_length=max_list_length,
149-
abs_tol=abs_tol,
150-
rel_tol=rel_tol,
151-
abs_tol_temporal=abs_tol_temporal,
152-
)
153-
154-
if (
155-
isinstance(dtype_left, pl.Enum)
156-
and isinstance(dtype_right, pl.Enum)
157-
and dtype_left != dtype_right
158-
) or _enum_and_categorical(dtype_left, dtype_right):
143+
if _needs_element_wise_comparison(dtype_left.inner, dtype_right.inner):
144+
return _compare_sequence_columns(
145+
col_left=col_left,
146+
col_right=col_right,
147+
dtype_left=dtype_left,
148+
dtype_right=dtype_right,
149+
max_list_length=max_list_length,
150+
abs_tol=abs_tol,
151+
rel_tol=rel_tol,
152+
abs_tol_temporal=abs_tol_temporal,
153+
)
154+
return col_left.eq_missing(col_right)
155+
156+
if _different_enums(dtype_left, dtype_right) or _enum_and_categorical(
157+
dtype_left, dtype_right
158+
):
159159
# Enums with different categories as well as enums and categoricals
160160
# can't be compared directly.
161161
# Fall back to comparison of strings.
@@ -237,6 +237,54 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
237237
return _eq_missing(has_same_length & elements_match, col_left, col_right)
238238

239239

240+
def _is_float_numeric_pair(
241+
dtype_left: DataType | DataTypeClass,
242+
dtype_right: DataType | DataTypeClass,
243+
) -> bool:
244+
return (dtype_left.is_float() or dtype_right.is_float()) and (
245+
dtype_left.is_numeric() and dtype_right.is_numeric()
246+
)
247+
248+
249+
def _is_temporal_pair(
250+
dtype_left: DataType | DataTypeClass,
251+
dtype_right: DataType | DataTypeClass,
252+
) -> bool:
253+
return dtype_left.is_temporal() and dtype_right.is_temporal()
254+
255+
256+
def _needs_element_wise_comparison(
257+
dtype_left: DataType | DataTypeClass,
258+
dtype_right: DataType | DataTypeClass,
259+
) -> bool:
260+
"""Check if two dtypes require element-wise comparison (tolerances or special
261+
handling).
262+
263+
Returns False when eq_missing() on the whole column would produce identical results,
264+
allowing us to skip the expensive element-wise iteration for list/array columns.
265+
"""
266+
if (
267+
_is_float_numeric_pair(dtype_left, dtype_right)
268+
or _is_temporal_pair(dtype_left, dtype_right)
269+
or _different_enums(dtype_left, dtype_right)
270+
or _enum_and_categorical(dtype_left, dtype_right)
271+
):
272+
return True
273+
if isinstance(dtype_left, pl.Struct) and isinstance(dtype_right, pl.Struct):
274+
fields_left = {f.name: f.dtype for f in dtype_left.fields}
275+
fields_right = {f.name: f.dtype for f in dtype_right.fields}
276+
return any(
277+
_needs_element_wise_comparison(fields_left[name], fields_right[name])
278+
for name in fields_left
279+
if name in fields_right
280+
)
281+
if isinstance(dtype_left, pl.List | pl.Array) and isinstance(
282+
dtype_right, pl.List | pl.Array
283+
):
284+
return _needs_element_wise_comparison(dtype_left.inner, dtype_right.inner)
285+
return False
286+
287+
240288
def _compare_primitive_columns(
241289
col_left: pl.Expr,
242290
col_right: pl.Expr,
@@ -246,13 +294,11 @@ def _compare_primitive_columns(
246294
rel_tol: float,
247295
abs_tol_temporal: dt.timedelta,
248296
) -> pl.Expr:
249-
if (dtype_left.is_float() or dtype_right.is_float()) and (
250-
dtype_left.is_numeric() and dtype_right.is_numeric()
251-
):
297+
if _is_float_numeric_pair(dtype_left, dtype_right):
252298
return col_left.is_close(col_right, abs_tol=abs_tol, rel_tol=rel_tol).pipe(
253299
_eq_missing_with_nan, lhs=col_left, rhs=col_right
254300
)
255-
elif dtype_left.is_temporal() and dtype_right.is_temporal():
301+
elif _is_temporal_pair(dtype_left, dtype_right):
256302
diff_less_than_tolerance = (col_left - col_right).abs() <= abs_tol_temporal
257303
return diff_less_than_tolerance.pipe(_eq_missing, lhs=col_left, rhs=col_right)
258304

@@ -270,6 +316,12 @@ def _eq_missing_with_nan(expr: pl.Expr, lhs: pl.Expr, rhs: pl.Expr) -> pl.Expr:
270316
return _eq_missing(expr, lhs, rhs) | both_nan
271317

272318

319+
def _different_enums(
320+
left: DataType | DataTypeClass, right: DataType | DataTypeClass
321+
) -> bool:
322+
return isinstance(left, pl.Enum) and isinstance(right, pl.Enum) and left != right
323+
324+
273325
def _enum_and_categorical(
274326
left: DataType | DataTypeClass, right: DataType | DataTypeClass
275327
) -> bool:

tests/test_conditions.py

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import polars as pl
77
import pytest
88

9-
from diffly._conditions import _can_compare_dtypes, condition_equal_columns
9+
from diffly._conditions import (
10+
_can_compare_dtypes,
11+
_needs_element_wise_comparison,
12+
condition_equal_columns,
13+
)
1014
from diffly.comparison import compare_frames
1115

1216

@@ -512,6 +516,45 @@ def test_condition_equal_columns_lists_only_inner() -> None:
512516
assert actual.to_list() == [True, False]
513517

514518

519+
def test_condition_equal_columns_list_of_different_enums() -> None:
520+
# Arrange
521+
first_enum = pl.Enum(["one", "two"])
522+
second_enum = pl.Enum(["one", "two", "three"])
523+
524+
lhs = pl.DataFrame(
525+
{"pk": [1, 2], "a": [["one", "two"], ["one", "one"]]},
526+
schema_overrides={"a": pl.List(first_enum)},
527+
)
528+
rhs = pl.DataFrame(
529+
{"pk": [1, 2], "a": [["one", "two"], ["one", "three"]]},
530+
schema_overrides={"a": pl.List(second_enum)},
531+
)
532+
c = compare_frames(lhs, rhs, primary_key="pk")
533+
534+
# Act
535+
lhs = lhs.rename({"a": "a_left"})
536+
rhs = rhs.rename({"a": "a_right"})
537+
actual = (
538+
lhs.join(rhs, on="pk", maintain_order="left")
539+
.select(
540+
condition_equal_columns(
541+
"a",
542+
dtype_left=lhs.schema["a_left"],
543+
dtype_right=rhs.schema["a_right"],
544+
max_list_length=c._max_list_lengths_by_column.get("a"),
545+
abs_tol=c.abs_tol_by_column["a"],
546+
rel_tol=c.rel_tol_by_column["a"],
547+
)
548+
)
549+
.to_series()
550+
)
551+
552+
# Assert
553+
assert c._max_list_lengths_by_column == {"a": 2}
554+
assert _needs_element_wise_comparison(first_enum, second_enum)
555+
assert actual.to_list() == [True, False]
556+
557+
515558
@pytest.mark.parametrize(
516559
("dtype_left", "dtype_right", "can_compare_dtypes"),
517560
[
@@ -534,3 +577,73 @@ def test_can_compare_dtypes(
534577
dtype_left=dtype_left, dtype_right=dtype_right
535578
)
536579
assert can_compare_dtypes_actual == can_compare_dtypes
580+
581+
582+
@pytest.mark.parametrize(
583+
("dtype_left", "dtype_right", "expected"),
584+
[
585+
# Primitives that don't need element-wise comparison
586+
(pl.Int64, pl.Int64, False),
587+
(pl.String, pl.String, False),
588+
(pl.Boolean, pl.Boolean, False),
589+
# Float/numeric pairs
590+
(pl.Float64, pl.Float64, True),
591+
(pl.Int64, pl.Float64, True),
592+
(pl.Float32, pl.Int32, True),
593+
# Temporal pairs
594+
(pl.Datetime, pl.Datetime, True),
595+
(pl.Date, pl.Date, True),
596+
(pl.Datetime, pl.Date, True),
597+
# Enum/categorical
598+
(pl.Enum(["a", "b"]), pl.Enum(["a", "b"]), False),
599+
(pl.Enum(["a", "b"]), pl.Enum(["a", "b", "c"]), True),
600+
(pl.Enum(["a"]), pl.Categorical(), True),
601+
(pl.Categorical(), pl.Enum(["a"]), True),
602+
# Struct with no tolerance-requiring fields
603+
(
604+
pl.Struct({"x": pl.Int64, "y": pl.String}),
605+
pl.Struct({"x": pl.Int64, "y": pl.String}),
606+
False,
607+
),
608+
# Struct with a float field
609+
(
610+
pl.Struct({"x": pl.Int64, "y": pl.Float64}),
611+
pl.Struct({"x": pl.Int64, "y": pl.Float64}),
612+
True,
613+
),
614+
# Struct with different-category enums
615+
(
616+
pl.Struct({"x": pl.Enum(["a"])}),
617+
pl.Struct({"x": pl.Enum(["b"])}),
618+
True,
619+
),
620+
# List/Array with non-tolerance inner type
621+
(pl.List(pl.Int64), pl.List(pl.Int64), False),
622+
(pl.Array(pl.String, shape=3), pl.Array(pl.String, shape=3), False),
623+
# List/Array with tolerance-requiring inner type
624+
(pl.List(pl.Float64), pl.List(pl.Float64), True),
625+
(pl.Array(pl.Datetime, shape=2), pl.Array(pl.Datetime, shape=2), True),
626+
# Nested: list of structs with a float field
627+
(
628+
pl.List(pl.Struct({"x": pl.Float64})),
629+
pl.List(pl.Struct({"x": pl.Float64})),
630+
True,
631+
),
632+
# Nested: list of structs without tolerance-requiring fields
633+
(
634+
pl.List(pl.Struct({"x": pl.Int64})),
635+
pl.List(pl.Struct({"x": pl.Int64})),
636+
False,
637+
),
638+
# Deeply nested: struct with a list of structs with a float field
639+
(
640+
pl.List(pl.Struct({"x": pl.String, "y": pl.List(pl.Float64)})),
641+
pl.List(pl.Struct({"x": pl.String, "y": pl.List(pl.Float64)})),
642+
True,
643+
),
644+
],
645+
)
646+
def test_needs_element_wise_comparison(
647+
dtype_left: pl.DataType, dtype_right: pl.DataType, expected: bool
648+
) -> None:
649+
assert _needs_element_wise_comparison(dtype_left, dtype_right) == expected

tests/test_performance.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,10 @@ def expensive_computation(col: pl.Expr) -> pl.Expr:
8383
)
8484

8585

86-
def test_element_wise_comparison_slower_than_eq_missing_for_list_columns() -> None:
87-
"""Confirm that comparing list columns with non-tolerance inner types via
88-
eq_missing() is significantly faster than the element-wise
89-
_compare_sequence_columns() path."""
86+
def test_eq_missing_not_slower_than_element_wise_for_list_columns() -> None:
87+
"""Ensure that comparing list columns with non-tolerance inner types via
88+
eq_missing() is not slower than the element-wise _compare_sequence_columns()
89+
path."""
9090
n_rows = 500_000
9191
list_len = 20
9292
num_runs_measured = 10
@@ -126,8 +126,8 @@ def test_element_wise_comparison_slower_than_eq_missing_for_list_columns() -> No
126126
mean_time_cond = statistics.mean(times_cond[num_runs_warmup:])
127127

128128
ratio = mean_time_cond / mean_time_eq
129-
assert ratio > 2.0, (
130-
f"Element-wise comparison was only {ratio:.1f}x slower than eq_missing "
129+
assert ratio < 1.25, (
130+
f"condition_equal_columns was {ratio:.1f}x slower than eq_missing "
131131
f"({mean_time_cond:.3f}s vs {mean_time_eq:.3f}s). "
132-
f"Expected at least 2x slowdown to justify the optimization."
132+
f"Expected comparable performance since list<i64> should use eq_missing directly."
133133
)

0 commit comments

Comments
 (0)