Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 78 additions & 8 deletions src/basic_memory/mcp/async_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,21 @@
import os
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import AsyncIterator, Callable, Optional
from inspect import isawaitable
from typing import TYPE_CHECKING, AsyncIterator, Callable, Optional, cast

from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient, Timeout
from loguru import logger

import logfire
from basic_memory.config import ConfigManager, ProjectMode

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker

LocalDatabaseState = tuple["AsyncEngine", "async_sessionmaker[AsyncSession]"]
_MISSING_STATE_VALUE = object()


def _force_local_mode() -> bool:
"""Check if local mode is forced via environment variable."""
Expand All @@ -34,15 +42,12 @@ def _build_timeout() -> Timeout:
)


def _asgi_client(timeout: Timeout) -> AsyncClient:
"""Create a local ASGI client."""
# Import on first local-client use so CLI help/version paths can import
# routing helpers without constructing the full FastAPI router graph.
from basic_memory.api.app import app as fastapi_app
def _build_asgi_client(app: FastAPI, timeout: Timeout) -> AsyncClient:
"""Create a local ASGI client for an already-prepared FastAPI app."""
from basic_memory.workspace_context import workspace_permalink_headers

return AsyncClient(
transport=ASGITransport(app=fastapi_app),
transport=ASGITransport(app=app),
base_url="http://test",
timeout=timeout,
# Local ASGI calls still cross the HTTP boundary, so request handlers need
Expand All @@ -51,6 +56,66 @@ def _asgi_client(timeout: Timeout) -> AsyncClient:
)


async def _resolve_local_asgi_database(app: FastAPI) -> LocalDatabaseState:
"""Resolve database state for a local ASGI request."""
from basic_memory.deps import get_engine_factory

override = app.dependency_overrides.get(get_engine_factory)
if override is not None:
result = override()
if isawaitable(result):
result = await result
Comment thread
phernandez marked this conversation as resolved.
Outdated
return cast(LocalDatabaseState, result)

from basic_memory import db

config = ConfigManager().config
return await db.get_or_create_db(config.database_path)


@asynccontextmanager
async def _prepared_local_asgi_database(app: FastAPI) -> AsyncIterator[None]:
"""Initialize local ASGI database state before the first request."""
previous_engine = getattr(app.state, "engine", _MISSING_STATE_VALUE)
previous_session_maker = getattr(app.state, "session_maker", _MISSING_STATE_VALUE)

engine, session_maker = await _resolve_local_asgi_database(app)
app.state.engine = engine
app.state.session_maker = session_maker

try:
yield
finally:
if previous_engine is _MISSING_STATE_VALUE:
if hasattr(app.state, "engine"):
delattr(app.state, "engine")
else:
app.state.engine = previous_engine
Comment thread
phernandez marked this conversation as resolved.
Outdated

if previous_session_maker is _MISSING_STATE_VALUE:
if hasattr(app.state, "session_maker"):
delattr(app.state, "session_maker")
else:
app.state.session_maker = previous_session_maker


@asynccontextmanager
async def _asgi_client(timeout: Timeout) -> AsyncIterator[AsyncClient]:
"""Create a local ASGI client."""
# Import on first local-client use so CLI help/version paths can import
# routing helpers without constructing the full FastAPI router graph.
from basic_memory.api.app import app as fastapi_app

# Trigger: local ASGITransport does not execute FastAPI lifespan startup.
# Why: letting request dependencies initialize Postgres can run asyncpg DDL
# under Starlette's request loop and trigger CPython's empty-ready-queue race.
# Outcome: request handling sees the same app.state database objects as API
# lifespan startup would have provided.
async with _prepared_local_asgi_database(fastapi_app):
async with _build_asgi_client(fastapi_app, timeout) as client:
yield client


