diff --git a/src/basic_memory/repository/entity_repository.py b/src/basic_memory/repository/entity_repository.py index 56271800..244bad8c 100644 --- a/src/basic_memory/repository/entity_repository.py +++ b/src/basic_memory/repository/entity_repository.py @@ -45,7 +45,17 @@ async def get_by_id(self, entity_id: int) -> Optional[Entity]: # pragma: no cov async with db.scoped_session(self.session_maker) as session: return await self.select_by_id(session, entity_id) - async def get_by_external_id(self, external_id: str) -> Optional[Entity]: + async def _find_one_by_query(self, query, *, load_relations: bool) -> Optional[Entity]: + """Return one entity row with optional eager loading.""" + if load_relations: + return await self.find_one(query) + + result = await self.execute_query(query, use_query_options=False) + return result.scalars().one_or_none() + + async def get_by_external_id( + self, external_id: str, *, load_relations: bool = True + ) -> Optional[Entity]: """Get entity by external UUID. Args: @@ -54,21 +64,21 @@ async def get_by_external_id(self, external_id: str) -> Optional[Entity]: Returns: Entity if found, None otherwise """ - query = ( - self.select().where(Entity.external_id == external_id).options(*self.get_load_options()) - ) - return await self.find_one(query) + query = self.select().where(Entity.external_id == external_id) + return await self._find_one_by_query(query, load_relations=load_relations) - async def get_by_permalink(self, permalink: str) -> Optional[Entity]: + async def get_by_permalink( + self, permalink: str, *, load_relations: bool = True + ) -> Optional[Entity]: """Get entity by permalink. Args: permalink: Unique identifier for the entity """ - query = self.select().where(Entity.permalink == permalink).options(*self.get_load_options()) - return await self.find_one(query) + query = self.select().where(Entity.permalink == permalink) + return await self._find_one_by_query(query, load_relations=load_relations) - async def get_by_title(self, title: str) -> Sequence[Entity]: + async def get_by_title(self, title: str, *, load_relations: bool = True) -> Sequence[Entity]: """Get entities by title, ordered by shortest path first. When multiple entities share the same title (in different folders), @@ -82,23 +92,20 @@ async def get_by_title(self, title: str) -> Sequence[Entity]: self.select() .where(Entity.title == title) .order_by(func.length(Entity.file_path), Entity.file_path) - .options(*self.get_load_options()) ) - result = await self.execute_query(query) + result = await self.execute_query(query, use_query_options=load_relations) return list(result.scalars().all()) - async def get_by_file_path(self, file_path: Union[Path, str]) -> Optional[Entity]: + async def get_by_file_path( + self, file_path: Union[Path, str], *, load_relations: bool = True + ) -> Optional[Entity]: """Get entity by file_path. Args: file_path: Path to the entity file (will be converted to string internally) """ - query = ( - self.select() - .where(Entity.file_path == Path(file_path).as_posix()) - .options(*self.get_load_options()) - ) - return await self.find_one(query) + query = self.select().where(Entity.file_path == Path(file_path).as_posix()) + return await self._find_one_by_query(query, load_relations=load_relations) # ------------------------------------------------------------------------- # Lightweight methods for permalink resolution (no eager loading) diff --git a/src/basic_memory/services/entity_service.py b/src/basic_memory/services/entity_service.py index 931cbae3..f387c016 100644 --- a/src/basic_memory/services/entity_service.py +++ b/src/basic_memory/services/entity_service.py @@ -242,9 +242,17 @@ async def create_or_update_entity(self, schema: EntitySchema) -> Tuple[EntityMod # Try to find existing entity using strict resolution (no fuzzy search) # This prevents incorrectly matching similar file paths like "Node A.md" and "Node C.md" - existing = await self.link_resolver.resolve_link(schema.file_path, strict=True) + existing = await self.link_resolver.resolve_link( + schema.file_path, + strict=True, + load_relations=False, + ) if not existing and schema.permalink: - existing = await self.link_resolver.resolve_link(schema.permalink, strict=True) + existing = await self.link_resolver.resolve_link( + schema.permalink, + strict=True, + load_relations=False, + ) if existing: logger.debug(f"Found existing entity: {existing.file_path}") @@ -840,10 +848,22 @@ async def update_entity_and_observations( """ logger.debug(f"Updating entity and observations: {file_path}") - db_entity = await self.repository.get_by_file_path(file_path.as_posix()) + with telemetry.scope( + "upsert.update.fetch_entity", + domain="entity_service", + action="upsert", + phase="fetch_entity", + ): + db_entity = await self.repository.get_by_file_path(file_path.as_posix()) # Clear observations for entity - await self.observation_repository.delete_by_fields(entity_id=db_entity.id) + with telemetry.scope( + "upsert.update.delete_observations", + domain="entity_service", + action="upsert", + phase="delete_observations", + ): + await self.observation_repository.delete_by_fields(entity_id=db_entity.id) # add new observations observations = [ @@ -857,7 +877,14 @@ async def update_entity_and_observations( ) for obs in markdown.observations ] - await self.observation_repository.add_all(observations) + with telemetry.scope( + "upsert.update.insert_observations", + domain="entity_service", + action="upsert", + phase="insert_observations", + count=len(observations), + ): + await self.observation_repository.add_all(observations) # update values from markdown db_entity = entity_model_from_markdown(file_path, markdown, db_entity) @@ -871,10 +898,16 @@ async def update_entity_and_observations( db_entity.last_updated_by = user_id # update entity - return await self.repository.update( - db_entity.id, - db_entity, - ) + with telemetry.scope( + "upsert.update.save_entity", + domain="entity_service", + action="upsert", + phase="save_entity", + ): + return await self.repository.update( + db_entity.id, + db_entity, + ) async def upsert_entity_from_markdown( self, @@ -888,20 +921,30 @@ async def upsert_entity_from_markdown( created = await self.create_entity_from_markdown(file_path, markdown) else: created = await self.update_entity_and_observations(file_path, markdown) - return await self.update_entity_relations(created.file_path, markdown) + # Pass entity directly — avoids redundant get_by_file_path inside update_entity_relations + return await self.update_entity_relations(created, markdown) async def update_entity_relations( self, - path: str, + entity: EntityModel, markdown: EntityMarkdown, ) -> EntityModel: - """Update relations for entity""" - logger.debug(f"Updating relations for entity: {path}") + """Update relations for entity. - db_entity = await self.repository.get_by_file_path(path) + Accepts the entity object directly to avoid a redundant DB fetch. + Only entity.id and entity.permalink are used from the passed-in object. + """ + entity_id = entity.id + logger.debug(f"Updating relations for entity: {entity.file_path}") # Clear existing relations first - await self.relation_repository.delete_outgoing_relations_from_entity(db_entity.id) + with telemetry.scope( + "upsert.relations.delete_existing", + domain="entity_service", + action="upsert", + phase="delete_relations", + ): + await self.relation_repository.delete_outgoing_relations_from_entity(entity_id) # Batch resolve all relation targets in parallel if markdown.relations: @@ -911,12 +954,23 @@ async def update_entity_relations( # Use strict=True to disable fuzzy search - only exact matches should create resolved relations # This ensures forward references (links to non-existent entities) remain unresolved (to_id=NULL) lookup_tasks = [ - self.link_resolver.resolve_link(rel.target, strict=True) + self.link_resolver.resolve_link( + rel.target, + strict=True, + load_relations=False, + ) for rel in markdown.relations ] # Execute all lookups in parallel - resolved_entities = await asyncio.gather(*lookup_tasks, return_exceptions=True) + with telemetry.scope( + "upsert.relations.resolve_links", + domain="entity_service", + action="upsert", + phase="resolve_links", + count=len(lookup_tasks), + ): + resolved_entities = await asyncio.gather(*lookup_tasks, return_exceptions=True) # Process results and create relation records relations_to_add = [] @@ -935,7 +989,7 @@ async def update_entity_relations( # Create the relation relation = Relation( project_id=self.relation_repository.project_id, - from_id=db_entity.id, + from_id=entity_id, to_id=target_id, to_name=target_name, relation_type=rel.type, @@ -945,22 +999,37 @@ async def update_entity_relations( # Batch insert all relations if relations_to_add: - try: - await self.relation_repository.add_all(relations_to_add) - except IntegrityError: - # Some relations might be duplicates - fall back to individual inserts - logger.debug("Batch relation insert failed, trying individual inserts") - for relation in relations_to_add: - try: - await self.relation_repository.add(relation) - except IntegrityError: - # Unique constraint violation - relation already exists - logger.debug( - f"Skipping duplicate relation {relation.relation_type} from {db_entity.permalink}" - ) - continue - - return await self.repository.get_by_file_path(path) + with telemetry.scope( + "upsert.relations.insert_relations", + domain="entity_service", + action="upsert", + phase="insert_relations", + count=len(relations_to_add), + ): + try: + await self.relation_repository.add_all(relations_to_add) + except IntegrityError: + # Some relations might be duplicates - fall back to individual inserts + logger.debug("Batch relation insert failed, trying individual inserts") + for relation in relations_to_add: + try: + await self.relation_repository.add(relation) + except IntegrityError: + # Unique constraint violation - relation already exists + logger.debug( + f"Skipping duplicate relation {relation.relation_type} from {entity.permalink}" + ) + continue + + # Reload entity with relations via PK lookup (faster than get_by_file_path string match) + with telemetry.scope( + "upsert.relations.reload_entity", + domain="entity_service", + action="upsert", + phase="reload_entity", + ): + reloaded = await self.repository.find_by_ids([entity_id]) + return reloaded[0] async def edit_entity( self, @@ -996,7 +1065,11 @@ async def edit_entity( action="edit", phase="resolve_entity", ): - entity = await self.link_resolver.resolve_link(identifier, strict=True) + entity = await self.link_resolver.resolve_link( + identifier, + strict=True, + load_relations=False, + ) if not entity: raise EntityNotFoundError(f"Entity not found: {identifier}") diff --git a/src/basic_memory/services/link_resolver.py b/src/basic_memory/services/link_resolver.py index 739e0087..89848ebf 100644 --- a/src/basic_memory/services/link_resolver.py +++ b/src/basic_memory/services/link_resolver.py @@ -47,6 +47,7 @@ async def resolve_link( use_search: bool = True, strict: bool = False, source_path: Optional[str] = None, + load_relations: bool = True, ) -> Optional[Entity]: """Resolve a markdown link to a permalink. @@ -56,6 +57,7 @@ async def resolve_link( strict: If True, only exact matches are allowed (no fuzzy search fallback) source_path: Optional path of the source file containing the link. Used to prefer notes closer to the source (context-aware resolution). + load_relations: When False, skip eager loading and return a lightweight entity row. """ logger.trace(f"Resolving link: {link_text} (source: {source_path})") @@ -70,7 +72,10 @@ async def resolve_link( # UUIDs also match the stored external_id values. try: canonical_id = str(uuid_mod.UUID(clean_text)) - entity = await self.entity_repository.get_by_external_id(canonical_id) + entity = await self.entity_repository.get_by_external_id( + canonical_id, + load_relations=load_relations, + ) if entity: logger.debug(f"Found entity by external_id: {entity.permalink}") return entity @@ -98,6 +103,7 @@ async def resolve_link( strict=strict, source_path=None, project_permalink=project.permalink, + load_relations=load_relations, ) current_project_permalink = await self._get_current_project_permalink() @@ -109,6 +115,7 @@ async def resolve_link( strict=strict, source_path=source_path, project_permalink=current_project_permalink, + load_relations=load_relations, ) if resolved: return resolved @@ -136,6 +143,7 @@ async def resolve_link( strict=strict, source_path=None, project_permalink=project.permalink, + load_relations=load_relations, ) def _normalize_link_text(self, link_text: str) -> Tuple[str, Optional[str]]: @@ -176,6 +184,7 @@ async def _resolve_in_project( strict: bool, source_path: Optional[str], project_permalink: Optional[str], + load_relations: bool, ) -> Optional[Entity]: """Resolve a link within a specific project scope.""" clean_text = link_text @@ -223,12 +232,18 @@ async def _resolve_in_project( # Try with .md extension if not relative_path.endswith(".md"): relative_path_md = f"{relative_path}.md" - entity = await entity_repository.get_by_file_path(relative_path_md) + entity = await entity_repository.get_by_file_path( + relative_path_md, + load_relations=load_relations, + ) if entity: return entity # Try as-is (already has extension or is a permalink) - entity = await entity_repository.get_by_file_path(relative_path) + entity = await entity_repository.get_by_file_path( + relative_path, + load_relations=load_relations, + ) if entity: return entity @@ -242,12 +257,18 @@ async def _resolve_in_project( # Check permalink match for candidate_permalink in permalink_candidates: - permalink_entity = await entity_repository.get_by_permalink(candidate_permalink) + permalink_entity = await entity_repository.get_by_permalink( + candidate_permalink, + load_relations=load_relations, + ) if permalink_entity and permalink_entity.id not in [c.id for c in candidates]: candidates.append(permalink_entity) # Check title matches - title_entities = await entity_repository.get_by_title(clean_text) + title_entities = await entity_repository.get_by_title( + clean_text, + load_relations=load_relations, + ) for entity in title_entities: # Avoid duplicates (permalink match might also be in title matches) if entity.id not in [c.id for c in candidates]: @@ -263,13 +284,19 @@ async def _resolve_in_project( # Standard resolution (no source context): permalink first, then title # 1. Try exact permalink match first (most efficient) for candidate_permalink in permalink_candidates: - entity = await entity_repository.get_by_permalink(candidate_permalink) + entity = await entity_repository.get_by_permalink( + candidate_permalink, + load_relations=load_relations, + ) if entity: logger.debug(f"Found exact permalink match: {entity.permalink}") return entity # 2. Try exact title match - found = await entity_repository.get_by_title(clean_text) + found = await entity_repository.get_by_title( + clean_text, + load_relations=load_relations, + ) if found: # Return first match (shortest path) if no source context entity = found[0] @@ -277,7 +304,10 @@ async def _resolve_in_project( return entity # 3. Try file path - found_path = await entity_repository.get_by_file_path(clean_text) + found_path = await entity_repository.get_by_file_path( + clean_text, + load_relations=load_relations, + ) if found_path: logger.debug(f"Found entity with path: {found_path.file_path}") return found_path @@ -285,7 +315,10 @@ async def _resolve_in_project( # 4. Try file path with .md extension if not already present if not clean_text.endswith(".md") and "/" in clean_text: file_path_with_md = f"{clean_text}.md" - found_path_md = await entity_repository.get_by_file_path(file_path_with_md) + found_path_md = await entity_repository.get_by_file_path( + file_path_with_md, + load_relations=load_relations, + ) if found_path_md: logger.debug(f"Found entity with path (with .md): {found_path_md.file_path}") return found_path_md @@ -309,7 +342,10 @@ async def _resolve_in_project( f"Selected best match from {len(results)} results: {best_match.permalink}" ) if best_match.permalink: - return await entity_repository.get_by_permalink(best_match.permalink) + return await entity_repository.get_by_permalink( + best_match.permalink, + load_relations=load_relations, + ) # if we couldn't find anything then return None return None diff --git a/tests/services/test_upsert_entity_optimization.py b/tests/services/test_upsert_entity_optimization.py new file mode 100644 index 00000000..b6e0499d --- /dev/null +++ b/tests/services/test_upsert_entity_optimization.py @@ -0,0 +1,426 @@ +"""Tests proving upsert_entity_from_markdown optimizations. + +Verifies that: +1. Redundant get_by_file_path call is eliminated (entity passed directly) +2. Final reload uses find_by_ids (PK lookup) instead of get_by_file_path (string lookup) +3. Telemetry sub-spans are emitted for each DB phase +4. Correctness is preserved for create, update, and edit flows +""" + +from __future__ import annotations + +import importlib +from contextlib import contextmanager +from datetime import datetime, timezone +from pathlib import Path +from types import SimpleNamespace +from unittest.mock import AsyncMock + +import pytest + +from basic_memory.markdown.schemas import ( + EntityFrontmatter, + EntityMarkdown, + Observation as MarkdownObservation, + Relation as MarkdownRelation, +) +from basic_memory.schemas import Entity as EntitySchema +from basic_memory.services.entity_service import EntityService + +entity_service_module = importlib.import_module("basic_memory.services.entity_service") + + +# --- Helpers --- + + +def _make_markdown( + title: str = "Test Entity", + observations: list | None = None, + relations: list | None = None, +) -> EntityMarkdown: + frontmatter = EntityFrontmatter(metadata={"title": title, "type": "note"}) + return EntityMarkdown( + frontmatter=frontmatter, + observations=observations or [], + relations=relations or [], + created=datetime.now(timezone.utc), + modified=datetime.now(timezone.utc), + ) + + +def _capture_spans(): + spans: list[tuple[str, dict]] = [] + + @contextmanager + def fake_span(name: str, **attrs): + spans.append((name, attrs)) + yield + + return spans, fake_span + + +# --- Optimization 1: No redundant get_by_file_path in update_entity_relations --- + + +@pytest.mark.asyncio +async def test_upsert_update_does_not_refetch_entity(entity_service: EntityService, monkeypatch): + """update_entity_relations should NOT call get_by_file_path — entity is passed directly.""" + # Create an entity first + entity = await entity_service.create_entity( + EntitySchema( + title="Refetch Test", + directory="notes", + note_type="note", + content="# Refetch Test\n\n## Observations\n- [fact] some fact", + ) + ) + + # Spy on get_by_file_path calls + original_get_by_file_path = entity_service.repository.get_by_file_path + call_count = 0 + + async def spy_get_by_file_path(*args, **kwargs): + nonlocal call_count + call_count += 1 + return await original_get_by_file_path(*args, **kwargs) + + monkeypatch.setattr(entity_service.repository, "get_by_file_path", spy_get_by_file_path) + + # Run upsert with is_new=False — this calls update_entity_and_observations + update_entity_relations + markdown = _make_markdown( + title="Refetch Test", + observations=[MarkdownObservation(content="updated fact", category="fact")], + ) + await entity_service.upsert_entity_from_markdown(Path(entity.file_path), markdown, is_new=False) + + # update_entity_and_observations calls get_by_file_path once (to load the entity) + # update_entity_relations should NOT call it at all (entity passed directly) + assert call_count == 1, ( + f"Expected 1 get_by_file_path call (in update_entity_and_observations only), " + f"got {call_count}. update_entity_relations should not re-fetch." + ) + + +# --- Optimization 2: Final reload uses find_by_ids (PK) not get_by_file_path --- + + +@pytest.mark.asyncio +async def test_update_entity_relations_uses_pk_reload(entity_service: EntityService, monkeypatch): + """update_entity_relations should use find_by_ids for the final reload, not get_by_file_path.""" + entity = await entity_service.create_entity( + EntitySchema( + title="PK Reload Test", + directory="notes", + note_type="note", + content="# PK Reload Test", + ) + ) + + # Spy on find_by_ids calls + original_find_by_ids = entity_service.repository.find_by_ids + find_by_ids_calls = [] + + async def spy_find_by_ids(ids): + find_by_ids_calls.append(ids) + return await original_find_by_ids(ids) + + monkeypatch.setattr(entity_service.repository, "find_by_ids", spy_find_by_ids) + + markdown = _make_markdown(title="PK Reload Test") + await entity_service.upsert_entity_from_markdown(Path(entity.file_path), markdown, is_new=False) + + # update_entity_relations should call find_by_ids once with the entity's PK + assert len(find_by_ids_calls) == 1 + assert find_by_ids_calls[0] == [entity.id] + + +@pytest.mark.asyncio +async def test_create_or_update_entity_uses_lightweight_exact_resolution( + entity_service: EntityService, monkeypatch +): + """create_or_update_entity should use strict lookups without eager relation loading.""" + schema = EntitySchema( + title="Create Or Update", + directory="notes", + note_type="note", + content="# Create Or Update", + ) + sentinel_entity = SimpleNamespace(file_path="notes/existing.md") + resolve_calls: list[tuple[str, dict]] = [] + + async def fake_resolve_link(link_text: str, **kwargs): + resolve_calls.append((link_text, kwargs)) + if link_text == schema.file_path: + return None + return sentinel_entity + + monkeypatch.setattr(entity_service.link_resolver, "resolve_link", fake_resolve_link) + monkeypatch.setattr(entity_service, "update_entity", AsyncMock(return_value=sentinel_entity)) + + entity, is_new = await entity_service.create_or_update_entity(schema) + + assert entity is sentinel_entity + assert is_new is False + assert resolve_calls == [ + (schema.file_path, {"strict": True, "load_relations": False}), + (schema.permalink, {"strict": True, "load_relations": False}), + ] + + +# --- Telemetry sub-spans --- + + +@pytest.mark.asyncio +async def test_upsert_update_emits_sub_spans(entity_service: EntityService, monkeypatch): + """upsert_entity_from_markdown (update path) should emit sub-spans for each DB phase.""" + entity = await entity_service.create_entity( + EntitySchema( + title="Span Test", + directory="notes", + note_type="note", + content="# Span Test\n\n## Observations\n- [fact] original", + ) + ) + + spans, fake_span = _capture_spans() + monkeypatch.setattr(entity_service_module.telemetry, "span", fake_span) + + markdown = _make_markdown( + title="Span Test", + observations=[MarkdownObservation(content="updated", category="fact")], + ) + await entity_service.upsert_entity_from_markdown(Path(entity.file_path), markdown, is_new=False) + + span_names = [name for name, _ in spans] + + # update_entity_and_observations sub-spans + assert "upsert.update.fetch_entity" in span_names + assert "upsert.update.delete_observations" in span_names + assert "upsert.update.insert_observations" in span_names + assert "upsert.update.save_entity" in span_names + + # update_entity_relations sub-spans + assert "upsert.relations.delete_existing" in span_names + assert "upsert.relations.reload_entity" in span_names + + +@pytest.mark.asyncio +async def test_upsert_with_relations_emits_resolve_and_insert_spans( + entity_service: EntityService, monkeypatch +): + """When relations exist, resolve_links and insert_relations spans should be emitted.""" + # Create two entities so the relation can resolve + await entity_service.create_entity( + EntitySchema( + title="Target Entity", + directory="notes", + note_type="note", + content="# Target Entity", + ) + ) + source = await entity_service.create_entity( + EntitySchema( + title="Source Entity", + directory="notes", + note_type="note", + content="# Source Entity", + ) + ) + + spans, fake_span = _capture_spans() + monkeypatch.setattr(entity_service_module.telemetry, "span", fake_span) + + markdown = _make_markdown( + title="Source Entity", + relations=[MarkdownRelation(type="links_to", target="Target Entity")], + ) + await entity_service.upsert_entity_from_markdown(Path(source.file_path), markdown, is_new=False) + + span_names = [name for name, _ in spans] + assert "upsert.relations.resolve_links" in span_names + assert "upsert.relations.insert_relations" in span_names + + +@pytest.mark.asyncio +async def test_upsert_with_relations_uses_lightweight_exact_resolution( + entity_service: EntityService, monkeypatch +): + """Relation target resolution should skip eager loading during upsert.""" + target = await entity_service.create_entity( + EntitySchema( + title="Lightweight Target", + directory="notes", + note_type="note", + content="# Lightweight Target", + ) + ) + source = await entity_service.create_entity( + EntitySchema( + title="Lightweight Source", + directory="notes", + note_type="note", + content="# Lightweight Source", + ) + ) + resolve_calls: list[tuple[str, dict]] = [] + + async def fake_resolve_link(link_text: str, **kwargs): + resolve_calls.append((link_text, kwargs)) + return target + + monkeypatch.setattr(entity_service.link_resolver, "resolve_link", fake_resolve_link) + + markdown = _make_markdown( + title="Lightweight Source", + relations=[MarkdownRelation(type="links_to", target="Lightweight Target")], + ) + await entity_service.upsert_entity_from_markdown(Path(source.file_path), markdown, is_new=False) + + assert resolve_calls == [ + ("Lightweight Target", {"strict": True, "load_relations": False}), + ] + + +# --- Correctness: full round-trip --- + + +@pytest.mark.asyncio +async def test_upsert_update_preserves_observations(entity_service: EntityService): + """After upsert (update path), observations should be correctly replaced.""" + entity = await entity_service.create_entity( + EntitySchema( + title="Obs Test", + directory="notes", + note_type="note", + content="# Obs Test\n\n## Observations\n- [fact] original fact", + ) + ) + assert len(entity.observations) == 1 + + markdown = _make_markdown( + title="Obs Test", + observations=[ + MarkdownObservation(content="new fact 1", category="fact"), + MarkdownObservation(content="new fact 2", category="idea"), + ], + ) + updated = await entity_service.upsert_entity_from_markdown( + Path(entity.file_path), markdown, is_new=False + ) + + assert updated.id == entity.id + assert len(updated.observations) == 2 + obs_contents = {o.content for o in updated.observations} + assert obs_contents == {"new fact 1", "new fact 2"} + + +@pytest.mark.asyncio +async def test_upsert_update_preserves_relations(entity_service: EntityService): + """After upsert (update path), relations should be correctly replaced.""" + target = await entity_service.create_entity( + EntitySchema( + title="Relation Target", + directory="notes", + note_type="note", + content="# Relation Target", + ) + ) + source = await entity_service.create_entity( + EntitySchema( + title="Relation Source", + directory="notes", + note_type="note", + content="# Relation Source\n\n## Relations\n- links_to [[Relation Target]]", + ) + ) + assert len(source.relations) == 1 + + markdown = _make_markdown( + title="Relation Source", + relations=[MarkdownRelation(type="references", target="Relation Target")], + ) + updated = await entity_service.upsert_entity_from_markdown( + Path(source.file_path), markdown, is_new=False + ) + + assert updated.id == source.id + # Old relation replaced with new one + outgoing = [r for r in updated.relations if r.from_id == source.id] + assert len(outgoing) == 1 + assert outgoing[0].relation_type == "references" + assert outgoing[0].to_id == target.id + + +@pytest.mark.asyncio +async def test_upsert_create_path_works(entity_service: EntityService): + """The is_new=True path should still work correctly.""" + markdown = _make_markdown( + title="Create Path Test", + observations=[MarkdownObservation(content="a fact", category="fact")], + ) + result = await entity_service.upsert_entity_from_markdown( + Path("notes/create-path-test.md"), markdown, is_new=True + ) + + assert result.title == "Create Path Test" + assert len(result.observations) == 1 + assert result.observations[0].content == "a fact" + + +@pytest.mark.asyncio +async def test_edit_entity_end_to_end(entity_service: EntityService): + """Full edit_entity flow uses optimized upsert and returns correct entity.""" + entity = await entity_service.create_entity( + EntitySchema( + title="Edit E2E", + directory="notes", + note_type="note", + content="# Edit E2E\n\nOriginal content.", + ) + ) + + updated = await entity_service.edit_entity( + entity.file_path, + operation="append", + content="\n\n## Observations\n- [fact] appended fact", + ) + + assert updated.id == entity.id + assert len(updated.observations) == 1 + assert updated.observations[0].content == "appended fact" + # Checksum should be set (not None) after edit completes + assert updated.checksum is not None + + +@pytest.mark.asyncio +async def test_edit_entity_uses_lightweight_identifier_resolution( + entity_service: EntityService, monkeypatch +): + """edit_entity should resolve the target note without eager relation loading.""" + entity = await entity_service.create_entity( + EntitySchema( + title="Edit Lightweight", + directory="notes", + note_type="note", + content="# Edit Lightweight\n\nOriginal content.", + ) + ) + original_resolve_link = entity_service.link_resolver.resolve_link + resolve_calls: list[tuple[str, dict]] = [] + + async def spy_resolve_link(link_text: str, **kwargs): + resolve_calls.append((link_text, kwargs)) + return await original_resolve_link(link_text, **kwargs) + + monkeypatch.setattr(entity_service.link_resolver, "resolve_link", spy_resolve_link) + + await entity_service.edit_entity( + entity.file_path, + operation="append", + content="\n\nNo relation changes here.", + ) + + assert resolve_calls[0] == ( + entity.file_path, + {"strict": True, "load_relations": False}, + )