-
Notifications
You must be signed in to change notification settings - Fork 17
Expand file tree
/
Copy pathdefault_mcp_gym_rollout_processor.py
More file actions
235 lines (192 loc) · 7.76 KB
/
Copy pathdefault_mcp_gym_rollout_processor.py
File metadata and controls
235 lines (192 loc) · 7.76 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
import asyncio
import atexit
import os
import signal
import socket
import subprocess
import time
from pathlib import Path
from typing import List, Optional
import eval_protocol as ep
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.types import RolloutProcessorConfig
class MCPServerManager:
"""Manages MCP server lifecycle for testing."""
# Class-level tracking of all server instances
_active_servers = []
_cleanup_registered = False
def __init__(self, server_script: str, port: int = 8000, domain: str = "airline"):
self.server_script = server_script
self.port = port
self.domain = domain
self.process: Optional[subprocess.Popen] = None
self.base_dir = Path(".").resolve()
self._log_file = None
self._log_file_path = None
# Register this server for cleanup
MCPServerManager._active_servers.append(self)
# Register cleanup handlers only once
if not MCPServerManager._cleanup_registered:
MCPServerManager._register_cleanup_handlers()
MCPServerManager._cleanup_registered = True
def start(self) -> None:
"""Start the MCP server."""
if self.process:
return
# Set environment for server
env = os.environ.copy()
env["PORT"] = str(self.port)
# Start server process (no domain argument needed for tau2_mcp server)
cmd = ["python", self.server_script, "--port", str(self.port)]
# Setup log file with cleanup
log_file_path = os.path.join(self.base_dir, f"server_output_{self.domain}_{self.port}.log")
if os.path.exists(log_file_path):
os.remove(log_file_path)
log_file = open(log_file_path, "w")
self.process = subprocess.Popen(
cmd,
cwd=self.base_dir,
env=env,
stdout=log_file,
stderr=log_file,
text=True,
)
# Store log file reference for cleanup
self._log_file = log_file
self._log_file_path = log_file_path
# Wait for server to be ready with proper health check
if not self._wait_for_server_ready(timeout=15):
try:
with open(self._log_file_path, "r") as f:
log_content = f.read()
print(f"❌ Server failed to start!")
print(f"📋 Server log ({self._log_file_path}):")
print("=" * 50)
print(log_content)
print("=" * 50)
raise RuntimeError(f"Server failed to start or become ready. Check log above for details.")
except Exception as e:
stdout, stderr = self.process.communicate()
raise RuntimeError(f"Server failed to start or become ready. stderr: {stderr}, log error: {e}")
print(f"✅ Server started successfully on port {self.port}")
def _wait_for_server_ready(self, timeout: int = 15) -> bool:
"""
Wait for server to be ready by polling socket connection.
"""
start_time = time.time()
health_check_failures = 0
while time.time() - start_time < timeout:
# Check if process is still running
if self.process.poll() is not None:
print(f"Server process exited early")
return False
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.settimeout(1)
result = s.connect_ex(("localhost", self.port))
if result == 0:
time.sleep(0.5)
return True
except Exception as e:
health_check_failures += 1
# Print first few failures for debugging
if health_check_failures <= 3:
print(f"Health check failed: {e}")
# Wait before next check
time.sleep(0.1)
print(f"Server failed to become ready within {timeout} seconds")
return False
def stop(self) -> None:
"""Stop the MCP server."""
if self.process:
print(f"🛑 Stopping server on port {self.port}...")
self.process.terminate()
try:
self.process.wait(timeout=5)
except subprocess.TimeoutExpired:
print(f"⚡ Force killing server on port {self.port}...")
self.process.kill()
self.process.wait()
self.process = None
# Clean up log file
if self._log_file:
try:
self._log_file.close()
except Exception:
pass
self._log_file = None
# Remove from active servers list
if self in MCPServerManager._active_servers:
MCPServerManager._active_servers.remove(self)
@classmethod
def _cleanup_all_servers(cls):
"""Clean up all active servers on exit"""
print(f"\n🧹 Cleaning up {len(cls._active_servers)} active servers...")
for server in cls._active_servers.copy():
try:
server.stop()
except Exception as e:
print(f"⚠️ Error stopping server: {e}")
cls._active_servers.clear()
@classmethod
def _signal_handler(cls, signum, frame):
"""Handle interrupt signals"""
print(f"\n🛑 Received signal {signum}, cleaning up...")
cls._cleanup_all_servers()
exit(1)
@classmethod
def _register_cleanup_handlers(cls):
"""Register cleanup handlers - called only once"""
atexit.register(cls._cleanup_all_servers)
signal.signal(signal.SIGINT, cls._signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, cls._signal_handler) # Termination signal
def __enter__(self):
"""Context manager entry"""
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - ensures cleanup even on exceptions"""
self.stop()
if exc_type:
print(f"⚠️ Server cleanup after exception: {exc_type.__name__}")
return False # Don't suppress exceptions
async def default_mcp_gym_rollout_processor(
rows: List[EvaluationRow], config: RolloutProcessorConfig
) -> List[EvaluationRow]:
"""
Rollout processor for tau bench environments.
This processor starts an MCP server, creates tau bench environments, and runs rollouts
using the eval_protocol framework, following the pattern from test_tau2_e2e.py.
Args:
rows: List of EvaluationRow objects containing messages and dataset info in input_metadata
config: RolloutProcessorConfig with model and other parameters
Returns:
List of EvaluationRow objects with completed conversations
"""
server = MCPServerManager(config.server_script_path, port=9700)
try:
server.start()
policy = ep.LiteLLMPolicy(
model_id=config.model,
temperature=config.input_params.get("temperature", 0.0),
max_tokens=config.input_params.get("max_tokens", 4096),
reasoning_effort=config.input_params.get("reasoning_effort", None),
)
# Create MCP environments directly from evaluation_rows
envs = await ep.make(
"http://localhost:9700/mcp/",
evaluation_rows=rows,
model_id=policy.model_id,
)
# Run rollout with environments and policy
evaluation_rows = await ep.rollout(
envs,
policy=policy,
evaluation_rows=rows,
steps=config.steps,
max_concurrent_rollouts=config.max_concurrent_rollouts,
)
return evaluation_rows
finally:
# Always clean up the server
server.stop()