diff --git a/buckaroo/pluggable_analysis_framework/stat_func.py b/buckaroo/pluggable_analysis_framework/stat_func.py index 8e97a3f28..3ddec4f6e 100644 --- a/buckaroo/pluggable_analysis_framework/stat_func.py +++ b/buckaroo/pluggable_analysis_framework/stat_func.py @@ -11,7 +11,7 @@ import inspect from dataclasses import dataclass, field -from typing import Any, Callable, List, Optional, TypedDict, get_type_hints +from typing import Any, Callable, List, Optional, Tuple, TypedDict, get_type_hints class MultipleProvides(TypedDict): @@ -145,6 +145,10 @@ class StatFunc: column_filter: optional predicate on column dtype quiet: suppress error reporting default: fallback value on failure (MISSING = no fallback) + pushdown: backend identifiers (e.g. ``("xorq", "polars")``) whose + engines can compute this stat via aggregation push-down without + in-process materialization. Empty default means pandas-only / + requires materialization. """ name: str func: Callable @@ -154,6 +158,7 @@ class StatFunc: column_filter: Optional[Callable] = None quiet: bool = False default: Any = field(default_factory=lambda: MISSING) + pushdown: Tuple[str, ...] = () spread_dict_result: bool = False # v1 compat: spread all dict keys into accumulator v1_computed: bool = False # v1 compat: pass full accumulator as single dict arg @@ -250,7 +255,7 @@ def _get_requires_from_params(sig: inspect.Signature, hints: dict) -> tuple: # @stat decorator # --------------------------------------------------------------------------- -def stat(column_filter=None, quiet=False, default=MISSING): +def stat(column_filter=None, quiet=False, default=MISSING, pushdown=()): """Decorator that converts a function into a StatFunc. The function signature IS the contract: @@ -263,6 +268,12 @@ def stat(column_filter=None, quiet=False, default=MISSING): key the rest of the DAG expects. Use ``MultipleProvides`` (a TypedDict alias) when one function should write several keys. + ``pushdown=`` declares which backends can compute this stat via + engine-side aggregation push-down (no in-process materialization). + Empty default means pandas-only / requires materialization. Recognised + identifiers: ``"xorq"``, ``"polars"``. Normalised to a tuple so callers + may pass a list. + Usage:: @stat() @@ -277,6 +288,10 @@ def mean(ser: RawSeries) -> float: def safe_ratio(a: int, b: int) -> float: return a / b + @stat(pushdown=("xorq", "polars")) + def mean(col): + return col.mean() + class TypingResult(MultipleProvides): is_numeric: bool is_integer: bool @@ -297,8 +312,12 @@ def decorator(func): requires, needs_raw = _get_requires_from_params(sig, hints) provides_keys = _get_provides_from_return_type(func.__name__, return_type) + # A bare string is the natural single-backend form; ``tuple(str)`` + # would silently expand it to a tuple of characters. + pushdown_norm = (pushdown,) if isinstance(pushdown, str) else tuple(pushdown) stat_func = StatFunc(name=func.__name__, func=func, requires=requires, provides=provides_keys, - needs_raw=needs_raw, column_filter=column_filter, quiet=quiet, default=default) + needs_raw=needs_raw, column_filter=column_filter, quiet=quiet, default=default, + pushdown=pushdown_norm) # Attach metadata to the function so pipeline can find it func._stat_func = stat_func diff --git a/tests/unit/test_paf_v2.py b/tests/unit/test_paf_v2.py index 7e8cbc1d6..55933644b 100644 --- a/tests/unit/test_paf_v2.py +++ b/tests/unit/test_paf_v2.py @@ -176,6 +176,34 @@ def test_no_default(self): sf = distinct_per._stat_func assert sf.default is MISSING + def test_stat_default_pushdown_is_empty(self): + sf = length._stat_func + assert sf.pushdown == () + + def test_stat_pushdown_stored_on_stat_func(self): + @stat(pushdown=("xorq", "polars")) + def pushdown_mean(ser: RawSeries) -> float: + return float(ser.mean()) + + assert pushdown_mean._stat_func.pushdown == ("xorq", "polars") + + def test_stat_pushdown_normalises_list_to_tuple(self): + @stat(pushdown=["xorq"]) + def pushdown_sum(ser: RawSeries) -> float: + return float(ser.sum()) + + assert pushdown_sum._stat_func.pushdown == ("xorq",) + assert isinstance(pushdown_sum._stat_func.pushdown, tuple) + + def test_stat_pushdown_accepts_bare_string(self): + # A bare string is the natural form for one backend. + # ``tuple("xorq")`` would silently store ``('x','o','r','q')``. + @stat(pushdown="xorq") + def pushdown_count(ser: RawSeries) -> int: + return int(ser.count()) + + assert pushdown_count._stat_func.pushdown == ("xorq",) + class _MultiSizeStats(MultipleProvides): row_count: int