Skip to content

Commit c46ecba

Browse files
committed
Merge #808: @stat(cost=...) decorator + tag known-expensive stats
2 parents 2f6a364 + 898414f commit c46ecba

5 files changed

Lines changed: 72 additions & 6 deletions

File tree

buckaroo/customizations/pd_stats_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def vc_nth(pos):
176176
HistogramSeriesResult = TypedDict('HistogramSeriesResult', {'histogram_args': dict, 'histogram_bins': list})
177177

178178

179-
@stat()
179+
@stat(cost="aggregate")
180180
def histogram_series(ser: RawSeries) -> HistogramSeriesResult:
181181
"""Compute histogram args from raw series (numeric path)."""
182182
if not pd.api.types.is_numeric_dtype(ser):
@@ -210,7 +210,7 @@ def histogram_series(ser: RawSeries) -> HistogramSeriesResult:
210210
}
211211

212212

213-
@stat()
213+
@stat(cost="aggregate")
214214
def histogram(value_counts: pd.Series, nan_per: float, is_numeric: bool, length: int, min: Any, max: Any,
215215
histogram_args: dict) -> list:
216216
"""Compute histogram from summary stats and histogram args."""

buckaroo/customizations/pl_stats_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def pl_numeric_stats(ser: RawSeries) -> NumericStatsResult:
113113
# Histogram Series (polars series API)
114114
# ============================================================
115115

116-
@stat()
116+
@stat(cost="aggregate")
117117
def pl_histogram_series(ser: RawSeries) -> HistogramSeriesResult:
118118
"""Compute histogram args from raw polars series (numeric path)."""
119119
if not ser.dtype.is_numeric():

