Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions src/basic_memory/services/context_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,8 +245,8 @@ async def find_related(
# For compatibility with the old query, we still need this for filtering
values = ", ".join([f"('{t}', {i})" for t, i in type_id_pairs])

# Parameters for bindings
params = {"max_depth": max_depth, "max_results": max_results}
# Parameters for bindings - include project_id for security filtering
params = {"max_depth": max_depth, "max_results": max_results, "project_id": self.search_repository.project_id}

# Build date and timeframe filters conditionally based on since parameter
if since:
Expand All @@ -258,6 +258,10 @@ async def find_related(
date_filter = ""
relation_date_filter = ""
timeframe_condition = ""

# Add project filtering for security - ensure all entities and relations belong to the same project
project_filter = "AND e.project_id = :project_id"
relation_project_filter = "AND e_from.project_id = :project_id"

# Use a CTE that operates directly on entity and relation tables
# This avoids the overhead of the search_index virtual table
Expand All @@ -284,6 +288,7 @@ async def find_related(
FROM entity e
WHERE e.id IN ({entity_id_values})
{date_filter}
{project_filter}

UNION ALL

Expand Down Expand Up @@ -314,8 +319,12 @@ async def find_related(
JOIN entity e_from ON (
r.from_id = e_from.id
{relation_date_filter}
{relation_project_filter}
)
LEFT JOIN entity e_to ON (r.to_id = e_to.id)
WHERE eg.depth < :max_depth
-- Ensure to_entity (if exists) also belongs to same project
AND (r.to_id IS NULL OR e_to.project_id = :project_id)

UNION ALL

Expand Down Expand Up @@ -348,6 +357,7 @@ async def find_related(
ELSE eg.from_id
END
{date_filter}
{project_filter}
)
WHERE eg.depth < :max_depth
-- Only include entities connected by relations within timeframe if specified
Expand Down
100 changes: 100 additions & 0 deletions tests/services/test_context_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from basic_memory.schemas.memory import memory_url, memory_url_path
from basic_memory.schemas.search import SearchItemType
from basic_memory.services.context_service import ContextService
from basic_memory.models.knowledge import Entity, Relation
from basic_memory.models.project import Project


@pytest_asyncio.fixture
Expand Down Expand Up @@ -218,3 +220,101 @@ async def test_context_metadata(context_service, test_graph):
assert metadata.depth == 2
assert metadata.generated_at is not None
assert metadata.primary_count > 0


@pytest.mark.asyncio
async def test_project_isolation_in_find_related(session_maker):
"""Test that find_related respects project boundaries and doesn't leak data."""
from basic_memory.repository.entity_repository import EntityRepository
from basic_memory.repository.observation_repository import ObservationRepository
from basic_memory.repository.search_repository import SearchRepository
from basic_memory import db

# Create database session
async with db.scoped_session(session_maker) as db_session:
# Create two separate projects
project1 = Project(name="project1", path="/test1")
project2 = Project(name="project2", path="/test2")
db_session.add(project1)
db_session.add(project2)
await db_session.flush()

# Create entities in project1
entity1_p1 = Entity(
title="Entity1_P1",
entity_type="document",
content_type="text/markdown",
project_id=project1.id,
permalink="project1/entity1",
file_path="project1/entity1.md",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC)
)
entity2_p1 = Entity(
title="Entity2_P1",
entity_type="document",
content_type="text/markdown",
project_id=project1.id,
permalink="project1/entity2",
file_path="project1/entity2.md",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC)
)

# Create entities in project2
entity1_p2 = Entity(
title="Entity1_P2",
entity_type="document",
content_type="text/markdown",
project_id=project2.id,
permalink="project2/entity1",
file_path="project2/entity1.md",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC)
)

db_session.add_all([entity1_p1, entity2_p1, entity1_p2])
await db_session.flush()

# Create relation in project1 (between entities of project1)
relation_p1 = Relation(
from_id=entity1_p1.id,
to_id=entity2_p1.id,
to_name="Entity2_P1",
relation_type="connects_to"
)
db_session.add(relation_p1)
await db_session.commit()

# Create repositories for project1
search_repo_p1 = SearchRepository(session_maker, project1.id)
entity_repo_p1 = EntityRepository(session_maker, project1.id)
obs_repo_p1 = ObservationRepository(session_maker, project1.id)
context_service_p1 = ContextService(search_repo_p1, entity_repo_p1, obs_repo_p1)

# Create repositories for project2
search_repo_p2 = SearchRepository(session_maker, project2.id)
entity_repo_p2 = EntityRepository(session_maker, project2.id)
obs_repo_p2 = ObservationRepository(session_maker, project2.id)
context_service_p2 = ContextService(search_repo_p2, entity_repo_p2, obs_repo_p2)

# Test: find_related for project1 should only return project1 entities
type_id_pairs_p1 = [("entity", entity1_p1.id)]
related_p1 = await context_service_p1.find_related(type_id_pairs_p1, max_depth=2)

# Verify only project1 entities are returned
related_entity_ids = [r.id for r in related_p1 if r.type == "entity"]
assert entity2_p1.id in related_entity_ids # Should find connected entity2 in project1
assert entity1_p2.id not in related_entity_ids # Should NOT find entity from project2

# Test: find_related for project2 should return empty (no relations)
type_id_pairs_p2 = [("entity", entity1_p2.id)]
related_p2 = await context_service_p2.find_related(type_id_pairs_p2, max_depth=2)

# Project2 has no relations, so should return empty
assert len(related_p2) == 0

# Double-check: verify entities exist in their respective projects
assert entity1_p1.project_id == project1.id
assert entity2_p1.project_id == project1.id
assert entity1_p2.project_id == project2.id
Loading