Skip to content

Commit e7149c5

Browse files
xzrderekDylan Huang
andauthored
Optimizing High Concurrency (#85)
* working * changing tests * updating llm usage * bug with accessing msg * temp * not finished yet * adding tau2 checks * removing erroneous tau2 subfolder * remove workflows folder * test * revert * fix test * tool warmup + concurrent uvicorn server --------- Co-authored-by: Dylan Huang <dhuang@fireworks.ai>
1 parent 01b7476 commit e7149c5

2 files changed

Lines changed: 95 additions & 26 deletions

File tree

eval_protocol/mcp/client/connection.py

Lines changed: 67 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
class MCPConnectionManager:
2424
"""Manages MCP client connections and session lifecycle."""
2525

26+
def __init__(self):
27+
self._tools_cache: Dict[str, List[Dict]] = {}
28+
self._tools_cache_lock = asyncio.Lock()
29+
2630
async def initialize_session(self, session: MCPSession) -> None:
2731
"""
2832
Initialize a persistent MCP session.
@@ -99,9 +103,40 @@ async def initialize_session(self, session: MCPSession) -> None:
99103
session.session_id = server_session_id
100104
logger.debug(f"Updated session ID to match server: {server_session_id}")
101105

106+
# PRE-WARM: Discover and cache tools immediately after session initialization
107+
# This prevents concurrent list_tools() calls later
108+
await self._prewarm_tools_cache(session)
109+
110+
async def _prewarm_tools_cache(self, session: MCPSession) -> None:
111+
"""
112+
Pre-warm the tools cache for this session's base URL.
113+
This prevents concurrent list_tools() calls during discover_tools().
114+
"""
115+
cache_key = session.base_url
116+
117+
async with self._tools_cache_lock:
118+
# Only fetch tools if not already cached for this base_url
119+
if cache_key not in self._tools_cache:
120+
logger.debug(f"Pre-warming tools cache for {cache_key}")
121+
tools_response = await session._mcp_session.list_tools()
122+
tools = tools_response.tools if hasattr(tools_response, "tools") else []
123+
124+
tool_schemas = []
125+
for tool in tools:
126+
tool_schema = {
127+
"name": tool.name,
128+
"description": tool.description,
129+
"input_schema": (tool.inputSchema if hasattr(tool, "inputSchema") else {}),
130+
}
131+
tool_schemas.append(tool_schema)
132+
133+
self._tools_cache[cache_key] = tool_schemas
134+
logger.debug(f"✅ PRE-WARMED {len(tool_schemas)} tools for{cache_key}")
135+
102136
async def discover_tools(self, session: MCPSession) -> List[Dict]:
103137
"""
104138
Discover available tools from an MCP session.
139+
Now uses pre-warmed cache to avoid concurrent list_tools() calls.
105140
106141
Args:
107142
session: The MCPSession to discover tools from
@@ -112,9 +147,19 @@ async def discover_tools(self, session: MCPSession) -> List[Dict]:
112147
if not session._mcp_session:
113148
raise RuntimeError("Session not initialized")
114149

150+
cache_key = session.base_url
151+
152+
# Check cache first (should be pre-warmed during initialization)
153+
async with self._tools_cache_lock:
154+
if cache_key in self._tools_cache:
155+
cached_tools = self._tools_cache[cache_key]
156+
logger.debug(f"Using cached tools for session {session.session_id} ({len(cached_tools)} tools)")
157+
return cached_tools
158+
159+
# Fallback: if cache miss (shouldn't happen with pre-warming), fetch directly
160+
logger.warning(f"Cache miss for {cache_key} - this shouldn't happen with pre-warming")
115161
mcp_session = session._mcp_session
116162

117-
# Get available tools from MCP server
118163
tools_response = await mcp_session.list_tools()
119164
tools = tools_response.tools if hasattr(tools_response, "tools") else []
120165

@@ -129,8 +174,26 @@ async def discover_tools(self, session: MCPSession) -> List[Dict]:
129174
}
130175
tool_schemas.append(tool_schema)
131176

177+
# Cache the result for future use
178+
async with self._tools_cache_lock:
179+
self._tools_cache[cache_key] = tool_schemas
180+
132181
return tool_schemas
133182