buckaroo/customizations/xorq_stats_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _categorical_histogram(execute: Callable[[Any], pd.DataFrame], expr: Any, co
255255
return out
256256

257257

258-
@stat(default=[])
258+
@stat(default=[], cost="aggregate")
259259
def histogram(expr: XorqExpr, execute: XorqExecute, orig_col_name: str, is_numeric: bool, is_bool: bool, length: int,
260260
distinct_count: int, min: float, max: float) -> list:
261261
"""10-bucket numeric histogram or top-10 categorical histogram.

buckaroo/pluggable_analysis_framework/stat_func.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def __repr__(self):
132132
# StatFunc — a registered stat computation
133133
# ---------------------------------------------------------------------------
134134

135+
VALID_COSTS = ("scalar", "aggregate")
136+
137+
135138
@dataclass
136139
class StatFunc:
137140
"""A registered stat computation.
@@ -149,6 +152,11 @@ class StatFunc:
149152
engines can compute this stat via aggregation push-down without
150153
in-process materialization. Empty default means pandas-only /
151154
requires materialization.
155+
cost: cost class — ``"scalar"`` (cheap, ships in the initial
156+
state_change response) or ``"aggregate"`` (slow path that the
157+
JS orchestrator fetches via a separate round-trip after a
158+
debounce). Histograms, value_counts and other per-column
159+
queries belong in ``"aggregate"``.
152160
"""
153161
name: str
154162
func: Callable
@@ -159,6 +167,7 @@ class StatFunc:
159167
quiet: bool = False
160168
default: Any = field(default_factory=lambda: MISSING)
161169
pushdown: Tuple[str, ...] = ()
170+
cost: str = "scalar"
162171
spread_dict_result: bool = False # v1 compat: spread all dict keys into accumulator
163172
v1_computed: bool = False # v1 compat: pass full accumulator as single dict arg
164173

@@ -255,7 +264,7 @@ def _get_requires_from_params(sig: inspect.Signature, hints: dict) -> tuple:
255264
# @stat decorator
256265
# ---------------------------------------------------------------------------
257266

258-
def stat(column_filter=None, quiet=False, default=MISSING, pushdown=()):
267+
def stat(column_filter=None, quiet=False, default=MISSING, pushdown=(), cost="scalar"):
259268
"""Decorator that converts a function into a StatFunc.
260269
261270
The function signature IS the contract:
@@ -274,6 +283,12 @@ def stat(column_filter=None, quiet=False, default=MISSING, pushdown=()):
274283
identifiers: ``"xorq"``, ``"polars"``. Normalised to a tuple so callers
275284
may pass a list.
276285
286+
``cost=`` declares the compute-cost class. ``"scalar"`` (default)
287+
means cheap — ships in the initial state_change response. ``"aggregate"``
288+
means slow (histograms, value_counts, per-column queries) — the JS
289+
orchestrator fetches these via a separate round-trip after a debounce.
290+
See plans/js-driven-stat-debounce.md.
291+
277292
Usage::
278293
279294
@stat()
@@ -292,6 +307,10 @@ def safe_ratio(a: int, b: int) -> float:
292307
def mean(col):
293308
return col.mean()
294309
310+
@stat(cost="aggregate")
311+
def histogram(...) -> list:
312+
...
313+
295314
class TypingResult(MultipleProvides):
296315
is_numeric: bool
297316
is_integer: bool
@@ -300,6 +319,10 @@ class TypingResult(MultipleProvides):
300319
def typing_stats(dtype: str) -> TypingResult:
301320
...
302321
"""
322+
if cost not in VALID_COSTS:
323+
raise ValueError(
324+
f"@stat(cost={cost!r}): invalid cost class. "
325+
f"Must be one of {VALID_COSTS}.")
303326
def decorator(func):
304327
sig = inspect.signature(func)
305328
try:
@@ -317,7 +340,7 @@ def decorator(func):
317340
pushdown_norm = (pushdown,) if isinstance(pushdown, str) else tuple(pushdown)
318341
stat_func = StatFunc(name=func.__name__, func=func, requires=requires, provides=provides_keys,
319342
needs_raw=needs_raw, column_filter=column_filter, quiet=quiet, default=default,
320-
pushdown=pushdown_norm)
343+
pushdown=pushdown_norm, cost=cost)
321344

322345
# Attach metadata to the function so pipeline can find it
323346
func._stat_func = stat_func

tests/unit/test_paf_v2.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,49 @@ def pushdown_count(ser: RawSeries) -> int:
204204

205205
assert pushdown_count._stat_func.pushdown == ("xorq",)
206206

207+
def test_stat_default_cost_is_scalar(self):
208+
# Default cost class is "scalar" — the bulk of stats are cheap.
209+
# Only known-expensive ones opt in to "aggregate".
210+
sf = length._stat_func
211+
assert sf.cost == "scalar"
212+
213+
def test_stat_explicit_cost_aggregate(self):
214+
@stat(cost="aggregate")
215+
def big_compute(ser: RawSeries) -> float:
216+
return float(ser.mean())
217+
218+
assert big_compute._stat_func.cost == "aggregate"
219+
220+
def test_stat_invalid_cost_rejected(self):
221+
# Only "scalar" and "aggregate" are recognised. A typo should
222+
# fail loud at decoration time — silently dropping an invalid
223+
# cost would leak into the cost-class router as an unscheduled
224+
# stat group.
225+
import pytest as _pt
226+
with _pt.raises(ValueError, match="cost"):
227+
@stat(cost="bigly")
228+
def bad(ser: RawSeries) -> float:
229+
return float(ser.mean())
230+
231+
def test_known_expensive_stats_marked_aggregate(self):
232+
"""The expensive built-in stat funcs (histogram producers across
233+
all three engines) are tagged ``cost="aggregate"`` so a
234+
downstream router can schedule them on the slow path."""
235+
from buckaroo.customizations.pd_stats_v2 import histogram as pd_histogram
236+
from buckaroo.customizations.pd_stats_v2 import histogram_series as pd_hs
237+
from buckaroo.customizations.pl_stats_v2 import pl_histogram_series
238+
239+
assert pd_histogram._stat_func.cost == "aggregate"
240+
assert pd_hs._stat_func.cost == "aggregate"
241+
assert pl_histogram_series._stat_func.cost == "aggregate"
242+
243+
# xorq histogram (optional dep — skip if not installed)
244+
try:
245+
from buckaroo.customizations.xorq_stats_v2 import histogram as xq_histogram
246+
except ImportError:
247+
return
248+
assert xq_histogram._stat_func.cost == "aggregate"
249+
207250

208251
class _MultiSizeStats(MultipleProvides):
209252
row_count: int

0 commit comments

Comments
 (0)