Skip to content

Commit e83668c

Browse files
authored
feat(mcp): user facts tools, SSE transport hardening, and warmup (#238)
* feat(mcp): add user facts tools, SSE transport hardening, and warmup Add store_user_facts/retrieve_user_facts MCP tools for durable user identity. Harden SSE transport by replacing private Starlette API (request._send) with public ASGI-level SseEndpoint class, suppressing CancelledError and teardown assertion races, and bumping graceful shutdown timeout to 3s. Add optional warmup on SSE server boot to reduce first-tool-call latency. Exclude build/ from mypy to fix pre-existing stale artifact errors. Test coverage: warmup disabled/failure paths, auth-enabled SSE route wiring, user facts store/retrieve/validation.
1 parent 0134371 commit e83668c

9 files changed

Lines changed: 413 additions & 22 deletions

File tree

altk_evolve/frontend/mcp/__main__.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,30 @@
11
import argparse
22
import logging
3+
import os
34
import sys
45
import threading
56
import uvicorn
67

7-
from altk_evolve.frontend.mcp.mcp_server import mcp, app
8+
from altk_evolve.frontend.mcp.mcp_server import app, get_client, mcp
89
from altk_evolve.frontend.mcp.http_transport import create_resilient_sse_app
910

1011
logger = logging.getLogger("evolve-mcp")
1112

1213

14+
def _is_truthy_env(name: str, default: bool) -> bool:
15+
raw = os.getenv(name)
16+
if raw is None:
17+
return default
18+
return raw.strip().lower() not in {"0", "false", "no", "off"}
19+
20+
21+
def warmup_mcp_runtime() -> None:
22+
"""Pre-initialize MCP backend state to reduce first-tool-call latency."""
23+
logger.info("Warming up MCP runtime...")
24+
get_client()
25+
logger.info("MCP runtime warmup complete")
26+
27+
1328
def _build_parser() -> argparse.ArgumentParser:
1429
parser = argparse.ArgumentParser(description="Run the Evolve MCP server")
1530
parser.add_argument(
@@ -42,9 +57,22 @@ def run_api_server():
4257

4358

4459
def run_sse_server(host: str, port: int) -> None:
45-
"""Run the MCP server over SSE with disconnect-safe transport handling."""
46-
sse_app = create_resilient_sse_app(mcp)
47-
uvicorn.run(sse_app, host=host, port=port, log_level="warning")
60+
"""Run the MCP server over SSE with disconnect-tolerant teardown."""
61+
if _is_truthy_env("EVOLVE_MCP_WARMUP", True):
62+
try:
63+
warmup_mcp_runtime()
64+
except Exception as exc:
65+
# Keep startup resilient: failed warmup should not block server boot.
66+
logger.warning("MCP warmup failed; continuing without warmup: %s", exc)
67+
68+
uvicorn.run(
69+
create_resilient_sse_app(mcp),
70+
host=host,
71+
port=port,
72+
lifespan="on",
73+
timeout_graceful_shutdown=3,
74+
ws="websockets-sansio",
75+
)
4876

4977

5078
def main():
@@ -61,7 +89,7 @@ def main():
6189
# Start FastMCP using stdio (which blocks)
6290
mcp.run()
6391
else:
64-
run_sse_server(host=args.host, port=args.port)
92+
run_sse_server(args.host, args.port)
6593
except KeyboardInterrupt:
6694
logger.info("MCP server stopped by user (KeyboardInterrupt)")
6795
sys.exit(0)

altk_evolve/frontend/mcp/http_transport.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import logging
45
from collections.abc import AsyncGenerator
56
from contextlib import asynccontextmanager
@@ -10,7 +11,6 @@
1011
from mcp.server.auth.routes import build_resource_metadata_url
1112
from mcp.server.lowlevel.server import LifespanResultT
1213
from mcp.server.sse import SseServerTransport
13-
from starlette.requests import Request
1414
from starlette.responses import Response
1515
from starlette.routing import BaseRoute, Mount, Route
1616

@@ -28,6 +28,13 @@ def _is_benign_disconnect_exception(exc: BaseException) -> bool:
2828
"""Return True when an exception only represents a dropped SSE client."""
2929
if isinstance(exc, (anyio.ClosedResourceError, anyio.BrokenResourceError)):
3030
return True
31+
if isinstance(exc, asyncio.CancelledError):
32+
# Uvicorn cancels outstanding request tasks during Ctrl+C shutdown.
33+
return True
34+
if isinstance(exc, AssertionError) and str(exc) == "Request already responded to":
35+
# MCP low-level request responder can assert during SSE teardown races
36+
# when cancellation/close wins over the normal response path.
37+
return True
3138

3239
if isinstance(exc, BaseExceptionGroup):
3340
return len(exc.exceptions) > 0 and all(_is_benign_disconnect_exception(child) for child in exc.exceptions)
@@ -50,7 +57,7 @@ async def _run_sse_session(
5057
streams[1],
5158
server._mcp_server.create_initialization_options(),
5259
)
53-
except Exception as exc:
60+
except BaseException as exc:
5461
if _is_benign_disconnect_exception(exc):
5562
logger.debug("Suppressing benign SSE disconnect during response flush")
5663
return False
@@ -59,6 +66,25 @@ async def _run_sse_session(
5966
return True
6067

6168

69+
async def _handle_sse(
70+
server: FastMCP[LifespanResultT],
71+
sse: SseServerTransport,
72+
scope,
73+
receive,
74+
send,
75+
) -> bool:
76+
"""
77+
Serve SSE directly as ASGI and avoid sending any follow-up HTTP response.
78+
79+
`connect_sse(...)(scope, receive, send)` owns the HTTP response lifecycle.
80+
Returning an additional Response after it exits can race with teardown and
81+
trigger duplicate MCP request completion assertions.
82+
83+
Returns False when the session ended due to a benign client disconnect.
84+
"""
85+
return await _run_sse_session(server, sse, scope, receive, send)
86+
87+
6288
def create_resilient_sse_app(
6389
server: FastMCP[LifespanResultT],
6490
message_path: str | None = None,
@@ -78,9 +104,27 @@ def create_resilient_sse_app(
78104

79105
sse = SseServerTransport(message_path)
80106

81-
async def handle_sse(scope, receive, send) -> Response:
82-
await _run_sse_session(server, sse, scope, receive, send)
83-
return Response(status_code=204)
107+
async def handle_sse(scope, receive, send) -> bool:
108+
return await _handle_sse(server, sse, scope, receive, send)
109+
110+
class SseEndpoint:
111+
"""ASGI app wrapping handle_sse that tracks whether a response was started."""
112+
113+
async def __call__(self, scope, receive, send) -> None:
114+
response_started = False
115+
116+
async def tracked_send(message) -> None:
117+
nonlocal response_started
118+
if message.get("type") == "http.response.start":
119+
response_started = True
120+
await send(message)
121+
122+
completed = await handle_sse(scope, receive, tracked_send)
123+
if completed and not response_started:
124+
response = Response(status_code=204)
125+
await response(scope, receive, send)
126+
127+
sse_endpoint = SseEndpoint()
84128

85129
if auth:
86130
auth_middleware = auth.get_middleware()
@@ -95,7 +139,7 @@ async def handle_sse(scope, receive, send) -> Response:
95139
Route(
96140
sse_path,
97141
endpoint=RequireAuthMiddleware(
98-
handle_sse,
142+
sse_endpoint,
99143
auth.required_scopes,
100144
resource_metadata_url,
101145
),
@@ -113,10 +157,6 @@ async def handle_sse(scope, receive, send) -> Response:
113157
)
114158
)
115159
else:
116-
117-
async def sse_endpoint(request: Request) -> Response:
118-
return await handle_sse(request.scope, request.receive, request._send)
119-
120160
server_routes.append(Route(sse_path, endpoint=sse_endpoint, methods=["GET"]))
121161
server_routes.append(Mount(message_path, app=sse.handle_post_message))
122162

altk_evolve/frontend/mcp/mcp_server.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import threading
1010
import uuid
1111
import os
12+
from typing import Any
1213

1314
from fastmcp import FastMCP
1415
from fastapi import FastAPI
@@ -214,6 +215,22 @@ def get_entities_logic(
214215
return "\n".join(response_lines)
215216

216217

218+
def _parse_metadata(metadata: str | None) -> dict[str, Any]:
219+
if not metadata:
220+
return {}
221+
222+
try:
223+
parsed = json.loads(metadata)
224+
except json.JSONDecodeError as e:
225+
logger.warning("Invalid JSON in metadata parameter: %s", e)
226+
raise ValueError(f"Failed to parse metadata: {str(e)}") from e
227+
228+
if not isinstance(parsed, dict):
229+
raise ValueError("Metadata must decode to a JSON object")
230+
231+
return parsed
232+
233+
217234
@mcp.tool()
218235
def get_entities(
219236
task: str,
@@ -261,6 +278,74 @@ def get_guidelines(
261278
return get_entities_logic(task, "guideline", user_id=user_id, namespace_id=namespace_id, session_id=session_id)
262279

263280

281+
@mcp.tool()
282+
def store_user_facts(
283+
user_id: str,
284+
message: str,
285+
metadata: str | None = None,
286+
enable_conflict_resolution: bool = False,
287+
) -> str:
288+
"""Extract and store user facts/preferences for a durable user identity."""
289+
try:
290+
metadata_dict = _parse_metadata(metadata)
291+
except ValueError as e:
292+
return json.dumps(
293+
{
294+
"error": "Invalid JSON",
295+
"message": str(e),
296+
"invalid_metadata": metadata,
297+
}
298+
)
299+
300+
updates = get_client().store_user_facts(
301+
namespace_id=evolve_config.namespace_id,
302+
message=message,
303+
user_id=user_id,
304+
metadata=metadata_dict,
305+
enable_conflict_resolution=enable_conflict_resolution,
306+
)
307+
308+
serialized_updates = [
309+
{
310+
"event": update.event,
311+
"id": update.id,
312+
"type": update.type,
313+
"content": update.content,
314+
"metadata": update.metadata,
315+
}
316+
for update in updates
317+
]
318+
319+
return json.dumps(
320+
{
321+
"user_id": user_id,
322+
"stored_count": len(serialized_updates),
323+
"updates": serialized_updates,
324+
}
325+
)
326+
327+
328+
@mcp.tool()
329+
def retrieve_user_facts(user_id: str, query: str | None = None, limit: int = 5) -> str:
330+
"""Retrieve categorized user facts/preferences for a durable user identity."""
331+
categories = get_client().retrieve_user_facts(
332+
namespace_id=evolve_config.namespace_id,
333+
user_id=user_id,
334+
query=query,
335+
limit=limit,
336+
)
337+
matched_count = sum(len(items) for items in categories.values())
338+
339+
return json.dumps(
340+
{
341+
"user_id": user_id,
342+
"query": query,
343+
"matched_count": matched_count,
344+
"categories": categories,
345+
}
346+
)
347+
348+
264349
@mcp.tool()
265350
def save_trajectory(
266351
trajectory_data: str,
@@ -413,7 +498,7 @@ def create_entity(
413498
try:
414499
metadata_dict = json.loads(metadata)
415500
except json.JSONDecodeError as e:
416-
logger.exception(f"Invalid JSON in metadata parameter: {str(e)}")
501+
logger.warning("Invalid JSON in metadata parameter: %s", e)
417502
return json.dumps({"error": "Invalid JSON", "message": f"Failed to parse metadata: {str(e)}", "invalid_metadata": metadata})
418503
if not isinstance(metadata_dict, dict):
419504
return json.dumps(

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ warn_unused_configs = true
159159
disallow_untyped_defs = false
160160
explicit_package_bases = true
161161
exclude = [
162+
"build/",
162163
"platform-integrations/",
163164
"examples/",
164165
]

tests/e2e/conftest.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
11
import os
22
import uuid
3+
import warnings
34
import pytest
45
from altk_evolve.config.milvus import milvus_client_settings
56

7+
8+
def pytest_configure(config):
9+
"""Warn early when LLM env vars needed by e2e tests are missing."""
10+
has_openai = bool(os.environ.get("OPENAI_API_KEY"))
11+
has_evolve_model = bool(os.environ.get("EVOLVE_MODEL_NAME"))
12+
if not has_openai and not has_evolve_model:
13+
warnings.warn(
14+
"No OPENAI_API_KEY or EVOLVE_MODEL_NAME set — e2e tests that depend on "
15+
"LLM fact-extraction will fail. See docs/guides/configuration.md.",
16+
stacklevel=1,
17+
)
18+
19+
620
_EVOLVE_ENV_KEYS = ("EVOLVE_NAMESPACE_ID", "EVOLVE_BACKEND", "EVOLVE_SQLITE_PATH", "EVOLVE_DATA_DIR")
721

822

tests/e2e/test_mcp.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import asyncio
2+
import uuid
3+
14
import pytest
25
import json
36
from fastmcp.client import Client
@@ -219,3 +222,39 @@ async def test_create_entity_with_invalid_json_metadata(mcp):
219222
assert result["error"] == "Invalid JSON"
220223
assert "message" in result
221224
assert "invalid_metadata" in result
225+
226+
227+
@pytest.mark.e2e
228+
async def test_store_and_retrieve_user_facts(mcp):
229+
async with Client(transport=mcp) as evolve_mcp:
230+
user_id = f"user-{uuid.uuid4()}"
231+
store_response = await evolve_mcp.call_tool_mcp(
232+
"store_user_facts",
233+
{
234+
"user_id": user_id,
235+
"message": "I prefer concise answers with bullet points.",
236+
"metadata": json.dumps({"source": "cuga-lite"}),
237+
"enable_conflict_resolution": False,
238+
},
239+
)
240+
stored = json.loads(store_response.content[0].text)
241+
assert stored["user_id"] == user_id
242+
assert stored["stored_count"] >= 1
243+
244+
for _ in range(10):
245+
retrieve_response = await evolve_mcp.call_tool_mcp(
246+
"retrieve_user_facts",
247+
{
248+
"user_id": user_id,
249+
"query": "How should I format the answer?",
250+
"limit": 5,
251+
},
252+
)
253+
retrieved = json.loads(retrieve_response.content[0].text)
254+
if retrieved["matched_count"] > 0:
255+
break
256+
await asyncio.sleep(0.25)
257+
258+
assert retrieved["user_id"] == user_id
259+
assert "categories" in retrieved
260+
assert retrieved["matched_count"] > 0

0 commit comments

Comments
 (0)