Skip to content

Commit 935e5cd

Browse files
paddymulclaude
andcommitted
feat(paf): process_table_scalars / _aggregates entry points
Phase 2 of plans/js-driven-stat-debounce.md. Splits the stat pipeline by cost class so the JS-driven progressive-stats router can run cheap stats immediately on state_change and expensive stats after a debounce. Server-side changes: - ``StatPipeline.process_df(cost_classes=None)`` — new kwarg filtering which stat funcs execute. Default ``None`` = run all costs (back-compat for every existing caller). - ``StatPipeline.process_df_scalars(df)`` — convenience wrapper for ``cost_classes={"scalar"}``. Histograms etc. skipped entirely. - ``StatPipeline.process_df_aggregates(df)`` — runs full pipeline (aggregates depend on scalar inputs), filters the response to just aggregate-cost provides. - ``XorqStatPipeline.process_table(cost_classes=None)`` + ``process_table_scalars()`` + ``process_table_aggregates()``. For xorq the scalars-only path skips the per-column query loop in ``process_table`` phase 2 — exactly the ~6.5 s of histogram queries that dominates the boston state_change latency. Test ``test_scalars_only_skips_histogram_queries`` asserts strictly fewer queries vs the full pipeline. This commit ships the API surface; the WS handler that calls these new entry points is the next PR (Phase 4 wiring). No behavior change for any existing caller — all default to ``cost_classes=None``. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 898414f commit 935e5cd

4 files changed

Lines changed: 213 additions & 6 deletions

File tree

