From 1d4b1fdcddced50a784da7231be34f1d1c229320 Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 12:22:58 -0500 Subject: [PATCH 01/14] wip: checkpoint vector sync batching pass Signed-off-by: phernandez --- .../repository/postgres_search_repository.py | 302 ++---- .../repository/search_repository_base.py | 969 +++++++++++------- .../repository/sqlite_search_repository.py | 22 +- src/basic_memory/telemetry.py | 56 + .../test_postgres_search_repository.py | 1 + .../test_postgres_search_repository_unit.py | 137 ++- tests/repository/test_semantic_search_base.py | 260 ++++- .../test_sqlite_vector_search_repository.py | 97 +- tests/test_telemetry.py | 56 + 9 files changed, 1211 insertions(+), 689 deletions(-) diff --git a/src/basic_memory/repository/postgres_search_repository.py b/src/basic_memory/repository/postgres_search_repository.py index 14f56a69..eb2ffa86 100644 --- a/src/basic_memory/repository/postgres_search_repository.py +++ b/src/basic_memory/repository/postgres_search_repository.py @@ -3,9 +3,8 @@ import asyncio import json import re -import time from datetime import datetime -from typing import List, Optional, cast +from typing import List, Optional from loguru import logger from sqlalchemy import text @@ -18,7 +17,7 @@ from basic_memory.repository.search_index_row import SearchIndexRow from basic_memory.repository.search_repository_base import ( SearchRepositoryBase, - _PreparedEntityVectorSync, + VectorChunkState, ) from basic_memory.repository.metadata_filters import parse_metadata_filters from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError @@ -458,247 +457,70 @@ def _vector_prepare_window_size(self) -> int: """Use a bounded config-driven prepare window for Postgres vector sync.""" return self._semantic_postgres_prepare_concurrency - async def _prepare_entity_vector_jobs_window( - self, entity_ids: list[int] - ) -> list[_PreparedEntityVectorSync | BaseException]: - """Prepare one Postgres window concurrently to hide DB round-trip latency.""" - prepared_window = await asyncio.gather( - *(self._prepare_entity_vector_jobs(entity_id) for entity_id in entity_ids), - return_exceptions=True, - ) - return [ - cast(_PreparedEntityVectorSync | BaseException, prepared) - for prepared in prepared_window - ] - - async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVectorSync: - """Prepare chunk mutations with Postgres-specific bulk upserts.""" - sync_start = time.perf_counter() - - async with db.scoped_session(self.session_maker) as session: - await self._prepare_vector_session(session) - - row_result = await session.execute( - text( - "SELECT id, type, title, permalink, content_stems, content_snippet, " - "category, relation_type " - "FROM search_index " - "WHERE entity_id = :entity_id AND project_id = :project_id " - "ORDER BY " - "CASE type " - "WHEN :entity_type THEN 0 " - "WHEN :observation_type THEN 1 " - "WHEN :relation_type_type THEN 2 " - "ELSE 3 END, id ASC" - ), - { - "entity_id": entity_id, - "project_id": self.project_id, - "entity_type": SearchItemType.ENTITY.value, - "observation_type": SearchItemType.OBSERVATION.value, - "relation_type_type": SearchItemType.RELATION.value, - }, - ) - rows = row_result.fetchall() - source_rows_count = len(rows) - - if not rows: - await self._delete_entity_chunks(session, entity_id) - await session.commit() - prepare_seconds = time.perf_counter() - sync_start - return _PreparedEntityVectorSync( - entity_id=entity_id, - sync_start=sync_start, - source_rows_count=source_rows_count, - embedding_jobs=[], - prepare_seconds=prepare_seconds, - ) - - chunk_records = self._build_chunk_records(rows) - built_chunk_records_count = len(chunk_records) - current_entity_fingerprint = self._build_entity_fingerprint(chunk_records) - current_embedding_model = self._embedding_model_key() - if not chunk_records: - await self._delete_entity_chunks(session, entity_id) - await session.commit() - prepare_seconds = time.perf_counter() - sync_start - return _PreparedEntityVectorSync( - entity_id=entity_id, - sync_start=sync_start, - source_rows_count=source_rows_count, - embedding_jobs=[], - prepare_seconds=prepare_seconds, - ) + async def _upsert_scheduled_chunk_records( + self, + session: AsyncSession, + *, + entity_id: int, + scheduled_records: list[dict[str, str]], + existing_by_key: dict[str, VectorChunkState], + entity_fingerprint: str, + embedding_model: str, + ) -> list[tuple[int, str]]: + """Use Postgres UPSERT to rewrite only the scheduled chunk rows.""" + if not scheduled_records: + return [] - existing_rows_result = await session.execute( - text( - "SELECT c.id, c.chunk_key, c.source_hash, c.entity_fingerprint, " - "c.embedding_model, " - "(e.chunk_id IS NOT NULL) AS has_embedding " - "FROM search_vector_chunks c " - "LEFT JOIN search_vector_embeddings e ON e.chunk_id = c.id " - "WHERE c.project_id = :project_id AND c.entity_id = :entity_id" - ), - {"project_id": self.project_id, "entity_id": entity_id}, - ) - existing_rows = existing_rows_result.mappings().all() - existing_by_key = {str(row["chunk_key"]): row for row in existing_rows} - existing_chunks_count = len(existing_by_key) - incoming_chunk_keys = {record["chunk_key"] for record in chunk_records} - - stale_ids = [ - int(row["id"]) - for chunk_key, row in existing_by_key.items() - if chunk_key not in incoming_chunk_keys - ] - stale_chunks_count = len(stale_ids) - if stale_ids: - await self._delete_stale_chunks(session, stale_ids, entity_id) - - orphan_ids = {int(row["id"]) for row in existing_rows if not bool(row["has_embedding"])} - orphan_chunks_count = len(orphan_ids) - - skip_unchanged_entity = ( - existing_chunks_count == built_chunk_records_count - and stale_chunks_count == 0 - and orphan_chunks_count == 0 - and existing_chunks_count > 0 - and all( - row["entity_fingerprint"] == current_entity_fingerprint - and row["embedding_model"] == current_embedding_model - for row in existing_rows - ) + upsert_params: dict[str, object] = { + "project_id": self.project_id, + "entity_id": entity_id, + } + upsert_values: list[str] = [] + # The SQL template is built from integer enumerate() indices only. + # No user-controlled text is interpolated into the statement. + for index, record in enumerate(scheduled_records): + upsert_params[f"chunk_key_{index}"] = record["chunk_key"] + upsert_params[f"chunk_text_{index}"] = record["chunk_text"] + upsert_params[f"source_hash_{index}"] = record["source_hash"] + upsert_params[f"entity_fingerprint_{index}"] = entity_fingerprint + upsert_params[f"embedding_model_{index}"] = embedding_model + upsert_values.append( + "(" + ":entity_id, :project_id, " + f":chunk_key_{index}, :chunk_text_{index}, :source_hash_{index}, " + f":entity_fingerprint_{index}, :embedding_model_{index}, NOW()" + ")" ) - if skip_unchanged_entity: - prepare_seconds = time.perf_counter() - sync_start - return _PreparedEntityVectorSync( - entity_id=entity_id, - sync_start=sync_start, - source_rows_count=source_rows_count, - embedding_jobs=[], - chunks_total=built_chunk_records_count, - chunks_skipped=built_chunk_records_count, - entity_skipped=True, - prepare_seconds=prepare_seconds, - ) - - pending_records: list[dict[str, str]] = [] - skipped_chunks_count = 0 - - for record in chunk_records: - current = existing_by_key.get(record["chunk_key"]) - if current is None: - pending_records.append(record) - continue - - row_id = int(current["id"]) - is_orphan = row_id in orphan_ids - same_source_hash = current["source_hash"] == record["source_hash"] - same_entity_fingerprint = ( - current["entity_fingerprint"] == current_entity_fingerprint - ) - same_embedding_model = current["embedding_model"] == current_embedding_model - - if same_source_hash and not is_orphan and same_embedding_model: - if not same_entity_fingerprint: - await session.execute( - text( - "UPDATE search_vector_chunks " - "SET entity_fingerprint = :entity_fingerprint, " - "embedding_model = :embedding_model, " - "updated_at = NOW() " - "WHERE id = :id" - ), - { - "id": row_id, - "entity_fingerprint": current_entity_fingerprint, - "embedding_model": current_embedding_model, - }, - ) - skipped_chunks_count += 1 - continue - pending_records.append(record) - - shard_plan = self._plan_entity_vector_shard(pending_records) - self._log_vector_shard_plan(entity_id=entity_id, shard_plan=shard_plan) - - scheduled_records = [ - record - for record in sorted(pending_records, key=lambda record: record["chunk_key"]) - if record["chunk_key"] in shard_plan.scheduled_chunk_keys - ] - - embedding_jobs: list[tuple[int, str]] = [] - upsert_records = list(scheduled_records) - - if upsert_records: - upsert_params: dict[str, object] = { - "project_id": self.project_id, - "entity_id": entity_id, - } - upsert_values: list[str] = [] - # The SQL template is built from integer enumerate() indices only. - # No user-controlled text is interpolated into the statement. - for index, record in enumerate(upsert_records): - upsert_params[f"chunk_key_{index}"] = record["chunk_key"] - upsert_params[f"chunk_text_{index}"] = record["chunk_text"] - upsert_params[f"source_hash_{index}"] = record["source_hash"] - upsert_params[f"entity_fingerprint_{index}"] = current_entity_fingerprint - upsert_params[f"embedding_model_{index}"] = current_embedding_model - upsert_values.append( - "(" - ":entity_id, :project_id, " - f":chunk_key_{index}, :chunk_text_{index}, :source_hash_{index}, " - f":entity_fingerprint_{index}, :embedding_model_{index}, NOW()" - ")" - ) - - upsert_result = await session.execute( - text(f""" - INSERT INTO search_vector_chunks ( - entity_id, - project_id, - chunk_key, - chunk_text, - source_hash, - entity_fingerprint, - embedding_model, - updated_at - ) VALUES {", ".join(upsert_values)} - ON CONFLICT (project_id, entity_id, chunk_key) DO UPDATE SET - chunk_text = EXCLUDED.chunk_text, - source_hash = EXCLUDED.source_hash, - entity_fingerprint = EXCLUDED.entity_fingerprint, - embedding_model = EXCLUDED.embedding_model, - updated_at = NOW() - RETURNING id, chunk_key - """), - upsert_params, - ) - upserted_ids_by_key = { - str(row["chunk_key"]): int(row["id"]) for row in upsert_result.mappings().all() - } - for record in upsert_records: - row_id = upserted_ids_by_key[record["chunk_key"]] - embedding_jobs.append((row_id, record["chunk_text"])) - - prepare_seconds = time.perf_counter() - sync_start - return _PreparedEntityVectorSync( - entity_id=entity_id, - sync_start=sync_start, - source_rows_count=source_rows_count, - embedding_jobs=embedding_jobs, - chunks_total=built_chunk_records_count, - chunks_skipped=skipped_chunks_count, - entity_complete=shard_plan.entity_complete, - oversized_entity=shard_plan.oversized_entity, - pending_jobs_total=shard_plan.pending_jobs_total, - shard_index=shard_plan.shard_index, - shard_count=shard_plan.shard_count, - remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard, - prepare_seconds=prepare_seconds, + upsert_result = await session.execute( + text(f""" + INSERT INTO search_vector_chunks ( + entity_id, + project_id, + chunk_key, + chunk_text, + source_hash, + entity_fingerprint, + embedding_model, + updated_at + ) VALUES {", ".join(upsert_values)} + ON CONFLICT (project_id, entity_id, chunk_key) DO UPDATE SET + chunk_text = EXCLUDED.chunk_text, + source_hash = EXCLUDED.source_hash, + entity_fingerprint = EXCLUDED.entity_fingerprint, + embedding_model = EXCLUDED.embedding_model, + updated_at = NOW() + RETURNING id, chunk_key + """), + upsert_params, ) + upserted_ids_by_key = { + str(row["chunk_key"]): int(row["id"]) for row in upsert_result.mappings().all() + } + return [ + (upserted_ids_by_key[record["chunk_key"]], record["chunk_text"]) + for record in scheduled_records + ] async def _write_embeddings( self, diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index f252df05..298767ec 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -1,11 +1,13 @@ """Abstract base class for search repository implementations.""" +import asyncio import hashlib import json import math import re import time from abc import ABC, abstractmethod +from contextlib import asynccontextmanager from dataclasses import dataclass, field, replace from datetime import datetime from typing import Any, Callable, Dict, List, Optional @@ -14,7 +16,7 @@ from sqlalchemy import Executable, Result, text from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker -from basic_memory import db +from basic_memory import db, telemetry from basic_memory.repository.embedding_provider import EmbeddingProvider from basic_memory.repository.search_index_row import SearchIndexRow from basic_memory.repository.semantic_errors import ( @@ -74,6 +76,7 @@ class _PreparedEntityVectorSync: shard_count: int = 1 remaining_jobs_after_shard: int = 0 prepare_seconds: float = 0.0 + queue_start: float | None = None @dataclass @@ -90,6 +93,7 @@ class _EntitySyncRuntime: """Per-entity runtime counters used while flushes are in flight.""" sync_start: float + queue_start: float source_rows_count: int embedding_jobs_count: int remaining_jobs: int @@ -120,6 +124,18 @@ class _EntityVectorShardPlan: entity_complete: bool +@dataclass(frozen=True) +class VectorChunkState: + """Existing vector chunk state fetched for one prepare window.""" + + id: int + chunk_key: str + source_hash: str + entity_fingerprint: str + embedding_model: str + has_embedding: bool + + class SearchRepositoryBase(ABC): """Abstract base class for backend-specific search repository implementations. @@ -781,6 +797,8 @@ async def _sync_entity_vectors_internal( ) if total_entities == 0: return result + batch_start = time.perf_counter() + backend_name = type(self).__name__.removesuffix("SearchRepository").lower() logger.info( "Vector batch sync start: project_id={project_id} entities_total={entities_total} " @@ -798,59 +816,90 @@ async def _sync_entity_vectors_internal( synced_entity_ids: set[int] = set() prepare_window_size = self._vector_prepare_window_size() - for window_start in range(0, total_entities, prepare_window_size): - window_entity_ids = entity_ids[window_start : window_start + prepare_window_size] - - if progress_callback is not None: - # Trigger: Postgres prepares one bounded entity window concurrently. - # Why: callbacks still need per-entity progress positions before the gather starts. - # Outcome: progress advances in prepare_window_size bursts instead of strict one-by-one. - for offset, entity_id in enumerate(window_entity_ids, start=window_start): - progress_callback(entity_id, offset, total_entities) + with telemetry.started_span( + "basic_memory.vector_sync.batch", + project_id=self.project_id, + backend=backend_name, + entities_total=total_entities, + window_size=prepare_window_size, + ) as batch_span: + for window_start in range(0, total_entities, prepare_window_size): + window_entity_ids = entity_ids[window_start : window_start + prepare_window_size] + + if progress_callback is not None: + # Trigger: prepare runs in bounded windows instead of strict one-by-one order. + # Why: callbacks still need deterministic per-entity positions before the window starts. + # Outcome: progress advances in prepare_window_size bursts. + for offset, entity_id in enumerate(window_entity_ids, start=window_start): + progress_callback(entity_id, offset, total_entities) + + prepared_window = await self._prepare_entity_vector_jobs_window(window_entity_ids) + + for entity_id, prepared in zip(window_entity_ids, prepared_window, strict=True): + if isinstance(prepared, BaseException): + if not continue_on_error: + raise prepared + failed_entity_ids.add(entity_id) + logger.warning( + "Vector batch sync entity prepare failed: project_id={project_id} " + "entity_id={entity_id} error={error}", + project_id=self.project_id, + entity_id=entity_id, + error=str(prepared), + ) + continue - prepared_window = await self._prepare_entity_vector_jobs_window(window_entity_ids) + embedding_jobs_count = len(prepared.embedding_jobs) + result.chunks_total += prepared.chunks_total + result.chunks_skipped += prepared.chunks_skipped + if prepared.entity_skipped: + result.entities_skipped += 1 + result.embedding_jobs_total += embedding_jobs_count + result.prepare_seconds_total += prepared.prepare_seconds + + if embedding_jobs_count == 0: + if prepared.entity_complete: + synced_entity_ids.add(entity_id) + else: + deferred_entity_ids.add(entity_id) + total_seconds = time.perf_counter() - prepared.sync_start + # Trigger: this entity never entered the shared embedding queue. + # Why: queue wait should track real flush contention only. + # Outcome: skip-only and delete-only entities report queue_wait ~= 0. + queue_wait_seconds = 0.0 + self._log_vector_sync_complete( + entity_id=entity_id, + total_seconds=total_seconds, + prepare_seconds=prepared.prepare_seconds, + queue_wait_seconds=queue_wait_seconds, + embed_seconds=0.0, + write_seconds=0.0, + source_rows_count=prepared.source_rows_count, + chunks_total=prepared.chunks_total, + chunks_skipped=prepared.chunks_skipped, + embedding_jobs_count=0, + entity_skipped=prepared.entity_skipped, + entity_complete=prepared.entity_complete, + oversized_entity=prepared.oversized_entity, + pending_jobs_total=prepared.pending_jobs_total, + shard_index=prepared.shard_index, + shard_count=prepared.shard_count, + remaining_jobs_after_shard=prepared.remaining_jobs_after_shard, + ) + continue - for entity_id, prepared in zip(window_entity_ids, prepared_window, strict=True): - if isinstance(prepared, BaseException): - if not continue_on_error: - raise prepared - failed_entity_ids.add(entity_id) - logger.warning( - "Vector batch sync entity prepare failed: project_id={project_id} " - "entity_id={entity_id} error={error}", - project_id=self.project_id, - entity_id=entity_id, - error=str(prepared), - ) - continue - - embedding_jobs_count = len(prepared.embedding_jobs) - result.chunks_total += prepared.chunks_total - result.chunks_skipped += prepared.chunks_skipped - if prepared.entity_skipped: - result.entities_skipped += 1 - result.embedding_jobs_total += embedding_jobs_count - result.prepare_seconds_total += prepared.prepare_seconds - - if embedding_jobs_count == 0: - if prepared.entity_complete: - synced_entity_ids.add(entity_id) - else: - deferred_entity_ids.add(entity_id) - total_seconds = time.perf_counter() - prepared.sync_start - queue_wait_seconds = max(0.0, total_seconds - prepared.prepare_seconds) - result.queue_wait_seconds_total += queue_wait_seconds - self._log_vector_sync_complete( - entity_id=entity_id, - total_seconds=total_seconds, - prepare_seconds=prepared.prepare_seconds, - queue_wait_seconds=queue_wait_seconds, - embed_seconds=0.0, - write_seconds=0.0, + entity_runtime[entity_id] = _EntitySyncRuntime( + sync_start=prepared.sync_start, + queue_start=( + prepared.queue_start + if prepared.queue_start is not None + else prepared.sync_start + prepared.prepare_seconds + ), source_rows_count=prepared.source_rows_count, + embedding_jobs_count=embedding_jobs_count, + remaining_jobs=embedding_jobs_count, chunks_total=prepared.chunks_total, chunks_skipped=prepared.chunks_skipped, - embedding_jobs_count=0, entity_skipped=prepared.entity_skipped, entity_complete=prepared.entity_complete, oversized_entity=prepared.oversized_entity, @@ -858,99 +907,86 @@ async def _sync_entity_vectors_internal( shard_index=prepared.shard_index, shard_count=prepared.shard_count, remaining_jobs_after_shard=prepared.remaining_jobs_after_shard, + prepare_seconds=prepared.prepare_seconds, ) - continue - - entity_runtime[entity_id] = _EntitySyncRuntime( - sync_start=prepared.sync_start, - source_rows_count=prepared.source_rows_count, - embedding_jobs_count=embedding_jobs_count, - remaining_jobs=embedding_jobs_count, - chunks_total=prepared.chunks_total, - chunks_skipped=prepared.chunks_skipped, - entity_skipped=prepared.entity_skipped, - entity_complete=prepared.entity_complete, - oversized_entity=prepared.oversized_entity, - pending_jobs_total=prepared.pending_jobs_total, - shard_index=prepared.shard_index, - shard_count=prepared.shard_count, - remaining_jobs_after_shard=prepared.remaining_jobs_after_shard, - prepare_seconds=prepared.prepare_seconds, - ) - pending_jobs.extend( - _PendingEmbeddingJob( - entity_id=entity_id, chunk_row_id=row_id, chunk_text=chunk_text + pending_jobs.extend( + _PendingEmbeddingJob( + entity_id=entity_id, chunk_row_id=row_id, chunk_text=chunk_text + ) + for row_id, chunk_text in prepared.embedding_jobs ) - for row_id, chunk_text in prepared.embedding_jobs - ) - while len(pending_jobs) >= self._semantic_embedding_sync_batch_size: - flush_jobs = pending_jobs[: self._semantic_embedding_sync_batch_size] - pending_jobs = pending_jobs[self._semantic_embedding_sync_batch_size :] - try: - embed_seconds, write_seconds = await self._flush_embedding_jobs( - flush_jobs=flush_jobs, - entity_runtime=entity_runtime, - synced_entity_ids=synced_entity_ids, - ) - result.embed_seconds_total += embed_seconds - result.write_seconds_total += write_seconds - (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( - entity_runtime=entity_runtime, - synced_entity_ids=synced_entity_ids, - deferred_entity_ids=deferred_entity_ids, - ) - except Exception as exc: - if not continue_on_error: - raise - affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) - failed_entity_ids.update(affected_entity_ids) - synced_entity_ids.difference_update(affected_entity_ids) - deferred_entity_ids.difference_update(affected_entity_ids) - for failed_entity_id in affected_entity_ids: - entity_runtime.pop(failed_entity_id, None) - logger.warning( - "Vector batch sync flush failed: project_id={project_id} " - "affected_entities={affected_entities} chunk_count={chunk_count} error={error}", - project_id=self.project_id, - affected_entities=affected_entity_ids, - chunk_count=len(flush_jobs), - error=str(exc), - ) + while len(pending_jobs) >= self._semantic_embedding_sync_batch_size: + flush_jobs = pending_jobs[: self._semantic_embedding_sync_batch_size] + pending_jobs = pending_jobs[self._semantic_embedding_sync_batch_size :] + try: + embed_seconds, write_seconds = await self._flush_embedding_jobs( + flush_jobs=flush_jobs, + entity_runtime=entity_runtime, + synced_entity_ids=synced_entity_ids, + ) + result.embed_seconds_total += embed_seconds + result.write_seconds_total += write_seconds + ( + result.queue_wait_seconds_total + ) += self._finalize_completed_entity_syncs( + entity_runtime=entity_runtime, + synced_entity_ids=synced_entity_ids, + deferred_entity_ids=deferred_entity_ids, + ) + except Exception as exc: + if not continue_on_error: + raise + affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) + failed_entity_ids.update(affected_entity_ids) + synced_entity_ids.difference_update(affected_entity_ids) + deferred_entity_ids.difference_update(affected_entity_ids) + for failed_entity_id in affected_entity_ids: + entity_runtime.pop(failed_entity_id, None) + logger.warning( + "Vector batch sync flush failed: project_id={project_id} " + "affected_entities={affected_entities} " + "chunk_count={chunk_count} error={error}", + project_id=self.project_id, + affected_entities=affected_entity_ids, + chunk_count=len(flush_jobs), + error=str(exc), + ) - if pending_jobs: - flush_jobs = list(pending_jobs) - pending_jobs = [] - try: - embed_seconds, write_seconds = await self._flush_embedding_jobs( - flush_jobs=flush_jobs, - entity_runtime=entity_runtime, - synced_entity_ids=synced_entity_ids, - ) - result.embed_seconds_total += embed_seconds - result.write_seconds_total += write_seconds - (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( - entity_runtime=entity_runtime, - synced_entity_ids=synced_entity_ids, - deferred_entity_ids=deferred_entity_ids, - ) - except Exception as exc: - if not continue_on_error: - raise - affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) - failed_entity_ids.update(affected_entity_ids) - synced_entity_ids.difference_update(affected_entity_ids) - deferred_entity_ids.difference_update(affected_entity_ids) - for failed_entity_id in affected_entity_ids: - entity_runtime.pop(failed_entity_id, None) - logger.warning( - "Vector batch sync final flush failed: project_id={project_id} " - "affected_entities={affected_entities} chunk_count={chunk_count} error={error}", - project_id=self.project_id, - affected_entities=affected_entity_ids, - chunk_count=len(flush_jobs), - error=str(exc), - ) + if pending_jobs: + flush_jobs = list(pending_jobs) + pending_jobs = [] + try: + embed_seconds, write_seconds = await self._flush_embedding_jobs( + flush_jobs=flush_jobs, + entity_runtime=entity_runtime, + synced_entity_ids=synced_entity_ids, + ) + result.embed_seconds_total += embed_seconds + result.write_seconds_total += write_seconds + (result.queue_wait_seconds_total) += self._finalize_completed_entity_syncs( + entity_runtime=entity_runtime, + synced_entity_ids=synced_entity_ids, + deferred_entity_ids=deferred_entity_ids, + ) + except Exception as exc: + if not continue_on_error: + raise + affected_entity_ids = sorted({job.entity_id for job in flush_jobs}) + failed_entity_ids.update(affected_entity_ids) + synced_entity_ids.difference_update(affected_entity_ids) + deferred_entity_ids.difference_update(affected_entity_ids) + for failed_entity_id in affected_entity_ids: + entity_runtime.pop(failed_entity_id, None) + logger.warning( + "Vector batch sync final flush failed: project_id={project_id} " + "affected_entities={affected_entities} chunk_count={chunk_count} " + "error={error}", + project_id=self.project_id, + affected_entities=affected_entity_ids, + chunk_count=len(flush_jobs), + error=str(exc), + ) # Trigger: this should never happen after all flushes succeed. # Why: remaining jobs mean runtime tracking drifted from queued jobs. @@ -999,263 +1035,363 @@ async def _sync_entity_vectors_internal( embed_seconds_total=result.embed_seconds_total, write_seconds_total=result.write_seconds_total, ) + batch_total_seconds = time.perf_counter() - batch_start + metric_attrs = { + "backend": backend_name, + "skip_only_batch": result.embedding_jobs_total == 0, + } + telemetry.record_histogram( + "vector_sync_batch_total_seconds", + batch_total_seconds, + unit="s", + **metric_attrs, + ) + telemetry.add_counter("vector_sync_entities_total", result.entities_total, **metric_attrs) + telemetry.add_counter( + "vector_sync_entities_skipped", + result.entities_skipped, + **metric_attrs, + ) + telemetry.add_counter( + "vector_sync_entities_deferred", + result.entities_deferred, + **metric_attrs, + ) + telemetry.add_counter( + "vector_sync_embedding_jobs_total", + result.embedding_jobs_total, + **metric_attrs, + ) + telemetry.add_counter("vector_sync_chunks_total", result.chunks_total, **metric_attrs) + telemetry.add_counter( + "vector_sync_chunks_skipped", + result.chunks_skipped, + **metric_attrs, + ) + if batch_span is not None: + batch_span.set_attributes( + { + "backend": backend_name, + "entities_synced": result.entities_synced, + "entities_failed": result.entities_failed, + "entities_deferred": result.entities_deferred, + "entities_skipped": result.entities_skipped, + "embedding_jobs_total": result.embedding_jobs_total, + "chunks_total": result.chunks_total, + "chunks_skipped": result.chunks_skipped, + "batch_total_seconds": batch_total_seconds, + } + ) return result def _vector_prepare_window_size(self) -> int: """Return the number of entities to prepare in one orchestration window.""" - return 1 + # Trigger: the shared window path now batches reads and then fans back out + # into per-entity prepare work. + # Why: SQLite benefits from concurrency too, but letting the default path + # explode to the full embed batch size creates unnecessary write contention. + # Outcome: local backends get a small bounded window, while Postgres keeps + # its explicit higher concurrency override. + return max(1, min(self._semantic_embedding_sync_batch_size, 8)) + + @asynccontextmanager + async def _prepare_entity_write_scope(self): + """Serialize the write-side prepare section when a backend needs it.""" + yield + + def _prepare_window_entity_params(self, entity_ids: list[int]) -> tuple[str, dict[str, object]]: + """Build deterministic bind params for one prepare window.""" + placeholders = ", ".join(f":entity_id_{index}" for index in range(len(entity_ids))) + params: dict[str, object] = {"project_id": self.project_id} + params.update( + {f"entity_id_{index}": entity_id for index, entity_id in enumerate(entity_ids)} + ) + return placeholders, params + + async def _fetch_prepare_window_source_rows( + self, + session: AsyncSession, + entity_ids: list[int], + ) -> dict[int, list[Any]]: + """Fetch all search_index rows needed for one prepare window.""" + grouped_rows: dict[int, list[Any]] = {entity_id: [] for entity_id in entity_ids} + if not entity_ids: + return grouped_rows + + placeholders, params = self._prepare_window_entity_params(entity_ids) + params.update( + { + "entity_type": SearchItemType.ENTITY.value, + "observation_type": SearchItemType.OBSERVATION.value, + "relation_type_type": SearchItemType.RELATION.value, + } + ) + result = await session.execute( + text( + "SELECT entity_id, id, type, title, permalink, content_stems, content_snippet, " + "category, relation_type " + "FROM search_index " + f"WHERE project_id = :project_id AND entity_id IN ({placeholders}) " + "ORDER BY entity_id ASC, " + "CASE type " + "WHEN :entity_type THEN 0 " + "WHEN :observation_type THEN 1 " + "WHEN :relation_type_type THEN 2 " + "ELSE 3 END, id ASC" + ), + params, + ) + for row in result.fetchall(): + grouped_rows.setdefault(int(row.entity_id), []).append(row) + return grouped_rows + + def _prepare_window_existing_rows_sql(self, placeholders: str) -> str: + """SQL for existing chunk/embedding rows in one prepare window.""" + return ( + "SELECT c.entity_id, c.id, c.chunk_key, c.source_hash, c.entity_fingerprint, " + "c.embedding_model, (e.chunk_id IS NOT NULL) AS has_embedding " + "FROM search_vector_chunks c " + "LEFT JOIN search_vector_embeddings e ON e.chunk_id = c.id " + f"WHERE c.project_id = :project_id AND c.entity_id IN ({placeholders}) " + "ORDER BY c.entity_id ASC, c.chunk_key ASC" + ) + + async def _fetch_prepare_window_existing_rows( + self, + session: AsyncSession, + entity_ids: list[int], + ) -> dict[int, list[VectorChunkState]]: + """Fetch all persisted chunk state needed for one prepare window.""" + grouped_rows: dict[int, list[VectorChunkState]] = { + entity_id: [] for entity_id in entity_ids + } + if not entity_ids: + return grouped_rows + + placeholders, params = self._prepare_window_entity_params(entity_ids) + result = await session.execute( + text(self._prepare_window_existing_rows_sql(placeholders)), params + ) + for row in result.mappings().all(): + grouped_rows.setdefault(int(row["entity_id"]), []).append( + VectorChunkState( + id=int(row["id"]), + chunk_key=str(row["chunk_key"]), + source_hash=str(row["source_hash"]), + entity_fingerprint=str(row["entity_fingerprint"]), + embedding_model=str(row["embedding_model"]), + has_embedding=bool(row["has_embedding"]), + ) + ) + return grouped_rows async def _prepare_entity_vector_jobs_window( self, entity_ids: list[int] ) -> list[_PreparedEntityVectorSync | BaseException]: - """Prepare one window of entity vector jobs. + """Prepare one window of entity vector jobs with shared read-side batching.""" + if not entity_ids: + return [] - Default implementation is sequential to preserve backend behavior. - Postgres overrides this to use a bounded concurrent gather. - """ - prepared_window: list[_PreparedEntityVectorSync | BaseException] = [] - for entity_id in entity_ids: - try: - prepared_window.append(await self._prepare_entity_vector_jobs(entity_id)) - except Exception as exc: - prepared_window.append(exc) - return prepared_window + try: + async with db.scoped_session(self.session_maker) as session: + await self._prepare_vector_session(session) + source_rows_by_entity = await self._fetch_prepare_window_source_rows( + session, entity_ids + ) + existing_rows_by_entity = await self._fetch_prepare_window_existing_rows( + session, entity_ids + ) + except Exception as exc: + return [exc for _ in entity_ids] + + # Trigger: prepare now does one shared read pass per window instead of + # paying the same select/join round-trips per entity. + # Why: both SQLite and Postgres were still burning wall clock in read-side + # fingerprint/orphan checks even when every entity ended up skipped. + # Outcome: we batch the reads once, then fan back out over entities while + # preserving input order in the gathered results. + prepared_window = await asyncio.gather( + *( + self._prepare_entity_vector_jobs_prefetched( + entity_id=entity_id, + source_rows=source_rows_by_entity.get(entity_id, []), + existing_rows=existing_rows_by_entity.get(entity_id, []), + ) + for entity_id in entity_ids + ), + return_exceptions=True, + ) + return list(prepared_window) async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVectorSync: """Prepare chunk mutations and embedding jobs for one entity.""" - sync_start = time.perf_counter() - - async with db.scoped_session(self.session_maker) as session: - await self._prepare_vector_session(session) + prepared_window = await self._prepare_entity_vector_jobs_window([entity_id]) + prepared = prepared_window[0] + if isinstance(prepared, BaseException): + raise prepared + return prepared - row_result = await session.execute( - text( - "SELECT id, type, title, permalink, content_stems, content_snippet, " - "category, relation_type " - "FROM search_index " - "WHERE entity_id = :entity_id AND project_id = :project_id " - "ORDER BY " - "CASE type " - "WHEN :entity_type THEN 0 " - "WHEN :observation_type THEN 1 " - "WHEN :relation_type_type THEN 2 " - "ELSE 3 END, id ASC" - ), - { - "entity_id": entity_id, - "project_id": self.project_id, - "entity_type": SearchItemType.ENTITY.value, - "observation_type": SearchItemType.OBSERVATION.value, - "relation_type_type": SearchItemType.RELATION.value, - }, + async def _prepare_entity_vector_jobs_prefetched( + self, + *, + entity_id: int, + source_rows: list[Any], + existing_rows: list[VectorChunkState], + ) -> _PreparedEntityVectorSync: + """Prepare one entity using prefetched window rows.""" + sync_start = time.perf_counter() + prepare_start = sync_start + source_rows_count = len(source_rows) + + if not source_rows: + async with self._prepare_entity_write_scope(): + async with db.scoped_session(self.session_maker) as session: + await self._prepare_vector_session(session) + await self._delete_entity_chunks(session, entity_id) + await session.commit() + prepare_seconds = time.perf_counter() - prepare_start + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=sync_start, + source_rows_count=source_rows_count, + embedding_jobs=[], + prepare_seconds=prepare_seconds, ) - rows = row_result.fetchall() - source_rows_count = len(rows) - built_chunk_records_count = 0 - - # No search_index rows → delete all chunk/embedding data for this entity. - if not rows: - await self._delete_entity_chunks(session, entity_id) - await session.commit() - prepare_seconds = time.perf_counter() - sync_start - return _PreparedEntityVectorSync( - entity_id=entity_id, - sync_start=sync_start, - source_rows_count=source_rows_count, - embedding_jobs=[], - prepare_seconds=prepare_seconds, - ) - chunk_records = self._build_chunk_records(rows) - built_chunk_records_count = len(chunk_records) - current_entity_fingerprint = self._build_entity_fingerprint(chunk_records) - current_embedding_model = self._embedding_model_key() - if not chunk_records: - await self._delete_entity_chunks(session, entity_id) - await session.commit() - prepare_seconds = time.perf_counter() - sync_start - return _PreparedEntityVectorSync( - entity_id=entity_id, - sync_start=sync_start, - source_rows_count=source_rows_count, - embedding_jobs=[], - prepare_seconds=prepare_seconds, - ) - - # --- Diff existing chunks against incoming --- - existing_rows_result = await session.execute( - text( - "SELECT id, chunk_key, source_hash, entity_fingerprint, embedding_model " - "FROM search_vector_chunks " - "WHERE project_id = :project_id AND entity_id = :entity_id" - ), - {"project_id": self.project_id, "entity_id": entity_id}, + chunk_records = self._build_chunk_records(source_rows) + built_chunk_records_count = len(chunk_records) + if not chunk_records: + async with self._prepare_entity_write_scope(): + async with db.scoped_session(self.session_maker) as session: + await self._prepare_vector_session(session) + await self._delete_entity_chunks(session, entity_id) + await session.commit() + prepare_seconds = time.perf_counter() - prepare_start + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=sync_start, + source_rows_count=source_rows_count, + embedding_jobs=[], + prepare_seconds=prepare_seconds, ) - existing_by_key = {row.chunk_key: row for row in existing_rows_result.fetchall()} - existing_chunks_count = len(existing_by_key) - incoming_hashes = { - record["chunk_key"]: record["source_hash"] for record in chunk_records - } - stale_ids = [ - int(row.id) - for chunk_key, row in existing_by_key.items() - if chunk_key not in incoming_hashes - ] - stale_chunks_count = len(stale_ids) - - if stale_ids: - await self._delete_stale_chunks(session, stale_ids, entity_id) - - # --- Orphan cleanup: chunks without corresponding embeddings --- - # Trigger: a previous sync crashed between chunk insert and embedding write. - # Why: self-healing on next sync prevents permanent data skew. - # Outcome: orphaned chunks are re-embedded instead of silently dropped. - orphan_result = await session.execute( - text(self._orphan_detection_sql()), - {"project_id": self.project_id, "entity_id": entity_id}, + + current_entity_fingerprint = self._build_entity_fingerprint(chunk_records) + current_embedding_model = self._embedding_model_key() + existing_by_key = {row.chunk_key: row for row in existing_rows} + incoming_chunk_keys = {record["chunk_key"] for record in chunk_records} + stale_ids = [ + row.id + for chunk_key, row in existing_by_key.items() + if chunk_key not in incoming_chunk_keys + ] + orphan_ids = {row.id for row in existing_rows if not row.has_embedding} + + # Trigger: all persisted chunk metadata already matches this entity's + # current fingerprint/model and every chunk still has an embedding. + # Why: unchanged entities should stop in prepare instead of paying write + # or queue accounting they never actually used. + # Outcome: skip-only entities return immediately with zero embedding jobs. + skip_unchanged_entity = ( + len(existing_rows) == built_chunk_records_count + and not stale_ids + and not orphan_ids + and bool(existing_rows) + and all( + row.entity_fingerprint == current_entity_fingerprint + and row.embedding_model == current_embedding_model + for row in existing_rows ) - orphan_rows = orphan_result.fetchall() - orphan_ids = {int(row.id) for row in orphan_rows} - orphan_chunks_count = len(orphan_ids) - - # Trigger: the persisted chunk metadata exactly matches the current - # semantic fingerprint/model and every chunk still has an embedding. - # Why: full reindex and embeddings-only runs should avoid reopening - # the expensive chunk diff + embed path for unchanged entities. - # Outcome: return immediately with skip counters and no writes. - skip_unchanged_entity = ( - existing_chunks_count == built_chunk_records_count - and stale_chunks_count == 0 - and orphan_chunks_count == 0 - and existing_chunks_count > 0 - and all( - row.entity_fingerprint == current_entity_fingerprint - and row.embedding_model == current_embedding_model - for row in existing_by_key.values() - ) + ) + if skip_unchanged_entity: + prepare_seconds = time.perf_counter() - prepare_start + return _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=sync_start, + source_rows_count=source_rows_count, + embedding_jobs=[], + chunks_total=built_chunk_records_count, + chunks_skipped=built_chunk_records_count, + entity_skipped=True, + prepare_seconds=prepare_seconds, ) - if skip_unchanged_entity: - prepare_seconds = time.perf_counter() - sync_start - return _PreparedEntityVectorSync( - entity_id=entity_id, - sync_start=sync_start, - source_rows_count=source_rows_count, - embedding_jobs=[], - chunks_total=built_chunk_records_count, - chunks_skipped=built_chunk_records_count, - entity_skipped=True, - prepare_seconds=prepare_seconds, - ) - timestamp_expr = self._timestamp_now_expr() - pending_records: list[dict[str, str]] = [] - skipped_chunks_count = 0 - for record in chunk_records: - current = existing_by_key.get(record["chunk_key"]) + timestamp_expr = self._timestamp_now_expr() + metadata_update_ids: list[int] = [] + pending_records: list[dict[str, str]] = [] + skipped_chunks_count = 0 + for record in chunk_records: + current = existing_by_key.get(record["chunk_key"]) + if current is None: + pending_records.append(record) + continue + + same_source_hash = current.source_hash == record["source_hash"] + same_entity_fingerprint = current.entity_fingerprint == current_entity_fingerprint + same_embedding_model = current.embedding_model == current_embedding_model - # Trigger: chunk exists and hash matches (no content change) - # but chunk has no embedding (orphan from crash). - # Outcome: schedule re-embedding without touching chunk metadata. - is_orphan = current and int(current.id) in orphan_ids - if current: - row_id = int(current.id) - same_source_hash = current.source_hash == record["source_hash"] - same_entity_fingerprint = ( - current.entity_fingerprint == current_entity_fingerprint - ) - same_embedding_model = current.embedding_model == current_embedding_model - - if same_source_hash and not is_orphan and same_embedding_model: - if not same_entity_fingerprint: - await session.execute( - text( - "UPDATE search_vector_chunks " - "SET entity_fingerprint = :entity_fingerprint, " - "embedding_model = :embedding_model, " - f"updated_at = {timestamp_expr} " - "WHERE id = :id" - ), - { - "id": row_id, - "entity_fingerprint": current_entity_fingerprint, - "embedding_model": current_embedding_model, - }, - ) - skipped_chunks_count += 1 - continue + if same_source_hash and current.id not in orphan_ids and same_embedding_model: + if not same_entity_fingerprint: + metadata_update_ids.append(current.id) + skipped_chunks_count += 1 + continue - pending_records.append(record) + pending_records.append(record) - shard_plan = self._plan_entity_vector_shard(pending_records) - self._log_vector_shard_plan(entity_id=entity_id, shard_plan=shard_plan) + shard_plan = self._plan_entity_vector_shard(pending_records) + self._log_vector_shard_plan(entity_id=entity_id, shard_plan=shard_plan) - # Trigger: oversized entities can accumulate thousands of pending chunks. - # Why: scheduling only one deterministic shard bounds memory and wall clock. - # Outcome: future runs resume from the remaining chunk rows without redoing completed work. - scheduled_records = [ - record - for record in sorted(pending_records, key=lambda record: record["chunk_key"]) - if record["chunk_key"] in shard_plan.scheduled_chunk_keys - ] + # Trigger: oversized entities can still produce many changed chunks even + # after the read side is batched. + # Why: we still need the existing shard cap so one entity cannot monopolize + # a sync run. + # Outcome: batching removes read overhead without changing deferred semantics. + scheduled_records = [ + record + for record in sorted(pending_records, key=lambda record: record["chunk_key"]) + if record["chunk_key"] in shard_plan.scheduled_chunk_keys + ] - # --- Upsert scheduled changed / new chunks, collect embedding jobs --- - embedding_jobs: list[tuple[int, str]] = [] - for record in scheduled_records: - current = existing_by_key.get(record["chunk_key"]) - if current: - row_id = int(current.id) - if ( - current.source_hash != record["source_hash"] - or current.entity_fingerprint != current_entity_fingerprint - or current.embedding_model != current_embedding_model - ): + embedding_jobs: list[tuple[int, str]] = [] + if stale_ids or metadata_update_ids or scheduled_records: + # Trigger: prepare needs to mutate chunk rows for this entity. + # Why: Postgres can keep these write-side steps concurrent, while + # SQLite should funnel them through one writer even after the shared + # read window fan-out. + # Outcome: backends share the batched read path without forcing + # SQLite into unnecessary concurrent write transactions. + async with self._prepare_entity_write_scope(): + async with db.scoped_session(self.session_maker) as session: + await self._prepare_vector_session(session) + if stale_ids: + await self._delete_stale_chunks(session, stale_ids, entity_id) + for row_id in metadata_update_ids: await session.execute( text( "UPDATE search_vector_chunks " - "SET chunk_text = :chunk_text, source_hash = :source_hash, " - "entity_fingerprint = :entity_fingerprint, " + "SET entity_fingerprint = :entity_fingerprint, " "embedding_model = :embedding_model, " f"updated_at = {timestamp_expr} " "WHERE id = :id" ), { "id": row_id, - "chunk_text": record["chunk_text"], - "source_hash": record["source_hash"], "entity_fingerprint": current_entity_fingerprint, "embedding_model": current_embedding_model, }, ) - embedding_jobs.append((row_id, record["chunk_text"])) - continue - - inserted = await session.execute( - text( - "INSERT INTO search_vector_chunks (" - "entity_id, project_id, chunk_key, chunk_text, source_hash, " - "entity_fingerprint, embedding_model, updated_at" - ") VALUES (" - f":entity_id, :project_id, :chunk_key, :chunk_text, :source_hash, " - ":entity_fingerprint, :embedding_model, " - f"{timestamp_expr}" - ") RETURNING id" - ), - { - "entity_id": entity_id, - "project_id": self.project_id, - "chunk_key": record["chunk_key"], - "chunk_text": record["chunk_text"], - "source_hash": record["source_hash"], - "entity_fingerprint": current_entity_fingerprint, - "embedding_model": current_embedding_model, - }, - ) - row_id = int(inserted.scalar_one()) - embedding_jobs.append((row_id, record["chunk_text"])) - await session.commit() + if scheduled_records: + embedding_jobs = await self._upsert_scheduled_chunk_records( + session, + entity_id=entity_id, + scheduled_records=scheduled_records, + existing_by_key=existing_by_key, + entity_fingerprint=current_entity_fingerprint, + embedding_model=current_embedding_model, + ) + await session.commit() - prepare_seconds = time.perf_counter() - sync_start + prepare_seconds = time.perf_counter() - prepare_start return _PreparedEntityVectorSync( entity_id=entity_id, sync_start=sync_start, @@ -1270,8 +1406,74 @@ async def _prepare_entity_vector_jobs(self, entity_id: int) -> _PreparedEntityVe shard_count=shard_plan.shard_count, remaining_jobs_after_shard=shard_plan.remaining_jobs_after_shard, prepare_seconds=prepare_seconds, + queue_start=time.perf_counter(), ) + async def _upsert_scheduled_chunk_records( + self, + session: AsyncSession, + *, + entity_id: int, + scheduled_records: list[dict[str, str]], + existing_by_key: dict[str, VectorChunkState], + entity_fingerprint: str, + embedding_model: str, + ) -> list[tuple[int, str]]: + """Upsert scheduled chunk rows and return embedding jobs.""" + timestamp_expr = self._timestamp_now_expr() + embedding_jobs: list[tuple[int, str]] = [] + for record in scheduled_records: + current = existing_by_key.get(record["chunk_key"]) + if current: + if ( + current.source_hash != record["source_hash"] + or current.entity_fingerprint != entity_fingerprint + or current.embedding_model != embedding_model + ): + await session.execute( + text( + "UPDATE search_vector_chunks " + "SET chunk_text = :chunk_text, source_hash = :source_hash, " + "entity_fingerprint = :entity_fingerprint, " + "embedding_model = :embedding_model, " + f"updated_at = {timestamp_expr} " + "WHERE id = :id" + ), + { + "id": current.id, + "chunk_text": record["chunk_text"], + "source_hash": record["source_hash"], + "entity_fingerprint": entity_fingerprint, + "embedding_model": embedding_model, + }, + ) + embedding_jobs.append((current.id, record["chunk_text"])) + continue + + inserted = await session.execute( + text( + "INSERT INTO search_vector_chunks (" + "entity_id, project_id, chunk_key, chunk_text, source_hash, " + "entity_fingerprint, embedding_model, updated_at" + ") VALUES (" + f":entity_id, :project_id, :chunk_key, :chunk_text, :source_hash, " + ":entity_fingerprint, :embedding_model, " + f"{timestamp_expr}" + ") RETURNING id" + ), + { + "entity_id": entity_id, + "project_id": self.project_id, + "chunk_key": record["chunk_key"], + "chunk_text": record["chunk_text"], + "source_hash": record["source_hash"], + "entity_fingerprint": entity_fingerprint, + "embedding_model": embedding_model, + }, + ) + embedding_jobs.append((int(inserted.scalar_one()), record["chunk_text"])) + return embedding_jobs + async def _flush_embedding_jobs( self, flush_jobs: list[_PendingEmbeddingJob], @@ -1336,13 +1538,16 @@ def _finalize_completed_entity_syncs( synced_entity_ids.add(entity_id) else: deferred_entity_ids.add(entity_id) - total_seconds = time.perf_counter() - runtime.sync_start + completed_at = time.perf_counter() + total_seconds = completed_at - runtime.sync_start + # Trigger: queue wait should represent time spent behind shared flush + # work after prepare finished. + # Why: skip-only entities never entered that queue, and mixed batches + # should only charge queue time to entities that actually waited. + # Outcome: skip-only batches stay near zero while real contention remains visible. queue_wait_seconds = max( 0.0, - total_seconds - - runtime.prepare_seconds - - runtime.embed_seconds - - runtime.write_seconds, + completed_at - runtime.queue_start - runtime.embed_seconds - runtime.write_seconds, ) queue_wait_seconds_total += queue_wait_seconds self._log_vector_sync_complete( @@ -1390,6 +1595,35 @@ def _log_vector_sync_complete( remaining_jobs_after_shard: int, ) -> None: """Log completion and slow-entity warnings with a consistent format.""" + backend_name = type(self).__name__.removesuffix("SearchRepository").lower() + metric_attrs = { + "backend": backend_name, + "skip_only_entity": entity_skipped and embedding_jobs_count == 0, + } + telemetry.record_histogram( + "vector_sync_prepare_seconds", + prepare_seconds, + unit="s", + **metric_attrs, + ) + telemetry.record_histogram( + "vector_sync_queue_wait_seconds", + queue_wait_seconds, + unit="s", + **metric_attrs, + ) + telemetry.record_histogram( + "vector_sync_embed_seconds", + embed_seconds, + unit="s", + **metric_attrs, + ) + telemetry.record_histogram( + "vector_sync_write_seconds", + write_seconds, + unit="s", + **metric_attrs, + ) if total_seconds > 10: logger.warning( "Vector sync slow entity: project_id={project_id} entity_id={entity_id} " @@ -1435,19 +1669,6 @@ def _timestamp_now_expr(self) -> str: """ return "CURRENT_TIMESTAMP" - def _orphan_detection_sql(self) -> str: - """SQL to find chunk rows without corresponding embeddings. - - Default implementation works for both backends; SQLite overrides - to reference the rowid-based embedding table layout. - """ - return ( - "SELECT c.id FROM search_vector_chunks c " - "LEFT JOIN search_vector_embeddings e ON e.chunk_id = c.id " - "WHERE c.project_id = :project_id AND c.entity_id = :entity_id " - "AND e.chunk_id IS NULL" - ) - # ------------------------------------------------------------------ # Shared semantic search: retrieval mode dispatch # ------------------------------------------------------------------ diff --git a/src/basic_memory/repository/sqlite_search_repository.py b/src/basic_memory/repository/sqlite_search_repository.py index cc42331b..447db427 100644 --- a/src/basic_memory/repository/sqlite_search_repository.py +++ b/src/basic_memory/repository/sqlite_search_repository.py @@ -1,11 +1,11 @@ """SQLite FTS5-based search repository implementation.""" +import asyncio import json import re +from contextlib import asynccontextmanager from datetime import datetime from typing import List, Optional - -import asyncio from loguru import logger from sqlalchemy import text from sqlalchemy.exc import OperationalError as SAOperationalError @@ -566,13 +566,21 @@ def _distance_to_similarity(self, distance: float) -> float: """ return max(0.0, 1.0 - (distance * distance) / 2.0) - def _orphan_detection_sql(self) -> str: - """SQLite sqlite-vec uses rowid-based embedding table.""" + @asynccontextmanager + async def _prepare_entity_write_scope(self): + """SQLite keeps the shared read window, but funnels prepare writes through one lock.""" + async with self._sqlite_vec_lock: + yield + + def _prepare_window_existing_rows_sql(self, placeholders: str) -> str: + """SQLite sqlite-vec stores embeddings by rowid rather than chunk_id.""" return ( - "SELECT c.id FROM search_vector_chunks c " + "SELECT c.entity_id, c.id, c.chunk_key, c.source_hash, c.entity_fingerprint, " + "c.embedding_model, (e.rowid IS NOT NULL) AS has_embedding " + "FROM search_vector_chunks c " "LEFT JOIN search_vector_embeddings e ON e.rowid = c.id " - "WHERE c.project_id = :project_id AND c.entity_id = :entity_id " - "AND e.rowid IS NULL" + f"WHERE c.project_id = :project_id AND c.entity_id IN ({placeholders}) " + "ORDER BY c.entity_id ASC, c.chunk_key ASC" ) # ------------------------------------------------------------------ diff --git a/src/basic_memory/telemetry.py b/src/basic_memory/telemetry.py index 35f6cfb5..aee27cef 100644 --- a/src/basic_memory/telemetry.py +++ b/src/basic_memory/telemetry.py @@ -40,6 +40,7 @@ class TelemetryState: _STATE = TelemetryState() _LOGFIRE_HANDLER: dict[str, Any] | None = None +_METRICS: dict[tuple[str, str, str, str], Any] = {} def reset_telemetry_state() -> None: @@ -55,6 +56,7 @@ def reset_telemetry_state() -> None: _STATE.send_to_logfire = False _STATE.warnings.clear() _LOGFIRE_HANDLER = None + _METRICS.clear() def _filter_attributes(attrs: dict[str, Any]) -> dict[str, Any]: @@ -136,6 +138,58 @@ def pop_telemetry_warnings() -> list[str]: return warnings +def _get_metric(metric_type: str, name: str, *, unit: str, description: str) -> Any | None: + """Create or reuse a Logfire metric instrument when telemetry is enabled.""" + logfire = _load_logfire() + if logfire is None or not _STATE.configured: # pragma: no cover + return None # pragma: no cover + + metric_key = (metric_type, name, unit, description) + cached_metric = _METRICS.get(metric_key) + if cached_metric is not None: + return cached_metric + + if metric_type == "counter": + metric = logfire.metric_counter(name, unit=unit, description=description) + elif metric_type == "histogram": + metric = logfire.metric_histogram(name, unit=unit, description=description) + else: # pragma: no cover + raise ValueError(f"Unsupported metric type: {metric_type}") # pragma: no cover + + _METRICS[metric_key] = metric + return metric + + +def add_counter( + name: str, + amount: int | float, + *, + unit: str = "1", + description: str = "", + **attrs: Any, +) -> None: + """Record a counter increment when telemetry is enabled.""" + metric = _get_metric("counter", name, unit=unit, description=description) + if metric is None: + return + metric.add(amount, attributes=_filter_attributes(attrs)) + + +def record_histogram( + name: str, + amount: int | float, + *, + unit: str = "", + description: str = "", + **attrs: Any, +) -> None: + """Record one histogram sample when telemetry is enabled.""" + metric = _get_metric("histogram", name, unit=unit, description=description) + if metric is None: + return + metric.record(amount, attributes=_filter_attributes(attrs)) + + @contextmanager def contextualize(**attrs: Any) -> Iterator[None]: """Apply filtered telemetry attributes to Loguru calls in this scope.""" @@ -176,11 +230,13 @@ def started_span(name: str, **attrs: Any) -> Iterator[Any | None]: __all__ = [ + "add_counter", "contextualize", "configure_telemetry", "get_logfire_handler", "operation", "pop_telemetry_warnings", + "record_histogram", "reset_telemetry_state", "scope", "span", diff --git a/tests/repository/test_postgres_search_repository.py b/tests/repository/test_postgres_search_repository.py index 9e9f3e1a..4ac72859 100644 --- a/tests/repository/test_postgres_search_repository.py +++ b/tests/repository/test_postgres_search_repository.py @@ -521,6 +521,7 @@ async def test_postgres_vector_sync_skips_unchanged_and_reembeds_changed_content assert unchanged_result.entities_synced == 1 assert unchanged_result.entities_skipped == 1 assert unchanged_result.embedding_jobs_total == 0 + assert unchanged_result.queue_wait_seconds_total == pytest.approx(0.0, abs=0.01) assert unchanged_result.chunks_skipped == unchanged_result.chunks_total await repo.index_item( diff --git a/tests/repository/test_postgres_search_repository_unit.py b/tests/repository/test_postgres_search_repository_unit.py index fa801ee0..b124a074 100644 --- a/tests/repository/test_postgres_search_repository_unit.py +++ b/tests/repository/test_postgres_search_repository_unit.py @@ -195,12 +195,9 @@ async def fake_scoped_session(session_maker): executed_sql = [str(call.args[0]) for call in session.execute.await_args_list] + assert any("CREATE TABLE IF NOT EXISTS search_vector_chunks" in sql for sql in executed_sql) assert any( - "CREATE TABLE IF NOT EXISTS search_vector_chunks" in sql for sql in executed_sql - ) - assert any( - "CREATE TABLE IF NOT EXISTS search_vector_embeddings" in sql - for sql in executed_sql + "CREATE TABLE IF NOT EXISTS search_vector_embeddings" in sql for sql in executed_sql ) assert not any("ALTER TABLE search_vector_chunks" in sql for sql in executed_sql) session.commit.assert_awaited_once() @@ -287,11 +284,11 @@ async def test_write_embeddings_executes_single_bulk_upsert(self): assert params["embedding_dims_1"] == 4 -class TestBatchPrepareConcurrency: - """Cover the Postgres-specific concurrent prepare window.""" +class TestBatchPrepareWindow: + """Cover the shared batched prepare window used by Postgres.""" @pytest.mark.asyncio - async def test_sync_entity_vectors_batch_prepares_entities_concurrently(self, monkeypatch): + async def test_sync_entity_vectors_batch_uses_shared_prepare_window(self, monkeypatch): repo = _make_repo( semantic_enabled=True, embedding_provider=StubEmbeddingProvider(), @@ -300,30 +297,65 @@ async def test_sync_entity_vectors_batch_prepares_entities_concurrently(self, mo repo._semantic_embedding_sync_batch_size = 8 repo._vector_tables_initialized = True + fetched_windows: list[list[int]] = [] + prepared_windows: list[list[int]] = [] active_prepares = 0 max_active_prepares = 0 - async def _stub_prepare(entity_id: int) -> _PreparedEntityVectorSync: + async def _stub_fetch_source_rows(session, entity_ids: list[int]): + fetched_windows.append(list(entity_ids)) + return {entity_id: [object()] for entity_id in entity_ids} + + async def _stub_fetch_existing_rows(session, entity_ids: list[int]): + return {entity_id: [] for entity_id in entity_ids} + + async def _stub_prepare_prefetched( + *, + entity_id: int, + source_rows, + existing_rows, + ) -> _PreparedEntityVectorSync: nonlocal active_prepares, max_active_prepares + assert len(source_rows) == 1 + assert existing_rows == [] active_prepares += 1 max_active_prepares = max(max_active_prepares, active_prepares) await asyncio.sleep(0) active_prepares -= 1 + prepared_windows.append([entity_id]) return _PreparedEntityVectorSync( entity_id=entity_id, sync_start=float(entity_id), source_rows_count=1, embedding_jobs=[], + entity_skipped=True, + chunks_total=1, + chunks_skipped=1, + prepare_seconds=0.1, ) + @asynccontextmanager + async def fake_scoped_session(session_maker): + yield AsyncMock() + monkeypatch.setattr(repo, "_ensure_vector_tables", AsyncMock()) - monkeypatch.setattr(repo, "_prepare_entity_vector_jobs", _stub_prepare) + monkeypatch.setattr( + "basic_memory.repository.search_repository_base.db.scoped_session", + fake_scoped_session, + ) + monkeypatch.setattr(repo, "_fetch_prepare_window_source_rows", _stub_fetch_source_rows) + monkeypatch.setattr(repo, "_fetch_prepare_window_existing_rows", _stub_fetch_existing_rows) + monkeypatch.setattr( + repo, "_prepare_entity_vector_jobs_prefetched", _stub_prepare_prefetched + ) result = await repo.sync_entity_vectors_batch([1, 2, 3, 4]) assert result.entities_total == 4 assert result.entities_synced == 4 assert result.entities_failed == 0 + assert fetched_windows == [[1, 2], [3, 4]] + assert prepared_windows == [[1], [2], [3], [4]] assert max_active_prepares == 2 @@ -337,14 +369,17 @@ async def test_postgres_batch_sync_tracks_prepare_and_queue_wait(monkeypatch): repo._semantic_embedding_sync_batch_size = 2 repo._vector_tables_initialized = True - async def _stub_prepare(entity_id: int) -> _PreparedEntityVectorSync: - return _PreparedEntityVectorSync( - entity_id=entity_id, - sync_start=0.0, - source_rows_count=1, - embedding_jobs=[(200 + entity_id, f"chunk-{entity_id}")], - prepare_seconds=1.0, - ) + async def _stub_prepare_window(entity_ids: list[int]): + return [ + _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=0.0, + source_rows_count=1, + embedding_jobs=[(200 + entity_id, f"chunk-{entity_id}")], + prepare_seconds=1.0, + ) + for entity_id in entity_ids + ] async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): for job in flush_jobs: @@ -364,10 +399,10 @@ async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): def _capture_log(**kwargs): completion_records.append(kwargs) - perf_counter_values = iter([4.0, 5.0]) + perf_counter_values = iter([0.0, 4.0, 5.0, 6.0]) monkeypatch.setattr(repo, "_ensure_vector_tables", AsyncMock()) - monkeypatch.setattr(repo, "_prepare_entity_vector_jobs", _stub_prepare) + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs_window", _stub_prepare_window) monkeypatch.setattr(repo, "_flush_embedding_jobs", _stub_flush) monkeypatch.setattr(repo, "_log_vector_sync_complete", _capture_log) monkeypatch.setattr( @@ -401,33 +436,41 @@ async def test_postgres_batch_sync_tracks_deferred_oversized_entities(monkeypatc repo._semantic_embedding_sync_batch_size = 8 repo._vector_tables_initialized = True - async def _stub_prepare(entity_id: int) -> _PreparedEntityVectorSync: - if entity_id == 1: - return _PreparedEntityVectorSync( - entity_id=entity_id, - sync_start=0.0, - source_rows_count=1, - embedding_jobs=[(201, "chunk-1a"), (202, "chunk-1b")], - chunks_total=5, - pending_jobs_total=5, - entity_complete=False, - oversized_entity=True, - shard_index=1, - shard_count=3, - remaining_jobs_after_shard=3, + async def _stub_prepare_window(entity_ids: list[int]): + prepared: list[_PreparedEntityVectorSync] = [] + for entity_id in entity_ids: + if entity_id == 1: + prepared.append( + _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=0.0, + source_rows_count=1, + embedding_jobs=[(201, "chunk-1a"), (202, "chunk-1b")], + chunks_total=5, + pending_jobs_total=5, + entity_complete=False, + oversized_entity=True, + shard_index=1, + shard_count=3, + remaining_jobs_after_shard=3, + ) + ) + continue + prepared.append( + _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=0.0, + source_rows_count=1, + embedding_jobs=[(301, "chunk-2a")], + chunks_total=1, + pending_jobs_total=1, + entity_complete=True, + shard_index=1, + shard_count=1, + remaining_jobs_after_shard=0, + ) ) - return _PreparedEntityVectorSync( - entity_id=entity_id, - sync_start=0.0, - source_rows_count=1, - embedding_jobs=[(301, "chunk-2a")], - chunks_total=1, - pending_jobs_total=1, - entity_complete=True, - shard_index=1, - shard_count=1, - remaining_jobs_after_shard=0, - ) + return prepared async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): for job in flush_jobs: @@ -443,7 +486,7 @@ def _capture_log(**kwargs): completion_records.append(kwargs) monkeypatch.setattr(repo, "_ensure_vector_tables", AsyncMock()) - monkeypatch.setattr(repo, "_prepare_entity_vector_jobs", _stub_prepare) + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs_window", _stub_prepare_window) monkeypatch.setattr(repo, "_flush_embedding_jobs", _stub_flush) monkeypatch.setattr(repo, "_log_vector_sync_complete", _capture_log) diff --git a/tests/repository/test_semantic_search_base.py b/tests/repository/test_semantic_search_base.py index 7c66a91f..7d8a60d2 100644 --- a/tests/repository/test_semantic_search_base.py +++ b/tests/repository/test_semantic_search_base.py @@ -4,7 +4,10 @@ _search_hybrid entity_id fusion key, and SemanticSearchDisabledError in SQLite. """ +import asyncio +from contextlib import asynccontextmanager from types import SimpleNamespace +from unittest.mock import AsyncMock import pytest @@ -312,8 +315,8 @@ async def test_sync_entity_vectors_batch_flushes_at_configured_threshold(monkeyp } flush_sizes: list[int] = [] - async def _stub_prepare(entity_id: int) -> _PreparedEntityVectorSync: - return prepared_by_entity[entity_id] + async def _stub_prepare_window(entity_ids: list[int]): + return [prepared_by_entity[entity_id] for entity_id in entity_ids] async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): flush_sizes.append(len(flush_jobs)) @@ -325,7 +328,7 @@ async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): entity_runtime.pop(job.entity_id, None) return (0.1, 0.2) - monkeypatch.setattr(repo, "_prepare_entity_vector_jobs", _stub_prepare) + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs_window", _stub_prepare_window) monkeypatch.setattr(repo, "_flush_embedding_jobs", _stub_flush) result = await repo.sync_entity_vectors_batch([1, 2, 3]) @@ -342,6 +345,39 @@ async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): assert result.write_seconds_total == pytest.approx(0.4) +@pytest.mark.asyncio +async def test_sync_entity_vectors_batch_skip_only_has_zero_queue_wait(monkeypatch): + """Skip-only batches should not accumulate synthetic queue wait.""" + repo = _ConcreteRepo() + repo._semantic_enabled = True + repo._embedding_provider = object() + + async def _stub_prepare_window(entity_ids: list[int]): + return [ + _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=float(entity_id), + source_rows_count=1, + embedding_jobs=[], + chunks_total=2, + chunks_skipped=2, + entity_skipped=True, + prepare_seconds=0.25, + ) + for entity_id in entity_ids + ] + + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs_window", _stub_prepare_window) + + result = await repo.sync_entity_vectors_batch([1, 2]) + + assert result.entities_total == 2 + assert result.entities_synced == 2 + assert result.entities_skipped == 2 + assert result.embedding_jobs_total == 0 + assert result.queue_wait_seconds_total == pytest.approx(0.0) + + @pytest.mark.asyncio async def test_sync_entity_vectors_batch_continue_on_error(monkeypatch): """Batch sync should continue after per-entity and per-flush failures.""" @@ -350,12 +386,18 @@ async def test_sync_entity_vectors_batch_continue_on_error(monkeypatch): repo._embedding_provider = object() repo._semantic_embedding_sync_batch_size = 1 - async def _stub_prepare(entity_id: int) -> _PreparedEntityVectorSync: - if entity_id == 2: - raise RuntimeError("prepare failed") - return _PreparedEntityVectorSync( - entity_id, float(entity_id), 1, [(100 + entity_id, "chunk")] - ) + async def _stub_prepare_window(entity_ids: list[int]): + prepared = [] + for entity_id in entity_ids: + if entity_id == 2: + prepared.append(RuntimeError("prepare failed")) + continue + prepared.append( + _PreparedEntityVectorSync( + entity_id, float(entity_id), 1, [(100 + entity_id, "chunk")] + ) + ) + return prepared async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): entity_id = flush_jobs[0].entity_id @@ -367,7 +409,7 @@ async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): entity_runtime.pop(entity_id, None) return (0.05, 0.05) - monkeypatch.setattr(repo, "_prepare_entity_vector_jobs", _stub_prepare) + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs_window", _stub_prepare_window) monkeypatch.setattr(repo, "_flush_embedding_jobs", _stub_flush) result = await repo.sync_entity_vectors_batch([1, 2, 3]) @@ -378,6 +420,71 @@ async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): assert result.failed_entity_ids == [2, 3] +@pytest.mark.asyncio +async def test_sync_entity_vectors_batch_only_attributes_queue_wait_to_flushed_entities( + monkeypatch, +): + """Mixed batches should only charge queue wait to entities that entered flush work.""" + repo = _ConcreteRepo() + repo._semantic_enabled = True + repo._embedding_provider = object() + repo._semantic_embedding_sync_batch_size = 2 + + async def _stub_prepare_window(entity_ids: list[int]): + prepared: list[_PreparedEntityVectorSync] = [] + for entity_id in entity_ids: + if entity_id == 1: + prepared.append( + _PreparedEntityVectorSync( + entity_id=1, + sync_start=0.0, + source_rows_count=1, + embedding_jobs=[], + chunks_total=2, + chunks_skipped=2, + entity_skipped=True, + prepare_seconds=0.5, + ) + ) + continue + prepared.append( + _PreparedEntityVectorSync( + entity_id=2, + sync_start=0.0, + source_rows_count=1, + embedding_jobs=[(102, "chunk-2")], + prepare_seconds=1.0, + ) + ) + return prepared + + async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): + runtime = entity_runtime[2] + runtime.embed_seconds = 1.0 + runtime.write_seconds = 0.5 + runtime.remaining_jobs = 0 + synced_entity_ids.add(2) + return (1.0, 0.5) + + perf_counter_values = iter([0.0, 2.0, 4.0, 5.0]) + + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs_window", _stub_prepare_window) + monkeypatch.setattr(repo, "_flush_embedding_jobs", _stub_flush) + monkeypatch.setattr( + search_repository_base_module.time, + "perf_counter", + lambda: next(perf_counter_values), + ) + + result = await repo.sync_entity_vectors_batch([1, 2]) + + assert result.entities_total == 2 + assert result.entities_synced == 2 + assert result.entities_skipped == 1 + assert result.embedding_jobs_total == 1 + assert result.queue_wait_seconds_total == pytest.approx(1.5) + + @pytest.mark.asyncio async def test_sync_entity_vectors_batch_tracks_prepare_and_queue_wait_seconds(monkeypatch): """Queue wait should be reported separately from prepare/embed/write timings.""" @@ -386,14 +493,17 @@ async def test_sync_entity_vectors_batch_tracks_prepare_and_queue_wait_seconds(m repo._embedding_provider = object() repo._semantic_embedding_sync_batch_size = 2 - async def _stub_prepare(entity_id: int) -> _PreparedEntityVectorSync: - return _PreparedEntityVectorSync( - entity_id=entity_id, - sync_start=0.0, - source_rows_count=1, - embedding_jobs=[(100 + entity_id, f"chunk-{entity_id}")], - prepare_seconds=1.0, - ) + async def _stub_prepare_window(entity_ids: list[int]): + return [ + _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=0.0, + source_rows_count=1, + embedding_jobs=[(100 + entity_id, f"chunk-{entity_id}")], + prepare_seconds=1.0, + ) + for entity_id in entity_ids + ] async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): assert len(flush_jobs) == 2 @@ -414,9 +524,9 @@ async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): def _capture_log(**kwargs): logged_completion.append(kwargs) - perf_counter_values = iter([4.0, 5.0]) + perf_counter_values = iter([0.0, 4.0, 5.0, 6.0]) - monkeypatch.setattr(repo, "_prepare_entity_vector_jobs", _stub_prepare) + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs_window", _stub_prepare_window) monkeypatch.setattr(repo, "_flush_embedding_jobs", _stub_flush) monkeypatch.setattr(repo, "_log_vector_sync_complete", _capture_log) monkeypatch.setattr( @@ -438,3 +548,113 @@ def _capture_log(**kwargs): for record in logged_completion: assert record["prepare_seconds"] == pytest.approx(1.0) assert record["queue_wait_seconds"] == pytest.approx(1.5) + + +@pytest.mark.asyncio +async def test_prepare_window_uses_entity_local_timing_after_shared_reads(monkeypatch): + """Per-entity prepare timing should start when that entity work actually begins.""" + repo = _ConcreteRepo() + repo._semantic_enabled = True + repo._embedding_provider = SimpleNamespace(model_name="stub", dimensions=4) + + async def _stub_fetch_source_rows(session, entity_ids: list[int]): + search_repository_base_module.time.perf_counter() + return {entity_id: [] for entity_id in entity_ids} + + async def _stub_fetch_existing_rows(session, entity_ids: list[int]): + search_repository_base_module.time.perf_counter() + return {entity_id: [] for entity_id in entity_ids} + + @asynccontextmanager + async def fake_scoped_session(session_maker): + yield AsyncMock() + + @asynccontextmanager + async def _yielding_write_scope(): + await asyncio.sleep(0) + yield + + perf_counter_values = iter([0.0, 5.0, 10.0, 11.0, 12.0, 13.0]) + + monkeypatch.setattr( + "basic_memory.repository.search_repository_base.db.scoped_session", + fake_scoped_session, + ) + monkeypatch.setattr(repo, "_fetch_prepare_window_source_rows", _stub_fetch_source_rows) + monkeypatch.setattr(repo, "_fetch_prepare_window_existing_rows", _stub_fetch_existing_rows) + monkeypatch.setattr(repo, "_prepare_entity_write_scope", _yielding_write_scope) + monkeypatch.setattr(repo, "_prepare_vector_session", AsyncMock()) + monkeypatch.setattr(repo, "_delete_entity_chunks", AsyncMock()) + monkeypatch.setattr( + search_repository_base_module.time, + "perf_counter", + lambda: next(perf_counter_values), + ) + + prepared = await repo._prepare_entity_vector_jobs_window([1, 2]) + + assert [result.sync_start for result in prepared] == [10.0, 11.0] + assert [result.prepare_seconds for result in prepared] == [2.0, 2.0] + + +@pytest.mark.asyncio +async def test_sync_entity_vectors_batch_records_entity_granularity_histograms(monkeypatch): + """Entity timing histograms should emit one sample per finalized entity.""" + repo = _ConcreteRepo() + repo._semantic_enabled = True + repo._embedding_provider = object() + repo._semantic_embedding_sync_batch_size = 2 + + async def _stub_prepare_window(entity_ids: list[int]): + return [ + _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=0.0, + source_rows_count=1, + embedding_jobs=[(100 + entity_id, f"chunk-{entity_id}")], + prepare_seconds=1.0, + ) + for entity_id in entity_ids + ] + + async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): + for job in flush_jobs: + runtime = entity_runtime[job.entity_id] + runtime.embed_seconds = 1.0 + runtime.write_seconds = 0.5 + runtime.remaining_jobs = 0 + synced_entity_ids.add(job.entity_id) + return (2.0, 1.0) + + histogram_calls: list[tuple[str, float, dict]] = [] + counter_calls: list[tuple[str, float, dict]] = [] + perf_counter_values = iter([0.0, 3.0, 4.5, 6.0]) + + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs_window", _stub_prepare_window) + monkeypatch.setattr(repo, "_flush_embedding_jobs", _stub_flush) + monkeypatch.setattr( + search_repository_base_module.telemetry, + "record_histogram", + lambda name, amount, **attrs: histogram_calls.append((name, amount, attrs)), + ) + monkeypatch.setattr( + search_repository_base_module.telemetry, + "add_counter", + lambda name, amount, **attrs: counter_calls.append((name, amount, attrs)), + ) + monkeypatch.setattr( + search_repository_base_module.time, + "perf_counter", + lambda: next(perf_counter_values), + ) + + result = await repo.sync_entity_vectors_batch([1, 2]) + + assert result.entities_synced == 2 + histogram_names = [name for name, _, _ in histogram_calls] + assert histogram_names.count("vector_sync_prepare_seconds") == 2 + assert histogram_names.count("vector_sync_queue_wait_seconds") == 2 + assert histogram_names.count("vector_sync_embed_seconds") == 2 + assert histogram_names.count("vector_sync_write_seconds") == 2 + assert histogram_names.count("vector_sync_batch_total_seconds") == 1 + assert [name for name, _, _ in counter_calls].count("vector_sync_entities_total") == 1 diff --git a/tests/repository/test_sqlite_vector_search_repository.py b/tests/repository/test_sqlite_vector_search_repository.py index d85f17fa..12ef6e01 100644 --- a/tests/repository/test_sqlite_vector_search_repository.py +++ b/tests/repository/test_sqlite_vector_search_repository.py @@ -1,12 +1,15 @@ """SQLite sqlite-vec search repository tests.""" +import asyncio +from contextlib import asynccontextmanager from datetime import datetime, timezone -from unittest.mock import MagicMock +from unittest.mock import AsyncMock, MagicMock import pytest from sqlalchemy import text from basic_memory import db +from basic_memory.config import BasicMemoryConfig, DatabaseBackend from basic_memory.repository.search_index_row import SearchIndexRow from basic_memory.repository.sqlite_search_repository import SQLiteSearchRepository from basic_memory.schemas.search import SearchItemType, SearchRetrievalMode @@ -83,6 +86,27 @@ def _enable_semantic( search_repository._vector_tables_initialized = False +def _make_sqlite_repo_for_unit_tests() -> SQLiteSearchRepository: + """Build a SQLite repository without touching a real sqlite-vec install.""" + session_maker = MagicMock() + app_config = BasicMemoryConfig( + env="test", + projects={"test-project": "/tmp/test"}, + default_project="test-project", + database_backend=DatabaseBackend.SQLITE, + semantic_search_enabled=True, + semantic_embedding_sync_batch_size=8, + ) + repo = SQLiteSearchRepository( + session_maker, + project_id=1, + app_config=app_config, + embedding_provider=StubEmbeddingProvider(), + ) + repo._vector_tables_initialized = True + return repo + + @pytest.mark.asyncio async def test_sqlite_vec_tables_are_created_and_rebuilt(search_repository): """Repository rebuilds vector schema deterministically on mismatch.""" @@ -239,6 +263,7 @@ async def test_sqlite_vector_sync_skips_unchanged_and_reembeds_changed_content(s assert unchanged_result.entities_synced == 1 assert unchanged_result.entities_skipped == 1 assert unchanged_result.embedding_jobs_total == 0 + assert unchanged_result.queue_wait_seconds_total == pytest.approx(0.0, abs=0.01) assert unchanged_result.chunks_skipped == unchanged_result.chunks_total await search_repository.index_item( @@ -266,6 +291,76 @@ async def test_sqlite_vector_sync_skips_unchanged_and_reembeds_changed_content(s assert model_changed_result.embedding_jobs_total == model_changed_result.chunks_total +@pytest.mark.asyncio +async def test_sqlite_prepare_window_uses_shared_reads_and_serialized_write_scope(monkeypatch): + """SQLite should batch read-side prepare work but serialize write-side mutations.""" + repo = _make_sqlite_repo_for_unit_tests() + + fetched_windows: list[list[int]] = [] + active_write_scopes = 0 + max_active_write_scopes = 0 + + async def _stub_fetch_source_rows(session, entity_ids: list[int]): + fetched_windows.append(list(entity_ids)) + return {entity_id: [object()] for entity_id in entity_ids} + + async def _stub_fetch_existing_rows(session, entity_ids: list[int]): + return {entity_id: [] for entity_id in entity_ids} + + def _stub_build_chunk_records(source_rows): + return [ + { + "chunk_key": "entity:1:0", + "chunk_text": "chunk text", + "source_hash": "hash", + } + ] + + @asynccontextmanager + async def _track_write_scope(): + nonlocal active_write_scopes, max_active_write_scopes + async with repo._sqlite_vec_lock: + active_write_scopes += 1 + max_active_write_scopes = max(max_active_write_scopes, active_write_scopes) + try: + yield + finally: + active_write_scopes -= 1 + + async def _stub_upsert( + session, + *, + entity_id: int, + scheduled_records, + existing_by_key, + entity_fingerprint: str, + embedding_model: str, + ): + await asyncio.sleep(0) + return [(entity_id * 100, scheduled_records[0]["chunk_text"])] + + @asynccontextmanager + async def fake_scoped_session(session_maker): + yield AsyncMock() + + monkeypatch.setattr( + "basic_memory.repository.search_repository_base.db.scoped_session", + fake_scoped_session, + ) + monkeypatch.setattr(repo, "_prepare_vector_session", AsyncMock()) + monkeypatch.setattr(repo, "_fetch_prepare_window_source_rows", _stub_fetch_source_rows) + monkeypatch.setattr(repo, "_fetch_prepare_window_existing_rows", _stub_fetch_existing_rows) + monkeypatch.setattr(repo, "_build_chunk_records", _stub_build_chunk_records) + monkeypatch.setattr(repo, "_prepare_entity_write_scope", _track_write_scope) + monkeypatch.setattr(repo, "_upsert_scheduled_chunk_records", _stub_upsert) + + prepared = await repo._prepare_entity_vector_jobs_window([1, 2]) + + assert fetched_windows == [[1, 2]] + assert [result.entity_id for result in prepared] == [1, 2] + assert max_active_write_scopes == 1 + + @pytest.mark.asyncio async def test_sqlite_vector_search_returns_ranked_entities(search_repository): """Vector mode ranks entities using sqlite-vec nearest-neighbor search.""" diff --git a/tests/test_telemetry.py b/tests/test_telemetry.py index cd133df4..b940f6fe 100644 --- a/tests/test_telemetry.py +++ b/tests/test_telemetry.py @@ -13,6 +13,20 @@ class FakeLogfire: """Small fake Logfire surface for bootstrap testing.""" + class FakeCounter: + def __init__(self, calls: list[tuple[float, dict | None]]) -> None: + self.calls = calls + + def add(self, amount: float, *, attributes=None) -> None: + self.calls.append((amount, attributes)) + + class FakeHistogram: + def __init__(self, calls: list[tuple[float, dict | None]]) -> None: + self.calls = calls + + def record(self, amount: float, *, attributes=None) -> None: + self.calls.append((amount, attributes)) + class CodeSource: def __init__(self, **kwargs): self.kwargs = kwargs @@ -21,6 +35,8 @@ def __init__(self, *, fail_on_send_to_logfire: bool = False) -> None: self.fail_on_send_to_logfire = fail_on_send_to_logfire self.configure_calls: list[dict] = [] self.span_calls: list[tuple[str, dict]] = [] + self.counter_calls: dict[str, list[tuple[float, dict | None]]] = {} + self.histogram_calls: dict[str, list[tuple[float, dict | None]]] = {} def configure(self, **kwargs) -> None: self.configure_calls.append(kwargs) @@ -30,6 +46,14 @@ def configure(self, **kwargs) -> None: def loguru_handler(self) -> dict: return {"sink": "fake-logfire", "level": "INFO"} + def metric_counter(self, name: str, *, unit: str = "", description: str = ""): + self.counter_calls.setdefault(name, []) + return self.FakeCounter(self.counter_calls[name]) + + def metric_histogram(self, name: str, *, unit: str = "", description: str = ""): + self.histogram_calls.setdefault(name, []) + return self.FakeHistogram(self.histogram_calls[name]) + @contextmanager def span(self, name: str, **attrs): self.span_calls.append((name, attrs)) @@ -176,6 +200,38 @@ def test_started_span_exposes_mutable_logfire_handle(monkeypatch) -> None: ] +def test_metrics_record_when_telemetry_enabled(monkeypatch) -> None: + fake_logfire = FakeLogfire() + telemetry.reset_telemetry_state() + monkeypatch.setattr(telemetry, "_load_logfire", lambda: fake_logfire) + telemetry.configure_telemetry( + "basic-memory-cli", + environment="dev", + enable_logfire=True, + ) + + telemetry.record_histogram( + "vector_sync_prepare_seconds", + 1.25, + unit="s", + backend="sqlite", + skip_only_batch=True, + ) + telemetry.add_counter( + "vector_sync_entities_skipped", + 2, + backend="sqlite", + skip_only_batch=True, + ) + + assert fake_logfire.histogram_calls["vector_sync_prepare_seconds"] == [ + (1.25, {"backend": "sqlite", "skip_only_batch": True}) + ] + assert fake_logfire.counter_calls["vector_sync_entities_skipped"] == [ + (2, {"backend": "sqlite", "skip_only_batch": True}) + ] + + def test_operation_creates_span_and_log_context(monkeypatch) -> None: fake_logfire = FakeLogfire() records: list[dict] = [] From 37092876ad939fcab665b01a7b927d36c47730b0 Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 13:39:26 -0500 Subject: [PATCH 02/14] wip: checkpoint full reindex parity and sqlite vector fixes Signed-off-by: phernandez --- src/basic_memory/cli/commands/db.py | 57 +++- .../repository/search_repository_base.py | 36 +- .../repository/sqlite_search_repository.py | 88 ++++- src/basic_memory/services/search_service.py | 75 ++-- src/basic_memory/sync/sync_service.py | 26 +- tests/cli/test_db_reindex.py | 322 ++++++++++++++++++ tests/repository/test_semantic_search_base.py | 44 +++ .../test_sqlite_vector_search_repository.py | 64 +++- tests/services/test_search_service.py | 4 +- tests/services/test_semantic_search.py | 64 +++- tests/sync/test_sync_service_incremental.py | 27 ++ 11 files changed, 741 insertions(+), 66 deletions(-) create mode 100644 tests/cli/test_db_reindex.py diff --git a/src/basic_memory/cli/commands/db.py b/src/basic_memory/cli/commands/db.py index dadc94fe..07f40f5c 100644 --- a/src/basic_memory/cli/commands/db.py +++ b/src/basic_memory/cli/commands/db.py @@ -26,7 +26,7 @@ class EmbeddingProgress: """Typed CLI progress payload for embedding backfills.""" entity_id: int - index: int + completed: int total: int @@ -147,20 +147,30 @@ def reindex( False, "--embeddings", "-e", help="Rebuild vector embeddings (requires semantic search)" ), search: bool = typer.Option(False, "--search", "-s", help="Rebuild full-text search index"), + full: bool = typer.Option( + False, + "--full", + help="Force a full filesystem scan and file reindex instead of the default incremental scan", + ), project: str = typer.Option( None, "--project", "-p", help="Reindex a specific project (default: all)" ), ): # pragma: no cover """Rebuild search indexes and/or vector embeddings without dropping the database. - By default rebuilds everything (search + embeddings if semantic is enabled). - Use --search or --embeddings to rebuild only one. + By default runs incremental search + embeddings (if semantic search is enabled). + Use --full to bypass incremental scan optimization, rebuild all file-backed search rows, + and re-embed all eligible notes. + Use --search or --embeddings to rebuild only one side. Examples: - bm reindex # Rebuild everything + bm reindex # Incremental search + embeddings + bm reindex --full # Full search + full re-embed bm reindex --embeddings # Only rebuild vector embeddings bm reindex --search # Only rebuild FTS index - bm reindex -p claw # Reindex only the 'claw' project + bm reindex --full --search # Full search only + bm reindex --full --embeddings # Full re-embed only + bm reindex -p claw --full # Full reindex for only the 'claw' project """ # If neither flag is set, do both if not embeddings and not search: @@ -179,10 +189,19 @@ def reindex( if not search: raise typer.Exit(0) - run_with_cleanup(_reindex(app_config, search=search, embeddings=embeddings, project=project)) + run_with_cleanup( + _reindex(app_config, search=search, embeddings=embeddings, full=full, project=project) + ) -async def _reindex(app_config, search: bool, embeddings: bool, project: str | None): +async def _reindex( + app_config, + *, + search: bool, + embeddings: bool, + full: bool, + project: str | None, +): """Run reindex operations.""" from basic_memory.repository import EntityRepository from basic_memory.repository.search_repository import create_search_repository @@ -220,6 +239,10 @@ async def _reindex(app_config, search: bool, embeddings: bool, project: str | No console.print(f"\n[bold]Project: [cyan]{proj.name}[/cyan][/bold]") if search: + search_mode_label = "full scan" if full else "incremental scan" + console.print( + f" Rebuilding full-text search index ([cyan]{search_mode_label}[/cyan])..." + ) sync_service = await get_sync_service(proj) sync_dir = Path(proj.path) with Progress( @@ -244,6 +267,8 @@ async def on_index_progress(update: IndexProgress) -> None: await sync_service.sync( sync_dir, project_name=proj.name, + force_full=full, + sync_embeddings=False, progress_callback=on_index_progress, ) progress.update(task, completed=progress.tasks[task].total or 1) @@ -251,7 +276,10 @@ async def on_index_progress(update: IndexProgress) -> None: console.print(" [green]✓[/green] Full-text search index rebuilt") if embeddings: - console.print(" Building vector embeddings...") + embedding_mode_label = "full rebuild" if full else "incremental sync" + console.print( + f" Building vector embeddings ([cyan]{embedding_mode_label}[/cyan])..." + ) entity_repository = EntityRepository(session_maker, project_id=proj.id) search_repository = create_search_repository( session_maker, project_id=proj.id, app_config=app_config @@ -274,16 +302,23 @@ async def on_index_progress(update: IndexProgress) -> None: def on_progress(entity_id, index, total): embedding_progress = EmbeddingProgress( entity_id=entity_id, - index=index, + completed=index, total=total, ) + # Trigger: repository progress now reports terminal entity completion. + # Why: operators need to see finished embedding work rather than + # entities merely entering prepare. + # Outcome: the CLI bar advances steadily with real completed work. progress.update( task, total=embedding_progress.total, - completed=embedding_progress.index, + completed=embedding_progress.completed, ) - stats = await search_service.reindex_vectors(progress_callback=on_progress) + stats = await search_service.reindex_vectors( + progress_callback=on_progress, + force_full=full, + ) progress.update(task, completed=stats["total_entities"]) console.print( diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index 298767ec..c18a3238 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -814,6 +814,22 @@ async def _sync_entity_vectors_internal( failed_entity_ids: set[int] = set() deferred_entity_ids: set[int] = set() synced_entity_ids: set[int] = set() + completed_entities = 0 + + def emit_progress(entity_id: int) -> None: + """Report terminal entity progress to callers such as the CLI. + + Trigger: an entity reaches a terminal state in this sync run. + Why: operators need progress based on completed work, not the moment + an entity merely enters prepare. + Outcome: the progress bar advances when an entity is done for this + run, whether it synced, skipped, deferred, or failed. + """ + nonlocal completed_entities + if progress_callback is None: + return + completed_entities += 1 + progress_callback(entity_id, completed_entities, total_entities) prepare_window_size = self._vector_prepare_window_size() with telemetry.started_span( @@ -826,13 +842,6 @@ async def _sync_entity_vectors_internal( for window_start in range(0, total_entities, prepare_window_size): window_entity_ids = entity_ids[window_start : window_start + prepare_window_size] - if progress_callback is not None: - # Trigger: prepare runs in bounded windows instead of strict one-by-one order. - # Why: callbacks still need deterministic per-entity positions before the window starts. - # Outcome: progress advances in prepare_window_size bursts. - for offset, entity_id in enumerate(window_entity_ids, start=window_start): - progress_callback(entity_id, offset, total_entities) - prepared_window = await self._prepare_entity_vector_jobs_window(window_entity_ids) for entity_id, prepared in zip(window_entity_ids, prepared_window, strict=True): @@ -847,6 +856,7 @@ async def _sync_entity_vectors_internal( entity_id=entity_id, error=str(prepared), ) + emit_progress(entity_id) continue embedding_jobs_count = len(prepared.embedding_jobs) @@ -886,6 +896,7 @@ async def _sync_entity_vectors_internal( shard_count=prepared.shard_count, remaining_jobs_after_shard=prepared.remaining_jobs_after_shard, ) + emit_progress(entity_id) continue entity_runtime[entity_id] = _EntitySyncRuntime( @@ -933,6 +944,7 @@ async def _sync_entity_vectors_internal( entity_runtime=entity_runtime, synced_entity_ids=synced_entity_ids, deferred_entity_ids=deferred_entity_ids, + progress_callback=emit_progress, ) except Exception as exc: if not continue_on_error: @@ -952,6 +964,8 @@ async def _sync_entity_vectors_internal( chunk_count=len(flush_jobs), error=str(exc), ) + for failed_entity_id in affected_entity_ids: + emit_progress(failed_entity_id) if pending_jobs: flush_jobs = list(pending_jobs) @@ -968,6 +982,7 @@ async def _sync_entity_vectors_internal( entity_runtime=entity_runtime, synced_entity_ids=synced_entity_ids, deferred_entity_ids=deferred_entity_ids, + progress_callback=emit_progress, ) except Exception as exc: if not continue_on_error: @@ -987,6 +1002,8 @@ async def _sync_entity_vectors_internal( chunk_count=len(flush_jobs), error=str(exc), ) + for failed_entity_id in affected_entity_ids: + emit_progress(failed_entity_id) # Trigger: this should never happen after all flushes succeed. # Why: remaining jobs mean runtime tracking drifted from queued jobs. @@ -1002,6 +1019,8 @@ async def _sync_entity_vectors_internal( project_id=self.project_id, unfinished_entities=orphan_runtime_entities, ) + for failed_entity_id in orphan_runtime_entities: + emit_progress(failed_entity_id) # Keep result counters aligned with successful/failed terminal states. synced_entity_ids.difference_update(failed_entity_ids) @@ -1527,6 +1546,7 @@ def _finalize_completed_entity_syncs( entity_runtime: dict[int, _EntitySyncRuntime], synced_entity_ids: set[int], deferred_entity_ids: set[int], + progress_callback: Callable[[int], None] | None = None, ) -> float: """Finalize completed entities and return cumulative queue wait seconds.""" queue_wait_seconds_total = 0.0 @@ -1570,6 +1590,8 @@ def _finalize_completed_entity_syncs( remaining_jobs_after_shard=runtime.remaining_jobs_after_shard, ) entity_runtime.pop(entity_id, None) + if progress_callback is not None: + progress_callback(entity_id) return queue_wait_seconds_total diff --git a/src/basic_memory/repository/sqlite_search_repository.py b/src/basic_memory/repository/sqlite_search_repository.py index 447db427..1c07ac1e 100644 --- a/src/basic_memory/repository/sqlite_search_repository.py +++ b/src/basic_memory/repository/sqlite_search_repository.py @@ -56,7 +56,8 @@ def __init__( self._app_config.semantic_embedding_sync_batch_size ) self._embedding_provider = embedding_provider - self._sqlite_vec_lock = asyncio.Lock() + self._sqlite_vec_load_lock = asyncio.Lock() + self._sqlite_prepare_write_lock = asyncio.Lock() self._vector_tables_initialized = False self._vector_dimensions = 384 @@ -357,7 +358,13 @@ async def _ensure_sqlite_vec_loaded(self, session) -> None: "pip install -U basic-memory" ) from exc - async with self._sqlite_vec_lock: + # Trigger: sqlite-vec must be loaded on each SQLite connection before + # vec tables and functions are visible. + # Why: extension loading is connection-local, so we need one narrow + # critical section to avoid racing two coroutines on the same step. + # Outcome: connection setup stays serialized without blocking unrelated + # prepare work behind the write-side lock. + async with self._sqlite_vec_load_lock: try: await session.execute(text("SELECT vec_version()")) return @@ -558,6 +565,76 @@ async def _delete_stale_chunks( stale_params, ) + async def delete_entity_vector_rows(self, entity_id: int) -> None: + """Delete one entity's vec rows on a sqlite-vec-enabled connection.""" + await self._ensure_vector_tables() + + async with db.scoped_session(self.session_maker) as session: + await self._ensure_sqlite_vec_loaded(session) + + # Constraint: sqlite-vec virtual tables are only visible after vec0 is + # loaded on this exact connection. + # Why: generic repository sessions can reach search_vector_chunks but still + # fail with "no such module: vec0" when touching embeddings. + # Outcome: service-level cleanup routes vec-table deletes through this helper. + await self._delete_entity_chunks(session, entity_id) + await session.commit() + + async def delete_project_vector_rows(self) -> None: + """Delete all vector rows for this project on a sqlite-vec-enabled connection.""" + await self._ensure_vector_tables() + + async with db.scoped_session(self.session_maker) as session: + await self._ensure_sqlite_vec_loaded(session) + + # Constraint: sqlite-vec stores embeddings separately with no cascade delete. + # Why: full rebuild must clear embeddings before chunk rows or stale vectors remain. + # Outcome: the next sync recreates the project's derived vectors from scratch. + await session.execute( + text( + "DELETE FROM search_vector_embeddings WHERE rowid IN (" + "SELECT id FROM search_vector_chunks WHERE project_id = :project_id)" + ), + {"project_id": self.project_id}, + ) + await session.execute( + text("DELETE FROM search_vector_chunks WHERE project_id = :project_id"), + {"project_id": self.project_id}, + ) + await session.commit() + + async def delete_stale_vector_rows(self) -> None: + """Delete vector rows whose source entities no longer exist.""" + await self._ensure_vector_tables() + + async with db.scoped_session(self.session_maker) as session: + await self._ensure_sqlite_vec_loaded(session) + + stale_entity_filter = ( + "entity_id NOT IN (SELECT id FROM entity WHERE project_id = :project_id)" + ) + params = {"project_id": self.project_id} + + # Trigger: deleted entities left behind derived vector rows. + # Why: sqlite-vec does not provide cascade cleanup from our chunk table. + # Outcome: stale vector state disappears before coverage stats or reindex runs. + await session.execute( + text( + "DELETE FROM search_vector_embeddings WHERE rowid IN (" + "SELECT id FROM search_vector_chunks " + f"WHERE project_id = :project_id AND {stale_entity_filter})" + ), + params, + ) + await session.execute( + text( + "DELETE FROM search_vector_chunks " + f"WHERE project_id = :project_id AND {stale_entity_filter}" + ), + params, + ) + await session.commit() + def _distance_to_similarity(self, distance: float) -> float: """Convert L2 distance to cosine similarity for normalized embeddings. @@ -569,7 +646,12 @@ def _distance_to_similarity(self, distance: float) -> float: @asynccontextmanager async def _prepare_entity_write_scope(self): """SQLite keeps the shared read window, but funnels prepare writes through one lock.""" - async with self._sqlite_vec_lock: + # Trigger: the shared prepare window fans out per entity after batched reads. + # Why: SQLite still benefits from shared reads, but write transactions do + # not get meaningfully faster when we open many at once. + # Outcome: one entity at a time mutates chunk rows, while vec extension + # loading uses its own separate lock and cannot deadlock this path. + async with self._sqlite_prepare_write_lock: yield def _prepare_window_existing_rows_sql(self, placeholders: str) -> str: diff --git a/src/basic_memory/services/search_service.py b/src/basic_memory/services/search_service.py index 102d69be..b1aec10b 100644 --- a/src/basic_memory/services/search_service.py +++ b/src/basic_memory/services/search_service.py @@ -521,7 +521,9 @@ async def sync_entity_vectors_batch( chunks_total=sum(result.chunks_total for result in repository_results), chunks_skipped=sum(result.chunks_skipped for result in repository_results), embedding_jobs_total=sum(result.embedding_jobs_total for result in repository_results), - prepare_seconds_total=sum(result.prepare_seconds_total for result in repository_results), + prepare_seconds_total=sum( + result.prepare_seconds_total for result in repository_results + ), queue_wait_seconds_total=sum( result.queue_wait_seconds_total for result in repository_results ), @@ -530,11 +532,14 @@ async def sync_entity_vectors_batch( ) return batch_result - async def reindex_vectors(self, progress_callback=None) -> dict: + async def reindex_vectors(self, progress_callback=None, force_full: bool = False) -> dict: """Rebuild vector embeddings for all entities. Args: - progress_callback: Optional callable(entity_id, index, total) for progress reporting. + progress_callback: Optional callable(entity_id, completed, total) for progress + reporting when an entity reaches a terminal state in this run. + force_full: When True, clear this project's derived vectors first so every + eligible entity re-embeds from scratch. Returns: dict with stats: total_entities, embedded, skipped, errors @@ -545,6 +550,8 @@ async def reindex_vectors(self, progress_callback=None) -> dict: # Clean up stale rows in search_index and search_vector_chunks # that reference entity_ids no longer in the entity table await self._purge_stale_search_rows() + if force_full: + await self._clear_project_vectors_for_full_reindex() batch_result = await self.sync_entity_vectors_batch( entity_ids, @@ -562,6 +569,31 @@ async def reindex_vectors(self, progress_callback=None) -> dict: return stats + async def _clear_project_vectors_for_full_reindex(self) -> None: + """Remove this project's derived vectors so a full reindex re-embeds everything. + + Trigger: the operator asked for a full embedding rebuild rather than the + default incremental vector sync. + Why: the repository sync path intentionally skips unchanged entities, so + we need to clear the derived vector state first to force fresh embeddings. + Outcome: the next batch sync recreates every eligible entity's vectors. + """ + from basic_memory.repository.sqlite_search_repository import SQLiteSearchRepository + + project_id = self.repository.project_id + params = {"project_id": project_id} + + # Constraint: sqlite-vec stores embeddings in a separate rowid table with + # no cascade delete, so embeddings must be removed before chunk rows. + if isinstance(self.repository, SQLiteSearchRepository): + await self.repository.delete_project_vector_rows() + else: + await self.repository.execute_query( + text("DELETE FROM search_vector_chunks WHERE project_id = :project_id"), + params, + ) + logger.info("Cleared project vectors for full reindex", project_id=project_id) + async def _purge_stale_search_rows(self) -> None: """Remove rows from search_index and search_vector_chunks for deleted entities. @@ -588,24 +620,17 @@ async def _purge_stale_search_rows(self) -> None: # SQLite vec has no CASCADE — must delete embeddings before chunks if isinstance(self.repository, SQLiteSearchRepository): + await self.repository.delete_stale_vector_rows() + else: + # Postgres CASCADE handles embedding deletion automatically await self.repository.execute_query( text( - "DELETE FROM search_vector_embeddings WHERE rowid IN (" - "SELECT id FROM search_vector_chunks " - f"WHERE project_id = :project_id AND {stale_entity_filter})" + f"DELETE FROM search_vector_chunks " + f"WHERE project_id = :project_id AND {stale_entity_filter}" ), params, ) - # Postgres CASCADE handles embedding deletion automatically - await self.repository.execute_query( - text( - f"DELETE FROM search_vector_chunks " - f"WHERE project_id = :project_id AND {stale_entity_filter}" - ), - params, - ) - logger.info("Purged stale search rows for deleted entities", project_id=project_id) @staticmethod @@ -640,28 +665,24 @@ async def _clear_entity_vectors(self, entity_id: int) -> None: # Trigger: semantic indexing is disabled for this repository instance. # Why: repositories only create vector tables when semantic search is enabled. # Outcome: skip cleanup because there are no active derived vector rows to maintain. - if isinstance(self.repository, SearchRepositoryBase) and not self.repository._semantic_enabled: + if ( + isinstance(self.repository, SearchRepositoryBase) + and not self.repository._semantic_enabled + ): return params = {"project_id": self.repository.project_id, "entity_id": entity_id} if isinstance(self.repository, SQLiteSearchRepository): + await self.repository.delete_entity_vector_rows(entity_id) + else: await self.repository.execute_query( text( - "DELETE FROM search_vector_embeddings WHERE rowid IN (" - "SELECT id FROM search_vector_chunks " - "WHERE project_id = :project_id AND entity_id = :entity_id)" + "DELETE FROM search_vector_chunks " + "WHERE project_id = :project_id AND entity_id = :entity_id" ), params, ) - await self.repository.execute_query( - text( - "DELETE FROM search_vector_chunks " - "WHERE project_id = :project_id AND entity_id = :entity_id" - ), - params, - ) - async def index_entity_file( self, entity: Entity, diff --git a/src/basic_memory/sync/sync_service.py b/src/basic_memory/sync/sync_service.py index c912d549..0cfa7302 100644 --- a/src/basic_memory/sync/sync_service.py +++ b/src/basic_memory/sync/sync_service.py @@ -104,6 +104,7 @@ class SyncReport: deleted: Set[str] = field(default_factory=set) moves: Dict[str, str] = field(default_factory=dict) # old_path -> new_path checksums: Dict[str, str] = field(default_factory=dict) # path -> checksum + scanned_paths: Set[str] = field(default_factory=set) skipped_files: List[SkippedFile] = field(default_factory=list) @property @@ -292,6 +293,7 @@ async def sync( directory: Path, project_name: Optional[str] = None, force_full: bool = False, + sync_embeddings: bool = True, progress_callback: Callable[[IndexProgress], Awaitable[None]] | None = None, ) -> SyncReport: """Sync all files with database and update scan watermark. @@ -300,6 +302,7 @@ async def sync( directory: Directory to sync project_name: Optional project name force_full: If True, force a full scan bypassing watermark optimization + sync_embeddings: If True, generate vectors for entities indexed during this sync progress_callback: Optional callback for typed indexing progress updates """ @@ -348,7 +351,16 @@ async def sync( for path in report.deleted: await self.handle_delete(path) - changed_paths = sorted(report.new | report.modified) + # Trigger: the caller requested a full reindex pass through the sync path. + # Why: cloud-style "full" semantics should rebuild every current file-backed + # search row, not only the files that differ from the last watermark. + # Outcome: progress reflects the whole project and unchanged files are + # re-indexed without inflating the change report itself. + changed_paths = ( + sorted(report.scanned_paths) + if force_full + else sorted(report.new | report.modified) + ) indexed_entities, skipped_files = await self._index_changed_files( changed_paths, report.checksums, @@ -357,9 +369,12 @@ async def sync( report.skipped_files.extend(skipped_files) synced_entity_ids = [indexed.entity_id for indexed in indexed_entities] - # Only resolve relations if there were actual changes - # If no files changed, no new unresolved relations could have been created - if report.total > 0: + # Trigger: either the filesystem diff found changes, or the caller forced a + # full reindex and we just reprocessed the current files. + # Why: relation resolution should follow the file-processing work that just ran, + # not only the lightweight diff summary. + # Outcome: full reindex can heal relation state even when the diff report is empty. + if report.total > 0 or (force_full and indexed_entities): with telemetry.scope( "sync.project.resolve_relations", relation_scope="all_pending" ): @@ -369,7 +384,7 @@ async def sync( # Batch-generate vector embeddings for all synced entities synced_entity_ids = list(dict.fromkeys(synced_entity_ids)) - if synced_entity_ids and self.app_config.semantic_search_enabled: + if synced_entity_ids and sync_embeddings and self.app_config.semantic_search_enabled: try: with telemetry.scope( "sync.project.sync_embeddings", @@ -906,6 +921,7 @@ async def scan(self, directory, force_full: bool = False): # Store checksums for files that need syncing report.checksums = changed_checksums + report.scanned_paths = scanned_paths scan_duration_ms = int((time.time() - scan_start_time) * 1000) diff --git a/tests/cli/test_db_reindex.py b/tests/cli/test_db_reindex.py new file mode 100644 index 00000000..16cf838e --- /dev/null +++ b/tests/cli/test_db_reindex.py @@ -0,0 +1,322 @@ +"""Tests for `bm reindex` CLI wiring.""" + +import asyncio +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest +from typer.testing import CliRunner + +from basic_memory.cli.app import app +import basic_memory.cli.commands.db as db_cmd # noqa: F401 + + +runner = CliRunner() + + +def _stub_app_config(*, semantic_search_enabled: bool = True) -> SimpleNamespace: + """Build the minimal config surface the CLI reindex path expects.""" + return SimpleNamespace( + semantic_search_enabled=semantic_search_enabled, + database_path=Path("/tmp/basic-memory.db"), + get_project_mode=lambda project_name: None, + ) + + +def _configure_reindex_cli(monkeypatch, app_config: SimpleNamespace) -> None: + """Keep CLI tests focused on reindex wiring instead of full app startup.""" + monkeypatch.setattr("basic_memory.cli.app.init_cli_logging", lambda: None) + monkeypatch.setattr("basic_memory.cli.app.maybe_show_init_line", lambda *_args: None) + monkeypatch.setattr("basic_memory.cli.app.maybe_show_cloud_promo", lambda *_args: None) + monkeypatch.setattr("basic_memory.cli.app.maybe_run_periodic_auto_update", lambda *_args: None) + monkeypatch.setattr( + "basic_memory.cli.app.CliContainer.create", + lambda: SimpleNamespace(config=app_config, mode=SimpleNamespace(is_cloud=False)), + ) + monkeypatch.setattr( + db_cmd, + "ConfigManager", + lambda: SimpleNamespace(config=app_config), + ) + + +def test_reindex_defaults_to_incremental_search_and_embeddings(monkeypatch): + app_config = _stub_app_config() + _configure_reindex_cli(monkeypatch, app_config) + captured: dict[str, object] = {} + + async def _stub_reindex(app_config, *, search: bool, embeddings: bool, full: bool, project): + captured.update( + { + "app_config": app_config, + "search": search, + "embeddings": embeddings, + "full": full, + "project": project, + } + ) + + monkeypatch.setattr(db_cmd, "_reindex", _stub_reindex) + monkeypatch.setattr(db_cmd, "run_with_cleanup", lambda coro: asyncio.run(coro)) + + result = runner.invoke(app, ["reindex"]) + + assert result.exit_code == 0 + assert captured == { + "app_config": app_config, + "search": True, + "embeddings": True, + "full": False, + "project": None, + } + + +def test_reindex_full_runs_full_search_and_embeddings(monkeypatch): + app_config = _stub_app_config() + _configure_reindex_cli(monkeypatch, app_config) + captured: dict[str, object] = {} + + async def _stub_reindex(app_config, *, search: bool, embeddings: bool, full: bool, project): + captured.update( + { + "search": search, + "embeddings": embeddings, + "full": full, + "project": project, + } + ) + + monkeypatch.setattr(db_cmd, "_reindex", _stub_reindex) + monkeypatch.setattr(db_cmd, "run_with_cleanup", lambda coro: asyncio.run(coro)) + + result = runner.invoke(app, ["reindex", "--full"]) + + assert result.exit_code == 0 + assert captured == { + "search": True, + "embeddings": True, + "full": True, + "project": None, + } + + +def test_reindex_full_search_runs_search_only(monkeypatch): + app_config = _stub_app_config() + _configure_reindex_cli(monkeypatch, app_config) + captured: dict[str, object] = {} + + async def _stub_reindex(app_config, *, search: bool, embeddings: bool, full: bool, project): + captured.update( + { + "search": search, + "embeddings": embeddings, + "full": full, + "project": project, + } + ) + + monkeypatch.setattr(db_cmd, "_reindex", _stub_reindex) + monkeypatch.setattr(db_cmd, "run_with_cleanup", lambda coro: asyncio.run(coro)) + + result = runner.invoke(app, ["reindex", "--full", "--search"]) + + assert result.exit_code == 0 + assert captured == { + "search": True, + "embeddings": False, + "full": True, + "project": None, + } + + +def test_reindex_embeddings_only_preserves_incremental_default(monkeypatch): + app_config = _stub_app_config() + _configure_reindex_cli(monkeypatch, app_config) + captured: dict[str, object] = {} + + async def _stub_reindex(app_config, *, search: bool, embeddings: bool, full: bool, project): + captured.update( + { + "search": search, + "embeddings": embeddings, + "full": full, + "project": project, + } + ) + + monkeypatch.setattr(db_cmd, "_reindex", _stub_reindex) + monkeypatch.setattr(db_cmd, "run_with_cleanup", lambda coro: asyncio.run(coro)) + + result = runner.invoke(app, ["reindex", "--embeddings"]) + + assert result.exit_code == 0 + assert captured == { + "search": False, + "embeddings": True, + "full": False, + "project": None, + } + + +@pytest.mark.asyncio +async def test_reindex_project_full_passes_force_full_to_sync_and_reports_mode(monkeypatch): + app_config = _stub_app_config() + project = SimpleNamespace(id=1, name="foo", path="/tmp/foo") + session_maker = object() + sync_service = SimpleNamespace(sync=AsyncMock()) + printed_lines: list[str] = [] + + class StubProjectRepository: + def __init__(self, _session_maker): + self._session_maker = _session_maker + + async def get_active_projects(self): + return [project] + + class SilentProgress: + def __init__(self, *args, **kwargs): + self.tasks: dict[int, SimpleNamespace] = {} + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def add_task(self, description, total=1): + self.tasks[1] = SimpleNamespace(total=total, description=description) + return 1 + + def update(self, task_id, **kwargs): + if "total" in kwargs: + self.tasks[task_id].total = kwargs["total"] + + monkeypatch.setattr(db_cmd, "reconcile_projects_with_config", AsyncMock()) + monkeypatch.setattr( + db_cmd.db, + "get_or_create_db", + AsyncMock(return_value=(None, session_maker)), + ) + monkeypatch.setattr(db_cmd.db, "shutdown_db", AsyncMock()) + monkeypatch.setattr(db_cmd, "ProjectRepository", StubProjectRepository) + monkeypatch.setattr(db_cmd, "get_sync_service", AsyncMock(return_value=sync_service)) + monkeypatch.setattr(db_cmd, "Progress", SilentProgress) + monkeypatch.setattr( + db_cmd.console, + "print", + lambda message="", *args, **kwargs: printed_lines.append(str(message)), + ) + + await db_cmd._reindex( + app_config, + search=True, + embeddings=False, + full=True, + project="foo", + ) + + sync_service.sync.assert_awaited_once() + sync_call = sync_service.sync.await_args + assert sync_call.args[0] == Path("/tmp/foo") + assert sync_call.kwargs["project_name"] == "foo" + assert sync_call.kwargs["force_full"] is True + assert sync_call.kwargs["sync_embeddings"] is False + assert callable(sync_call.kwargs["progress_callback"]) + assert any("full scan" in line for line in printed_lines) + + +@pytest.mark.asyncio +async def test_reindex_embeddings_only_full_passes_force_full_to_vector_reindex(monkeypatch): + app_config = _stub_app_config() + project = SimpleNamespace(id=1, name="foo", path="/tmp/foo") + session_maker = object() + printed_lines: list[str] = [] + vector_reindex_calls: list[dict[str, object]] = [] + + class StubProjectRepository: + def __init__(self, _session_maker): + self._session_maker = _session_maker + + async def get_active_projects(self): + return [project] + + class StubSearchService: + def __init__(self, search_repository, entity_repository, file_service): + self.search_repository = search_repository + self.entity_repository = entity_repository + self.file_service = file_service + + async def reindex_vectors(self, *, progress_callback=None, force_full: bool = False): + vector_reindex_calls.append( + { + "progress_callback": progress_callback, + "force_full": force_full, + } + ) + return {"total_entities": 2, "embedded": 2, "skipped": 0, "errors": 0} + + class SilentProgress: + def __init__(self, *args, **kwargs): + self.tasks: dict[int, SimpleNamespace] = {} + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def add_task(self, description, total=None): + self.tasks[1] = SimpleNamespace(total=total, description=description) + return 1 + + def update(self, task_id, **kwargs): + if "total" in kwargs: + self.tasks[task_id].total = kwargs["total"] + + monkeypatch.setattr(db_cmd, "reconcile_projects_with_config", AsyncMock()) + monkeypatch.setattr( + db_cmd.db, + "get_or_create_db", + AsyncMock(return_value=(None, session_maker)), + ) + monkeypatch.setattr(db_cmd.db, "shutdown_db", AsyncMock()) + monkeypatch.setattr(db_cmd, "ProjectRepository", StubProjectRepository) + monkeypatch.setattr( + "basic_memory.repository.search_repository.create_search_repository", + lambda *args, **kwargs: object(), + ) + monkeypatch.setattr( + "basic_memory.repository.EntityRepository", lambda *args, **kwargs: object() + ) + monkeypatch.setattr( + "basic_memory.markdown.entity_parser.EntityParser", + lambda *args, **kwargs: object(), + ) + monkeypatch.setattr( + "basic_memory.markdown.markdown_processor.MarkdownProcessor", + lambda *args, **kwargs: object(), + ) + monkeypatch.setattr( + "basic_memory.services.file_service.FileService", lambda *args, **kwargs: object() + ) + monkeypatch.setattr("basic_memory.services.search_service.SearchService", StubSearchService) + monkeypatch.setattr(db_cmd, "Progress", SilentProgress) + monkeypatch.setattr( + db_cmd.console, + "print", + lambda message="", *args, **kwargs: printed_lines.append(str(message)), + ) + + await db_cmd._reindex( + app_config, + search=False, + embeddings=True, + full=True, + project="foo", + ) + + assert len(vector_reindex_calls) == 1 + assert vector_reindex_calls[0]["force_full"] is True + assert callable(vector_reindex_calls[0]["progress_callback"]) + assert any("full rebuild" in line for line in printed_lines) diff --git a/tests/repository/test_semantic_search_base.py b/tests/repository/test_semantic_search_base.py index 7d8a60d2..acc06845 100644 --- a/tests/repository/test_semantic_search_base.py +++ b/tests/repository/test_semantic_search_base.py @@ -378,6 +378,50 @@ async def _stub_prepare_window(entity_ids: list[int]): assert result.queue_wait_seconds_total == pytest.approx(0.0) +@pytest.mark.asyncio +async def test_sync_entity_vectors_batch_progress_tracks_terminal_entities(monkeypatch): + """Progress callback should advance on terminal entity completion, not prepare entry.""" + repo = _ConcreteRepo() + repo._semantic_enabled = True + repo._embedding_provider = object() + repo._semantic_embedding_sync_batch_size = 2 + + prepared_by_entity = { + 1: _PreparedEntityVectorSync(1, 1.0, 1, []), + 2: _PreparedEntityVectorSync(2, 2.0, 1, [(102, "chunk-2")]), + 3: _PreparedEntityVectorSync(3, 3.0, 1, [(103, "chunk-3")]), + } + progress_events: list[tuple[int, int, int]] = [] + + async def _stub_prepare_window(entity_ids: list[int]): + return [prepared_by_entity[entity_id] for entity_id in entity_ids] + + async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): + for job in flush_jobs: + runtime = entity_runtime[job.entity_id] + runtime.remaining_jobs -= 1 + if runtime.remaining_jobs <= 0: + synced_entity_ids.add(job.entity_id) + return (0.1, 0.2) + + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs_window", _stub_prepare_window) + monkeypatch.setattr(repo, "_flush_embedding_jobs", _stub_flush) + + result = await repo.sync_entity_vectors_batch( + [1, 2, 3], + progress_callback=lambda entity_id, completed, total: progress_events.append( + (entity_id, completed, total) + ), + ) + + assert result.entities_synced == 3 + assert progress_events == [ + (1, 1, 3), + (2, 2, 3), + (3, 3, 3), + ] + + @pytest.mark.asyncio async def test_sync_entity_vectors_batch_continue_on_error(monkeypatch): """Batch sync should continue after per-entity and per-flush failures.""" diff --git a/tests/repository/test_sqlite_vector_search_repository.py b/tests/repository/test_sqlite_vector_search_repository.py index 12ef6e01..ea25675c 100644 --- a/tests/repository/test_sqlite_vector_search_repository.py +++ b/tests/repository/test_sqlite_vector_search_repository.py @@ -319,7 +319,7 @@ def _stub_build_chunk_records(source_rows): @asynccontextmanager async def _track_write_scope(): nonlocal active_write_scopes, max_active_write_scopes - async with repo._sqlite_vec_lock: + async with repo._sqlite_prepare_write_lock: active_write_scopes += 1 max_active_write_scopes = max(max_active_write_scopes, active_write_scopes) try: @@ -361,6 +361,68 @@ async def fake_scoped_session(session_maker): assert max_active_write_scopes == 1 +@pytest.mark.asyncio +async def test_sqlite_prepare_window_does_not_deadlock_when_vec_loading_inside_write_scope( + monkeypatch, +): + """SQLite should keep vec loading and prepare writes on separate locks.""" + repo = _make_sqlite_repo_for_unit_tests() + + async def _stub_fetch_source_rows(session, entity_ids: list[int]): + return {entity_id: [object()] for entity_id in entity_ids} + + async def _stub_fetch_existing_rows(session, entity_ids: list[int]): + return {entity_id: [] for entity_id in entity_ids} + + def _stub_build_chunk_records(source_rows): + return [ + { + "chunk_key": "entity:1:0", + "chunk_text": "chunk text", + "source_hash": "hash", + } + ] + + async def _stub_prepare_vector_session(session): + # Trigger: SQLite prepare writes call _prepare_vector_session() after + # entering the write scope. + # Why: vec loading still needs a lock, but reusing the write lock here + # would deadlock before the first entity completes. + # Outcome: this regression test proves the two concerns stay separate. + async with repo._sqlite_vec_load_lock: + await asyncio.sleep(0) + + async def _stub_upsert( + session, + *, + entity_id: int, + scheduled_records, + existing_by_key, + entity_fingerprint: str, + embedding_model: str, + ): + return [(entity_id * 100, scheduled_records[0]["chunk_text"])] + + @asynccontextmanager + async def fake_scoped_session(session_maker): + yield AsyncMock() + + monkeypatch.setattr( + "basic_memory.repository.search_repository_base.db.scoped_session", + fake_scoped_session, + ) + monkeypatch.setattr(repo, "_fetch_prepare_window_source_rows", _stub_fetch_source_rows) + monkeypatch.setattr(repo, "_fetch_prepare_window_existing_rows", _stub_fetch_existing_rows) + monkeypatch.setattr(repo, "_build_chunk_records", _stub_build_chunk_records) + monkeypatch.setattr(repo, "_prepare_vector_session", _stub_prepare_vector_session) + monkeypatch.setattr(repo, "_upsert_scheduled_chunk_records", _stub_upsert) + + prepared = await asyncio.wait_for(repo._prepare_entity_vector_jobs_window([1]), timeout=1.0) + + assert len(prepared) == 1 + assert prepared[0].entity_id == 1 + + @pytest.mark.asyncio async def test_sqlite_vector_search_returns_ranked_entities(search_repository): """Vector mode ranks entities using sqlite-vec nearest-neighbor search.""" diff --git a/tests/services/test_search_service.py b/tests/services/test_search_service.py index a31eadaf..b4040cce 100644 --- a/tests/services/test_search_service.py +++ b/tests/services/test_search_service.py @@ -1203,7 +1203,7 @@ async def _stub_sync_entity_vectors_batch(entity_ids: list[int], progress_callba assert entity_ids == created_entity_ids if progress_callback: for i, entity_id in enumerate(entity_ids): - progress_callback(entity_id, i, len(entity_ids)) + progress_callback(entity_id, i + 1, len(entity_ids)) return VectorSyncBatchResult( entities_total=len(entity_ids), entities_synced=len(entity_ids), @@ -1236,7 +1236,7 @@ def on_progress(entity_id, index, total): assert len(progress_calls) == stats["total_entities"] # Progress indices should be sequential for i, (_, index, total) in enumerate(progress_calls): - assert index == i + assert index == i + 1 assert total == stats["total_entities"] diff --git a/tests/services/test_semantic_search.py b/tests/services/test_semantic_search.py index 19a1f949..7bbf6b2e 100644 --- a/tests/services/test_semantic_search.py +++ b/tests/services/test_semantic_search.py @@ -111,20 +111,18 @@ async def test_semantic_vector_sync_skips_embed_opt_out_and_clears_vectors( AsyncMock(return_value=SimpleNamespace(id=42, entity_metadata={"embed": False})), ) sync_vectors = AsyncMock() - execute_query = AsyncMock() + delete_entity_vectors = AsyncMock() monkeypatch.setattr(repository, "sync_entity_vectors", sync_vectors) - monkeypatch.setattr(repository, "execute_query", execute_query) + monkeypatch.setattr(repository, "delete_entity_vector_rows", delete_entity_vectors) await search_service.sync_entity_vectors(42) sync_vectors.assert_not_awaited() - assert execute_query.await_count == 2 + delete_entity_vectors.assert_awaited_once_with(42) @pytest.mark.asyncio -async def test_semantic_vector_sync_resumes_when_embed_opt_out_removed( - search_service, monkeypatch -): +async def test_semantic_vector_sync_resumes_when_embed_opt_out_removed(search_service, monkeypatch): """Removing the opt-out should restore normal embedding sync.""" repository = _sqlite_repo(search_service) repository._semantic_enabled = True @@ -170,9 +168,9 @@ async def test_semantic_vector_sync_batch_skips_embed_opt_out_and_reports_skips( entities_failed=0, ) ) - execute_query = AsyncMock() + delete_entity_vectors = AsyncMock() monkeypatch.setattr(repository, "sync_entity_vectors_batch", sync_batch) - monkeypatch.setattr(repository, "execute_query", execute_query) + monkeypatch.setattr(repository, "delete_entity_vector_rows", delete_entity_vectors) result = await search_service.sync_entity_vectors_batch([41, 42]) @@ -181,7 +179,7 @@ async def test_semantic_vector_sync_batch_skips_embed_opt_out_and_reports_skips( assert result.entities_total == 2 assert result.entities_synced == 1 assert result.entities_skipped == 1 - assert execute_query.await_count == 2 + delete_entity_vectors.assert_awaited_once_with(41) @pytest.mark.asyncio @@ -256,6 +254,50 @@ async def test_reindex_vectors_respects_embed_opt_out(search_service, monkeypatc } +@pytest.mark.asyncio +async def test_reindex_vectors_force_full_clears_project_vectors_before_resync( + search_service, monkeypatch +): + """Force-full vector reindex should clear derived vectors before batch sync.""" + repository = _sqlite_repo(search_service) + repository._semantic_enabled = True + + monkeypatch.setattr( + search_service.entity_repository, + "find_all", + AsyncMock( + return_value=[ + SimpleNamespace(id=41, entity_metadata={}), + SimpleNamespace(id=42, entity_metadata={}), + ] + ), + ) + purge_stale_rows = AsyncMock() + delete_project_vectors = AsyncMock() + sync_batch = AsyncMock( + return_value=VectorSyncBatchResult( + entities_total=2, + entities_synced=2, + entities_failed=0, + ) + ) + monkeypatch.setattr(search_service, "_purge_stale_search_rows", purge_stale_rows) + monkeypatch.setattr(repository, "delete_project_vector_rows", delete_project_vectors) + monkeypatch.setattr(search_service, "sync_entity_vectors_batch", sync_batch) + + stats = await search_service.reindex_vectors(force_full=True) + + purge_stale_rows.assert_awaited_once() + delete_project_vectors.assert_awaited_once() + sync_batch.assert_awaited_once_with([41, 42], progress_callback=None) + assert stats == { + "total_entities": 2, + "embedded": 2, + "skipped": 0, + "errors": 0, + } + + @pytest.mark.asyncio async def test_semantic_vector_sync_batch_cleans_up_unknown_ids(search_service, monkeypatch): """Deleted entity IDs should still flow through repository cleanup instead of being dropped.""" @@ -291,7 +333,9 @@ async def test_semantic_vector_sync_batch_cleans_up_unknown_ids(search_service, called_entity_ids = {tuple(call.args[0]) for call in sync_batch.await_args_list} assert called_entity_ids == {(41,), (42,)} progress_callback_calls = [ - call for call in sync_batch.await_args_list if call.kwargs.get("progress_callback") is not None + call + for call in sync_batch.await_args_list + if call.kwargs.get("progress_callback") is not None ] assert len(progress_callback_calls) == 1 assert progress_callback_calls[0].args[0] == [42] diff --git a/tests/sync/test_sync_service_incremental.py b/tests/sync/test_sync_service_incremental.py index 854f58e3..beca0db4 100644 --- a/tests/sync/test_sync_service_incremental.py +++ b/tests/sync/test_sync_service_incremental.py @@ -18,6 +18,7 @@ import pytest from basic_memory.config import ProjectConfig +from basic_memory.indexing.models import IndexingBatchResult from basic_memory.sync.sync_service import SyncService @@ -208,6 +209,32 @@ async def test_force_full_bypasses_watermark_optimization( assert project.last_scan_timestamp > initial_timestamp +@pytest.mark.asyncio +async def test_force_full_reindexes_unchanged_files( + sync_service: SyncService, project_config: ProjectConfig, monkeypatch +): + """Test that force_full rewrites search rows even when the diff report is empty.""" + project_dir = project_config.home + await create_test_file(project_dir / "file1.md", "# File 1\nOriginal") + + # First sync establishes the watermark and initial search rows. + await sync_service.sync(project_dir) + await sleep_past_watermark() + + indexed_batches: list[list[str]] = [] + + async def _stub_index_files(loaded_files, **kwargs): + indexed_batches.append(sorted(loaded_files)) + return IndexingBatchResult() + + monkeypatch.setattr(sync_service.batch_indexer, "index_files", _stub_index_files) + + report = await sync_service.sync(project_dir, force_full=True, sync_embeddings=False) + + assert report.total == 0 + assert indexed_batches == [["file1.md"]] + + # ============================================================================== # Incremental Scan Base Cases # ============================================================================== From 36ca070c2e2ef65863fae448738c339a32747aa3 Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 14:26:48 -0500 Subject: [PATCH 03/14] wip: tune fastembed defaults Signed-off-by: phernandez --- .../repository/embedding_provider_factory.py | 51 ++++++++++-- .../repository/search_repository_base.py | 46 +++++++++++ tests/repository/test_openai_provider.py | 77 +++++++++++++++++++ tests/repository/test_semantic_search_base.py | 50 ++++++++++++ 4 files changed, 218 insertions(+), 6 deletions(-) diff --git a/src/basic_memory/repository/embedding_provider_factory.py b/src/basic_memory/repository/embedding_provider_factory.py index e259c62b..2447733c 100644 --- a/src/basic_memory/repository/embedding_provider_factory.py +++ b/src/basic_memory/repository/embedding_provider_factory.py @@ -1,5 +1,6 @@ """Factory for creating configured semantic embedding providers.""" +import os from threading import Lock from basic_memory.config import BasicMemoryConfig @@ -20,8 +21,45 @@ _EMBEDDING_PROVIDER_CACHE_LOCK = Lock() +def _available_cpu_count() -> int | None: + """Return the CPU budget available to this process when the runtime exposes it.""" + process_cpu_count = getattr(os, "process_cpu_count", None) + if callable(process_cpu_count): + cpu_count = process_cpu_count() + if cpu_count is not None and cpu_count > 0: + return cpu_count + + cpu_count = os.cpu_count() + return cpu_count if cpu_count is not None and cpu_count > 0 else None + + +def _resolve_fastembed_runtime_knobs(app_config: BasicMemoryConfig) -> tuple[int | None, int | None]: + """Resolve FastEmbed threads/parallel from explicit config or CPU-aware defaults.""" + configured_threads = app_config.semantic_embedding_threads + configured_parallel = app_config.semantic_embedding_parallel + if configured_threads is not None or configured_parallel is not None: + return configured_threads, configured_parallel + + available_cpus = _available_cpu_count() + if available_cpus is None: + return None, None + + # Trigger: local laptops and cloud workers expose different CPU budgets. + # Why: FastEmbed throughput wants enough ONNX threads to use the machine, + # but the multiprocessing-style ``parallel`` fan-out can add a lot of + # overhead for this workload and make full rebuilds slower instead of faster. + # Outcome: when config leaves the knobs unset, each process uses a bounded + # thread count and keeps FastEmbed on the simpler single-process path. + if available_cpus <= 2: + return available_cpus, 1 + + threads = min(8, available_cpus) + return threads, 1 + + def _provider_cache_key(app_config: BasicMemoryConfig) -> ProviderCacheKey: """Build a stable cache key from provider-relevant semantic embedding config.""" + resolved_threads, resolved_parallel = _resolve_fastembed_runtime_knobs(app_config) return ( app_config.semantic_embedding_provider.strip().lower(), app_config.semantic_embedding_model, @@ -29,8 +67,8 @@ def _provider_cache_key(app_config: BasicMemoryConfig) -> ProviderCacheKey: app_config.semantic_embedding_batch_size, app_config.semantic_embedding_request_concurrency, app_config.semantic_embedding_cache_dir, - app_config.semantic_embedding_threads, - app_config.semantic_embedding_parallel, + resolved_threads, + resolved_parallel, ) @@ -61,12 +99,13 @@ def create_embedding_provider(app_config: BasicMemoryConfig) -> EmbeddingProvide # Deferred import: fastembed (and its onnxruntime dep) may not be installed from basic_memory.repository.fastembed_provider import FastEmbedEmbeddingProvider + resolved_threads, resolved_parallel = _resolve_fastembed_runtime_knobs(app_config) if app_config.semantic_embedding_cache_dir is not None: extra_kwargs["cache_dir"] = app_config.semantic_embedding_cache_dir - if app_config.semantic_embedding_threads is not None: - extra_kwargs["threads"] = app_config.semantic_embedding_threads - if app_config.semantic_embedding_parallel is not None: - extra_kwargs["parallel"] = app_config.semantic_embedding_parallel + if resolved_threads is not None: + extra_kwargs["threads"] = resolved_threads + if resolved_parallel is not None: + extra_kwargs["parallel"] = resolved_parallel provider = FastEmbedEmbeddingProvider( model_name=app_config.semantic_embedding_model, diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index c18a3238..827c5414 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -800,6 +800,7 @@ async def _sync_entity_vectors_internal( batch_start = time.perf_counter() backend_name = type(self).__name__.removesuffix("SearchRepository").lower() + self._log_vector_sync_runtime_settings(backend_name=backend_name, entities_total=total_entities) logger.info( "Vector batch sync start: project_id={project_id} entities_total={entities_total} " "sync_batch_size={sync_batch_size} prepare_window_size={prepare_window_size}", @@ -1595,6 +1596,51 @@ def _finalize_completed_entity_syncs( return queue_wait_seconds_total + def _log_vector_sync_runtime_settings(self, *, backend_name: str, entities_total: int) -> None: + """Log the resolved embedding runtime knobs before the first prepare window. + + Trigger: a vector sync batch is about to start real work. + Why: operators need one place to confirm the provider/runtime settings that + this run will actually use, especially when threads/parallel are auto-tuned. + Outcome: the log shows the resolved values once per batch without changing + the hot-path control flow or adding more telemetry structure. + """ + assert self._embedding_provider is not None + + from basic_memory.repository.fastembed_provider import FastEmbedEmbeddingProvider + + provider = self._embedding_provider + if isinstance(provider, FastEmbedEmbeddingProvider): + logger.info( + "Vector batch runtime settings: project_id={project_id} backend={backend} " + "entities_total={entities_total} provider={provider} model_name={model_name} " + "dimensions={dimensions} provider_batch_size={provider_batch_size} " + "sync_batch_size={sync_batch_size} threads={threads} " + "configured_parallel={configured_parallel} effective_parallel={effective_parallel}", + project_id=self.project_id, + backend=backend_name, + entities_total=entities_total, + provider=type(provider).__name__, + model_name=provider.model_name, + dimensions=provider.dimensions, + provider_batch_size=provider.batch_size, + sync_batch_size=self._semantic_embedding_sync_batch_size, + threads=provider.threads, + configured_parallel=provider.parallel, + effective_parallel=provider._effective_parallel(), + ) + return + + logger.info( + "Vector batch runtime settings: project_id={project_id} backend={backend} " + "entities_total={entities_total} provider={provider} sync_batch_size={sync_batch_size}", + project_id=self.project_id, + backend=backend_name, + entities_total=entities_total, + provider=type(provider).__name__, + sync_batch_size=self._semantic_embedding_sync_batch_size, + ) + def _log_vector_sync_complete( self, *, diff --git a/tests/repository/test_openai_provider.py b/tests/repository/test_openai_provider.py index e9db6bbe..456d320b 100644 --- a/tests/repository/test_openai_provider.py +++ b/tests/repository/test_openai_provider.py @@ -8,6 +8,7 @@ import pytest from basic_memory.config import BasicMemoryConfig +import basic_memory.repository.embedding_provider_factory as embedding_provider_factory_module from basic_memory.repository.embedding_provider_factory import ( create_embedding_provider, reset_embedding_provider_cache, @@ -264,6 +265,52 @@ def test_embedding_provider_factory_forwards_fastembed_runtime_knobs(): assert provider.parallel == 2 +def test_embedding_provider_factory_auto_tunes_fastembed_runtime_knobs_from_cpu_budget(monkeypatch): + """Unset FastEmbed runtime knobs should resolve from available CPU budget.""" + monkeypatch.setattr(embedding_provider_factory_module.os, "process_cpu_count", lambda: 8) + monkeypatch.setattr(embedding_provider_factory_module.os, "cpu_count", lambda: 8) + + config = BasicMemoryConfig( + env="test", + projects={"test-project": "/tmp/basic-memory-test"}, + default_project="test-project", + semantic_search_enabled=True, + semantic_embedding_provider="fastembed", + semantic_embedding_threads=None, + semantic_embedding_parallel=None, + ) + + provider = create_embedding_provider(config) + + assert isinstance(provider, FastEmbedEmbeddingProvider) + assert provider.threads == 8 + assert provider.parallel == 1 + + +def test_embedding_provider_factory_auto_tuning_stays_conservative_on_small_cpu_budget( + monkeypatch, +): + """Small workers should not get an oversized FastEmbed runtime footprint.""" + monkeypatch.setattr(embedding_provider_factory_module.os, "process_cpu_count", lambda: 2) + monkeypatch.setattr(embedding_provider_factory_module.os, "cpu_count", lambda: 2) + + config = BasicMemoryConfig( + env="test", + projects={"test-project": "/tmp/basic-memory-test"}, + default_project="test-project", + semantic_search_enabled=True, + semantic_embedding_provider="fastembed", + semantic_embedding_threads=None, + semantic_embedding_parallel=None, + ) + + provider = create_embedding_provider(config) + + assert isinstance(provider, FastEmbedEmbeddingProvider) + assert provider.threads == 2 + assert provider.parallel == 1 + + def test_embedding_provider_factory_reuses_provider_for_same_cache_key(): """Factory should reuse the same provider instance for identical config values.""" config_a = BasicMemoryConfig( @@ -289,6 +336,36 @@ def test_embedding_provider_factory_reuses_provider_for_same_cache_key(): assert provider_a is provider_b +def test_embedding_provider_factory_reuses_auto_tuned_provider_for_same_cpu_budget(monkeypatch): + """Auto-tuned FastEmbed providers should still reuse the process cache.""" + monkeypatch.setattr(embedding_provider_factory_module.os, "process_cpu_count", lambda: 8) + monkeypatch.setattr(embedding_provider_factory_module.os, "cpu_count", lambda: 8) + + config_a = BasicMemoryConfig( + env="test", + projects={"test-project": "/tmp/basic-memory-test"}, + default_project="test-project", + semantic_search_enabled=True, + semantic_embedding_provider="fastembed", + semantic_embedding_threads=None, + semantic_embedding_parallel=None, + ) + config_b = BasicMemoryConfig( + env="test", + projects={"test-project": "/tmp/basic-memory-test"}, + default_project="test-project", + semantic_search_enabled=True, + semantic_embedding_provider="fastembed", + semantic_embedding_threads=None, + semantic_embedding_parallel=None, + ) + + provider_a = create_embedding_provider(config_a) + provider_b = create_embedding_provider(config_b) + + assert provider_a is provider_b + + @pytest.mark.asyncio async def test_openai_provider_runs_batches_concurrently_and_preserves_output_order(monkeypatch): """Concurrent request fan-out should keep batch order stable.""" diff --git a/tests/repository/test_semantic_search_base.py b/tests/repository/test_semantic_search_base.py index acc06845..f42a32f0 100644 --- a/tests/repository/test_semantic_search_base.py +++ b/tests/repository/test_semantic_search_base.py @@ -12,6 +12,7 @@ import pytest import basic_memory.repository.search_repository_base as search_repository_base_module +from basic_memory.repository.fastembed_provider import FastEmbedEmbeddingProvider from basic_memory.repository.search_repository_base import ( MAX_VECTOR_CHUNK_CHARS, SearchRepositoryBase, @@ -702,3 +703,52 @@ async def _stub_flush(flush_jobs, entity_runtime, synced_entity_ids): assert histogram_names.count("vector_sync_write_seconds") == 2 assert histogram_names.count("vector_sync_batch_total_seconds") == 1 assert [name for name, _, _ in counter_calls].count("vector_sync_entities_total") == 1 + + +@pytest.mark.asyncio +async def test_sync_entity_vectors_batch_logs_resolved_fastembed_runtime_settings(monkeypatch): + """Batch start should log the resolved FastEmbed knobs that shape this run.""" + repo = _ConcreteRepo() + repo._semantic_enabled = True + repo._embedding_provider = FastEmbedEmbeddingProvider( + batch_size=128, + dimensions=384, + threads=4, + parallel=2, + ) + + async def _stub_prepare_window(entity_ids: list[int]): + return [ + _PreparedEntityVectorSync( + entity_id=entity_id, + sync_start=0.0, + source_rows_count=1, + embedding_jobs=[], + entity_skipped=True, + ) + for entity_id in entity_ids + ] + + info_calls: list[tuple[str, dict]] = [] + + def _capture_info(message: str, **kwargs): + info_calls.append((message, kwargs)) + + monkeypatch.setattr(repo, "_prepare_entity_vector_jobs_window", _stub_prepare_window) + monkeypatch.setattr(search_repository_base_module.logger, "info", _capture_info) + + result = await repo.sync_entity_vectors_batch([1]) + + assert result.entities_synced == 1 + runtime_logs = [ + kwargs + for message, kwargs in info_calls + if message.startswith("Vector batch runtime settings:") + ] + assert len(runtime_logs) == 1 + assert runtime_logs[0]["model_name"] == "bge-small-en-v1.5" + assert runtime_logs[0]["provider_batch_size"] == 128 + assert runtime_logs[0]["sync_batch_size"] == 64 + assert runtime_logs[0]["threads"] == 4 + assert runtime_logs[0]["configured_parallel"] == 2 + assert runtime_logs[0]["effective_parallel"] == 2 From a68933aec5a2b1aa615018a1a64a1e68b08de3af Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 15:13:55 -0500 Subject: [PATCH 04/14] perf(core): tune fastembed auto threads Signed-off-by: phernandez --- src/basic_memory/cli/commands/db.py | 4 ++-- .../repository/embedding_provider_factory.py | 18 ++++++++------- tests/repository/test_openai_provider.py | 22 +++++++++++++++++++ 3 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/basic_memory/cli/commands/db.py b/src/basic_memory/cli/commands/db.py index 07f40f5c..fea4f88d 100644 --- a/src/basic_memory/cli/commands/db.py +++ b/src/basic_memory/cli/commands/db.py @@ -273,7 +273,7 @@ async def on_index_progress(update: IndexProgress) -> None: ) progress.update(task, completed=progress.tasks[task].total or 1) - console.print(" [green]✓[/green] Full-text search index rebuilt") + console.print(" [green]done[/green] Full-text search index rebuilt") if embeddings: embedding_mode_label = "full rebuild" if full else "incremental sync" @@ -322,7 +322,7 @@ def on_progress(entity_id, index, total): progress.update(task, completed=stats["total_entities"]) console.print( - f" [green]✓[/green] Embeddings complete: " + f" [green]done[/green] Embeddings complete: " f"{stats['embedded']} entities embedded, " f"{stats['skipped']} skipped, " f"{stats['errors']} errors" diff --git a/src/basic_memory/repository/embedding_provider_factory.py b/src/basic_memory/repository/embedding_provider_factory.py index 2447733c..7b4988c4 100644 --- a/src/basic_memory/repository/embedding_provider_factory.py +++ b/src/basic_memory/repository/embedding_provider_factory.py @@ -26,14 +26,16 @@ def _available_cpu_count() -> int | None: process_cpu_count = getattr(os, "process_cpu_count", None) if callable(process_cpu_count): cpu_count = process_cpu_count() - if cpu_count is not None and cpu_count > 0: + if isinstance(cpu_count, int) and cpu_count > 0: return cpu_count cpu_count = os.cpu_count() return cpu_count if cpu_count is not None and cpu_count > 0 else None -def _resolve_fastembed_runtime_knobs(app_config: BasicMemoryConfig) -> tuple[int | None, int | None]: +def _resolve_fastembed_runtime_knobs( + app_config: BasicMemoryConfig, +) -> tuple[int | None, int | None]: """Resolve FastEmbed threads/parallel from explicit config or CPU-aware defaults.""" configured_threads = app_config.semantic_embedding_threads configured_parallel = app_config.semantic_embedding_parallel @@ -45,15 +47,15 @@ def _resolve_fastembed_runtime_knobs(app_config: BasicMemoryConfig) -> tuple[int return None, None # Trigger: local laptops and cloud workers expose different CPU budgets. - # Why: FastEmbed throughput wants enough ONNX threads to use the machine, - # but the multiprocessing-style ``parallel`` fan-out can add a lot of - # overhead for this workload and make full rebuilds slower instead of faster. - # Outcome: when config leaves the knobs unset, each process uses a bounded - # thread count and keeps FastEmbed on the simpler single-process path. + # Why: full rebuilds got faster when FastEmbed used most, but not all, of + # the available CPUs. Leaving a little headroom avoids starving the rest of + # the pipeline while still giving ONNX enough threads to stay busy. + # Outcome: when config leaves the knobs unset, each process reserves a small + # CPU cushion and keeps FastEmbed on the simpler single-process path. if available_cpus <= 2: return available_cpus, 1 - threads = min(8, available_cpus) + threads = min(8, max(2, available_cpus - 2)) return threads, 1 diff --git a/tests/repository/test_openai_provider.py b/tests/repository/test_openai_provider.py index 456d320b..7152ff0d 100644 --- a/tests/repository/test_openai_provider.py +++ b/tests/repository/test_openai_provider.py @@ -282,6 +282,28 @@ def test_embedding_provider_factory_auto_tunes_fastembed_runtime_knobs_from_cpu_ provider = create_embedding_provider(config) + assert isinstance(provider, FastEmbedEmbeddingProvider) + assert provider.threads == 6 + assert provider.parallel == 1 + + +def test_embedding_provider_factory_auto_tuning_caps_large_cpu_budgets(monkeypatch): + """Large workers should still leave some headroom and stop at the thread cap.""" + monkeypatch.setattr(embedding_provider_factory_module.os, "process_cpu_count", lambda: 16) + monkeypatch.setattr(embedding_provider_factory_module.os, "cpu_count", lambda: 16) + + config = BasicMemoryConfig( + env="test", + projects={"test-project": "/tmp/basic-memory-test"}, + default_project="test-project", + semantic_search_enabled=True, + semantic_embedding_provider="fastembed", + semantic_embedding_threads=None, + semantic_embedding_parallel=None, + ) + + provider = create_embedding_provider(config) + assert isinstance(provider, FastEmbedEmbeddingProvider) assert provider.threads == 8 assert provider.parallel == 1 From d59b974912797fe4de2624a1cf03f454d8b17695 Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 15:46:09 -0500 Subject: [PATCH 05/14] refactor(core): address review nits Signed-off-by: phernandez --- .../repository/embedding_provider.py | 6 +- .../repository/embedding_provider_factory.py | 3 +- .../repository/fastembed_provider.py | 9 + .../repository/openai_provider.py | 7 + .../repository/search_repository_base.py | 244 +++++++++--------- tests/repository/test_openai_provider.py | 22 ++ 6 files changed, 168 insertions(+), 123 deletions(-) diff --git a/src/basic_memory/repository/embedding_provider.py b/src/basic_memory/repository/embedding_provider.py index 4392acb9..0d8640e7 100644 --- a/src/basic_memory/repository/embedding_provider.py +++ b/src/basic_memory/repository/embedding_provider.py @@ -1,6 +1,6 @@ """Embedding provider protocol for pluggable semantic backends.""" -from typing import Protocol +from typing import Any, Protocol class EmbeddingProvider(Protocol): @@ -16,3 +16,7 @@ async def embed_query(self, text: str) -> list[float]: async def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed a list of document chunks.""" ... + + def runtime_log_attrs(self) -> dict[str, Any]: + """Return provider-specific runtime settings suitable for startup logs.""" + ... diff --git a/src/basic_memory/repository/embedding_provider_factory.py b/src/basic_memory/repository/embedding_provider_factory.py index 7b4988c4..a0e7dbcb 100644 --- a/src/basic_memory/repository/embedding_provider_factory.py +++ b/src/basic_memory/repository/embedding_provider_factory.py @@ -19,6 +19,7 @@ _EMBEDDING_PROVIDER_CACHE: dict[ProviderCacheKey, EmbeddingProvider] = {} _EMBEDDING_PROVIDER_CACHE_LOCK = Lock() +_FASTEMBED_MAX_THREADS = 8 def _available_cpu_count() -> int | None: @@ -55,7 +56,7 @@ def _resolve_fastembed_runtime_knobs( if available_cpus <= 2: return available_cpus, 1 - threads = min(8, max(2, available_cpus - 2)) + threads = min(_FASTEMBED_MAX_THREADS, max(2, available_cpus - 2)) return threads, 1 diff --git a/src/basic_memory/repository/fastembed_provider.py b/src/basic_memory/repository/fastembed_provider.py index 5dc90898..e3635148 100644 --- a/src/basic_memory/repository/fastembed_provider.py +++ b/src/basic_memory/repository/fastembed_provider.py @@ -24,6 +24,15 @@ class FastEmbedEmbeddingProvider(EmbeddingProvider): def _effective_parallel(self) -> int | None: return self.parallel if self.parallel is not None and self.parallel > 1 else None + def runtime_log_attrs(self) -> dict[str, int | str | None]: + """Return the resolved runtime knobs that shape FastEmbed throughput.""" + return { + "provider_batch_size": self.batch_size, + "threads": self.threads, + "configured_parallel": self.parallel, + "effective_parallel": self._effective_parallel(), + } + def __init__( self, model_name: str = "bge-small-en-v1.5", diff --git a/src/basic_memory/repository/openai_provider.py b/src/basic_memory/repository/openai_provider.py index 479bce12..b44e13b7 100644 --- a/src/basic_memory/repository/openai_provider.py +++ b/src/basic_memory/repository/openai_provider.py @@ -34,6 +34,13 @@ def __init__( self._client: Any | None = None self._client_lock = asyncio.Lock() + def runtime_log_attrs(self) -> dict[str, int]: + """Return the request fan-out knobs that shape API embedding batches.""" + return { + "provider_batch_size": self.batch_size, + "request_concurrency": self.request_concurrency, + } + async def _get_client(self) -> Any: if self._client is not None: return self._client diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index 827c5414..9e372600 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -37,6 +37,7 @@ HEADER_LINE_PATTERN = re.compile(r"^\s*#{1,6}\s+") BULLET_PATTERN = re.compile(r"^[\-\*]\s+") OVERSIZED_ENTITY_VECTOR_SHARD_SIZE = 256 +_SQLITE_MAX_PREPARE_WINDOW = 8 @dataclass @@ -800,7 +801,9 @@ async def _sync_entity_vectors_internal( batch_start = time.perf_counter() backend_name = type(self).__name__.removesuffix("SearchRepository").lower() - self._log_vector_sync_runtime_settings(backend_name=backend_name, entities_total=total_entities) + self._log_vector_sync_runtime_settings( + backend_name=backend_name, entities_total=total_entities + ) logger.info( "Vector batch sync start: project_id={project_id} entities_total={entities_total} " "sync_batch_size={sync_batch_size} prepare_window_size={prepare_window_size}", @@ -1006,102 +1009,104 @@ def emit_progress(entity_id: int) -> None: for failed_entity_id in affected_entity_ids: emit_progress(failed_entity_id) - # Trigger: this should never happen after all flushes succeed. - # Why: remaining jobs mean runtime tracking drifted from queued jobs. - # Outcome: fail-safe marks these entities as failed to avoid false positives. - if entity_runtime: - orphan_runtime_entities = sorted(entity_runtime.keys()) - failed_entity_ids.update(orphan_runtime_entities) - synced_entity_ids.difference_update(orphan_runtime_entities) - deferred_entity_ids.difference_update(orphan_runtime_entities) - logger.warning( - "Vector batch sync left unfinished entities after flushes: " - "project_id={project_id} unfinished_entities={unfinished_entities}", + # Trigger: this should never happen after all flushes succeed. + # Why: remaining jobs mean runtime tracking drifted from queued jobs. + # Outcome: fail-safe marks these entities as failed to avoid false positives. + if entity_runtime: + orphan_runtime_entities = sorted(entity_runtime.keys()) + failed_entity_ids.update(orphan_runtime_entities) + synced_entity_ids.difference_update(orphan_runtime_entities) + deferred_entity_ids.difference_update(orphan_runtime_entities) + logger.warning( + "Vector batch sync left unfinished entities after flushes: " + "project_id={project_id} unfinished_entities={unfinished_entities}", + project_id=self.project_id, + unfinished_entities=orphan_runtime_entities, + ) + for failed_entity_id in orphan_runtime_entities: + emit_progress(failed_entity_id) + + # Keep result counters aligned with successful/failed terminal states. + synced_entity_ids.difference_update(failed_entity_ids) + deferred_entity_ids.difference_update(failed_entity_ids) + deferred_entity_ids.difference_update(synced_entity_ids) + result.failed_entity_ids = sorted(failed_entity_ids) + result.entities_failed = len(result.failed_entity_ids) + result.entities_deferred = len(deferred_entity_ids) + result.entities_synced = len(synced_entity_ids) + + logger.info( + "Vector batch sync complete: project_id={project_id} entities_total={entities_total} " + "entities_synced={entities_synced} entities_failed={entities_failed} " + "entities_deferred={entities_deferred} " + "entities_skipped={entities_skipped} chunks_total={chunks_total} " + "chunks_skipped={chunks_skipped} embedding_jobs_total={embedding_jobs_total} " + "prepare_seconds_total={prepare_seconds_total:.3f} " + "queue_wait_seconds_total={queue_wait_seconds_total:.3f} " + "embed_seconds_total={embed_seconds_total:.3f} write_seconds_total={write_seconds_total:.3f}", project_id=self.project_id, - unfinished_entities=orphan_runtime_entities, + entities_total=result.entities_total, + entities_synced=result.entities_synced, + entities_failed=result.entities_failed, + entities_deferred=result.entities_deferred, + entities_skipped=result.entities_skipped, + chunks_total=result.chunks_total, + chunks_skipped=result.chunks_skipped, + embedding_jobs_total=result.embedding_jobs_total, + prepare_seconds_total=result.prepare_seconds_total, + queue_wait_seconds_total=result.queue_wait_seconds_total, + embed_seconds_total=result.embed_seconds_total, + write_seconds_total=result.write_seconds_total, ) - for failed_entity_id in orphan_runtime_entities: - emit_progress(failed_entity_id) - - # Keep result counters aligned with successful/failed terminal states. - synced_entity_ids.difference_update(failed_entity_ids) - deferred_entity_ids.difference_update(failed_entity_ids) - deferred_entity_ids.difference_update(synced_entity_ids) - result.failed_entity_ids = sorted(failed_entity_ids) - result.entities_failed = len(result.failed_entity_ids) - result.entities_deferred = len(deferred_entity_ids) - result.entities_synced = len(synced_entity_ids) - - logger.info( - "Vector batch sync complete: project_id={project_id} entities_total={entities_total} " - "entities_synced={entities_synced} entities_failed={entities_failed} " - "entities_deferred={entities_deferred} " - "entities_skipped={entities_skipped} chunks_total={chunks_total} " - "chunks_skipped={chunks_skipped} embedding_jobs_total={embedding_jobs_total} " - "prepare_seconds_total={prepare_seconds_total:.3f} " - "queue_wait_seconds_total={queue_wait_seconds_total:.3f} " - "embed_seconds_total={embed_seconds_total:.3f} write_seconds_total={write_seconds_total:.3f}", - project_id=self.project_id, - entities_total=result.entities_total, - entities_synced=result.entities_synced, - entities_failed=result.entities_failed, - entities_deferred=result.entities_deferred, - entities_skipped=result.entities_skipped, - chunks_total=result.chunks_total, - chunks_skipped=result.chunks_skipped, - embedding_jobs_total=result.embedding_jobs_total, - prepare_seconds_total=result.prepare_seconds_total, - queue_wait_seconds_total=result.queue_wait_seconds_total, - embed_seconds_total=result.embed_seconds_total, - write_seconds_total=result.write_seconds_total, - ) - batch_total_seconds = time.perf_counter() - batch_start - metric_attrs = { - "backend": backend_name, - "skip_only_batch": result.embedding_jobs_total == 0, - } - telemetry.record_histogram( - "vector_sync_batch_total_seconds", - batch_total_seconds, - unit="s", - **metric_attrs, - ) - telemetry.add_counter("vector_sync_entities_total", result.entities_total, **metric_attrs) - telemetry.add_counter( - "vector_sync_entities_skipped", - result.entities_skipped, - **metric_attrs, - ) - telemetry.add_counter( - "vector_sync_entities_deferred", - result.entities_deferred, - **metric_attrs, - ) - telemetry.add_counter( - "vector_sync_embedding_jobs_total", - result.embedding_jobs_total, - **metric_attrs, - ) - telemetry.add_counter("vector_sync_chunks_total", result.chunks_total, **metric_attrs) - telemetry.add_counter( - "vector_sync_chunks_skipped", - result.chunks_skipped, - **metric_attrs, - ) - if batch_span is not None: - batch_span.set_attributes( - { - "backend": backend_name, - "entities_synced": result.entities_synced, - "entities_failed": result.entities_failed, - "entities_deferred": result.entities_deferred, - "entities_skipped": result.entities_skipped, - "embedding_jobs_total": result.embedding_jobs_total, - "chunks_total": result.chunks_total, - "chunks_skipped": result.chunks_skipped, - "batch_total_seconds": batch_total_seconds, - } + batch_total_seconds = time.perf_counter() - batch_start + metric_attrs = { + "backend": backend_name, + "skip_only_batch": result.embedding_jobs_total == 0, + } + telemetry.record_histogram( + "vector_sync_batch_total_seconds", + batch_total_seconds, + unit="s", + **metric_attrs, + ) + telemetry.add_counter( + "vector_sync_entities_total", result.entities_total, **metric_attrs + ) + telemetry.add_counter( + "vector_sync_entities_skipped", + result.entities_skipped, + **metric_attrs, + ) + telemetry.add_counter( + "vector_sync_entities_deferred", + result.entities_deferred, + **metric_attrs, + ) + telemetry.add_counter( + "vector_sync_embedding_jobs_total", + result.embedding_jobs_total, + **metric_attrs, + ) + telemetry.add_counter("vector_sync_chunks_total", result.chunks_total, **metric_attrs) + telemetry.add_counter( + "vector_sync_chunks_skipped", + result.chunks_skipped, + **metric_attrs, ) + if batch_span is not None: + batch_span.set_attributes( + { + "backend": backend_name, + "entities_synced": result.entities_synced, + "entities_failed": result.entities_failed, + "entities_deferred": result.entities_deferred, + "entities_skipped": result.entities_skipped, + "embedding_jobs_total": result.embedding_jobs_total, + "chunks_total": result.chunks_total, + "chunks_skipped": result.chunks_skipped, + "batch_total_seconds": batch_total_seconds, + } + ) return result @@ -1113,7 +1118,10 @@ def _vector_prepare_window_size(self) -> int: # explode to the full embed batch size creates unnecessary write contention. # Outcome: local backends get a small bounded window, while Postgres keeps # its explicit higher concurrency override. - return max(1, min(self._semantic_embedding_sync_batch_size, 8)) + return max( + 1, + min(self._semantic_embedding_sync_batch_size, _SQLITE_MAX_PREPARE_WINDOW), + ) @asynccontextmanager async def _prepare_entity_write_scope(self): @@ -1223,14 +1231,18 @@ async def _prepare_entity_vector_jobs_window( session, entity_ids ) except Exception as exc: + # Trigger: the shared read pass failed before we had entity-level diffs. + # Why: once the window-level read session breaks, we cannot safely + # distinguish one entity from another inside that window. + # Outcome: every entity in the window gets the same failure object. return [exc for _ in entity_ids] # Trigger: prepare now does one shared read pass per window instead of # paying the same select/join round-trips per entity. # Why: both SQLite and Postgres were still burning wall clock in read-side # fingerprint/orphan checks even when every entity ended up skipped. - # Outcome: we batch the reads once, then fan back out over entities while - # preserving input order in the gathered results. + # Outcome: we batch the reads once, close that shared read session, and + # then fan back out over entities while preserving input order. prepared_window = await asyncio.gather( *( self._prepare_entity_vector_jobs_prefetched( @@ -1264,7 +1276,8 @@ async def _prepare_entity_vector_jobs_prefetched( prepare_start = sync_start source_rows_count = len(source_rows) - if not source_rows: + async def _delete_entity_chunks_and_finish() -> _PreparedEntityVectorSync: + """Delete derived rows and return the empty prepare result.""" async with self._prepare_entity_write_scope(): async with db.scoped_session(self.session_maker) as session: await self._prepare_vector_session(session) @@ -1279,22 +1292,13 @@ async def _prepare_entity_vector_jobs_prefetched( prepare_seconds=prepare_seconds, ) + if not source_rows: + return await _delete_entity_chunks_and_finish() + chunk_records = self._build_chunk_records(source_rows) built_chunk_records_count = len(chunk_records) if not chunk_records: - async with self._prepare_entity_write_scope(): - async with db.scoped_session(self.session_maker) as session: - await self._prepare_vector_session(session) - await self._delete_entity_chunks(session, entity_id) - await session.commit() - prepare_seconds = time.perf_counter() - prepare_start - return _PreparedEntityVectorSync( - entity_id=entity_id, - sync_start=sync_start, - source_rows_count=source_rows_count, - embedding_jobs=[], - prepare_seconds=prepare_seconds, - ) + return await _delete_entity_chunks_and_finish() current_entity_fingerprint = self._build_entity_fingerprint(chunk_records) current_embedding_model = self._embedding_model_key() @@ -1607,27 +1611,25 @@ def _log_vector_sync_runtime_settings(self, *, backend_name: str, entities_total """ assert self._embedding_provider is not None - from basic_memory.repository.fastembed_provider import FastEmbedEmbeddingProvider - provider = self._embedding_provider - if isinstance(provider, FastEmbedEmbeddingProvider): + runtime_attrs = ( + provider.runtime_log_attrs() if hasattr(provider, "runtime_log_attrs") else {} + ) + if runtime_attrs: logger.info( "Vector batch runtime settings: project_id={project_id} backend={backend} " "entities_total={entities_total} provider={provider} model_name={model_name} " - "dimensions={dimensions} provider_batch_size={provider_batch_size} " - "sync_batch_size={sync_batch_size} threads={threads} " - "configured_parallel={configured_parallel} effective_parallel={effective_parallel}", + "dimensions={dimensions} sync_batch_size={sync_batch_size} " + "{runtime_attrs}", project_id=self.project_id, backend=backend_name, entities_total=entities_total, provider=type(provider).__name__, model_name=provider.model_name, dimensions=provider.dimensions, - provider_batch_size=provider.batch_size, sync_batch_size=self._semantic_embedding_sync_batch_size, - threads=provider.threads, - configured_parallel=provider.parallel, - effective_parallel=provider._effective_parallel(), + runtime_attrs=" ".join(f"{key}={value}" for key, value in runtime_attrs.items()), + **runtime_attrs, ) return diff --git a/tests/repository/test_openai_provider.py b/tests/repository/test_openai_provider.py index 7152ff0d..e76882a6 100644 --- a/tests/repository/test_openai_provider.py +++ b/tests/repository/test_openai_provider.py @@ -265,6 +265,28 @@ def test_embedding_provider_factory_forwards_fastembed_runtime_knobs(): assert provider.parallel == 2 +def test_fastembed_provider_reports_runtime_log_attrs(): + """FastEmbed should expose the resolved runtime knobs for batch startup logs.""" + provider = FastEmbedEmbeddingProvider(batch_size=128, threads=4, parallel=2) + + assert provider.runtime_log_attrs() == { + "provider_batch_size": 128, + "threads": 4, + "configured_parallel": 2, + "effective_parallel": 2, + } + + +def test_openai_provider_reports_runtime_log_attrs(): + """OpenAI provider should expose API batch fan-out settings for startup logs.""" + provider = OpenAIEmbeddingProvider(batch_size=32, request_concurrency=6) + + assert provider.runtime_log_attrs() == { + "provider_batch_size": 32, + "request_concurrency": 6, + } + + def test_embedding_provider_factory_auto_tunes_fastembed_runtime_knobs_from_cpu_budget(monkeypatch): """Unset FastEmbed runtime knobs should resolve from available CPU budget.""" monkeypatch.setattr(embedding_provider_factory_module.os, "process_cpu_count", lambda: 8) From 794dcc65d1e5257a6630f37320b7dd051baeb27a Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 20:09:41 -0500 Subject: [PATCH 06/14] perf(core): lower vector sync batch defaults Signed-off-by: phernandez --- src/basic_memory/config.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/basic_memory/config.py b/src/basic_memory/config.py index 62579bee..3fa87580 100644 --- a/src/basic_memory/config.py +++ b/src/basic_memory/config.py @@ -188,8 +188,14 @@ class BasicMemoryConfig(BaseSettings): default=None, description="Embedding vector dimensions. Auto-detected from provider if not set (384 for FastEmbed, 1536 for OpenAI).", ) + # Trigger: full local rebuilds spend most of their time waiting behind shared + # embed flushes, not constructing vectors themselves. + # Why: smaller FastEmbed batches cut queue wait far more than they increase + # write overhead on real-world projects, which makes full reindex materially faster. + # Outcome: default to the smaller local/cloud-safe batch size we benchmarked as + # the current best end-to-end setting in the shared vector sync pipeline. semantic_embedding_batch_size: int = Field( - default=64, + default=2, description="Batch size for embedding generation.", gt=0, ) @@ -199,7 +205,7 @@ class BasicMemoryConfig(BaseSettings): gt=0, ) semantic_embedding_sync_batch_size: int = Field( - default=64, + default=2, description="Batch size for vector sync orchestration flushes.", gt=0, ) From 0434d9c6eee690280aab055169988fecc458487c Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 20:34:18 -0500 Subject: [PATCH 07/14] fix(sync): stabilize windows checksum writes Signed-off-by: phernandez --- src/basic_memory/file_utils.py | 10 +++++--- src/basic_memory/services/file_service.py | 19 +++++++++++---- tests/services/test_file_service.py | 28 +++++++++++++++++++++++ 3 files changed, 49 insertions(+), 8 deletions(-) diff --git a/src/basic_memory/file_utils.py b/src/basic_memory/file_utils.py index 3de80b29..746a91aa 100644 --- a/src/basic_memory/file_utils.py +++ b/src/basic_memory/file_utils.py @@ -114,8 +114,12 @@ async def write_file_atomic(path: FilePath, content: str) -> None: temp_path = path_obj.with_suffix(".tmp") try: - # Use aiofiles for non-blocking write - async with aiofiles.open(temp_path, mode="w", encoding="utf-8") as f: + # Trigger: Basic Memory writes markdown and metadata from normalized Python strings. + # Why: Windows text mode would translate "\n" into "\r\n", which makes the + # persisted bytes diverge from the in-memory content we index and hash. + # Outcome: force LF on every platform so file bytes, checksums, and move detection + # stay deterministic across local and CI environments. + async with aiofiles.open(temp_path, mode="w", encoding="utf-8", newline="\n") as f: await f.write(content) # Atomic rename (this is fast, doesn't need async) @@ -168,7 +172,7 @@ async def format_markdown_builtin(path: Path) -> Optional[str]: # Only write if content changed if formatted_content != content: - async with aiofiles.open(path, mode="w", encoding="utf-8") as f: + async with aiofiles.open(path, mode="w", encoding="utf-8", newline="\n") as f: await f.write(formatted_content) logger.debug( diff --git a/src/basic_memory/services/file_service.py b/src/basic_memory/services/file_service.py index c7081e15..adf23090 100644 --- a/src/basic_memory/services/file_service.py +++ b/src/basic_memory/services/file_service.py @@ -208,15 +208,20 @@ async def write_file(self, path: FilePath, content: str) -> str: await file_utils.write_file_atomic(full_path, content) - final_content = content if self.app_config: formatted_content = await file_utils.format_file( full_path, self.app_config, is_markdown=self.is_markdown(path) ) if formatted_content is not None: - final_content = formatted_content # pragma: no cover - - checksum = await file_utils.compute_checksum(final_content) + pass # pragma: no cover + + # Trigger: formatters and platform-specific text writers can change the + # persisted bytes even when the logical content string is the same. + # Why: sync and move detection compare against on-disk checksums, not + # the pre-write Python string. + # Outcome: return the checksum of the actual stored file so callers do + # not record a hash that immediately disagrees with the file. + checksum = await self.compute_checksum(full_path) logger.debug(f"File write completed path={full_path}, {checksum=}") return checksum @@ -478,8 +483,12 @@ async def update_frontmatter_with_result( if formatted_content is not None: content_for_checksum = formatted_content # pragma: no cover + # Trigger: frontmatter normalization may persist bytes that differ from the + # in-memory string because of formatter output or platform newline handling. + # Why: follow-up scans and checksum-based move detection read raw bytes from disk. + # Outcome: the returned checksum always matches the file that was just written. return FrontmatterUpdateResult( - checksum=await file_utils.compute_checksum(content_for_checksum), + checksum=await self.compute_checksum(full_path), content=content_for_checksum, ) diff --git a/tests/services/test_file_service.py b/tests/services/test_file_service.py index e2e2fdf9..b35cf94a 100644 --- a/tests/services/test_file_service.py +++ b/tests/services/test_file_service.py @@ -4,6 +4,7 @@ import pytest +from basic_memory import file_utils from basic_memory.services.exceptions import FileOperationError from basic_memory.services.file_service import FileService @@ -167,6 +168,33 @@ async def test_write_unicode_content(tmp_path: Path, file_service: FileService): assert content == test_content +@pytest.mark.asyncio +async def test_update_frontmatter_checksum_matches_persisted_bytes( + tmp_path: Path, file_service: FileService, monkeypatch +): + """Frontmatter writes should hash the stored file, not the pre-write string.""" + test_path = tmp_path / "note.md" + test_path.write_text("# Note\nBody\n", encoding="utf-8") + + async def fake_write_file_atomic(path: Path, content: str) -> None: + # Trigger: simulate a writer that persists CRLF bytes like Windows text mode. + # Why: the regression happened when the stored bytes diverged from the LF string + # used to build the checksum. + # Outcome: this test proves FileService returns the checksum for the stored file. + persisted = content.replace("\n", "\r\n").encode("utf-8") + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(persisted) + + monkeypatch.setattr(file_utils, "write_file_atomic", fake_write_file_atomic) + + result = await file_service.update_frontmatter_with_result( + test_path, + {"title": "Note", "type": "note"}, + ) + + assert result.checksum == await file_service.compute_checksum(test_path) + + @pytest.mark.asyncio async def test_read_file_content(tmp_path: Path, file_service: FileService): """Test read_file_content returns just the content without checksum.""" From b54c5f268b5c905b54969249272824fde8b8cce1 Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 21:42:35 -0500 Subject: [PATCH 08/14] fix(sync): preserve platform newlines Signed-off-by: phernandez --- src/basic_memory/file_utils.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/basic_memory/file_utils.py b/src/basic_memory/file_utils.py index 746a91aa..3de80b29 100644 --- a/src/basic_memory/file_utils.py +++ b/src/basic_memory/file_utils.py @@ -114,12 +114,8 @@ async def write_file_atomic(path: FilePath, content: str) -> None: temp_path = path_obj.with_suffix(".tmp") try: - # Trigger: Basic Memory writes markdown and metadata from normalized Python strings. - # Why: Windows text mode would translate "\n" into "\r\n", which makes the - # persisted bytes diverge from the in-memory content we index and hash. - # Outcome: force LF on every platform so file bytes, checksums, and move detection - # stay deterministic across local and CI environments. - async with aiofiles.open(temp_path, mode="w", encoding="utf-8", newline="\n") as f: + # Use aiofiles for non-blocking write + async with aiofiles.open(temp_path, mode="w", encoding="utf-8") as f: await f.write(content) # Atomic rename (this is fast, doesn't need async) @@ -172,7 +168,7 @@ async def format_markdown_builtin(path: Path) -> Optional[str]: # Only write if content changed if formatted_content != content: - async with aiofiles.open(path, mode="w", encoding="utf-8", newline="\n") as f: + async with aiofiles.open(path, mode="w", encoding="utf-8") as f: await f.write(formatted_content) logger.debug( From 3f23b3ebb9fec66affc03481f47942b6974ef3ef Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 21:53:52 -0500 Subject: [PATCH 09/14] docs(core): explain windows newline checksum behavior Signed-off-by: phernandez --- src/basic_memory/file_utils.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/basic_memory/file_utils.py b/src/basic_memory/file_utils.py index 3de80b29..d23626da 100644 --- a/src/basic_memory/file_utils.py +++ b/src/basic_memory/file_utils.py @@ -114,7 +114,13 @@ async def write_file_atomic(path: FilePath, content: str) -> None: temp_path = path_obj.with_suffix(".tmp") try: - # Use aiofiles for non-blocking write + # Trigger: callers hand us normalized Python text, but the final bytes are allowed + # to use the host platform's native newline convention during the write. + # Why: preserving CRLF on Windows keeps local files aligned with editors like + # Obsidian, while FileService now hashes the persisted file bytes instead of + # the pre-write string. + # Outcome: this async write stays editor-friendly across platforms without + # reintroducing checksum drift in sync or move detection. async with aiofiles.open(temp_path, mode="w", encoding="utf-8") as f: await f.write(content) @@ -168,6 +174,13 @@ async def format_markdown_builtin(path: Path) -> Optional[str]: # Only write if content changed if formatted_content != content: + # Trigger: mdformat may rewrite markdown content, then the host platform + # decides the newline bytes for the follow-up async text write. + # Why: we want formatter output to preserve native newlines instead of + # forcing LF, and the authoritative checksum comes from rereading the + # stored file bytes later in FileService. + # Outcome: formatting remains compatible with local editors on Windows while + # checksum-based sync logic stays anchored to on-disk bytes. async with aiofiles.open(path, mode="w", encoding="utf-8") as f: await f.write(formatted_content) From 5069b462efdbcb058f175caae821effe1c1a896c Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 21:55:17 -0500 Subject: [PATCH 10/14] test(core): clarify windows checksum regression Signed-off-by: phernandez --- tests/services/test_file_service.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/services/test_file_service.py b/tests/services/test_file_service.py index b35cf94a..dca5a565 100644 --- a/tests/services/test_file_service.py +++ b/tests/services/test_file_service.py @@ -169,15 +169,16 @@ async def test_write_unicode_content(tmp_path: Path, file_service: FileService): @pytest.mark.asyncio -async def test_update_frontmatter_checksum_matches_persisted_bytes( +async def test_update_frontmatter_checksum_matches_windows_crlf_persisted_bytes( tmp_path: Path, file_service: FileService, monkeypatch ): - """Frontmatter writes should hash the stored file, not the pre-write string.""" + """Windows-style CRLF writes should hash the stored file, not the pre-write string.""" test_path = tmp_path / "note.md" test_path.write_text("# Note\nBody\n", encoding="utf-8") async def fake_write_file_atomic(path: Path, content: str) -> None: - # Trigger: simulate a writer that persists CRLF bytes like Windows text mode. + # Trigger: simulate Windows text-mode persistence, where logical LF strings + # land on disk as CRLF bytes. # Why: the regression happened when the stored bytes diverged from the LF string # used to build the checksum. # Outcome: this test proves FileService returns the checksum for the stored file. From 7421cfccc90cf6e79461319bde2990eb14c81014 Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 22:18:19 -0500 Subject: [PATCH 11/14] fix(core): align checksums with persisted bytes Signed-off-by: phernandez --- src/basic_memory/services/file_service.py | 8 +++++++- tests/indexing/test_batch_indexer.py | 16 ++++++++++++++-- 2 files changed, 21 insertions(+), 3 deletions(-) diff --git a/src/basic_memory/services/file_service.py b/src/basic_memory/services/file_service.py index adf23090..d73e5e1a 100644 --- a/src/basic_memory/services/file_service.py +++ b/src/basic_memory/services/file_service.py @@ -350,7 +350,13 @@ async def read_file(self, path: FilePath) -> Tuple[str, str]: async with aiofiles.open(full_path, mode="r", encoding="utf-8") as f: content = await f.read() - checksum = await file_utils.compute_checksum(content) + # Trigger: text-mode reads normalize line endings on Windows, so the + # decoded string can differ from the bytes we just wrote. + # Why: write_file/update_frontmatter now return the checksum of the + # persisted file, and read_file should report the same authority. + # Outcome: callers get human-readable content plus the checksum for the + # exact bytes stored on disk. + checksum = await self.compute_checksum(full_path) logger.debug( "File read completed", diff --git a/tests/indexing/test_batch_indexer.py b/tests/indexing/test_batch_indexer.py index cacf55aa..1d0ebf92 100644 --- a/tests/indexing/test_batch_indexer.py +++ b/tests/indexing/test_batch_indexer.py @@ -265,9 +265,15 @@ async def test_batch_indexer_returns_original_markdown_content_when_no_frontmatt parse_max_concurrent=1, ) + # Trigger: Windows persists CRLF for text writes even when the test literal uses LF. + # Why: this assertion cares about "no rewrite happened", not about forcing one newline + # convention across platforms. + # Outcome: compare against the exact markdown text stored on disk for this file. + persisted_content = (project_config.home / path).read_text() + assert result.errors == [] assert len(result.indexed) == 1 - assert result.indexed[0].markdown_content == original_content + assert result.indexed[0].markdown_content == persisted_content @pytest.mark.asyncio @@ -514,9 +520,15 @@ async def test_batch_indexer_uses_parsed_markdown_body_for_malformed_frontmatter parse_max_concurrent=1, ) + # Trigger: malformed frontmatter should pass through without normalization. + # Why: Windows can still surface that unchanged file with CRLF line endings. + # Outcome: compare the indexed markdown to the persisted file content, not the LF + # test literal used to create it. + persisted_content = (project_config.home / path).read_text() + assert result.errors == [] assert len(result.indexed) == 1 - assert result.indexed[0].markdown_content == malformed_content + assert result.indexed[0].markdown_content == persisted_content entity = await entity_repository.get_by_file_path(path) assert entity is not None From 7c37e49427dd0d7a64cb0bc9eb28a20ec57a4970 Mon Sep 17 00:00:00 2001 From: phernandez Date: Wed, 8 Apr 2026 23:10:46 -0500 Subject: [PATCH 12/14] test(core): preserve persisted windows newlines Signed-off-by: phernandez --- tests/indexing/test_batch_indexer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/indexing/test_batch_indexer.py b/tests/indexing/test_batch_indexer.py index 1d0ebf92..e1371d92 100644 --- a/tests/indexing/test_batch_indexer.py +++ b/tests/indexing/test_batch_indexer.py @@ -269,7 +269,7 @@ async def test_batch_indexer_returns_original_markdown_content_when_no_frontmatt # Why: this assertion cares about "no rewrite happened", not about forcing one newline # convention across platforms. # Outcome: compare against the exact markdown text stored on disk for this file. - persisted_content = (project_config.home / path).read_text() + persisted_content = (project_config.home / path).read_bytes().decode("utf-8") assert result.errors == [] assert len(result.indexed) == 1 @@ -524,7 +524,7 @@ async def test_batch_indexer_uses_parsed_markdown_body_for_malformed_frontmatter # Why: Windows can still surface that unchanged file with CRLF line endings. # Outcome: compare the indexed markdown to the persisted file content, not the LF # test literal used to create it. - persisted_content = (project_config.home / path).read_text() + persisted_content = (project_config.home / path).read_bytes().decode("utf-8") assert result.errors == [] assert len(result.indexed) == 1 From 3a5371a43cbe15eb9bb07876d8a74f76e8da3322 Mon Sep 17 00:00:00 2001 From: phernandez Date: Thu, 9 Apr 2026 20:29:46 -0500 Subject: [PATCH 13/14] fix(core): clean up delete vectors and cloud sync Signed-off-by: phernandez --- .../cli/commands/cloud/project_sync.py | 32 ---- .../repository/search_repository.py | 4 + .../repository/search_repository_base.py | 9 ++ src/basic_memory/services/search_service.py | 20 +-- tests/cli/cloud/test_project_sync_command.py | 28 +--- tests/services/test_entity_service.py | 151 +++++++++++++++++- tests/services/test_initialization.py | 1 - 7 files changed, 173 insertions(+), 72 deletions(-) diff --git a/src/basic_memory/cli/commands/cloud/project_sync.py b/src/basic_memory/cli/commands/cloud/project_sync.py index c963a09b..e18620da 100644 --- a/src/basic_memory/cli/commands/cloud/project_sync.py +++ b/src/basic_memory/cli/commands/cloud/project_sync.py @@ -124,22 +124,6 @@ def sync_project_command( if success: console.print(f"[green]{name} synced successfully[/green]") - - # Trigger database sync if not a dry run - if not dry_run: - - async def _trigger_db_sync(): - async with get_client(project_name=name) as client: - return await ProjectClient(client).sync( - project_data.external_id, force_full=False - ) - - try: - with force_routing(cloud=True): - result = run_with_cleanup(_trigger_db_sync()) - console.print(f"[dim]Database sync initiated: {result.get('message')}[/dim]") - except Exception as e: - console.print(f"[yellow]Warning: Could not trigger database sync: {e}[/yellow]") else: console.print(f"[red]{name} sync failed[/red]") raise typer.Exit(1) @@ -202,22 +186,6 @@ def bisync_project_command( sync_entry.last_sync = datetime.now() sync_entry.bisync_initialized = True ConfigManager().save_config(config) - - # Trigger database sync if not a dry run - if not dry_run: - - async def _trigger_db_sync(): - async with get_client(project_name=name) as client: - return await ProjectClient(client).sync( - project_data.external_id, force_full=False - ) - - try: - with force_routing(cloud=True): - result = run_with_cleanup(_trigger_db_sync()) - console.print(f"[dim]Database sync initiated: {result.get('message')}[/dim]") - except Exception as e: - console.print(f"[yellow]Warning: Could not trigger database sync: {e}[/yellow]") else: console.print(f"[red]{name} bisync failed[/red]") raise typer.Exit(1) diff --git a/src/basic_memory/repository/search_repository.py b/src/basic_memory/repository/search_repository.py index 005fee1f..d1d2e365 100644 --- a/src/basic_memory/repository/search_repository.py +++ b/src/basic_memory/repository/search_repository.py @@ -70,6 +70,10 @@ async def sync_entity_vectors(self, entity_id: int) -> None: """Sync semantic vector chunks for an entity.""" ... + async def delete_entity_vector_rows(self, entity_id: int) -> None: + """Delete semantic vector chunks and embeddings for one entity.""" + ... + async def sync_entity_vectors_batch( self, entity_ids: list[int], diff --git a/src/basic_memory/repository/search_repository_base.py b/src/basic_memory/repository/search_repository_base.py index 9e372600..7783aa28 100644 --- a/src/basic_memory/repository/search_repository_base.py +++ b/src/basic_memory/repository/search_repository_base.py @@ -454,6 +454,15 @@ async def execute_query( logger.debug(f"Query executed successfully in {elapsed_time:.2f}s.") return result + async def delete_entity_vector_rows(self, entity_id: int) -> None: + """Delete one entity's derived vector rows using the backend's cleanup path.""" + await self._ensure_vector_tables() + + async with db.scoped_session(self.session_maker) as session: + await self._prepare_vector_session(session) + await self._delete_entity_chunks(session, entity_id) + await session.commit() + # ------------------------------------------------------------------ # Shared semantic search: guard, text processing, chunking # ------------------------------------------------------------------ diff --git a/src/basic_memory/services/search_service.py b/src/basic_memory/services/search_service.py index b1aec10b..d43971b8 100644 --- a/src/basic_memory/services/search_service.py +++ b/src/basic_memory/services/search_service.py @@ -660,7 +660,6 @@ def _entity_embeddings_enabled(entity: Entity) -> bool: async def _clear_entity_vectors(self, entity_id: int) -> None: """Delete derived vector rows for one entity.""" from basic_memory.repository.search_repository_base import SearchRepositoryBase - from basic_memory.repository.sqlite_search_repository import SQLiteSearchRepository # Trigger: semantic indexing is disabled for this repository instance. # Why: repositories only create vector tables when semantic search is enabled. @@ -671,17 +670,7 @@ async def _clear_entity_vectors(self, entity_id: int) -> None: ): return - params = {"project_id": self.repository.project_id, "entity_id": entity_id} - if isinstance(self.repository, SQLiteSearchRepository): - await self.repository.delete_entity_vector_rows(entity_id) - else: - await self.repository.execute_query( - text( - "DELETE FROM search_vector_chunks " - "WHERE project_id = :project_id AND entity_id = :entity_id" - ), - params, - ) + await self.repository.delete_entity_vector_rows(entity_id) async def index_entity_file( self, @@ -889,7 +878,7 @@ async def delete_by_entity_id(self, entity_id: int): await self.repository.delete_by_entity_id(entity_id) async def handle_delete(self, entity: Entity): - """Handle complete entity deletion from search index including observations and relations. + """Handle complete entity deletion from search and semantic index state. This replicates the logic from sync_service.handle_delete() to properly clean up all search index entries for an entity and its related data. @@ -916,3 +905,8 @@ async def handle_delete(self, entity: Entity): await self.delete_by_permalink(permalink) else: await self.delete_by_entity_id(entity.id) + + # Trigger: entity deletion removes the source rows for this note. + # Why: semantic chunks/embeddings are stored separately from search_index rows. + # Outcome: deleting an entity clears both full-text and vector-derived search state. + await self._clear_entity_vectors(entity.id) diff --git a/tests/cli/cloud/test_project_sync_command.py b/tests/cli/cloud/test_project_sync_command.py index 4d5dda94..93166138 100644 --- a/tests/cli/cloud/test_project_sync_command.py +++ b/tests/cli/cloud/test_project_sync_command.py @@ -1,7 +1,6 @@ """Tests for cloud sync and bisync command behavior.""" import importlib -from contextlib import asynccontextmanager from types import SimpleNamespace import pytest @@ -20,11 +19,10 @@ ["cloud", "bisync", "--name", "research"], ], ) -def test_cloud_sync_commands_use_incremental_db_sync(monkeypatch, argv, config_manager): - """Cloud sync commands should not force a full database re-index after file sync.""" +def test_cloud_sync_commands_skip_explicit_cloud_project_sync(monkeypatch, argv, config_manager): + """Cloud sync commands should not trigger an extra explicit cloud project sync.""" project_sync_command = importlib.import_module("basic_memory.cli.commands.cloud.project_sync") - seen: dict[str, object] = {} config = config_manager.load_config() config.set_project_mode("research", ProjectMode.CLOUD) config_manager.save_config(config) @@ -50,30 +48,10 @@ def test_cloud_sync_commands_use_incremental_db_sync(monkeypatch, argv, config_m monkeypatch.setattr(project_sync_command, "project_sync", lambda *args, **kwargs: True) monkeypatch.setattr(project_sync_command, "project_bisync", lambda *args, **kwargs: True) - @asynccontextmanager - async def fake_get_client(*, project_name=None, workspace=None): - seen["project_name"] = project_name - seen["workspace"] = workspace - yield object() - - class FakeProjectClient: - def __init__(self, _client): - pass - - async def sync(self, external_id: str, force_full: bool = False): - seen["external_id"] = external_id - seen["force_full"] = force_full - return {"message": "queued"} - - monkeypatch.setattr(project_sync_command, "get_client", fake_get_client) - monkeypatch.setattr(project_sync_command, "ProjectClient", FakeProjectClient) - result = runner.invoke(app, argv) assert result.exit_code == 0, result.output - assert seen["project_name"] == "research" - assert seen["external_id"] == "external-project-id" - assert seen["force_full"] is False + assert "Database sync initiated" not in result.output def test_cloud_bisync_fails_fast_when_sync_entry_disappears(monkeypatch, config_manager): diff --git a/tests/services/test_entity_service.py b/tests/services/test_entity_service.py index dd9d1e6b..df071ce8 100644 --- a/tests/services/test_entity_service.py +++ b/tests/services/test_entity_service.py @@ -6,8 +6,10 @@ import pytest import yaml +from sqlalchemy import text -from basic_memory.config import ProjectConfig, BasicMemoryConfig +from basic_memory import db +from basic_memory.config import ProjectConfig, BasicMemoryConfig, DatabaseBackend from basic_memory.markdown import EntityParser from basic_memory.models import Entity as EntityModel from basic_memory.repository import EntityRepository @@ -19,6 +21,98 @@ from basic_memory.utils import generate_permalink +class _DeleteTestEmbeddingProvider: + """Deterministic embedding provider for entity delete cleanup tests.""" + + model_name = "delete-test" + dimensions = 4 + + async def embed_query(self, text: str) -> list[float]: + return self._vectorize(text) + + async def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [self._vectorize(text) for text in texts] + + @staticmethod + def _vectorize(text: str) -> list[float]: + normalized = text.lower() + if "semantic" in normalized: + return [1.0, 0.0, 0.0, 0.0] + if "cleanup" in normalized: + return [0.0, 1.0, 0.0, 0.0] + return [0.0, 0.0, 1.0, 0.0] + + +async def _count_entity_search_state( + session_maker, + app_config: BasicMemoryConfig, + project_id: int, + entity_id: int, +) -> tuple[int, int, int]: + """Return counts for all derived search rows tied to one entity.""" + embedding_join = ( + "e.chunk_id = c.id" + if app_config.database_backend == DatabaseBackend.POSTGRES + else "e.rowid = c.id" + ) + params = {"project_id": project_id, "entity_id": entity_id} + + async with db.scoped_session(session_maker) as session: + search_index_rows = await session.execute( + text( + "SELECT COUNT(*) FROM search_index " + "WHERE project_id = :project_id AND entity_id = :entity_id" + ), + params, + ) + vector_chunk_rows = await session.execute( + text( + "SELECT COUNT(*) FROM search_vector_chunks " + "WHERE project_id = :project_id AND entity_id = :entity_id" + ), + params, + ) + vector_embedding_rows = await session.execute( + text( + "SELECT COUNT(*) FROM search_vector_embeddings e " + "JOIN search_vector_chunks c ON " + f"{embedding_join} " + "WHERE c.project_id = :project_id AND c.entity_id = :entity_id" + ), + params, + ) + + return ( + int(search_index_rows.scalar_one()), + int(vector_chunk_rows.scalar_one()), + int(vector_embedding_rows.scalar_one()), + ) + + +@pytest.fixture +def entity_service_with_search( + entity_repository: EntityRepository, + observation_repository, + relation_repository, + entity_parser: EntityParser, + file_service: FileService, + link_resolver, + search_service: SearchService, + app_config: BasicMemoryConfig, +) -> EntityService: + """Create EntityService with a real attached search service.""" + return EntityService( + entity_parser=entity_parser, + entity_repository=entity_repository, + observation_repository=observation_repository, + relation_repository=relation_repository, + file_service=file_service, + link_resolver=link_resolver, + search_service=search_service, + app_config=app_config, + ) + + @pytest.mark.asyncio async def test_create_entity( entity_service: EntityService, file_service: FileService, project_config: ProjectConfig @@ -227,6 +321,61 @@ async def test_delete_entity_by_id(entity_service: EntityService): await entity_service.get_by_permalink(entity_data.permalink) +@pytest.mark.asyncio +async def test_delete_entity_removes_search_and_vector_state( + entity_service_with_search: EntityService, + search_service: SearchService, + session_maker, + app_config: BasicMemoryConfig, +): + """Deleting an entity should clear all of its full-text and semantic search state.""" + if app_config.database_backend == DatabaseBackend.SQLITE: + pytest.importorskip("sqlite_vec") + + repository = search_service.repository + repository._semantic_enabled = True + repository._embedding_provider = _DeleteTestEmbeddingProvider() + repository._vector_dimensions = repository._embedding_provider.dimensions + repository._vector_tables_initialized = False + await search_service.init_search_index() + + entity = await entity_service_with_search.create_entity( + EntitySchema( + title="Semantic Delete Target", + directory="test", + note_type="note", + content=dedent(""" + # Semantic Delete Target + + - [note] Semantic cleanup should remove every derived row + - references [[Cleanup Target]] + """).strip(), + ) + ) + + await search_service.index_entity(entity) + await search_service.sync_entity_vectors(entity.id) + + search_rows, chunk_rows, embedding_rows = await _count_entity_search_state( + session_maker, + app_config, + search_service.repository.project_id, + entity.id, + ) + assert search_rows >= 3 + assert chunk_rows > 0 + assert embedding_rows > 0 + + assert await entity_service_with_search.delete_entity(entity.id) is True + + assert await _count_entity_search_state( + session_maker, + app_config, + search_service.repository.project_id, + entity.id, + ) == (0, 0, 0) + + @pytest.mark.asyncio async def test_get_entity_by_permalink_not_found(entity_service: EntityService): """Test handling of non-existent entity retrieval.""" diff --git a/tests/services/test_initialization.py b/tests/services/test_initialization.py index 37c28285..0ebfc84e 100644 --- a/tests/services/test_initialization.py +++ b/tests/services/test_initialization.py @@ -196,4 +196,3 @@ def capture_warning(message: str) -> None: "ensure_frontmatter_on_sync=True overrides disable_permalinks=True" in message for message in warnings ) - From a6a4b22489d9a7036beb2e8ec4eda5e1d07963aa Mon Sep 17 00:00:00 2001 From: phernandez Date: Thu, 9 Apr 2026 20:47:02 -0500 Subject: [PATCH 14/14] refactor(core): remove redundant sqlite vector cleanup override Signed-off-by: phernandez --- .../repository/sqlite_search_repository.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/basic_memory/repository/sqlite_search_repository.py b/src/basic_memory/repository/sqlite_search_repository.py index 1c07ac1e..06eb41c0 100644 --- a/src/basic_memory/repository/sqlite_search_repository.py +++ b/src/basic_memory/repository/sqlite_search_repository.py @@ -565,21 +565,6 @@ async def _delete_stale_chunks( stale_params, ) - async def delete_entity_vector_rows(self, entity_id: int) -> None: - """Delete one entity's vec rows on a sqlite-vec-enabled connection.""" - await self._ensure_vector_tables() - - async with db.scoped_session(self.session_maker) as session: - await self._ensure_sqlite_vec_loaded(session) - - # Constraint: sqlite-vec virtual tables are only visible after vec0 is - # loaded on this exact connection. - # Why: generic repository sessions can reach search_vector_chunks but still - # fail with "no such module: vec0" when touching embeddings. - # Outcome: service-level cleanup routes vec-table deletes through this helper. - await self._delete_entity_chunks(session, entity_id) - await session.commit() - async def delete_project_vector_rows(self) -> None: """Delete all vector rows for this project on a sqlite-vec-enabled connection.""" await self._ensure_vector_tables()