Skip to content

Commit 1722e0f

Browse files
author
Dylan Huang
committed
Refactor evaluation_test function to enforce parameter type validation for pointwise and batch modes, updating tests to use 'rows' instead of 'input_dataset' for consistency.
1 parent 9dc3a2a commit 1722e0f

8 files changed

Lines changed: 33 additions & 29 deletions

eval_protocol/pytest/pytest_utils.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -204,15 +204,21 @@ def decorator(
204204
if "row" not in sig.parameters:
205205
raise ValueError(f"In pointwise mode, your eval function must have a parameter named 'row'")
206206

207+
# validate that "Row" is of type EvaluationRow
208+
if sig.parameters["row"].annotation is not EvaluationRow:
209+
raise ValueError(f"In pointwise mode, the 'row' parameter must be of type EvaluationRow")
210+
207211
# validate that the function has a return type of EvaluationRow
208212
if sig.return_annotation is not EvaluationRow:
209213
raise ValueError("In pointwise mode, your eval function must return an EvaluationRow instance")
210214
else:
211215
# Batch mode: function should accept input_dataset and model
212-
if "input_dataset" not in sig.parameters:
213-
raise ValueError("In batch mode, your eval function must have a parameter named 'input_dataset'")
214-
if "model" not in sig.parameters:
215-
raise ValueError("In batch mode, your eval function must have a parameter named 'model'")
216+
if "rows" not in sig.parameters:
217+
raise ValueError("In batch mode, your eval function must have a parameter named 'rows'")
218+
219+
# validate that "Rows" is of type List[EvaluationRow]
220+
if sig.parameters["rows"].annotation is not List[EvaluationRow]:
221+
raise ValueError(f"In batch mode, the 'rows' parameter must be of type List[EvaluationRow]")
216222

217223
# validate that the function has a return type of List[EvaluationRow]
218224
if sig.return_annotation is not List[EvaluationRow]:
@@ -227,7 +233,7 @@ def execute_with_params(
227233
):
228234
kwargs = {}
229235
if input_dataset is not None:
230-
kwargs["input_dataset"] = list(input_dataset)
236+
kwargs["rows"] = input_dataset
231237
if input_params is not None:
232238
kwargs["input_params"] = input_params
233239
if model is not None:

tests/pytest/test_markdown_highlighting.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
This test demonstrates how to check if model responses contain the required number of highlighted sections.
55
"""
66

7-
import json
87
import re
98
from typing import Any, Dict, List, Optional
109

@@ -69,8 +68,8 @@ def markdown_format_evaluate(messages: List[Message], ground_truth: Optional[str
6968
rollout_processor=default_single_turn_rollout_processor,
7069
num_runs=1,
7170
)
72-
def test_markdown_highlighting_evaluation(input_dataset, input_params, model) -> List[EvaluationRow]:
71+
def test_markdown_highlighting_evaluation(rows: List[EvaluationRow]) -> List[EvaluationRow]:
7372
"""
7473
Test markdown highlighting validation using batch mode with evaluate().
7574
"""
76-
return evaluate(input_dataset, markdown_format_evaluate)
75+
return evaluate(rows, markdown_format_evaluate)

tests/pytest/test_pytest_async.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
from typing import List
22

3-
from eval_protocol.models import EvaluationRow
3+
from eval_protocol.models import EvaluationRow, Message
44
from eval_protocol.pytest import evaluation_test
55
from examples.math_example.main import evaluate as math_evaluate
66

77

88
@evaluation_test(
99
input_messages=[
1010
[
11-
{"role": "user", "content": "What is the capital of France?"},
11+
Message(role="user", content="What is the capital of France?"),
1212
],
1313
[
14-
{"role": "user", "content": "What is the capital of the moon?"},
14+
Message(role="user", content="What is the capital of the moon?"),
1515
],
1616
],
1717
model=["accounts/fireworks/models/kimi-k2-instruct"],
1818
)
19-
async def test_pytest_async(input_dataset: List[EvaluationRow], model) -> List[EvaluationRow]:
19+
async def test_pytest_async(rows: List[EvaluationRow]) -> List[EvaluationRow]:
2020
"""Run math evaluation on sample dataset using pytest interface."""
21-
return input_dataset
21+
return rows
Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,24 @@
11
from datetime import datetime
22
from typing import List
33

4-
from eval_protocol.models import EvaluationRow, Message
4+
from eval_protocol.models import Message, EvaluationRow
55
from eval_protocol.pytest import default_agent_rollout_processor, evaluation_test
66

77

88
@evaluation_test(
99
input_messages=[
1010
[
11-
{
12-
"role": "user",
13-
"content": "Can you give a summary of the past week in the 'general, model-requests, bug-reports, questions, and feature-requests' channels. For EVERY message or thread has not been resolved, please list them at the end of your response in a table. Be sure to include the exact message, severity, and current status so far. Current Date & Time: {current_date_time}".format(
11+
Message(
12+
role="user",
13+
content="Can you give a summary of the past week in the 'general, model-requests, bug-reports, questions, and feature-requests' channels. For EVERY message or thread has not been resolved, please list them at the end of your response in a table. Be sure to include the exact message, severity, and current status so far. Current Date & Time: {current_date_time}".format(
1414
current_date_time=datetime.now().strftime("%B %d, %Y at %I:%M %p")
1515
),
16-
}
16+
)
1717
]
1818
],
1919
rollout_processor=default_agent_rollout_processor,
2020
model=["fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"],
2121
)
22-
def test_pytest_default_agent_rollout_processor(input_dataset: List[EvaluationRow], model) -> List[EvaluationRow]:
22+
def test_pytest_default_agent_rollout_processor(rows: List[EvaluationRow]) -> List[EvaluationRow]:
2323
"""Run math evaluation on sample dataset using pytest interface."""
24-
return input_dataset
24+
return rows
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
from typing import List
22

3-
from eval_protocol.models import EvaluationRow
3+
from eval_protocol.models import Message, EvaluationRow
44
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test
55

66

77
@evaluation_test(
88
input_messages=[
99
[
10-
{"role": "user", "content": "What is the capital of France?"},
10+
Message(role="user", content="What is the capital of France?"),
1111
]
1212
],
1313
model=["accounts/fireworks/models/kimi-k2-instruct"],
1414
rollout_processor=default_single_turn_rollout_processor,
1515
)
16-
def test_input_messages_in_decorator(input_dataset: List[EvaluationRow], model) -> List[EvaluationRow]:
16+
def test_input_messages_in_decorator(rows: List[EvaluationRow]) -> List[EvaluationRow]:
1717
"""Run math evaluation on sample dataset using pytest interface."""
18-
return input_dataset
18+
return rows

tests/pytest/test_pytest_math_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414
threshold_of_success=0.0,
1515
rollout_processor=default_single_turn_rollout_processor,
1616
)
17-
def test_math_dataset(input_dataset, input_params, model) -> List[EvaluationRow]:
17+
def test_math_dataset(rows: List[EvaluationRow]) -> List[EvaluationRow]:
1818
"""Run math evaluation on sample dataset using pytest interface."""
19-
return evaluate(input_dataset, math_evaluate)
19+
return evaluate(rows, math_evaluate)

tests/pytest/test_pytest_math_format_length.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,6 @@
1414
threshold_of_success=0.0,
1515
rollout_processor=default_single_turn_rollout_processor,
1616
)
17-
def test_math_format_length_dataset(input_dataset, input_params, model) -> List[EvaluationRow]:
17+
def test_math_format_length_dataset(rows: List[EvaluationRow]) -> List[EvaluationRow]:
1818
"""Run math with format and length evaluation on sample dataset."""
19-
return evaluate(input_dataset, math_fl_evaluate)
19+
return evaluate(rows, math_fl_evaluate)

tests/pytest/test_pytest_word_count_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from typing import List
21
from eval_protocol.pytest import default_single_turn_rollout_processor, evaluation_test
32
from eval_protocol.models import EvaluateResult, MetricResult, EvaluationRow
43
from tests.pytest.helper.word_count_to_evaluation_row import word_count_to_evaluation_row

0 commit comments

Comments
 (0)