|
9 | 9 | from basic_memory.schemas.memory import memory_url, memory_url_path |
10 | 10 | from basic_memory.schemas.search import SearchItemType |
11 | 11 | from basic_memory.services.context_service import ContextService |
| 12 | +from basic_memory.models.knowledge import Entity, Relation |
| 13 | +from basic_memory.models.project import Project |
12 | 14 |
|
13 | 15 |
|
14 | 16 | @pytest_asyncio.fixture |
@@ -218,3 +220,101 @@ async def test_context_metadata(context_service, test_graph): |
218 | 220 | assert metadata.depth == 2 |
219 | 221 | assert metadata.generated_at is not None |
220 | 222 | assert metadata.primary_count > 0 |
| 223 | + |
| 224 | + |
| 225 | +@pytest.mark.asyncio |
| 226 | +async def test_project_isolation_in_find_related(session_maker): |
| 227 | + """Test that find_related respects project boundaries and doesn't leak data.""" |
| 228 | + from basic_memory.repository.entity_repository import EntityRepository |
| 229 | + from basic_memory.repository.observation_repository import ObservationRepository |
| 230 | + from basic_memory.repository.search_repository import SearchRepository |
| 231 | + from basic_memory import db |
| 232 | + |
| 233 | + # Create database session |
| 234 | + async with db.scoped_session(session_maker) as db_session: |
| 235 | + # Create two separate projects |
| 236 | + project1 = Project(name="project1", path="/test1") |
| 237 | + project2 = Project(name="project2", path="/test2") |
| 238 | + db_session.add(project1) |
| 239 | + db_session.add(project2) |
| 240 | + await db_session.flush() |
| 241 | + |
| 242 | + # Create entities in project1 |
| 243 | + entity1_p1 = Entity( |
| 244 | + title="Entity1_P1", |
| 245 | + entity_type="document", |
| 246 | + content_type="text/markdown", |
| 247 | + project_id=project1.id, |
| 248 | + permalink="project1/entity1", |
| 249 | + file_path="project1/entity1.md", |
| 250 | + created_at=datetime.now(UTC), |
| 251 | + updated_at=datetime.now(UTC) |
| 252 | + ) |
| 253 | + entity2_p1 = Entity( |
| 254 | + title="Entity2_P1", |
| 255 | + entity_type="document", |
| 256 | + content_type="text/markdown", |
| 257 | + project_id=project1.id, |
| 258 | + permalink="project1/entity2", |
| 259 | + file_path="project1/entity2.md", |
| 260 | + created_at=datetime.now(UTC), |
| 261 | + updated_at=datetime.now(UTC) |
| 262 | + ) |
| 263 | + |
| 264 | + # Create entities in project2 |
| 265 | + entity1_p2 = Entity( |
| 266 | + title="Entity1_P2", |
| 267 | + entity_type="document", |
| 268 | + content_type="text/markdown", |
| 269 | + project_id=project2.id, |
| 270 | + permalink="project2/entity1", |
| 271 | + file_path="project2/entity1.md", |
| 272 | + created_at=datetime.now(UTC), |
| 273 | + updated_at=datetime.now(UTC) |
| 274 | + ) |
| 275 | + |
| 276 | + db_session.add_all([entity1_p1, entity2_p1, entity1_p2]) |
| 277 | + await db_session.flush() |
| 278 | + |
| 279 | + # Create relation in project1 (between entities of project1) |
| 280 | + relation_p1 = Relation( |
| 281 | + from_id=entity1_p1.id, |
| 282 | + to_id=entity2_p1.id, |
| 283 | + to_name="Entity2_P1", |
| 284 | + relation_type="connects_to" |
| 285 | + ) |
| 286 | + db_session.add(relation_p1) |
| 287 | + await db_session.commit() |
| 288 | + |
| 289 | + # Create repositories for project1 |
| 290 | + search_repo_p1 = SearchRepository(session_maker, project1.id) |
| 291 | + entity_repo_p1 = EntityRepository(session_maker, project1.id) |
| 292 | + obs_repo_p1 = ObservationRepository(session_maker, project1.id) |
| 293 | + context_service_p1 = ContextService(search_repo_p1, entity_repo_p1, obs_repo_p1) |
| 294 | + |
| 295 | + # Create repositories for project2 |
| 296 | + search_repo_p2 = SearchRepository(session_maker, project2.id) |
| 297 | + entity_repo_p2 = EntityRepository(session_maker, project2.id) |
| 298 | + obs_repo_p2 = ObservationRepository(session_maker, project2.id) |
| 299 | + context_service_p2 = ContextService(search_repo_p2, entity_repo_p2, obs_repo_p2) |
| 300 | + |
| 301 | + # Test: find_related for project1 should only return project1 entities |
| 302 | + type_id_pairs_p1 = [("entity", entity1_p1.id)] |
| 303 | + related_p1 = await context_service_p1.find_related(type_id_pairs_p1, max_depth=2) |
| 304 | + |
| 305 | + # Verify only project1 entities are returned |
| 306 | + related_entity_ids = [r.id for r in related_p1 if r.type == "entity"] |
| 307 | + assert entity2_p1.id in related_entity_ids # Should find connected entity2 in project1 |
| 308 | + assert entity1_p2.id not in related_entity_ids # Should NOT find entity from project2 |
| 309 | + |
| 310 | + # Test: find_related for project2 should return empty (no relations) |
| 311 | + type_id_pairs_p2 = [("entity", entity1_p2.id)] |
| 312 | + related_p2 = await context_service_p2.find_related(type_id_pairs_p2, max_depth=2) |
| 313 | + |
| 314 | + # Project2 has no relations, so should return empty |
| 315 | + assert len(related_p2) == 0 |
| 316 | + |
| 317 | + # Double-check: verify entities exist in their respective projects |
| 318 | + assert entity1_p1.project_id == project1.id |
| 319 | + assert entity2_p1.project_id == project1.id |
| 320 | + assert entity1_p2.project_id == project2.id |
0 commit comments