Skip to content

Commit 5a24c46

Browse files
committed
fix(core): address note_content review feedback
Signed-off-by: phernandez <paul@basicmachines.co>
1 parent f13742a commit 5a24c46

File tree

3 files changed

+245
-16
lines changed

3 files changed

+245
-16
lines changed

src/basic_memory/alembic/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666

6767

6868
# Add this function to tell Alembic what to include/exclude
69-
def include_object(object, name, type_, reflected, compare_to):
69+
def include_object(obj, name, type_, reflected, compare_to):
7070
# Ignore SQLite FTS tables
7171
if type_ == "table" and name.startswith("search_index"):
7272
return False

src/basic_memory/repository/note_content_repository.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,22 @@ def __init__(self, session_maker: async_sessionmaker[AsyncSession], project_id:
3434
"""Initialize with session maker and project-scoped filtering."""
3535
super().__init__(session_maker, NoteContent, project_id=project_id)
3636

37-
def _coerce_note_content(self, data: Mapping[str, Any] | NoteContent) -> NoteContent:
38-
"""Convert input data to a NoteContent model while preserving nullable fields."""
37+
def _coerce_note_content(
38+
self, data: Mapping[str, Any] | NoteContent
39+
) -> tuple[NoteContent, set[str]]:
40+
"""Convert input data to a NoteContent model and track explicit fields."""
3941
if isinstance(data, NoteContent):
40-
return data
42+
model_data = {
43+
key: value for key, value in data.__dict__.items() if key in self.valid_columns
44+
}
45+
else:
46+
model_data = {key: value for key, value in data.items() if key in self.valid_columns}
4147

42-
model_data = {key: value for key, value in data.items() if key in self.valid_columns}
4348
entity_id = model_data.get("entity_id")
4449
if entity_id is None:
4550
raise ValueError("entity_id is required for note_content writes")
4651

47-
return NoteContent(**model_data)
52+
return NoteContent(**model_data), set(model_data)
4853

4954
async def _load_entity_identity(self, session: AsyncSession, entity_id: int) -> Entity:
5055
"""Load the owning entity so duplicated identity fields stay aligned."""
@@ -80,13 +85,32 @@ async def get_by_external_id(self, external_id: str) -> Optional[NoteContent]:
8085
return await self.find_one(query)
8186

8287
async def get_by_file_path(self, file_path: Path | str) -> Optional[NoteContent]:
83-
"""Get note content by the mirrored entity file path."""
84-
query = self.select().where(NoteContent.file_path == Path(file_path).as_posix())
85-
return await self.find_one(query)
88+
"""Get note content by file path, preferring rows whose entity still owns that path."""
89+
normalized_path = Path(file_path).as_posix()
90+
91+
# Trigger: note_content mirrors entity.file_path but does not enforce project-level uniqueness.
92+
# Why: entity renames can leave stale mirrored paths behind until note_content realigns.
93+
# Outcome: prefer the row whose current entity path still matches, then the newest mirror.
94+
query = (
95+
self.select()
96+
.join(Entity, Entity.id == NoteContent.entity_id)
97+
.where(NoteContent.file_path == normalized_path)
98+
.order_by(
99+
(Entity.file_path == normalized_path).desc(),
100+
NoteContent.updated_at.desc(),
101+
NoteContent.entity_id.desc(),
102+
)
103+
.limit(1)
104+
.options(*self.get_load_options())
105+
)
106+
107+
async with db.scoped_session(self.session_maker) as session:
108+
result = await session.execute(query)
109+
return result.scalars().first()
86110

87111
async def create(self, data: Mapping[str, Any] | NoteContent) -> NoteContent:
88112
"""Create a note_content row aligned to its owning entity."""
89-
note_content = self._coerce_note_content(data)
113+
note_content, _ = self._coerce_note_content(data)
90114

91115
async with db.scoped_session(self.session_maker) as session:
92116
await self._align_identity_fields(session, note_content)
@@ -102,7 +126,7 @@ async def create(self, data: Mapping[str, Any] | NoteContent) -> NoteContent:
102126

103127
async def upsert(self, data: Mapping[str, Any] | NoteContent) -> NoteContent:
104128
"""Insert or update note_content while keeping mirrored identity fields in sync."""
105-
note_content = self._coerce_note_content(data)
129+
note_content, provided_fields = self._coerce_note_content(data)
106130

107131
async with db.scoped_session(self.session_maker) as session:
108132
await self._align_identity_fields(session, note_content)
@@ -118,9 +142,12 @@ async def upsert(self, data: Mapping[str, Any] | NoteContent) -> NoteContent:
118142
)
119143
return created
120144

121-
for column_name in NoteContent.__table__.columns.keys():
122-
if column_name == "entity_id":
123-
continue
145+
fields_to_update = (provided_fields - {"entity_id"}) | {
146+
"project_id",
147+
"external_id",
148+
"file_path",
149+
}
150+
for column_name in fields_to_update:
124151
setattr(existing, column_name, getattr(note_content, column_name))
125152

