|
| 1 | +# Copyright (c) QuantCo 2025-2026 |
| 2 | +# SPDX-License-Identifier: BSD-3-Clause |
| 3 | + |
| 4 | +from __future__ import annotations |
| 5 | + |
| 6 | +from collections.abc import Callable |
| 7 | +from dataclasses import dataclass |
| 8 | + |
| 9 | +import polars as pl |
| 10 | +import polars.selectors as cs |
| 11 | + |
| 12 | + |
| 13 | +@dataclass(frozen=True) |
| 14 | +class Metric: |
| 15 | + """A metric function paired with a column-applicability selector. |
| 16 | +
|
| 17 | + Internal only. |
| 18 | + """ |
| 19 | + |
| 20 | + fn: MetricFn |
| 21 | + selector: cs.Selector |
| 22 | + |
| 23 | + |
| 24 | +MetricFn = Callable[[pl.Expr, pl.Expr], pl.Expr] |
| 25 | +"""A metric function maps ``(left_expr, right_expr)`` to a scalar aggregation |
| 26 | +expression. |
| 27 | +
|
| 28 | +The expressions refer to the left-side and right-side values of a single column across |
| 29 | +all joined rows. |
| 30 | +""" |
| 31 | + |
| 32 | + |
| 33 | +def _make_numeric_metric(fn: MetricFn) -> Metric: |
| 34 | + return Metric(fn=fn, selector=cs.numeric()) |
| 35 | + |
| 36 | + |
| 37 | +def mean(left: pl.Expr, right: pl.Expr) -> pl.Expr: |
| 38 | + """Mean of ``right - left``.""" |
| 39 | + return (right - left).mean() |
| 40 | + |
| 41 | + |
| 42 | +def median(left: pl.Expr, right: pl.Expr) -> pl.Expr: |
| 43 | + """Median of ``right - left``.""" |
| 44 | + return (right - left).median() |
| 45 | + |
| 46 | + |
| 47 | +def min(left: pl.Expr, right: pl.Expr) -> pl.Expr: |
| 48 | + """Minimum of ``right - left``.""" |
| 49 | + return (right - left).min() |
| 50 | + |
| 51 | + |
| 52 | +def max(left: pl.Expr, right: pl.Expr) -> pl.Expr: |
| 53 | + """Maximum of ``right - left``.""" |
| 54 | + return (right - left).max() |
| 55 | + |
| 56 | + |
| 57 | +def std(left: pl.Expr, right: pl.Expr) -> pl.Expr: |
| 58 | + """Standard deviation of ``right - left``.""" |
| 59 | + return (right - left).std() |
| 60 | + |
| 61 | + |
| 62 | +def mean_absolute_deviation(left: pl.Expr, right: pl.Expr) -> pl.Expr: |
| 63 | + """Mean of ``|right - left|``.""" |
| 64 | + return (right - left).abs().mean() |
| 65 | + |
| 66 | + |
| 67 | +def mean_relative_deviation(left: pl.Expr, right: pl.Expr) -> pl.Expr: |
| 68 | + """Mean of ``|(right - left) / left|``. Yields ``inf`` or ``null`` where |
| 69 | + ``left`` is zero.""" |
| 70 | + return ((right - left) / left).abs().mean() |
| 71 | + |
| 72 | + |
| 73 | +def quantile(q: float) -> MetricFn: |
| 74 | + """Factory returning a metric that computes the ``q``-quantile of |
| 75 | + ``right - left``.""" |
| 76 | + if not 0 <= q <= 1: |
| 77 | + raise ValueError(f"q must be in [0, 1], got {q}") |
| 78 | + |
| 79 | + def _quantile(left: pl.Expr, right: pl.Expr) -> pl.Expr: |
| 80 | + return (right - left).quantile(q) |
| 81 | + |
| 82 | + return _quantile |
| 83 | + |
| 84 | + |
| 85 | +DEFAULT_METRICS: dict[str, MetricFn] = { |
| 86 | + "Mean": mean, |
| 87 | + "Median": median, |
| 88 | + "Min": min, |
| 89 | + "Max": max, |
| 90 | + "Std": std, |
| 91 | + "Mean absolute deviation": mean_absolute_deviation, |
| 92 | + "Mean relative deviation": mean_relative_deviation, |
| 93 | +} |
0 commit comments