|
6 | 6 |
|
7 | 7 | import pytest |
8 | 8 | import yaml |
| 9 | +from sqlalchemy import text |
9 | 10 |
|
10 | | -from basic_memory.config import ProjectConfig, BasicMemoryConfig |
| 11 | +from basic_memory import db |
| 12 | +from basic_memory.config import ProjectConfig, BasicMemoryConfig, DatabaseBackend |
11 | 13 | from basic_memory.markdown import EntityParser |
12 | 14 | from basic_memory.models import Entity as EntityModel |
13 | 15 | from basic_memory.repository import EntityRepository |
|
19 | 21 | from basic_memory.utils import generate_permalink |
20 | 22 |
|
21 | 23 |
|
| 24 | +class _DeleteTestEmbeddingProvider: |
| 25 | + """Deterministic embedding provider for entity delete cleanup tests.""" |
| 26 | + |
| 27 | + model_name = "delete-test" |
| 28 | + dimensions = 4 |
| 29 | + |
| 30 | + async def embed_query(self, text: str) -> list[float]: |
| 31 | + return self._vectorize(text) |
| 32 | + |
| 33 | + async def embed_documents(self, texts: list[str]) -> list[list[float]]: |
| 34 | + return [self._vectorize(text) for text in texts] |
| 35 | + |
| 36 | + @staticmethod |
| 37 | + def _vectorize(text: str) -> list[float]: |
| 38 | + normalized = text.lower() |
| 39 | + if "semantic" in normalized: |
| 40 | + return [1.0, 0.0, 0.0, 0.0] |
| 41 | + if "cleanup" in normalized: |
| 42 | + return [0.0, 1.0, 0.0, 0.0] |
| 43 | + return [0.0, 0.0, 1.0, 0.0] |
| 44 | + |
| 45 | + |
| 46 | +async def _count_entity_search_state( |
| 47 | + session_maker, |
| 48 | + app_config: BasicMemoryConfig, |
| 49 | + project_id: int, |
| 50 | + entity_id: int, |
| 51 | +) -> tuple[int, int, int]: |
| 52 | + """Return counts for all derived search rows tied to one entity.""" |
| 53 | + embedding_join = ( |
| 54 | + "e.chunk_id = c.id" |
| 55 | + if app_config.database_backend == DatabaseBackend.POSTGRES |
| 56 | + else "e.rowid = c.id" |
| 57 | + ) |
| 58 | + params = {"project_id": project_id, "entity_id": entity_id} |
| 59 | + |
| 60 | + async with db.scoped_session(session_maker) as session: |
| 61 | + search_index_rows = await session.execute( |
| 62 | + text( |
| 63 | + "SELECT COUNT(*) FROM search_index " |
| 64 | + "WHERE project_id = :project_id AND entity_id = :entity_id" |
| 65 | + ), |
| 66 | + params, |
| 67 | + ) |
| 68 | + vector_chunk_rows = await session.execute( |
| 69 | + text( |
| 70 | + "SELECT COUNT(*) FROM search_vector_chunks " |
| 71 | + "WHERE project_id = :project_id AND entity_id = :entity_id" |
| 72 | + ), |
| 73 | + params, |
| 74 | + ) |
| 75 | + vector_embedding_rows = await session.execute( |
| 76 | + text( |
| 77 | + "SELECT COUNT(*) FROM search_vector_embeddings e " |
| 78 | + "JOIN search_vector_chunks c ON " |
| 79 | + f"{embedding_join} " |
| 80 | + "WHERE c.project_id = :project_id AND c.entity_id = :entity_id" |
| 81 | + ), |
| 82 | + params, |
| 83 | + ) |
| 84 | + |
| 85 | + return ( |
| 86 | + int(search_index_rows.scalar_one()), |
| 87 | + int(vector_chunk_rows.scalar_one()), |
| 88 | + int(vector_embedding_rows.scalar_one()), |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +@pytest.fixture |
| 93 | +def entity_service_with_search( |
| 94 | + entity_repository: EntityRepository, |
| 95 | + observation_repository, |
| 96 | + relation_repository, |
| 97 | + entity_parser: EntityParser, |
| 98 | + file_service: FileService, |
| 99 | + link_resolver, |
| 100 | + search_service: SearchService, |
| 101 | + app_config: BasicMemoryConfig, |
| 102 | +) -> EntityService: |
| 103 | + """Create EntityService with a real attached search service.""" |
| 104 | + return EntityService( |
| 105 | + entity_parser=entity_parser, |
| 106 | + entity_repository=entity_repository, |
| 107 | + observation_repository=observation_repository, |
| 108 | + relation_repository=relation_repository, |
| 109 | + file_service=file_service, |
| 110 | + link_resolver=link_resolver, |
| 111 | + search_service=search_service, |
| 112 | + app_config=app_config, |
| 113 | + ) |
| 114 | + |
| 115 | + |
22 | 116 | @pytest.mark.asyncio |
23 | 117 | async def test_create_entity( |
24 | 118 | entity_service: EntityService, file_service: FileService, project_config: ProjectConfig |
@@ -227,6 +321,61 @@ async def test_delete_entity_by_id(entity_service: EntityService): |
227 | 321 | await entity_service.get_by_permalink(entity_data.permalink) |
228 | 322 |
|
229 | 323 |
|
| 324 | +@pytest.mark.asyncio |
| 325 | +async def test_delete_entity_removes_search_and_vector_state( |
| 326 | + entity_service_with_search: EntityService, |
| 327 | + search_service: SearchService, |
| 328 | + session_maker, |
| 329 | + app_config: BasicMemoryConfig, |
| 330 | +): |
| 331 | + """Deleting an entity should clear all of its full-text and semantic search state.""" |
| 332 | + if app_config.database_backend == DatabaseBackend.SQLITE: |
| 333 | + pytest.importorskip("sqlite_vec") |
| 334 | + |
| 335 | + repository = search_service.repository |
| 336 | + repository._semantic_enabled = True |
| 337 | + repository._embedding_provider = _DeleteTestEmbeddingProvider() |
| 338 | + repository._vector_dimensions = repository._embedding_provider.dimensions |
| 339 | + repository._vector_tables_initialized = False |
| 340 | + await search_service.init_search_index() |
| 341 | + |
| 342 | + entity = await entity_service_with_search.create_entity( |
| 343 | + EntitySchema( |
| 344 | + title="Semantic Delete Target", |
| 345 | + directory="test", |
| 346 | + note_type="note", |
| 347 | + content=dedent(""" |
| 348 | + # Semantic Delete Target |
| 349 | +
|
| 350 | + - [note] Semantic cleanup should remove every derived row |
| 351 | + - references [[Cleanup Target]] |
| 352 | + """).strip(), |
| 353 | + ) |
| 354 | + ) |
| 355 | + |
| 356 | + await search_service.index_entity(entity) |
| 357 | + await search_service.sync_entity_vectors(entity.id) |
| 358 | + |
| 359 | + search_rows, chunk_rows, embedding_rows = await _count_entity_search_state( |
| 360 | + session_maker, |
| 361 | + app_config, |
| 362 | + search_service.repository.project_id, |
| 363 | + entity.id, |
| 364 | + ) |
| 365 | + assert search_rows >= 3 |
| 366 | + assert chunk_rows > 0 |
| 367 | + assert embedding_rows > 0 |
| 368 | + |
| 369 | + assert await entity_service_with_search.delete_entity(entity.id) is True |
| 370 | + |
| 371 | + assert await _count_entity_search_state( |
| 372 | + session_maker, |
| 373 | + app_config, |
| 374 | + search_service.repository.project_id, |
| 375 | + entity.id, |
| 376 | + ) == (0, 0, 0) |
| 377 | + |
| 378 | + |
230 | 379 | @pytest.mark.asyncio |
231 | 380 | async def test_get_entity_by_permalink_not_found(entity_service: EntityService): |
232 | 381 | """Test handling of non-existent entity retrieval.""" |
|
0 commit comments