Skip to content

Commit cdffd87

Browse files
committed
feat(core): support note-level embedding opt-out
Signed-off-by: phernandez <paul@basicmachines.co>
1 parent d986c4d commit cdffd87

File tree

2 files changed

+231
-3
lines changed

2 files changed

+231
-3
lines changed

src/basic_memory/services/search_service.py

Lines changed: 103 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,15 @@ async def index_entity_data(
427427

428428
async def sync_entity_vectors(self, entity_id: int) -> None:
429429
"""Refresh vector chunks for one entity in repositories that support semantic indexing."""
430+
entity = await self.entity_repository.find_by_id(entity_id)
431+
if entity is None:
432+
await self._clear_entity_vectors(entity_id)
433+
return
434+
435+
if not self._entity_embeddings_enabled(entity):
436+
await self._clear_entity_vectors(entity_id)
437+
return
438+
430439
await self.repository.sync_entity_vectors(entity_id)
431440

432441
async def sync_entity_vectors_batch(
@@ -435,10 +444,47 @@ async def sync_entity_vectors_batch(
435444
progress_callback=None,
436445
) -> VectorSyncBatchResult:
437446
"""Refresh vector chunks for a batch of entities."""
438-
return await self.repository.sync_entity_vectors_batch(
439-
entity_ids,
447+
if not entity_ids:
448+
return VectorSyncBatchResult(
449+
entities_total=0,
450+
entities_synced=0,
451+
entities_failed=0,
452+
)
453+
454+
entities_by_id = {
455+
entity.id: entity for entity in await self.entity_repository.find_by_ids(entity_ids)
456+
}
457+
opted_out_ids = [
458+
entity_id
459+
for entity_id in entity_ids
460+
if (
461+
(entity := entities_by_id.get(entity_id)) is not None
462+
and not self._entity_embeddings_enabled(entity)
463+
)
464+
]
465+
for entity_id in opted_out_ids:
466+
await self._clear_entity_vectors(entity_id)
467+
468+
eligible_entity_ids = [
469+
entity_id
470+
for entity_id in entity_ids
471+
if entity_id in entities_by_id and entity_id not in opted_out_ids
472+
]
473+
if not eligible_entity_ids:
474+
return VectorSyncBatchResult(
475+
entities_total=len(entity_ids),
476+
entities_synced=0,
477+
entities_failed=0,
478+
entities_skipped=len(opted_out_ids),
479+
)
480+
481+
batch_result = await self.repository.sync_entity_vectors_batch(
482+
eligible_entity_ids,
440483
progress_callback=progress_callback,
441484
)
485+
batch_result.entities_total = len(entity_ids)
486+
batch_result.entities_skipped += len(opted_out_ids)
487+
return batch_result
442488

443489
async def reindex_vectors(self, progress_callback=None) -> dict:
444490
"""Rebuild vector embeddings for all entities.
@@ -463,7 +509,7 @@ async def reindex_vectors(self, progress_callback=None) -> dict:
463509
stats = {
464510
"total_entities": batch_result.entities_total,
465511
"embedded": batch_result.entities_synced,
466-
"skipped": 0,
512+
"skipped": batch_result.entities_skipped,
467513
"errors": batch_result.entities_failed,
468514
}
469515

@@ -518,6 +564,60 @@ async def _purge_stale_search_rows(self) -> None:
518564

519565
logger.info("Purged stale search rows for deleted entities", project_id=project_id)
520566

567+
@staticmethod
568+
def _entity_embeddings_enabled(entity: Entity) -> bool:
569+
"""Return whether semantic embeddings should be generated for this entity."""
570+
if not entity.entity_metadata:
571+
return True
572+
573+
embed_value = entity.entity_metadata.get("embed")
574+
if embed_value is None:
575+
return True
576+
if isinstance(embed_value, bool):
577+
return embed_value
578+
if isinstance(embed_value, str):
579+
normalized = embed_value.strip().lower()
580+
if normalized in {"false", "0", "no", "off"}:
581+
return False
582+
if normalized in {"true", "1", "yes", "on"}:
583+
return True
584+
if isinstance(embed_value, (int, float)):
585+
return bool(embed_value)
586+
587+
# Default unknown values to enabled so malformed metadata does not silently
588+
# remove notes from semantic search.
589+
return True
590+
591+
async def _clear_entity_vectors(self, entity_id: int) -> None:
592+
"""Delete derived vector rows for one entity."""
593+
from basic_memory.repository.search_repository_base import SearchRepositoryBase
594+
from basic_memory.repository.sqlite_search_repository import SQLiteSearchRepository
595+
596+
# Trigger: semantic indexing is disabled for this repository instance.
597+
# Why: repositories only create vector tables when semantic search is enabled.
598+
# Outcome: skip cleanup because there are no active derived vector rows to maintain.
599+
if isinstance(self.repository, SearchRepositoryBase) and not self.repository._semantic_enabled:
600+
return
601+
602+
params = {"project_id": self.repository.project_id, "entity_id": entity_id}
603+
if isinstance(self.repository, SQLiteSearchRepository):
604+
await self.repository.execute_query(
605+
text(
606+
"DELETE FROM search_vector_embeddings WHERE rowid IN ("
607+
"SELECT id FROM search_vector_chunks "
608+
"WHERE project_id = :project_id AND entity_id = :entity_id)"
609+
),
610+
params,
611+
)
612+
613+
await self.repository.execute_query(
614+
text(
615+
"DELETE FROM search_vector_chunks "
616+
"WHERE project_id = :project_id AND entity_id = :entity_id"
617+
),
618+
params,
619+
)
620+
521621
async def index_entity_file(
522622
self,
523623
entity: Entity,

tests/services/test_semantic_search.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
"""Semantic search service regression tests for local SQLite search."""
22

3+
from datetime import datetime
4+
from types import SimpleNamespace
5+
from unittest.mock import AsyncMock
6+
37
import pytest
48

9+
from basic_memory.repository import EntityRepository
10+
from basic_memory.repository.search_repository_base import VectorSyncBatchResult
511
from basic_memory.repository.semantic_errors import (
612
SemanticDependenciesMissingError,
713
SemanticSearchDisabledError,
@@ -89,3 +95,125 @@ async def test_semantic_fts_mode_still_returns_observations(search_service, test
8995

9096
assert results
9197
assert any(result.type == SearchItemType.OBSERVATION.value for result in results)
98+
99+
100+
@pytest.mark.asyncio
101+
async def test_semantic_vector_sync_skips_embed_opt_out_and_clears_vectors(
102+
search_service, monkeypatch
103+
):
104+
"""Embed opt-out should clear stale vectors instead of regenerating them."""
105+
repository = _sqlite_repo(search_service)
106+
repository._semantic_enabled = True
107+
108+
monkeypatch.setattr(
109+
search_service.entity_repository,
110+
"find_by_id",
111+
AsyncMock(return_value=SimpleNamespace(id=42, entity_metadata={"embed": False})),
112+
)
113+
sync_vectors = AsyncMock()
114+
execute_query = AsyncMock()
115+
monkeypatch.setattr(repository, "sync_entity_vectors", sync_vectors)
116+
monkeypatch.setattr(repository, "execute_query", execute_query)
117+
118+
await search_service.sync_entity_vectors(42)
119+
120+
sync_vectors.assert_not_awaited()
121+
assert execute_query.await_count == 2
122+
123+
124+
@pytest.mark.asyncio
125+
async def test_semantic_vector_sync_resumes_when_embed_opt_out_removed(
126+
search_service, monkeypatch
127+
):
128+
"""Removing the opt-out should restore normal embedding sync."""
129+
repository = _sqlite_repo(search_service)
130+
repository._semantic_enabled = True
131+
132+
monkeypatch.setattr(
133+
search_service.entity_repository,
134+
"find_by_id",
135+
AsyncMock(return_value=SimpleNamespace(id=42, entity_metadata={})),
136+
)
137+
sync_vectors = AsyncMock()
138+
execute_query = AsyncMock()
139+
monkeypatch.setattr(repository, "sync_entity_vectors", sync_vectors)
140+
monkeypatch.setattr(repository, "execute_query", execute_query)
141+
142+
await search_service.sync_entity_vectors(42)
143+
144+
sync_vectors.assert_awaited_once_with(42)
145+
execute_query.assert_not_awaited()
146+
147+
148+
@pytest.mark.asyncio
149+
async def test_semantic_vector_sync_batch_skips_embed_opt_out_and_reports_skips(
150+
search_service, monkeypatch
151+
):
152+
"""Batch vector sync should only embed eligible notes and report skipped opt-outs."""
153+
repository = _sqlite_repo(search_service)
154+
repository._semantic_enabled = True
155+
156+
monkeypatch.setattr(
157+
search_service.entity_repository,
158+
"find_by_ids",
159+
AsyncMock(
160+
return_value=[
161+
SimpleNamespace(id=41, entity_metadata={"embed": False}),
162+
SimpleNamespace(id=42, entity_metadata={}),
163+
]
164+
),
165+
)
166+
sync_batch = AsyncMock(
167+
return_value=VectorSyncBatchResult(
168+
entities_total=1,
169+
entities_synced=1,
170+
entities_failed=0,
171+
)
172+
)
173+
execute_query = AsyncMock()
174+
monkeypatch.setattr(repository, "sync_entity_vectors_batch", sync_batch)
175+
monkeypatch.setattr(repository, "execute_query", execute_query)
176+
177+
result = await search_service.sync_entity_vectors_batch([41, 42])
178+
179+
sync_batch.assert_awaited_once()
180+
assert sync_batch.await_args.args[0] == [42]
181+
assert result.entities_total == 2
182+
assert result.entities_synced == 1
183+
assert result.entities_skipped == 1
184+
assert execute_query.await_count == 2
185+
186+
187+
@pytest.mark.asyncio
188+
async def test_embed_opt_out_note_still_participates_in_fts(
189+
search_service, session_maker, test_project
190+
):
191+
"""Per-note semantic opt-out should not remove the note from FTS search."""
192+
entity_repo = EntityRepository(session_maker, project_id=test_project.id)
193+
entity = await entity_repo.create(
194+
{
195+
"title": "FTS Opt Out",
196+
"note_type": "note",
197+
"entity_metadata": {"embed": False},
198+
"content_type": "text/markdown",
199+
"file_path": "test/fts-opt-out.md",
200+
"permalink": "test/fts-opt-out",
201+
"project_id": test_project.id,
202+
"created_at": datetime.now(),
203+
"updated_at": datetime.now(),
204+
}
205+
)
206+
207+
await search_service.index_entity(
208+
entity,
209+
content="This note should stay searchable through full text indexing.",
210+
)
211+
212+
results = await search_service.search(
213+
SearchQuery(
214+
text="stay searchable",
215+
retrieval_mode=SearchRetrievalMode.FTS,
216+
)
217+
)
218+
219+
assert any(result.entity_id == entity.id for result in results)

0 commit comments

Comments
 (0)