diff --git a/src/basic_memory/repository/entity_repository.py b/src/basic_memory/repository/entity_repository.py index 1159e2808..3e15dc7e4 100644 --- a/src/basic_memory/repository/entity_repository.py +++ b/src/basic_memory/repository/entity_repository.py @@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.orm import selectinload from sqlalchemy.orm.interfaces import LoaderOption +from sqlalchemy.dialects.sqlite import insert as sqlite_insert from basic_memory import db from basic_memory.models.knowledge import Entity, Observation, Relation @@ -101,11 +102,11 @@ async def find_by_permalinks(self, permalinks: List[str]) -> Sequence[Entity]: return list(result.scalars().all()) async def upsert_entity(self, entity: Entity) -> Entity: - """Insert or update entity using a hybrid approach. + """Insert or update entity using SQLite's ON CONFLICT clause. - This method provides a cleaner alternative to the try/catch approach - for handling permalink and file_path conflicts. It first tries direct - insertion, then handles conflicts intelligently. + This method uses SQLite's native ON CONFLICT semantics to handle race conditions + efficiently without manual exception handling. It's atomic and eliminates the need + for separate checks and updates. Args: entity: The entity to insert or update @@ -113,114 +114,78 @@ async def upsert_entity(self, entity: Entity) -> Entity: Returns: The inserted or updated entity """ - async with db.scoped_session(self.session_maker) as session: # Set project_id if applicable and not already set self._set_project_id_if_needed(entity) - # Check for existing entity with same file_path first - existing_by_path = await session.execute( - select(Entity).where( - Entity.file_path == entity.file_path, Entity.project_id == entity.project_id - ) + # Build entity data dictionary from the entity object + entity_data = { + "file_path": entity.file_path, + "project_id": entity.project_id, + "title": entity.title, + "entity_type": entity.entity_type, + "entity_metadata": entity.entity_metadata, + "content_type": entity.content_type, + "permalink": entity.permalink, + "checksum": entity.checksum, + "created_at": entity.created_at, + "updated_at": entity.updated_at, + } + + # First attempt: Try to upsert with the given permalink + stmt = sqlite_insert(Entity).values(entity_data) + stmt = stmt.on_conflict_do_update( + index_elements=["file_path"], + set_={ + "title": stmt.excluded.title, + "entity_type": stmt.excluded.entity_type, + "entity_metadata": stmt.excluded.entity_metadata, + "content_type": stmt.excluded.content_type, + "permalink": stmt.excluded.permalink, + "checksum": stmt.excluded.checksum, + "updated_at": stmt.excluded.updated_at, + }, ) - existing_path_entity = existing_by_path.scalar_one_or_none() - - if existing_path_entity: - # Update existing entity with same file path - for key, value in { - "title": entity.title, - "entity_type": entity.entity_type, - "entity_metadata": entity.entity_metadata, - "content_type": entity.content_type, - "permalink": entity.permalink, - "checksum": entity.checksum, - "updated_at": entity.updated_at, - }.items(): - setattr(existing_path_entity, key, value) - await session.flush() - # Return with relationships loaded - query = ( - self.select() - .where(Entity.file_path == entity.file_path) - .options(*self.get_load_options()) - ) - result = await session.execute(query) - found = result.scalar_one_or_none() - if not found: # pragma: no cover - raise RuntimeError( - f"Failed to retrieve entity after update: {entity.file_path}" - ) - return found - - # No existing entity with same file_path, try insert try: - # Simple insert for new entity - session.add(entity) + await session.execute(stmt) await session.flush() + except IntegrityError as e: + # If we get here, it's likely a permalink conflict + if "UNIQUE constraint failed: entity.permalink" in str(e): + await session.rollback() + # Handle permalink conflict by generating a unique permalink + return await self._handle_permalink_conflict_optimistic(entity_data, session) + raise + + # Retrieve the entity with relationships loaded + query = ( + self.select() + .where(Entity.file_path == entity.file_path) + .options(*self.get_load_options()) + ) + result = await session.execute(query) + found = result.scalar_one_or_none() - # Return with relationships loaded - query = ( - self.select() - .where(Entity.file_path == entity.file_path) - .options(*self.get_load_options()) - ) - result = await session.execute(query) - found = result.scalar_one_or_none() - if not found: # pragma: no cover - raise RuntimeError( - f"Failed to retrieve entity after insert: {entity.file_path}" - ) - return found - - except IntegrityError: - # Could be either file_path or permalink conflict - await session.rollback() - - # Check if it's a file_path conflict (race condition) - existing_by_path_check = await session.execute( - select(Entity).where( - Entity.file_path == entity.file_path, Entity.project_id == entity.project_id - ) - ) - race_condition_entity = existing_by_path_check.scalar_one_or_none() - - if race_condition_entity: - # Race condition: file_path conflict detected after our initial check - # Update the existing entity instead - for key, value in { - "title": entity.title, - "entity_type": entity.entity_type, - "entity_metadata": entity.entity_metadata, - "content_type": entity.content_type, - "permalink": entity.permalink, - "checksum": entity.checksum, - "updated_at": entity.updated_at, - }.items(): - setattr(race_condition_entity, key, value) - - await session.flush() - # Return the updated entity with relationships loaded - query = ( - self.select() - .where(Entity.file_path == entity.file_path) - .options(*self.get_load_options()) - ) - result = await session.execute(query) - found = result.scalar_one_or_none() - if not found: # pragma: no cover - raise RuntimeError( - f"Failed to retrieve entity after race condition update: {entity.file_path}" - ) - return found - else: - # Must be permalink conflict - generate unique permalink - return await self._handle_permalink_conflict(entity, session) - - async def _handle_permalink_conflict(self, entity: Entity, session: AsyncSession) -> Entity: - """Handle permalink conflicts by generating a unique permalink.""" - base_permalink = entity.permalink + if not found: # pragma: no cover + raise RuntimeError(f"Failed to retrieve entity after upsert: {entity.file_path}") + + return found + + async def _handle_permalink_conflict_optimistic( + self, entity_data: dict, session: AsyncSession + ) -> Entity: + """Handle permalink conflicts by generating a unique permalink. + + Args: + entity_data: Dictionary of entity data to insert + session: Database session to use + + Returns: + The inserted entity with a unique permalink + """ + base_permalink = entity_data["permalink"] + project_id = entity_data.get("project_id") suffix = 1 # Find a unique permalink @@ -228,27 +193,131 @@ async def _handle_permalink_conflict(self, entity: Entity, session: AsyncSession test_permalink = f"{base_permalink}-{suffix}" existing = await session.execute( select(Entity).where( - Entity.permalink == test_permalink, Entity.project_id == entity.project_id + Entity.permalink == test_permalink, Entity.project_id == project_id ) ) if existing.scalar_one_or_none() is None: # Found unique permalink - entity.permalink = test_permalink + entity_data["permalink"] = test_permalink break suffix += 1 - # Insert with unique permalink (no conflict possible now) - session.add(entity) + # Insert with unique permalink using ON CONFLICT + stmt = sqlite_insert(Entity).values(entity_data) + stmt = stmt.on_conflict_do_update( + index_elements=["file_path"], + set_={ + key: stmt.excluded[key] + for key in entity_data.keys() + if key not in ["file_path", "id", "project_id"] + }, + ) + + await session.execute(stmt) await session.flush() # Return the inserted entity with relationships loaded query = ( self.select() - .where(Entity.file_path == entity.file_path) + .where(Entity.file_path == entity_data["file_path"]) .options(*self.get_load_options()) ) result = await session.execute(query) found = result.scalar_one_or_none() if not found: # pragma: no cover - raise RuntimeError(f"Failed to retrieve entity after insert: {entity.file_path}") + raise RuntimeError(f"Failed to retrieve entity after insert: {entity_data['file_path']}") return found + + async def update(self, entity_id: int, entity_data: dict | Entity) -> Optional[Entity]: + """Update an entity using SQLite's ON CONFLICT clause for atomic updates. + + This method overrides the base repository update to use SQLite's native + ON CONFLICT semantics, eliminating race conditions between concurrent updates. + + Args: + entity_id: The ID of the entity to update + entity_data: Dictionary of fields to update or an Entity object + + Returns: + The updated entity or None if no entity exists with the given ID + """ + async with db.scoped_session(self.session_maker) as session: + # First get the current entity + existing_entity = await session.execute( + select(Entity).where(Entity.id == entity_id) + ) + entity = existing_entity.scalar_one_or_none() + + if not entity: + return None + + # Convert Entity object to dict if needed + if isinstance(entity_data, Entity): + updates = { + column.name: getattr(entity_data, column.name) + for column in Entity.__table__.columns + if hasattr(entity_data, column.name) + } + else: + updates = entity_data + + # Build complete entity data with current values + updates + complete_data = { + "file_path": entity.file_path, + "project_id": entity.project_id, + "title": entity.title, + "entity_type": entity.entity_type, + "content_type": entity.content_type, + "permalink": entity.permalink, + "checksum": entity.checksum, + "created_at": entity.created_at, + "updated_at": entity.updated_at, + } + + # Apply updates + for key, value in updates.items(): + if key in self.valid_columns and key != "id": + complete_data[key] = value + + # Use ON CONFLICT to handle concurrent updates atomically + stmt = sqlite_insert(Entity).values(complete_data) + stmt = stmt.on_conflict_do_update( + index_elements=["file_path"], + set_={ + key: stmt.excluded[key] + for key in complete_data.keys() + if key not in ["file_path", "id", "project_id", "created_at"] + }, + ) + + await session.execute(stmt) + await session.flush() + + # Return the updated entity with relationships loaded + query = ( + self.select() + .where(Entity.file_path == complete_data["file_path"]) + .options(*self.get_load_options()) + ) + result = await session.execute(query) + return result.scalar_one_or_none() + + async def update_by_file_path(self, file_path: str, updates: dict) -> Optional[Entity]: + """Update an entity by file_path using SQLite's ON CONFLICT clause. + + This is a convenience method for updating entities when you only have the file_path. + + Args: + file_path: The file_path of the entity to update + updates: Dictionary of fields to update + + Returns: + The updated entity or None if no entity exists with the given file_path + """ + # First check if entity exists + entity = await self.get_by_file_path(file_path) + if not entity: + return None + + # Use the regular update method with the entity ID + return await self.update(entity.id, updates) diff --git a/src/basic_memory/sync/sync_service.py b/src/basic_memory/sync/sync_service.py index 96e6f792f..32c77f892 100644 --- a/src/basic_memory/sync/sync_service.py +++ b/src/basic_memory/sync/sync_service.py @@ -389,8 +389,8 @@ async def sync_regular_file(self, path: str, new: bool = True) -> Tuple[Optional logger.error(f"Entity not found after constraint violation, path={path}") raise ValueError(f"Entity not found after constraint violation: {path}") - updated = await self.entity_repository.update( - entity.id, {"file_path": path, "checksum": checksum} + updated = await self.entity_repository.update_by_file_path( + path, {"checksum": checksum} ) if updated is None: # pragma: no cover @@ -407,8 +407,8 @@ async def sync_regular_file(self, path: str, new: bool = True) -> Tuple[Optional logger.error(f"Entity not found for existing file, path={path}") raise ValueError(f"Entity not found for existing file: {path}") - updated = await self.entity_repository.update( - entity.id, {"file_path": path, "checksum": checksum} + updated = await self.entity_repository.update_by_file_path( + path, {"checksum": checksum} ) if updated is None: # pragma: no cover