diff --git a/src/basic_memory/config.py b/src/basic_memory/config.py index fef79c4c0..6f957a6ee 100644 --- a/src/basic_memory/config.py +++ b/src/basic_memory/config.py @@ -132,6 +132,12 @@ class BasicMemoryConfig(BaseSettings): gt=0, ) + sync_batch_size: int = Field( + default=100, + description="Number of files to process in a single database transaction during sync. Higher values improve performance with remote databases (Postgres) but increase memory usage. Typical values: 100 (conservative), 500 (balanced), 1000 (aggressive).", + gt=0, + ) + kebab_filenames: bool = Field( default=False, description="Format for generated filenames. False preserves spaces and special chars, True converts them to hyphens for consistency with permalinks", diff --git a/src/basic_memory/repository/entity_repository.py b/src/basic_memory/repository/entity_repository.py index 149a75156..2377cb2ab 100644 --- a/src/basic_memory/repository/entity_repository.py +++ b/src/basic_memory/repository/entity_repository.py @@ -63,6 +63,38 @@ async def get_by_file_path(self, file_path: Union[Path, str]) -> Optional[Entity ) return await self.find_one(query) + async def get_by_file_paths_batch( + self, file_paths: Sequence[Union[Path, str]] + ) -> dict[str, Entity]: + """Batch fetch entities by file paths with eager-loaded relationships. + + Optimized for scan operations - reduces N queries to 1 batched query. + Returns entities with relationships already loaded via selectinload. + + Args: + file_paths: List of file paths to fetch entities for + + Returns: + Dict mapping file_path (as posix string) -> Entity + Only includes entities that exist; missing files are not in dict + """ + if not file_paths: + return {} + + # Convert all paths to posix strings + posix_paths = [Path(p).as_posix() for p in file_paths] + + # Batch query with eager loading + query = ( + self.select().where(Entity.file_path.in_(posix_paths)).options(*self.get_load_options()) + ) + + result = await self.execute_query(query) + entities = list(result.scalars().all()) + + # Return as dict for O(1) lookup + return {e.file_path: e for e in entities} + async def find_by_checksum(self, checksum: str) -> Sequence[Entity]: """Find entities with the given checksum. @@ -338,3 +370,80 @@ async def _handle_permalink_conflict(self, entity: Entity, session: AsyncSession # Re-raise if not a foreign key error raise return entity + + async def upsert_entities(self, entities: List[Entity]) -> List[Entity]: + """Bulk insert or update multiple entities in a single transaction. + + Optimized for batch operations with remote databases (Postgres). + Handles conflicts the same way as upsert_entity() but processes + all entities in one transaction. + + Args: + entities: List of entities to upsert + + Returns: + List of upserted entities with relationships loaded + + Raises: + SyncFatalError: If any entity references a non-existent project_id + """ + if not entities: + return [] + + async with db.scoped_session(self.session_maker) as session: + # Set project_id on all entities if needed + for entity in entities: + self._set_project_id_if_needed(entity) + + # Try to add all entities + for entity in entities: + session.add(entity) + + try: + await session.flush() + + # Fetch all entities with relationships loaded + file_paths = [e.file_path for e in entities] + query = ( + self.select() + .where(Entity.file_path.in_(file_paths)) + .options(*self.get_load_options()) + ) + result = await session.execute(query) + return list(result.scalars().all()) + + except IntegrityError as e: + # Check for foreign key constraint failures + error_str = str(e) + if ( + "FOREIGN KEY constraint failed" in error_str + or "violates foreign key constraint" in error_str + ): + from basic_memory.services.exceptions import SyncFatalError + + raise SyncFatalError( + "Cannot sync entities: project_id does not exist in database. " + "The project may have been deleted. This sync will be terminated." + ) from e + + # For other integrity errors (file_path or permalink conflicts), + # rollback and fall back to individual processing + await session.rollback() + + # Process each entity individually to handle conflicts properly + logger.debug( + f"Batch upsert failed with IntegrityError, falling back to individual upserts for {len(entities)} entities" + ) + + result_entities = [] + for entity in entities: + try: + upserted = await self.upsert_entity(entity) + result_entities.append(upserted) + except Exception as individual_error: + logger.error( + f"Failed to upsert entity {entity.file_path}: {individual_error}" + ) + # Continue with other entities + + return result_entities diff --git a/src/basic_memory/repository/observation_repository.py b/src/basic_memory/repository/observation_repository.py index 5fb91595d..d6974202c 100644 --- a/src/basic_memory/repository/observation_repository.py +++ b/src/basic_memory/repository/observation_repository.py @@ -70,3 +70,33 @@ async def find_by_entities(self, entity_ids: List[int]) -> Dict[int, List[Observ observations_by_entity[obs.entity_id].append(obs) return observations_by_entity + + async def delete_by_entity_ids(self, entity_ids: List[int]) -> int: + """Delete all observations for multiple entities in a single query. + + Optimized for batch operations - deletes observations for many entities + in one database transaction. + + Args: + entity_ids: List of entity IDs whose observations should be deleted + + Returns: + Number of observations deleted + """ + if not entity_ids: + return 0 + + from basic_memory import db + + async with db.scoped_session(self.session_maker) as session: + # Use bulk delete with IN clause + query = select(Observation).where(Observation.entity_id.in_(entity_ids)) + result = await session.execute(query) + observations_to_delete = result.scalars().all() + + # Delete all observations + for obs in observations_to_delete: + await session.delete(obs) + + await session.flush() + return len(observations_to_delete) diff --git a/src/basic_memory/repository/relation_repository.py b/src/basic_memory/repository/relation_repository.py index 913adb275..cd40622ff 100644 --- a/src/basic_memory/repository/relation_repository.py +++ b/src/basic_memory/repository/relation_repository.py @@ -67,6 +67,27 @@ async def delete_outgoing_relations_from_entity(self, entity_id: int) -> None: async with db.scoped_session(self.session_maker) as session: await session.execute(delete(Relation).where(Relation.from_id == entity_id)) + async def delete_outgoing_relations_from_entities(self, entity_ids: List[int]) -> int: + """Delete outgoing relations for multiple entities in a single query. + + Optimized for batch operations - deletes relations for many entities + in one database transaction. Only deletes relations where these entities + are the source (from_id). + + Args: + entity_ids: List of entity IDs whose outgoing relations should be deleted + + Returns: + Number of relations deleted + """ + if not entity_ids: + return 0 + + async with db.scoped_session(self.session_maker) as session: + # Use bulk delete with IN clause + result = await session.execute(delete(Relation).where(Relation.from_id.in_(entity_ids))) + return result.rowcount or 0 + async def find_unresolved_relations(self) -> Sequence[Relation]: """Find all unresolved relations, where to_id is null.""" query = select(Relation).filter(Relation.to_id.is_(None)) diff --git a/src/basic_memory/sync/sync_service.py b/src/basic_memory/sync/sync_service.py index 9b4c78cbf..85808a522 100644 --- a/src/basic_memory/sync/sync_service.py +++ b/src/basic_memory/sync/sync_service.py @@ -31,6 +31,7 @@ from basic_memory.services.exceptions import SyncFatalError from basic_memory.services.link_resolver import LinkResolver from basic_memory.services.search_service import SearchService +from basic_memory.sync.utils import chunks # Circuit breaker configuration MAX_CONSECUTIVE_FAILURES = 3 @@ -299,38 +300,139 @@ async def sync( for path in report.deleted: await self.handle_delete(path) - # then new and modified + # then new and modified - process in batches for better performance + batch_size = self.app_config.sync_batch_size + logger.debug(f"Using batch size of {batch_size} for file processing") + with logfire.span("process_new_files", new_count=len(report.new)): - for path in report.new: - entity, _ = await self.sync_file(path, new=True) - - # Track if file was skipped - if entity is None and await self._should_skip_file(path): - failure_info = self._file_failures[path] - report.skipped_files.append( - SkippedFile( - path=path, - reason=failure_info.last_error, - failure_count=failure_info.count, - first_failed=failure_info.first_failure, + # Convert set to list for batching + new_files_list = list(report.new) + + for batch in chunks(new_files_list, batch_size): + logger.debug(f"Processing batch of {len(batch)} new files") + + # Separate markdown and non-markdown files + markdown_files = [p for p in batch if self.file_service.is_markdown(p)] + regular_files = [p for p in batch if not self.file_service.is_markdown(p)] + + # Batch process markdown files + if markdown_files: + try: + batch_results = await self.sync_markdown_batch(markdown_files, new=True) + + # Track skipped files + for path, (entity, _) in zip(markdown_files, batch_results): + if entity is None and await self._should_skip_file(path): + failure_info = self._file_failures[path] + report.skipped_files.append( + SkippedFile( + path=path, + reason=failure_info.last_error, + failure_count=failure_info.count, + first_failed=failure_info.first_failure, + ) + ) + + except SyncFatalError: + # Re-raise fatal errors immediately + raise + except Exception as e: + # Batch method raised an exception - record failure for all files in batch + logger.error(f"Batch sync failed for {len(markdown_files)} files: {e}") + for path in markdown_files: + await self._record_failure(path, str(e)) + # Track skipped files + if await self._should_skip_file(path): + failure_info = self._file_failures[path] + report.skipped_files.append( + SkippedFile( + path=path, + reason=failure_info.last_error, + failure_count=failure_info.count, + first_failed=failure_info.first_failure, + ) + ) + + # Process regular files individually (they're already fast) + for path in regular_files: + entity, _ = await self.sync_file(path, new=True) + + # Track if file was skipped + if entity is None and await self._should_skip_file(path): + failure_info = self._file_failures[path] + report.skipped_files.append( + SkippedFile( + path=path, + reason=failure_info.last_error, + failure_count=failure_info.count, + first_failed=failure_info.first_failure, + ) ) - ) with logfire.span("process_modified_files", modified_count=len(report.modified)): - for path in report.modified: - entity, _ = await self.sync_file(path, new=False) - - # Track if file was skipped - if entity is None and await self._should_skip_file(path): - failure_info = self._file_failures[path] - report.skipped_files.append( - SkippedFile( - path=path, - reason=failure_info.last_error, - failure_count=failure_info.count, - first_failed=failure_info.first_failure, + # Convert set to list for batching + modified_files_list = list(report.modified) + + for batch in chunks(modified_files_list, batch_size): + logger.debug(f"Processing batch of {len(batch)} modified files") + + # Separate markdown and non-markdown files + markdown_files = [p for p in batch if self.file_service.is_markdown(p)] + regular_files = [p for p in batch if not self.file_service.is_markdown(p)] + + # Batch process markdown files + if markdown_files: + try: + batch_results = await self.sync_markdown_batch(markdown_files, new=False) + + # Track skipped files + for path, (entity, _) in zip(markdown_files, batch_results): + if entity is None and await self._should_skip_file(path): + failure_info = self._file_failures[path] + report.skipped_files.append( + SkippedFile( + path=path, + reason=failure_info.last_error, + failure_count=failure_info.count, + first_failed=failure_info.first_failure, + ) + ) + + except SyncFatalError: + # Re-raise fatal errors immediately + raise + except Exception as e: + # Batch method raised an exception - record failure for all files in batch + logger.error(f"Batch sync failed for {len(markdown_files)} files: {e}") + for path in markdown_files: + await self._record_failure(path, str(e)) + # Track skipped files + if await self._should_skip_file(path): + failure_info = self._file_failures[path] + report.skipped_files.append( + SkippedFile( + path=path, + reason=failure_info.last_error, + failure_count=failure_info.count, + first_failed=failure_info.first_failure, + ) + ) + + # Process regular files individually (they're already fast) + for path in regular_files: + entity, _ = await self.sync_file(path, new=False) + + # Track if file was skipped + if entity is None and await self._should_skip_file(path): + failure_info = self._file_failures[path] + report.skipped_files.append( + SkippedFile( + path=path, + reason=failure_info.last_error, + failure_count=failure_info.count, + first_failed=failure_info.first_failure, + ) ) - ) # Only resolve relations if there were actual changes # If no files changed, no new unresolved relations could have been created @@ -484,6 +586,12 @@ async def scan(self, directory, force_full: bool = False): logger.debug(f"Processing {len(file_paths_to_scan)} files with mtime-based comparison") + # Optimization: Batch fetch all entities for files being scanned + # This reduces N queries to 1 batch query (massive performance win for remote DBs) + logger.debug(f"Batch fetching entities for {len(file_paths_to_scan)} files") + entities_by_path = await self.entity_repository.get_by_file_paths_batch(file_paths_to_scan) + logger.debug(f"Found {len(entities_by_path)} existing entities in database") + for rel_path in file_paths_to_scan: scanned_paths.add(rel_path) @@ -495,8 +603,8 @@ async def scan(self, directory, force_full: bool = False): stat_info = abs_path.stat() - # Indexed lookup - single file query (not full table scan) - db_entity = await self.entity_repository.get_by_file_path(rel_path) + # O(1) dict lookup instead of database query + db_entity = entities_by_path.get(rel_path) if db_entity is None: # New file - need checksum for move detection @@ -737,6 +845,206 @@ async def sync_markdown_file(self, path: str, new: bool = True) -> Tuple[Optiona # Return the final checksum to ensure everything is consistent return entity, final_checksum + @logfire.instrument() + async def sync_markdown_batch( + self, paths: List[str], new: bool = True + ) -> List[Tuple[Optional[Entity], str]]: + """Sync multiple markdown files in a single batch operation. + + Optimized for remote databases (Postgres) - reduces N queries to 1 batch query. + Parses all files first, then does all database operations in one transaction. + + Args: + paths: List of paths to markdown files + new: Whether these are new files + + Returns: + List of tuples (entity, checksum) for each file + """ + from basic_memory.markdown.utils import entity_model_from_markdown + + if not paths: + return [] + + logger.debug(f"Batch syncing {len(paths)} markdown files (new={new})") + + # Phase 1: Parse all files (no DB operations) + parsed_files = [] + for path in paths: + # Check if file should be skipped due to repeated failures (circuit breaker) + if await self._should_skip_file(path): + logger.warning(f"Skipping file in batch due to repeated failures: {path}") + parsed_files.append(None) + continue + + try: + file_content = await self.file_service.read_file_content(path) + file_contains_frontmatter = has_frontmatter(file_content) + + # Get file timestamps for tracking modification times + file_stats = self.file_service.file_stats(path) + created = datetime.fromtimestamp(file_stats.st_ctime).astimezone() + modified = datetime.fromtimestamp(file_stats.st_mtime).astimezone() + + # Parse markdown to get entity structure + entity_markdown = await self.entity_parser.parse_file(path) + + # Resolve permalink if needed (skip conflict checks during batch) + permalink = entity_markdown.frontmatter.permalink + if file_contains_frontmatter and not self.app_config.disable_permalinks: + permalink = await self.entity_service.resolve_permalink( + path, markdown=entity_markdown, skip_conflict_check=True + ) + + # If permalink changed, update the file + if permalink != entity_markdown.frontmatter.permalink: + logger.info( + f"Updating permalink for path: {path}, " + f"old_permalink: {entity_markdown.frontmatter.permalink}, " + f"new_permalink: {permalink}" + ) + entity_markdown.frontmatter.metadata["permalink"] = permalink + await self.file_service.update_frontmatter(path, {"permalink": permalink}) + + # Convert to entity model (without saving to DB yet) + entity_model = entity_model_from_markdown(Path(path), entity_markdown) + entity_model.checksum = None # Will be set after relations are resolved + + parsed_files.append( + { + "path": path, + "entity_model": entity_model, + "entity_markdown": entity_markdown, + "created": created, + "modified": modified, + "mtime": file_stats.st_mtime, + "size": file_stats.st_size, + } + ) + + except Exception as e: + # Check if this is a fatal error (or caused by one) + # Fatal errors like project deletion should terminate sync immediately + if isinstance(e, SyncFatalError) or isinstance(e.__cause__, SyncFatalError): + logger.error( + f"Fatal sync error encountered during batch parse, terminating sync: path={path}" + ) + raise + + # Otherwise treat as recoverable file-level error + logger.error(f"Failed to parse file in batch: path={path}, error={e}") + # Track failure for circuit breaker + await self._record_failure(path, str(e)) + parsed_files.append(None) # Mark as failed + + # Phase 2: Batch database operations + # Filter out failed parses + valid_files = [f for f in parsed_files if f is not None] + if not valid_files: + logger.warning("No valid files to sync in batch") + return [(None, "") for _ in paths] + + # If this is a new batch, upsert all entities at once + if new: + entities_to_upsert = [f["entity_model"] for f in valid_files] + logger.debug(f"Batch upserting {len(entities_to_upsert)} new entities") + upserted_entities = await self.entity_repository.upsert_entities(entities_to_upsert) + + # Create lookup by file_path for O(1) access + entities_by_path = {e.file_path: e for e in upserted_entities} + + # If updating existing entities, we need to handle observations/relations differently + else: + logger.debug(f"Batch updating {len(valid_files)} existing entities") + # For updates, we need to: + # 1. Get existing entities + # 2. Delete old observations/relations + # 3. Upsert updated entities + + file_paths = [f["path"] for f in valid_files] + existing_entities = await self.entity_repository.get_by_file_paths_batch(file_paths) + + # Delete old observations and relations in batch + entity_ids = [e.id for e in existing_entities.values() if e.id] + if entity_ids: + await self.entity_service.observation_repository.delete_by_entity_ids(entity_ids) + await self.relation_repository.delete_outgoing_relations_from_entities(entity_ids) + + # Upsert all updated entities + entities_to_upsert = [f["entity_model"] for f in valid_files] + upserted_entities = await self.entity_repository.upsert_entities(entities_to_upsert) + + # Create lookup by file_path + entities_by_path = {e.file_path: e for e in upserted_entities} + + # Phase 3: Post-processing (relations, checksums, search index) + results = [] + for i, path in enumerate(paths): + parsed_file = parsed_files[i] + + # Skip failed files + if parsed_file is None: + results.append((None, "")) + continue + + entity = entities_by_path.get(parsed_file["path"]) + if entity is None: + logger.error(f"Entity not found after upsert: {parsed_file['path']}") + results.append((None, "")) + continue + + try: + # Update relations for this entity + entity_with_relations = await self.entity_service.update_entity_relations( + parsed_file["path"], parsed_file["entity_markdown"] + ) + + # Compute final checksum after relations are resolved + final_checksum = await self.file_service.compute_checksum(parsed_file["path"]) + + # Update checksum, timestamps, and file metadata + await self.entity_repository.update( + entity.id, + { + "checksum": final_checksum, + "created_at": parsed_file["created"], + "updated_at": parsed_file["modified"], + "mtime": parsed_file["mtime"], + "size": parsed_file["size"], + }, + ) + + # Index for search + await self.search_service.index_entity(entity_with_relations) + + # Clear failure tracking on successful sync + self._clear_failure(parsed_file["path"]) + + results.append((entity_with_relations, final_checksum)) + + logger.debug( + f"Batch sync completed for file: path={parsed_file['path']}, " + f"entity_id={entity.id}, checksum={final_checksum[:8]}" + ) + + except Exception as e: + # Check if this is a fatal error + if isinstance(e, SyncFatalError) or isinstance(e.__cause__, SyncFatalError): + logger.error( + f"Fatal sync error during post-processing, terminating sync: path={parsed_file['path']}" + ) + raise + + # Otherwise treat as recoverable file-level error + logger.error( + f"Failed to complete post-processing for file: path={parsed_file['path']}, error={e}" + ) + await self._record_failure(parsed_file["path"], str(e)) + results.append((None, "")) + continue + + return results + @logfire.instrument() async def sync_regular_file(self, path: str, new: bool = True) -> Tuple[Optional[Entity], str]: """Sync a non-markdown file with basic tracking. @@ -1134,13 +1442,13 @@ async def _scan_directory_modified_since( async def _scan_directory_full(self, directory: Path) -> List[str]: """Full directory scan returning all file paths. - Uses scan_directory() which respects .bmignore patterns. + Uses scan_directory() which respects .bmignore patterns. - Args: - directory: Directory to scan + Args: + directory: Directory to scan Returns: - List of relative file paths (respects .bmignore) + List of relative file paths (respects .bmignore) """ file_paths = [] async for file_path_str, _ in self.scan_directory(directory): diff --git a/src/basic_memory/sync/utils.py b/src/basic_memory/sync/utils.py new file mode 100644 index 000000000..b3b5172df --- /dev/null +++ b/src/basic_memory/sync/utils.py @@ -0,0 +1,23 @@ +"""Utilities for sync operations.""" + +from typing import Iterator, List, TypeVar + +T = TypeVar("T") + + +def chunks(items: List[T], size: int) -> Iterator[List[T]]: + """Split a list into chunks of specified size. + + Args: + items: List of items to chunk + size: Size of each chunk + + Yields: + Lists of items, each of specified size (last chunk may be smaller) + + Example: + >>> list(chunks([1, 2, 3, 4, 5], 2)) + [[1, 2], [3, 4], [5]] + """ + for i in range(0, len(items), size): + yield items[i : i + size] diff --git a/tests/sync/test_sync_service.py b/tests/sync/test_sync_service.py index e54b1413c..e64cf05e9 100644 --- a/tests/sync/test_sync_service.py +++ b/tests/sync/test_sync_service.py @@ -1486,11 +1486,12 @@ async def test_circuit_breaker_skips_after_three_failures( # Create a file with malformed content that will fail to parse await create_test_file(test_file, "invalid markdown content") - # Mock sync_markdown_file to always fail - async def mock_sync_markdown_file(*args, **kwargs): + # Mock sync_markdown_batch to always fail for all files in batch + async def mock_sync_markdown_batch(paths, new=True): + # Simulate batch failure - return (None, "") for each path raise ValueError("Simulated sync failure") - with patch.object(sync_service, "sync_markdown_file", side_effect=mock_sync_markdown_file): + with patch.object(sync_service, "sync_markdown_batch", side_effect=mock_sync_markdown_batch): # First sync - should fail and record (1/3) report1 = await sync_service.sync(project_dir) assert len(report1.skipped_files) == 0 # Not skipped yet @@ -1545,15 +1546,15 @@ async def test_circuit_breaker_resets_on_file_change( # Create initial failing content await create_test_file(test_file, "initial bad content") - # Mock sync_markdown_file to fail + # Mock sync_markdown_batch to fail call_count = 0 - async def mock_sync_markdown_file(*args, **kwargs): + async def mock_sync_markdown_batch(paths, new=True): nonlocal call_count - call_count += 1 + call_count += len(paths) raise ValueError("Simulated sync failure") - with patch.object(sync_service, "sync_markdown_file", side_effect=mock_sync_markdown_file): + with patch.object(sync_service, "sync_markdown_batch", side_effect=mock_sync_markdown_batch): # Fail 3 times to hit circuit breaker threshold await sync_service.sync(project_dir) # Fail 1 await touch_file(test_file) # Touch to trigger incremental scan @@ -1670,8 +1671,8 @@ async def test_circuit_breaker_handles_checksum_computation_failure( test_file = project_dir / "checksum_fail.md" await create_test_file(test_file, "content") - # Mock sync_markdown_file to fail - async def mock_sync_markdown_file(*args, **kwargs): + # Mock sync_markdown_batch to fail + async def mock_sync_markdown_batch(paths, new=True): raise ValueError("Sync failure") # Mock checksum computation to fail only during _record_failure (not during scan) @@ -1688,7 +1689,7 @@ async def mock_compute_checksum(path): raise IOError("Cannot read file") with ( - patch.object(sync_service, "sync_markdown_file", side_effect=mock_sync_markdown_file), + patch.object(sync_service, "sync_markdown_batch", side_effect=mock_sync_markdown_batch), patch.object( sync_service.file_service, "compute_checksum", @@ -1757,16 +1758,16 @@ async def test_sync_fatal_error_terminates_sync_immediately( ), ) - # Mock entity_service.create_entity_from_markdown to raise SyncFatalError on first file - # This simulates project being deleted during sync - async def mock_create_entity_from_markdown(*args, **kwargs): + # Mock entity_repository.upsert_entities to raise SyncFatalError + # This simulates project being deleted during batch sync + async def mock_upsert_entities(entities): raise SyncFatalError( - "Cannot sync file 'file1.md': project_id=99999 does not exist in database. " + "Cannot sync entities: project_id=99999 does not exist in database. " "The project may have been deleted. This sync will be terminated." ) with patch.object( - entity_service, "create_entity_from_markdown", side_effect=mock_create_entity_from_markdown + sync_service.entity_repository, "upsert_entities", side_effect=mock_upsert_entities ): # Sync should raise SyncFatalError and terminate immediately with pytest.raises(SyncFatalError, match="project_id=99999 does not exist"):