Skip to content

Commit f916662

Browse files
authored
fix(core): allow cross-project context traversal
1 parent 0c9800c commit f916662

6 files changed

Lines changed: 342 additions & 32 deletions

File tree

src/basic_memory/api/v2/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919

2020
class EntityBatchLookup(Protocol):
21-
async def find_by_ids_for_hydration(self, ids: List[int]) -> Sequence[Any]: ...
21+
async def find_by_ids_for_hydration(
22+
self, ids: List[int], *, include_cross_project: bool = False
23+
) -> Sequence[Any]: ...
2224

2325

2426
class EntityServiceBatchLookup(Protocol):
@@ -88,7 +90,7 @@ async def to_graph_context(
8890
result_count=len(entity_ids_needed),
8991
):
9092
entities = await entity_repository.find_by_ids_for_hydration(
91-
list(entity_ids_needed)
93+
list(entity_ids_needed), include_cross_project=True
9294
)
9395
for e in entities:
9496
entity_title_lookup[e.id] = e.title

src/basic_memory/repository/entity_repository.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,21 +178,31 @@ async def get_all_permalinks(self) -> List[str]:
178178
result = await self.execute_query(query, use_query_options=False)
179179
return list(result.scalars().all())
180180

181-
async def find_by_ids_for_hydration(self, ids: List[int]) -> Sequence[Entity]:
181+
async def find_by_ids_for_hydration(
182+
self, ids: List[int], *, include_cross_project: bool = False
183+
) -> Sequence[Entity]:
182184
"""Fetch minimal entity fields needed for context hydration.
183185
184186
Context hydration only needs an entity's primary key, title, and external
185187
UUID. Keeping this separate from find_by_ids avoids the relationship eager
186188
loads that are useful for full entity reads but expensive for response shaping.
189+
190+
Args:
191+
ids: Entity IDs to hydrate.
192+
include_cross_project: Include IDs outside this repository's project scope.
193+
Use only for IDs already reached through validated graph traversal.
187194
"""
188195
if not ids:
189196
return []
190197

191198
query = (
192-
self.select()
199+
select(Entity)
193200
.where(Entity.id.in_(ids))
194201
.options(load_only(Entity.id, Entity.title, Entity.external_id))
195202
)
203+
if not include_cross_project:
204+
query = self._add_project_filter(query)
205+
196206
result = await self.execute_query(query, use_query_options=False)
197207
return list(result.scalars().all())
198208

src/basic_memory/services/context_service.py

