Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 197 additions & 9 deletions src/basic_memory/mcp/async_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,36 @@
import os
from contextlib import AbstractAsyncContextManager, asynccontextmanager
from typing import AsyncIterator, Callable, Optional
from asyncio import Lock
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
from dataclasses import dataclass
from threading import RLock
from typing import TYPE_CHECKING, Annotated, Any, AsyncIterator, Callable, Optional

from fastapi import Depends, FastAPI, Request
from httpx import ASGITransport, AsyncClient, Timeout
from loguru import logger

import logfire
from basic_memory.config import ConfigManager, ProjectMode

if TYPE_CHECKING:
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker

LocalDatabaseState = tuple["AsyncEngine", "async_sessionmaker[AsyncSession]"]
_MISSING_STATE_VALUE = object()


@dataclass
class _PreparedLocalAsgiDatabase:
active_count: int
previous_engine: object
previous_session_maker: object
dependency_context: AbstractAsyncContextManager[LocalDatabaseState]


_prepared_local_asgi_database_lock = RLock()
_prepared_local_asgi_database_prepare_locks: dict[FastAPI, Lock] = {}
_prepared_local_asgi_databases: dict[FastAPI, _PreparedLocalAsgiDatabase] = {}


def _force_local_mode() -> bool:
"""Check if local mode is forced via environment variable."""
Expand All @@ -34,15 +57,12 @@ def _build_timeout() -> Timeout:
)


def _asgi_client(timeout: Timeout) -> AsyncClient:
"""Create a local ASGI client."""
# Import on first local-client use so CLI help/version paths can import
# routing helpers without constructing the full FastAPI router graph.
from basic_memory.api.app import app as fastapi_app
def _build_asgi_client(app: FastAPI, timeout: Timeout) -> AsyncClient:
"""Create a local ASGI client for an already-prepared FastAPI app."""
from basic_memory.workspace_context import workspace_permalink_headers

return AsyncClient(
transport=ASGITransport(app=fastapi_app),
transport=ASGITransport(app=app),
base_url="http://test",
timeout=timeout,
# Local ASGI calls still cross the HTTP boundary, so request handlers need
Expand All @@ -51,6 +71,169 @@ def _asgi_client(timeout: Timeout) -> AsyncClient:
)


def _get_prepared_local_asgi_database_prepare_lock(app: FastAPI) -> Lock:
"""Get the async lock that serializes first-time DB preparation for an app."""
with _prepared_local_asgi_database_lock:
prepare_lock = _prepared_local_asgi_database_prepare_locks.get(app)
if prepare_lock is None:
prepare_lock = Lock()
_prepared_local_asgi_database_prepare_locks[app] = prepare_lock
return prepare_lock


@asynccontextmanager
async def _resolve_local_asgi_database(app: FastAPI) -> AsyncIterator[LocalDatabaseState]:
"""Resolve database state for a local ASGI request."""
from fastapi.dependencies.utils import get_dependant, solve_dependencies

from basic_memory.deps import get_engine_factory

async def resolve_database_state(
database_state: Annotated[LocalDatabaseState, Depends(get_engine_factory)],
) -> LocalDatabaseState:
return database_state

scope: dict[str, Any] = {
"type": "http",
"asgi": {"version": "3.0"},
"method": "GET",
"scheme": "http",
"path": "/",
"raw_path": b"/",
"root_path": "",
"query_string": b"",
"headers": [],
"client": ("testclient", 50000),
"server": ("testserver", 80),
"app": app,
"path_params": {},
}

async with AsyncExitStack() as request_stack, AsyncExitStack() as function_stack:
scope["fastapi_inner_astack"] = request_stack
scope["fastapi_function_astack"] = function_stack
request = Request(scope)
dependant = get_dependant(path="/", call=resolve_database_state)
solved = await solve_dependencies(
request=request,
dependant=dependant,
dependency_overrides_provider=app,
async_exit_stack=request_stack,
embed_body_fields=False,
)
if solved.errors:
raise RuntimeError(f"Failed to resolve local ASGI database dependency: {solved.errors}")

yield await resolve_database_state(**solved.values)


def _retain_prepared_local_asgi_database(app: FastAPI) -> bool:
"""Retain an active local ASGI database preparation if one exists."""
with _prepared_local_asgi_database_lock:
active = _prepared_local_asgi_databases.get(app)
if active is None:
return False

