-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathtest_pytest_propagate_error.py
More file actions
68 lines (55 loc) · 2.53 KB
/
Copy pathtest_pytest_propagate_error.py
File metadata and controls
68 lines (55 loc) · 2.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from typing import Set
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.default_agent_rollout_processor import AgentRolloutProcessor
from eval_protocol.dataset_logger import DatasetLogger
class TrackingLogger(DatasetLogger):
"""Custom logger that ensures that the final row is in an error state."""
def __init__(self, rollouts: dict[str, EvaluationRow]):
self.rollouts = rollouts
def log(self, row: EvaluationRow):
self.rollouts[row.execution_metadata.rollout_id] = row
def read(self):
return []
async def test_pytest_propagate_error():
"""
Properly propagate errors from rollout processing to eval_metadata.status.
To test this, we use a broken MCP configuration that should fail during the
rollout processing. Then the final eval_metadata.status should be an error.
This way the UI can properly render an error state for the rollout and a
developer can identify and investigate the error.
"""
from eval_protocol.pytest.evaluation_test import evaluation_test
input_messages = [
[
Message(
role="system",
content="You are a helpful assistant that can answer questions about Fireworks.",
),
]
]
completion_params_list = [
{"model": "dummy/local-model"},
]
rollouts: dict[str, EvaluationRow] = {}
logger = TrackingLogger(rollouts)
@evaluation_test(
input_messages=input_messages,
completion_params=completion_params_list,
rollout_processor=AgentRolloutProcessor(),
mode="pointwise",
num_runs=5,
mcp_config_path="tests/pytest/mcp_configurations/docs_mcp_config_broken.json",
logger=logger,
)
def eval_fn(row: EvaluationRow) -> EvaluationRow:
return row
# Manually invoke all parameter combinations within a single test
for params in completion_params_list:
await eval_fn(input_messages=input_messages, completion_params=params)
# assert that the status of eval_metadata.status is "error"
assert len(rollouts) == 5
assert all(row.eval_metadata.status.is_error() for row in rollouts.values())
# make sure the error message includes details of the error
assert all("HTTPStatusError" in row.rollout_status.message for row in rollouts.values())
assert all("405 Method Not Allowed" in row.rollout_status.message for row in rollouts.values())
assert all("https://docs.fireworks.ai/mcp-non-existent" in row.rollout_status.message for row in rollouts.values())