buckaroo/pluggable_analysis_framework/stat_pipeline.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,13 +234,18 @@ def __init__(self, stat_funcs: list, unit_test: bool = True, record_timings: boo
234234
self._unit_test_result = self.unit_test()
235235

236236
def process_column(self, column_name: str, column_dtype, raw_series=None, sampled_series=None, raw_dataframe=None,
237-
initial_stats: Optional[Dict[str, Any]] = None) -> Tuple[Dict[str, Any], List[StatError]]:
237+
initial_stats: Optional[Dict[str, Any]] = None,
238+
cost_classes=None) -> Tuple[Dict[str, Any], List[StatError]]:
238239
"""Process a single column through the stat DAG.
239240
240241
1. Filters stat functions by column dtype
241-
2. Executes in topological order with Ok/Err accumulator
242-
3. Returns (plain_dict, errors)
242+
2. Filters by ``cost_classes`` (default: all costs) — used by
243+
the JS-driven progressive-stats router to run scalars only
244+
3. Executes in topological order with Ok/Err accumulator
245+
4. Returns (plain_dict, errors)
243246
"""
247+
if cost_classes is None:
248+
cost_classes = {"scalar", "aggregate"}
244249
# Build column-specific DAG (filters by dtype)
245250
external = set(self.EXTERNAL_KEYS)
246251
if initial_stats:
@@ -255,6 +260,8 @@ def process_column(self, column_name: str, column_dtype, raw_series=None, sample
255260
accumulator[k] = Ok(v)
256261
record_timings = self.record_timings
257262
for sf in column_funcs:
263+
if sf.cost not in cost_classes:
264+
continue
258265
if record_timings:
259266
t0 = time.perf_counter()
260267
_execute_stat_func(sf, accumulator, column_name, raw_series=raw_series, sampled_series=sampled_series,
@@ -270,9 +277,12 @@ def process_column(self, column_name: str, column_dtype, raw_series=None, sample
270277

271278
return resolve_accumulator(accumulator, column_name, col_key_to_func)
272279

273-
def process_df(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, List[StatError]]:
280+
def process_df(self, df: pd.DataFrame, debug: bool = False, cost_classes=None) -> Tuple[SDType, List[StatError]]:
274281
"""Process all columns of a DataFrame.
275282
283+
``cost_classes`` (default: all) restricts which stat funcs run.
284+
Used by the JS-driven progressive-stats router.
285+
276286
Returns:
277287
(summary_dict, all_errors) where summary_dict is SDType-compatible
278288
(column_name -> {stat_name -> value}).
@@ -292,13 +302,39 @@ def process_df(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, Lis
292302

293303
col_result, col_errors = self.process_column(column_name=rewritten_col_name, column_dtype=col_dtype,
294304
raw_series=ser, sampled_series=ser, raw_dataframe=df,
295-
initial_stats={'orig_col_name': orig_col_name, 'rewritten_col_name': rewritten_col_name})
305+
initial_stats={'orig_col_name': orig_col_name, 'rewritten_col_name': rewritten_col_name},
306+
cost_classes=cost_classes)
296307

297308
summary[rewritten_col_name] = col_result
298309
all_errors.extend(col_errors)
299310

300311
return summary, all_errors
301312

313+
def process_df_scalars(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, List[StatError]]:
314+
"""Run only cost=scalar stats — the fast path used by the
315+
JS-driven progressive-stats router. Histograms etc. are
316+
skipped."""
317+
return self.process_df(df, debug=debug, cost_classes={"scalar"})
318+
319+
def process_df_aggregates(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, List[StatError]]:
320+
"""Run the full pipeline, return only aggregate-cost stats.
321+
322+
Aggregates depend on scalar inputs, so the full pipeline runs;
323+
the response is filtered to ship just the aggregate provides
324+
(typically histograms). Caching scalars across this and
325+
``process_df_scalars`` is a future PR.
326+
"""
327+
summary, errs = self.process_df(df, debug=debug)
328+
agg_provides = {sk.name for sf in self.ordered_stat_funcs
329+
if sf.cost == "aggregate" for sk in sf.provides}
330+
filtered = {col: {k: v for k, v in stats.items() if k in agg_provides}
331+
for col, stats in summary.items()}
332+
agg_func_names = {sf.name for sf in self.ordered_stat_funcs
333+
if sf.cost == "aggregate"}
334+
filtered_errs = [e for e in errs
335+
if e.stat_func is not None and e.stat_func.name in agg_func_names]
336+
return filtered, filtered_errs
337+
302338
def process_df_v1_compat(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, ErrDict]:
303339
"""Process DataFrame with v1-compatible error format.
304340

buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,48 @@ def unit_test(self) -> Tuple[bool, List[StatError]]:
165165
finally:
166166
self.backend = saved_backend
167167

168-
def process_table(self, table) -> Tuple[SDType, List[StatError]]:
168+
def process_table_scalars(self, table) -> Tuple[SDType, List[StatError]]:
169+
"""Run only cost=scalar stats — the fast path.
170+
171+
Filters out cost=aggregate funcs (histograms, per-column-query
172+
stats) before running the pipeline. Used by the JS-driven
173+
progressive-stats router to ship cheap stats immediately on
174+
state_change. See plans/js-driven-stat-debounce.md.
175+
176+
Aggregate stats that depend only on scalars (e.g. a downstream
177+
compute that consumes ``histogram``) get their inputs missing
178+
and produce ``Err`` upstream — that's the expected shape; the
179+
consumer should ignore those when asking for scalars only.
180+
"""
181+
return self.process_table(table, cost_classes={"scalar"})
182+
183+
def process_table_aggregates(self, table) -> Tuple[SDType, List[StatError]]:
184+
"""Run the full pipeline, return only aggregate-cost stats.
185+
186+
Aggregates typically depend on scalar inputs (e.g. ``histogram``
187+
consumes ``value_counts``, ``length``), so the full pipeline
188+
has to run. Output is filtered to just aggregate-cost provides
189+
so the response payload is small. Caching of the scalar half
190+
across this and ``process_table_scalars`` is a future PR; here
191+
the compute is whole, only the shipping is filtered.
192+
"""
193+
summary, errs = self.process_table(table)
194+
agg_provides = {sk.name for sf in self.ordered_stat_funcs
195+
if sf.cost == "aggregate" for sk in sf.provides}
196+
filtered = {col: {k: v for k, v in stats.items() if k in agg_provides}
197+
for col, stats in summary.items()}
198+
# Errors are per-stat-func; keep only those from aggregate funcs.
199+
agg_func_names = {sf.name for sf in self.ordered_stat_funcs
200+
if sf.cost == "aggregate"}
201+
filtered_errs = [e for e in errs
202+
if e.stat_func is not None and e.stat_func.name in agg_func_names]
203+
return filtered, filtered_errs
204+
205+
def process_table(self, table, cost_classes=None) -> Tuple[SDType, List[StatError]]:
206+
"""Run the pipeline; ``cost_classes`` filter restricts which
207+
stat funcs execute (default: all costs)."""
208+
if cost_classes is None:
209+
cost_classes = {"scalar", "aggregate"}
169210
schema = table.schema()
170211
columns = list(table.columns)
171212

@@ -188,6 +229,8 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]:
188229
for sf in self.ordered_stat_funcs:
189230
if not _is_batch_func(sf):
190231
continue
232+
if sf.cost not in cost_classes:
233+
continue
191234
xorq_col_param = next(r.name for r in sf.requires if r.type is XorqColumn)
192235
for col in columns:
193236
col_dtype = schema[col]
@@ -253,6 +296,10 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]:
253296
# (typically the batch-phase stats).
254297
if sf.provides and all(sk.name in col_accum for sk in sf.provides):
255298
continue
299+
# Cost-class filter — the JS-driven router runs scalars
300+
# first, aggregates after a debounce.
301+
if sf.cost not in cost_classes:
302+
continue
256303
_execute_stat_func(sf, col_accum, col, raw_series=None, sampled_series=None, raw_dataframe=None,
257304
xorq_expr=table, xorq_execute=self._execute)
258305

tests/unit/test_paf_v2.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,54 @@ def test_basic_pipeline(self):
600600
assert 'distinct_per' in pipeline.provided_summary_facts_set
601601
assert 'length' in pipeline.provided_summary_facts_set
602602

603+
def test_process_df_cost_class_filter_scalars(self):
604+
"""Scalars-only path: ``process_df_scalars`` runs only cost=scalar
605+
stats. Aggregate funcs (and stats that depend on them) are
606+
skipped. The fast path the JS router uses for the initial
607+
state_change response."""
608+
@stat(cost="aggregate")
609+
def expensive_stat(ser: RawSeries) -> int:
610+
return 999 # would be slow in real life
611+
612+
df = pd.DataFrame({'a': [1, 2, 3, 1, 2]})
613+
pipeline = StatPipeline([length, expensive_stat], unit_test=False)
614+
summary, errs = pipeline.process_df_scalars(df)
615+
616+
assert summary['a']['length'] == 5
617+
# The aggregate stat was filtered out — it did not run.
618+
assert 'expensive_stat' not in summary['a']
619+
620+
def test_process_df_cost_class_filter_aggregates(self):
621+
"""Aggregates-only path: ``process_df_aggregates`` runs the
622+
pipeline but ships only aggregate provides. Used by the JS
623+
router for the slow follow-up after a debounce."""
624+
@stat(cost="aggregate")
625+
def expensive_stat(ser: RawSeries) -> int:
626+
return 999
627+
628+
df = pd.DataFrame({'a': [1, 2, 3, 1, 2]})
629+
pipeline = StatPipeline([length, expensive_stat], unit_test=False)
630+
summary, errs = pipeline.process_df_aggregates(df)
631+
632+
# Only aggregate-cost stat is shipped.
633+
assert summary['a']['expensive_stat'] == 999
634+
# Scalar 'length' is computed but filtered out of the response.
635+
assert 'length' not in summary['a']
636+
637+
def test_process_df_default_runs_all_costs(self):
638+
"""Default ``process_df()`` (no ``cost_classes`` arg) runs all
639+
cost classes — back-compat for every existing caller."""
640+
@stat(cost="aggregate")
641+
def expensive_stat(ser: RawSeries) -> int:
642+
return 999
643+
644+
df = pd.DataFrame({'a': [1, 2, 3]})
645+
pipeline = StatPipeline([length, expensive_stat], unit_test=False)
646+
summary, errs = pipeline.process_df(df)
647+
648+
assert summary['a']['length'] == 3
649+
assert summary['a']['expensive_stat'] == 999
650+
603651
def test_process_column(self):
604652
pipeline = StatPipeline([length, distinct_count, distinct_per], unit_test=False)
605653
ser = pd.Series([1, 2, 3, 1, 2])

tests/unit/test_xorq_buckaroo_widget.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,82 @@ def counting(self, q):
6060
)
6161

