-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathtest_preprocess_fn_data_loaders.py
More file actions
92 lines (68 loc) · 2.86 KB
/
Copy pathtest_preprocess_fn_data_loaders.py
File metadata and controls
92 lines (68 loc) · 2.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from eval_protocol.data_loader import DynamicDataLoader
from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
from eval_protocol.models import EvaluateResult, EvaluationRow, Message
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.default_no_op_rollout_processor import NoOpRolloutProcessor
class InMemoryLogger(DatasetLogger):
def log(self, row: EvaluationRow) -> None:
return None
def read(self) -> list[EvaluationRow]:
return []
class StopAfterPreprocess(Exception):
pass
class StopAfterPreprocessRolloutProcessor(NoOpRolloutProcessor):
def setup(self) -> None:
raise StopAfterPreprocess("Stop after preprocessing for focused test assertions")
def _build_rows() -> list[EvaluationRow]:
return [
EvaluationRow(
messages=[
Message(role="user", content="question"),
Message(role="assistant", content="answer"),
]
)
]
async def test_preprocess_fn_runs_with_data_loader_without_loader_preprocess():
call_count = {"decorator_preprocess": 0}
def decorator_preprocess(rows: list[EvaluationRow]) -> list[EvaluationRow]:
call_count["decorator_preprocess"] += 1
return rows
data_loader = DynamicDataLoader(generators=[_build_rows])
@evaluation_test(
data_loaders=data_loader,
preprocess_fn=decorator_preprocess,
rollout_processor=StopAfterPreprocessRolloutProcessor(),
logger=InMemoryLogger(),
)
def eval_fn(row: EvaluationRow) -> EvaluationRow:
row.evaluation_result = EvaluateResult(score=1.0, reason="ok")
return row
try:
await eval_fn(data_loaders=data_loader)
except StopAfterPreprocess:
pass
assert call_count["decorator_preprocess"] == 1
async def test_preprocess_fn_not_double_applied_when_data_loader_preprocess_exists():
call_count = {"loader_preprocess": 0, "decorator_preprocess": 0}
def loader_preprocess(rows: list[EvaluationRow]) -> list[EvaluationRow]:
call_count["loader_preprocess"] += 1
return rows
def decorator_preprocess(rows: list[EvaluationRow]) -> list[EvaluationRow]:
call_count["decorator_preprocess"] += 1
return rows
data_loader = DynamicDataLoader(generators=[_build_rows], preprocess_fn=loader_preprocess)
@evaluation_test(
data_loaders=data_loader,
preprocess_fn=decorator_preprocess,
rollout_processor=StopAfterPreprocessRolloutProcessor(),
logger=InMemoryLogger(),
)
def eval_fn(row: EvaluationRow) -> EvaluationRow:
row.evaluation_result = EvaluateResult(score=1.0, reason="ok")
return row
try:
await eval_fn(data_loaders=data_loader)
except StopAfterPreprocess:
pass
assert call_count["loader_preprocess"] == 1
assert call_count["decorator_preprocess"] == 0