Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 142 additions & 8 deletions src/basic_memory/mcp/async_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,34 @@
import os
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import AsyncIterator, Callable, Optional
from dataclasses import dataclass
from inspect import isawaitable
from threading import RLock
from typing import TYPE_CHECKING, AsyncIterator, Callable, Optional, cast

from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient, Timeout
from loguru import logger

import logfire
from basic_memory.config import ConfigManager, ProjectMode

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker

LocalDatabaseState = tuple["AsyncEngine", "async_sessionmaker[AsyncSession]"]
_MISSING_STATE_VALUE = object()


@dataclass
class _PreparedLocalAsgiDatabase:
active_count: int
previous_engine: object
previous_session_maker: object


_prepared_local_asgi_database_lock = RLock()
_prepared_local_asgi_databases: dict[FastAPI, _PreparedLocalAsgiDatabase] = {}


def _force_local_mode() -> bool:
"""Check if local mode is forced via environment variable."""
Expand All @@ -34,15 +55,12 @@ def _build_timeout() -> Timeout:
)


def _asgi_client(timeout: Timeout) -> AsyncClient:
"""Create a local ASGI client."""
# Import on first local-client use so CLI help/version paths can import
# routing helpers without constructing the full FastAPI router graph.
from basic_memory.api.app import app as fastapi_app
def _build_asgi_client(app: FastAPI, timeout: Timeout) -> AsyncClient:
"""Create a local ASGI client for an already-prepared FastAPI app."""
from basic_memory.workspace_context import workspace_permalink_headers

return AsyncClient(
transport=ASGITransport(app=fastapi_app),
transport=ASGITransport(app=app),
base_url="http://test",
timeout=timeout,
# Local ASGI calls still cross the HTTP boundary, so request handlers need
Expand All @@ -51,6 +69,117 @@ def _asgi_client(timeout: Timeout) -> AsyncClient:
)


async def _resolve_local_asgi_database(app: FastAPI) -> LocalDatabaseState:
"""Resolve database state for a local ASGI request."""
from basic_memory.deps import get_engine_factory

override = app.dependency_overrides.get(get_engine_factory)
if override is not None:
result = override()
if isawaitable(result):
result = await result
Comment thread
phernandez marked this conversation as resolved.
Outdated
return cast(LocalDatabaseState, result)

from basic_memory import db

config = ConfigManager().config
return await db.get_or_create_db(config.database_path)


def _retain_prepared_local_asgi_database(app: FastAPI) -> bool:
"""Retain an active local ASGI database preparation if one exists."""
with _prepared_local_asgi_database_lock:
active = _prepared_local_asgi_databases.get(app)
if active is None:
return False

active.active_count += 1
return True


def _install_prepared_local_asgi_database(
app: FastAPI,
database_state: LocalDatabaseState,
) -> None:
"""Install local ASGI database state or retain an overlapping installation."""
with _prepared_local_asgi_database_lock:
active = _prepared_local_asgi_databases.get(app)
if active is not None:
active.active_count += 1
return

previous_engine = getattr(app.state, "engine", _MISSING_STATE_VALUE)
previous_session_maker = getattr(app.state, "session_maker", _MISSING_STATE_VALUE)
engine, session_maker = database_state

app.state.engine = engine
app.state.session_maker = session_maker
_prepared_local_asgi_databases[app] = _PreparedLocalAsgiDatabase(
active_count=1,
previous_engine=previous_engine,
previous_session_maker=previous_session_maker,
)


def _restore_local_asgi_state_attribute(app: FastAPI, name: str, previous_value: object) -> None:
"""Restore a FastAPI app.state attribute captured before local ASGI preparation."""
if previous_value is _MISSING_STATE_VALUE:
if hasattr(app.state, name):
delattr(app.state, name)
else:
setattr(app.state, name, previous_value)


def _release_prepared_local_asgi_database(app: FastAPI) -> None:
"""Release local ASGI database state after a client context exits."""
with _prepared_local_asgi_database_lock:
active = _prepared_local_asgi_databases.get(app)
if active is None:
raise RuntimeError("Local ASGI database state released without a matching retain")

active.active_count -= 1
if active.active_count > 0:
return

del _prepared_local_asgi_databases[app]
_restore_local_asgi_state_attribute(app, "engine", active.previous_engine)
_restore_local_asgi_state_attribute(
app,
"session_maker",
active.previous_session_maker,
)


