Skip to content

Commit 63bf655

Browse files
committed
Merge #809: process_table_scalars / _aggregates (Phase 2)
2 parents c46ecba + 935e5cd commit 63bf655

4 files changed

Lines changed: 213 additions & 6 deletions

File tree

buckaroo/pluggable_analysis_framework/stat_pipeline.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,13 +234,18 @@ def __init__(self, stat_funcs: list, unit_test: bool = True, record_timings: boo
234234
self._unit_test_result = self.unit_test()
235235

236236
def process_column(self, column_name: str, column_dtype, raw_series=None, sampled_series=None, raw_dataframe=None,
237-
initial_stats: Optional[Dict[str, Any]] = None) -> Tuple[Dict[str, Any], List[StatError]]:
237+
initial_stats: Optional[Dict[str, Any]] = None,
238+
cost_classes=None) -> Tuple[Dict[str, Any], List[StatError]]:
238239
"""Process a single column through the stat DAG.
239240
240241
1. Filters stat functions by column dtype
241-
2. Executes in topological order with Ok/Err accumulator
242-
3. Returns (plain_dict, errors)
242+
2. Filters by ``cost_classes`` (default: all costs) — used by
243+
the JS-driven progressive-stats router to run scalars only
244+
3. Executes in topological order with Ok/Err accumulator
245+
4. Returns (plain_dict, errors)
243246
"""
247+
if cost_classes is None:
248+
cost_classes = {"scalar", "aggregate"}
244249
# Build column-specific DAG (filters by dtype)
245250
external = set(self.EXTERNAL_KEYS)
246251
if initial_stats:
@@ -255,6 +260,8 @@ def process_column(self, column_name: str, column_dtype, raw_series=None, sample
255260
accumulator[k] = Ok(v)
256261
record_timings = self.record_timings
257262
for sf in column_funcs:
263+
if sf.cost not in cost_classes:
264+
continue
258265
if record_timings:
259266
t0 = time.perf_counter()
260267
_execute_stat_func(sf, accumulator, column_name, raw_series=raw_series, sampled_series=sampled_series,
@@ -270,9 +277,12 @@ def process_column(self, column_name: str, column_dtype, raw_series=None, sample
270277

271278
return resolve_accumulator(accumulator, column_name, col_key_to_func)
272279

273-
def process_df(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, List[StatError]]:
280+
def process_df(self, df: pd.DataFrame, debug: bool = False, cost_classes=None) -> Tuple[SDType, List[StatError]]:
274281
"""Process all columns of a DataFrame.
275282
283+
``cost_classes`` (default: all) restricts which stat funcs run.
284+
Used by the JS-driven progressive-stats router.
285+
276286
Returns:
277287
(summary_dict, all_errors) where summary_dict is SDType-compatible
278288
(column_name -> {stat_name -> value}).
@@ -292,13 +302,39 @@ def process_df(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, Lis
292302

293303
col_result, col_errors = self.process_column(column_name=rewritten_col_name, column_dtype=col_dtype,
294304
raw_series=ser, sampled_series=ser, raw_dataframe=df,
295-
initial_stats={'orig_col_name': orig_col_name, 'rewritten_col_name': rewritten_col_name})
305+
initial_stats={'orig_col_name': orig_col_name, 'rewritten_col_name': rewritten_col_name},
306+
cost_classes=cost_classes)
296307

297308
summary[rewritten_col_name] = col_result
298309
all_errors.extend(col_errors)
299310

300311
return summary, all_errors
301312

313+
def process_df_scalars(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, List[StatError]]:
314+
"""Run only cost=scalar stats — the fast path used by the
315+
JS-driven progressive-stats router. Histograms etc. are
316+
skipped."""
317+
return self.process_df(df, debug=debug, cost_classes={"scalar"})
318+
319+
def process_df_aggregates(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, List[StatError]]:
320+
"""Run the full pipeline, return only aggregate-cost stats.
321+
322+
Aggregates depend on scalar inputs, so the full pipeline runs;
323+
the response is filtered to ship just the aggregate provides
324+
(typically histograms). Caching scalars across this and
325+
``process_df_scalars`` is a future PR.
326+
"""
327+
summary, errs = self.process_df(df, debug=debug)
328+
agg_provides = {sk.name for sf in self.ordered_stat_funcs
329+
if sf.cost == "aggregate" for sk in sf.provides}
330+
filtered = {col: {k: v for k, v in stats.items() if k in agg_provides}
331+
for col, stats in summary.items()}
332+
agg_func_names = {sf.name for sf in self.ordered_stat_funcs
333+
if sf.cost == "aggregate"}
334+
filtered_errs = [e for e in errs
335+
if e.stat_func is not None and e.stat_func.name in agg_func_names]
336+
return filtered, filtered_errs
337+
302338
def process_df_v1_compat(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, ErrDict]:
303339
"""Process DataFrame with v1-compatible error format.
304340

buckaroo/pluggable_analysis_framework/xorq_stat_pipeline.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,48 @@ def unit_test(self) -> Tuple[bool, List[StatError]]:
165165
finally:
166166
self.backend = saved_backend
167167

168-
def process_table(self, table) -> Tuple[SDType, List[StatError]]:
168+
def process_table_scalars(self, table) -> Tuple[SDType, List[StatError]]:
169+
"""Run only cost=scalar stats — the fast path.
170+
171+
Filters out cost=aggregate funcs (histograms, per-column-query
172+
stats) before running the pipeline. Used by the JS-driven
173+
progressive-stats router to ship cheap stats immediately on
174+
state_change. See plans/js-driven-stat-debounce.md.
175+
176+
Aggregate stats that depend only on scalars (e.g. a downstream
177+
compute that consumes ``histogram``) get their inputs missing
178+
and produce ``Err`` upstream — that's the expected shape; the
179+
consumer should ignore those when asking for scalars only.
180+
"""
181+
return self.process_table(table, cost_classes={"scalar"})
182+
183+
def process_table_aggregates(self, table) -> Tuple[SDType, List[StatError]]:
184+
"""Run the full pipeline, return only aggregate-cost stats.
185+
186+
Aggregates typically depend on scalar inputs (e.g. ``histogram``
187+
consumes ``value_counts``, ``length``), so the full pipeline
188+
has to run. Output is filtered to just aggregate-cost provides
189+
so the response payload is small. Caching of the scalar half
190+
across this and ``process_table_scalars`` is a future PR; here
191+
the compute is whole, only the shipping is filtered.
192+
"""
193+
summary, errs = self.process_table(table)
194+
agg_provides = {sk.name for sf in self.ordered_stat_funcs
195+
if sf.cost == "aggregate" for sk in sf.provides}
196+
filtered = {col: {k: v for k, v in stats.items() if k in agg_provides}
197+
for col, stats in summary.items()}
198+
# Errors are per-stat-func; keep only those from aggregate funcs.
199+
agg_func_names = {sf.name for sf in self.ordered_stat_funcs
200+
if sf.cost == "aggregate"}
201+
filtered_errs = [e for e in errs
202+
if e.stat_func is not None and e.stat_func.name in agg_func_names]
203+
return filtered, filtered_errs
204+
205+
def process_table(self, table, cost_classes=None) -> Tuple[SDType, List[StatError]]:
206+
"""Run the pipeline; ``cost_classes`` filter restricts which
207+
stat funcs execute (default: all costs)."""
208+
if cost_classes is None:
209+
cost_classes = {"scalar", "aggregate"}
169210
schema = table.schema()
170211
columns = list(table.columns)
171212

@@ -188,6 +229,8 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]:
188229
for sf in self.ordered_stat_funcs:
189230
if not _is_batch_func(sf):
190231
continue
232+
if sf.cost not in cost_classes:
233+
continue
191234
xorq_col_param = next(r.name for r in sf.requires if r.type is XorqColumn)
192235
for col in columns:
193236
col_dtype = schema[col]
@@ -253,6 +296,10 @@ def process_table(self, table) -> Tuple[SDType, List[StatError]]:
253296
# (typically the batch-phase stats).
254297
if sf.provides and all(sk.name in col_accum for sk in sf.provides):
255298
continue
299+
# Cost-class filter — the JS-driven router runs scalars
300+
# first, aggregates after a debounce.
301+
if sf.cost not in cost_classes:
302+
continue
256303
_execute_stat_func(sf, col_accum, col, raw_series=None, sampled_series=None, raw_dataframe=None,
257304
xorq_expr=table, xorq_execute=self._execute)
258305

tests/unit/test_paf_v2.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,54 @@ def test_basic_pipeline(self):
628628
assert 'distinct_per' in pipeline.provided_summary_facts_set
629629
assert 'length' in pipeline.provided_summary_facts_set
630630

631+
def test_process_df_cost_class_filter_scalars(self):
632+
"""Scalars-only path: ``process_df_scalars`` runs only cost=scalar
633+
stats. Aggregate funcs (and stats that depend on them) are
634+
skipped. The fast path the JS router uses for the initial
635+
state_change response."""
636+
@stat(cost="aggregate")
637+
def expensive_stat(ser: RawSeries) -> int:
638+
return 999 # would be slow in real life
639+
640+
df = pd.DataFrame({'a': [1, 2, 3, 1, 2]})
641+
pipeline = StatPipeline([length, expensive_stat], unit_test=False)
642+
summary, errs = pipeline.process_df_scalars(df)
643+
644+
assert summary['a']['length'] == 5
645+
# The aggregate stat was filtered out — it did not run.
646+
assert 'expensive_stat' not in summary['a']
647+
648+
def test_process_df_cost_class_filter_aggregates(self):
649+
"""Aggregates-only path: ``process_df_aggregates`` runs the
650+
pipeline but ships only aggregate provides. Used by the JS
651+
router for the slow follow-up after a debounce."""
652+
@stat(cost="aggregate")
653+
def expensive_stat(ser: RawSeries) -> int:
654+
return 999
655+
656+
df = pd.DataFrame({'a': [1, 2, 3, 1, 2]})
657+
pipeline = StatPipeline([length, expensive_stat], unit_test=False)
658+
summary, errs = pipeline.process_df_aggregates(df)
659+
660+
# Only aggregate-cost stat is shipped.
661+
assert summary['a']['expensive_stat'] == 999
662+
# Scalar 'length' is computed but filtered out of the response.
663+
assert 'length' not in summary['a']
664+
665+
def test_process_df_default_runs_all_costs(self):
666+
"""Default ``process_df()`` (no ``cost_classes`` arg) runs all
667+
cost classes — back-compat for every existing caller."""
668+
@stat(cost="aggregate")
669+
def expensive_stat(ser: RawSeries) -> int:
670+
return 999
671+
672+
df = pd.DataFrame({'a': [1, 2, 3]})
673+
pipeline = StatPipeline([length, expensive_stat], unit_test=False)
674+
summary, errs = pipeline.process_df(df)
675+
676+
assert summary['a']['length'] == 3
677+
assert summary['a']['expensive_stat'] == 999
678+
631679
def test_process_column(self):
632680
pipeline = StatPipeline([length, distinct_count, distinct_per], unit_test=False)
633681
ser = pd.Series([1, 2, 3, 1, 2])

tests/unit/test_xorq_buckaroo_widget.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,82 @@ def test_expr_count_pandas_path_unaffected(self):
121121
assert xorq_buckaroo._expr_count(df) == 4
122122

123123

124+
class TestCostClassFilter:
125+
"""Phase 2 of the JS-driven progressive-stats router. The
126+
``process_table_scalars`` / ``process_table_aggregates`` entry
127+
points filter by ``StatFunc.cost`` so the JS orchestrator can
128+
fetch cheap stats immediately on state_change and slow stats
129+
after a debounce.
130+
"""
131+
132+
def test_scalars_only_skips_histogram_queries(self):
133+
"""``process_table_scalars`` skips the expensive per-column
134+
histogram path entirely. Spy on ``_execute`` and assert the
135+
query count drops vs the full pipeline."""
136+
from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import XorqStatPipeline
137+
from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2
138+
139+
full_queries: list = []
140+
scalar_queries: list = []
141+
orig = XorqStatPipeline._execute
142+
143+
def counting(self, q, sink):
144+
sink.append(q)
145+
return orig(self, q)
146+
147+
expr = _expr()
148+
149+
# Full pipeline reference count
150+
XorqStatPipeline._execute = lambda self, q: counting(self, q, full_queries)
151+
try:
152+
p1 = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False)
153+
p1.process_table(expr)
154+
finally:
155+
XorqStatPipeline._execute = orig
156+
157+
# Scalars-only count — must be strictly fewer
158+
XorqStatPipeline._execute = lambda self, q: counting(self, q, scalar_queries)
159+
try:
160+
p2 = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False)
161+
p2.process_table_scalars(expr)
162+
finally:
163+
XorqStatPipeline._execute = orig
164+
165+
assert len(scalar_queries) < len(full_queries), (
166+
f"scalars-only must issue fewer queries than full pipeline; "
167+
f"got scalars={len(scalar_queries)} full={len(full_queries)}")
168+
169+
def test_scalars_only_omits_histogram_key(self):
170+
"""``process_table_scalars`` output has no ``histogram`` key for
171+
any column — that stat is cost=aggregate."""
172+
from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import XorqStatPipeline
173+
from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2
174+
175+
pipeline = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False)
176+
summary, errs = pipeline.process_table_scalars(_expr())
177+
for col, stats in summary.items():
178+
assert "histogram" not in stats, (
179+
f"col {col} has histogram in scalars-only output: "
180+
f"{list(stats.keys())}")
181+
182+
def test_aggregates_only_ships_only_aggregate_keys(self):
183+
"""``process_table_aggregates`` returns only aggregate-cost
184+
stat provides. Scalars are computed (as dependencies) but
185+
filtered out of the response."""
186+
from buckaroo.pluggable_analysis_framework.xorq_stat_pipeline import XorqStatPipeline
187+
from buckaroo.customizations.xorq_stats_v2 import XORQ_STATS_V2
188+
189+
pipeline = XorqStatPipeline(list(XORQ_STATS_V2), unit_test=False)
190+
summary, errs = pipeline.process_table_aggregates(_expr())
191+
for col, stats in summary.items():
192+
# The boston-style histogram stat is aggregate.
193+
if "histogram" in stats:
194+
# Other obvious scalar keys must NOT be there.
195+
assert "length" not in stats, (
196+
f"col {col} aggregate response leaks scalar 'length'")
197+
assert "min" not in stats
198+
199+
124200
class TestInstantiation:
125201
def test_smoke(self):
126202
XorqBuckarooWidget(_expr())

0 commit comments

Comments
 (0)