Skip to content

Commit 01b7476

Browse files
author
Dylan Huang
authored
Default agent rollout processor (#84)
* save * Add dataset adapter support in evaluation_test and new test cases - Included helper function `gsm8k_to_evaluation_row` for transforming GSM8K dataset entries into evaluation rows. * Refactor pytest rollout processors and enhance evaluation testing - Introduced new rollout processors: `default_agent_rollout_processor`, `default_no_op_rollout_processor`, and `default_single_turn_rollout_processor`. - Updated `Agent` class to support initial messages and improved model call handling. - Enhanced `RolloutProcessorConfig` to encapsulate model parameters and initial messages. - Refactored `evaluation_test` decorator to support async functions and improved parameter handling. - Added utility functions for evaluating and processing datasets in pytest tests. - Updated existing tests to utilize new rollout processors and ensure compatibility with the refactored structure. * Implement validation for empty messages in default_single_turn_rollout_processor to ensure non-empty datasets are provided
1 parent 4151522 commit 01b7476

16 files changed

Lines changed: 560 additions & 81 deletions

eval_protocol/mcp/execution/policy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,7 @@ async def _make_llm_call(self, messages: List[Dict], tools: List[Dict]) -> Dict:
185185
"choices": [
186186
{
187187
"message": {
188+
"role": response.choices[0].message.role,
188189
"content": response.choices[0].message.content,
189190
"tool_calls": (
190191
[
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import json
2+
import os
3+
from contextlib import AsyncExitStack
4+
from dataclasses import dataclass
5+
from typing import Any, Dict, List, Optional
6+
7+
from dotenv import load_dotenv
8+
from mcp import ClientSession, StdioServerParameters
9+
from mcp.client.stdio import stdio_client
10+
from mcp.types import CallToolResult
11+
from openai.types import FunctionDefinition
12+
from openai.types.chat import ChatCompletionToolParam
13+
14+
from eval_protocol.types.types import MCPMultiClientConfiguration
15+
16+
load_dotenv() # load environment variables from .env
17+
18+
19+
class MCPMultiClient:
20+
"""
21+
Implements what clients like Cursor and Claude Desktop do when you configure
22+
them to use multiple MCP servers. The difference is that it validates
23+
against a list of environment variables rather than injects them into the
24+
MCP server process. This is so you can version control your configuration
25+
without exposing your environment variables to the MCP server process.
26+
27+
Environment variables should instead be set in a .env file
28+
"""
29+
30+
def __init__(self, config_path: Optional[str] = None):
31+
# Initialize session and client objects
32+
self.sessions: Dict[str, ClientSession] = {}
33+
self.tools_to_sessions: Dict[str, ClientSession] = {}
34+
self.exit_stack = AsyncExitStack()
35+
self.config = self._load_config(config_path)
36+
37+
def _load_config(self, config_path: Optional[str] = None) -> MCPMultiClientConfiguration:
38+
"""Load MCP server configuration from file or use default"""
39+
if config_path and os.path.exists(config_path):
40+
with open(config_path, "r") as f:
41+
return json.load(f)
42+
43+
# Default configuration - can be overridden by config file
44+
return {"mcpServers": {}}
45+
46+
def _validate_environment_variables(self, server_name: str, required_env: List[str]) -> None:
47+
"""Validate that required environment variables are set in os.environ"""
48+
missing_vars = []
49+
for env_var in required_env:
50+
if env_var not in os.environ:
51+
missing_vars.append(env_var)
52+
53+
if missing_vars:
54+
raise ValueError(
55+
f"Server '{server_name}' requires the following environment variables "
56+
f"to be set in os.environ: {missing_vars}. "
57+
f"Please set these variables in your environment or .env file."
58+
)
59+
60+
async def connect_to_servers(self):
61+
"""Connect to all configured MCP servers"""
62+
if not self.config.get("mcpServers"):
63+
print("No MCP servers configured. Please provide a configuration file.")
64+
return
65+
66+
for server_name, server_config in self.config["mcpServers"].items():
67+
try:
68+
await self._connect_to_server(server_name, server_config)
69+
except Exception as e:
70+
print(f"Failed to connect to server '{server_name}': {e}")
71+
72+
async def _connect_to_server(self, server_name: str, server_config: Dict[str, Any]):
73+
"""Connect to a specific MCP server using its configuration"""
74+
command = server_config.get("command")
75+
args = server_config.get("args", [])
76+
env_config = server_config.get("env", [])
77+
78+
if not command:
79+
raise ValueError(f"Server '{server_name}' must have a 'command' specified")
80+
81+
# Validate that required environment variables are set
82+
if env_config:
83+
self._validate_environment_variables(server_name, env_config)
84+
85+
# Use the current system environment (os.environ) - don't override with config
86+
server_params = StdioServerParameters(command=command, args=args, env=os.environ)
87+
88+
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
89+
stdio, write = stdio_transport
90+
session = await self.exit_stack.enter_async_context(ClientSession(stdio, write))
91+
92+
await session.initialize()
93+
self.sessions[server_name] = session
94+
95+
# List available tools
96+
response = await session.list_tools()
97+
tools = response.tools
98+
for tool in tools:
99+
if tool.name in self.tools_to_sessions:
100+
raise ValueError(f"Tool '{tool.name}' already exists")
101+
self.tools_to_sessions[tool.name] = session
102+
print(
103+
f"\nConnected to server '{server_name}' with tools:",
104+
[tool.name for tool in tools],
105+
)
106+
107+
async def get_available_tools(self) -> List[ChatCompletionToolParam]:
108+
"""Get all available tools from all connected servers"""
109+
all_tools = []
110+
for server_name, session in self.sessions.items():
111+
try:
112+
response = await session.list_tools()
113+
for tool in response.tools:
114+
all_tools.append(
115+
ChatCompletionToolParam(
116+
function=FunctionDefinition(
117+
name=tool.name, # Prefix with server name
118+
description=tool.description,
119+
parameters=tool.inputSchema,
120+
),
121+
type="function",
122+
)
123+
)
124+
except Exception as e:
125+
print(f"Error listing tools from server '{server_name}': {e}")
126+
127+
return all_tools
128+
129+
async def call_tool(self, tool_name: str, tool_args: Dict[str, Any]) -> CallToolResult:
130+
"""Call a specific tool by name with arguments"""
131+
132+
session = self.tools_to_sessions[tool_name]
133+
try:
134+
result = await session.call_tool(tool_name, tool_args)
135+
return result
136+
except Exception as e:
137+
return f"Error calling tool {tool_name}: {e}"
138+
139+
async def cleanup(self):
140+
"""Clean up resources"""
141+
await self.exit_stack.aclose()

eval_protocol/pytest/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from .default_agent_rollout_processor import default_agent_rollout_processor
2+
from .default_no_op_rollout_process import default_no_op_rollout_processor
3+
from .default_single_turn_rollout_process import default_single_turn_rollout_processor
4+
from .pytest_utils import evaluate, evaluation_test
5+
from .types import RolloutProcessor, RolloutProcessorConfig
6+
7+
__all__ = [
8+
"default_agent_rollout_processor",
9+
"default_no_op_rollout_processor",
10+
"default_single_turn_rollout_processor",
11+
"RolloutProcessor",
12+
"RolloutProcessorConfig",
13+
"evaluate",
14+
"evaluation_test",
15+
]
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import json
2+
import os
3+
from typing import Any, List, Optional
4+
5+
from mcp.types import CallToolResult
6+
from openai.types.chat import ChatCompletionMessage, ChatCompletionToolParam
7+
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
8+
9+
from eval_protocol.mcp.execution.policy import LiteLLMPolicy
10+
from eval_protocol.mcp.mcp_multi_client import MCPMultiClient
11+
from eval_protocol.models import EvaluationRow, Message
12+
from eval_protocol.pytest.types import RolloutProcessorConfig
13+
14+
15+
class Agent:
16+
"""
17+
A really simple agent that calls the model until no more tool calls are needed.
18+
"""
19+
20+
def __init__(self, model: str, initial_messages: list[Message], config_path: str):
21+
self.model = model
22+
self.messages: list[Message] = initial_messages
23+
self._policy = LiteLLMPolicy(model_id=model)
24+
self.mcp_client = MCPMultiClient(config_path=config_path) if config_path else None
25+
26+
async def setup(self):
27+
if self.mcp_client:
28+
await self.mcp_client.connect_to_servers()
29+
30+
async def call_agent(self) -> str:
31+
"""
32+
Call the assistant with the user query.
33+
"""
34+
tools = await self.mcp_client.get_available_tools() if self.mcp_client else None
35+
36+
message = await self._call_model(self.messages, tools)
37+
self.messages.append(message)
38+
if message["tool_calls"]:
39+
for tool_call in message["tool_calls"]:
40+
tool_call_id = tool_call["id"]
41+
tool_name = tool_call["function"]["name"]
42+
tool_args = tool_call["function"]["arguments"]
43+
tool_args_dict = json.loads(tool_args)
44+
tool_result = await self.mcp_client.call_tool(tool_name, tool_args_dict)
45+
content = self._get_content_from_tool_result(tool_result)
46+
self.messages.append(
47+
{
48+
"role": "tool",
49+
"content": content,
50+
"tool_call_id": tool_call_id,
51+
}
52+
)
53+
return message["content"]
54+
55+
async def _call_model(
56+
self, messages: list[Message], tools: Optional[list[ChatCompletionToolParam]]
57+
) -> ChatCompletionMessage:
58+
messages = [message.model_dump() if hasattr(message, "model_dump") else message for message in messages]
59+
response = await self._policy._make_llm_call(
60+
messages=messages,
61+
tools=tools,
62+
)
63+
return response["choices"][0]["message"]
64+
65+
def _get_content_from_tool_result(self, tool_result: CallToolResult) -> str:
66+
if tool_result.structuredContent:
67+
return json.dumps(tool_result.structuredContent)
68+
if len(tool_result.content) > 1:
69+
raise NotImplementedError("Multiple content is not supported yet")
70+
first_content = tool_result.content[0]
71+
if first_content.type != "text":
72+
raise NotImplementedError("Non-text content is not supported yet")
73+
return first_content.text
74+
75+
76+
async def default_agent_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]:
77+
agent = Agent(model=config.model, initial_messages=config.initial_messages, config_path=config.mcp_config_path)
78+
await agent.setup()
79+
await agent.call_agent()
80+
return [EvaluationRow(messages=agent.messages)]
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from typing import List
2+
3+
from eval_protocol.models import EvaluationRow
4+
from eval_protocol.pytest.types import ModelParam, RolloutProcessorConfig
5+
6+
7+
def default_no_op_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]:
8+
"""
9+
Simply passes input dataset through to the test function. This can be useful
10+
if you want to run the rollout yourself.
11+
"""
12+
return [row]
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from typing import List
2+
3+
from openai import OpenAI
4+
5+
from eval_protocol.auth import get_fireworks_api_base, get_fireworks_api_key
6+
from eval_protocol.models import CompletionParams, EvaluationRow, InputMetadata, Message
7+
from eval_protocol.pytest.types import ModelParam, RolloutProcessorConfig
8+
9+
10+
def default_single_turn_rollout_processor(row: EvaluationRow, config: RolloutProcessorConfig) -> List[EvaluationRow]:
11+
"""Generate a single response from a Fireworks model."""
12+
13+
api_key = get_fireworks_api_key()
14+
api_base = get_fireworks_api_base()
15+
client = OpenAI(api_key=api_key, base_url=f"{api_base}/inference/v1")
16+
17+
if len(row.messages) == 0:
18+
raise ValueError("Messages is empty. Please provide a non-empty dataset")
19+
20+
messages_payload = [{"role": m.role, "content": m.content} for m in row.messages]
21+
22+
response = client.chat.completions.create(model=config.model, messages=messages_payload, **config.input_params)
23+
assistant_content = response.choices[0].message.content or ""
24+
messages = list(row.messages) + [Message(role="assistant", content=assistant_content)]
25+
processed = EvaluationRow(
26+
messages=messages,
27+
ground_truth=row.ground_truth,
28+
input_metadata=InputMetadata(completion_params=CompletionParams(model=config.model)),
29+
)
30+
return [processed]

0 commit comments

Comments
 (0)