126153
await session.flush()
@@ -132,7 +159,7 @@ async def upsert(self, data: Mapping[str, Any] | NoteContent) -> NoteContent:
132159
return updated
133160

134161
async def update_state_fields(self, entity_id: int, **updates: Any) -> Optional[NoteContent]:
135-
"""Update mutable sync and materialization fields for a note_content row."""
162+
"""Update sync fields and re-align project_id, external_id, and file_path from entity."""
136163
invalid_fields = set(updates) - NOTE_CONTENT_MUTABLE_FIELDS
137164
if invalid_fields:
138165
invalid_list = ", ".join(sorted(invalid_fields))

tests/repository/test_note_content_repository.py

Lines changed: 203 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Tests for the NoteContentRepository."""
22

3-
from datetime import datetime, timezone
3+
from datetime import datetime, timedelta, timezone
44

55
import pytest
66

7+
from basic_memory import db
78
from basic_memory.models import NoteContent, Project
89
from basic_memory.repository.entity_repository import EntityRepository
910
from basic_memory.repository.note_content_repository import NoteContentRepository
@@ -101,6 +102,114 @@ async def test_upsert_updates_existing_note_content(
101102
assert updated.last_materialization_error == "transient failure"
102103

103104

105+
@pytest.mark.asyncio
106+
async def test_upsert_inserts_when_no_existing_row(
107+
session_maker,
108+
test_project: Project,
109+
sample_entity,
110+
):
111+
"""Upsert should insert a new row when the entity has no note_content yet."""
112+
repository = NoteContentRepository(session_maker, project_id=test_project.id)
113+
114+
created = await repository.upsert(build_note_content_payload(sample_entity.id))
115+
116+
assert created.entity_id == sample_entity.id
117+
assert created.project_id == sample_entity.project_id
118+
assert created.external_id == sample_entity.external_id
119+
assert created.file_path == sample_entity.file_path
120+
assert created.db_version == 1
121+
122+
123+
@pytest.mark.asyncio
124+
async def test_create_requires_entity_id(session_maker, test_project: Project):
125+
"""Create should fail fast when note_content identity is missing."""
126+
repository = NoteContentRepository(session_maker, project_id=test_project.id)
127+
128+
with pytest.raises(ValueError, match="entity_id is required"):
129+
await repository.create({"markdown_content": "# Missing entity"})
130+
131+
132+
@pytest.mark.asyncio
133+
async def test_upsert_preserves_existing_fields_for_partial_payload(
134+
session_maker,
135+
test_project: Project,
136+
sample_entity,
137+
):
138+
"""Partial upserts should only change explicit fields and preserve existing state."""
139+
repository = NoteContentRepository(session_maker, project_id=test_project.id)
140+
payload = build_note_content_payload(sample_entity.id)
141+
payload["last_materialization_error"] = "stale failure"
142+
created = await repository.create(payload)
143+
144+
updated_at = datetime.now(timezone.utc)
145+
updated = await repository.upsert(
146+
{
147+
"entity_id": sample_entity.id,
148+
"markdown_content": "# Partially updated content",
149+
"db_version": 2,
150+
"updated_at": updated_at,
151+
"last_materialization_error": None,
152+
}
153+
)
154+
155+
assert updated.markdown_content == "# Partially updated content"
156+
assert updated.db_version == 2
157+
assert updated.db_checksum == created.db_checksum
158+
assert updated.file_write_status == created.file_write_status
159+
assert updated.last_source == created.last_source
160+
assert updated.last_materialization_error is None
161+
assert updated.file_path == sample_entity.file_path
162+
163+
164+
@pytest.mark.asyncio
165+
async def test_create_rejects_missing_entity(session_maker, test_project: Project):
166+
"""Create should fail when the owning entity does not exist."""
167+
repository = NoteContentRepository(session_maker, project_id=test_project.id)
168+
169+
with pytest.raises(ValueError, match="Entity 999999 does not exist"):
170+
await repository.create(build_note_content_payload(999999))
171+
172+
173+
@pytest.mark.asyncio
174+
async def test_create_rejects_entity_from_another_project(session_maker, config_home):
175+
"""Create should reject note_content writes across project boundaries."""
176+
project_repository = ProjectRepository(session_maker)
177+
project_one = await project_repository.create(
178+
{
179+
"name": "project-one-boundary",
180+
"path": str(config_home / "project-one-boundary"),
181+
"is_active": True,
182+
}
183+
)
184+
project_two = await project_repository.create(
185+
{
186+
"name": "project-two-boundary",
187+
"path": str(config_home / "project-two-boundary"),
188+
"is_active": True,
189+
}
190+
)
191+
entity_repository = EntityRepository(session_maker, project_id=project_two.id)
192+
other_project_entity = await entity_repository.create(
193+
{
194+
"title": "Other Project Note",
195+
"note_type": "test",
196+
"permalink": "project-two/other-project-note",
197+
"file_path": "notes/other-project-note.md",
198+
"content_type": "text/markdown",
199+
"created_at": datetime.now(timezone.utc),
200+
"updated_at": datetime.now(timezone.utc),
201+
}
202+
)
203+
204+
repository = NoteContentRepository(session_maker, project_id=project_one.id)
205+
206+
with pytest.raises(
207+
ValueError,
208+
match=f"Entity {other_project_entity.id} belongs to project {project_two.id}",
209+
):
210+
await repository.create(build_note_content_payload(other_project_entity.id))
211+
212+
104213
@pytest.mark.asyncio
105214
async def test_update_state_fields_realigns_identity_with_entity(
106215
session_maker,
@@ -134,6 +243,31 @@ async def test_update_state_fields_realigns_identity_with_entity(
134243
assert updated.last_materialization_attempt_at is None
135244

136245

246+
@pytest.mark.asyncio
247+
async def test_update_state_fields_rejects_invalid_fields(
248+
session_maker,
249+
test_project: Project,
250+
sample_entity,
251+
):
252+
"""Only the declared mutable sync fields should be accepted."""
253+
repository = NoteContentRepository(session_maker, project_id=test_project.id)
254+
await repository.create(build_note_content_payload(sample_entity.id))
255+
256+
with pytest.raises(ValueError, match="Unsupported note_content update fields: file_path"):
257+
await repository.update_state_fields(sample_entity.id, file_path="renamed/note.md")
258+
259+
260+
@pytest.mark.asyncio
261+
async def test_update_state_fields_returns_none_for_missing_note_content(
262+
session_maker,
263+
test_project: Project,
264+
):
265+
"""Missing note_content rows should produce a clean None response."""
266+
repository = NoteContentRepository(session_maker, project_id=test_project.id)
267+
268+
assert await repository.update_state_fields(999999, file_write_status="failed") is None
269+
270+
137271
@pytest.mark.asyncio
138272
async def test_delete_by_entity_id(session_maker, test_project: Project, sample_entity):
139273
"""Delete note_content directly by entity identifier."""
@@ -146,6 +280,17 @@ async def test_delete_by_entity_id(session_maker, test_project: Project, sample_
146280
assert await repository.get_by_entity_id(sample_entity.id) is None
147281

148282

283+
@pytest.mark.asyncio
284+
async def test_delete_by_entity_id_returns_false_when_missing(
285+
session_maker,
286+
test_project: Project,
287+
):
288+
"""Delete should report False when the note_content row does not exist."""
289+
repository = NoteContentRepository(session_maker, project_id=test_project.id)
290+
291+
assert await repository.delete_by_entity_id(999999) is False
292+
293+
149294
@pytest.mark.asyncio
150295
async def test_note_content_cascades_when_entity_is_deleted(
151296
session_maker,
@@ -221,3 +366,60 @@ async def test_note_content_file_path_lookup_is_project_scoped(session_maker, co
221366
assert found_two is not None
222367
assert found_one.entity_id == entity_one.id
223368
assert found_two.entity_id == entity_two.id
369+
370+
371+
@pytest.mark.asyncio
372+
async def test_note_content_file_path_lookup_prefers_entity_with_current_path(
373+
session_maker,
374+
config_home,
375+
):
376+
"""File-path lookup should prefer the entity whose current path still matches."""
377+
project_repository = ProjectRepository(session_maker)
378+
project = await project_repository.create(
379+
{
380+
"name": "project-path-drift",
381+
"path": str(config_home / "project-path-drift"),
382+
"is_active": True,
383+
}
384+
)
385+
entity_repository = EntityRepository(session_maker, project_id=project.id)
386+
387+
stale_entity = await entity_repository.create(
388+
{
389+
"title": "Stale Note",
390+
"note_type": "test",
391+
"permalink": "project/stale-note",
392+
"file_path": "archived/note.md",
393+
"content_type": "text/markdown",
394+
"created_at": datetime.now(timezone.utc),
395+
"updated_at": datetime.now(timezone.utc),
396+
}
397+
)
398+
current_entity = await entity_repository.create(
399+
{
400+
"title": "Current Note",
401+
"note_type": "test",
402+
"permalink": "project/current-note",
403+
"file_path": "shared/note.md",
404+
"content_type": "text/markdown",
405+
"created_at": datetime.now(timezone.utc),
406+
"updated_at": datetime.now(timezone.utc),
407+
}
408+
)
409+
410+
repository = NoteContentRepository(session_maker, project_id=project.id)
411+
stale_payload = build_note_content_payload(stale_entity.id)
412+
stale_payload["updated_at"] = datetime.now(timezone.utc) + timedelta(minutes=5)
413+
await repository.create(stale_payload)
414+
await repository.create(build_note_content_payload(current_entity.id))
415+
416+
async with db.scoped_session(session_maker) as session:
417+
stale_note_content = await repository.select_by_id(session, stale_entity.id)
418+
assert stale_note_content is not None
419+
stale_note_content.file_path = "shared/note.md"
420+
await session.flush()
421+
422+
found = await repository.get_by_file_path("shared/note.md")
423+
424+
assert found is not None
425+
assert found.entity_id == current_entity.id

0 commit comments

Comments
 (0)