Skip to content

Commit ff42f74

Browse files
committed
feat: Introduce session pooling for HTTP MCP calls and result caching for read-only tools.
1 parent 086344d commit ff42f74

1 file changed

Lines changed: 195 additions & 39 deletions

File tree

tooldns/caller.py

Lines changed: 195 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,88 @@
1111
import os
1212
import re
1313
import json
14+
import time
15+
import hashlib
16+
import threading
1417
import httpx
18+
from collections import OrderedDict
1519
from pathlib import Path
1620
from typing import Optional
1721

1822
from tooldns.config import logger, TOOLDNS_HOME
1923
from tooldns.fetcher import MCPFetcher
2024

2125

26+
# ---------------------------------------------------------------------------
27+
# MCP Session pool — reuse sessions to avoid handshake on every call
28+
# ---------------------------------------------------------------------------
29+
30+
_session_lock = threading.Lock()
31+
_sessions: dict[str, dict] = {} # url -> {"session_id": str, "headers": dict, "client": httpx.Client, "created": float}
32+
_SESSION_TTL = 300.0 # 5 minutes
33+
34+
# ---------------------------------------------------------------------------
35+
# Result cache for read-only tool calls
36+
# ---------------------------------------------------------------------------
37+
38+
_result_cache_lock = threading.Lock()
39+
_result_cache: OrderedDict[str, tuple[float, dict]] = OrderedDict()
40+
_RESULT_CACHE_TTL = 600.0 # 10 minutes
41+
_RESULT_CACHE_MAX = 128
42+
43+
# Read-only tool name patterns (prefixes that indicate read/list/get operations)
44+
_READ_ONLY_PREFIXES = (
45+
"GMAIL_FETCH", "GMAIL_LIST", "GMAIL_GET",
46+
"GOOGLECALENDAR_FIND", "GOOGLECALENDAR_LIST", "GOOGLECALENDAR_GET",
47+
"GOOGLE_CALENDAR_FIND", "GOOGLE_CALENDAR_LIST", "GOOGLE_CALENDAR_GET",
48+
"REDDIT_GET", "REDDIT_LIST", "REDDIT_FETCH",
49+
"SLACK_LIST", "SLACK_GET",
50+
"GITHUB_LIST", "GITHUB_GET",
51+
)
52+
53+
54+
def _is_read_only(tool_name: str) -> bool:
55+
"""Check if a tool call is read-only and safe to cache."""
56+
upper = tool_name.upper()
57+
return any(upper.startswith(p) or upper.endswith(p) for p in _READ_ONLY_PREFIXES)
58+
59+
60+
def _cache_key(tool_name: str, arguments: dict) -> str:
61+
"""Generate a stable cache key from tool name + arguments."""
62+
raw = f"{tool_name}:{json.dumps(arguments, sort_keys=True)}"
63+
return hashlib.sha256(raw.encode()).hexdigest()
64+
65+
66+
def _get_cached_result(tool_name: str, arguments: dict) -> Optional[dict]:
67+
"""Return cached result if available and not expired."""
68+
if not _is_read_only(tool_name):
69+
return None
70+
key = _cache_key(tool_name, arguments)
71+
with _result_cache_lock:
72+
entry = _result_cache.get(key)
73+
if entry is None:
74+
return None
75+
expires_at, result = entry
76+
if time.monotonic() > expires_at:
77+
del _result_cache[key]
78+
return None
79+
_result_cache.move_to_end(key)
80+
logger.debug("Result cache HIT for {}", tool_name)
81+
return result
82+
83+
84+
def _set_cached_result(tool_name: str, arguments: dict, result: dict) -> None:
85+
"""Cache a read-only tool result."""
86+
if not _is_read_only(tool_name):
87+
return
88+
key = _cache_key(tool_name, arguments)
89+
with _result_cache_lock:
90+
_result_cache[key] = (time.monotonic() + _RESULT_CACHE_TTL, result)
91+
if len(_result_cache) > _RESULT_CACHE_MAX:
92+
_result_cache.popitem(last=False)
93+
logger.debug("Result cache SET for {} (ttl={}s)", tool_name, _RESULT_CACHE_TTL)
94+
95+
2296
# ---------------------------------------------------------------------------
2397
# Argument resolution for workflows/macros
2498
# ---------------------------------------------------------------------------
@@ -287,65 +361,147 @@ def _lookup_http_config(source_info: dict, database) -> tuple:
287361
return None, {}
288362

289363

290-
def _http_tool_call(server_url: str, server_headers: dict,
291-
tool_name: str, arguments: dict) -> dict:
292-
"""Send a tools/call request to an HTTP MCP server."""
364+
def _get_or_create_session(server_url: str, server_headers: dict) -> tuple[httpx.Client, dict]:
365+
"""
366+
Get or create a pooled MCP session for the given server URL.
367+
368+
Reuses existing sessions (with their MCP session ID and persistent
369+
HTTP connection) to avoid the initialize+notify handshake on every call.
370+
Sessions expire after _SESSION_TTL seconds.
371+
372+
Returns (client, headers) ready for tools/call requests.
373+
"""
374+
now = time.monotonic()
375+
376+
with _session_lock:
377+
cached = _sessions.get(server_url)
378+
if cached and (now - cached["created"]) < _SESSION_TTL:
379+
return cached["client"], cached["headers"]
380+
381+
# Close old client if expired
382+
if cached:
383+
try:
384+
cached["client"].close()
385+
except Exception:
386+
pass
387+
388+
# Create new session outside the lock (handshake is slow)
293389
h = {
294390
"Content-Type": "application/json",
295391
"Accept": "application/json, text/event-stream",
296392
**server_headers
297393
}
298394