async def _resolve_cloud_token(config) -> str:
"""Resolve cloud token with API key preferred, OAuth fallback."""
with logfire.span(
Expand Down Expand Up @@ -260,7 +325,12 @@ def create_client() -> AsyncClient:

if _force_local_mode() or not _force_cloud_mode():
logger.info("Creating ASGI client for local Basic Memory API")
return _asgi_client(timeout)
# Deprecated sync path: create_client() cannot await the local ASGI
# pre-initialization used by get_client(), so callers that need proper
# resource setup should use the async context manager instead.
from basic_memory.api.app import app as fastapi_app

return _build_asgi_client(fastapi_app, timeout)

logger.info("Creating HTTP client for cloud proxy (legacy create_client path)")
config = ConfigManager().config
Expand Down
100 changes: 100 additions & 0 deletions tests/mcp/test_async_client_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,106 @@ async def test_get_client_default_uses_local_asgi_transport(config_manager):
assert isinstance(client._transport, httpx.ASGITransport) # pyright: ignore[reportPrivateUsage]


@pytest.mark.asyncio
async def test_get_client_preinitializes_local_asgi_database(config_manager, monkeypatch):
"""Local ASGI routing initializes DB state before request handling."""
from basic_memory import db
from basic_memory.api.app import app as fastapi_app

cfg = config_manager.load_config()
config_manager.save_config(cfg)

previous_engine = getattr(fastapi_app.state, "engine", None)
previous_session_maker = getattr(fastapi_app.state, "session_maker", None)
fastapi_app.state._state.pop("engine", None) # pyright: ignore[reportPrivateUsage]
fastapi_app.state._state.pop("session_maker", None) # pyright: ignore[reportPrivateUsage]

engine = object()
session_maker = object()
calls = []

async def fake_get_or_create_db(db_path):
calls.append(db_path)
return engine, session_maker

monkeypatch.setattr(db, "get_or_create_db", fake_get_or_create_db)

try:
async with get_client() as client:
assert isinstance(client._transport, httpx.ASGITransport) # pyright: ignore[reportPrivateUsage]
assert calls == [cfg.database_path]
assert fastapi_app.state.engine is engine
assert fastapi_app.state.session_maker is session_maker
assert not hasattr(fastapi_app.state, "engine")
assert not hasattr(fastapi_app.state, "session_maker")
finally:
if previous_engine is None:
fastapi_app.state._state.pop("engine", None) # pyright: ignore[reportPrivateUsage]
else:
fastapi_app.state.engine = previous_engine
if previous_session_maker is None:
fastapi_app.state._state.pop("session_maker", None) # pyright: ignore[reportPrivateUsage]
else:
fastapi_app.state.session_maker = previous_session_maker


@pytest.mark.parametrize("async_override", [False, True])
@pytest.mark.asyncio
async def test_get_client_uses_existing_local_asgi_database_override(
config_manager,
async_override,
):
"""Local ASGI routing honors FastAPI test dependency overrides."""
from basic_memory.api.app import app as fastapi_app
from basic_memory.deps import get_engine_factory

cfg = config_manager.load_config()
config_manager.save_config(cfg)

previous_overrides = dict(fastapi_app.dependency_overrides)
previous_engine = getattr(fastapi_app.state, "engine", None)
previous_session_maker = getattr(fastapi_app.state, "session_maker", None)
fastapi_app.state._state.pop("engine", None) # pyright: ignore[reportPrivateUsage]
fastapi_app.state._state.pop("session_maker", None) # pyright: ignore[reportPrivateUsage]

engine = object()
session_maker = object()
calls = []

if async_override:

async def override_engine_factory():
calls.append("override")
return engine, session_maker

else:

def override_engine_factory():
calls.append("override")
return engine, session_maker

fastapi_app.dependency_overrides[get_engine_factory] = override_engine_factory

try:
async with get_client() as client:
assert isinstance(client._transport, httpx.ASGITransport) # pyright: ignore[reportPrivateUsage]
assert calls == ["override"]
assert fastapi_app.state.engine is engine
assert fastapi_app.state.session_maker is session_maker
assert not hasattr(fastapi_app.state, "engine")
assert not hasattr(fastapi_app.state, "session_maker")
finally:
fastapi_app.dependency_overrides = previous_overrides
if previous_engine is None:
fastapi_app.state._state.pop("engine", None) # pyright: ignore[reportPrivateUsage]
else:
fastapi_app.state.engine = previous_engine
if previous_session_maker is None:
fastapi_app.state._state.pop("session_maker", None) # pyright: ignore[reportPrivateUsage]
else:
fastapi_app.state.session_maker = previous_session_maker


@pytest.mark.asyncio
async def test_get_client_explicit_cloud_uses_api_key(config_manager, monkeypatch):
cfg = config_manager.load_config()
Expand Down
Loading