Skip to content

Commit 7d6f5ef

Browse files
gaodan-fangclaude
andcommitted
feat(mcp): add user facts tools, SSE transport hardening, and warmup
Add store_user_facts/retrieve_user_facts MCP tools for durable user identity. Harden SSE transport by replacing private Starlette API (request._send) with public ASGI-level SseEndpoint class, suppressing CancelledError and teardown assertion races, and bumping graceful shutdown timeout to 3s. Add optional warmup on SSE server boot to reduce first-tool-call latency. Exclude build/ from mypy to fix pre-existing stale artifact errors. Test coverage: warmup disabled/failure paths, auth-enabled SSE route wiring, user facts store/retrieve/validation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 42316f3 commit 7d6f5ef

8 files changed

Lines changed: 380 additions & 19 deletions

File tree

altk_evolve/frontend/mcp/__main__.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
11
import argparse
22
import logging
3+
import os
34
import sys
45
import threading
56
import uvicorn
67

7-
from altk_evolve.frontend.mcp.mcp_server import mcp, app
8+
from altk_evolve.frontend.mcp.mcp_server import app, get_client, mcp
89
from altk_evolve.frontend.mcp.http_transport import create_resilient_sse_app
910

1011
logger = logging.getLogger("evolve-mcp")
1112

1213

