|
11 | 11 | import os |
12 | 12 | import re |
13 | 13 | import json |
| 14 | +import time |
| 15 | +import hashlib |
| 16 | +import threading |
14 | 17 | import httpx |
| 18 | +from collections import OrderedDict |
15 | 19 | from pathlib import Path |
16 | 20 | from typing import Optional |
17 | 21 |
|
18 | 22 | from tooldns.config import logger, TOOLDNS_HOME |
19 | 23 | from tooldns.fetcher import MCPFetcher |
20 | 24 |
|
21 | 25 |
|
| 26 | +# --------------------------------------------------------------------------- |
| 27 | +# MCP Session pool — reuse sessions to avoid handshake on every call |
| 28 | +# --------------------------------------------------------------------------- |
| 29 | + |
| 30 | +_session_lock = threading.Lock() |
| 31 | +_sessions: dict[str, dict] = {} # url -> {"session_id": str, "headers": dict, "client": httpx.Client, "created": float} |
| 32 | +_SESSION_TTL = 300.0 # 5 minutes |
| 33 | + |
| 34 | +# --------------------------------------------------------------------------- |
| 35 | +# Result cache for read-only tool calls |
| 36 | +# --------------------------------------------------------------------------- |
| 37 | + |
| 38 | +_result_cache_lock = threading.Lock() |
| 39 | +_result_cache: OrderedDict[str, tuple[float, dict]] = OrderedDict() |
| 40 | +_RESULT_CACHE_TTL = 600.0 # 10 minutes |
| 41 | +_RESULT_CACHE_MAX = 128 |
| 42 | + |
| 43 | +# Read-only tool name patterns (prefixes that indicate read/list/get operations) |
| 44 | +_READ_ONLY_PREFIXES = ( |
| 45 | + "GMAIL_FETCH", "GMAIL_LIST", "GMAIL_GET", |
| 46 | + "GOOGLECALENDAR_FIND", "GOOGLECALENDAR_LIST", "GOOGLECALENDAR_GET", |
| 47 | + "GOOGLE_CALENDAR_FIND", "GOOGLE_CALENDAR_LIST", "GOOGLE_CALENDAR_GET", |
| 48 | + "REDDIT_GET", "REDDIT_LIST", "REDDIT_FETCH", |
| 49 | + "SLACK_LIST", "SLACK_GET", |
| 50 | + "GITHUB_LIST", "GITHUB_GET", |
| 51 | +) |
| 52 | + |
| 53 | + |
| 54 | +def _is_read_only(tool_name: str) -> bool: |
| 55 | + """Check if a tool call is read-only and safe to cache.""" |
| 56 | + upper = tool_name.upper() |
| 57 | + return any(upper.startswith(p) or upper.endswith(p) for p in _READ_ONLY_PREFIXES) |
| 58 | + |
| 59 | + |
| 60 | +def _cache_key(tool_name: str, arguments: dict) -> str: |
| 61 | + """Generate a stable cache key from tool name + arguments.""" |
| 62 | + raw = f"{tool_name}:{json.dumps(arguments, sort_keys=True)}" |
| 63 | + return hashlib.sha256(raw.encode()).hexdigest() |
| 64 | + |
| 65 | + |
| 66 | +def _get_cached_result(tool_name: str, arguments: dict) -> Optional[dict]: |
| 67 | + """Return cached result if available and not expired.""" |
| 68 | + if not _is_read_only(tool_name): |
| 69 | + return None |
| 70 | + key = _cache_key(tool_name, arguments) |
| 71 | + with _result_cache_lock: |
| 72 | + entry = _result_cache.get(key) |
| 73 | + if entry is None: |
| 74 | + return None |
| 75 | + expires_at, result = entry |
| 76 | + if time.monotonic() > expires_at: |
| 77 | + del _result_cache[key] |
| 78 | + return None |
| 79 | + _result_cache.move_to_end(key) |
| 80 | + logger.debug("Result cache HIT for {}", tool_name) |
| 81 | + return result |
| 82 | + |
| 83 | + |
| 84 | +def _set_cached_result(tool_name: str, arguments: dict, result: dict) -> None: |
| 85 | + """Cache a read-only tool result.""" |
| 86 | + if not _is_read_only(tool_name): |
| 87 | + return |
| 88 | + key = _cache_key(tool_name, arguments) |
| 89 | + with _result_cache_lock: |
| 90 | + _result_cache[key] = (time.monotonic() + _RESULT_CACHE_TTL, result) |
| 91 | + if len(_result_cache) > _RESULT_CACHE_MAX: |
| 92 | + _result_cache.popitem(last=False) |
| 93 | + logger.debug("Result cache SET for {} (ttl={}s)", tool_name, _RESULT_CACHE_TTL) |
| 94 | + |
| 95 | + |
22 | 96 | # --------------------------------------------------------------------------- |
23 | 97 | # Argument resolution for workflows/macros |
24 | 98 | # --------------------------------------------------------------------------- |
@@ -287,65 +361,147 @@ def _lookup_http_config(source_info: dict, database) -> tuple: |
287 | 361 | return None, {} |
288 | 362 |
|
289 | 363 |
|
290 | | -def _http_tool_call(server_url: str, server_headers: dict, |
291 | | - tool_name: str, arguments: dict) -> dict: |
292 | | - """Send a tools/call request to an HTTP MCP server.""" |
| 364 | +def _get_or_create_session(server_url: str, server_headers: dict) -> tuple[httpx.Client, dict]: |
| 365 | + """ |
| 366 | + Get or create a pooled MCP session for the given server URL. |
| 367 | +
|
| 368 | + Reuses existing sessions (with their MCP session ID and persistent |
| 369 | + HTTP connection) to avoid the initialize+notify handshake on every call. |
| 370 | + Sessions expire after _SESSION_TTL seconds. |
| 371 | +
|
| 372 | + Returns (client, headers) ready for tools/call requests. |
| 373 | + """ |
| 374 | + now = time.monotonic() |
| 375 | + |
| 376 | + with _session_lock: |
| 377 | + cached = _sessions.get(server_url) |
| 378 | + if cached and (now - cached["created"]) < _SESSION_TTL: |
| 379 | + return cached["client"], cached["headers"] |
| 380 | + |
| 381 | + # Close old client if expired |
| 382 | + if cached: |
| 383 | + try: |
| 384 | + cached["client"].close() |
| 385 | + except Exception: |
| 386 | + pass |
| 387 | + |
| 388 | + # Create new session outside the lock (handshake is slow) |
293 | 389 | h = { |
294 | 390 | "Content-Type": "application/json", |
295 | 391 | "Accept": "application/json, text/event-stream", |
296 | 392 | **server_headers |
297 | 393 | } |
298 | 394 |
|
299 | | - init_resp = httpx.post( |
300 | | - server_url, headers=h, |
301 | | - json={ |
302 | | - "jsonrpc": "2.0", "id": 1, |
303 | | - "method": "initialize", |
304 | | - "params": { |
305 | | - "protocolVersion": "2024-11-05", |
306 | | - "capabilities": {}, |
307 | | - "clientInfo": {"name": "tooldns-proxy", "version": "1.0.0"} |
308 | | - } |
309 | | - }, |
310 | | - timeout=30 |
311 | | - ) |
312 | | - session_id = init_resp.headers.get("mcp-session-id") |
313 | | - if session_id: |
314 | | - h["mcp-session-id"] = session_id |
315 | | - |
316 | | - httpx.post( |
317 | | - server_url, headers=h, |
318 | | - json={"jsonrpc": "2.0", "method": "notifications/initialized"}, |
319 | | - timeout=10 |
320 | | - ) |
321 | | - |
322 | | - resp = httpx.post( |
323 | | - server_url, headers=h, |
324 | | - json={ |
325 | | - "jsonrpc": "2.0", "id": 2, |
326 | | - "method": "tools/call", |
327 | | - "params": {"name": tool_name, "arguments": arguments} |
328 | | - }, |
329 | | - timeout=60 |
330 | | - ) |
331 | | - resp.raise_for_status() |
| 395 | + client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0)) |
| 396 | + |
| 397 | + try: |
| 398 | + init_resp = client.post( |
| 399 | + server_url, headers=h, |
| 400 | + json={ |
| 401 | + "jsonrpc": "2.0", "id": 1, |
| 402 | + "method": "initialize", |
| 403 | + "params": { |
| 404 | + "protocolVersion": "2024-11-05", |
| 405 | + "capabilities": {}, |
| 406 | + "clientInfo": {"name": "tooldns-proxy", "version": "1.0.0"} |
| 407 | + } |
| 408 | + }, |
| 409 | + ) |
| 410 | + session_id = init_resp.headers.get("mcp-session-id") |
| 411 | + if session_id: |
| 412 | + h["mcp-session-id"] = session_id |
| 413 | + |
| 414 | + client.post( |
| 415 | + server_url, headers=h, |
| 416 | + json={"jsonrpc": "2.0", "method": "notifications/initialized"}, |
| 417 | + timeout=10 |
| 418 | + ) |
| 419 | + except Exception as e: |
| 420 | + logger.warning("MCP session init failed for {}: {}", server_url, e) |
| 421 | + # Continue anyway — some servers don't require handshake |
| 422 | + client.close() |
| 423 | + client = httpx.Client(timeout=httpx.Timeout(60.0, connect=10.0)) |
| 424 | + |
| 425 | + with _session_lock: |
| 426 | + _sessions[server_url] = { |
| 427 | + "session_id": h.get("mcp-session-id"), |
| 428 | + "headers": h, |
| 429 | + "client": client, |
| 430 | + "created": time.monotonic(), |
| 431 | + } |
| 432 | + |
| 433 | + logger.info("MCP session created for {} (session_id={})", server_url[:60], h.get("mcp-session-id", "none")) |
| 434 | + return client, h |
| 435 | + |
| 436 | + |
| 437 | +def _http_tool_call(server_url: str, server_headers: dict, |
| 438 | + tool_name: str, arguments: dict) -> dict: |
| 439 | + """Send a tools/call request to an HTTP MCP server with session pooling and result caching.""" |
| 440 | + |
| 441 | + # Check result cache for read-only tools |
| 442 | + cached = _get_cached_result(tool_name, arguments) |
| 443 | + if cached is not None: |
| 444 | + return cached |
| 445 | + |
| 446 | + t0 = time.monotonic() |
| 447 | + client, h = _get_or_create_session(server_url, server_headers) |
| 448 | + |
| 449 | + try: |
| 450 | + resp = client.post( |
| 451 | + server_url, headers=h, |
| 452 | + json={ |
| 453 | + "jsonrpc": "2.0", "id": 2, |
| 454 | + "method": "tools/call", |
| 455 | + "params": {"name": tool_name, "arguments": arguments} |
| 456 | + }, |
| 457 | + ) |
| 458 | + resp.raise_for_status() |
| 459 | + except (httpx.RemoteProtocolError, httpx.ConnectError, httpx.ReadError) as e: |
| 460 | + # Session may have expired server-side — retry with fresh session |
| 461 | + logger.warning("MCP session error, retrying with fresh session: {}", e) |
| 462 | + with _session_lock: |
| 463 | + _sessions.pop(server_url, None) |
| 464 | + try: |
| 465 | + client.close() |
| 466 | + except Exception: |
| 467 | + pass |
| 468 | + client, h = _get_or_create_session(server_url, server_headers) |
| 469 | + resp = client.post( |
| 470 | + server_url, headers=h, |
| 471 | + json={ |
| 472 | + "jsonrpc": "2.0", "id": 2, |
| 473 | + "method": "tools/call", |
| 474 | + "params": {"name": tool_name, "arguments": arguments} |
| 475 | + }, |
| 476 | + ) |
| 477 | + resp.raise_for_status() |
| 478 | + |
| 479 | + elapsed = time.monotonic() - t0 |
| 480 | + logger.info("Tool {} executed in {:.1f}s", tool_name, elapsed) |
332 | 481 |
|
333 | 482 | content_type = resp.headers.get("content-type", "") |
334 | 483 | if "text/event-stream" in content_type: |
| 484 | + result = None |
335 | 485 | for line in resp.text.split("\n"): |
336 | 486 | line = line.strip() |
337 | 487 | if line.startswith("data:"): |
338 | 488 | data = line[5:].strip() |
339 | 489 | if data: |
340 | 490 | try: |
341 | 491 | parsed = json.loads(data) |
342 | | - return parsed.get("result", parsed) |
| 492 | + result = parsed.get("result", parsed) |
| 493 | + break |
343 | 494 | except Exception: |
344 | 495 | continue |
345 | | - return {"raw": resp.text} |
| 496 | + if result is None: |
| 497 | + result = {"raw": resp.text} |
346 | 498 | else: |
347 | 499 | data = resp.json() |
348 | | - return data.get("result", data) |
| 500 | + result = data.get("result", data) |
| 501 | + |
| 502 | + # Cache read-only results |
| 503 | + _set_cached_result(tool_name, arguments, result) |
| 504 | + return result |
349 | 505 |
|
350 | 506 |
|
351 | 507 | def _resolve_env(val): |
|
0 commit comments