diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 88a68f96..4e27a3ad 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -52,7 +52,6 @@ jobs: test-sqlite-unit: name: Test SQLite Unit (${{ matrix.os }}, Python ${{ matrix.python-version }}) timeout-minutes: 30 - needs: [static-checks] strategy: fail-fast: false matrix: @@ -99,7 +98,6 @@ jobs: test-sqlite-integration: name: Test SQLite Integration (${{ matrix.os }}, Python ${{ matrix.python-version }}) timeout-minutes: 45 - needs: [static-checks] strategy: fail-fast: false matrix: @@ -146,7 +144,7 @@ jobs: test-postgres-unit: name: Test Postgres Unit (Python ${{ matrix.python-version }}) timeout-minutes: 30 - needs: [static-checks] + if: github.event_name != 'pull_request' || matrix.python-version == '3.12' strategy: fail-fast: false matrix: @@ -155,8 +153,22 @@ jobs: - python-version: "3.13" - python-version: "3.14" runs-on: ubuntu-latest - - # Note: No services section needed - testcontainers handles Postgres in Docker + services: + postgres: + image: pgvector/pgvector:pg16 + env: + POSTGRES_USER: basic_memory_user + POSTGRES_PASSWORD: dev_password + POSTGRES_DB: basic_memory_test + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U basic_memory_user -d basic_memory_test" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + env: + BASIC_MEMORY_TEST_POSTGRES_URL: postgresql://basic_memory_user:dev_password@127.0.0.1:5432/basic_memory_test steps: - uses: actions/checkout@v4 @@ -190,7 +202,7 @@ jobs: test-postgres-integration: name: Test Postgres Integration (Python ${{ matrix.python-version }}) timeout-minutes: 45 - needs: [static-checks] + if: github.event_name != 'pull_request' || matrix.python-version == '3.12' strategy: fail-fast: false matrix: @@ -199,8 +211,22 @@ jobs: - python-version: "3.13" - python-version: "3.14" runs-on: ubuntu-latest - - # Note: No services section needed - testcontainers handles Postgres in Docker + services: + postgres: + image: pgvector/pgvector:pg16 + env: + POSTGRES_USER: basic_memory_user + POSTGRES_PASSWORD: dev_password + POSTGRES_DB: basic_memory_test + ports: + - 5432:5432 + options: >- + --health-cmd "pg_isready -U basic_memory_user -d basic_memory_test" + --health-interval 10s + --health-timeout 5s + --health-retries 5 + env: + BASIC_MEMORY_TEST_POSTGRES_URL: postgresql://basic_memory_user:dev_password@127.0.0.1:5432/basic_memory_test steps: - uses: actions/checkout@v4 @@ -234,7 +260,6 @@ jobs: test-semantic: name: Test Semantic (Python 3.12) timeout-minutes: 45 - needs: [static-checks] runs-on: ubuntu-latest steps: diff --git a/AGENTS.md b/AGENTS.md index 8bfa0c76..239bcc12 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -442,5 +442,9 @@ With GitHub integration, the development workflow includes: 3. **Branch management** - Claude can create feature branches for implementations 4. **Documentation maintenance** - Claude can keep documentation updated as the code evolves 5. **Code Commits**: ALWAYS sign off commits with `git commit -s` +6. **Pull Request Titles**: PR titles must follow the semantic format enforced by `.github/workflows/pr-title.yml`: `type(scope): summary` + - Allowed types: `feat`, `fix`, `chore`, `docs`, `style`, `refactor`, `perf`, `test`, `build`, `ci` + - Allowed scopes: `core`, `cli`, `api`, `mcp`, `sync`, `ui`, `deps`, `installer` + - Example: `fix(cli): propagate cloud workspace routing` This level of integration represents a new paradigm in AI-human collaboration, where the AI assistant becomes a full-fledged team member rather than just a tool for generating code snippets. diff --git a/src/basic_memory/cli/commands/cloud/cloud_utils.py b/src/basic_memory/cli/commands/cloud/cloud_utils.py index b4c1158a..1e21178f 100644 --- a/src/basic_memory/cli/commands/cloud/cloud_utils.py +++ b/src/basic_memory/cli/commands/cloud/cloud_utils.py @@ -2,6 +2,7 @@ from basic_memory.cli.commands.cloud.api_client import make_api_request from basic_memory.config import ConfigManager +from basic_memory.mcp.async_client import resolve_configured_workspace from basic_memory.schemas.cloud import ( CloudProjectList, CloudProjectCreateRequest, @@ -16,8 +17,25 @@ class CloudUtilsError(Exception): pass +def _workspace_headers( + *, + project_name: str | None = None, + workspace: str | None = None, +) -> dict[str, str]: + """Build optional workspace headers using the CLI config resolution chain.""" + resolved_workspace = resolve_configured_workspace( + project_name=project_name, + workspace=workspace, + ) + if resolved_workspace is None: + return {} + return {"X-Workspace-ID": resolved_workspace} + + async def fetch_cloud_projects( *, + project_name: str | None = None, + workspace: str | None = None, api_request=make_api_request, ) -> CloudProjectList: """Fetch list of projects from cloud API. @@ -30,7 +48,11 @@ async def fetch_cloud_projects( config = config_manager.config host_url = config.cloud_host.rstrip("/") - response = await api_request(method="GET", url=f"{host_url}/proxy/v2/projects/") + response = await api_request( + method="GET", + url=f"{host_url}/proxy/v2/projects/", + headers=_workspace_headers(project_name=project_name, workspace=workspace), + ) return CloudProjectList.model_validate(response.json()) except Exception as e: @@ -40,12 +62,14 @@ async def fetch_cloud_projects( async def create_cloud_project( project_name: str, *, + workspace: str | None = None, api_request=make_api_request, ) -> CloudProjectCreateResponse: """Create a new project on cloud. Args: project_name: Name of project to create + workspace: Optional workspace override for tenant-scoped project creation Returns: CloudProjectCreateResponse with project details from API @@ -67,7 +91,10 @@ async def create_cloud_project( response = await api_request( method="POST", url=f"{host_url}/proxy/v2/projects/", - headers={"Content-Type": "application/json"}, + headers={ + "Content-Type": "application/json", + **_workspace_headers(project_name=project_name, workspace=workspace), + }, json_data=project_data.model_dump(), ) @@ -91,18 +118,28 @@ async def sync_project(project_name: str, force_full: bool = False) -> None: raise CloudUtilsError(f"Failed to sync project '{project_name}': {e}") from e -async def project_exists(project_name: str, *, api_request=make_api_request) -> bool: +async def project_exists( + project_name: str, + *, + workspace: str | None = None, + api_request=make_api_request, +) -> bool: """Check if a project exists on cloud. Args: project_name: Name of project to check + workspace: Optional workspace override for tenant-scoped project lookup Returns: True if project exists, False otherwise + + Raises: + CloudUtilsError: If the project list cannot be fetched from cloud """ - try: - projects = await fetch_cloud_projects(api_request=api_request) - project_names = {p.name for p in projects.projects} - return project_name in project_names - except Exception: - return False + projects = await fetch_cloud_projects( + project_name=project_name, + workspace=workspace, + api_request=api_request, + ) + project_names = {p.name for p in projects.projects} + return project_name in project_names diff --git a/src/basic_memory/cli/commands/cloud/project_sync.py b/src/basic_memory/cli/commands/cloud/project_sync.py index 08b616be..c963a09b 100644 --- a/src/basic_memory/cli/commands/cloud/project_sync.py +++ b/src/basic_memory/cli/commands/cloud/project_sync.py @@ -54,7 +54,7 @@ def _require_cloud_credentials(config) -> None: async def _get_cloud_project(name: str) -> ProjectItem | None: """Fetch a project by name from the cloud API.""" - async with get_client() as client: + async with get_client(project_name=name) as client: projects_list = await ProjectClient(client).list_projects() for proj in projects_list.projects: if generate_permalink(proj.name) == generate_permalink(name): @@ -129,9 +129,9 @@ def sync_project_command( if not dry_run: async def _trigger_db_sync(): - async with get_client() as client: + async with get_client(project_name=name) as client: return await ProjectClient(client).sync( - project_data.external_id, force_full=True + project_data.external_id, force_full=False ) try: @@ -195,7 +195,10 @@ def bisync_project_command( # Update config — sync_entry is guaranteed non-None because # _get_sync_project validated local_sync_path (which comes from sync_entry) sync_entry = config.projects.get(name) - assert sync_entry is not None + if sync_entry is None: + raise RuntimeError( + f"Sync entry for project '{name}' unexpectedly missing after validation" + ) sync_entry.last_sync = datetime.now() sync_entry.bisync_initialized = True ConfigManager().save_config(config) @@ -204,9 +207,9 @@ def bisync_project_command( if not dry_run: async def _trigger_db_sync(): - async with get_client() as client: + async with get_client(project_name=name) as client: return await ProjectClient(client).sync( - project_data.external_id, force_full=True + project_data.external_id, force_full=False ) try: @@ -320,7 +323,7 @@ def setup_project_sync( async def _verify_project_exists(): """Verify the project exists on cloud by listing all projects.""" - async with get_client() as client: + async with get_client(project_name=name) as client: projects_list = await ProjectClient(client).list_projects() project_names = [p.name for p in projects_list.projects] if name not in project_names: diff --git a/src/basic_memory/cli/commands/cloud/upload_command.py b/src/basic_memory/cli/commands/cloud/upload_command.py index b27c83ff..179a3b27 100644 --- a/src/basic_memory/cli/commands/cloud/upload_command.py +++ b/src/basic_memory/cli/commands/cloud/upload_command.py @@ -1,5 +1,6 @@ """Upload CLI commands for basic-memory projects.""" +from functools import partial from pathlib import Path import typer @@ -8,12 +9,16 @@ from basic_memory.cli.app import cloud_app from basic_memory.cli.commands.command_utils import run_with_cleanup from basic_memory.cli.commands.cloud.cloud_utils import ( + CloudUtilsError, create_cloud_project, project_exists, sync_project, ) from basic_memory.cli.commands.cloud.upload import upload_path -from basic_memory.mcp.async_client import get_cloud_control_plane_client +from basic_memory.mcp.async_client import ( + get_cloud_control_plane_client, + resolve_configured_workspace, +) console = Console() @@ -73,12 +78,20 @@ def upload( """ async def _upload(): + resolved_workspace = resolve_configured_workspace(project_name=project) + + try: + project_already_exists = await project_exists(project, workspace=resolved_workspace) + except CloudUtilsError as e: + console.print(f"[red]Failed to check cloud project '{project}': {e}[/red]") + raise typer.Exit(1) + # Check if project exists - if not await project_exists(project): + if not project_already_exists: if create_project: console.print(f"[blue]Creating cloud project '{project}'...[/blue]") try: - await create_cloud_project(project) + await create_cloud_project(project, workspace=resolved_workspace) console.print(f"[green]Created project '{project}'[/green]") except Exception as e: console.print(f"[red]Failed to create project: {e}[/red]") @@ -106,7 +119,10 @@ async def _upload(): verbose=verbose, use_gitignore=not no_gitignore, dry_run=dry_run, - client_cm_factory=get_cloud_control_plane_client, + client_cm_factory=partial( + get_cloud_control_plane_client, + workspace=resolved_workspace, + ), ) if not success: console.print("[red]Upload failed[/red]") @@ -117,8 +133,10 @@ async def _upload(): else: console.print(f"[green]Successfully uploaded to '{project}'[/green]") - # Sync project if requested (skip on dry run) - # Force full scan after bisync to ensure database is up-to-date with synced files + # Sync project if requested (skip on dry run). + # Trigger: upload adds new files the watcher has not observed locally. + # Why: force_full ensures those freshly uploaded files are indexed immediately. + # Outcome: upload keeps its eager reindex while sync/bisync stay incremental. if sync and not dry_run: console.print(f"[blue]Syncing project '{project}'...[/blue]") try: diff --git a/src/basic_memory/mcp/async_client.py b/src/basic_memory/mcp/async_client.py index 5b794c23..76580f05 100644 --- a/src/basic_memory/mcp/async_client.py +++ b/src/basic_memory/mcp/async_client.py @@ -66,6 +66,27 @@ async def _resolve_cloud_token(config) -> str: ) +def resolve_configured_workspace( + *, + config=None, + project_name: Optional[str] = None, + workspace: Optional[str] = None, +) -> Optional[str]: + """Resolve workspace from explicit input, per-project config, then global default.""" + if workspace is not None: + return workspace + + if config is None: + config = ConfigManager().config + + if project_name is not None: + project_entry = config.projects.get(project_name) + if project_entry and project_entry.workspace_id: + return project_entry.workspace_id + + return config.default_workspace + + @asynccontextmanager async def _cloud_client( config, @@ -88,15 +109,20 @@ async def _cloud_client( @asynccontextmanager -async def get_cloud_control_plane_client() -> AsyncIterator[AsyncClient]: +async def get_cloud_control_plane_client( + workspace: Optional[str] = None, +) -> AsyncIterator[AsyncClient]: """Create a control-plane cloud client for endpoints outside /proxy.""" config = ConfigManager().config timeout = _build_timeout() token = await _resolve_cloud_token(config) + headers = {"Authorization": f"Bearer {token}"} + if workspace: + headers["X-Workspace-ID"] = workspace logger.info(f"Creating HTTP client for cloud control plane at: {config.cloud_host}") async with AsyncClient( base_url=config.cloud_host, - headers={"Authorization": f"Bearer {token}"}, + headers=headers, timeout=timeout, ) as client: yield client @@ -167,7 +193,12 @@ async def get_client( if _force_cloud_mode(): logger.debug("Explicit cloud routing enabled - using cloud proxy client") - async with _cloud_client(config, timeout, workspace=workspace) as client: + effective_workspace = resolve_configured_workspace( + config=config, + project_name=project_name, + workspace=workspace, + ) + async with _cloud_client(config, timeout, workspace=effective_workspace) as client: yield client return @@ -179,8 +210,13 @@ async def get_client( project_mode = config.get_project_mode(project_name) if project_mode == ProjectMode.CLOUD: logger.debug(f"Project '{project_name}' is cloud mode - using cloud proxy client") + effective_workspace = resolve_configured_workspace( + config=config, + project_name=project_name, + workspace=workspace, + ) try: - async with _cloud_client(config, timeout, workspace=workspace) as client: + async with _cloud_client(config, timeout, workspace=effective_workspace) as client: yield client except RuntimeError as exc: raise RuntimeError( diff --git a/test-int/conftest.py b/test-int/conftest.py index 07f5af38..0ee5fc32 100644 --- a/test-int/conftest.py +++ b/test-int/conftest.py @@ -51,7 +51,7 @@ async def test_my_mcp_tool(mcp_server, app): """ import os -from typing import AsyncGenerator, Literal +from typing import AsyncGenerator, Generator, Literal import pytest import pytest_asyncio @@ -63,7 +63,13 @@ async def test_my_mcp_tool(mcp_server, app): from httpx import AsyncClient, ASGITransport -from basic_memory.config import BasicMemoryConfig, ProjectConfig, ConfigManager, DatabaseBackend +from basic_memory.config import ( + BasicMemoryConfig, + ProjectConfig, + ProjectEntry, + ConfigManager, + DatabaseBackend, +) from basic_memory.db import engine_session_factory, DatabaseType from basic_memory.models import Project from basic_memory.models.base import Base @@ -103,7 +109,7 @@ def postgres_container(db_backend): Uses testcontainers to spin up a real Postgres instance. Only starts if db_backend is "postgres". """ - if db_backend != "postgres": + if db_backend != "postgres" or _configured_postgres_sync_url(): yield None return @@ -112,6 +118,73 @@ def postgres_container(db_backend): yield postgres +POSTGRES_EPHEMERAL_TABLES = [ + "search_vector_embeddings", + "search_vector_chunks", + "search_vector_index", +] + + +def _configured_postgres_sync_url() -> str | None: + """Prefer an externally managed Postgres server when CI provides one.""" + configured_url = os.environ.get("BASIC_MEMORY_TEST_POSTGRES_URL") or os.environ.get( + "POSTGRES_TEST_URL" + ) + if not configured_url: + return None + + return ( + configured_url.replace("postgresql+asyncpg://", "postgresql+psycopg2://", 1) + .replace("postgresql://", "postgresql+psycopg2://", 1) + .replace("postgres://", "postgresql+psycopg2://", 1) + ) + + +def _postgres_reset_tables() -> list[str]: + """Resolve the current ORM table set at reset time.""" + return [table.name for table in Base.metadata.sorted_tables] + ["search_index"] + + +def _resolve_postgres_sync_url(postgres_container) -> str: + """Use CI's shared service when configured, otherwise fall back to testcontainers.""" + configured_url = _configured_postgres_sync_url() + if configured_url: + return configured_url + assert postgres_container is not None + return postgres_container.get_connection_url() + + +async def _reset_postgres_integration_schema(engine) -> None: + """Restore the shared Postgres integration schema to a clean baseline.""" + from basic_memory.models.search import ( + CREATE_POSTGRES_SEARCH_INDEX_FTS, + CREATE_POSTGRES_SEARCH_INDEX_METADATA, + CREATE_POSTGRES_SEARCH_INDEX_PERMALINK, + CREATE_POSTGRES_SEARCH_INDEX_TABLE, + ) + + async with engine.begin() as conn: + # Trigger: integration tests may leave behind temporary search/vector tables while + # exercising full-stack recovery paths. + # Why: recreating only the missing schema is much cheaper than dropping every table. + # Outcome: each integration test gets the same baseline without paying repeated full DDL cost. + await conn.run_sync(Base.metadata.create_all) + await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_TABLE) + await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_FTS) + await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_METADATA) + await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_PERMALINK) + + for table_name in POSTGRES_EPHEMERAL_TABLES: + await conn.execute(text(f"DROP TABLE IF EXISTS {table_name} CASCADE")) + + await conn.execute( + text( + f"TRUNCATE TABLE {', '.join(_postgres_reset_tables())} " + "RESTART IDENTITY CASCADE" + ) + ) + + @pytest_asyncio.fixture async def engine_factory( app_config, @@ -121,18 +194,12 @@ async def engine_factory( tmp_path, ) -> AsyncGenerator[tuple, None]: """Create engine and session factory for the configured database backend.""" - from basic_memory.models.search import ( - CREATE_SEARCH_INDEX, - CREATE_POSTGRES_SEARCH_INDEX_TABLE, - CREATE_POSTGRES_SEARCH_INDEX_FTS, - CREATE_POSTGRES_SEARCH_INDEX_METADATA, - CREATE_POSTGRES_SEARCH_INDEX_PERMALINK, - ) + from basic_memory.models.search import CREATE_SEARCH_INDEX from basic_memory import db if db_backend == "postgres": # Postgres mode using testcontainers - sync_url = postgres_container.get_connection_url() + sync_url = _resolve_postgres_sync_url(postgres_container) async_url = sync_url.replace("postgresql+psycopg2", "postgresql+asyncpg") engine = create_async_engine( @@ -153,16 +220,7 @@ async def engine_factory( db._engine = engine db._session_maker = session_maker - # Drop and recreate all tables for test isolation - async with engine.begin() as conn: - await conn.execute(text("DROP TABLE IF EXISTS search_index CASCADE")) - await conn.run_sync(Base.metadata.drop_all) - await conn.run_sync(Base.metadata.create_all) - # asyncpg requires separate execute calls for each statement - await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_TABLE) - await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_FTS) - await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_METADATA) - await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_PERMALINK) + await _reset_postgres_integration_schema(engine) yield engine, session_maker @@ -228,13 +286,15 @@ def app_config( monkeypatch.setenv("BASIC_MEMORY_CLOUD_MODE", "false") # Create a basic config with test-project like unit tests do - projects = {"test-project": str(config_home)} + projects = {"test-project": ProjectEntry(path=str(config_home))} # Configure database backend based on env var if db_backend == "postgres": database_backend = DatabaseBackend.POSTGRES - # Get URL from testcontainer and convert to asyncpg driver - sync_url = postgres_container.get_connection_url() + # Trigger: CI jobs can provide a shared Postgres service instead of per-session containers. + # Why: reusing one pgvector-enabled server avoids Docker startup churn on every job. + # Outcome: local runs keep using testcontainers, while CI injects a stable service URL. + sync_url = _resolve_postgres_sync_url(postgres_container) database_url = sync_url.replace("postgresql+psycopg2", "postgresql+asyncpg") else: database_backend = DatabaseBackend.SQLITE @@ -285,7 +345,9 @@ def project_config(test_project): @pytest.fixture -def app(app_config, project_config, engine_factory, test_project, config_manager) -> FastAPI: +def app( + app_config, project_config, engine_factory, test_project, config_manager +) -> Generator[FastAPI, None, None]: """Create test FastAPI application with single project.""" # Import the FastAPI app AFTER the config_manager has written the test config to disk diff --git a/tests/cli/cloud/test_cloud_api_client_and_utils.py b/tests/cli/cloud/test_cloud_api_client_and_utils.py index 65f2f99e..60aef12d 100644 --- a/tests/cli/cloud/test_cloud_api_client_and_utils.py +++ b/tests/cli/cloud/test_cloud_api_client_and_utils.py @@ -10,10 +10,12 @@ make_api_request, ) from basic_memory.cli.commands.cloud.cloud_utils import ( + CloudUtilsError, create_cloud_project, fetch_cloud_projects, project_exists, ) +from basic_memory.config import ProjectMode @pytest.mark.asyncio @@ -165,6 +167,70 @@ async def api_request(**kwargs): assert seen["create_payload"]["path"] == "my-project" +@pytest.mark.asyncio +async def test_cloud_utils_use_configured_workspace_headers(config_home, config_manager): + """Workspace-aware cloud helpers should prefer project workspace over global default.""" + config = config_manager.load_config() + config.cloud_host = "https://cloud.example.test" + config.default_workspace = "default-workspace" + config.set_project_mode("alpha", ProjectMode.CLOUD) + config.projects["alpha"].workspace_id = "project-workspace" + config_manager.save_config(config) + + seen: list[tuple[str, str | None]] = [] + + async def api_request(**kwargs): + seen.append( + ( + kwargs["method"], + (kwargs.get("headers") or {}).get("X-Workspace-ID"), + ) + ) + + if kwargs["method"] == "GET": + return httpx.Response( + 200, + json={ + "projects": [{"id": 1, "name": "alpha", "path": "alpha", "is_default": True}] + }, + ) + + return httpx.Response( + 200, + json={ + "message": "created", + "status": "success", + "default": False, + "old_project": None, + "new_project": {"name": "alpha", "path": "alpha"}, + }, + ) + + assert await project_exists("alpha", api_request=api_request) is True + await create_cloud_project("alpha", api_request=api_request) + await fetch_cloud_projects(project_name="missing", api_request=api_request) + + assert seen == [ + ("GET", "project-workspace"), + ("POST", "project-workspace"), + ("GET", "default-workspace"), + ] + + +@pytest.mark.asyncio +async def test_project_exists_surfaces_cloud_lookup_failures(config_home, config_manager): + """project_exists should surface lookup failures instead of pretending the project is missing.""" + config = config_manager.load_config() + config.cloud_host = "https://cloud.example.test" + config_manager.save_config(config) + + async def api_request(**_kwargs): + raise httpx.ConnectError("boom") + + with pytest.raises(CloudUtilsError, match="Failed to fetch cloud projects"): + await project_exists("alpha", api_request=api_request) + + @pytest.mark.asyncio async def test_make_api_request_prefers_api_key_over_oauth(config_home, config_manager): """API key in config should be used without needing an OAuth token on disk.""" diff --git a/tests/cli/cloud/test_project_sync_command.py b/tests/cli/cloud/test_project_sync_command.py new file mode 100644 index 00000000..4d5dda94 --- /dev/null +++ b/tests/cli/cloud/test_project_sync_command.py @@ -0,0 +1,114 @@ +"""Tests for cloud sync and bisync command behavior.""" + +import importlib +from contextlib import asynccontextmanager +from types import SimpleNamespace + +import pytest +from typer.testing import CliRunner + +from basic_memory.cli.app import app +from basic_memory.config import ProjectMode + +runner = CliRunner() + + +@pytest.mark.parametrize( + "argv", + [ + ["cloud", "sync", "--name", "research"], + ["cloud", "bisync", "--name", "research"], + ], +) +def test_cloud_sync_commands_use_incremental_db_sync(monkeypatch, argv, config_manager): + """Cloud sync commands should not force a full database re-index after file sync.""" + project_sync_command = importlib.import_module("basic_memory.cli.commands.cloud.project_sync") + + seen: dict[str, object] = {} + config = config_manager.load_config() + config.set_project_mode("research", ProjectMode.CLOUD) + config_manager.save_config(config) + + monkeypatch.setattr(project_sync_command, "_require_cloud_credentials", lambda _config: None) + monkeypatch.setattr( + project_sync_command, + "get_mount_info", + lambda: _async_value(SimpleNamespace(bucket_name="tenant-bucket")), + ) + monkeypatch.setattr( + project_sync_command, + "_get_cloud_project", + lambda _name: _async_value( + SimpleNamespace(name="research", external_id="external-project-id", path="research") + ), + ) + monkeypatch.setattr( + project_sync_command, + "_get_sync_project", + lambda _name, _config, _project_data: (SimpleNamespace(name="research"), "/tmp/research"), + ) + monkeypatch.setattr(project_sync_command, "project_sync", lambda *args, **kwargs: True) + monkeypatch.setattr(project_sync_command, "project_bisync", lambda *args, **kwargs: True) + + @asynccontextmanager + async def fake_get_client(*, project_name=None, workspace=None): + seen["project_name"] = project_name + seen["workspace"] = workspace + yield object() + + class FakeProjectClient: + def __init__(self, _client): + pass + + async def sync(self, external_id: str, force_full: bool = False): + seen["external_id"] = external_id + seen["force_full"] = force_full + return {"message": "queued"} + + monkeypatch.setattr(project_sync_command, "get_client", fake_get_client) + monkeypatch.setattr(project_sync_command, "ProjectClient", FakeProjectClient) + + result = runner.invoke(app, argv) + + assert result.exit_code == 0, result.output + assert seen["project_name"] == "research" + assert seen["external_id"] == "external-project-id" + assert seen["force_full"] is False + + +def test_cloud_bisync_fails_fast_when_sync_entry_disappears(monkeypatch, config_manager): + """Bisync should raise a runtime error when validated sync config vanishes before persistence.""" + project_sync_command = importlib.import_module("basic_memory.cli.commands.cloud.project_sync") + + config = config_manager.load_config() + config.projects.pop("research", None) + config_manager.save_config(config) + + monkeypatch.setattr(project_sync_command, "_require_cloud_credentials", lambda _config: None) + monkeypatch.setattr( + project_sync_command, + "get_mount_info", + lambda: _async_value(SimpleNamespace(bucket_name="tenant-bucket")), + ) + monkeypatch.setattr( + project_sync_command, + "_get_cloud_project", + lambda _name: _async_value( + SimpleNamespace(name="research", external_id="external-project-id", path="research") + ), + ) + monkeypatch.setattr( + project_sync_command, + "_get_sync_project", + lambda _name, _config, _project_data: (SimpleNamespace(name="research"), "/tmp/research"), + ) + monkeypatch.setattr(project_sync_command, "project_bisync", lambda *args, **kwargs: True) + + result = runner.invoke(app, ["cloud", "bisync", "--name", "research"]) + + assert result.exit_code == 1, result.output + assert "unexpectedly missing after validation" in result.output + + +async def _async_value(value): + return value diff --git a/tests/cli/cloud/test_upload_command_routing.py b/tests/cli/cloud/test_upload_command_routing.py index 6bb0115b..9a42e2b5 100644 --- a/tests/cli/cloud/test_upload_command_routing.py +++ b/tests/cli/cloud/test_upload_command_routing.py @@ -6,6 +6,8 @@ from typer.testing import CliRunner from basic_memory.cli.app import app +from basic_memory.cli.commands.cloud.cloud_utils import CloudUtilsError +from basic_memory.config import ProjectMode runner = CliRunner() @@ -20,11 +22,11 @@ def test_cloud_upload_uses_control_plane_client(monkeypatch, tmp_path): seen: dict[str, str] = {} - async def fake_project_exists(_project_name: str) -> bool: + async def fake_project_exists(_project_name: str, workspace: str | None = None) -> bool: return True @asynccontextmanager - async def fake_get_client(): + async def fake_get_client(workspace: str | None = None): async with httpx.AsyncClient(base_url="https://cloud.example.test") as client: yield client @@ -53,3 +55,89 @@ async def fake_upload_path(*args, **kwargs): assert result.exit_code == 0, result.output assert seen["base_url"] == "https://cloud.example.test" + + +def test_cloud_upload_uses_project_workspace_for_api_and_webdav( + monkeypatch, tmp_path, config_manager +): + """Upload command should reuse the configured workspace across API and WebDAV calls.""" + import basic_memory.cli.commands.cloud.upload_command as upload_command + + config = config_manager.load_config() + config.default_workspace = "default-workspace" + config.set_project_mode("routing-test", ProjectMode.CLOUD) + config.projects["routing-test"].workspace_id = "project-workspace" + config_manager.save_config(config) + + upload_dir = tmp_path / "upload" + upload_dir.mkdir() + (upload_dir / "note.md").write_text("hello", encoding="utf-8") + + seen: dict[str, str | None] = {} + + async def fake_project_exists(_project_name: str, workspace: str | None = None) -> bool: + seen["project_exists_workspace"] = workspace + return True + + @asynccontextmanager + async def fake_get_client(workspace: str | None = None): + seen["control_plane_workspace"] = workspace + async with httpx.AsyncClient(base_url="https://cloud.example.test") as client: + yield client + + async def fake_upload_path(*args, **kwargs): + client_cm_factory = kwargs.get("client_cm_factory") + assert client_cm_factory is not None + async with client_cm_factory() as client: + seen["base_url"] = str(client.base_url).rstrip("/") + return True + + monkeypatch.setattr(upload_command, "project_exists", fake_project_exists) + monkeypatch.setattr(upload_command, "get_cloud_control_plane_client", fake_get_client) + monkeypatch.setattr(upload_command, "upload_path", fake_upload_path) + + result = runner.invoke( + app, + [ + "cloud", + "upload", + str(upload_dir), + "--project", + "routing-test", + "--no-sync", + ], + ) + + assert result.exit_code == 0, result.output + assert seen["project_exists_workspace"] == "project-workspace" + assert seen["control_plane_workspace"] == "project-workspace" + assert seen["base_url"] == "https://cloud.example.test" + + +def test_cloud_upload_exits_when_project_lookup_fails(monkeypatch, tmp_path): + """Upload command should fail fast when cloud project lookup cannot reach the API.""" + import basic_memory.cli.commands.cloud.upload_command as upload_command + + upload_dir = tmp_path / "upload" + upload_dir.mkdir() + (upload_dir / "note.md").write_text("hello", encoding="utf-8") + + async def fake_project_exists(_project_name: str, workspace: str | None = None) -> bool: + raise CloudUtilsError("lookup failed") + + monkeypatch.setattr(upload_command, "project_exists", fake_project_exists) + + result = runner.invoke( + app, + [ + "cloud", + "upload", + str(upload_dir), + "--project", + "routing-test", + "--no-sync", + ], + ) + + assert result.exit_code == 1, result.output + assert "Failed to check cloud project 'routing-test'" in result.output diff --git a/tests/conftest.py b/tests/conftest.py index 8f8051c2..cd4b81d8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,7 +22,13 @@ from testcontainers.postgres import PostgresContainer from basic_memory import db -from basic_memory.config import ProjectConfig, BasicMemoryConfig, ConfigManager, DatabaseBackend +from basic_memory.config import ( + ProjectConfig, + ProjectEntry, + BasicMemoryConfig, + ConfigManager, + DatabaseBackend, +) from basic_memory.db import DatabaseType from basic_memory.markdown import EntityParser from basic_memory.markdown.markdown_processor import MarkdownProcessor @@ -74,7 +80,7 @@ def postgres_container(db_backend): The container is started once per test session and shared across all tests. Only starts if db_backend is "postgres". """ - if db_backend != "postgres": + if db_backend != "postgres" or _configured_postgres_sync_url(): yield None return @@ -83,6 +89,103 @@ def postgres_container(db_backend): yield postgres +POSTGRES_EPHEMERAL_TABLES = [ + "search_vector_embeddings", + "search_vector_index", +] + + +def _configured_postgres_sync_url() -> str | None: + """Prefer an externally managed Postgres server when CI provides one.""" + configured_url = os.environ.get("BASIC_MEMORY_TEST_POSTGRES_URL") or os.environ.get( + "POSTGRES_TEST_URL" + ) + if not configured_url: + return None + + return ( + configured_url.replace("postgresql+asyncpg://", "postgresql+psycopg2://", 1) + .replace("postgresql://", "postgresql+psycopg2://", 1) + .replace("postgres://", "postgresql+psycopg2://", 1) + ) + + +def _postgres_alembic_config(async_url: str) -> Config: + """Build Alembic config for stamping the shared Postgres test schema.""" + alembic_dir = Path(db.__file__).parent / "alembic" + cfg = Config() + cfg.set_main_option("script_location", str(alembic_dir)) + cfg.set_main_option( + "file_template", + "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s", + ) + cfg.set_main_option("timezone", "UTC") + cfg.set_main_option("revision_environment", "false") + cfg.set_main_option("sqlalchemy.url", async_url) + return cfg + + +def _postgres_reset_tables() -> list[str]: + """Resolve the current ORM table set at reset time. + + Some tests declare models after conftest import, so the list must stay dynamic. + """ + return [table.name for table in Base.metadata.sorted_tables] + [ + "search_index", + "search_vector_chunks", + ] + + +def _resolve_postgres_sync_url(postgres_container) -> str: + """Use CI's shared service when configured, otherwise fall back to testcontainers.""" + configured_url = _configured_postgres_sync_url() + if configured_url: + return configured_url + assert postgres_container is not None + return postgres_container.get_connection_url() + + +async def _reset_postgres_test_schema(engine: AsyncEngine, async_url: str) -> None: + """Restore the shared Postgres schema to a clean baseline before each test.""" + from basic_memory.models.search import ( + CREATE_POSTGRES_SEARCH_INDEX_FTS, + CREATE_POSTGRES_SEARCH_INDEX_METADATA, + CREATE_POSTGRES_SEARCH_INDEX_PERMALINK, + CREATE_POSTGRES_SEARCH_INDEX_TABLE, + CREATE_POSTGRES_SEARCH_VECTOR_CHUNKS_INDEX, + CREATE_POSTGRES_SEARCH_VECTOR_CHUNKS_TABLE, + ) + + async with engine.begin() as conn: + # Trigger: several tests intentionally drop or stub search tables to exercise recovery code. + # Why: TRUNCATE is much cheaper than drop_all/create_all, but it only works when the schema exists. + # Outcome: we recreate any missing core tables once, then clear rows for deterministic test setup. + await conn.run_sync(Base.metadata.create_all) + await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_TABLE) + await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_FTS) + await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_METADATA) + await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_PERMALINK) + await conn.execute(CREATE_POSTGRES_SEARCH_VECTOR_CHUNKS_TABLE) + await conn.execute(CREATE_POSTGRES_SEARCH_VECTOR_CHUNKS_INDEX) + + for table_name in POSTGRES_EPHEMERAL_TABLES: + await conn.execute(text(f"DROP TABLE IF EXISTS {table_name} CASCADE")) + + await conn.execute( + text( + f"TRUNCATE TABLE {', '.join(_postgres_reset_tables())} " + "RESTART IDENTITY CASCADE" + ) + ) + + alembic_version_exists = ( + await conn.execute(text("SELECT to_regclass('public.alembic_version')")) + ).scalar() is not None + + if not alembic_version_exists: + command.stamp(_postgres_alembic_config(async_url), "head") + + @pytest.fixture def anyio_backend(): return "asyncio" @@ -114,13 +217,15 @@ def config_home(tmp_path, monkeypatch) -> Path: @pytest.fixture(scope="function") def app_config(config_home, db_backend, postgres_container, monkeypatch) -> BasicMemoryConfig: """Create test app configuration for the appropriate backend.""" - projects = {"test-project": str(config_home)} + projects = {"test-project": ProjectEntry(path=str(config_home))} # Set backend based on parameterized db_backend fixture if db_backend == "postgres": backend = DatabaseBackend.POSTGRES - # Get URL from testcontainer and convert to asyncpg driver - sync_url = postgres_container.get_connection_url() + # Trigger: CI jobs can provide a shared Postgres service instead of per-session containers. + # Why: reusing one pgvector-enabled server avoids Docker startup churn on every job. + # Outcome: local runs keep using testcontainers, while CI injects a stable service URL. + sync_url = _resolve_postgres_sync_url(postgres_container) database_url = sync_url.replace("postgresql+psycopg2", "postgresql+asyncpg") else: backend = DatabaseBackend.SQLITE @@ -206,7 +311,7 @@ async def engine_factory( if db_backend == "postgres": # Postgres mode using testcontainers # Get async connection URL (asyncpg driver - same as production) - sync_url = postgres_container.get_connection_url() + sync_url = _resolve_postgres_sync_url(postgres_container) async_url = sync_url.replace("postgresql+psycopg2", "postgresql+asyncpg") engine = create_async_engine( @@ -229,46 +334,7 @@ async def engine_factory( db._engine = engine db._session_maker = session_maker - from basic_memory.models.search import ( - CREATE_POSTGRES_SEARCH_INDEX_TABLE, - CREATE_POSTGRES_SEARCH_INDEX_FTS, - CREATE_POSTGRES_SEARCH_INDEX_METADATA, - CREATE_POSTGRES_SEARCH_INDEX_PERMALINK, - CREATE_POSTGRES_SEARCH_VECTOR_CHUNKS_TABLE, - CREATE_POSTGRES_SEARCH_VECTOR_CHUNKS_INDEX, - ) - - # Drop and recreate all tables for test isolation - async with engine.begin() as conn: - # Must drop search_index first (has FK to project, blocks drop_all) - await conn.execute(text("DROP TABLE IF EXISTS search_index CASCADE")) - await conn.run_sync(Base.metadata.drop_all) - await conn.run_sync(Base.metadata.create_all) - # Create search_index via DDL (not ORM - uses composite PK + tsvector) - # asyncpg requires separate execute calls for each statement - await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_TABLE) - await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_FTS) - await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_METADATA) - await conn.execute(CREATE_POSTGRES_SEARCH_INDEX_PERMALINK) - await conn.execute(CREATE_POSTGRES_SEARCH_VECTOR_CHUNKS_TABLE) - await conn.execute(CREATE_POSTGRES_SEARCH_VECTOR_CHUNKS_INDEX) - - # Mark migrations as already applied for this test-created schema. - # - # Some codepaths (e.g. ensure_initialization()) invoke Alembic migrations. - # If we create tables via ORM directly, alembic_version is missing and migrations - # will try to create tables again, causing DuplicateTableError. - alembic_dir = Path(db.__file__).parent / "alembic" - cfg = Config() - cfg.set_main_option("script_location", str(alembic_dir)) - cfg.set_main_option( - "file_template", - "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s", - ) - cfg.set_main_option("timezone", "UTC") - cfg.set_main_option("revision_environment", "false") - cfg.set_main_option("sqlalchemy.url", async_url) - command.stamp(cfg, "head") + await _reset_postgres_test_schema(engine, async_url) yield engine, session_maker diff --git a/tests/mcp/test_async_client_modes.py b/tests/mcp/test_async_client_modes.py index f65c4f2e..5a0d65f9 100644 --- a/tests/mcp/test_async_client_modes.py +++ b/tests/mcp/test_async_client_modes.py @@ -79,6 +79,35 @@ async def test_get_client_cloud_adds_workspace_header(config_manager): assert client.headers.get("X-Workspace-ID") == "tenant-123" +@pytest.mark.asyncio +async def test_get_client_cloud_uses_project_workspace_when_not_explicit(config_manager): + cfg = config_manager.load_config() + cfg.cloud_host = "https://cloud.example.test" + cfg.cloud_api_key = "bmc_test_key_123" + cfg.default_workspace = "default-tenant" + cfg.set_project_mode("research", ProjectMode.CLOUD) + cfg.projects["research"].workspace_id = "project-tenant" + config_manager.save_config(cfg) + + async with get_client(project_name="research") as client: + assert str(client.base_url).rstrip("/") == "https://cloud.example.test/proxy" + assert client.headers.get("X-Workspace-ID") == "project-tenant" + + +@pytest.mark.asyncio +async def test_get_client_cloud_uses_default_workspace_when_project_has_none(config_manager): + cfg = config_manager.load_config() + cfg.cloud_host = "https://cloud.example.test" + cfg.cloud_api_key = "bmc_test_key_123" + cfg.default_workspace = "default-tenant" + cfg.set_project_mode("research", ProjectMode.CLOUD) + config_manager.save_config(cfg) + + async with get_client(project_name="research") as client: + assert str(client.base_url).rstrip("/") == "https://cloud.example.test/proxy" + assert client.headers.get("X-Workspace-ID") == "default-tenant" + + @pytest.mark.asyncio async def test_get_client_explicit_cloud_raises_without_credentials(config_manager, monkeypatch): cfg = config_manager.load_config() @@ -253,6 +282,19 @@ async def test_get_cloud_control_plane_client_uses_api_key_when_available(config assert client.headers.get("Authorization") == "Bearer bmc_test_key_123" +@pytest.mark.asyncio +async def test_get_cloud_control_plane_client_adds_workspace_header(config_manager): + cfg = config_manager.load_config() + cfg.cloud_host = "https://cloud.example.test" + cfg.cloud_api_key = "bmc_test_key_123" + config_manager.save_config(cfg) + + async with get_cloud_control_plane_client(workspace="tenant-123") as client: + assert str(client.base_url).rstrip("/") == "https://cloud.example.test" + assert client.headers.get("Authorization") == "Bearer bmc_test_key_123" + assert client.headers.get("X-Workspace-ID") == "tenant-123" + + @pytest.mark.asyncio async def test_get_cloud_control_plane_client_uses_oauth_token(config_manager): cfg = config_manager.load_config()