14+
def _is_truthy_env(name: str, default: bool) -> bool:
15+
raw = os.getenv(name)
16+
if raw is None:
17+
return default
18+
return raw.strip().lower() not in {"0", "false", "no", "off"}
19+
20+
21+
def warmup_mcp_runtime() -> None:
22+
"""Pre-initialize MCP backend state to reduce first-tool-call latency."""
23+
logger.info("Warming up MCP runtime...")
24+
get_client()
25+
logger.info("MCP runtime warmup complete")
26+
27+
1328
def _build_parser() -> argparse.ArgumentParser:
1429
parser = argparse.ArgumentParser(description="Run the Evolve MCP server")
1530
parser.add_argument(
@@ -42,9 +57,22 @@ def run_api_server():
4257

4358

4459
def run_sse_server(host: str, port: int) -> None:
45-
"""Run the MCP server over SSE with disconnect-safe transport handling."""
46-
sse_app = create_resilient_sse_app(mcp)
47-
uvicorn.run(sse_app, host=host, port=port, log_level="warning")
60+
"""Run the MCP server over SSE with disconnect-tolerant teardown."""
61+
if _is_truthy_env("EVOLVE_MCP_WARMUP", True):
62+
try:
63+
warmup_mcp_runtime()
64+
except Exception as exc:
65+
# Keep startup resilient: failed warmup should not block server boot.
66+
logger.warning("MCP warmup failed; continuing without warmup: %s", exc)
67+
68+
uvicorn.run(
69+
create_resilient_sse_app(mcp),
70+
host=host,
71+
port=port,
72+
lifespan="on",
73+
timeout_graceful_shutdown=3,
74+
ws="websockets-sansio",
75+
)
4876

4977

5078
def main():
@@ -61,7 +89,7 @@ def main():
6189
# Start FastMCP using stdio (which blocks)
6290
mcp.run()
6391
else:
64-
run_sse_server(host=args.host, port=args.port)
92+
run_sse_server(args.host, args.port)
6593
except KeyboardInterrupt:
6694
logger.info("MCP server stopped by user (KeyboardInterrupt)")
6795
sys.exit(0)

altk_evolve/frontend/mcp/http_transport.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import logging
45
from collections.abc import AsyncGenerator
56
from contextlib import asynccontextmanager
@@ -10,7 +11,6 @@
1011
from mcp.server.auth.routes import build_resource_metadata_url
1112
from mcp.server.lowlevel.server import LifespanResultT
1213
from mcp.server.sse import SseServerTransport
13-
from starlette.requests import Request
1414
from starlette.responses import Response
1515
from starlette.routing import BaseRoute, Mount, Route
1616

@@ -28,6 +28,13 @@ def _is_benign_disconnect_exception(exc: BaseException) -> bool:
2828
"""Return True when an exception only represents a dropped SSE client."""
2929
if isinstance(exc, (anyio.ClosedResourceError, anyio.BrokenResourceError)):
3030
return True
31+
if isinstance(exc, asyncio.CancelledError):
32+
# Uvicorn cancels outstanding request tasks during Ctrl+C shutdown.
33+
return True
34+
if isinstance(exc, AssertionError) and str(exc) == "Request already responded to":
35+
# MCP low-level request responder can assert during SSE teardown races
36+
# when cancellation/close wins over the normal response path.
37+
return True
3138

3239
if isinstance(exc, BaseExceptionGroup):
3340
return len(exc.exceptions) > 0 and all(_is_benign_disconnect_exception(child) for child in exc.exceptions)
@@ -59,6 +66,23 @@ async def _run_sse_session(
5966
return True
6067

6168

69+
async def _handle_sse(
70+
server: FastMCP[LifespanResultT],
71+
sse: SseServerTransport,
72+
scope,
73+
receive,
74+
send,
75+
) -> None:
76+
"""
77+
Serve SSE directly as ASGI and avoid sending any follow-up HTTP response.
78+
79+
`connect_sse(...)(scope, receive, send)` owns the HTTP response lifecycle.
80+
Returning an additional Response after it exits can race with teardown and
81+
trigger duplicate MCP request completion assertions.
82+
"""
83+
await _run_sse_session(server, sse, scope, receive, send)
84+
85+
6286
def create_resilient_sse_app(
6387
server: FastMCP[LifespanResultT],
6488
message_path: str | None = None,
@@ -78,9 +102,27 @@ def create_resilient_sse_app(
78102

79103
sse = SseServerTransport(message_path)
80104

81-
async def handle_sse(scope, receive, send) -> Response:
82-
await _run_sse_session(server, sse, scope, receive, send)
83-
return Response(status_code=204)
105+
async def handle_sse(scope, receive, send) -> None:
106+
await _handle_sse(server, sse, scope, receive, send)
107+
108+
class SseEndpoint:
109+
"""ASGI app wrapping handle_sse that tracks whether a response was started."""
110+
111+
async def __call__(self, scope, receive, send) -> None:
112+
response_started = False
113+
114+
async def tracked_send(message) -> None:
115+
nonlocal response_started
116+
if message.get("type") == "http.response.start":
117+
response_started = True
118+
await send(message)
119+
120+
await handle_sse(scope, receive, tracked_send)
121+
if not response_started:
122+
response = Response(status_code=204)
123+
await response(scope, receive, send)
124+
125+
sse_endpoint = SseEndpoint()
84126

85127
if auth:
86128
auth_middleware = auth.get_middleware()
@@ -95,7 +137,7 @@ async def handle_sse(scope, receive, send) -> Response:
95137
Route(
96138
sse_path,
97139
endpoint=RequireAuthMiddleware(
98-
handle_sse,
140+
sse_endpoint,
99141
auth.required_scopes,
100142
resource_metadata_url,
101143
),
@@ -113,10 +155,6 @@ async def handle_sse(scope, receive, send) -> Response:
113155
)
114156
)
115157
else:
116-
117-
async def sse_endpoint(request: Request) -> Response:
118-
return await handle_sse(request.scope, request.receive, request._send)
119-
120158
server_routes.append(Route(sse_path, endpoint=sse_endpoint, methods=["GET"]))
121159
server_routes.append(Mount(message_path, app=sse.handle_post_message))
122160

altk_evolve/frontend/mcp/mcp_server.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import threading
1010
import uuid
1111
import os
12+
from typing import Any
1213

1314
from fastmcp import FastMCP
1415
from fastapi import FastAPI
@@ -160,6 +161,22 @@ def get_entities_logic(task: str, entity_type: str = "guideline", include_public
160161
return "\n".join(response_lines)
161162

162163

164+
def _parse_metadata(metadata: str | None) -> dict[str, Any]:
165+
if not metadata:
166+
return {}
167+
168+
try:
169+
parsed = json.loads(metadata)
170+
except json.JSONDecodeError as e:
171+
logger.exception(f"Invalid JSON in metadata parameter: {str(e)}")
172+
raise ValueError(f"Failed to parse metadata: {str(e)}") from e
173+
174+
if not isinstance(parsed, dict):
175+
raise ValueError("Metadata must decode to a JSON object")
176+
177+
return parsed
178+
179+
163180
@mcp.tool()
164181
def get_entities(task: str, entity_type: str = "guideline", include_public: bool = False, limit: int = 10) -> str:
165182
"""
@@ -188,6 +205,74 @@ def get_guidelines(task: str) -> str:
188205
return get_entities_logic(task, "guideline")
189206

190207

208+
@mcp.tool()
209+
def store_user_facts(
210+
user_id: str,
211+
message: str,
212+
metadata: str | None = None,
213+
enable_conflict_resolution: bool = False,
214+
) -> str:
215+
"""Extract and store user facts/preferences for a durable user identity."""
216+
try:
217+
metadata_dict = _parse_metadata(metadata)
218+
except ValueError as e:
219+
return json.dumps(
220+
{
221+
"error": "Invalid JSON",
222+
"message": str(e),
223+
"invalid_metadata": metadata,
224+
}
225+
)
226+
227+
updates = get_client().store_user_facts(
228+
namespace_id=evolve_config.namespace_id,
229+
message=message,
230+
user_id=user_id,
231+
metadata=metadata_dict,
232+
enable_conflict_resolution=enable_conflict_resolution,
233+
)
234+
235+
serialized_updates = [
236+
{
237+
"event": update.event,
238+
"id": update.id,
239+
"type": update.type,
240+
"content": update.content,
241+
"metadata": update.metadata,
242+
}
243+
for update in updates
244+
]
245+
246+
return json.dumps(
247+
{
248+
"user_id": user_id,
249+
"stored_count": len(serialized_updates),
250+
"updates": serialized_updates,
251+
}
252+
)
253+
254+
255+
@mcp.tool()
256+
def retrieve_user_facts(user_id: str, query: str | None = None, limit: int = 5) -> str:
257+
"""Retrieve categorized user facts/preferences for a durable user identity."""
258+
categories = get_client().retrieve_user_facts(
259+
namespace_id=evolve_config.namespace_id,
260+
user_id=user_id,
261+
query=query,
262+
limit=limit,
263+
)
264+
matched_count = sum(len(items) for items in categories.values())
265+
266+
return json.dumps(
267+
{
268+
"user_id": user_id,
269+
"query": query,
270+
"matched_count": matched_count,
271+
"categories": categories,
272+
}
273+
)
274+
275+
191276
@mcp.tool()
192277
def save_trajectory(trajectory_data: str, task_id: str | None = None, owner_id: str | None = None) -> list[RecordedEntity]:
193278
"""

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ warn_unused_configs = true
159159
disallow_untyped_defs = false
160160
explicit_package_bases = true
161161
exclude = [
162+
"build/",
162163
"platform-integrations/",
163164
"examples/",
164165
]

tests/e2e/test_mcp.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,34 @@ async def test_create_entity_with_invalid_json_metadata(mcp):
219219
assert result["error"] == "Invalid JSON"
220220
assert "message" in result
221221
assert "invalid_metadata" in result
222+
223+
224+
@pytest.mark.e2e
225+
async def test_store_and_retrieve_user_facts(mcp):
226+
async with Client(transport=mcp) as evolve_mcp:
227+
store_response = await evolve_mcp.call_tool_mcp(
228+
"store_user_facts",
229+
{
230+
"user_id": "user-123",
231+
"message": "I prefer concise answers with bullet points.",
232+
"metadata": json.dumps({"source": "cuga-lite"}),
233+
"enable_conflict_resolution": False,
234+
},
235+
)
236+
stored = json.loads(store_response.content[0].text)
237+
assert stored["user_id"] == "user-123"
238+
assert stored["stored_count"] >= 1
239+
240+
retrieve_response = await evolve_mcp.call_tool_mcp(
241+
"retrieve_user_facts",
242+
{
243+
"user_id": "user-123",
244+
"query": "How should I format the answer?",
245+
"limit": 5,
246+
},
247+
)
248+
retrieved = json.loads(retrieve_response.content[0].text)
249+
250+
assert retrieved["user_id"] == "user-123"
251+
assert "categories" in retrieved
252+
assert retrieved["matched_count"] >= 0

0 commit comments

Comments
 (0)