Skip to content

Commit 6725f56

Browse files
author
Dylan Huang
authored
Add dataset adapter support in evaluation_test and new test cases (#83)
- Included helper function `gsm8k_to_evaluation_row` for transforming GSM8K dataset entries into evaluation rows.
1 parent 3d63c7b commit 6725f56

6 files changed

Lines changed: 27 additions & 12 deletions

File tree

eval_protocol/pytest_utils.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def evaluation_test(
7474
model: List[ModelParam],
7575
input_messages: Optional[List[InputMessagesParam]] = None,
7676
input_dataset: Optional[List[DatasetPathParam]] = None,
77+
dataset_adapter: Optional[Callable[[List[Dict[str, Any]]], Dataset]] = lambda x: x,
7778
input_params: Optional[List[InputParam]] = None,
7879
rollout_processor: Callable[
7980
[EvaluationRow, ModelParam, InputParam], List[EvaluationRow]
@@ -90,8 +91,13 @@ def evaluation_test(
9091
9192
Args:
9293
model: Model identifiers to query.
93-
input_messages: Messages to send to the model.
94-
input_dataset: Paths to JSONL datasets.
94+
input_messages: Messages to send to the model. This is useful if you
95+
don't have a dataset but can hard-code the messages.
96+
input_dataset: Paths to JSONL datasets. This is useful if you have a
97+
dataset already. Provide a dataset_adapter to convert the input dataset
98+
to a list of EvaluationRows if you have a custom dataset format.
99+
dataset_adapter: Function to convert the input dataset to a list of
100+
EvaluationRows. This is useful if you have a custom dataset format.
95101
input_params: Generation parameters for the model.
96102
rollout_processor: Function used to perform the rollout.
97103
aggregation_method: How to aggregate scores across rows.
@@ -240,16 +246,9 @@ def wrapper_body(**kwargs):
240246
data = load_jsonl(kwargs["dataset_path"])
241247
if max_dataset_rows is not None:
242248
data = data[:max_dataset_rows]
243-
input_dataset = []
244-
for entry in data:
245-
user_query = entry.get("user_query") or entry.get("prompt")
246-
if not user_query:
247-
continue
248-
messages = [Message(role="user", content=user_query)]
249-
row = EvaluationRow(
250-
messages=messages,
251-
ground_truth=entry.get("ground_truth_for_eval"),
252-
)
249+
data = dataset_adapter(data)
250+
input_dataset: List[EvaluationRow] = []
251+
for row in data:
253252
processed = rollout_processor(
254253
row, model=model_name, input_params=kwargs.get("input_params") or {}
255254
)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import Any, Dict, List
2+
3+
from eval_protocol.models import EvaluationRow, Message
4+
5+
6+
def gsm8k_to_evaluation_row(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
7+
return [
8+
EvaluationRow(
9+
messages=[Message(role="user", content=row["user_query"])], ground_truth=row["ground_truth_for_eval"]
10+
)
11+
for row in data
12+
]
File renamed without changes.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from eval_protocol.pytest_utils import evaluate, evaluation_test
22
from examples.math_example.main import evaluate as math_evaluate
3+
from tests.pytest.helper.gsm8k_to_evaluation_row import gsm8k_to_evaluation_row
34

45

56
@evaluation_test(
67
input_dataset=["development/gsm8k_sample.jsonl"],
8+
dataset_adapter=gsm8k_to_evaluation_row,
79
model=["accounts/fireworks/models/kimi-k2-instruct"],
810
input_params=[{"temperature": 0.0}],
911
max_dataset_rows=5,

tests/test_pytest_math_format_length.py renamed to tests/pytest/test_pytest_math_format_length.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from eval_protocol.pytest_utils import evaluate, evaluation_test
22
from examples.math_with_format_and_length.main import evaluate as math_fl_evaluate
3+
from tests.pytest.helper.gsm8k_to_evaluation_row import gsm8k_to_evaluation_row
34

45

56
@evaluation_test(
67
input_dataset=["development/gsm8k_sample.jsonl"],
8+
dataset_adapter=gsm8k_to_evaluation_row,
79
model=["accounts/fireworks/models/kimi-k2-instruct"],
810
input_params=[{"temperature": 0.0}],
911
max_dataset_rows=5,

0 commit comments

Comments
 (0)