Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 59 additions & 14 deletions diffly/_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,16 +140,18 @@ def _compare_columns(
elif isinstance(dtype_left, pl.List | pl.Array) and isinstance(
dtype_right, pl.List | pl.Array
):
return _compare_sequence_columns(
col_left=col_left,
col_right=col_right,
dtype_left=dtype_left,
dtype_right=dtype_right,
max_list_length=max_list_length,
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
)
if _needs_element_wise_comparison(dtype_left.inner, dtype_right.inner):
Comment thread
MariusMerkleQC marked this conversation as resolved.
return _compare_sequence_columns(
col_left=col_left,
col_right=col_right,
dtype_left=dtype_left,
dtype_right=dtype_right,
max_list_length=max_list_length,
abs_tol=abs_tol,
rel_tol=rel_tol,
abs_tol_temporal=abs_tol_temporal,
)
return col_left.eq_missing(col_right)
Comment thread
MariusMerkleQC marked this conversation as resolved.

if (
isinstance(dtype_left, pl.Enum)
Expand Down Expand Up @@ -237,6 +239,51 @@ def _get_element(col: pl.Expr, dtype: DataType | DataTypeClass, i: int) -> pl.Ex
return _eq_missing(has_same_length & elements_match, col_left, col_right)


def _is_float_numeric_pair(
dtype_left: DataType | DataTypeClass,
dtype_right: DataType | DataTypeClass,
) -> bool:
return (dtype_left.is_float() or dtype_right.is_float()) and (
dtype_left.is_numeric() and dtype_right.is_numeric()
)


def _is_temporal_pair(
dtype_left: DataType | DataTypeClass,
dtype_right: DataType | DataTypeClass,
) -> bool:
return dtype_left.is_temporal() and dtype_right.is_temporal()


def _needs_element_wise_comparison(
dtype_left: DataType | DataTypeClass,
dtype_right: DataType | DataTypeClass,
) -> bool:
"""Check if two dtypes require element-wise comparison (tolerances or special
handling).

Returns False when eq_missing() on the whole column would produce identical results,
allowing us to skip the expensive element-wise iteration for list/array columns.
"""
if _is_float_numeric_pair(dtype_left, dtype_right):
return True
if _is_temporal_pair(dtype_left, dtype_right):
return True
if isinstance(dtype_left, pl.Struct) and isinstance(dtype_right, pl.Struct):
fields_left = {f.name: f.dtype for f in dtype_left.fields}
fields_right = {f.name: f.dtype for f in dtype_right.fields}
return any(
_needs_element_wise_comparison(fields_left[name], fields_right[name])
for name in fields_left
if name in fields_right
)
if isinstance(dtype_left, pl.List | pl.Array) and isinstance(
dtype_right, pl.List | pl.Array
):
return _needs_element_wise_comparison(dtype_left.inner, dtype_right.inner)
Comment thread
MariusMerkleQC marked this conversation as resolved.
Outdated
return False


def _compare_primitive_columns(
col_left: pl.Expr,
col_right: pl.Expr,
Expand All @@ -246,13 +293,11 @@ def _compare_primitive_columns(
rel_tol: float,
abs_tol_temporal: dt.timedelta,
) -> pl.Expr:
if (dtype_left.is_float() or dtype_right.is_float()) and (
dtype_left.is_numeric() and dtype_right.is_numeric()
):
if _is_float_numeric_pair(dtype_left, dtype_right):
return col_left.is_close(col_right, abs_tol=abs_tol, rel_tol=rel_tol).pipe(
_eq_missing_with_nan, lhs=col_left, rhs=col_right
)
elif dtype_left.is_temporal() and dtype_right.is_temporal():
elif _is_temporal_pair(dtype_left, dtype_right):
diff_less_than_tolerance = (col_left - col_right).abs() <= abs_tol_temporal
return diff_less_than_tolerance.pipe(_eq_missing, lhs=col_left, rhs=col_right)

Expand Down
57 changes: 57 additions & 0 deletions tests/test_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
import polars as pl

from diffly import compare_frames
from diffly._conditions import condition_equal_columns
from diffly._utils import (
ABS_TOL_DEFAULT,
ABS_TOL_TEMPORAL_DEFAULT,
REL_TOL_DEFAULT,
Side,
)


def test_summary_lazyframe_not_slower_than_dataframe() -> None:
Expand Down Expand Up @@ -74,3 +81,53 @@ def expensive_computation(col: pl.Expr) -> pl.Expr:
f"({mean_time_lf:.3f}s vs {mean_time_df:.3f}s). "
f"This suggests unnecessary re-collection of LazyFrames."
)


def test_eq_missing_not_slower_than_element_wise_for_list_columns() -> None:
"""Ensure that comparing list columns with non-tolerance inner types via
eq_missing() is not slower than the element-wise _compare_sequence_columns()
path."""
n_rows = 500_000
list_len = 20
num_runs_measured = 10
num_runs_warmup = 2

col_left = f"val_{Side.LEFT}"
col_right = f"val_{Side.RIGHT}"
df = pl.DataFrame(
{
col_left: [list(range(list_len)) for _ in range(n_rows)],
col_right: [list(range(list_len)) for _ in range(n_rows)],
}
)

times_eq = []
times_cond = []
for _ in range(num_runs_warmup + num_runs_measured):
start = time.perf_counter()
df.select(pl.col(col_left).eq_missing(pl.col(col_right))).to_series()
times_eq.append(time.perf_counter() - start)

start = time.perf_counter()
df.select(
condition_equal_columns(
column="val",
dtype_left=df.schema[col_left],
dtype_right=df.schema[col_right],
max_list_length=list_len,
abs_tol=ABS_TOL_DEFAULT,
rel_tol=REL_TOL_DEFAULT,
abs_tol_temporal=ABS_TOL_TEMPORAL_DEFAULT,
)
).to_series()
times_cond.append(time.perf_counter() - start)

mean_time_eq = statistics.mean(times_eq[num_runs_warmup:])
mean_time_cond = statistics.mean(times_cond[num_runs_warmup:])

ratio = mean_time_cond / mean_time_eq
assert ratio < 1.25, (
f"condition_equal_columns was {ratio:.1f}x slower than eq_missing "
f"({mean_time_cond:.3f}s vs {mean_time_eq:.3f}s). "
f"Expected comparable performance since list<i64> should use eq_missing directly."
)
Loading