Lines changed: 55 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,14 @@ async def find_related(
333333
relation_date_filter = ""
334334
timeframe_condition = ""
335335

336-
# Add project filtering for security - ensure all entities and relations belong to the same project
337-
project_filter = "AND e.project_id = :project_id"
338-
relation_project_filter = "AND e_from.project_id = :project_id"
336+
# Trigger: build_context starts from a project-scoped search result.
337+
# Why: the seed entity must belong to the requested project, but an
338+
# explicit relation edge may point at another project.
339+
# Outcome: traversal follows only project-owned edges from reached
340+
# entities, instead of forcing every reached entity into the seed project.
341+
seed_project_filter = "AND e.project_id = :project_id"
342+
connected_entity_project_filter = ""
343+
relation_project_filter = "AND e_from.project_id = r.project_id"
339344

340345
# Use a CTE that operates directly on entity and relation tables
341346
# This avoids the overhead of the search_index virtual table
@@ -351,7 +356,8 @@ async def find_related(
351356
query = self._build_postgres_query(
352357
entity_id_values,
353358
date_filter,
354-
project_filter,
359+
seed_project_filter,
360+
connected_entity_project_filter,
355361
relation_date_filter,
356362
relation_project_filter,
357363
timeframe_condition,
@@ -362,7 +368,8 @@ async def find_related(
362368
query = self._build_sqlite_query(
363369
entity_id_values,
364370
date_filter,
365-
project_filter,
371+
seed_project_filter,
372+
connected_entity_project_filter,
366373
relation_date_filter,
367374
relation_project_filter,
368375
timeframe_condition,
@@ -397,7 +404,8 @@ def _build_postgres_query( # pragma: no cover
397404
self,
398405
entity_id_values: str,
399406
date_filter: str,
400-
project_filter: str,
407+
seed_project_filter: str,
408+
connected_entity_project_filter: str,
401409
relation_date_filter: str,
402410
relation_project_filter: str,
403411
timeframe_condition: str,
@@ -421,11 +429,13 @@ def _build_postgres_query( # pragma: no cover
421429
0 as depth,
422430
e.id as root_id,
423431
e.created_at,
424-
e.created_at as relation_date
432+
e.created_at as relation_date,
433+
e.project_id as project_id,
434+
',' || e.id::text || ',' as entity_path
425435
FROM entity e
426436
WHERE e.id IN ({entity_id_values})
427437
{date_filter}
428-
{project_filter}
438+
{seed_project_filter}
429439
430440
UNION ALL
431441
@@ -477,15 +487,25 @@ def _build_postgres_query( # pragma: no cover
477487
CASE
478488
WHEN step_type = 1 THEN e_from.created_at
479489
ELSE eg.relation_date
480-
END as relation_date
490+
END as relation_date,
491+
CASE
492+
WHEN step_type = 1 THEN eg.project_id
493+
ELSE e.project_id
494+
END as project_id,
495+
CASE
496+
WHEN step_type = 1 THEN eg.entity_path
497+
ELSE eg.entity_path || e.id::text || ','
498+
END as entity_path
481499
FROM entity_graph eg
482500
CROSS JOIN LATERAL (VALUES (1), (2)) AS steps(step_type)
483501
JOIN relation r ON (
484502
eg.type = 'entity' AND
485-
(r.from_id = eg.id OR r.to_id = eg.id)
503+
(r.from_id = eg.id OR r.to_id = eg.id) AND
504+
r.project_id = eg.project_id
486505
)
487506
JOIN entity e_from ON (
488507
r.from_id = e_from.id
508+
{relation_date_filter}
489509
{relation_project_filter}
490510
)
491511
LEFT JOIN entity e ON (
@@ -495,10 +515,17 @@ def _build_postgres_query( # pragma: no cover
495515
ELSE r.from_id
496516
END
497517
{date_filter}
498-
{project_filter}
518+
{connected_entity_project_filter}
499519
)
500520
WHERE eg.depth < :max_depth
501-
AND (step_type = 1 OR (step_type = 2 AND e.id IS NOT NULL AND e.id != eg.id))
521+
AND (
522+
step_type = 1 OR (
523+
step_type = 2
524+
AND e.id IS NOT NULL
525+
AND e.id != eg.id
526+
AND position(',' || e.id::text || ',' in eg.entity_path) = 0
527+
)
528+
)
502529
{timeframe_condition}
503530
)
504531
-- Materialize and filter
@@ -529,7 +556,8 @@ def _build_sqlite_query(
529556
self,
530557
entity_id_values: str,
531558
date_filter: str,
532-
project_filter: str,
559+
seed_project_filter: str,
560+
connected_entity_project_filter: str,
533561
relation_date_filter: str,
534562
relation_project_filter: str,
535563
timeframe_condition: str,
@@ -555,11 +583,13 @@ def _build_sqlite_query(
555583
e.id as root_id,
556584
e.created_at,
557585
e.created_at as relation_date,
558-
0 as is_incoming
586+
0 as is_incoming,
587+
e.project_id as project_id,
588+
',' || e.id || ',' as entity_path
559589
FROM entity e
560590
WHERE e.id IN ({entity_id_values})
561591
{date_filter}
562-
{project_filter}
592+
{seed_project_filter}
563593
564594
UNION ALL
565595
@@ -580,11 +610,14 @@ def _build_sqlite_query(
580610
eg.root_id,
581611
e_from.created_at,
582612
e_from.created_at as relation_date,
583-
CASE WHEN r.from_id = eg.id THEN 0 ELSE 1 END as is_incoming
613+
CASE WHEN r.from_id = eg.id THEN 0 ELSE 1 END as is_incoming,
614+
eg.project_id as project_id,
615+
eg.entity_path as entity_path
584616
FROM entity_graph eg
585617
JOIN relation r ON (
586618
eg.type = 'entity' AND
587-
(r.from_id = eg.id OR r.to_id = eg.id)
619+
(r.from_id = eg.id OR r.to_id = eg.id) AND
620+
r.project_id = eg.project_id
588621
)
589622
JOIN entity e_from ON (
590623
r.from_id = e_from.id
@@ -615,7 +648,9 @@ def _build_sqlite_query(
615648
eg.root_id,
616649
e.created_at,
617650
eg.relation_date,
618-
eg.is_incoming
651+
eg.is_incoming,
652+
e.project_id as project_id,
653+
eg.entity_path || e.id || ',' as entity_path
619654
FROM entity_graph eg
620655
JOIN entity e ON (
621656
eg.type = 'relation' AND
@@ -624,9 +659,10 @@ def _build_sqlite_query(
624659
ELSE eg.from_id
625660
END
626661
{date_filter}
627-
{project_filter}
662+
{connected_entity_project_filter}
628663
)
629664
WHERE eg.depth < :max_depth
665+
AND instr(eg.entity_path, ',' || e.id || ',') = 0
630666
{timeframe_condition}
631667
)
632668
SELECT DISTINCT

tests/api/v2/test_memory_hydration.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,16 @@ class SpyEntityRepository:
4949

5050
def __init__(self, entities_by_id: dict[int, SimpleNamespace]):
5151
self.entities_by_id = entities_by_id
52-
self.calls: list[list[int]] = []
52+
self.calls: list[tuple[list[int], bool]] = []
5353

5454
async def find_by_ids(self, ids: list[int]):
55-
self.calls.append(ids)
55+
self.calls.append((ids, False))
5656
return [self.entities_by_id[i] for i in ids if i in self.entities_by_id]
5757

58-
async def find_by_ids_for_hydration(self, ids: list[int]):
59-
self.calls.append(ids)
58+
async def find_by_ids_for_hydration(
59+
self, ids: list[int], *, include_cross_project: bool = False
60+
):
61+
self.calls.append((ids, include_cross_project))
6062
return [self.entities_by_id[i] for i in ids if i in self.entities_by_id]
6163

6264

@@ -65,13 +67,15 @@ class LightweightOnlyEntityRepository:
6567

6668
def __init__(self, entities_by_id: dict[int, SimpleNamespace]):
6769
self.entities_by_id = entities_by_id
68-
self.hydration_calls: list[list[int]] = []
70+
self.hydration_calls: list[tuple[list[int], bool]] = []
6971

7072
async def find_by_ids(self, ids: list[int]):
7173
raise AssertionError("graph hydration must use the lightweight hydration lookup")
7274

73-
async def find_by_ids_for_hydration(self, ids: list[int]):
74-
self.hydration_calls.append(ids)
75+
async def find_by_ids_for_hydration(
76+
self, ids: list[int], *, include_cross_project: bool = False
77+
):
78+
self.hydration_calls.append((ids, include_cross_project))
7579
return [self.entities_by_id[i] for i in ids if i in self.entities_by_id]
7680

7781

@@ -177,7 +181,8 @@ async def test_to_graph_context_batches_entity_hydration_for_recent_activity():
177181
graph = await to_graph_context(context, entity_repository=repo, page=1, page_size=10)
178182

179183
assert len(repo.calls) == 1, f"Expected 1 entity lookup, got {len(repo.calls)}"
180-
assert set(repo.calls[0]) == {1, 2, 3}
184+
assert set(repo.calls[0][0]) == {1, 2, 3}
185+
assert repo.calls[0][1] is True
181186

182187
first_result = graph.results[0]
183188
first_primary = first_result.primary_result
@@ -272,7 +277,8 @@ async def test_to_graph_context_uses_lightweight_hydration_lookup():
272277
graph = await to_graph_context(context, entity_repository=repo)
273278

274279
assert len(repo.hydration_calls) == 1
275-
assert set(repo.hydration_calls[0]) == {1, 2}
280+
assert set(repo.hydration_calls[0][0]) == {1, 2}
281+
assert repo.hydration_calls[0][1] is True
276282
relation = graph.results[0].related_results[0]
277283
assert isinstance(relation, RelationSummary)
278284
assert relation.from_entity_external_id == "ext-root"

tests/repository/test_entity_repository.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1069,6 +1069,41 @@ def fail_get_load_options():
10691069
assert found[0].external_id == sample_entity.external_id
10701070

10711071

1072+
@pytest.mark.asyncio
1073+
async def test_find_by_ids_for_hydration_can_include_cross_project_entities(
1074+
entity_repository: EntityRepository, sample_entity: Entity, session_maker
1075+
):
1076+
"""Context hydration can opt into IDs reached through explicit graph edges."""
1077+
async with db.scoped_session(session_maker) as session:
1078+
other_project = Project(name="other-project", path="/other")
1079+
session.add(other_project)
1080+
await session.flush()
1081+
1082+
other_entity = Entity(
1083+
project_id=other_project.id,
1084+
title="Other Project Entity",
1085+
note_type="test",
1086+
permalink="other-project/entity",
1087+
file_path="other-project/entity.md",
1088+
content_type="text/markdown",
1089+
created_at=datetime.now(timezone.utc),
1090+
updated_at=datetime.now(timezone.utc),
1091+
)
1092+
session.add(other_entity)
1093+
await session.flush()
1094+
other_entity_id = other_entity.id
1095+
1096+
project_scoped = await entity_repository.find_by_ids_for_hydration(
1097+
[sample_entity.id, other_entity_id]
1098+
)
1099+
cross_project = await entity_repository.find_by_ids_for_hydration(
1100+
[sample_entity.id, other_entity_id], include_cross_project=True
1101+
)
1102+
1103+
assert {entity.id for entity in project_scoped} == {sample_entity.id}
1104+
assert {entity.id for entity in cross_project} == {sample_entity.id, other_entity_id}
1105+
1106+
10721107
@pytest.mark.asyncio
10731108
async def test_get_permalink_to_file_path_map(entity_repository: EntityRepository, session_maker):
10741109
"""Test getting permalink -> file_path mapping for bulk operations."""

0 commit comments

Comments
 (0)