Skip to content

Commit 6029271

Browse files
author
Dylan Huang
authored
Merge pull request #4 from eval-protocol/pytest_for_pointwise
feat: Add pointwise evaluation mode with pytest integration
2 parents 3f3161e + 56bf3ab commit 6029271

15 files changed

Lines changed: 623 additions & 360 deletions

eval_protocol/pytest/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
from .default_agent_rollout_processor import default_agent_rollout_processor
22
from .default_no_op_rollout_process import default_no_op_rollout_processor
33
from .default_single_turn_rollout_process import default_single_turn_rollout_processor
4-
from .pytest_utils import evaluate, evaluation_test
4+
from .evaluation_test import evaluation_test
55
from .types import RolloutProcessor, RolloutProcessorConfig
6+
from .utils import evaluate
67

78
__all__ = [
89
"default_agent_rollout_processor",
Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
import inspect
2+
from typing import Any, Callable, Dict, List, Optional
3+
4+
import pytest
5+
6+
from eval_protocol.models import EvaluationRow
7+
from eval_protocol.pytest.default_no_op_rollout_process import default_no_op_rollout_processor
8+
from eval_protocol.pytest.types import (
9+
Dataset,
10+
DatasetPathParam,
11+
EvaluationTestMode,
12+
InputMessagesParam,
13+
InputParam,
14+
ModelParam,
15+
RolloutProcessor,
16+
RolloutProcessorConfig,
17+
TestFunction,
18+
)
19+
from eval_protocol.pytest.utils import (
20+
AggregationMethod,
21+
aggregate,
22+
create_dynamically_parameterized_wrapper,
23+
execute_function,
24+
)
25+
26+
from ..common_utils import load_jsonl
27+
28+
29+
def evaluation_test(
30+
*,
31+
model: List[ModelParam],
32+
input_messages: Optional[List[InputMessagesParam]] = None,
33+
input_dataset: Optional[List[DatasetPathParam]] = None,
34+
dataset_adapter: Optional[Callable[[List[Dict[str, Any]]], Dataset]] = lambda x: x,
35+
input_params: Optional[List[InputParam]] = None,
36+
rollout_processor: RolloutProcessor = default_no_op_rollout_processor,
37+
aggregation_method: AggregationMethod = "mean",
38+
threshold_of_success: Optional[float] = None,
39+
num_runs: int = 1,
40+
max_dataset_rows: Optional[int] = None,
41+
mcp_config_path: Optional[str] = None,
42+
mode: EvaluationTestMode = "batch",
43+
) -> Callable[
44+
[TestFunction],
45+
TestFunction,
46+
]:
47+
"""Decorator to create pytest-based evaluation tests.
48+
49+
Args:
50+
model: Model identifiers to query.
51+
input_messages: Messages to send to the model. This is useful if you
52+
don't have a dataset but can hard-code the messages. Will be passed as
53+
"input_dataset" to the test function.
54+
input_dataset: Paths to JSONL datasets. This is useful if you have a
55+
dataset already. Provide a dataset_adapter to convert the input dataset
56+
to a list of EvaluationRows if you have a custom dataset format.
57+
dataset_adapter: Function to convert the input dataset to a list of
58+
EvaluationRows. This is useful if you have a custom dataset format.
59+
input_params: Generation parameters for the model.
60+
rollout_processor: Function used to perform the rollout.
61+
aggregation_method: How to aggregate scores across rows.
62+
threshold_of_success: If set, fail the test if the aggregated score is
63+
below this threshold.
64+
num_runs: Number of times to repeat the evaluation.
65+
max_dataset_rows: Limit dataset to the first N rows.
66+
mcp_config_path: Path to MCP config file that follows MCPMultiClientConfiguration schema
67+
mode: Evaluation mode. "batch" (default) expects test function to handle
68+
full dataset. "pointwise" applies test function to each row. If your evaluation requires
69+
the full rollout of all rows to compute the score, use
70+
"""
71+
72+
def decorator(
73+
test_func: TestFunction,
74+
):
75+
sig = inspect.signature(test_func)
76+
77+
# For pointwise/rowwise mode, we expect a different signature
78+
if mode == "pointwise":
79+
# Pointwise mode: function should accept messages and other row-level params
80+
if "row" not in sig.parameters:
81+
raise ValueError(f"In pointwise mode, your eval function must have a parameter named 'row'")
82+
83+
# validate that "Row" is of type EvaluationRow
84+
if sig.parameters["row"].annotation is not EvaluationRow:
85+
raise ValueError(f"In pointwise mode, the 'row' parameter must be of type EvaluationRow")
86+
87+
# validate that the function has a return type of EvaluationRow
88+
if sig.return_annotation is not EvaluationRow:
89+
raise ValueError("In pointwise mode, your eval function must return an EvaluationRow instance")
90+
else:
91+
# Batch mode: function should accept input_dataset and model
92+
if "rows" not in sig.parameters:
93+
raise ValueError("In batch mode, your eval function must have a parameter named 'rows'")
94+
95+
# validate that "Rows" is of type List[EvaluationRow]
96+
if sig.parameters["rows"].annotation is not List[EvaluationRow]:
97+
raise ValueError(f"In batch mode, the 'rows' parameter must be of type List[EvaluationRow]")
98+
99+
# validate that the function has a return type of List[EvaluationRow]
100+
if sig.return_annotation is not List[EvaluationRow]:
101+
raise ValueError("In batch mode, your eval function must return a list of EvaluationRow instances")
102+
103+
def execute_with_params(
104+
test_func: TestFunction,
105+
row: EvaluationRow | None = None,
106+
input_dataset: List[EvaluationRow] | None = None,
107+
):
108+
kwargs = {}
109+
if input_dataset is not None:
110+
kwargs["rows"] = input_dataset
111+
if row is not None:
112+
kwargs["row"] = row
113+
return execute_function(test_func, **kwargs)
114+
115+
# Calculate all possible combinations of parameters
116+
def generate_combinations():
117+
combinations = []
118+
119+
# Handle optional parameters with defaults
120+
datasets: List[Optional[DatasetPathParam]] = input_dataset if input_dataset is not None else [None] # type: ignore
121+
params: List[Optional[InputParam]] = input_params if input_params is not None else [None] # type: ignore
122+
messages: List[Optional[InputMessagesParam]] = input_messages if input_messages is not None else [None] # type: ignore
123+
124+
# Generate all combinations
125+
for m in model:
126+
for ds in datasets:
127+
for ip in params:
128+
for im in messages:
129+
# Skip combinations that don't make sense
130+
# If we have a dataset, we should have params for rollout
131+
if ds is not None and ip is None:
132+
continue
133+
# If we have messages but no dataset, that's fine
134+
# If we have no dataset and no messages, that's also fine
135+
combinations.append((m, ds, ip, im))
136+
137+
return combinations
138+
139+
combinations = generate_combinations()
140+
141+
# Create parameter tuples for pytest.mark.parametrize
142+
param_tuples = []
143+
for combo in combinations:
144+
model_name, dataset, params, messages = combo
145+
param_tuple = [model_name]
146+
if input_dataset is not None:
147+
param_tuple.append(dataset)
148+
if input_params is not None:
149+
param_tuple.append(params)
150+
if input_messages is not None:
151+
param_tuple.append(messages)
152+
param_tuples.append(tuple(param_tuple))
153+
154+
# For batch mode, use the original parameter names
155+
test_param_names = ["model"]
156+
if input_dataset is not None:
157+
test_param_names.append("dataset_path")
158+
if input_params is not None:
159+
test_param_names.append("input_params")
160+
if input_messages is not None:
161+
test_param_names.append("input_messages")
162+
163+
# Create wrapper function with exact signature that pytest expects
164+
def create_wrapper_with_signature():
165+
# Create the function body that will be used
166+
def wrapper_body(**kwargs):
167+
model_name = kwargs["model"]
168+
169+
# Handle dataset loading
170+
if "dataset_path" in kwargs and kwargs["dataset_path"] is not None:
171+
data = load_jsonl(kwargs["dataset_path"])
172+
if max_dataset_rows is not None:
173+
data = data[:max_dataset_rows]
174+
data = dataset_adapter(data)
175+
elif "input_messages" in kwargs and kwargs["input_messages"] is not None:
176+
data: List[EvaluationRow] = [EvaluationRow(messages=kwargs["input_messages"])]
177+
else:
178+
raise ValueError("No input dataset or input messages provided")
179+
180+
input_dataset: List[EvaluationRow] = []
181+
config = RolloutProcessorConfig(
182+
model=model_name,
183+
input_params=kwargs.get("input_params") or {},
184+
mcp_config_path=mcp_config_path or "",
185+
initial_messages=kwargs.get("input_messages") if "input_messages" in kwargs else [],
186+
)
187+
for row in data:
188+
processed: List[EvaluationRow] = execute_function(rollout_processor, row=row, config=config)
189+
input_dataset.extend(processed)
190+
191+
all_results: List[EvaluationRow] = []
192+
for _ in range(num_runs):
193+
if mode == "pointwise":
194+
# Pointwise mode: apply the evaluator function to each row
195+
for row in input_dataset:
196+
result = execute_with_params(
197+
test_func,
198+
row=row,
199+
)
200+
if result is None or not isinstance(result, EvaluationRow):
201+
raise ValueError(
202+
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
203+
)
204+
all_results.append(result)
205+
else:
206+
# Batch mode: call the test function with the full dataset
207+
results = execute_with_params(
208+
test_func,
209+
input_dataset=input_dataset,
210+
)
211+
if results is None:
212+
raise ValueError(
213+
f"Test function {test_func.__name__} did not return an EvaluationRow instance. You must return an EvaluationRow instance from your test function decorated with @evaluation_test."
214+
)
215+
if not isinstance(results, list):
216+
raise ValueError(
217+
f"Test function {test_func.__name__} did not return a list of EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
218+
)
219+
if not results:
220+
raise ValueError(
221+
f"Test function {test_func.__name__} returned an empty list. You must return a non-empty list of EvaluationRow instances from your test function decorated with @evaluation_test."
222+
)
223+
if not all(isinstance(r, EvaluationRow) for r in results):
224+
raise ValueError(
225+
f"Test function {test_func.__name__} returned a list containing non-EvaluationRow instances. You must return a list of EvaluationRow instances from your test function decorated with @evaluation_test."
226+
)
227+
all_results.extend(results)
228+
229+
scores = [r.evaluation_result.score for r in all_results if r.evaluation_result]
230+
agg_score = aggregate(scores, aggregation_method)
231+
if threshold_of_success is not None:
232+
assert (
233+
agg_score >= threshold_of_success
234+
), f"Aggregated score {agg_score:.3f} below threshold {threshold_of_success}"
235+
236+
return create_dynamically_parameterized_wrapper(test_func, wrapper_body, test_param_names)
237+
238+
wrapper = create_wrapper_with_signature()
239+
wrapper = pytest.mark.parametrize(test_param_names, param_tuples)(wrapper)
240+
241+
return wrapper
242+
243+
return decorator

0 commit comments

Comments
 (0)