Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions src/basic_memory/api/v2/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional, List

from basic_memory import telemetry
from basic_memory.models import Entity as EntityModel
from basic_memory.repository import EntityRepository
from basic_memory.repository.search_repository import SearchIndexRow
from basic_memory.schemas.memory import (
Expand Down Expand Up @@ -177,20 +178,26 @@ async def to_search_results(entity_service: EntityService, results: List[SearchI
phase="hydrate_results",
result_count=len(results),
):
entity_batches = []
# Collect all unique entity IDs across all results in a single pass
# This avoids N+1 queries — one batch fetch instead of one per result
all_entity_ids: set[int] = set()
for result in results:
for eid in (result.entity_id, result.from_id, result.to_id):
if eid is not None:
all_entity_ids.add(eid)

# Single batch fetch for all entities
entities_by_id: dict[int, EntityModel] = {}
with telemetry.scope(
"search.hydrate_results.fetch_entities",
domain="search",
action="search",
phase="fetch_entities",
result_count=len(results),
result_count=len(all_entity_ids),
):
for result in results:
entity_batches.append(
await entity_service.get_entities_by_id(
[result.entity_id, result.from_id, result.to_id] # pyright: ignore
)
)
if all_entity_ids:
entities = await entity_service.get_entities_by_id(list(all_entity_ids))
entities_by_id = {e.id: e for e in entities}

search_results = []
with telemetry.scope(
Expand All @@ -200,7 +207,7 @@ async def to_search_results(entity_service: EntityService, results: List[SearchI
phase="shape_results",
result_count=len(results),
):
for result, entities in zip(results, entity_batches):
for result in results:
entity_id = None
observation_id = None
relation_id = None
Expand All @@ -214,13 +221,18 @@ async def to_search_results(entity_service: EntityService, results: List[SearchI
relation_id = result.id
entity_id = result.entity_id

# Look up entities by their specific IDs
parent_entity = entities_by_id.get(result.entity_id) if result.entity_id else None # pyright: ignore
from_entity = entities_by_id.get(result.from_id) if result.from_id else None # pyright: ignore
to_entity = entities_by_id.get(result.to_id) if result.to_id else None

search_results.append(
SearchResult(
title=result.title, # pyright: ignore
type=result.type, # pyright: ignore
permalink=result.permalink,
score=result.score, # pyright: ignore
entity=entities[0].permalink if entities else None,
entity=parent_entity.permalink if parent_entity else None,
content=result.content,
matched_chunk=result.matched_chunk_text,
file_path=result.file_path,
Expand All @@ -229,8 +241,8 @@ async def to_search_results(entity_service: EntityService, results: List[SearchI
observation_id=observation_id,
relation_id=relation_id,
category=result.category,
from_entity=entities[0].permalink if entities else None,
to_entity=entities[1].permalink if len(entities) > 1 else None,
from_entity=from_entity.permalink if from_entity else None,
to_entity=to_entity.permalink if to_entity else None,
relation_type=result.relation_type,
)
)
Expand Down
305 changes: 305 additions & 0 deletions tests/api/v2/test_search_hydration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
"""Tests for search result hydration in to_search_results().

Proves that the batch fetch eliminates N+1 queries and that
entity ID lookups are correct across all result types.
"""

from __future__ import annotations

from datetime import datetime, timezone
from types import SimpleNamespace

import pytest

from basic_memory.api.v2.utils import to_search_results
from basic_memory.repository.search_index_row import SearchIndexRow


# --- Helpers ---


def _make_entity(id: int, permalink: str) -> SimpleNamespace:
return SimpleNamespace(id=id, permalink=permalink)


def _make_row(*, type: str, id: int, **kwargs) -> SearchIndexRow:
now = datetime.now(timezone.utc)
defaults = dict(
project_id=1,
file_path=f"notes/{id}.md",
created_at=now,
updated_at=now,
score=1.0,
title=f"Item {id}",
permalink=f"notes/{id}",
)
defaults.update(kwargs)
return SearchIndexRow(type=type, id=id, **defaults)


class SpyEntityService:
"""Tracks calls to get_entities_by_id and returns from a preset lookup."""

def __init__(self, entities_by_id: dict[int, SimpleNamespace]):
self.entities_by_id = entities_by_id
self.calls: list[list[int]] = []

async def get_entities_by_id(self, ids: list[int]):
self.calls.append(ids)
return [self.entities_by_id[i] for i in ids if i in self.entities_by_id]


# --- Single batch fetch (N+1 elimination) ---


@pytest.mark.asyncio
async def test_single_db_call_for_multiple_results():
"""Multiple search results must trigger exactly one get_entities_by_id call."""
service = SpyEntityService(
{
1: _make_entity(1, "notes/a"),
2: _make_entity(2, "notes/b"),
3: _make_entity(3, "notes/c"),
}
)
results = [
_make_row(type="entity", id=1, entity_id=1),
_make_row(type="entity", id=2, entity_id=2),
_make_row(type="entity", id=3, entity_id=3),
]

await to_search_results(service, results)

assert len(service.calls) == 1, f"Expected 1 DB call, got {len(service.calls)}"


@pytest.mark.asyncio
async def test_no_db_call_for_empty_results():
"""Empty result list should not make any DB call."""
service = SpyEntityService({})

search_results = await to_search_results(service, [])

assert len(service.calls) == 0
assert search_results == []


# --- ID deduplication ---


@pytest.mark.asyncio
async def test_deduplicates_entity_ids():
"""Shared entity IDs across results should be fetched once, not per-result."""
# entity_id=1 appears in all three results, from_id=1 overlaps with entity_id
service = SpyEntityService(
{
1: _make_entity(1, "notes/shared"),
2: _make_entity(2, "notes/target-a"),
3: _make_entity(3, "notes/target-b"),
}
)
results = [
_make_row(type="relation", id=10, entity_id=1, from_id=1, to_id=2, relation_type="links"),
_make_row(type="relation", id=11, entity_id=1, from_id=1, to_id=3, relation_type="links"),
]

await to_search_results(service, results)

# Single call with deduplicated IDs: {1, 2, 3}
assert len(service.calls) == 1
fetched_ids = set(service.calls[0])
assert fetched_ids == {1, 2, 3}


# --- Correct entity-to-field mapping ---


@pytest.mark.asyncio
async def test_entity_result_maps_permalink():
"""Entity results should populate the 'entity' field with the entity's permalink."""
service = SpyEntityService({5: _make_entity(5, "notes/my-entity")})
results = [_make_row(type="entity", id=5, entity_id=5)]

search_results = await to_search_results(service, results)

assert len(search_results) == 1
r = search_results[0]
assert r.entity == "notes/my-entity"
assert r.entity_id == 5
assert r.from_entity is None
assert r.to_entity is None


@pytest.mark.asyncio
async def test_observation_result_maps_parent_entity():
"""Observation results should populate 'entity' with the parent entity's permalink."""
service = SpyEntityService({10: _make_entity(10, "notes/parent")})
results = [_make_row(type="observation", id=20, entity_id=10)]

search_results = await to_search_results(service, results)

r = search_results[0]
assert r.entity == "notes/parent"
assert r.entity_id == 10
assert r.observation_id == 20
assert r.from_entity is None
assert r.to_entity is None


@pytest.mark.asyncio
async def test_relation_result_maps_from_and_to():
"""Relation results should populate entity, from_entity, and to_entity correctly."""
service = SpyEntityService(
{
1: _make_entity(1, "notes/parent"),
2: _make_entity(2, "notes/source"),
3: _make_entity(3, "notes/target"),
}
)
results = [
_make_row(
type="relation",
id=99,
entity_id=1,
from_id=2,
to_id=3,
relation_type="references",
)
]

search_results = await to_search_results(service, results)

r = search_results[0]
assert r.entity == "notes/parent"
assert r.from_entity == "notes/source"
assert r.to_entity == "notes/target"
assert r.relation_id == 99
assert r.relation_type == "references"


@pytest.mark.asyncio
async def test_relation_with_distinct_entity_and_from_ids():
"""When entity_id != from_id, from_entity must use from_id's permalink, not entity_id's.

This was a bug in the old positional-index code: entities[0] was used for both
'entity' and 'from_entity', which was wrong when entity_id != from_id.
"""
service = SpyEntityService(
{
10: _make_entity(10, "notes/parent-entity"),
20: _make_entity(20, "notes/actual-source"),
30: _make_entity(30, "notes/target"),
}
)
results = [
_make_row(
type="relation",
id=50,
entity_id=10,
from_id=20,
to_id=30,
relation_type="derived_from",
)
]

search_results = await to_search_results(service, results)

r = search_results[0]
# entity should be the parent entity (entity_id=10)
assert r.entity == "notes/parent-entity"
# from_entity must be from_id=20, NOT entity_id=10
assert r.from_entity == "notes/actual-source"
assert r.to_entity == "notes/target"


# --- Mixed result types ---


@pytest.mark.asyncio
async def test_mixed_result_types_single_fetch():
"""A mix of entity, observation, and relation results should all hydrate in one fetch."""
service = SpyEntityService(
{
1: _make_entity(1, "notes/entity-one"),
2: _make_entity(2, "notes/entity-two"),
3: _make_entity(3, "notes/entity-three"),
}
)
results = [
_make_row(type="entity", id=1, entity_id=1),
_make_row(type="observation", id=10, entity_id=2, category="fact"),
_make_row(type="relation", id=20, entity_id=1, from_id=1, to_id=3, relation_type="links"),
]

search_results = await to_search_results(service, results)

# Single DB call
assert len(service.calls) == 1

# Entity result
assert search_results[0].entity == "notes/entity-one"
assert search_results[0].entity_id == 1

# Observation result
assert search_results[1].entity == "notes/entity-two"
assert search_results[1].observation_id == 10

# Relation result
assert search_results[2].from_entity == "notes/entity-one"
assert search_results[2].to_entity == "notes/entity-three"


# --- Graceful handling of missing entities ---


@pytest.mark.asyncio
async def test_missing_entity_returns_none_permalink():
"""If an entity ID isn't found in the DB, permalink fields should be None."""
# Only entity 1 exists; entity 99 (to_id) is missing
service = SpyEntityService({1: _make_entity(1, "notes/source")})
results = [
_make_row(type="relation", id=5, entity_id=1, from_id=1, to_id=99, relation_type="links")
]

search_results = await to_search_results(service, results)

r = search_results[0]
assert r.entity == "notes/source"
assert r.from_entity == "notes/source"
assert r.to_entity is None # entity 99 not found


@pytest.mark.asyncio
async def test_null_ids_handled_gracefully():
"""Results with None entity_id/from_id/to_id should not cause errors."""
service = SpyEntityService({})
# Entity result: entity_id is the row id itself, from_id/to_id are None
results = [_make_row(type="entity", id=1)]

search_results = await to_search_results(service, results)

# No entity_id on the row means no fetch needed, all fields None
r = search_results[0]
assert r.entity is None
assert r.from_entity is None
assert r.to_entity is None


# --- Scaling: prove O(1) DB calls ---


@pytest.mark.asyncio
async def test_single_db_call_scales_to_many_results():
"""Even with many results, only one DB call should be made."""
n = 50
entities = {i: _make_entity(i, f"notes/e-{i}") for i in range(1, n + 1)}
service = SpyEntityService(entities)
results = [_make_row(type="entity", id=i, entity_id=i) for i in range(1, n + 1)]

search_results = await to_search_results(service, results)

assert len(service.calls) == 1, f"Expected 1 DB call for {n} results, got {len(service.calls)}"
assert len(search_results) == n
# Every result got its permalink
for i, r in enumerate(search_results, start=1):
assert r.entity == f"notes/e-{i}"
4 changes: 2 additions & 2 deletions tests/api/v2/test_utils_telemetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ async def test_to_search_results_emits_hydration_spans(monkeypatch) -> None:
class FakeEntityService:
async def get_entities_by_id(self, ids):
return [
SimpleNamespace(permalink="notes/root"),
SimpleNamespace(permalink="notes/child"),
SimpleNamespace(id=1, permalink="notes/root"),
SimpleNamespace(id=2, permalink="notes/child"),
]

now = datetime.now(timezone.utc)
Expand Down
Loading