Skip to content
38 changes: 33 additions & 5 deletions altk_evolve/frontend/mcp/__main__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
import argparse
import logging
import os
import sys
import threading
import uvicorn

from altk_evolve.frontend.mcp.mcp_server import mcp, app
from altk_evolve.frontend.mcp.mcp_server import app, get_client, mcp
from altk_evolve.frontend.mcp.http_transport import create_resilient_sse_app

logger = logging.getLogger("evolve-mcp")


def _is_truthy_env(name: str, default: bool) -> bool:
raw = os.getenv(name)
if raw is None:
return default
return raw.strip().lower() not in {"0", "false", "no", "off"}


def warmup_mcp_runtime() -> None:
"""Pre-initialize MCP backend state to reduce first-tool-call latency."""
logger.info("Warming up MCP runtime...")
get_client()
logger.info("MCP runtime warmup complete")


def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run the Evolve MCP server")
parser.add_argument(
Expand Down Expand Up @@ -42,9 +57,22 @@ def run_api_server():


def run_sse_server(host: str, port: int) -> None:
"""Run the MCP server over SSE with disconnect-safe transport handling."""
sse_app = create_resilient_sse_app(mcp)
uvicorn.run(sse_app, host=host, port=port, log_level="warning")
"""Run the MCP server over SSE with disconnect-tolerant teardown."""
if _is_truthy_env("EVOLVE_MCP_WARMUP", True):
try:
warmup_mcp_runtime()
except Exception as exc:
# Keep startup resilient: failed warmup should not block server boot.
logger.warning("MCP warmup failed; continuing without warmup: %s", exc)

uvicorn.run(
create_resilient_sse_app(mcp),
host=host,
port=port,
lifespan="on",
timeout_graceful_shutdown=3,
ws="websockets-sansio",
)


def main():
Expand All @@ -61,7 +89,7 @@ def main():
# Start FastMCP using stdio (which blocks)
mcp.run()
else:
run_sse_server(host=args.host, port=args.port)
run_sse_server(args.host, args.port)
except KeyboardInterrupt:
logger.info("MCP server stopped by user (KeyboardInterrupt)")
sys.exit(0)
Expand Down
56 changes: 47 additions & 9 deletions altk_evolve/frontend/mcp/http_transport.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
Expand All @@ -10,7 +11,6 @@
from mcp.server.auth.routes import build_resource_metadata_url
from mcp.server.lowlevel.server import LifespanResultT
from mcp.server.sse import SseServerTransport
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Mount, Route

Expand All @@ -28,6 +28,13 @@ def _is_benign_disconnect_exception(exc: BaseException) -> bool:
"""Return True when an exception only represents a dropped SSE client."""
if isinstance(exc, (anyio.ClosedResourceError, anyio.BrokenResourceError)):
return True
if isinstance(exc, asyncio.CancelledError):
# Uvicorn cancels outstanding request tasks during Ctrl+C shutdown.
return True
if isinstance(exc, AssertionError) and str(exc) == "Request already responded to":
# MCP low-level request responder can assert during SSE teardown races
# when cancellation/close wins over the normal response path.
return True

if isinstance(exc, BaseExceptionGroup):
return len(exc.exceptions) > 0 and all(_is_benign_disconnect_exception(child) for child in exc.exceptions)
Expand Down Expand Up @@ -59,6 +66,23 @@ async def _run_sse_session(
return True


async def _handle_sse(
server: FastMCP[LifespanResultT],
sse: SseServerTransport,
scope,
receive,
send,
) -> None:
"""
Serve SSE directly as ASGI and avoid sending any follow-up HTTP response.

`connect_sse(...)(scope, receive, send)` owns the HTTP response lifecycle.
Returning an additional Response after it exits can race with teardown and
trigger duplicate MCP request completion assertions.
"""
await _run_sse_session(server, sse, scope, receive, send)


def create_resilient_sse_app(
server: FastMCP[LifespanResultT],
message_path: str | None = None,
Expand All @@ -78,9 +102,27 @@ def create_resilient_sse_app(

sse = SseServerTransport(message_path)

async def handle_sse(scope, receive, send) -> Response:
await _run_sse_session(server, sse, scope, receive, send)
return Response(status_code=204)
async def handle_sse(scope, receive, send) -> None:
await _handle_sse(server, sse, scope, receive, send)

class SseEndpoint:
"""ASGI app wrapping handle_sse that tracks whether a response was started."""

async def __call__(self, scope, receive, send) -> None:
response_started = False

async def tracked_send(message) -> None:
nonlocal response_started
if message.get("type") == "http.response.start":
response_started = True
await send(message)

await handle_sse(scope, receive, tracked_send)
if not response_started:
response = Response(status_code=204)
await response(scope, receive, send)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

sse_endpoint = SseEndpoint()

if auth:
auth_middleware = auth.get_middleware()
Expand All @@ -95,7 +137,7 @@ async def handle_sse(scope, receive, send) -> Response:
Route(
sse_path,
endpoint=RequireAuthMiddleware(
handle_sse,
sse_endpoint,
auth.required_scopes,
resource_metadata_url,
),
Expand All @@ -113,10 +155,6 @@ async def handle_sse(scope, receive, send) -> Response:
)
)
else:

async def sse_endpoint(request: Request) -> Response:
return await handle_sse(request.scope, request.receive, request._send)

server_routes.append(Route(sse_path, endpoint=sse_endpoint, methods=["GET"]))
server_routes.append(Mount(message_path, app=sse.handle_post_message))

Expand Down
87 changes: 86 additions & 1 deletion altk_evolve/frontend/mcp/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import threading
import uuid
import os
from typing import Any

from fastmcp import FastMCP
from fastapi import FastAPI
Expand Down Expand Up @@ -214,6 +215,22 @@ def get_entities_logic(
return "\n".join(response_lines)


def _parse_metadata(metadata: str | None) -> dict[str, Any]:
if not metadata:
return {}

try:
parsed = json.loads(metadata)
except json.JSONDecodeError as e:
logger.warning("Invalid JSON in metadata parameter: %s", e)
raise ValueError(f"Failed to parse metadata: {str(e)}") from e

if not isinstance(parsed, dict):
raise ValueError("Metadata must decode to a JSON object")

return parsed


@mcp.tool()
def get_entities(
task: str,
Expand Down Expand Up @@ -261,6 +278,74 @@ def get_guidelines(
return get_entities_logic(task, "guideline", user_id=user_id, namespace_id=namespace_id, session_id=session_id)


@mcp.tool()
def store_user_facts(
user_id: str,
message: str,
metadata: str | None = None,
enable_conflict_resolution: bool = False,
) -> str:
"""Extract and store user facts/preferences for a durable user identity."""
try:
metadata_dict = _parse_metadata(metadata)
except ValueError as e:
return json.dumps(
{
"error": "Invalid JSON",
"message": str(e),
"invalid_metadata": metadata,
}
)

updates = get_client().store_user_facts(
namespace_id=evolve_config.namespace_id,
message=message,
user_id=user_id,
metadata=metadata_dict,
enable_conflict_resolution=enable_conflict_resolution,
)

serialized_updates = [
{
"event": update.event,
"id": update.id,
"type": update.type,
"content": update.content,
"metadata": update.metadata,
}
for update in updates
]

return json.dumps(
{
"user_id": user_id,
"stored_count": len(serialized_updates),
"updates": serialized_updates,
}
)


@mcp.tool()
def retrieve_user_facts(user_id: str, query: str | None = None, limit: int = 5) -> str:
"""Retrieve categorized user facts/preferences for a durable user identity."""
categories = get_client().retrieve_user_facts(
namespace_id=evolve_config.namespace_id,
user_id=user_id,
query=query,
limit=limit,
)
matched_count = sum(len(items) for items in categories.values())

return json.dumps(
{
"user_id": user_id,
"query": query,
"matched_count": matched_count,
"categories": categories,
}
)


@mcp.tool()
def save_trajectory(
trajectory_data: str,
Expand Down Expand Up @@ -413,7 +498,7 @@ def create_entity(
try:
metadata_dict = json.loads(metadata)
except json.JSONDecodeError as e:
logger.exception(f"Invalid JSON in metadata parameter: {str(e)}")
logger.warning("Invalid JSON in metadata parameter: %s", e)
return json.dumps({"error": "Invalid JSON", "message": f"Failed to parse metadata: {str(e)}", "invalid_metadata": metadata})
if not isinstance(metadata_dict, dict):
return json.dumps(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ warn_unused_configs = true
disallow_untyped_defs = false
explicit_package_bases = true
exclude = [
"build/",
"platform-integrations/",
"examples/",
]
Expand Down
31 changes: 31 additions & 0 deletions tests/e2e/test_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,34 @@ async def test_create_entity_with_invalid_json_metadata(mcp):
assert result["error"] == "Invalid JSON"
assert "message" in result
assert "invalid_metadata" in result


@pytest.mark.e2e
async def test_store_and_retrieve_user_facts(mcp):
async with Client(transport=mcp) as evolve_mcp:
store_response = await evolve_mcp.call_tool_mcp(
"store_user_facts",
{
"user_id": "user-123",
"message": "I prefer concise answers with bullet points.",
"metadata": json.dumps({"source": "cuga-lite"}),
"enable_conflict_resolution": False,
},
)
stored = json.loads(store_response.content[0].text)
assert stored["user_id"] == "user-123"
assert stored["stored_count"] >= 1

retrieve_response = await evolve_mcp.call_tool_mcp(
"retrieve_user_facts",
{
"user_id": "user-123",
"query": "How should I format the answer?",
"limit": 5,
},
)
retrieved = json.loads(retrieve_response.content[0].text)

assert retrieved["user_id"] == "user-123"
assert "categories" in retrieved
assert retrieved["matched_count"] >= 0
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
Loading
Loading