@@ -234,13 +234,18 @@ def __init__(self, stat_funcs: list, unit_test: bool = True, record_timings: boo
234234 self ._unit_test_result = self .unit_test ()
235235
236236 def process_column (self , column_name : str , column_dtype , raw_series = None , sampled_series = None , raw_dataframe = None ,
237- initial_stats : Optional [Dict [str , Any ]] = None ) -> Tuple [Dict [str , Any ], List [StatError ]]:
237+ initial_stats : Optional [Dict [str , Any ]] = None ,
238+ cost_classes = None ) -> Tuple [Dict [str , Any ], List [StatError ]]:
238239 """Process a single column through the stat DAG.
239240
240241 1. Filters stat functions by column dtype
241- 2. Executes in topological order with Ok/Err accumulator
242- 3. Returns (plain_dict, errors)
242+ 2. Filters by ``cost_classes`` (default: all costs) — used by
243+ the JS-driven progressive-stats router to run scalars only
244+ 3. Executes in topological order with Ok/Err accumulator
245+ 4. Returns (plain_dict, errors)
243246 """
247+ if cost_classes is None :
248+ cost_classes = {"scalar" , "aggregate" }
244249 # Build column-specific DAG (filters by dtype)
245250 external = set (self .EXTERNAL_KEYS )
246251 if initial_stats :
@@ -255,6 +260,8 @@ def process_column(self, column_name: str, column_dtype, raw_series=None, sample
255260 accumulator [k ] = Ok (v )
256261 record_timings = self .record_timings
257262 for sf in column_funcs :
263+ if sf .cost not in cost_classes :
264+ continue
258265 if record_timings :
259266 t0 = time .perf_counter ()
260267 _execute_stat_func (sf , accumulator , column_name , raw_series = raw_series , sampled_series = sampled_series ,
@@ -270,9 +277,12 @@ def process_column(self, column_name: str, column_dtype, raw_series=None, sample
270277
271278 return resolve_accumulator (accumulator , column_name , col_key_to_func )
272279
273- def process_df (self , df : pd .DataFrame , debug : bool = False ) -> Tuple [SDType , List [StatError ]]:
280+ def process_df (self , df : pd .DataFrame , debug : bool = False , cost_classes = None ) -> Tuple [SDType , List [StatError ]]:
274281 """Process all columns of a DataFrame.
275282
283+ ``cost_classes`` (default: all) restricts which stat funcs run.
284+ Used by the JS-driven progressive-stats router.
285+
276286 Returns:
277287 (summary_dict, all_errors) where summary_dict is SDType-compatible
278288 (column_name -> {stat_name -> value}).
@@ -292,13 +302,39 @@ def process_df(self, df: pd.DataFrame, debug: bool = False) -> Tuple[SDType, Lis
292302
293303 col_result , col_errors = self .process_column (column_name = rewritten_col_name , column_dtype = col_dtype ,
294304 raw_series = ser , sampled_series = ser , raw_dataframe = df ,
295- initial_stats = {'orig_col_name' : orig_col_name , 'rewritten_col_name' : rewritten_col_name })
305+ initial_stats = {'orig_col_name' : orig_col_name , 'rewritten_col_name' : rewritten_col_name },
306+ cost_classes = cost_classes )
296307
297308 summary [rewritten_col_name ] = col_result
298309 all_errors .extend (col_errors )
299310
300311 return summary , all_errors
301312
313+ def process_df_scalars (self , df : pd .DataFrame , debug : bool = False ) -> Tuple [SDType , List [StatError ]]:
314+ """Run only cost=scalar stats — the fast path used by the
315+ JS-driven progressive-stats router. Histograms etc. are
316+ skipped."""
317+ return self .process_df (df , debug = debug , cost_classes = {"scalar" })
318+
319+ def process_df_aggregates (self , df : pd .DataFrame , debug : bool = False ) -> Tuple [SDType , List [StatError ]]:
320+ """Run the full pipeline, return only aggregate-cost stats.
321+
322+ Aggregates depend on scalar inputs, so the full pipeline runs;
323+ the response is filtered to ship just the aggregate provides
324+ (typically histograms). Caching scalars across this and
325+ ``process_df_scalars`` is a future PR.
326+ """
327+ summary , errs = self .process_df (df , debug = debug )
328+ agg_provides = {sk .name for sf in self .ordered_stat_funcs
329+ if sf .cost == "aggregate" for sk in sf .provides }
330+ filtered = {col : {k : v for k , v in stats .items () if k in agg_provides }
331+ for col , stats in summary .items ()}
332+ agg_func_names = {sf .name for sf in self .ordered_stat_funcs
333+ if sf .cost == "aggregate" }
334+ filtered_errs = [e for e in errs
335+ if e .stat_func is not None and e .stat_func .name in agg_func_names ]
336+ return filtered , filtered_errs
337+
302338 def process_df_v1_compat (self , df : pd .DataFrame , debug : bool = False ) -> Tuple [SDType , ErrDict ]:
303339 """Process DataFrame with v1-compatible error format.
304340
0 commit comments