183+
def clear_tools_cache(self, base_url: Optional[str] = None):
184+
"""
185+
Clear the tools cache for debugging or when server tools change.
186+
187+
Args:
188+
base_url: If provided, clear cache only for this URL. If None, clear all.
189+
"""
190+
if base_url:
191+
self._tools_cache.pop(base_url, None)
192+
logger.debug(f"Cleared tools cache for {base_url}")
193+
else:
194+
self._tools_cache.clear()
195+
logger.debug("Cleared all tools cache")
196+
134197
async def get_initial_state(self, session: MCPSession) -> Any:
135198
"""
136199
Get initial state from session-aware control plane endpoint.
@@ -160,8 +223,9 @@ async def get_initial_state(self, session: MCPSession) -> Any:
160223

161224
# Query initial state endpoint
162225
try:
163-
# Use shorter timeout for playback mode
164-
timeout = 3.0 if hasattr(session, "_is_playback_mode") and session._is_playback_mode else 5.0
226+
# Use shorter timeout for playback mode, longer timeout for high-concurrency initialization
227+
# (50+ concurrent sessions need more time for initial state setup)
228+
timeout = 3.0 if hasattr(session, "_is_playback_mode") and session._is_playback_mode else 15.0
165229
async with httpx.AsyncClient(timeout=timeout) as client:
166230
initial_state_response = await client.get(
167231
f"{base_url}/control/initial_state",

eval_protocol/mcp/mcpgym.py

Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
- Session-aware control plane endpoints via @control_plane_endpoint decorator
1313
"""
1414

15+
import asyncio
1516
import hashlib
1617
import inspect
1718
import json
@@ -21,6 +22,7 @@
2122
from abc import ABC, abstractmethod
2223
from typing import Any, Callable, Dict, Optional, Tuple
2324

25+
import uvicorn
2426
from mcp.server.fastmcp import Context, FastMCP
2527
from starlette.requests import Request
2628
from starlette.responses import JSONResponse
@@ -553,29 +555,32 @@ def format_observation(self, obs: Any, env: Any) -> Dict[str, Any]:
553555
return {"observation": serialized_obs}
554556

555557
def run(self, transport: str = "streamable-http", **kwargs):
556-
"""
557-
Run the unified MCP-Gym server.
558-
559-
Args:
560-
transport: MCP transport protocol ("stdio", "sse", "streamable-http")
561-
**kwargs: Additional arguments passed to FastMCP.run()
562-
"""
563-
print(f"🚀 {self.mcp.name} MCP-Gym Server Starting...")
564-
print(f"📡 Transport: {transport}")
565-
print("🎯 MCP Pattern: HTTP endpoints for control plane, tools for data plane")
566-
print("🔗 Session-aware control plane endpoints:")
567-
568-
# List registered control plane endpoints
569-
for endpoint_name, endpoint_func in self._control_plane_endpoints.items():
570-
print(f" - {endpoint_name}: {endpoint_func._control_plane_path}")
571-
572-
if not self._control_plane_endpoints:
573-
print(" - No control plane endpoints registered")
574-
575-
print()
576-
577-
# Run the unified server
578-
self.mcp.run(transport=transport, **kwargs)
558+
"""Run the unified MCP-Gym server with high concurrency settings."""
559+
if transport == "streamable-http":
560+
# Run with custom high-concurrency uvicorn config
561+
562+
async def run_with_high_concurrency():
563+
starlette_app = self.mcp.streamable_http_app()
564+
565+
config = uvicorn.Config(
566+
starlette_app,
567+
host=self.mcp.settings.host,
568+
port=self.mcp.settings.port,
569+
log_level=self.mcp.settings.log_level.lower(),
570+
# HIGH CONCURRENCY SETTINGS
571+
limit_concurrency=200, # Increase for HTTP endpoints + MCP
572+
limit_max_requests=100000, # Higher request limit
573+
timeout_keep_alive=120, # Longer keep-alive for control plane
574+
timeout_notify=180,
575+
h11_max_incomplete_event_size=4 * 1024 * 1024, # Handle larger events
576+
)
577+
server = uvicorn.Server(config)
578+
await server.serve()
579+
580+
asyncio.run(run_with_high_concurrency())
581+
else:
582+
# Use default FastMCP run for other transports
583+
self.mcp.run(transport=transport, **kwargs)
579584

580585
def _to_json_serializable(self, obj: Any) -> Any:
581586
"""Convert any object to JSON-serializable format.

0 commit comments

Comments
 (0)