Skip to content
38 changes: 33 additions & 5 deletions altk_evolve/frontend/mcp/__main__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,30 @@
import argparse
import logging
import os
import sys
import threading
import uvicorn

from altk_evolve.frontend.mcp.mcp_server import mcp, app
from altk_evolve.frontend.mcp.mcp_server import app, get_client, mcp
from altk_evolve.frontend.mcp.http_transport import create_resilient_sse_app

logger = logging.getLogger("evolve-mcp")


def _is_truthy_env(name: str, default: bool) -> bool:
raw = os.getenv(name)
if raw is None:
return default
return raw.strip().lower() not in {"0", "false", "no", "off"}


def warmup_mcp_runtime() -> None:
"""Pre-initialize MCP backend state to reduce first-tool-call latency."""
logger.info("Warming up MCP runtime...")
get_client()
logger.info("MCP runtime warmup complete")


def _build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="Run the Evolve MCP server")
parser.add_argument(
Expand Down Expand Up @@ -42,9 +57,22 @@ def run_api_server():


def run_sse_server(host: str, port: int) -> None:
"""Run the MCP server over SSE with disconnect-safe transport handling."""
sse_app = create_resilient_sse_app(mcp)
uvicorn.run(sse_app, host=host, port=port, log_level="warning")
"""Run the MCP server over SSE with disconnect-tolerant teardown."""
if _is_truthy_env("EVOLVE_MCP_WARMUP", True):
try:
warmup_mcp_runtime()
except Exception as exc:
# Keep startup resilient: failed warmup should not block server boot.
logger.warning("MCP warmup failed; continuing without warmup: %s", exc)

uvicorn.run(
create_resilient_sse_app(mcp),
host=host,
port=port,
lifespan="on",
timeout_graceful_shutdown=3,
ws="websockets-sansio",
)


def main():
Expand All @@ -61,7 +89,7 @@ def main():
# Start FastMCP using stdio (which blocks)
mcp.run()
else:
run_sse_server(host=args.host, port=args.port)
run_sse_server(args.host, args.port)
except KeyboardInterrupt:
logger.info("MCP server stopped by user (KeyboardInterrupt)")
sys.exit(0)
Expand Down
60 changes: 50 additions & 10 deletions altk_evolve/frontend/mcp/http_transport.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
import logging
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
Expand All @@ -10,7 +11,6 @@
from mcp.server.auth.routes import build_resource_metadata_url
from mcp.server.lowlevel.server import LifespanResultT
from mcp.server.sse import SseServerTransport
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import BaseRoute, Mount, Route

Expand All @@ -28,6 +28,13 @@ def _is_benign_disconnect_exception(exc: BaseException) -> bool:
"""Return True when an exception only represents a dropped SSE client."""
if isinstance(exc, (anyio.ClosedResourceError, anyio.BrokenResourceError)):
return True
if isinstance(exc, asyncio.CancelledError):
# Uvicorn cancels outstanding request tasks during Ctrl+C shutdown.
return True
if isinstance(exc, AssertionError) and str(exc) == "Request already responded to":
# MCP low-level request responder can assert during SSE teardown races
# when cancellation/close wins over the normal response path.
return True

if isinstance(exc, BaseExceptionGroup):
return len(exc.exceptions) > 0 and all(_is_benign_disconnect_exception(child) for child in exc.exceptions)
Expand All @@ -50,7 +57,7 @@ async def _run_sse_session(
streams[1],
server._mcp_server.create_initialization_options(),
)
except Exception as exc:
except BaseException as exc:
if _is_benign_disconnect_exception(exc):
logger.debug("Suppressing benign SSE disconnect during response flush")
return False
Expand All @@ -59,6 +66,25 @@ async def _run_sse_session(
return True


async def _handle_sse(
server: FastMCP[LifespanResultT],
sse: SseServerTransport,
scope,
receive,
send,
) -> bool:
"""
Serve SSE directly as ASGI and avoid sending any follow-up HTTP response.

`connect_sse(...)(scope, receive, send)` owns the HTTP response lifecycle.
Returning an additional Response after it exits can race with teardown and
trigger duplicate MCP request completion assertions.

Returns False when the session ended due to a benign client disconnect.
"""
return await _run_sse_session(server, sse, scope, receive, send)


def create_resilient_sse_app(
server: FastMCP[LifespanResultT],
message_path: str | None = None,
Expand All @@ -78,9 +104,27 @@ def create_resilient_sse_app(

sse = SseServerTransport(message_path)

async def handle_sse(scope, receive, send) -> Response:
await _run_sse_session(server, sse, scope, receive, send)
return Response(status_code=204)
async def handle_sse(scope, receive, send) -> bool:
return await _handle_sse(server, sse, scope, receive, send)

class SseEndpoint:
"""ASGI app wrapping handle_sse that tracks whether a response was started."""

async def __call__(self, scope, receive, send) -> None:
response_started = False

async def tracked_send(message) -> None:
nonlocal response_started
if message.get("type") == "http.response.start":
response_started = True
await send(message)

completed = await handle_sse(scope, receive, tracked_send)
if completed and not response_started:
response = Response(status_code=204)
await response(scope, receive, send)

sse_endpoint = SseEndpoint()

if auth:
auth_middleware = auth.get_middleware()
Expand All @@ -95,7 +139,7 @@ async def handle_sse(scope, receive, send) -> Response:
Route(
sse_path,
endpoint=RequireAuthMiddleware(
handle_sse,
sse_endpoint,
auth.required_scopes,
resource_metadata_url,
),
Expand All @@ -113,10 +157,6 @@ async def handle_sse(scope, receive, send) -> Response:
)
)
else:

async def sse_endpoint(request: Request) -> Response:
return await handle_sse(request.scope, request.receive, request._send)

server_routes.append(Route(sse_path, endpoint=sse_endpoint, methods=["GET"]))
server_routes.append(Mount(message_path, app=sse.handle_post_message))

Expand Down
87 changes: 86 additions & 1 deletion altk_evolve/frontend/mcp/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import threading
import uuid
import os
from typing import Any

from fastmcp import FastMCP
from fastapi import FastAPI
Expand Down Expand Up @@ -214,6 +215,22 @@ def get_entities_logic(
return "\n".join(response_lines)


def _parse_metadata(metadata: str | None) -> dict[str, Any]:
if not metadata:
return {}

try:
parsed = json.loads(metadata)
except json.JSONDecodeError as e:
logger.warning("Invalid JSON in metadata parameter: %s", e)
raise ValueError(f"Failed to parse metadata: {str(e)}") from e

if not isinstance(parsed, dict):
raise ValueError("Metadata must decode to a JSON object")

return parsed


@mcp.tool()
def get_entities(
task: str,
Expand Down Expand Up @@ -261,6 +278,74 @@ def get_guidelines(
return get_entities_logic(task, "guideline", user_id=user_id, namespace_id=namespace_id, session_id=session_id)


@mcp.tool()
def store_user_facts(
user_id: str,
message: str,
metadata: str | None = None,
enable_conflict_resolution: bool = False,
) -> str:
"""Extract and store user facts/preferences for a durable user identity."""
try:
metadata_dict = _parse_metadata(metadata)
except ValueError as e:
return json.dumps(
{
"error": "Invalid JSON",
"message": str(e),
"invalid_metadata": metadata,
}
)

updates = get_client().store_user_facts(
namespace_id=evolve_config.namespace_id,
message=message,
user_id=user_id,
metadata=metadata_dict,
enable_conflict_resolution=enable_conflict_resolution,
)

serialized_updates = [
{
"event": update.event,
"id": update.id,
"type": update.type,
"content": update.content,
"metadata": update.metadata,
}
for update in updates
]

return json.dumps(
{
"user_id": user_id,
"stored_count": len(serialized_updates),
"updates": serialized_updates,
}
)


@mcp.tool()
def retrieve_user_facts(user_id: str, query: str | None = None, limit: int = 5) -> str:
"""Retrieve categorized user facts/preferences for a durable user identity."""
categories = get_client().retrieve_user_facts(
namespace_id=evolve_config.namespace_id,
user_id=user_id,
query=query,
limit=limit,
)
matched_count = sum(len(items) for items in categories.values())

return json.dumps(
{
"user_id": user_id,
"query": query,
"matched_count": matched_count,
"categories": categories,
}
)


@mcp.tool()
def save_trajectory(
trajectory_data: str,
Expand Down Expand Up @@ -413,7 +498,7 @@ def create_entity(
try:
metadata_dict = json.loads(metadata)
except json.JSONDecodeError as e:
logger.exception(f"Invalid JSON in metadata parameter: {str(e)}")
logger.warning("Invalid JSON in metadata parameter: %s", e)
return json.dumps({"error": "Invalid JSON", "message": f"Failed to parse metadata: {str(e)}", "invalid_metadata": metadata})
if not isinstance(metadata_dict, dict):
return json.dumps(
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ warn_unused_configs = true
disallow_untyped_defs = false
explicit_package_bases = true
exclude = [
"build/",
"platform-integrations/",
"examples/",
]
Expand Down
14 changes: 14 additions & 0 deletions tests/e2e/conftest.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,22 @@
import os
import uuid
import warnings
import pytest
from altk_evolve.config.milvus import milvus_client_settings


def pytest_configure(config):
"""Warn early when LLM env vars needed by e2e tests are missing."""
has_openai = bool(os.environ.get("OPENAI_API_KEY"))
has_evolve_model = bool(os.environ.get("EVOLVE_MODEL_NAME"))
if not has_openai and not has_evolve_model:
warnings.warn(
"No OPENAI_API_KEY or EVOLVE_MODEL_NAME set — e2e tests that depend on "
"LLM fact-extraction will fail. See docs/guides/configuration.md.",
stacklevel=1,
)


_EVOLVE_ENV_KEYS = ("EVOLVE_NAMESPACE_ID", "EVOLVE_BACKEND", "EVOLVE_SQLITE_PATH", "EVOLVE_DATA_DIR")


Expand Down
39 changes: 39 additions & 0 deletions tests/e2e/test_mcp.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import asyncio
import uuid

import pytest
import json
from fastmcp.client import Client
Expand Down Expand Up @@ -219,3 +222,39 @@ async def test_create_entity_with_invalid_json_metadata(mcp):
assert result["error"] == "Invalid JSON"
assert "message" in result
assert "invalid_metadata" in result


@pytest.mark.e2e
async def test_store_and_retrieve_user_facts(mcp):
async with Client(transport=mcp) as evolve_mcp:
user_id = f"user-{uuid.uuid4()}"
store_response = await evolve_mcp.call_tool_mcp(
"store_user_facts",
{
"user_id": user_id,
"message": "I prefer concise answers with bullet points.",
"metadata": json.dumps({"source": "cuga-lite"}),
"enable_conflict_resolution": False,
},
)
stored = json.loads(store_response.content[0].text)
assert stored["user_id"] == user_id
assert stored["stored_count"] >= 1

for _ in range(10):
retrieve_response = await evolve_mcp.call_tool_mcp(
"retrieve_user_facts",
{
"user_id": user_id,
"query": "How should I format the answer?",
"limit": 5,
},
)
retrieved = json.loads(retrieve_response.content[0].text)
if retrieved["matched_count"] > 0:
break
await asyncio.sleep(0.25)

assert retrieved["user_id"] == user_id
assert "categories" in retrieved
assert retrieved["matched_count"] > 0
Loading
Loading