Skip to content

Commit 733c4f7

Browse files
groksrcclaude
andauthored
fix: eliminate N+1 query in search hydrate_results (#713)
Signed-off-by: Drew Cain <groksrc@gmail.com> Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent cfa7000 commit 733c4f7

File tree

3 files changed

+331
-14
lines changed

3 files changed

+331
-14
lines changed

src/basic_memory/api/v2/utils.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Optional, List
22

33
from basic_memory import telemetry
4+
from basic_memory.models import Entity as EntityModel
45
from basic_memory.repository import EntityRepository
56
from basic_memory.repository.search_repository import SearchIndexRow
67
from basic_memory.schemas.memory import (
@@ -177,20 +178,26 @@ async def to_search_results(entity_service: EntityService, results: List[SearchI
177178
phase="hydrate_results",
178179
result_count=len(results),
179180
):
180-
entity_batches = []
181+
# Collect all unique entity IDs across all results in a single pass
182+
# This avoids N+1 queries — one batch fetch instead of one per result
183+
all_entity_ids: set[int] = set()
184+
for result in results:
185+
for eid in (result.entity_id, result.from_id, result.to_id):
186+
if eid is not None:
187+
all_entity_ids.add(eid)
188+
189+
# Single batch fetch for all entities
190+
entities_by_id: dict[int, EntityModel] = {}
181191
with telemetry.scope(
182192
"search.hydrate_results.fetch_entities",
183193
domain="search",
184194
action="search",
185195
phase="fetch_entities",
186-
result_count=len(results),
196+
result_count=len(all_entity_ids),
187197
):
188-
for result in results:
189-
entity_batches.append(
190-
await entity_service.get_entities_by_id(
191-
[result.entity_id, result.from_id, result.to_id] # pyright: ignore
192-
)
193-
)
198+
if all_entity_ids:
199+
entities = await entity_service.get_entities_by_id(list(all_entity_ids))
200+
entities_by_id = {e.id: e for e in entities}
194201

195202
search_results = []
196203
with telemetry.scope(
@@ -200,7 +207,7 @@ async def to_search_results(entity_service: EntityService, results: List[SearchI
200207
phase="shape_results",
201208
result_count=len(results),
202209
):
203-
for result, entities in zip(results, entity_batches):
210+
for result in results:
204211
entity_id = None
205212
observation_id = None
206213
relation_id = None
@@ -214,13 +221,18 @@ async def to_search_results(entity_service: EntityService, results: List[SearchI
214221
relation_id = result.id
215222
entity_id = result.entity_id
216223

224+
# Look up entities by their specific IDs
225+
parent_entity = entities_by_id.get(result.entity_id) if result.entity_id else None # pyright: ignore
226+
from_entity = entities_by_id.get(result.from_id) if result.from_id else None # pyright: ignore
227+
to_entity = entities_by_id.get(result.to_id) if result.to_id else None
228+
217229
search_results.append(
218230
SearchResult(
219231
title=result.title, # pyright: ignore
220232
type=result.type, # pyright: ignore
221233
permalink=result.permalink,
222234
score=result.score, # pyright: ignore
223-
entity=entities[0].permalink if entities else None,
235+
entity=parent_entity.permalink if parent_entity else None,
224236
content=result.content,
225237
matched_chunk=result.matched_chunk_text,
226238
file_path=result.file_path,
@@ -229,8 +241,8 @@ async def to_search_results(entity_service: EntityService, results: List[SearchI
229241
observation_id=observation_id,
230242
relation_id=relation_id,
231243
category=result.category,
232-
from_entity=entities[0].permalink if entities else None,
233-
to_entity=entities[1].permalink if len(entities) > 1 else None,
244+
from_entity=from_entity.permalink if from_entity else None,
245+
to_entity=to_entity.permalink if to_entity else None,
234246
relation_type=result.relation_type,
235247
)
236248
)
Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
"""Tests for search result hydration in to_search_results().
2+
3+
Proves that the batch fetch eliminates N+1 queries and that
4+
entity ID lookups are correct across all result types.
5+
"""
6+
7+
from __future__ import annotations
8+
9+
from datetime import datetime, timezone
10+
from types import SimpleNamespace
11+
12+
import pytest
13+
14+
from basic_memory.api.v2.utils import to_search_results
15+
from basic_memory.repository.search_index_row import SearchIndexRow
16+
17+
18+
# --- Helpers ---
19+
20+
21+
def _make_entity(id: int, permalink: str) -> SimpleNamespace:
22+
return SimpleNamespace(id=id, permalink=permalink)
23+
24+
25+
def _make_row(*, type: str, id: int, **kwargs) -> SearchIndexRow:
26+
now = datetime.now(timezone.utc)
27+
defaults = dict(
28+
project_id=1,
29+
file_path=f"notes/{id}.md",
30+
created_at=now,
31+
updated_at=now,
32+
score=1.0,
33+
title=f"Item {id}",
34+
permalink=f"notes/{id}",
35+
)
36+
defaults.update(kwargs)
37+
return SearchIndexRow(type=type, id=id, **defaults)
38+
39+
40+
class SpyEntityService:
41+
"""Tracks calls to get_entities_by_id and returns from a preset lookup."""
42+
43+
def __init__(self, entities_by_id: dict[int, SimpleNamespace]):
44+
self.entities_by_id = entities_by_id
45+
self.calls: list[list[int]] = []
46+
47+
async def get_entities_by_id(self, ids: list[int]):
48+
self.calls.append(ids)
49+
return [self.entities_by_id[i] for i in ids if i in self.entities_by_id]
50+
51+
52+
# --- Single batch fetch (N+1 elimination) ---
53+
54+
55+
@pytest.mark.asyncio
56+
async def test_single_db_call_for_multiple_results():
57+
"""Multiple search results must trigger exactly one get_entities_by_id call."""
58+
service = SpyEntityService(
59+
{
60+
1: _make_entity(1, "notes/a"),
61+
2: _make_entity(2, "notes/b"),
62+
3: _make_entity(3, "notes/c"),
63+
}
64+
)
65+
results = [
66+
_make_row(type="entity", id=1, entity_id=1),
67+
_make_row(type="entity", id=2, entity_id=2),
68+
_make_row(type="entity", id=3, entity_id=3),
69+
]
70+
71+
await to_search_results(service, results)
72+
73+
assert len(service.calls) == 1, f"Expected 1 DB call, got {len(service.calls)}"
74+
75+
76+
@pytest.mark.asyncio
77+
async def test_no_db_call_for_empty_results():
78+
"""Empty result list should not make any DB call."""
79+
service = SpyEntityService({})
80+
81+
search_results = await to_search_results(service, [])
82+
83+
assert len(service.calls) == 0
84+
assert search_results == []
85+
86+
87+
# --- ID deduplication ---
88+
89+
90+
@pytest.mark.asyncio
91+
async def test_deduplicates_entity_ids():
92+
"""Shared entity IDs across results should be fetched once, not per-result."""
93+
# entity_id=1 appears in all three results, from_id=1 overlaps with entity_id
94+
service = SpyEntityService(
95+
{
96+
1: _make_entity(1, "notes/shared"),
97+
2: _make_entity(2, "notes/target-a"),
98+
3: _make_entity(3, "notes/target-b"),
99+
}
100+
)
101+
results = [
102+
_make_row(type="relation", id=10, entity_id=1, from_id=1, to_id=2, relation_type="links"),
103+
_make_row(type="relation", id=11, entity_id=1, from_id=1, to_id=3, relation_type="links"),
104+
]
105+
106+
await to_search_results(service, results)
107+
108+
# Single call with deduplicated IDs: {1, 2, 3}
109+
assert len(service.calls) == 1
110+
fetched_ids = set(service.calls[0])
111+
assert fetched_ids == {1, 2, 3}
112+
113+
114+
# --- Correct entity-to-field mapping ---
115+
116+
117+
@pytest.mark.asyncio
118+
async def test_entity_result_maps_permalink():
119+
"""Entity results should populate the 'entity' field with the entity's permalink."""
120+
service = SpyEntityService({5: _make_entity(5, "notes/my-entity")})
121+
results = [_make_row(type="entity", id=5, entity_id=5)]
122+
123+
search_results = await to_search_results(service, results)
124+
125+
assert len(search_results) == 1
126+
r = search_results[0]
127+
assert r.entity == "notes/my-entity"
128+
assert r.entity_id == 5
129+
assert r.from_entity is None
130+
assert r.to_entity is None
131+
132+
133+
@pytest.mark.asyncio
134+
async def test_observation_result_maps_parent_entity():
135+
"""Observation results should populate 'entity' with the parent entity's permalink."""
136+
service = SpyEntityService({10: _make_entity(10, "notes/parent")})
137+
results = [_make_row(type="observation", id=20, entity_id=10)]
138+
139+
search_results = await to_search_results(service, results)
140+
141+
r = search_results[0]
142+
assert r.entity == "notes/parent"
143+
assert r.entity_id == 10
144+
assert r.observation_id == 20
145+
assert r.from_entity is None
146+
assert r.to_entity is None
147+
148+
149+
@pytest.mark.asyncio
150+
async def test_relation_result_maps_from_and_to():
151+
"""Relation results should populate entity, from_entity, and to_entity correctly."""
152+
service = SpyEntityService(
153+
{
154+
1: _make_entity(1, "notes/parent"),
155+
2: _make_entity(2, "notes/source"),
156+
3: _make_entity(3, "notes/target"),
157+
}
158+
)
159+
results = [
160+
_make_row(
161+
type="relation",
162+
id=99,
163+
entity_id=1,
164+
from_id=2,
165+
to_id=3,
166+
relation_type="references",
167+
)
168+
]
169+
170+
search_results = await to_search_results(service, results)
171+
172+
r = search_results[0]
173+
assert r.entity == "notes/parent"
174+
assert r.from_entity == "notes/source"
175+
assert r.to_entity == "notes/target"
176+
assert r.relation_id == 99
177+
assert r.relation_type == "references"
178+
179+
180+
@pytest.mark.asyncio
181+
async def test_relation_with_distinct_entity_and_from_ids():
182+
"""When entity_id != from_id, from_entity must use from_id's permalink, not entity_id's.
183+
184+
This was a bug in the old positional-index code: entities[0] was used for both
185+
'entity' and 'from_entity', which was wrong when entity_id != from_id.
186+
"""
187+
service = SpyEntityService(
188+
{
189+
10: _make_entity(10, "notes/parent-entity"),
190+
20: _make_entity(20, "notes/actual-source"),
191+
30: _make_entity(30, "notes/target"),
192+
}
193+
)
194+
results = [
195+
_make_row(
196+
type="relation",
197+
id=50,
198+
entity_id=10,
199+
from_id=20,
200+
to_id=30,
201+
relation_type="derived_from",
202+
)
203+
]
204+
205+
search_results = await to_search_results(service, results)
206+
207+
r = search_results[0]
208+
# entity should be the parent entity (entity_id=10)
209+
assert r.entity == "notes/parent-entity"
210+
# from_entity must be from_id=20, NOT entity_id=10
211+
assert r.from_entity == "notes/actual-source"
212+
assert r.to_entity == "notes/target"
213+
214+
215+
# --- Mixed result types ---
216+
217+
218+
@pytest.mark.asyncio
219+
async def test_mixed_result_types_single_fetch():
220+
"""A mix of entity, observation, and relation results should all hydrate in one fetch."""
221+
service = SpyEntityService(
222+
{
223+
1: _make_entity(1, "notes/entity-one"),
224+
2: _make_entity(2, "notes/entity-two"),
225+
3: _make_entity(3, "notes/entity-three"),
226+
}
227+
)
228+
results = [
229+
_make_row(type="entity", id=1, entity_id=1),
230+
_make_row(type="observation", id=10, entity_id=2, category="fact"),
231+
_make_row(type="relation", id=20, entity_id=1, from_id=1, to_id=3, relation_type="links"),
232+
]
233+
234+
search_results = await to_search_results(service, results)
235+
236+
# Single DB call
237+
assert len(service.calls) == 1
238+
239+
# Entity result
240+
assert search_results[0].entity == "notes/entity-one"
241+
assert search_results[0].entity_id == 1
242+
243+
# Observation result
244+
assert search_results[1].entity == "notes/entity-two"
245+
assert search_results[1].observation_id == 10
246+
247+
# Relation result
248+
assert search_results[2].from_entity == "notes/entity-one"
249+
assert search_results[2].to_entity == "notes/entity-three"
250+
251+
252+
# --- Graceful handling of missing entities ---
253+
254+
255+
@pytest.mark.asyncio
256+
async def test_missing_entity_returns_none_permalink():
257+
"""If an entity ID isn't found in the DB, permalink fields should be None."""
258+
# Only entity 1 exists; entity 99 (to_id) is missing
259+
service = SpyEntityService({1: _make_entity(1, "notes/source")})
260+
results = [
261+
_make_row(type="relation", id=5, entity_id=1, from_id=1, to_id=99, relation_type="links")
262+
]
263+
264+
search_results = await to_search_results(service, results)
265+
266+
r = search_results[0]
267+
assert r.entity == "notes/source"
268+
assert r.from_entity == "notes/source"
269+
assert r.to_entity is None # entity 99 not found
270+
271+
272+
@pytest.mark.asyncio
273+
async def test_null_ids_handled_gracefully():
274+
"""Results with None entity_id/from_id/to_id should not cause errors."""
275+
service = SpyEntityService({})
276+
# Entity result: entity_id is the row id itself, from_id/to_id are None
277+
results = [_make_row(type="entity", id=1)]
278+
279+
search_results = await to_search_results(service, results)
280+
281+
# No entity_id on the row means no fetch needed, all fields None
282+
r = search_results[0]
283+
assert r.entity is None
284+
assert r.from_entity is None
285+
assert r.to_entity is None
286+
287+
288+
# --- Scaling: prove O(1) DB calls ---
289+
290+
291+
@pytest.mark.asyncio
292+
async def test_single_db_call_scales_to_many_results():
293+
"""Even with many results, only one DB call should be made."""
294+
n = 50
295+
entities = {i: _make_entity(i, f"notes/e-{i}") for i in range(1, n + 1)}
296+
service = SpyEntityService(entities)
297+
results = [_make_row(type="entity", id=i, entity_id=i) for i in range(1, n + 1)]
298+
299+
search_results = await to_search_results(service, results)
300+
301+
assert len(service.calls) == 1, f"Expected 1 DB call for {n} results, got {len(service.calls)}"
302+
assert len(search_results) == n
303+
# Every result got its permalink
304+
for i, r in enumerate(search_results, start=1):
305+
assert r.entity == f"notes/e-{i}"

tests/api/v2/test_utils_telemetry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ async def test_to_search_results_emits_hydration_spans(monkeypatch) -> None:
3333
class FakeEntityService:
3434
async def get_entities_by_id(self, ids):
3535
return [
36-
SimpleNamespace(permalink="notes/root"),
37-
SimpleNamespace(permalink="notes/child"),
36+
SimpleNamespace(id=1, permalink="notes/root"),
37+
SimpleNamespace(id=2, permalink="notes/child"),
3838
]
3939

4040
now = datetime.now(timezone.utc)

0 commit comments

Comments
 (0)