@asynccontextmanager
async def _prepared_local_asgi_database(app: FastAPI) -> AsyncIterator[None]:
"""Initialize local ASGI database state before the first request."""
if not _retain_prepared_local_asgi_database(app):
database_state = await _resolve_local_asgi_database(app)
_install_prepared_local_asgi_database(app, database_state)

try:
yield
finally:
_release_prepared_local_asgi_database(app)


@asynccontextmanager
async def _asgi_client(timeout: Timeout) -> AsyncIterator[AsyncClient]:
"""Create a local ASGI client."""
# Import on first local-client use so CLI help/version paths can import
# routing helpers without constructing the full FastAPI router graph.
from basic_memory.api.app import app as fastapi_app

# Trigger: local ASGITransport does not execute FastAPI lifespan startup.
# Why: letting request dependencies initialize Postgres can run asyncpg DDL
# under Starlette's request loop and trigger CPython's empty-ready-queue race.
# Outcome: request handling sees the same app.state database objects as API
# lifespan startup would have provided.
async with _prepared_local_asgi_database(fastapi_app):
async with _build_asgi_client(fastapi_app, timeout) as client:
yield client


async def _resolve_cloud_token(config) -> str:
"""Resolve cloud token with API key preferred, OAuth fallback."""
with logfire.span(
Expand Down Expand Up @@ -260,7 +389,12 @@ def create_client() -> AsyncClient:

if _force_local_mode() or not _force_cloud_mode():
logger.info("Creating ASGI client for local Basic Memory API")
return _asgi_client(timeout)
# Deprecated sync path: create_client() cannot await the local ASGI
# pre-initialization used by get_client(), so callers that need proper
# resource setup should use the async context manager instead.
from basic_memory.api.app import app as fastapi_app

return _build_asgi_client(fastapi_app, timeout)

logger.info("Creating HTTP client for cloud proxy (legacy create_client path)")
config = ConfigManager().config
Expand Down
172 changes: 172 additions & 0 deletions tests/mcp/test_async_client_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
@pytest.fixture(autouse=True)
def _reset_async_client_state(monkeypatch):
async_client_module._client_factory = None
async_client_module._prepared_local_asgi_databases.clear()
monkeypatch.delenv("BASIC_MEMORY_FORCE_LOCAL", raising=False)
monkeypatch.delenv("BASIC_MEMORY_FORCE_CLOUD", raising=False)
monkeypatch.delenv("BASIC_MEMORY_EXPLICIT_ROUTING", raising=False)
yield
async_client_module._client_factory = None
async_client_module._prepared_local_asgi_databases.clear()


@pytest.mark.asyncio
Expand Down Expand Up @@ -51,6 +53,176 @@ async def test_get_client_default_uses_local_asgi_transport(config_manager):
assert isinstance(client._transport, httpx.ASGITransport) # pyright: ignore[reportPrivateUsage]


@pytest.mark.asyncio
async def test_get_client_preinitializes_local_asgi_database(config_manager, monkeypatch):
"""Local ASGI routing initializes DB state before request handling."""
from basic_memory import db
from basic_memory.api.app import app as fastapi_app

cfg = config_manager.load_config()
config_manager.save_config(cfg)

previous_engine = getattr(fastapi_app.state, "engine", None)
previous_session_maker = getattr(fastapi_app.state, "session_maker", None)
fastapi_app.state._state.pop("engine", None) # pyright: ignore[reportPrivateUsage]
fastapi_app.state._state.pop("session_maker", None) # pyright: ignore[reportPrivateUsage]

engine = object()
session_maker = object()
calls = []

async def fake_get_or_create_db(db_path):
calls.append(db_path)
return engine, session_maker

monkeypatch.setattr(db, "get_or_create_db", fake_get_or_create_db)

try:
async with get_client() as client:
assert isinstance(client._transport, httpx.ASGITransport) # pyright: ignore[reportPrivateUsage]
assert calls == [cfg.database_path]
assert fastapi_app.state.engine is engine
assert fastapi_app.state.session_maker is session_maker
assert not hasattr(fastapi_app.state, "engine")
assert not hasattr(fastapi_app.state, "session_maker")
finally:
if previous_engine is None:
fastapi_app.state._state.pop("engine", None) # pyright: ignore[reportPrivateUsage]
else:
fastapi_app.state.engine = previous_engine
if previous_session_maker is None:
fastapi_app.state._state.pop("session_maker", None) # pyright: ignore[reportPrivateUsage]
else:
fastapi_app.state.session_maker = previous_session_maker


@pytest.mark.parametrize("async_override", [False, True])
@pytest.mark.asyncio
async def test_get_client_uses_existing_local_asgi_database_override(
config_manager,
async_override,
):
"""Local ASGI routing honors FastAPI test dependency overrides."""
from basic_memory.api.app import app as fastapi_app
from basic_memory.deps import get_engine_factory

cfg = config_manager.load_config()
config_manager.save_config(cfg)

previous_overrides = dict(fastapi_app.dependency_overrides)
previous_engine = getattr(fastapi_app.state, "engine", None)
previous_session_maker = getattr(fastapi_app.state, "session_maker", None)
fastapi_app.state._state.pop("engine", None) # pyright: ignore[reportPrivateUsage]
fastapi_app.state._state.pop("session_maker", None) # pyright: ignore[reportPrivateUsage]

engine = object()
session_maker = object()
calls = []

if async_override:

async def override_engine_factory():
calls.append("override")
return engine, session_maker

else:

def override_engine_factory():
calls.append("override")
return engine, session_maker

fastapi_app.dependency_overrides[get_engine_factory] = override_engine_factory

try:
async with get_client() as client:
assert isinstance(client._transport, httpx.ASGITransport) # pyright: ignore[reportPrivateUsage]
assert calls == ["override"]
assert fastapi_app.state.engine is engine
assert fastapi_app.state.session_maker is session_maker
assert not hasattr(fastapi_app.state, "engine")
assert not hasattr(fastapi_app.state, "session_maker")
finally:
fastapi_app.dependency_overrides = previous_overrides
if previous_engine is None:
fastapi_app.state._state.pop("engine", None) # pyright: ignore[reportPrivateUsage]
else:
fastapi_app.state.engine = previous_engine
if previous_session_maker is None:
fastapi_app.state._state.pop("session_maker", None) # pyright: ignore[reportPrivateUsage]
else:
fastapi_app.state.session_maker = previous_session_maker


@pytest.mark.asyncio
async def test_get_client_keeps_local_asgi_database_during_overlapping_contexts(
config_manager,
monkeypatch,
):
"""Local ASGI database state stays installed until the last overlapping client exits."""
from basic_memory import db
from basic_memory.api.app import app as fastapi_app

cfg = config_manager.load_config()
config_manager.save_config(cfg)

previous_engine = getattr(fastapi_app.state, "engine", None)
previous_session_maker = getattr(fastapi_app.state, "session_maker", None)
fastapi_app.state._state.pop("engine", None) # pyright: ignore[reportPrivateUsage]
fastapi_app.state._state.pop("session_maker", None) # pyright: ignore[reportPrivateUsage]

engine = object()
session_maker = object()
calls = []

async def fake_get_or_create_db(db_path):
calls.append(db_path)
return engine, session_maker

monkeypatch.setattr(db, "get_or_create_db", fake_get_or_create_db)

first_context = get_client()
second_context = get_client()
first_entered = False
second_entered = False
first_exited = False

try:
first_client = await first_context.__aenter__()
first_entered = True
second_client = await second_context.__aenter__()
second_entered = True

assert isinstance(first_client._transport, httpx.ASGITransport) # pyright: ignore[reportPrivateUsage]
assert isinstance(second_client._transport, httpx.ASGITransport) # pyright: ignore[reportPrivateUsage]
assert calls == [cfg.database_path]

await first_context.__aexit__(None, None, None)
first_exited = True

assert fastapi_app.state.engine is engine
assert fastapi_app.state.session_maker is session_maker

await second_context.__aexit__(None, None, None)
second_entered = False

assert not hasattr(fastapi_app.state, "engine")
assert not hasattr(fastapi_app.state, "session_maker")
finally:
if second_entered:
await second_context.__aexit__(None, None, None)
if first_entered and not first_exited:
await first_context.__aexit__(None, None, None)

if previous_engine is None:
fastapi_app.state._state.pop("engine", None) # pyright: ignore[reportPrivateUsage]
else:
fastapi_app.state.engine = previous_engine
if previous_session_maker is None:
fastapi_app.state._state.pop("session_maker", None) # pyright: ignore[reportPrivateUsage]
else:
fastapi_app.state.session_maker = previous_session_maker


@pytest.mark.asyncio
async def test_get_client_explicit_cloud_uses_api_key(config_manager, monkeypatch):
cfg = config_manager.load_config()
Expand Down
Loading