active.active_count += 1
return True


def _install_prepared_local_asgi_database(
app: FastAPI,
database_state: LocalDatabaseState,
dependency_context: AbstractAsyncContextManager[LocalDatabaseState],
) -> None:
"""Install local ASGI database state after dependency resolution."""
with _prepared_local_asgi_database_lock:
active = _prepared_local_asgi_databases.get(app)
if active is not None:
raise RuntimeError("Local ASGI database state installed while another state is active")

previous_engine = getattr(app.state, "engine", _MISSING_STATE_VALUE)
previous_session_maker = getattr(app.state, "session_maker", _MISSING_STATE_VALUE)
engine, session_maker = database_state

app.state.engine = engine
app.state.session_maker = session_maker
_prepared_local_asgi_databases[app] = _PreparedLocalAsgiDatabase(
active_count=1,
previous_engine=previous_engine,
previous_session_maker=previous_session_maker,
dependency_context=dependency_context,
)


def _restore_local_asgi_state_attribute(app: FastAPI, name: str, previous_value: object) -> None:
"""Restore a FastAPI app.state attribute captured before local ASGI preparation."""
if previous_value is _MISSING_STATE_VALUE:
if hasattr(app.state, name):
delattr(app.state, name)
else:
setattr(app.state, name, previous_value)


def _release_prepared_local_asgi_database(
app: FastAPI,
) -> AbstractAsyncContextManager[LocalDatabaseState] | None:
"""Release local ASGI database state after a client context exits."""
with _prepared_local_asgi_database_lock:
active = _prepared_local_asgi_databases.get(app)
if active is None:
raise RuntimeError("Local ASGI database state released without a matching retain")

active.active_count -= 1
if active.active_count > 0:
return None

del _prepared_local_asgi_databases[app]
_restore_local_asgi_state_attribute(app, "engine", active.previous_engine)
_restore_local_asgi_state_attribute(
app,
"session_maker",
active.previous_session_maker,
)
return active.dependency_context


@asynccontextmanager
async def _prepared_local_asgi_database(app: FastAPI) -> AsyncIterator[None]:
"""Initialize local ASGI database state before the first request."""
prepare_lock = _get_prepared_local_asgi_database_prepare_lock(app)
async with prepare_lock:
if not _retain_prepared_local_asgi_database(app):
database_context = _resolve_local_asgi_database(app)
database_state = await database_context.__aenter__()
try:
_install_prepared_local_asgi_database(app, database_state, database_context)
except Exception:
await database_context.__aexit__(None, None, None)
raise

try:
yield
finally:
database_context = _release_prepared_local_asgi_database(app)
if database_context is not None:
await database_context.__aexit__(None, None, None)


@asynccontextmanager
async def _asgi_client(timeout: Timeout) -> AsyncIterator[AsyncClient]:
"""Create a local ASGI client."""
# Import on first local-client use so CLI help/version paths can import
# routing helpers without constructing the full FastAPI router graph.
from basic_memory.api.app import app as fastapi_app

# Trigger: local ASGITransport does not execute FastAPI lifespan startup.
# Why: letting request dependencies initialize Postgres can run asyncpg DDL
# under Starlette's request loop and trigger CPython's empty-ready-queue race.
# Outcome: request handling sees the same app.state database objects as API
# lifespan startup would have provided.
async with _prepared_local_asgi_database(fastapi_app):
async with _build_asgi_client(fastapi_app, timeout) as client:
yield client


async def _resolve_cloud_token(config) -> str:
"""Resolve cloud token with API key preferred, OAuth fallback."""
with logfire.span(
Expand Down Expand Up @@ -260,7 +443,12 @@ def create_client() -> AsyncClient:

if _force_local_mode() or not _force_cloud_mode():
logger.info("Creating ASGI client for local Basic Memory API")
return _asgi_client(timeout)
# Deprecated sync path: create_client() cannot await the local ASGI
# pre-initialization used by get_client(), so callers that need proper
# resource setup should use the async context manager instead.
from basic_memory.api.app import app as fastapi_app

return _build_asgi_client(fastapi_app, timeout)

logger.info("Creating HTTP client for cloud proxy (legacy create_client path)")
config = ConfigManager().config
Expand Down
Loading
Loading