Skip to content

Commit 3202461

Browse files
author
Dylan Huang
committed
part 2
1 parent 16300bb commit 3202461

8 files changed

Lines changed: 282 additions & 200 deletions

File tree

.vscode/settings.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,9 @@
66
"python.defaultInterpreterPath": "./.venv/bin/python",
77
"python.testing.cwd": "${workspaceFolder}",
88
"cursorpyright.analysis.diagnosticMode": "openFilesOnly",
9-
"editor.defaultFormatter": "charliermarsh.ruff"
9+
"editor.defaultFormatter": "charliermarsh.ruff",
10+
"editor.formatOnSave": true,
11+
"[python]": {
12+
"editor.defaultFormatter": "charliermarsh.ruff"
13+
}
1014
}

eval_protocol/models.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -439,11 +439,19 @@ class EvaluationThreshold(BaseModel):
439439
success: float = Field(
440440
..., description="Minimum success rate threshold (fraction of total score, 0.0 to 1.0)", ge=0.0, le=1.0
441441
)
442-
standard_error: Optional[float] = Field(
443-
None, description="Maximum standard error threshold (fraction of total score, 0.0 to 1.0)", ge=0.0, le=1.0
442+
standard_error: float | None = Field(
443+
default=None,
444+
description="Maximum standard error threshold (fraction of total score, 0.0 to 1.0)",
445+
ge=0.0,
446+
le=1.0,
444447
)
445448

446449

450+
class EvaluationThresholdDict(TypedDict):
451+
success: float
452+
standard_error: float | None
453+
454+
447455
class EvalMetadata(BaseModel):
448456
"""Metadata about the evaluation that was run."""
449457

eval_protocol/pytest/evaluation_test.py

Lines changed: 36 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,13 @@
1919
EvalMetadata,
2020
EvaluationRow,
2121
EvaluationThreshold,
22+
EvaluationThresholdDict,
2223
InputMetadata,
2324
Message,
2425
Status,
2526
)
27+
from eval_protocol.pytest.parameterize import pytest_parametrize
28+
from eval_protocol.pytest.validate_signature import validate_signature
2629
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
2730
from eval_protocol.pytest.default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
2831
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
@@ -38,6 +41,8 @@
3841
RolloutProcessorInputParam,
3942
TestFunction,
4043
)
44+
45+
4146
from eval_protocol.pytest.utils import (
4247
AggregationMethod,
4348
aggregate,
@@ -237,15 +242,15 @@ def postprocess(
237242
def evaluation_test(
238243
*,
239244
completion_params: list[CompletionParams | None] | None = None,
240-
input_messages: list[InputMessagesParam] | None = None,
245+
input_messages: list[InputMessagesParam | None] | None = None,
241246
input_dataset: list[DatasetPathParam] | None = None,
242247
input_rows: list[EvaluationRow] | None = None,
243248
dataset_adapter: Callable[[list[dict[str, Any]]], Dataset] = default_dataset_adapter, # pyright: ignore[reportExplicitAny]
244249
rollout_processor: RolloutProcessor | None = None,
245-
evaluation_test_kwargs: list[EvaluationInputParam] | None = None,
250+
evaluation_test_kwargs: list[EvaluationInputParam | None] | None = None,
246251
rollout_processor_kwargs: RolloutProcessorInputParam | None = None,
247252
aggregation_method: AggregationMethod = "mean",
248-
passed_threshold: EvaluationThreshold | float | dict[str, Any] | None = None, # pyright: ignore[reportExplicitAny]
253+
passed_threshold: EvaluationThreshold | float | EvaluationThresholdDict | None = None,
249254
num_runs: int = 1,
250255
max_dataset_rows: int | None = None,
251256
mcp_config_path: str | None = None,
@@ -257,10 +262,7 @@ def evaluation_test(
257262
combine_datasets: bool = True,
258263
logger: DatasetLogger | None = None,
259264
exception_handler_config: ExceptionHandlerConfig | None = None,
260-
) -> Callable[
261-
[TestFunction],
262-
TestFunction,
263-
]:
265+
) -> Callable[[TestFunction], TestFunction]:
264266
"""Decorator to create pytest-based evaluation tests.
265267
266268
Here are some key concepts to understand the terminology in EP:
@@ -328,6 +330,10 @@ def evaluation_test(
328330
exception_handler_config: Configuration for exception handling and backoff retry logic.
329331
If not provided, a default configuration will be used with common retryable exceptions.
330332
"""
333+
if completion_params is None:
334+
completion_params = [None]
335+
if rollout_processor is None:
336+
rollout_processor = NoOpRolloutProcessor()
331337

332338
active_logger: DatasetLogger = logger if logger else default_logger
333339

@@ -337,148 +343,40 @@ def evaluation_test(
337343
num_runs = parse_ep_num_runs(num_runs)
338344
max_concurrent_rollouts = parse_ep_max_concurrent_rollouts(max_concurrent_rollouts)
339345
max_dataset_rows = parse_ep_max_rows(max_dataset_rows)
340-
if completion_params is None:
341-
completion_params = [None]
342-
if rollout_processor is None:
343-
rollout_processor = NoOpRolloutProcessor()
344346
completion_params = parse_ep_completion_params(completion_params)
345347
original_completion_params = completion_params
346348
passed_threshold = parse_ep_passed_threshold(passed_threshold)
347349

348350
def decorator(
349351
test_func: TestFunction,
350-
):
351-
if passed_threshold is not None:
352-
if isinstance(passed_threshold, float):
353-
threshold = EvaluationThreshold(success=passed_threshold)
354-
else:
355-
threshold = EvaluationThreshold(**passed_threshold)
356-
else:
357-
threshold = None
358-
352+
) -> TestFunction:
359353
sig = inspect.signature(test_func)
360-
361-
# For pointwise/groupwise mode, we expect a different signature
362-
# we expect single row to be passed in as the original row
363-
if mode == "pointwise":
364-
# Pointwise mode: function should accept messages and other row-level params
365-
if "row" not in sig.parameters:
366-
raise ValueError("In pointwise mode, your eval function must have a parameter named 'row'")
367-
368-
# validate that "Row" is of type EvaluationRow
369-
if sig.parameters["row"].annotation is not EvaluationRow:
370-
raise ValueError("In pointwise mode, the 'row' parameter must be of type EvaluationRow")
371-
372-
# validate that the function has a return type of EvaluationRow
373-
if sig.return_annotation is not EvaluationRow:
374-
raise ValueError("In pointwise mode, your eval function must return an EvaluationRow instance")
375-
376-
# additional check for groupwise evaluation
377-
elif mode == "groupwise":
378-
if "rows" not in sig.parameters:
379-
raise ValueError("In groupwise mode, your eval function must have a parameter named 'rows'")
380-
381-
# validate that "Rows" is of type List[EvaluationRow]
382-
if sig.parameters["rows"].annotation is not List[EvaluationRow]:
383-
raise ValueError("In groupwise mode, the 'rows' parameter must be of type List[EvaluationRow")
384-
385-
# validate that the function has a return type of List[EvaluationRow]
386-
if sig.return_annotation is not List[EvaluationRow]:
387-
raise ValueError("In groupwise mode, your eval function must return a list of EvaluationRow instances")
388-
if len(completion_params) < 2:
389-
raise ValueError("In groupwise mode, you must provide at least 2 completion parameters")
390-
else:
391-
# all mode: function should accept input_dataset and model
392-
if "rows" not in sig.parameters:
393-
raise ValueError("In all mode, your eval function must have a parameter named 'rows'")
394-
395-
# validate that "Rows" is of type List[EvaluationRow]
396-
if sig.parameters["rows"].annotation is not List[EvaluationRow]:
397-
raise ValueError("In all mode, the 'rows' parameter must be of type List[EvaluationRow")
398-
399-
# validate that the function has a return type of List[EvaluationRow]
400-
if sig.return_annotation is not List[EvaluationRow]:
401-
raise ValueError("In all mode, your eval function must return a list of EvaluationRow instances")
402-
403-
async def execute_with_params(
404-
test_func: TestFunction,
405-
processed_row: EvaluationRow | None = None,
406-
processed_dataset: List[EvaluationRow] | None = None,
407-
evaluation_test_kwargs: Optional[EvaluationInputParam] = None,
408-
):
409-
kwargs = {}
410-
if processed_dataset is not None:
411-
kwargs["rows"] = processed_dataset
412-
if processed_row is not None:
413-
kwargs["row"] = processed_row
414-
if evaluation_test_kwargs is not None:
415-
if "row" in evaluation_test_kwargs:
416-
raise ValueError("'row' is a reserved parameter for the evaluation function")
417-
if "rows" in evaluation_test_kwargs:
418-
raise ValueError("'rows' is a reserved parameter for the evaluation function")
419-
kwargs.update(evaluation_test_kwargs)
420-
421-
# Handle both sync and async test functions
422-
if asyncio.iscoroutinefunction(test_func):
423-
return await test_func(**kwargs)
424-
else:
425-
return test_func(**kwargs)
354+
validate_signature(sig, mode, completion_params)
426355

427356
# Calculate all possible combinations of parameters
428-
if mode == "groupwise":
429-
combinations = generate_parameter_combinations(
430-
input_dataset,
431-
completion_params,
432-
input_messages,
433-
input_rows,
434-
evaluation_test_kwargs,
435-
max_dataset_rows,
436-
combine_datasets,
437-
)
438-
else:
439-
combinations = generate_parameter_combinations(
440-
input_dataset,
441-
completion_params,
442-
input_messages,
443-
input_rows,
444-
evaluation_test_kwargs,
445-
max_dataset_rows,
446-
combine_datasets,
447-
)
357+
combinations = generate_parameter_combinations(
358+
input_dataset,
359+
completion_params,
360+
input_messages,
361+
input_rows,
362+
evaluation_test_kwargs,
363+
max_dataset_rows,
364+
combine_datasets,
365+
)
448366
if len(combinations) == 0:
449367
raise ValueError(
450368
"No combinations of parameters were found. Please provide at least a model and one of input_dataset, input_messages, or input_rows."
451369
)
452370

453371
# Create parameter tuples for pytest.mark.parametrize
454-
param_tuples = []
455-
for combo in combinations:
456-
dataset, cp, messages, rows, etk = combo
457-
param_tuple = []
458-
if input_dataset is not None:
459-
param_tuple.append(dataset)
460-
if completion_params is not None:
461-
param_tuple.append(cp)
462-
if input_messages is not None:
463-
param_tuple.append(messages)
464-
if input_rows is not None:
465-
param_tuple.append(rows)
466-
if evaluation_test_kwargs is not None:
467-
param_tuple.append(etk)
468-
param_tuples.append(tuple(param_tuple))
469-
470-
# For all mode, preserve the original parameter names
471-
test_param_names = []
472-
if input_dataset is not None:
473-
test_param_names.append("dataset_path")
474-
if completion_params is not None:
475-
test_param_names.append("completion_params")
476-
if input_messages is not None:
477-
test_param_names.append("input_messages")
478-
if input_rows is not None:
479-
test_param_names.append("input_rows")
480-
if evaluation_test_kwargs is not None:
481-
test_param_names.append("evaluation_test_kwargs")
372+
pytest_parametrize_args = pytest_parametrize(
373+
combinations,
374+
input_dataset,
375+
completion_params,
376+
input_messages,
377+
input_rows,
378+
evaluation_test_kwargs,
379+
)
482380

483381
# Create wrapper function with exact signature that pytest expects
484382
def create_wrapper_with_signature() -> Callable:
@@ -613,7 +511,7 @@ async def _execute_eval_with_semaphore(**inner_kwargs):
613511
# NOTE: we will still evaluate errored rows (give users control over this)
614512
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
615513
if "row" in inner_kwargs:
616-
result = await execute_with_params(
514+
result = await execute_pytest(
617515
test_func,
618516
processed_row=inner_kwargs["row"],
619517
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
@@ -624,7 +522,7 @@ async def _execute_eval_with_semaphore(**inner_kwargs):
624522
)
625523
return result
626524
if "rows" in inner_kwargs:
627-
results = await execute_with_params(
525+
results = await execute_pytest(
628526
test_func,
629527
processed_dataset=inner_kwargs["rows"],
630528
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
@@ -696,7 +594,7 @@ async def _collect_result(config, lst):
696594
input_dataset.append(row)
697595
# NOTE: we will still evaluate errored rows (give users control over this)
698596
# i.e., they can choose to give EvaluateResult.score = 0 for errored rows in their test_func
699-
results = await execute_with_params(
597+
results = await execute_pytest(
700598
test_func,
701599
processed_dataset=input_dataset,
702600
evaluation_test_kwargs=kwargs.get("evaluation_test_kwargs") or {},
@@ -795,7 +693,7 @@ async def _collect_result(config, lst):
795693

796694
# Create the pytest wrapper
797695
pytest_wrapper = create_wrapper_with_signature()
798-
pytest_wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(pytest_wrapper)
696+
pytest_wrapper = pytest.mark.parametrize(**pytest_parametrize_args)(pytest_wrapper)
799697
pytest_wrapper = pytest.mark.asyncio(pytest_wrapper)
800698

801699
def create_dual_mode_wrapper() -> Callable:

eval_protocol/pytest/execution.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import asyncio
2+
from collections.abc import Awaitable
3+
from eval_protocol.models import EvaluationRow
4+
from eval_protocol.pytest.types import EvaluationInputParam, TestFunction
5+
6+
7+
async def execute_pytest(
8+
test_func: TestFunction,
9+
processed_row: EvaluationRow | None = None,
10+
processed_dataset: list[EvaluationRow] | None = None,
11+
evaluation_test_kwargs: EvaluationInputParam | None = None,
12+
) -> EvaluationRow | list[EvaluationRow]:
13+
if evaluation_test_kwargs is not None:
14+
if "row" in evaluation_test_kwargs:
15+
raise ValueError("'row' is a reserved parameter for the evaluation function")
16+
if "rows" in evaluation_test_kwargs:
17+
raise ValueError("'rows' is a reserved parameter for the evaluation function")
18+
19+
# Handle both sync and async test functions
20+
if asyncio.iscoroutinefunction(test_func):
21+
if processed_row is not None:
22+
return await test_func(processed_row)
23+
if processed_dataset is not None:
24+
return await test_func(processed_dataset)
25+
return await test_func()
26+
else:
27+
if processed_row is not None:
28+
row = test_func(processed_row)
29+
if processed_dataset is not None:
30+
return test_func(processed_dataset)
31+
return test_func()

0 commit comments

Comments
 (0)