Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions buckaroo/pluggable_analysis_framework/stat_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import inspect
from dataclasses import dataclass, field
from typing import Any, Callable, List, Optional, TypedDict, get_type_hints
from typing import Any, Callable, List, Optional, Tuple, TypedDict, get_type_hints


class MultipleProvides(TypedDict):
Expand Down Expand Up @@ -145,6 +145,10 @@ class StatFunc:
column_filter: optional predicate on column dtype
quiet: suppress error reporting
default: fallback value on failure (MISSING = no fallback)
pushdown: backend identifiers (e.g. ``("xorq", "polars")``) whose
engines can compute this stat via aggregation push-down without
in-process materialization. Empty default means pandas-only /
requires materialization.
"""
name: str
func: Callable
Expand All @@ -154,6 +158,7 @@ class StatFunc:
column_filter: Optional[Callable] = None
quiet: bool = False
default: Any = field(default_factory=lambda: MISSING)
pushdown: Tuple[str, ...] = ()
spread_dict_result: bool = False # v1 compat: spread all dict keys into accumulator
v1_computed: bool = False # v1 compat: pass full accumulator as single dict arg

Expand Down Expand Up @@ -250,7 +255,7 @@ def _get_requires_from_params(sig: inspect.Signature, hints: dict) -> tuple:
# @stat decorator
# ---------------------------------------------------------------------------

def stat(column_filter=None, quiet=False, default=MISSING):
def stat(column_filter=None, quiet=False, default=MISSING, pushdown=()):
"""Decorator that converts a function into a StatFunc.

The function signature IS the contract:
Expand All @@ -263,6 +268,12 @@ def stat(column_filter=None, quiet=False, default=MISSING):
key the rest of the DAG expects. Use ``MultipleProvides`` (a TypedDict
alias) when one function should write several keys.

``pushdown=`` declares which backends can compute this stat via
engine-side aggregation push-down (no in-process materialization).
Empty default means pandas-only / requires materialization. Recognised
identifiers: ``"xorq"``, ``"polars"``. Normalised to a tuple so callers
may pass a list.

Usage::

@stat()
Expand All @@ -277,6 +288,10 @@ def mean(ser: RawSeries) -> float:
def safe_ratio(a: int, b: int) -> float:
return a / b

@stat(pushdown=("xorq", "polars"))
def mean(col):
return col.mean()

class TypingResult(MultipleProvides):
is_numeric: bool
is_integer: bool
Expand All @@ -297,8 +312,12 @@ def decorator(func):
requires, needs_raw = _get_requires_from_params(sig, hints)
provides_keys = _get_provides_from_return_type(func.__name__, return_type)

# A bare string is the natural single-backend form; ``tuple(str)``
# would silently expand it to a tuple of characters.
pushdown_norm = (pushdown,) if isinstance(pushdown, str) else tuple(pushdown)
stat_func = StatFunc(name=func.__name__, func=func, requires=requires, provides=provides_keys,
needs_raw=needs_raw, column_filter=column_filter, quiet=quiet, default=default)
needs_raw=needs_raw, column_filter=column_filter, quiet=quiet, default=default,
pushdown=pushdown_norm)

# Attach metadata to the function so pipeline can find it
func._stat_func = stat_func
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/test_paf_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,34 @@ def test_no_default(self):
sf = distinct_per._stat_func
assert sf.default is MISSING

def test_stat_default_pushdown_is_empty(self):
sf = length._stat_func
assert sf.pushdown == ()

def test_stat_pushdown_stored_on_stat_func(self):
@stat(pushdown=("xorq", "polars"))
def pushdown_mean(ser: RawSeries) -> float:
return float(ser.mean())

assert pushdown_mean._stat_func.pushdown == ("xorq", "polars")

def test_stat_pushdown_normalises_list_to_tuple(self):
@stat(pushdown=["xorq"])
def pushdown_sum(ser: RawSeries) -> float:
return float(ser.sum())

assert pushdown_sum._stat_func.pushdown == ("xorq",)
assert isinstance(pushdown_sum._stat_func.pushdown, tuple)

def test_stat_pushdown_accepts_bare_string(self):
# A bare string is the natural form for one backend.
# ``tuple("xorq")`` would silently store ``('x','o','r','q')``.
@stat(pushdown="xorq")
def pushdown_count(ser: RawSeries) -> int:
return int(ser.count())

assert pushdown_count._stat_func.pushdown == ("xorq",)


class _MultiSizeStats(MultipleProvides):
row_count: int
Expand Down
Loading