1919 EvalMetadata ,
2020 EvaluationRow ,
2121 EvaluationThreshold ,
22+ EvaluationThresholdDict ,
2223 InputMetadata ,
2324 Message ,
2425 Status ,
2526)
27+ from eval_protocol .pytest .parameterize import pytest_parametrize
28+ from eval_protocol .pytest .validate_signature import validate_signature
2629from eval_protocol .pytest .default_dataset_adapter import default_dataset_adapter
2730from eval_protocol .pytest .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
2831from eval_protocol .pytest .default_no_op_rollout_processor import NoOpRolloutProcessor
3841 RolloutProcessorInputParam ,
3942 TestFunction ,
4043)
44+
45+
4146from eval_protocol .pytest .utils import (
4247 AggregationMethod ,
4348 aggregate ,
@@ -237,15 +242,15 @@ def postprocess(
237242def evaluation_test (
238243 * ,
239244 completion_params : list [CompletionParams | None ] | None = None ,
240- input_messages : list [InputMessagesParam ] | None = None ,
245+ input_messages : list [InputMessagesParam | None ] | None = None ,
241246 input_dataset : list [DatasetPathParam ] | None = None ,
242247 input_rows : list [EvaluationRow ] | None = None ,
243248 dataset_adapter : Callable [[list [dict [str , Any ]]], Dataset ] = default_dataset_adapter , # pyright: ignore[reportExplicitAny]
244249 rollout_processor : RolloutProcessor | None = None ,
245- evaluation_test_kwargs : list [EvaluationInputParam ] | None = None ,
250+ evaluation_test_kwargs : list [EvaluationInputParam | None ] | None = None ,
246251 rollout_processor_kwargs : RolloutProcessorInputParam | None = None ,
247252 aggregation_method : AggregationMethod = "mean" ,
248- passed_threshold : EvaluationThreshold | float | dict [ str , Any ] | None = None , # pyright: ignore[reportExplicitAny]
253+ passed_threshold : EvaluationThreshold | float | EvaluationThresholdDict | None = None ,
249254 num_runs : int = 1 ,
250255 max_dataset_rows : int | None = None ,
251256 mcp_config_path : str | None = None ,
@@ -257,10 +262,7 @@ def evaluation_test(
257262 combine_datasets : bool = True ,
258263 logger : DatasetLogger | None = None ,
259264 exception_handler_config : ExceptionHandlerConfig | None = None ,
260- ) -> Callable [
261- [TestFunction ],
262- TestFunction ,
263- ]:
265+ ) -> Callable [[TestFunction ], TestFunction ]:
264266 """Decorator to create pytest-based evaluation tests.
265267
266268 Here are some key concepts to understand the terminology in EP:
@@ -328,6 +330,10 @@ def evaluation_test(
328330 exception_handler_config: Configuration for exception handling and backoff retry logic.
329331 If not provided, a default configuration will be used with common retryable exceptions.
330332 """
333+ if completion_params is None :
334+ completion_params = [None ]
335+ if rollout_processor is None :
336+ rollout_processor = NoOpRolloutProcessor ()
331337
332338 active_logger : DatasetLogger = logger if logger else default_logger
333339
@@ -337,148 +343,40 @@ def evaluation_test(
337343 num_runs = parse_ep_num_runs (num_runs )
338344 max_concurrent_rollouts = parse_ep_max_concurrent_rollouts (max_concurrent_rollouts )
339345 max_dataset_rows = parse_ep_max_rows (max_dataset_rows )
340- if completion_params is None :
341- completion_params = [None ]
342- if rollout_processor is None :
343- rollout_processor = NoOpRolloutProcessor ()
344346 completion_params = parse_ep_completion_params (completion_params )
345347 original_completion_params = completion_params
346348 passed_threshold = parse_ep_passed_threshold (passed_threshold )
347349
348350 def decorator (
349351 test_func : TestFunction ,
350- ):
351- if passed_threshold is not None :
352- if isinstance (passed_threshold , float ):
353- threshold = EvaluationThreshold (success = passed_threshold )
354- else :
355- threshold = EvaluationThreshold (** passed_threshold )
356- else :
357- threshold = None
358-
352+ ) -> TestFunction :
359353 sig = inspect .signature (test_func )
360-
361- # For pointwise/groupwise mode, we expect a different signature
362- # we expect single row to be passed in as the original row
363- if mode == "pointwise" :
364- # Pointwise mode: function should accept messages and other row-level params
365- if "row" not in sig .parameters :
366- raise ValueError ("In pointwise mode, your eval function must have a parameter named 'row'" )
367-
368- # validate that "Row" is of type EvaluationRow
369- if sig .parameters ["row" ].annotation is not EvaluationRow :
370- raise ValueError ("In pointwise mode, the 'row' parameter must be of type EvaluationRow" )
371-
372- # validate that the function has a return type of EvaluationRow
373- if sig .return_annotation is not EvaluationRow :
374- raise ValueError ("In pointwise mode, your eval function must return an EvaluationRow instance" )
375-
376- # additional check for groupwise evaluation
377- elif mode == "groupwise" :
378- if "rows" not in sig .parameters :
379- raise ValueError ("In groupwise mode, your eval function must have a parameter named 'rows'" )
380-
381- # validate that "Rows" is of type List[EvaluationRow]
382- if sig .parameters ["rows" ].annotation is not List [EvaluationRow ]:
383- raise ValueError ("In groupwise mode, the 'rows' parameter must be of type List[EvaluationRow" )
384-
385- # validate that the function has a return type of List[EvaluationRow]
386- if sig .return_annotation is not List [EvaluationRow ]:
387- raise ValueError ("In groupwise mode, your eval function must return a list of EvaluationRow instances" )
388- if len (completion_params ) < 2 :
389- raise ValueError ("In groupwise mode, you must provide at least 2 completion parameters" )
390- else :
391- # all mode: function should accept input_dataset and model
392- if "rows" not in sig .parameters :
393- raise ValueError ("In all mode, your eval function must have a parameter named 'rows'" )
394-
395- # validate that "Rows" is of type List[EvaluationRow]
396- if sig .parameters ["rows" ].annotation is not List [EvaluationRow ]:
397- raise ValueError ("In all mode, the 'rows' parameter must be of type List[EvaluationRow" )
398-
399- # validate that the function has a return type of List[EvaluationRow]
400- if sig .return_annotation is not List [EvaluationRow ]:
401- raise ValueError ("In all mode, your eval function must return a list of EvaluationRow instances" )
402-
403- async def execute_with_params (
404- test_func : TestFunction ,
405- processed_row : EvaluationRow | None = None ,
406- processed_dataset : List [EvaluationRow ] | None = None ,
407- evaluation_test_kwargs : Optional [EvaluationInputParam ] = None ,
408- ):
409- kwargs = {}
410- if processed_dataset is not None :
411- kwargs ["rows" ] = processed_dataset
412- if processed_row is not None :
413- kwargs ["row" ] = processed_row
414- if evaluation_test_kwargs is not None :
415- if "row" in evaluation_test_kwargs :
416- raise ValueError ("'row' is a reserved parameter for the evaluation function" )
417- if "rows" in evaluation_test_kwargs :
418- raise ValueError ("'rows' is a reserved parameter for the evaluation function" )
419- kwargs .update (evaluation_test_kwargs )
420-
421- # Handle both sync and async test functions
422- if asyncio .iscoroutinefunction (test_func ):
423- return await test_func (** kwargs )
424- else :
425- return test_func (** kwargs )
354+ validate_signature (sig , mode , completion_params )
426355
427356 # Calculate all possible combinations of parameters
428- if mode == "groupwise" :
429- combinations = generate_parameter_combinations (
430- input_dataset ,
431- completion_params ,
432- input_messages ,
433- input_rows ,
434- evaluation_test_kwargs ,
435- max_dataset_rows ,
436- combine_datasets ,
437- )
438- else :
439- combinations = generate_parameter_combinations (
440- input_dataset ,
441- completion_params ,
442- input_messages ,
443- input_rows ,
444- evaluation_test_kwargs ,
445- max_dataset_rows ,
446- combine_datasets ,
447- )
357+ combinations = generate_parameter_combinations (
358+ input_dataset ,
359+ completion_params ,
360+ input_messages ,
361+ input_rows ,
362+ evaluation_test_kwargs ,
363+ max_dataset_rows ,
364+ combine_datasets ,
365+ )
448366 if len (combinations ) == 0 :
449367 raise ValueError (
450368 "No combinations of parameters were found. Please provide at least a model and one of input_dataset, input_messages, or input_rows."
451369 )
452370
453371 # Create parameter tuples for pytest.mark.parametrize
454- param_tuples = []
455- for combo in combinations :
456- dataset , cp , messages , rows , etk = combo
457- param_tuple = []
458- if input_dataset is not None :
459- param_tuple .append (dataset )
460- if completion_params is not None :
461- param_tuple .append (cp )
462- if input_messages is not None :
463- param_tuple .append (messages )
464- if input_rows is not None :
465- param_tuple .append (rows )
466- if evaluation_test_kwargs is not None :
467- param_tuple .append (etk )
468- param_tuples .append (tuple (param_tuple ))
469-
470- # For all mode, preserve the original parameter names
471- test_param_names = []
472- if input_dataset is not None :
473- test_param_names .append ("dataset_path" )
474- if completion_params is not None :
475- test_param_names .append ("completion_params" )
476- if input_messages is not None :
477- test_param_names .append ("input_messages" )
478- if input_rows is not None :
479- test_param_names .append ("input_rows" )
480- if evaluation_test_kwargs is not None :
481- test_param_names .append ("evaluation_test_kwargs" )
372+ pytest_parametrize_args = pytest_parametrize (
373+ combinations ,
374+ input_dataset ,
375+ completion_params ,
376+ input_messages ,
377+ input_rows ,
378+ evaluation_test_kwargs ,
379+ )
482380
483381 # Create wrapper function with exact signature that pytest expects
484382 def create_wrapper_with_signature () -> Callable :
@@ -613,7 +511,7 @@ async def _execute_eval_with_semaphore(**inner_kwargs):
613511 # NOTE: we will still evaluate errored rows (give users control over this)
614512 # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
615513 if "row" in inner_kwargs :
616- result = await execute_with_params (
514+ result = await execute_pytest (
617515 test_func ,
618516 processed_row = inner_kwargs ["row" ],
619517 evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
@@ -624,7 +522,7 @@ async def _execute_eval_with_semaphore(**inner_kwargs):
624522 )
625523 return result
626524 if "rows" in inner_kwargs :
627- results = await execute_with_params (
525+ results = await execute_pytest (
628526 test_func ,
629527 processed_dataset = inner_kwargs ["rows" ],
630528 evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
@@ -696,7 +594,7 @@ async def _collect_result(config, lst):
696594 input_dataset .append (row )
697595 # NOTE: we will still evaluate errored rows (give users control over this)
698596 # i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
699- results = await execute_with_params (
597+ results = await execute_pytest (
700598 test_func ,
701599 processed_dataset = input_dataset ,
702600 evaluation_test_kwargs = kwargs .get ("evaluation_test_kwargs" ) or {},
@@ -795,7 +693,7 @@ async def _collect_result(config, lst):
795693
796694 # Create the pytest wrapper
797695 pytest_wrapper = create_wrapper_with_signature ()
798- pytest_wrapper = pytest .mark .parametrize (test_param_names , param_tuples )(pytest_wrapper )
696+ pytest_wrapper = pytest .mark .parametrize (** pytest_parametrize_args )(pytest_wrapper )
799697 pytest_wrapper = pytest .mark .asyncio (pytest_wrapper )
800698
801699 def create_dual_mode_wrapper () -> Callable :
0 commit comments