-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathtest_pytest_klavis_mcp.py
More file actions
54 lines (48 loc) · 2.02 KB
/
Copy pathtest_pytest_klavis_mcp.py
File metadata and controls
54 lines (48 loc) · 2.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from eval_protocol.models import EvaluateResult, EvaluationRow, Message
from eval_protocol.pytest import AgentRolloutProcessor, evaluation_test
from openai import AsyncOpenAI
import json
from pydantic import BaseModel
import logging
logger = logging.getLogger(__name__)
import os
class ResponseFormat(BaseModel):
score: float
@evaluation_test(
input_dataset=["tests/pytest/datasets/gmail_inbox.jsonl"],
rollout_processor=AgentRolloutProcessor(),
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct"}],
mode="pointwise",
mcp_config_path="tests/pytest/mcp_configurations/klavis_strata_mcp.json",
)
async def test_pytest_klavis_mcp(row: EvaluationRow) -> EvaluationRow:
ground_truth = row.ground_truth
# check if the final messages contains the ground truth
async with AsyncOpenAI(
api_key=os.environ["FIREWORKS_API_KEY"], base_url="https://api.fireworks.ai/inference/v1"
) as client:
response = await client.chat.completions.create(
model="accounts/fireworks/models/kimi-k2-instruct-0905",
messages=[
{
"role": "system",
"content": "You are judging the output of the model versus the ground truth. Return score = 1 if the output contains the ground truth, 0 otherwise.",
},
{
"role": "user",
"content": "Final model output: {row.messages[-1].content}\nGround truth: {ground_truth}",
},
],
response_format={
"type": "json_schema",
"json_schema": {"name": "ResponseFormat", "schema": ResponseFormat.model_json_schema()},
},
)
response_text = response.choices[0].message.content
logger.info("response_text: %s", response_text)
score = json.loads(response_text or "{}")["score"]
row.evaluation_result = EvaluateResult(
score=score,
reason=response_text,
)
return row