Skip to content

Commit 11dbd13

Browse files
feat: Add per-column metrics to summary (#34)
Co-authored-by: Marius Merkle <122545105+MariusMerkleQC@users.noreply.github.com>
1 parent ef4c5e9 commit 11dbd13

144 files changed

Lines changed: 4846 additions & 608 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

diffly/cli.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ._compat import typer
1313
from ._utils import ABS_TOL_DEFAULT, ABS_TOL_TEMPORAL_DEFAULT, REL_TOL_DEFAULT
14+
from .metrics import DEFAULT_METRICS
1415

1516
app = typer.Typer()
1617

@@ -129,8 +130,24 @@ def main(
129130
)
130131
),
131132
] = [],
133+
metric: Annotated[
134+
list[str],
135+
typer.Option(
136+
help=(
137+
"Metric presets to display per numerical column. Repeatable. "
138+
f"Available: {', '.join(DEFAULT_METRICS)}."
139+
)
140+
),
141+
] = [],
132142
) -> None:
133143
"""Compare two `parquet` files and print the comparison result."""
144+
for name in metric:
145+
if name not in DEFAULT_METRICS:
146+
raise typer.BadParameter(
147+
f"Unknown metric: {name!r}. Available: {', '.join(DEFAULT_METRICS)}."
148+
)
149+
metrics = {name: DEFAULT_METRICS[name] for name in metric}
150+
134151
comparison = compare_frames(
135152
pl.scan_parquet(left),
136153
pl.scan_parquet(right),
@@ -148,6 +165,7 @@ def main(
148165
right_name=right_name,
149166
slim=slim,
150167
hidden_columns=hidden_columns,
168+
metrics=metrics,
151169
)
152170
if output_json:
153171
typer.echo(summary.to_json())

diffly/comparison.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
lazy_len,
2626
make_and_validate_mapping,
2727
)
28+
from .metrics import MetricFn, _make_numeric_metric
2829

2930
if TYPE_CHECKING: # pragma: no cover
3031
# NOTE: We cannot import at runtime as we're otherwise running into circular
@@ -919,6 +920,7 @@ def summary(
919920
right_name: str = Side.RIGHT,
920921
slim: bool = False,
921922
hidden_columns: list[str] | None = None,
923+
metrics: Mapping[str, MetricFn] | None = None,
922924
) -> Summary:
923925
"""Generate a summary of all aspects of the comparison.
924926
@@ -948,6 +950,16 @@ def summary(
948950
advanced users who are familiar with the summary format.
949951
hidden_columns: Columns for which no values are printed, e.g. because they
950952
contain sensitive information.
953+
metrics: Optional mapping from display label to a metric callable
954+
``(left_expr, right_expr) -> pl.Expr``. Each callable receives two
955+
:class:`polars.Expr` referring to the left and right values of a single
956+
numerical column across all joined rows, and must return a scalar
957+
aggregation expression. See :doc:`/api/metrics` for the full list of
958+
presets and the :data:`~diffly.metrics.MetricFn` type. When ``None``
959+
(default), no metrics are computed; presets are not applied
960+
automatically. Metrics are only computed for numerical columns. Prefer
961+
short labels — the summary has a fixed width and many or long labels
962+
degrade rendering.
951963
952964
Returns:
953965
A summary which can be printed or written to a file.
@@ -963,6 +975,12 @@ def summary(
963975
# NOTE: We're importing here to prevent circular imports
964976
from .summary import Summary
965977

978+
resolved_metrics = (
979+
{label: _make_numeric_metric(fn) for label, fn in metrics.items()}
980+
if metrics is not None
981+
else None
982+
)
983+
966984
return Summary(
967985
self,
968986
show_perfect_column_matches=show_perfect_column_matches,
@@ -973,6 +991,7 @@ def summary(
973991
right_name=right_name,
974992
slim=slim,
975993
hidden_columns=hidden_columns,
994+
metrics=resolved_metrics,
976995
)
977996

978997
# ----------------------------------- UTILITIES ----------------------------------- #

diffly/metrics.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# Copyright (c) QuantCo 2025-2026
2+
# SPDX-License-Identifier: BSD-3-Clause
3+
4+
from __future__ import annotations
5+
6+
from collections.abc import Callable
7+
from dataclasses import dataclass
8+
9+
import polars as pl
10+
import polars.selectors as cs
11+
12+
13+
@dataclass(frozen=True)
14+
class Metric:
15+
"""A metric function paired with a column-applicability selector.
16+
17+
Internal only.
18+
"""
19+
20+
fn: MetricFn
21+
selector: cs.Selector
22+
23+
24+
MetricFn = Callable[[pl.Expr, pl.Expr], pl.Expr]
25+
"""A metric function maps ``(left_expr, right_expr)`` to a scalar aggregation
26+
expression.
27+
28+
The expressions refer to the left-side and right-side values of a single column across
29+
all joined rows.
30+
"""
31+
32+
33+
def _make_numeric_metric(fn: MetricFn) -> Metric:
34+
return Metric(fn=fn, selector=cs.numeric())
35+
36+
37+
def mean(left: pl.Expr, right: pl.Expr) -> pl.Expr:
38+
"""Mean of ``right - left``."""
39+
return (right - left).mean()
40+
41+
42+
def median(left: pl.Expr, right: pl.Expr) -> pl.Expr:
43+
"""Median of ``right - left``."""
44+
return (right - left).median()
45+
46+
47+
def min(left: pl.Expr, right: pl.Expr) -> pl.Expr:
48+
"""Minimum of ``right - left``."""
49+
return (right - left).min()
50+
51+
52+
def max(left: pl.Expr, right: pl.Expr) -> pl.Expr:
53+
"""Maximum of ``right - left``."""
54+
return (right - left).max()
55+
56+
57+
def std(left: pl.Expr, right: pl.Expr) -> pl.Expr:
58+
"""Standard deviation of ``right - left``."""
59+
return (right - left).std()
60+
61+
62+
def mean_absolute_deviation(left: pl.Expr, right: pl.Expr) -> pl.Expr:
63+
"""Mean of ``|right - left|``."""
64+
return (right - left).abs().mean()
65+
66+
67+
def mean_relative_deviation(left: pl.Expr, right: pl.Expr) -> pl.Expr:
68+
"""Mean of ``|(right - left) / left|``. Yields ``inf`` or ``null`` where
69+
``left`` is zero."""
70+
return ((right - left) / left).abs().mean()
71+
72+
73+
def quantile(q: float) -> MetricFn:
74+
"""Factory returning a metric that computes the ``q``-quantile of
75+
``right - left``."""
76+
if not 0 <= q <= 1:
77+
raise ValueError(f"q must be in [0, 1], got {q}")
78+
79+
def _quantile(left: pl.Expr, right: pl.Expr) -> pl.Expr:
80+
return (right - left).quantile(q)
81+
82+
return _quantile
83+
84+
85+
DEFAULT_METRICS: dict[str, MetricFn] = {
86+
"Mean": mean,
87+
"Median": median,
88+
"Min": min,
89+
"Max": max,
90+
"Std": std,
91+
"Mean absolute deviation": mean_absolute_deviation,
92+
"Mean relative deviation": mean_relative_deviation,
93+
}

0 commit comments

Comments
 (0)