From f13742abf88850c2e5aec5d627d84443bc2ce6d3 Mon Sep 17 00:00:00 2001 From: phernandez Date: Sat, 4 Apr 2026 21:19:24 -0500 Subject: [PATCH 1/2] feat(core): add note_content tenant schema primitive - add the shared note_content model, migration, and repository - cover CRUD, migration, and Postgres-backed repository behavior - fix the Alembic async fallback coroutine warning Signed-off-by: phernandez --- src/basic_memory/alembic/env.py | 76 ++++-- .../l5g6h7i8j9k0_add_note_content_table.py | 65 +++++ src/basic_memory/models/__init__.py | 3 +- src/basic_memory/models/knowledge.py | 76 ++++++ src/basic_memory/repository/__init__.py | 2 + .../repository/note_content_repository.py | 164 +++++++++++++ .../test_note_content_repository.py | 223 ++++++++++++++++++ tests/test_alembic_env.py | 112 +++++++++ tests/test_note_content_migration.py | 75 ++++++ 9 files changed, 771 insertions(+), 25 deletions(-) create mode 100644 src/basic_memory/alembic/versions/l5g6h7i8j9k0_add_note_content_table.py create mode 100644 src/basic_memory/repository/note_content_repository.py create mode 100644 tests/repository/test_note_content_repository.py create mode 100644 tests/test_alembic_env.py create mode 100644 tests/test_note_content_migration.py diff --git a/src/basic_memory/alembic/env.py b/src/basic_memory/alembic/env.py index 94b15bdfb..01ecc4eaf 100644 --- a/src/basic_memory/alembic/env.py +++ b/src/basic_memory/alembic/env.py @@ -118,6 +118,54 @@ async def run_async_migrations(connectable): await connectable.dispose() +def _run_async_migrations_with_asyncio_run(connectable) -> None: + """Run async migrations with asyncio.run while closing failed coroutines. + + Trigger: asyncio.run() may reject execution when another event loop is already active. + Why: Python raises before awaiting the coroutine, which otherwise leaks a + RuntimeWarning about an un-awaited coroutine. + Outcome: close the pending coroutine before bubbling the RuntimeError to the + fallback path. + """ + migration_coro = run_async_migrations(connectable) + try: + asyncio.run(migration_coro) + except RuntimeError: + migration_coro.close() + raise + + +def _run_async_migrations_in_thread(connectable) -> None: + """Run async migrations in a dedicated thread with its own event loop.""" + import concurrent.futures + + def run_in_thread(): + """Run async migrations in a new event loop in a separate thread.""" + new_loop = asyncio.new_event_loop() + asyncio.set_event_loop(new_loop) + try: + new_loop.run_until_complete(run_async_migrations(connectable)) + finally: + new_loop.close() + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_in_thread) + future.result() # Wait for completion and re-raise any exceptions + + +def _run_async_engine_migrations(connectable) -> None: + """Run async-engine migrations with a running-loop fallback.""" + try: + _run_async_migrations_with_asyncio_run(connectable) + except RuntimeError as e: + if "cannot be called from a running event loop" in str(e): + # We're in a running event loop (likely uvloop or Python 3.14+ tests). + # Switch to a dedicated thread so Alembic can finish without nesting loops. + _run_async_migrations_in_thread(connectable) + else: + raise + + def run_migrations_online() -> None: """Run migrations in 'online' mode. @@ -148,30 +196,10 @@ def run_migrations_online() -> None: # Handle async engines (PostgreSQL with asyncpg) if isinstance(connectable, AsyncEngine): - # Try to run async migrations - # nest_asyncio allows asyncio.run() from within event loops, but doesn't work with uvloop - try: - asyncio.run(run_async_migrations(connectable)) - except RuntimeError as e: - if "cannot be called from a running event loop" in str(e): - # We're in a running event loop (likely uvloop) - need to use a different approach - # Create a new thread to run the async migrations - import concurrent.futures - - def run_in_thread(): - """Run async migrations in a new event loop in a separate thread.""" - new_loop = asyncio.new_event_loop() - asyncio.set_event_loop(new_loop) - try: - new_loop.run_until_complete(run_async_migrations(connectable)) - finally: - new_loop.close() - - with concurrent.futures.ThreadPoolExecutor() as executor: - future = executor.submit(run_in_thread) - future.result() # Wait for completion and re-raise any exceptions - else: - raise + # Trigger: async engines need Alembic work to cross the sync/async boundary. + # Why: most callers can use asyncio.run(), but running-loop contexts need a thread fallback. + # Outcome: migrations complete without leaking un-awaited coroutines. + _run_async_engine_migrations(connectable) else: # Handle sync engines (SQLite) or sync connections if hasattr(connectable, "connect"): diff --git a/src/basic_memory/alembic/versions/l5g6h7i8j9k0_add_note_content_table.py b/src/basic_memory/alembic/versions/l5g6h7i8j9k0_add_note_content_table.py new file mode 100644 index 000000000..bd5a8642e --- /dev/null +++ b/src/basic_memory/alembic/versions/l5g6h7i8j9k0_add_note_content_table.py @@ -0,0 +1,65 @@ +"""Add note_content table + +Revision ID: l5g6h7i8j9k0 +Revises: k4e5f6g7h8i9 +Create Date: 2026-04-04 12:00:00.000000 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = "l5g6h7i8j9k0" +down_revision: Union[str, None] = "k4e5f6g7h8i9" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + """Create note_content for materialized note content and sync state.""" + op.create_table( + "note_content", + sa.Column("entity_id", sa.Integer(), nullable=False), + sa.Column("project_id", sa.Integer(), nullable=False), + sa.Column("external_id", sa.String(), nullable=False), + sa.Column("file_path", sa.String(), nullable=False), + sa.Column("markdown_content", sa.Text(), nullable=False), + sa.Column("db_version", sa.BigInteger(), nullable=False), + sa.Column("db_checksum", sa.String(), nullable=False), + sa.Column("file_version", sa.BigInteger(), nullable=True), + sa.Column("file_checksum", sa.String(), nullable=True), + sa.Column("file_write_status", sa.String(), nullable=False), + sa.Column("last_source", sa.String(), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("file_updated_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("last_materialization_error", sa.Text(), nullable=True), + sa.Column("last_materialization_attempt_at", sa.DateTime(timezone=True), nullable=True), + sa.CheckConstraint( + "file_write_status IN (" + "'pending', " + "'writing', " + "'synced', " + "'failed', " + "'external_change_detected'" + ")", + name="ck_note_content_file_write_status", + ), + sa.ForeignKeyConstraint(["entity_id"], ["entity.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"), + sa.PrimaryKeyConstraint("entity_id"), + ) + op.create_index("ix_note_content_project_id", "note_content", ["project_id"], unique=False) + op.create_index("ix_note_content_file_path", "note_content", ["file_path"], unique=False) + op.create_index("ix_note_content_external_id", "note_content", ["external_id"], unique=True) + + +def downgrade() -> None: + """Drop note_content and its supporting indexes.""" + op.drop_index("ix_note_content_external_id", table_name="note_content") + op.drop_index("ix_note_content_file_path", table_name="note_content") + op.drop_index("ix_note_content_project_id", table_name="note_content") + op.drop_table("note_content") diff --git a/src/basic_memory/models/__init__.py b/src/basic_memory/models/__init__.py index acdc03b18..7c80476df 100644 --- a/src/basic_memory/models/__init__.py +++ b/src/basic_memory/models/__init__.py @@ -2,12 +2,13 @@ import basic_memory from basic_memory.models.base import Base -from basic_memory.models.knowledge import Entity, Observation, Relation +from basic_memory.models.knowledge import Entity, NoteContent, Observation, Relation from basic_memory.models.project import Project __all__ = [ "Base", "Entity", + "NoteContent", "Observation", "Relation", "Project", diff --git a/src/basic_memory/models/knowledge.py b/src/basic_memory/models/knowledge.py index 05054a2dd..0a8eabf3c 100644 --- a/src/basic_memory/models/knowledge.py +++ b/src/basic_memory/models/knowledge.py @@ -6,6 +6,8 @@ from typing import Optional from sqlalchemy import ( + BigInteger, + CheckConstraint, Integer, String, Text, @@ -116,6 +118,12 @@ class Entity(Base): foreign_keys="[Relation.to_id]", cascade="all, delete-orphan", ) + note_content = relationship( + "NoteContent", + back_populates="entity", + cascade="all, delete-orphan", + uselist=False, + ) @property def relations(self): @@ -141,6 +149,74 @@ def __repr__(self) -> str: return f"Entity(id={self.id}, external_id='{self.external_id}', name='{self.title}', type='{self.note_type}', checksum='{self.checksum}')" +class NoteContent(Base): + """Materialized markdown content and sync state for a note entity.""" + + __tablename__ = "note_content" + __table_args__ = ( + CheckConstraint( + "file_write_status IN (" + "'pending', " + "'writing', " + "'synced', " + "'failed', " + "'external_change_detected'" + ")", + name="ck_note_content_file_write_status", + ), + Index("ix_note_content_project_id", "project_id"), + Index("ix_note_content_file_path", "file_path"), + Index("ix_note_content_external_id", "external_id", unique=True), + ) + + # Core identity mirrored from entity for hot note reads + entity_id: Mapped[int] = mapped_column( + Integer, + ForeignKey("entity.id", ondelete="CASCADE"), + primary_key=True, + ) + project_id: Mapped[int] = mapped_column( + Integer, + ForeignKey("project.id", ondelete="CASCADE"), + nullable=False, + ) + external_id: Mapped[str] = mapped_column(String, nullable=False) + file_path: Mapped[str] = mapped_column(String, nullable=False) + + # Materialized content version tracked in the tenant database + markdown_content: Mapped[str] = mapped_column(Text, nullable=False) + db_version: Mapped[int] = mapped_column(BigInteger, nullable=False) + db_checksum: Mapped[str] = mapped_column(String, nullable=False) + + # File materialization state tracked against the latest write attempts + file_version: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True) + file_checksum: Mapped[Optional[str]] = mapped_column(String, nullable=True) + file_write_status: Mapped[str] = mapped_column(String, nullable=False, default="pending") + last_source: Mapped[Optional[str]] = mapped_column(String, nullable=True) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), + default=lambda: datetime.now().astimezone(), + onupdate=lambda: datetime.now().astimezone(), + ) + file_updated_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + last_materialization_error: Mapped[Optional[str]] = mapped_column(Text, nullable=True) + last_materialization_attempt_at: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), + nullable=True, + ) + + entity = relationship("Entity", back_populates="note_content") + + def __repr__(self) -> str: # pragma: no cover + return ( + f"NoteContent(entity_id={self.entity_id}, external_id='{self.external_id}', " + f"file_path='{self.file_path}', file_write_status='{self.file_write_status}')" + ) + + class Observation(Base): """An observation about an entity. diff --git a/src/basic_memory/repository/__init__.py b/src/basic_memory/repository/__init__.py index 37df07668..579313ce9 100644 --- a/src/basic_memory/repository/__init__.py +++ b/src/basic_memory/repository/__init__.py @@ -1,10 +1,12 @@ from .entity_repository import EntityRepository +from .note_content_repository import NoteContentRepository from .observation_repository import ObservationRepository from .project_repository import ProjectRepository from .relation_repository import RelationRepository __all__ = [ "EntityRepository", + "NoteContentRepository", "ObservationRepository", "ProjectRepository", "RelationRepository", diff --git a/src/basic_memory/repository/note_content_repository.py b/src/basic_memory/repository/note_content_repository.py new file mode 100644 index 000000000..441ba6717 --- /dev/null +++ b/src/basic_memory/repository/note_content_repository.py @@ -0,0 +1,164 @@ +"""Repository for managing note materialization state.""" + +from pathlib import Path +from typing import Any, Mapping, Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker + +from basic_memory import db +from basic_memory.models import Entity, NoteContent +from basic_memory.repository.repository import Repository + +NOTE_CONTENT_MUTABLE_FIELDS = frozenset( + { + "markdown_content", + "db_version", + "db_checksum", + "file_version", + "file_checksum", + "file_write_status", + "last_source", + "updated_at", + "file_updated_at", + "last_materialization_error", + "last_materialization_attempt_at", + } +) + + +class NoteContentRepository(Repository[NoteContent]): + """Repository for project-scoped note materialization state.""" + + def __init__(self, session_maker: async_sessionmaker[AsyncSession], project_id: int): + """Initialize with session maker and project-scoped filtering.""" + super().__init__(session_maker, NoteContent, project_id=project_id) + + def _coerce_note_content(self, data: Mapping[str, Any] | NoteContent) -> NoteContent: + """Convert input data to a NoteContent model while preserving nullable fields.""" + if isinstance(data, NoteContent): + return data + + model_data = {key: value for key, value in data.items() if key in self.valid_columns} + entity_id = model_data.get("entity_id") + if entity_id is None: + raise ValueError("entity_id is required for note_content writes") + + return NoteContent(**model_data) + + async def _load_entity_identity(self, session: AsyncSession, entity_id: int) -> Entity: + """Load the owning entity so duplicated identity fields stay aligned.""" + result = await session.execute(select(Entity).where(Entity.id == entity_id)) + entity = result.scalar_one_or_none() + if entity is None: + raise ValueError(f"Entity {entity_id} does not exist") + + if self.project_id is not None and entity.project_id != self.project_id: + raise ValueError( + f"Entity {entity_id} belongs to project {entity.project_id}, " + f"not repository project {self.project_id}" + ) + + return entity + + async def _align_identity_fields( + self, session: AsyncSession, note_content: NoteContent + ) -> None: + """Mirror project identity from entity before persisting note content.""" + entity = await self._load_entity_identity(session, note_content.entity_id) + note_content.project_id = entity.project_id + note_content.external_id = entity.external_id + note_content.file_path = Path(entity.file_path).as_posix() + + async def get_by_entity_id(self, entity_id: int) -> Optional[NoteContent]: + """Get note content by the owning entity identifier.""" + return await self.find_by_id(entity_id) + + async def get_by_external_id(self, external_id: str) -> Optional[NoteContent]: + """Get note content by the mirrored entity external identifier.""" + query = self.select().where(NoteContent.external_id == external_id) + return await self.find_one(query) + + async def get_by_file_path(self, file_path: Path | str) -> Optional[NoteContent]: + """Get note content by the mirrored entity file path.""" + query = self.select().where(NoteContent.file_path == Path(file_path).as_posix()) + return await self.find_one(query) + + async def create(self, data: Mapping[str, Any] | NoteContent) -> NoteContent: + """Create a note_content row aligned to its owning entity.""" + note_content = self._coerce_note_content(data) + + async with db.scoped_session(self.session_maker) as session: + await self._align_identity_fields(session, note_content) + session.add(note_content) + await session.flush() + + created = await self.select_by_id(session, note_content.entity_id) + if created is None: # pragma: no cover + raise ValueError( + f"Can't find NoteContent for entity {note_content.entity_id} after add" + ) + return created + + async def upsert(self, data: Mapping[str, Any] | NoteContent) -> NoteContent: + """Insert or update note_content while keeping mirrored identity fields in sync.""" + note_content = self._coerce_note_content(data) + + async with db.scoped_session(self.session_maker) as session: + await self._align_identity_fields(session, note_content) + existing = await self.select_by_id(session, note_content.entity_id) + + if existing is None: + session.add(note_content) + await session.flush() + created = await self.select_by_id(session, note_content.entity_id) + if created is None: # pragma: no cover + raise ValueError( + f"Can't find NoteContent for entity {note_content.entity_id} after upsert" + ) + return created + + for column_name in NoteContent.__table__.columns.keys(): + if column_name == "entity_id": + continue + setattr(existing, column_name, getattr(note_content, column_name)) + + await session.flush() + updated = await self.select_by_id(session, existing.entity_id) + if updated is None: # pragma: no cover + raise ValueError( + f"Can't find NoteContent for entity {existing.entity_id} after upsert" + ) + return updated + + async def update_state_fields(self, entity_id: int, **updates: Any) -> Optional[NoteContent]: + """Update mutable sync and materialization fields for a note_content row.""" + invalid_fields = set(updates) - NOTE_CONTENT_MUTABLE_FIELDS + if invalid_fields: + invalid_list = ", ".join(sorted(invalid_fields)) + raise ValueError(f"Unsupported note_content update fields: {invalid_list}") + + async with db.scoped_session(self.session_maker) as session: + note_content = await self.select_by_id(session, entity_id) + if note_content is None: + return None + + await self._align_identity_fields(session, note_content) + for field_name, value in updates.items(): + setattr(note_content, field_name, value) + + await session.flush() + updated = await self.select_by_id(session, entity_id) + if updated is None: # pragma: no cover + raise ValueError(f"Can't find NoteContent for entity {entity_id} after update") + return updated + + async def delete_by_entity_id(self, entity_id: int) -> bool: + """Delete note_content by entity identifier.""" + async with db.scoped_session(self.session_maker) as session: + note_content = await self.select_by_id(session, entity_id) + if note_content is None: + return False + + await session.delete(note_content) + return True diff --git a/tests/repository/test_note_content_repository.py b/tests/repository/test_note_content_repository.py new file mode 100644 index 000000000..860c71a6c --- /dev/null +++ b/tests/repository/test_note_content_repository.py @@ -0,0 +1,223 @@ +"""Tests for the NoteContentRepository.""" + +from datetime import datetime, timezone + +import pytest + +from basic_memory.models import NoteContent, Project +from basic_memory.repository.entity_repository import EntityRepository +from basic_memory.repository.note_content_repository import NoteContentRepository +from basic_memory.repository.project_repository import ProjectRepository + + +def build_note_content_payload(entity_id: int) -> dict: + """Build a minimal payload for note_content writes.""" + return { + "entity_id": entity_id, + "project_id": -1, + "external_id": "stale-external-id", + "file_path": "stale/path.md", + "markdown_content": "# Materialized content", + "db_version": 1, + "db_checksum": "db-checksum-1", + "file_version": None, + "file_checksum": None, + "file_write_status": "pending", + "last_source": "api", + "updated_at": datetime.now(timezone.utc), + "file_updated_at": None, + "last_materialization_error": None, + "last_materialization_attempt_at": None, + } + + +@pytest.mark.asyncio +async def test_create_and_lookup_note_content( + session_maker, + test_project: Project, + sample_entity, +): + """Create note_content and read it back through each supported lookup.""" + repository = NoteContentRepository(session_maker, project_id=test_project.id) + + created = await repository.create(build_note_content_payload(sample_entity.id)) + + assert created.entity_id == sample_entity.id + assert created.project_id == sample_entity.project_id + assert created.external_id == sample_entity.external_id + assert created.file_path == sample_entity.file_path + + by_entity = await repository.get_by_entity_id(sample_entity.id) + by_external = await repository.get_by_external_id(sample_entity.external_id) + by_path = await repository.get_by_file_path(sample_entity.file_path) + + assert by_entity is not None + assert by_external is not None + assert by_path is not None + assert by_entity.entity_id == created.entity_id + assert by_external.entity_id == created.entity_id + assert by_path.entity_id == created.entity_id + + +@pytest.mark.asyncio +async def test_upsert_updates_existing_note_content( + session_maker, + test_project: Project, + sample_entity, +): + """Upsert should update the existing row instead of inserting a duplicate.""" + repository = NoteContentRepository(session_maker, project_id=test_project.id) + await repository.create(build_note_content_payload(sample_entity.id)) + + updated_at = datetime.now(timezone.utc) + updated = await repository.upsert( + NoteContent( + entity_id=sample_entity.id, + project_id=test_project.id, + external_id=sample_entity.external_id, + file_path=sample_entity.file_path, + markdown_content="# Updated materialized content", + db_version=2, + db_checksum="db-checksum-2", + file_version=7, + file_checksum="file-checksum-7", + file_write_status="synced", + last_source="reconciler", + updated_at=updated_at, + file_updated_at=updated_at, + last_materialization_error="transient failure", + last_materialization_attempt_at=updated_at, + ) + ) + + assert updated.entity_id == sample_entity.id + assert updated.markdown_content == "# Updated materialized content" + assert updated.db_version == 2 + assert updated.db_checksum == "db-checksum-2" + assert updated.file_version == 7 + assert updated.file_checksum == "file-checksum-7" + assert updated.file_write_status == "synced" + assert updated.last_source == "reconciler" + assert updated.last_materialization_error == "transient failure" + + +@pytest.mark.asyncio +async def test_update_state_fields_realigns_identity_with_entity( + session_maker, + test_project: Project, + sample_entity, + entity_repository: EntityRepository, +): + """Sync-field updates should refresh mirrored identity from the owning entity.""" + repository = NoteContentRepository(session_maker, project_id=test_project.id) + await repository.create(build_note_content_payload(sample_entity.id)) + + renamed_path = "renamed/test_entity.md" + await entity_repository.update(sample_entity.id, {"file_path": renamed_path}) + + updated = await repository.update_state_fields( + sample_entity.id, + file_write_status="failed", + file_version=3, + file_checksum="file-checksum-3", + last_materialization_error=None, + last_materialization_attempt_at=None, + ) + + assert updated is not None + assert updated.file_path == renamed_path + assert updated.external_id == sample_entity.external_id + assert updated.file_write_status == "failed" + assert updated.file_version == 3 + assert updated.file_checksum == "file-checksum-3" + assert updated.last_materialization_error is None + assert updated.last_materialization_attempt_at is None + + +@pytest.mark.asyncio +async def test_delete_by_entity_id(session_maker, test_project: Project, sample_entity): + """Delete note_content directly by entity identifier.""" + repository = NoteContentRepository(session_maker, project_id=test_project.id) + await repository.create(build_note_content_payload(sample_entity.id)) + + deleted = await repository.delete_by_entity_id(sample_entity.id) + + assert deleted is True + assert await repository.get_by_entity_id(sample_entity.id) is None + + +@pytest.mark.asyncio +async def test_note_content_cascades_when_entity_is_deleted( + session_maker, + test_project: Project, + sample_entity, + entity_repository: EntityRepository, +): + """Deleting the owning entity should cascade to note_content.""" + repository = NoteContentRepository(session_maker, project_id=test_project.id) + await repository.create(build_note_content_payload(sample_entity.id)) + + deleted = await entity_repository.delete(sample_entity.id) + + assert deleted is True + assert await repository.get_by_entity_id(sample_entity.id) is None + + +@pytest.mark.asyncio +async def test_note_content_file_path_lookup_is_project_scoped(session_maker, config_home): + """Lookups by file_path should respect the repository project scope.""" + project_repository = ProjectRepository(session_maker) + project_one = await project_repository.create( + { + "name": "project-one", + "path": str(config_home / "project-one"), + "is_active": True, + } + ) + project_two = await project_repository.create( + { + "name": "project-two", + "path": str(config_home / "project-two"), + "is_active": True, + } + ) + + entity_one_repo = EntityRepository(session_maker, project_id=project_one.id) + entity_two_repo = EntityRepository(session_maker, project_id=project_two.id) + + shared_file_path = "shared/note.md" + entity_one = await entity_one_repo.create( + { + "title": "Shared Note", + "note_type": "test", + "permalink": "project-one/shared-note", + "file_path": shared_file_path, + "content_type": "text/markdown", + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } + ) + entity_two = await entity_two_repo.create( + { + "title": "Shared Note", + "note_type": "test", + "permalink": "project-two/shared-note", + "file_path": shared_file_path, + "content_type": "text/markdown", + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } + ) + + repository_one = NoteContentRepository(session_maker, project_id=project_one.id) + repository_two = NoteContentRepository(session_maker, project_id=project_two.id) + await repository_one.create(build_note_content_payload(entity_one.id)) + await repository_two.create(build_note_content_payload(entity_two.id)) + + found_one = await repository_one.get_by_file_path(shared_file_path) + found_two = await repository_two.get_by_file_path(shared_file_path) + + assert found_one is not None + assert found_two is not None + assert found_one.entity_id == entity_one.id + assert found_two.entity_id == entity_two.id diff --git a/tests/test_alembic_env.py b/tests/test_alembic_env.py new file mode 100644 index 000000000..2dd3a6ae1 --- /dev/null +++ b/tests/test_alembic_env.py @@ -0,0 +1,112 @@ +"""Regression tests for Alembic env async migration helpers.""" + +import importlib.util +import uuid +from contextlib import nullcontext +from pathlib import Path + +import pytest + + +class FakeAlembicConfig: + """Minimal config object used while importing env.py under test.""" + + def __init__(self): + self.options = {"sqlalchemy.url": "sqlite:///:memory:"} + self.attributes = {} + self.config_file_name = None + self.config_ini_section = "alembic" + + def get_main_option(self, name: str) -> str | None: + return self.options.get(name) + + def set_main_option(self, name: str, value: str) -> None: + self.options[name] = value + + def get_section(self, name: str, default=None): + return default or {} + + +class FakeCoroutine: + """Track whether the migration coroutine gets closed on failure.""" + + def __init__(self): + self.closed = False + + def close(self) -> None: + self.closed = True + + +def load_alembic_env_module(monkeypatch, tmp_path): + """Import env.py with a fake Alembic context and isolated HOME.""" + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("BASIC_MEMORY_HOME", str(tmp_path / "basic-memory")) + + from alembic import context as alembic_context + + fake_config = FakeAlembicConfig() + monkeypatch.setattr(alembic_context, "config", fake_config, raising=False) + monkeypatch.setattr(alembic_context, "configure", lambda *args, **kwargs: None, raising=False) + monkeypatch.setattr(alembic_context, "begin_transaction", lambda: nullcontext(), raising=False) + monkeypatch.setattr(alembic_context, "run_migrations", lambda: None, raising=False) + monkeypatch.setattr(alembic_context, "is_offline_mode", lambda: True, raising=False) + + env_path = Path(__file__).resolve().parents[1] / "src/basic_memory/alembic/env.py" + module_name = f"test_alembic_env_{uuid.uuid4().hex}" + spec = importlib.util.spec_from_file_location(module_name, env_path) + assert spec is not None + assert spec.loader is not None + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +def test_asyncio_run_failure_closes_migration_coroutine(monkeypatch, tmp_path): + """The running-loop fallback should not leak an un-awaited coroutine.""" + env_module = load_alembic_env_module(monkeypatch, tmp_path) + fake_coro = FakeCoroutine() + + monkeypatch.setattr(env_module, "run_async_migrations", lambda connectable: fake_coro) + + def raising_asyncio_run(coro): + raise RuntimeError("asyncio.run() cannot be called from a running event loop") + + monkeypatch.setattr(env_module.asyncio, "run", raising_asyncio_run) + + with pytest.raises(RuntimeError, match="running event loop"): + env_module._run_async_migrations_with_asyncio_run(object()) + + assert fake_coro.closed is True + + +def test_running_loop_error_uses_thread_fallback(monkeypatch, tmp_path): + """Async-engine helper should switch to the thread fallback for running-loop errors.""" + env_module = load_alembic_env_module(monkeypatch, tmp_path) + connectable = object() + fallback_calls: list[object] = [] + + def raising_run(connectable): + raise RuntimeError("asyncio.run() cannot be called from a running event loop") + + def record_fallback(target): + fallback_calls.append(target) + + monkeypatch.setattr(env_module, "_run_async_migrations_with_asyncio_run", raising_run) + monkeypatch.setattr(env_module, "_run_async_migrations_in_thread", record_fallback) + + env_module._run_async_engine_migrations(connectable) + + assert fallback_calls == [connectable] + + +def test_non_loop_runtime_error_is_re_raised(monkeypatch, tmp_path): + """Unexpected RuntimeError values should not be swallowed by the fallback path.""" + env_module = load_alembic_env_module(monkeypatch, tmp_path) + + def raising_run(connectable): + raise RuntimeError("different runtime failure") + + monkeypatch.setattr(env_module, "_run_async_migrations_with_asyncio_run", raising_run) + + with pytest.raises(RuntimeError, match="different runtime failure"): + env_module._run_async_engine_migrations(object()) diff --git a/tests/test_note_content_migration.py b/tests/test_note_content_migration.py new file mode 100644 index 000000000..b8f31f02a --- /dev/null +++ b/tests/test_note_content_migration.py @@ -0,0 +1,75 @@ +"""Migration tests for note_content schema.""" + +import sqlite3 +from pathlib import Path + +from alembic import command +from alembic.config import Config + +from basic_memory import db + + +def sqlite_alembic_config(database_path: Path) -> Config: + """Build an Alembic config that upgrades a temporary SQLite database.""" + alembic_dir = Path(db.__file__).parent / "alembic" + config = Config() + config.set_main_option("script_location", str(alembic_dir)) + config.set_main_option( + "file_template", + "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s", + ) + config.set_main_option("timezone", "UTC") + config.set_main_option("revision_environment", "false") + config.set_main_option("sqlalchemy.url", f"sqlite:///{database_path}") + return config + + +def test_alembic_upgrade_creates_note_content_table(tmp_path, monkeypatch): + """Running Alembic head should create note_content with its expected contract.""" + monkeypatch.setenv("HOME", str(tmp_path)) + monkeypatch.setenv("BASIC_MEMORY_HOME", str(tmp_path / "basic-memory")) + + database_path = tmp_path / "note-content-migration.db" + command.upgrade(sqlite_alembic_config(database_path), "head") + + connection = sqlite3.connect(database_path) + try: + columns = { + row[1] for row in connection.execute("PRAGMA table_info(note_content)").fetchall() + } + assert columns == { + "entity_id", + "project_id", + "external_id", + "file_path", + "markdown_content", + "db_version", + "db_checksum", + "file_version", + "file_checksum", + "file_write_status", + "last_source", + "updated_at", + "file_updated_at", + "last_materialization_error", + "last_materialization_attempt_at", + } + + foreign_keys = connection.execute("PRAGMA foreign_key_list(note_content)").fetchall() + entity_fk = next(row for row in foreign_keys if row[3] == "entity_id") + project_fk = next(row for row in foreign_keys if row[3] == "project_id") + assert entity_fk[2] == "entity" + assert entity_fk[4] == "id" + assert entity_fk[6].upper() == "CASCADE" + assert project_fk[2] == "project" + assert project_fk[4] == "id" + assert project_fk[6].upper() == "CASCADE" + + indexes = { + row[1] for row in connection.execute("PRAGMA index_list(note_content)").fetchall() + } + assert "ix_note_content_project_id" in indexes + assert "ix_note_content_file_path" in indexes + assert "ix_note_content_external_id" in indexes + finally: + connection.close() From 5a24c4682f31eeda5f87098dbc373b297dad316d Mon Sep 17 00:00:00 2001 From: phernandez Date: Sat, 4 Apr 2026 21:32:35 -0500 Subject: [PATCH 2/2] fix(core): address note_content review feedback Signed-off-by: phernandez --- src/basic_memory/alembic/env.py | 2 +- .../repository/note_content_repository.py | 55 +++-- .../test_note_content_repository.py | 204 +++++++++++++++++- 3 files changed, 245 insertions(+), 16 deletions(-) diff --git a/src/basic_memory/alembic/env.py b/src/basic_memory/alembic/env.py index 01ecc4eaf..700ecef50 100644 --- a/src/basic_memory/alembic/env.py +++ b/src/basic_memory/alembic/env.py @@ -66,7 +66,7 @@ # Add this function to tell Alembic what to include/exclude -def include_object(object, name, type_, reflected, compare_to): +def include_object(obj, name, type_, reflected, compare_to): # Ignore SQLite FTS tables if type_ == "table" and name.startswith("search_index"): return False diff --git a/src/basic_memory/repository/note_content_repository.py b/src/basic_memory/repository/note_content_repository.py index 441ba6717..80b34903b 100644 --- a/src/basic_memory/repository/note_content_repository.py +++ b/src/basic_memory/repository/note_content_repository.py @@ -34,17 +34,22 @@ def __init__(self, session_maker: async_sessionmaker[AsyncSession], project_id: """Initialize with session maker and project-scoped filtering.""" super().__init__(session_maker, NoteContent, project_id=project_id) - def _coerce_note_content(self, data: Mapping[str, Any] | NoteContent) -> NoteContent: - """Convert input data to a NoteContent model while preserving nullable fields.""" + def _coerce_note_content( + self, data: Mapping[str, Any] | NoteContent + ) -> tuple[NoteContent, set[str]]: + """Convert input data to a NoteContent model and track explicit fields.""" if isinstance(data, NoteContent): - return data + model_data = { + key: value for key, value in data.__dict__.items() if key in self.valid_columns + } + else: + model_data = {key: value for key, value in data.items() if key in self.valid_columns} - model_data = {key: value for key, value in data.items() if key in self.valid_columns} entity_id = model_data.get("entity_id") if entity_id is None: raise ValueError("entity_id is required for note_content writes") - return NoteContent(**model_data) + return NoteContent(**model_data), set(model_data) async def _load_entity_identity(self, session: AsyncSession, entity_id: int) -> Entity: """Load the owning entity so duplicated identity fields stay aligned.""" @@ -80,13 +85,32 @@ async def get_by_external_id(self, external_id: str) -> Optional[NoteContent]: return await self.find_one(query) async def get_by_file_path(self, file_path: Path | str) -> Optional[NoteContent]: - """Get note content by the mirrored entity file path.""" - query = self.select().where(NoteContent.file_path == Path(file_path).as_posix()) - return await self.find_one(query) + """Get note content by file path, preferring rows whose entity still owns that path.""" + normalized_path = Path(file_path).as_posix() + + # Trigger: note_content mirrors entity.file_path but does not enforce project-level uniqueness. + # Why: entity renames can leave stale mirrored paths behind until note_content realigns. + # Outcome: prefer the row whose current entity path still matches, then the newest mirror. + query = ( + self.select() + .join(Entity, Entity.id == NoteContent.entity_id) + .where(NoteContent.file_path == normalized_path) + .order_by( + (Entity.file_path == normalized_path).desc(), + NoteContent.updated_at.desc(), + NoteContent.entity_id.desc(), + ) + .limit(1) + .options(*self.get_load_options()) + ) + + async with db.scoped_session(self.session_maker) as session: + result = await session.execute(query) + return result.scalars().first() async def create(self, data: Mapping[str, Any] | NoteContent) -> NoteContent: """Create a note_content row aligned to its owning entity.""" - note_content = self._coerce_note_content(data) + note_content, _ = self._coerce_note_content(data) async with db.scoped_session(self.session_maker) as session: await self._align_identity_fields(session, note_content) @@ -102,7 +126,7 @@ async def create(self, data: Mapping[str, Any] | NoteContent) -> NoteContent: async def upsert(self, data: Mapping[str, Any] | NoteContent) -> NoteContent: """Insert or update note_content while keeping mirrored identity fields in sync.""" - note_content = self._coerce_note_content(data) + note_content, provided_fields = self._coerce_note_content(data) async with db.scoped_session(self.session_maker) as session: await self._align_identity_fields(session, note_content) @@ -118,9 +142,12 @@ async def upsert(self, data: Mapping[str, Any] | NoteContent) -> NoteContent: ) return created - for column_name in NoteContent.__table__.columns.keys(): - if column_name == "entity_id": - continue + fields_to_update = (provided_fields - {"entity_id"}) | { + "project_id", + "external_id", + "file_path", + } + for column_name in fields_to_update: setattr(existing, column_name, getattr(note_content, column_name)) await session.flush() @@ -132,7 +159,7 @@ async def upsert(self, data: Mapping[str, Any] | NoteContent) -> NoteContent: return updated async def update_state_fields(self, entity_id: int, **updates: Any) -> Optional[NoteContent]: - """Update mutable sync and materialization fields for a note_content row.""" + """Update sync fields and re-align project_id, external_id, and file_path from entity.""" invalid_fields = set(updates) - NOTE_CONTENT_MUTABLE_FIELDS if invalid_fields: invalid_list = ", ".join(sorted(invalid_fields)) diff --git a/tests/repository/test_note_content_repository.py b/tests/repository/test_note_content_repository.py index 860c71a6c..0032a8e0f 100644 --- a/tests/repository/test_note_content_repository.py +++ b/tests/repository/test_note_content_repository.py @@ -1,9 +1,10 @@ """Tests for the NoteContentRepository.""" -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone import pytest +from basic_memory import db from basic_memory.models import NoteContent, Project from basic_memory.repository.entity_repository import EntityRepository from basic_memory.repository.note_content_repository import NoteContentRepository @@ -101,6 +102,114 @@ async def test_upsert_updates_existing_note_content( assert updated.last_materialization_error == "transient failure" +@pytest.mark.asyncio +async def test_upsert_inserts_when_no_existing_row( + session_maker, + test_project: Project, + sample_entity, +): + """Upsert should insert a new row when the entity has no note_content yet.""" + repository = NoteContentRepository(session_maker, project_id=test_project.id) + + created = await repository.upsert(build_note_content_payload(sample_entity.id)) + + assert created.entity_id == sample_entity.id + assert created.project_id == sample_entity.project_id + assert created.external_id == sample_entity.external_id + assert created.file_path == sample_entity.file_path + assert created.db_version == 1 + + +@pytest.mark.asyncio +async def test_create_requires_entity_id(session_maker, test_project: Project): + """Create should fail fast when note_content identity is missing.""" + repository = NoteContentRepository(session_maker, project_id=test_project.id) + + with pytest.raises(ValueError, match="entity_id is required"): + await repository.create({"markdown_content": "# Missing entity"}) + + +@pytest.mark.asyncio +async def test_upsert_preserves_existing_fields_for_partial_payload( + session_maker, + test_project: Project, + sample_entity, +): + """Partial upserts should only change explicit fields and preserve existing state.""" + repository = NoteContentRepository(session_maker, project_id=test_project.id) + payload = build_note_content_payload(sample_entity.id) + payload["last_materialization_error"] = "stale failure" + created = await repository.create(payload) + + updated_at = datetime.now(timezone.utc) + updated = await repository.upsert( + { + "entity_id": sample_entity.id, + "markdown_content": "# Partially updated content", + "db_version": 2, + "updated_at": updated_at, + "last_materialization_error": None, + } + ) + + assert updated.markdown_content == "# Partially updated content" + assert updated.db_version == 2 + assert updated.db_checksum == created.db_checksum + assert updated.file_write_status == created.file_write_status + assert updated.last_source == created.last_source + assert updated.last_materialization_error is None + assert updated.file_path == sample_entity.file_path + + +@pytest.mark.asyncio +async def test_create_rejects_missing_entity(session_maker, test_project: Project): + """Create should fail when the owning entity does not exist.""" + repository = NoteContentRepository(session_maker, project_id=test_project.id) + + with pytest.raises(ValueError, match="Entity 999999 does not exist"): + await repository.create(build_note_content_payload(999999)) + + +@pytest.mark.asyncio +async def test_create_rejects_entity_from_another_project(session_maker, config_home): + """Create should reject note_content writes across project boundaries.""" + project_repository = ProjectRepository(session_maker) + project_one = await project_repository.create( + { + "name": "project-one-boundary", + "path": str(config_home / "project-one-boundary"), + "is_active": True, + } + ) + project_two = await project_repository.create( + { + "name": "project-two-boundary", + "path": str(config_home / "project-two-boundary"), + "is_active": True, + } + ) + entity_repository = EntityRepository(session_maker, project_id=project_two.id) + other_project_entity = await entity_repository.create( + { + "title": "Other Project Note", + "note_type": "test", + "permalink": "project-two/other-project-note", + "file_path": "notes/other-project-note.md", + "content_type": "text/markdown", + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } + ) + + repository = NoteContentRepository(session_maker, project_id=project_one.id) + + with pytest.raises( + ValueError, + match=f"Entity {other_project_entity.id} belongs to project {project_two.id}", + ): + await repository.create(build_note_content_payload(other_project_entity.id)) + + @pytest.mark.asyncio async def test_update_state_fields_realigns_identity_with_entity( session_maker, @@ -134,6 +243,31 @@ async def test_update_state_fields_realigns_identity_with_entity( assert updated.last_materialization_attempt_at is None +@pytest.mark.asyncio +async def test_update_state_fields_rejects_invalid_fields( + session_maker, + test_project: Project, + sample_entity, +): + """Only the declared mutable sync fields should be accepted.""" + repository = NoteContentRepository(session_maker, project_id=test_project.id) + await repository.create(build_note_content_payload(sample_entity.id)) + + with pytest.raises(ValueError, match="Unsupported note_content update fields: file_path"): + await repository.update_state_fields(sample_entity.id, file_path="renamed/note.md") + + +@pytest.mark.asyncio +async def test_update_state_fields_returns_none_for_missing_note_content( + session_maker, + test_project: Project, +): + """Missing note_content rows should produce a clean None response.""" + repository = NoteContentRepository(session_maker, project_id=test_project.id) + + assert await repository.update_state_fields(999999, file_write_status="failed") is None + + @pytest.mark.asyncio async def test_delete_by_entity_id(session_maker, test_project: Project, sample_entity): """Delete note_content directly by entity identifier.""" @@ -146,6 +280,17 @@ async def test_delete_by_entity_id(session_maker, test_project: Project, sample_ assert await repository.get_by_entity_id(sample_entity.id) is None +@pytest.mark.asyncio +async def test_delete_by_entity_id_returns_false_when_missing( + session_maker, + test_project: Project, +): + """Delete should report False when the note_content row does not exist.""" + repository = NoteContentRepository(session_maker, project_id=test_project.id) + + assert await repository.delete_by_entity_id(999999) is False + + @pytest.mark.asyncio async def test_note_content_cascades_when_entity_is_deleted( session_maker, @@ -221,3 +366,60 @@ async def test_note_content_file_path_lookup_is_project_scoped(session_maker, co assert found_two is not None assert found_one.entity_id == entity_one.id assert found_two.entity_id == entity_two.id + + +@pytest.mark.asyncio +async def test_note_content_file_path_lookup_prefers_entity_with_current_path( + session_maker, + config_home, +): + """File-path lookup should prefer the entity whose current path still matches.""" + project_repository = ProjectRepository(session_maker) + project = await project_repository.create( + { + "name": "project-path-drift", + "path": str(config_home / "project-path-drift"), + "is_active": True, + } + ) + entity_repository = EntityRepository(session_maker, project_id=project.id) + + stale_entity = await entity_repository.create( + { + "title": "Stale Note", + "note_type": "test", + "permalink": "project/stale-note", + "file_path": "archived/note.md", + "content_type": "text/markdown", + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } + ) + current_entity = await entity_repository.create( + { + "title": "Current Note", + "note_type": "test", + "permalink": "project/current-note", + "file_path": "shared/note.md", + "content_type": "text/markdown", + "created_at": datetime.now(timezone.utc), + "updated_at": datetime.now(timezone.utc), + } + ) + + repository = NoteContentRepository(session_maker, project_id=project.id) + stale_payload = build_note_content_payload(stale_entity.id) + stale_payload["updated_at"] = datetime.now(timezone.utc) + timedelta(minutes=5) + await repository.create(stale_payload) + await repository.create(build_note_content_payload(current_entity.id)) + + async with db.scoped_session(session_maker) as session: + stale_note_content = await repository.select_by_id(session, stale_entity.id) + assert stale_note_content is not None + stale_note_content.file_path = "shared/note.md" + await session.flush() + + found = await repository.get_by_file_path("shared/note.md") + + assert found is not None + assert found.entity_id == current_entity.id