Skip to content

Commit a5f3cb3

Browse files
committed
init
1 parent bd1be95 commit a5f3cb3

6 files changed

Lines changed: 288 additions & 1 deletion

File tree

eval_protocol/pytest/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .default_agent_rollout_processor import AgentRolloutProcessor
22
from .default_dataset_adapter import default_dataset_adapter
3+
from .default_klavis_sandbox_rollout_processor import KlavisSandboxRolloutProcessor
34
from .default_mcp_gym_rollout_processor import MCPGymRolloutProcessor
45
from .default_no_op_rollout_processor import NoOpRolloutProcessor
56
from .default_single_turn_rollout_process import SingleTurnRolloutProcessor
@@ -31,6 +32,7 @@
3132

3233
__all__ = [
3334
"AgentRolloutProcessor",
35+
"KlavisSandboxRolloutProcessor",
3436
"MCPGymRolloutProcessor",
3537
"RolloutProcessor",
3638
"SingleTurnRolloutProcessor",
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import asyncio
2+
import json
3+
import logging
4+
import os
5+
import tempfile
6+
import time
7+
from typing import Any, Callable, Dict, List, Optional
8+
9+
from pydantic import BaseModel, Field
10+
11+
from eval_protocol.models import EvaluationRow
12+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
13+
from eval_protocol.pytest.types import RolloutProcessorConfig
14+
15+
from eval_protocol.pytest.default_agent_rollout_processor import Agent
16+
from klavis import Klavis
17+
from klavis.types import CreateSandboxResponse, SandboxMcpServer
18+
from openai.types import CompletionUsage
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
class KlavisSandboxRolloutProcessor(RolloutProcessor):
24+
def __init__(
25+
self,
26+
server_name: str,
27+
initialize_data_factory: Optional[Callable[[EvaluationRow], Dict[str, Any]]] = None,
28+
):
29+
super().__init__()
30+
self.server_name = server_name
31+
self.initialize_data_factory = initialize_data_factory
32+
self.klavis_client = Klavis(api_key=os.environ.get("KLAVIS_API_KEY"))
33+
self.sandbox = self._init_sandbox()
34+
35+
def _init_sandbox(self) -> CreateSandboxResponse:
36+
try:
37+
server_name_enum = SandboxMcpServer(self.server_name)
38+
return self.klavis_client.sandbox.create_sandbox(server_name=server_name_enum)
39+
except Exception as e:
40+
logger.error(f"Error creating sandbox: {str(e)}", exc_info=True)
41+
raise
42+
43+
@staticmethod
44+
def create_mcp_config(server_url: str, server_key: str = "main", auth_token: str | None = None) -> str:
45+
"""Create a temporary MCP config file and return its path."""
46+
config = {
47+
"mcpServers": {
48+
server_key: {
49+
"url": server_url,
50+
"transport": "streamable_http",
51+
**({"authorization": f"Bearer {auth_token}"} if auth_token else {})
52+
}
53+
}
54+
}
55+
56+
# Create a temp file that persists for the session
57+
fd, path = tempfile.mkstemp(suffix=".json", prefix="mcp_config_")
58+
with os.fdopen(fd, 'w') as f:
59+
json.dump(config, f)
60+
return path
61+
62+
def __call__(
63+
self, rows: List[EvaluationRow], config: RolloutProcessorConfig
64+
) -> List[asyncio.Task[EvaluationRow]]:
65+
"""Process evaluation rows with Klavis sandbox lifecycle management"""
66+
if not self.sandbox:
67+
raise RuntimeError("Sandbox not initialized")
68+
69+
semaphore = config.semaphore
70+
71+
async def process_row(row: EvaluationRow) -> EvaluationRow:
72+
"""Process a single row with complete sandbox lifecycle"""
73+
74+
start_time = time.perf_counter()
75+
76+
try:
77+
# Step 1: Initialize data in the sandbox
78+
if self.initialize_data_factory:
79+
logger.info(f"Initializing {self.server_name} sandbox {self.sandbox.sandbox_id}")
80+
init_data = self.initialize_data_factory(row)
81+
initialize_method = getattr(self.klavis_client.sandbox, f"initialize_{self.sandbox.server_name}_sandbox")
82+
initialize_method(sandbox_id=self.sandbox.sandbox_id, **init_data)
83+
logger.info(f"Sandbox initialized successfully")
84+
85+
# Step 2: Create temporary MCP config with sandbox URL
86+
temp_config_path = self.create_mcp_config(server_url=self.sandbox.server_url, server_key=self.sandbox.server_name)
87+
88+
# Step 3: Run agent with sandbox MCP server
89+
logger.info(f"Running agent for row {row.execution_metadata.rollout_id} with {self.server_name} sandbox")
90+
agent = Agent(
91+
model=row.input_metadata.completion_params["model"],
92+
row=row,
93+
config_path=temp_config_path,
94+
logger=config.logger,
95+
)
96+
await agent.setup()
97+
await agent.call_agent()
98+
99+
# Update usage metadata
100+
row.execution_metadata.usage = CompletionUsage(
101+
prompt_tokens=agent.usage.get("prompt_tokens", 0),
102+
completion_tokens=agent.usage.get("completion_tokens", 0),
103+
total_tokens=agent.usage.get("total_tokens", 0),
104+
)
105+
row = agent.evaluation_row
106+
logger.info(f"Agent execution completed for row {row.execution_metadata.rollout_id}")
107+
108+
# Step 4: Export sandbox data
109+
logger.info(f"Exporting {self.server_name} sandbox data")
110+
dump_method = getattr(self.klavis_client.sandbox, f"dump_{self.sandbox.server_name}_sandbox")
111+
dump_response = dump_method(sandbox_id=self.sandbox.sandbox_id)
112+
sandbox_data = dump_response.data
113+
114+
# Store sandbox data in row metadata for evaluation
115+
if not row.execution_metadata.extra:
116+
row.execution_metadata.extra = {}
117+
row.execution_metadata.extra["sandbox_data"] = sandbox_data
118+
row.execution_metadata.extra["sandbox_id"] = self.sandbox.sandbox_id
119+
row.execution_metadata.extra["server_name"] = self.server_name
120+
121+
except Exception as e:
122+
logger.error(f"Error processing row {row.execution_metadata.rollout_id}: {str(e)}", exc_info=True)
123+
if not row.execution_metadata.extra:
124+
row.execution_metadata.extra = {}
125+
row.execution_metadata.extra["error"] = str(e)
126+
raise
127+
128+
finally:
129+
# Cleanup agent MCP client and temp config
130+
if agent and agent.mcp_client:
131+
await agent.mcp_client.cleanup()
132+
if temp_config_path and os.path.exists(temp_config_path):
133+
os.unlink(temp_config_path)
134+
135+
# Release sandbox
136+
if self.sandbox.sandbox_id:
137+
try:
138+
logger.info(f"Releasing {self.server_name} sandbox {self.sandbox.sandbox_id}")
139+
self.klavis_client.sandbox.delete_sandbox(
140+
server_name=self.sandbox.server_name, sandbox_id=self.sandbox.sandbox_id
141+
)
142+
logger.info(f"Sandbox {self.sandbox.sandbox_id} released successfully")
143+
except Exception as e:
144+
logger.error(f"Error releasing sandbox {self.sandbox.sandbox_id}: {str(e)}", exc_info=True)
145+
146+
row.execution_metadata.rollout_duration_seconds = time.perf_counter() - start_time
147+
148+
return row
149+
150+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
151+
async with semaphore:
152+
result = await process_row(r)
153+
return result
154+
155+
# Create and return tasks
156+
tasks = [asyncio.create_task(_sem_wrapper(row)) for row in rows]
157+
return tasks

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ openenv = [
134134
dspy = [
135135
"dspy>=3.0.0",
136136
]
137+
klavis = [
138+
"klavis>=2.18.0",
139+
]
137140

138141
# Optional deps for LangGraph example/tests
139142
langgraph = [
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
{"messages": [{"role": "system", "content": "You are a helpful assistant with access to Gmail. You can send emails, draft emails, and manage messages."}, {"role": "user", "content": "Send an email to john@example.com with subject 'Meeting Tomorrow' and body 'Hi John, Just confirming our meeting tomorrow at 2pm. Best regards.'"}], "ground_truth": "One email sent to john@example.com with subject 'Meeting Tomorrow' containing meeting confirmation"}
2+
{"messages": [{"role": "system", "content": "You are a helpful assistant with access to Gmail. You can send emails, draft emails, and manage messages."}, {"role": "user", "content": "Draft an email to sarah@company.com with subject 'Project Update' and body 'Hi Sarah, The project is progressing well. I will send you the detailed report by Friday.'"}], "ground_truth": "One draft email created for sarah@company.com with subject 'Project Update' about project progress"}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import json
2+
import logging
3+
import os
4+
5+
from eval_protocol.models import EvaluateResult, EvaluationRow
6+
from eval_protocol.pytest import KlavisSandboxRolloutProcessor, evaluation_test
7+
from openai import AsyncOpenAI
8+
from pydantic import BaseModel
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class ResponseFormat(BaseModel):
14+
score: float
15+
reasoning: str
16+
17+
18+
@evaluation_test(
19+
input_dataset=["tests/pytest/datasets/klavis_gmail_sandbox_test.jsonl"],
20+
rollout_processor=KlavisSandboxRolloutProcessor(
21+
server_name="gmail",
22+
# Optional: provide custom initialization data factory
23+
# initialize_data_factory=lambda row: {"messages": [], "drafts": []},
24+
),
25+
completion_params=[{"model": "fireworks_ai/accounts/fireworks/models/deepseek-v3p2"}],
26+
mode="pointwise",
27+
)
28+
async def test_pytest_gmail_sandbox(row: EvaluationRow) -> EvaluationRow:
29+
"""
30+
Evaluate Gmail sandbox results by comparing with ground truth using LLM judge.
31+
32+
The sandbox data is exported after agent execution and compared with expected output.
33+
Sandbox data is available in row.execution_metadata.metadata["sandbox_data"].
34+
"""
35+
ground_truth = row.ground_truth
36+
sandbox_data = row.execution_metadata.extra.get("sandbox_data", {}) if row.execution_metadata.extra else {}
37+
final_message = row.messages[-1].content if row.messages else ""
38+
39+
logger.info(f"Evaluating row {row.execution_metadata.rollout_id}")
40+
logger.info(f"Final message: {final_message}")
41+
logger.info(f"Sandbox data: {json.dumps(sandbox_data, indent=2, default=str)}")
42+
logger.info(f"Ground truth: {ground_truth}")
43+
44+
async with AsyncOpenAI(
45+
api_key=os.environ["FIREWORKS_API_KEY"], base_url="https://api.fireworks.ai/inference/v1"
46+
) as client:
47+
# Use LLM to judge if the sandbox data matches the ground truth
48+
evaluation_prompt = f"""You are evaluating an AI agent's performance on a Gmail task.
49+
50+
Task: {row.messages[0].content if row.messages else 'N/A'}
51+
52+
Ground Truth: {ground_truth}
53+
54+
Agent's Final Response: {final_message}
55+
56+
Gmail Sandbox State After Execution:
57+
{json.dumps(sandbox_data, indent=2, default=str)}
58+
59+
Evaluate whether the agent successfully completed the task by checking:
60+
1. Did the agent understand and attempt the task?
61+
2. Does the sandbox data reflect the expected outcome described in the ground truth?
62+
3. Are there any emails sent/drafted that match the task requirements?
63+
64+
Return:
65+
- score: 1.0 if task completed successfully, 0.5 if partially completed, 0.0 if failed
66+
- reasoning: Explain your evaluation in 1-2 sentences
67+
"""
68+
69+
try:
70+
response = await client.chat.completions.create(
71+
model="accounts/fireworks/models/deepseek-v3p2",
72+
messages=[
73+
{
74+
"role": "system",
75+
"content": "You are a precise evaluator of AI agent performance. Analyze the task, execution, and results carefully.",
76+
},
77+
{"role": "user", "content": evaluation_prompt},
78+
],
79+
response_format={
80+
"type": "json_schema",
81+
"json_schema": {"name": "ResponseFormat", "schema": ResponseFormat.model_json_schema()},
82+
},
83+
temperature=0.0,
84+
)
85+
86+
response_text = response.choices[0].message.content
87+
logger.info(f"LLM judge response: {response_text}")
88+
89+
parsed = json.loads(response_text or "{}")
90+
score = parsed.get("score", 0.0)
91+
reasoning = parsed.get("reasoning", "No reasoning provided")
92+
93+
row.evaluation_result = EvaluateResult(
94+
score=score,
95+
reason=reasoning,
96+
)
97+
except Exception as e:
98+
logger.error(f"Error during LLM evaluation: {str(e)}", exc_info=True)
99+
row.evaluation_result = EvaluateResult(
100+
score=0.0,
101+
reason=f"Evaluation error: {str(e)}",
102+
)
103+
104+
return row

uv.lock

Lines changed: 20 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)