Skip to content

Commit d4cb54d

Browse files
gaodan-fangclaude
andcommitted
feat(mcp): add user facts tools, SSE transport hardening, and warmup
Add store_user_facts/retrieve_user_facts MCP tools for durable user identity. Harden SSE transport by replacing private Starlette API (request._send) with public ASGI-level SseEndpoint class, suppressing CancelledError and teardown assertion races, and bumping graceful shutdown timeout to 3s. Add optional warmup on SSE server boot to reduce first-tool-call latency. Exclude build/ from mypy to fix pre-existing stale artifact errors. Test coverage: warmup disabled/failure paths, auth-enabled SSE route wiring, user facts store/retrieve/validation. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent ea24a31 commit d4cb54d

8 files changed

Lines changed: 386 additions & 20 deletions

File tree

altk_evolve/frontend/mcp/__main__.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
11
import argparse
22
import logging
3+
import os
34
import sys
45
import threading
56
import uvicorn
67

7-
from altk_evolve.frontend.mcp.mcp_server import mcp, app
8+
from altk_evolve.frontend.mcp.mcp_server import app, get_client, mcp
89
from altk_evolve.frontend.mcp.http_transport import create_resilient_sse_app
910

1011
logger = logging.getLogger("evolve-mcp")
1112

1213

