Skip to content

Commit 1b25299

Browse files
feat(mcp): implement dedicated MCP database session management
- Create a dedicated async database engine for MCP handlers - Implement `get_mcp_db_session` to manage MCP database sessions - Replace direct database context usage with MCP session management in multiple functions - Ensure proper session closure to avoid connection issues
1 parent c66633d commit 1b25299

File tree

1 file changed

+82
-15
lines changed

1 file changed

+82
-15
lines changed

api/mcp/server.py

Lines changed: 82 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,66 @@
44
Provides tools for AI assistants to search plot specifications and fetch implementation code.
55
"""
66

7+
import os
78
from typing import Any
89

910
from fastmcp import FastMCP
11+
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
12+
from sqlalchemy.pool import NullPool
1013

1114
from api.schemas import ImplementationResponse, SpecDetailResponse, SpecListItem
12-
from core.database import ImplRepository, LibraryRepository, SpecRepository, get_db_context, is_db_configured
15+
from core.database import ImplRepository, LibraryRepository, SpecRepository, is_db_configured
1316

1417

1518
# Website URL for linking to pyplots.ai
1619
PYPLOTS_WEBSITE_URL = "https://pyplots.ai"
1720

21+
# MCP-specific database engine (created lazily)
22+
# This is separate from FastAPI's engine to avoid greenlet context issues
23+
_mcp_engine = None
24+
_mcp_session_factory = None
25+
26+
27+
def _get_mcp_engine():
28+
"""Create a dedicated engine for MCP handlers."""
29+
global _mcp_engine, _mcp_session_factory
30+
31+
if _mcp_engine is not None:
32+
return _mcp_engine
33+
34+
database_url = os.getenv("DATABASE_URL", "")
35+
if not database_url:
36+
raise ValueError("DATABASE_URL not configured")
37+
38+
# Ensure async driver
39+
if database_url.startswith("postgresql://"):
40+
database_url = database_url.replace("postgresql://", "postgresql+asyncpg://")
41+
elif database_url.startswith("postgres://"):
42+
database_url = database_url.replace("postgres://", "postgresql+asyncpg://")
43+
44+
# Use NullPool for MCP to avoid connection state issues across requests
45+
_mcp_engine = create_async_engine(database_url, poolclass=NullPool)
46+
_mcp_session_factory = async_sessionmaker(_mcp_engine, class_=AsyncSession, expire_on_commit=False)
47+
48+
return _mcp_engine
49+
50+
51+
async def get_mcp_db_session() -> AsyncSession:
52+
"""
53+
Get database session for MCP handlers.
54+
55+
Uses a dedicated engine to avoid greenlet context issues
56+
that occur when Streamable HTTP transport runs in a different
57+
async context than FastAPI's main event loop.
58+
"""
59+
_get_mcp_engine() # Ensure engine is created
60+
61+
if _mcp_session_factory is None:
62+
raise ValueError("Database not configured. Check DATABASE_URL.")
63+
64+
return _mcp_session_factory()
65+
66+
1867
# Initialize FastMCP server
1968
mcp_server = FastMCP("pyplots")
2069

@@ -34,8 +83,9 @@ async def list_specs(limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
3483
if not is_db_configured():
3584
raise ValueError("Database not configured. Check DATABASE_URL or INSTANCE_CONNECTION_NAME.")
3685

37-
async with get_db_context() as db:
38-
repo = SpecRepository(db)
86+
session = await get_mcp_db_session()
87+
try:
88+
repo = SpecRepository(session)
3989
specs = await repo.get_all()
4090

4191
# Apply pagination
@@ -51,6 +101,8 @@ async def list_specs(limit: int = 100, offset: int = 0) -> list[dict[str, Any]]:
51101
result.append({**item.model_dump(), "website_url": f"{PYPLOTS_WEBSITE_URL}/{spec.id}"})
52102

53103
return result
104+
finally:
105+
await session.close()
54106

55107

56108
@mcp_server.tool()
@@ -103,8 +155,9 @@ async def search_specs_by_tags(
103155
if not is_db_configured():
104156
raise ValueError("Database not configured. Check DATABASE_URL or INSTANCE_CONNECTION_NAME.")
105157

106-
async with get_db_context() as db:
107-
repo = SpecRepository(db)
158+
session = await get_mcp_db_session()
159+
try:
160+
repo = SpecRepository(session)
108161

109162
# Build filter dict (spec-level tags)
110163
filters: dict[str, list[str]] = {}
@@ -174,6 +227,8 @@ async def search_specs_by_tags(
174227
result.append({**item.model_dump(), "website_url": f"{PYPLOTS_WEBSITE_URL}/{spec.id}"})
175228

176229
return result
230+
finally:
231+
await session.close()
177232

178233

179234
@mcp_server.tool()
@@ -197,8 +252,9 @@ async def get_spec_detail(spec_id: str) -> dict[str, Any]:
197252
if not is_db_configured():
198253
raise ValueError("Database not configured. Check DATABASE_URL or INSTANCE_CONNECTION_NAME.")
199254

200-
async with get_db_context() as db:
201-
repo = SpecRepository(db)
255+
session = await get_mcp_db_session()
256+
try:
257+
repo = SpecRepository(session)
202258
spec = await repo.get_by_id(spec_id)
203259

204260
if spec is None:
@@ -250,6 +306,8 @@ async def get_spec_detail(spec_id: str) -> dict[str, Any]:
250306
)
251307

252308
return {**response.model_dump(), "website_url": f"{PYPLOTS_WEBSITE_URL}/{spec_id}"}
309+
finally:
310+
await session.close()
253311

254312

255313
@mcp_server.tool()
@@ -276,10 +334,11 @@ async def get_implementation(spec_id: str, library: str) -> dict[str, Any]:
276334
if not is_db_configured():
277335
raise ValueError("Database not configured. Check DATABASE_URL or INSTANCE_CONNECTION_NAME.")
278336

279-
async with get_db_context() as db:
280-
spec_repo = SpecRepository(db)
281-
library_repo = LibraryRepository(db)
282-
impl_repo = ImplRepository(db)
337+
session = await get_mcp_db_session()
338+
try:
339+
spec_repo = SpecRepository(session)
340+
library_repo = LibraryRepository(session)
341+
impl_repo = ImplRepository(session)
283342

284343
# Validate spec exists
285344
spec = await spec_repo.get_by_id(spec_id)
@@ -320,6 +379,8 @@ async def get_implementation(spec_id: str, library: str) -> dict[str, Any]:
320379
)
321380

322381
return {**response.model_dump(), "website_url": f"{PYPLOTS_WEBSITE_URL}/{spec_id}/{library}"}
382+
finally:
383+
await session.close()
323384

324385

325386
@mcp_server.tool()
@@ -333,15 +394,18 @@ async def list_libraries() -> list[dict[str, Any]]:
333394
if not is_db_configured():
334395
raise ValueError("Database not configured. Check DATABASE_URL or INSTANCE_CONNECTION_NAME.")
335396

336-
async with get_db_context() as db:
337-
repo = LibraryRepository(db)
397+
session = await get_mcp_db_session()
398+
try:
399+
repo = LibraryRepository(session)
338400
libraries = await repo.get_all()
339401

340402
result = []
341403
for lib in libraries:
342404
result.append({"id": lib.id, "name": lib.name, "description": lib.description})
343405

344406
return result
407+
finally:
408+
await session.close()
345409

346410

347411
@mcp_server.tool()
@@ -384,8 +448,9 @@ async def get_tag_values(category: str) -> list[str]:
384448
if not is_db_configured():
385449
raise ValueError("Database not configured. Check DATABASE_URL or INSTANCE_CONNECTION_NAME.")
386450

387-
async with get_db_context() as db:
388-
repo = SpecRepository(db)
451+
session = await get_mcp_db_session()
452+
try:
453+
repo = SpecRepository(session)
389454
specs = await repo.get_all()
390455

391456
# Collect unique tag values
@@ -408,3 +473,5 @@ async def get_tag_values(category: str) -> list[str]:
408473
values.update(tag_list)
409474

410475
return sorted(values)
476+
finally:
477+
await session.close()

0 commit comments

Comments
 (0)