11import inspect
22import os
3- import os
3+ import copy
4+ import math
5+ import statistics
46from typing import Any , Callable , Dict , List , Optional
57
68import pytest
@@ -91,11 +93,11 @@ def decorator(
9193 if mode == "pointwise" :
9294 # Pointwise mode: function should accept messages and other row-level params
9395 if "row" not in sig .parameters :
94- raise ValueError (f "In pointwise mode, your eval function must have a parameter named 'row'" )
96+ raise ValueError ("In pointwise mode, your eval function must have a parameter named 'row'" )
9597
9698 # validate that "Row" is of type EvaluationRow
9799 if sig .parameters ["row" ].annotation is not EvaluationRow :
98- raise ValueError (f "In pointwise mode, the 'row' parameter must be of type EvaluationRow" )
100+ raise ValueError ("In pointwise mode, the 'row' parameter must be of type EvaluationRow" )
99101
100102 # validate that the function has a return type of EvaluationRow
101103 if sig .return_annotation is not EvaluationRow :
@@ -107,7 +109,7 @@ def decorator(
107109
108110 # validate that "Rows" is of type List[EvaluationRow]
109111 if sig .parameters ["rows" ].annotation is not List [EvaluationRow ]:
110- raise ValueError (f "In batch mode, the 'rows' parameter must be of type List[EvaluationRow] " )
112+ raise ValueError ("In batch mode, the 'rows' parameter must be of type List[EvaluationRow" )
111113
112114 # validate that the function has a return type of List[EvaluationRow]
113115 if sig .return_annotation is not List [EvaluationRow ]:
@@ -150,7 +152,13 @@ def generate_combinations():
150152 combinations = []
151153
152154 # Handle optional parameters with defaults
153- datasets : List [Optional [DatasetPathParam ]] = input_dataset if input_dataset is not None else [None ] # type: ignore
155+ # Treat multiple dataset paths as a single combined dataset rather than
156+ # parameterizing over each path separately. This produces one summary
157+ # that reflects the aggregate of all provided files (e.g., AIME I+II).
158+ if input_dataset is not None :
159+ datasets : List [Optional [List [DatasetPathParam ]]] = [input_dataset ] # type: ignore
160+ else :
161+ datasets = [None ]
154162 params : List [Optional [RolloutInputParam ]] = rollout_input_params if rollout_input_params is not None else [None ] # type: ignore
155163 # Apply EP_MAX_DATASET_ROWS to input_messages to uniformly control row count when messages are provided
156164 if input_messages is not None and isinstance (input_messages , list ):
@@ -222,7 +230,15 @@ def wrapper_body(**kwargs):
222230 # Handle dataset loading
223231 data : List [EvaluationRow ] = []
224232 if "dataset_path" in kwargs and kwargs ["dataset_path" ] is not None :
225- data_jsonl = load_jsonl (kwargs ["dataset_path" ])
233+ ds_arg = kwargs ["dataset_path" ]
234+ # Support either a single path or a list of paths; if a list is provided,
235+ # concatenate the rows from each file in order.
236+ if isinstance (ds_arg , list ):
237+ data_jsonl = []
238+ for p in ds_arg :
239+ data_jsonl .extend (load_jsonl (p ))
240+ else :
241+ data_jsonl = load_jsonl (ds_arg )
226242 # Apply env override for max rows if present
227243 effective_max_rows = _parse_ep_max_rows (max_dataset_rows )
228244 if effective_max_rows is not None :
@@ -270,7 +286,7 @@ def wrapper_body(**kwargs):
270286 row .pid = os .getpid ()
271287 default_logger .log (row )
272288
273- # Now run the rollout processor with metadata-initialized data
289+ # Prepare rollout processor config once; we will generate fresh outputs per run
274290 config = RolloutProcessorConfig (
275291 model = model_name ,
276292 input_params = input_params ,
@@ -279,9 +295,12 @@ def wrapper_body(**kwargs):
279295 server_script_path = server_script_path ,
280296 steps = steps ,
281297 )
282- input_dataset = execute_function (rollout_processor , rows = data , config = config )
283298
284299 for _ in range (num_runs ):
300+ # Regenerate outputs each run by deep-copying the pristine dataset
301+ # so model responses are not reused across runs.
302+ fresh_rows = [copy .deepcopy (r ) for r in data ]
303+ input_dataset = execute_function (rollout_processor , rows = fresh_rows , config = config )
285304 if mode == "pointwise" :
286305 # Pointwise mode: apply the evaluator function to each row
287306 for row in input_dataset :
@@ -323,6 +342,23 @@ def wrapper_body(**kwargs):
323342 scores = [r .evaluation_result .score for r in all_results if r .evaluation_result ]
324343 agg_score = aggregate (scores , aggregation_method )
325344
345+ # Compute 95% confidence interval for mean aggregation
346+ # TODO bchen: remove after Derek has his stuff
347+ ci_low : float | None = None
348+ ci_high : float | None = None
349+ if aggregation_method == "mean" :
350+ n = len (scores )
351+ if n >= 2 :
352+ try :
353+ sample_std = statistics .stdev (scores )
354+ se = sample_std / math .sqrt (n )
355+ margin = 1.96 * se
356+ ci_low = float (max (0.0 , (agg_score or 0.0 ) - margin )) if agg_score is not None else None
357+ ci_high = float (min (1.0 , (agg_score or 0.0 ) + margin )) if agg_score is not None else None
358+ except Exception :
359+ ci_low = None
360+ ci_high = None
361+
326362 # Determine if the evaluation passed based on threshold
327363 passed = None
328364 if threshold_of_success is not None :
@@ -335,6 +371,86 @@ def wrapper_body(**kwargs):
335371 r .eval_metadata .passed = passed
336372 default_logger .log (r )
337373
374+ # Optional: print and/or persist a summary artifact for CI
375+ try :
376+ should_print = os .getenv ("EP_PRINT_SUMMARY" ) == "1"
377+ summary_path = os .getenv ("EP_SUMMARY_JSON" )
378+ suite_name = test_func .__name__
379+ model_used = model_name
380+ total_rows = len (all_results )
381+ summary_obj = {
382+ "suite" : suite_name ,
383+ "model" : model_used ,
384+ "agg_score" : float (agg_score ) if agg_score is not None else None ,
385+ "num_runs" : num_runs ,
386+ "rows" : total_rows ,
387+ }
388+ if ci_low is not None and ci_high is not None :
389+ summary_obj ["agg_ci_low" ] = ci_low
390+ summary_obj ["agg_ci_high" ] = ci_high
391+
392+ # Aggregate per-metric mean and 95% CI when available
393+ metrics_summary : Dict [str , Dict [str , float ]] = {}
394+ from collections import defaultdict
395+ metric_scores : Dict [str , list ] = defaultdict (list )
396+ for r in all_results :
397+ if r .evaluation_result and r .evaluation_result .metrics :
398+ for m_name , m_res in r .evaluation_result .metrics .items ():
399+ if m_res is not None and getattr (m_res , "score" , None ) is not None :
400+ metric_scores [m_name ].append (m_res .score )
401+ for m_name , vals in metric_scores .items ():
402+ if len (vals ) == 0 :
403+ continue
404+ m_mean = sum (vals ) / len (vals )
405+ m_low = None
406+ m_high = None
407+ if len (vals ) >= 2 :
408+ try :
409+ m_std = statistics .stdev (vals )
410+ m_se = m_std / math .sqrt (len (vals ))
411+ m_margin = 1.96 * m_se
412+ m_low = max (0.0 , m_mean - m_margin )
413+ m_high = min (1.0 , m_mean + m_margin )
414+ except Exception :
415+ m_low = None
416+ m_high = None
417+ entry : Dict [str , float ] = {"mean" : float (m_mean )}
418+ if m_low is not None and m_high is not None :
419+ entry ["ci_low" ] = float (m_low )
420+ entry ["ci_high" ] = float (m_high )
421+ metrics_summary [m_name ] = entry
422+ if metrics_summary :
423+ summary_obj ["metrics_agg" ] = metrics_summary
424+ if should_print :
425+ if ci_low is not None and ci_high is not None :
426+ print (
427+ f"EP Summary | suite={ suite_name } model={ model_used } agg={ summary_obj ['agg_score' ]:.3f} ci95=[{ ci_low :.3f} ,{ ci_high :.3f} ] runs={ num_runs } rows={ total_rows } "
428+ )
429+ else :
430+ print (
431+ f"EP Summary | suite={ suite_name } model={ model_used } agg={ summary_obj ['agg_score' ]:.3f} runs={ num_runs } rows={ total_rows } "
432+ )
433+ # Print per-metric aggregations concisely (only names present)
434+ if metrics_summary :
435+ parts = []
436+ for m_name , entry in metrics_summary .items ():
437+ if "ci_low" in entry and "ci_high" in entry :
438+ parts .append (f"{ m_name } ={ entry ['mean' ]:.3f} ci95=[{ entry ['ci_low' ]:.3f} ,{ entry ['ci_high' ]:.3f} ]" )
439+ else :
440+ parts .append (f"{ m_name } ={ entry ['mean' ]:.3f} " )
441+ print (f"EP Metrics | " + ", " .join (parts ))
442+ if summary_path :
443+ import json , pathlib , time
444+
445+ p = pathlib .Path (summary_path )
446+ p .parent .mkdir (parents = True , exist_ok = True )
447+ summary_obj ["timestamp" ] = int (time .time ())
448+ with p .open ("w" , encoding = "utf-8" ) as f :
449+ json .dump (summary_obj , f )
450+ except Exception :
451+ # Do not fail evaluation if summary writing fails
452+ pass
453+
338454 # Check threshold after logging
339455 if threshold_of_success is not None and not passed :
340456 assert (
0 commit comments