6262

63+
class TestCostClassFilter:
64+
"""Phase 2 of the JS-driven progressive-stats router. The
65+
``process_table_scalars`` / ``process_table_aggregates`` entry
66+
points filter by ``StatFunc.cost`` so the JS orchestrator can
67+
fetch cheap stats immediately on state_change and slow stats
68+
after a debounce.
69+
"""
70+
71+
def test_scalars_only_skips_histogram_queries(self):
72+
"""``process_table_scalars`` skips the expensive per-column
73+
histogram path entirely. Spy on ``_execute`` and assert the
74+
query count drops vs the full pipeline."""
75+
from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import XorqStatPipeline
76+
from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2
77+
78+
full_queries: list = []
79+
scalar_queries: list = []
80+
orig = XorqStatPipeline._execute
81+
82+
def counting(self, q, sink):
83+
sink.append(q)
84+
return orig(self, q)
85+
86+
expr = _expr()
87+
88+
# Full pipeline reference count
89+
XorqStatPipeline._execute = lambda self, q: counting(self, q, full_queries)
90+
try:
91+
p1 = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False)
92+
p1.process_table(expr)
93+
finally:
94+
XorqStatPipeline._execute = orig
95+
96+
# Scalars-only count — must be strictly fewer
97+
XorqStatPipeline._execute = lambda self, q: counting(self, q, scalar_queries)
98+
try:
99+
p2 = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False)
100+
p2.process_table_scalars(expr)
101+
finally:
102+
XorqStatPipeline._execute = orig
103+
104+
assert len(scalar_queries) < len(full_queries), (
105+
f"scalars-only must issue fewer queries than full pipeline; "
106+
f"got scalars={len(scalar_queries)} full={len(full_queries)}")
107+
108+
def test_scalars_only_omits_histogram_key(self):
109+
"""``process_table_scalars`` output has no ``histogram`` key for
110+
any column — that stat is cost=aggregate."""
111+
from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import XorqStatPipeline
112+
from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2
113+
114+
pipeline = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False)
115+
summary, errs = pipeline.process_table_scalars(_expr())
116+
for col, stats in summary.items():
117+
assert "histogram" not in stats, (
118+
f"col {col} has histogram in scalars-only output: "
119+
f"{list(stats.keys())}")
120+
121+
def test_aggregates_only_ships_only_aggregate_keys(self):
122+
"""``process_table_aggregates`` returns only aggregate-cost
123+
stat provides. Scalars are computed (as dependencies) but
124+
filtered out of the response."""
125+
from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import XorqStatPipeline
126+
from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2
127+
128+
pipeline = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False)
129+
summary, errs = pipeline.process_table_aggregates(_expr())
130+
for col, stats in summary.items():
131+
# The boston-style histogram stat is aggregate.
132+
if "histogram" in stats:
133+
# Other obvious scalar keys must NOT be there.
134+
assert "length" not in stats, (
135+
f"col {col} aggregate response leaks scalar 'length'")
136+
assert "min" not in stats
137+
138+
63139
class TestInstantiation:
64140
def test_smoke(self):
65141
XorqBuckarooWidget(_expr())

0 commit comments

Comments
 (0)