299-
init_resp = httpx.post(
300-
server_url, headers=h,
301-
json={
302-
"jsonrpc": "2.0", "id": 1,
303-
"method": "initialize",
304-
"params": {
305-
"protocolVersion": "2024-11-05",
306-
"capabilities": {},
307-
"clientInfo": {"name": "tooldns-proxy", "version": "1.0.0"}
308-
}
309-
},
310-
timeout=30
311-
)
312-
session_id = init_resp.headers.get("mcp-session-id")
313-
if session_id:
314-
h["mcp-session-id"] = session_id
315-
316-
httpx.post(
317-
server_url, headers=h,
318-
json={"jsonrpc": "2.0", "method": "notifications/initialized"},
319-
timeout=10
320-
)
321-
322-
resp = httpx.post(
323-
server_url, headers=h,
324-
json={
325-
"jsonrpc": "2.0", "id": 2,
326-
"method": "tools/call",
327-
"params": {"name": tool_name, "arguments": arguments}
328-
},
329-
timeout=60
330-
)
331-
resp.raise_for_status()
395+
client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0))
396+
397+
try:
398+
init_resp = client.post(
399+
server_url, headers=h,
400+
json={
401+
"jsonrpc": "2.0", "id": 1,
402+
"method": "initialize",
403+
"params": {
404+
"protocolVersion": "2024-11-05",
405+
"capabilities": {},
406+
"clientInfo": {"name": "tooldns-proxy", "version": "1.0.0"}
407+
}
408+
},
409+
)
410+
session_id = init_resp.headers.get("mcp-session-id")
411+
if session_id:
412+
h["mcp-session-id"] = session_id
413+
414+
client.post(
415+
server_url, headers=h,
416+
json={"jsonrpc": "2.0", "method": "notifications/initialized"},
417+
timeout=10
418+
)
419+
except Exception as e:
420+
logger.warning("MCP session init failed for {}: {}", server_url, e)
421+
# Continue anyway — some servers don't require handshake
422+
client.close()
423+
client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0))
424+
425+
with _session_lock:
426+
_sessions[server_url] = {
427+
"session_id": h.get("mcp-session-id"),
428+
"headers": h,
429+
"client": client,
430+
"created": time.monotonic(),
431+
}
432+
433+
logger.info("MCP session created for {} (session_id={})", server_url[:60], h.get("mcp-session-id", "none"))
434+
return client, h
435+
436+
437+
def _http_tool_call(server_url: str, server_headers: dict,
438+
tool_name: str, arguments: dict) -> dict:
439+
"""Send a tools/call request to an HTTP MCP server with session pooling and result caching."""
440+
441+
# Check result cache for read-only tools
442+
cached = _get_cached_result(tool_name, arguments)
443+
if cached is not None:
444+
return cached
445+
446+
t0 = time.monotonic()
447+
client, h = _get_or_create_session(server_url, server_headers)
448+
449+
try:
450+
resp = client.post(
451+
server_url, headers=h,
452+
json={
453+
"jsonrpc": "2.0", "id": 2,
454+
"method": "tools/call",
455+
"params": {"name": tool_name, "arguments": arguments}
456+
},
457+
)
458+
resp.raise_for_status()
459+
except (httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadError) as e:
460+
# Session may have expired server-side — retry with fresh session
461+
logger.warning("MCP session error, retrying with fresh session: {}", e)
462+
with _session_lock:
463+
_sessions.pop(server_url, None)
464+
try:
465+
client.close()
466+
except Exception:
467+
pass
468+
client, h = _get_or_create_session(server_url, server_headers)
469+
resp = client.post(
470+
server_url, headers=h,
471+
json={
472+
"jsonrpc": "2.0", "id": 2,
473+
"method": "tools/call",
474+
"params": {"name": tool_name, "arguments": arguments}
475+
},
476+
)
477+
resp.raise_for_status()
478+
479+
elapsed = time.monotonic() - t0
480+
logger.info("Tool {} executed in {:.1f}s", tool_name, elapsed)
332481

333482
content_type = resp.headers.get("content-type", "")
334483
if "text/event-stream" in content_type:
484+
result = None
335485
for line in resp.text.split("\n"):
336486
line = line.strip()
337487
if line.startswith("data:"):
338488
data = line[5:].strip()
339489
if data:
340490
try:
341491
parsed = json.loads(data)
342-
return parsed.get("result", parsed)
492+
result = parsed.get("result", parsed)
493+
break
343494
except Exception:
344495
continue
345-
return {"raw": resp.text}
496+
if result is None:
497+
result = {"raw": resp.text}
346498
else:
347499
data = resp.json()
348-
return data.get("result", data)
500+
result = data.get("result", data)
501+
502+
# Cache read-only results
503+
_set_cached_result(tool_name, arguments, result)
504+
return result
349505

350506

351507
def _resolve_env(val):

0 commit comments

Comments
 (0)