14+
def _is_truthy_env(name: str, default: bool) -> bool:
15+
raw = os.getenv(name)
16+
if raw is None:
17+
return default
18+
return raw.strip().lower() not in {"0", "false", "no", "off"}
19+
20+
21+
def warmup_mcp_runtime() -> None:
22+
"""Pre-initialize MCP backend state to reduce first-tool-call latency."""
23+
logger.info("Warming up MCP runtime...")
24+
get_client()
25+
logger.info("MCP runtime warmup complete")
26+
27+
1328
def _build_parser() -> argparse.ArgumentParser:
1429
parser = argparse.ArgumentParser(description="Run the Evolve MCP server")
1530
parser.add_argument(
@@ -42,9 +57,22 @@ def run_api_server():
4257

4358

4459
def run_sse_server(host: str, port: int) -> None:
45-
"""Run the MCP server over SSE with disconnect-safe transport handling."""
46-
sse_app = create_resilient_sse_app(mcp)
47-
uvicorn.run(sse_app, host=host, port=port, log_level="warning")
60+
"""Run the MCP server over SSE with disconnect-tolerant teardown."""
61+
if _is_truthy_env("EVOLVE_MCP_WARMUP", True):
62+
try:
63+
warmup_mcp_runtime()
64+
except Exception as exc:
65+
# Keep startup resilient: failed warmup should not block server boot.
66+
logger.warning("MCP warmup failed; continuing without warmup: %s", exc)
67+
68+
uvicorn.run(
69+
create_resilient_sse_app(mcp),
70+
host=host,
71+
port=port,
72+
lifespan="on",
73+
timeout_graceful_shutdown=3,
74+
ws="websockets-sansio",
75+
)
4876

4977

5078
def main():
@@ -61,7 +89,7 @@ def main():
6189
# Start FastMCP using stdio (which blocks)
6290
mcp.run()
6391
else:
64-
run_sse_server(host=args.host, port=args.port)
92+
run_sse_server(args.host, args.port)
6593
except KeyboardInterrupt:
6694
logger.info("MCP server stopped by user (KeyboardInterrupt)")
6795
sys.exit(0)

altk_evolve/frontend/mcp/http_transport.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import logging
45
from collections.abc import AsyncGenerator
56
from contextlib import asynccontextmanager
@@ -10,7 +11,6 @@
1011
from mcp.server.auth.routes import build_resource_metadata_url
1112
from mcp.server.lowlevel.server import LifespanResultT
1213
from mcp.server.sse import SseServerTransport
13-
from starlette.requests import Request
1414
from starlette.responses import Response
1515
from starlette.routing import BaseRoute, Mount, Route
1616

@@ -28,6 +28,13 @@ def _is_benign_disconnect_exception(exc: BaseException) -> bool:
2828
"""Return True when an exception only represents a dropped SSE client."""
2929
if isinstance(exc, (anyio.ClosedResourceError, anyio.BrokenResourceError)):
3030
return True
31+
if isinstance(exc, asyncio.CancelledError):
32+
# Uvicorn cancels outstanding request tasks during Ctrl+C shutdown.
33+
return True
34+
if isinstance(exc, AssertionError) and str(exc) == "Request already responded to":
35+
# MCP low-level request responder can assert during SSE teardown races
36+
# when cancellation/close wins over the normal response path.
37+
return True
3138

3239
if isinstance(exc, BaseExceptionGroup):
3340
return len(exc.exceptions) > 0 and all(_is_benign_disconnect_exception(child) for child in exc.exceptions)
@@ -59,6 +66,23 @@ async def _run_sse_session(
5966
return True
6067

6168

69+
async def _handle_sse(
70+
server: FastMCP[LifespanResultT],
71+
sse: SseServerTransport,
72+
scope,
73+
receive,
74+
send,
75+
) -> None:
76+
"""
77+
Serve SSE directly as ASGI and avoid sending any follow-up HTTP response.
78+
79+
`connect_sse(...)(scope, receive, send)` owns the HTTP response lifecycle.
80+
Returning an additional Response after it exits can race with teardown and
81+
trigger duplicate MCP request completion assertions.
82+
"""
83+
await _run_sse_session(server, sse, scope, receive, send)
84+
85+
6286
def create_resilient_sse_app(
6387
server: FastMCP[LifespanResultT],
6488
message_path: str | None = None,
@@ -78,9 +102,27 @@ def create_resilient_sse_app(
78102

79103
sse = SseServerTransport(message_path)
80104

81-
async def handle_sse(scope, receive, send) -> Response:
82-
await _run_sse_session(server, sse, scope, receive, send)
83-
return Response(status_code=204)
105+
async def handle_sse(scope, receive, send) -> None:
106+
await _handle_sse(server, sse, scope, receive, send)
107+
108+
class SseEndpoint:
109+
"""ASGI app wrapping handle_sse that tracks whether a response was started."""
110+
111+
async def __call__(self, scope, receive, send) -> None:
112+
response_started = False
113+
114+
async def tracked_send(message) -> None:
115+
nonlocal response_started
116+
if message.get("type") == "http.response.start":
117+
response_started = True
118+
await send(message)
119+
120+
await handle_sse(scope, receive, tracked_send)
121+
if not response_started:
122+
response = Response(status_code=204)
123+
await response(scope, receive, send)
124+
125+
sse_endpoint = SseEndpoint()
84126

85127
if auth:
86128
auth_middleware = auth.get_middleware()
@@ -95,7 +137,7 @@ async def handle_sse(scope, receive, send) -> Response:
95137
Route(
96138
sse_path,
97139
endpoint=RequireAuthMiddleware(
98-
handle_sse,
140+
sse_endpoint,
99141
auth.required_scopes,
100142
resource_metadata_url,
101143
),
@@ -113,10 +155,6 @@ async def handle_sse(scope, receive, send) -> Response:
113155
)
114156
)
115157
else:
116-
117-
async def sse_endpoint(request: Request) -> Response:
118-
return await handle_sse(request.scope, request.receive, request._send)
119-
120158
server_routes.append(Route(sse_path, endpoint=sse_endpoint, methods=["GET"]))
121159
server_routes.append(Mount(message_path, app=sse.handle_post_message))
122160

altk_evolve/frontend/mcp/mcp_server.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import threading
1010
import uuid
1111
import os
12+
from typing import Any
1213

1314
from fastmcp import FastMCP
1415
from fastapi import FastAPI
@@ -214,6 +215,22 @@ def get_entities_logic(
214215
return "\n".join(response_lines)
215216

216217

218+
def _parse_metadata(metadata: str | None) -> dict[str, Any]:
219+
if not metadata:
220+
return {}
221+
222+
try:
223+
parsed = json.loads(metadata)
224+
except json.JSONDecodeError as e:
225+
logger.exception(f"Invalid JSON in metadata parameter: {str(e)}")
226+
raise ValueError(f"Failed to parse metadata: {str(e)}") from e
227+
228+
if not isinstance(parsed, dict):
229+
raise ValueError("Metadata must decode to a JSON object")
230+
231+
return parsed
232+
233+
217234
@mcp.tool()
218235
def get_entities(
219236
task: str,
@@ -261,6 +278,74 @@ def get_guidelines(
261278
return get_entities_logic(task, "guideline", user_id=user_id, namespace_id=namespace_id, session_id=session_id)
262279

263280

281+
@mcp.tool()
282+
def store_user_facts(
283+
user_id: str,
284+
message: str,
285+
metadata: str | None = None,
286+
enable_conflict_resolution: bool = False,
287+
) -> str:
288+
"""Extract and store user facts/preferences for a durable user identity."""
289+
try:
290+
metadata_dict = _parse_metadata(metadata)
291+
except ValueError as e:
292+
return json.dumps(
293+
{
294+
"error": "Invalid JSON",
295+
"message": str(e),
296+
"invalid_metadata": metadata,
297+
}
298+
)
299+
300+
updates = get_client().store_user_facts(
301+
namespace_id=evolve_config.namespace_id,
302+
message=message,
303+
user_id=user_id,
304+
metadata=metadata_dict,
305+
enable_conflict_resolution=enable_conflict_resolution,
306+
)
307+
308+
serialized_updates = [
309+
{
310+
"event": update.event,
311+
"id": update.id,
312+
"type": update.type,
313+
"content": update.content,
314+
"metadata": update.metadata,
315+
}
316+
for update in updates
317+
]
318+
319+
return json.dumps(
320+
{
321+
"user_id": user_id,
322+
"stored_count": len(serialized_updates),
323+
"updates": serialized_updates,
324+
}
325+
)
326+
327+
328+
@mcp.tool()
329+
def retrieve_user_facts(user_id: str, query: str | None = None, limit: int = 5) -> str:
330+
"""Retrieve categorized user facts/preferences for a durable user identity."""
331+
categories = get_client().retrieve_user_facts(
332+
namespace_id=evolve_config.namespace_id,
333+
user_id=user_id,
334+
query=query,
335+
limit=limit,
336+
)
337+
matched_count = sum(len(items) for items in categories.values())
338+
339+
return json.dumps(
340+
{
341+
"user_id": user_id,
342+
"query": query,
343+
"matched_count": matched_count,
344+
"categories": categories,
345+
}
346+
)
347+
348+
264349
@mcp.tool()
265350
def save_trajectory(
266351
trajectory_data: str,

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ warn_unused_configs = true
159159
disallow_untyped_defs = false
160160
explicit_package_bases = true
161161
exclude = [
162+
"build/",
162163
"platform-integrations/",
163164
"examples/",
164165
]

tests/e2e/test_mcp.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,34 @@ async def test_create_entity_with_invalid_json_metadata(mcp):
219219
assert result["error"] == "Invalid JSON"
220220
assert "message" in result
221221
assert "invalid_metadata" in result
222+
223+
224+
@pytest.mark.e2e
225+
async def test_store_and_retrieve_user_facts(mcp):
226+
async with Client(transport=mcp) as evolve_mcp:
227+
store_response = await evolve_mcp.call_tool_mcp(
228+
"store_user_facts",
229+
{
230+
"user_id": "user-123",
231+
"message": "I prefer concise answers with bullet points.",
232+
"metadata": json.dumps({"source": "cuga-lite"}),
233+
"enable_conflict_resolution": False,
234+
},
235+
)
236+
stored = json.loads(store_response.content[0].text)
237+
assert stored["user_id"] == "user-123"
238+
assert stored["stored_count"] >= 1
239+
240+
retrieve_response = await evolve_mcp.call_tool_mcp(
241+
"retrieve_user_facts",
242+
{
243+
"user_id": "user-123",
244+
"query": "How should I format the answer?",
245+
"limit": 5,
246+
},
247+
)
248+
retrieved = json.loads(retrieve_response.content[0].text)
249+
250+
assert retrieved["user_id"] == "user-123"
251+
assert "categories" in retrieved
252+
assert retrieved["matched_count"] >= 0

0 commit comments

Comments
 (0)