Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 29 additions & 82 deletions src/basic_memory/repository/entity_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,62 +101,23 @@ async def find_by_permalinks(self, permalinks: List[str]) -> Sequence[Entity]:
return list(result.scalars().all())

async def upsert_entity(self, entity: Entity) -> Entity:
"""Insert or update entity using a hybrid approach.
"""Insert or update entity using simple try/catch with database-level conflict resolution.

This method provides a cleaner alternative to the try/catch approach
for handling permalink and file_path conflicts. It first tries direct
insertion, then handles conflicts intelligently.
Handles file_path race conditions by checking for existing entity on IntegrityError.
For permalink conflicts, generates a unique permalink with numeric suffix.

Args:
entity: The entity to insert or update

Returns:
The inserted or updated entity
"""

async with db.scoped_session(self.session_maker) as session:
# Set project_id if applicable and not already set
self._set_project_id_if_needed(entity)

# Check for existing entity with same file_path first
existing_by_path = await session.execute(
select(Entity).where(
Entity.file_path == entity.file_path, Entity.project_id == entity.project_id
)
)
existing_path_entity = existing_by_path.scalar_one_or_none()

if existing_path_entity:
# Update existing entity with same file path
for key, value in {
"title": entity.title,
"entity_type": entity.entity_type,
"entity_metadata": entity.entity_metadata,
"content_type": entity.content_type,
"permalink": entity.permalink,
"checksum": entity.checksum,
"updated_at": entity.updated_at,
}.items():
setattr(existing_path_entity, key, value)

await session.flush()
# Return with relationships loaded
query = (
self.select()
.where(Entity.file_path == entity.file_path)
.options(*self.get_load_options())
)
result = await session.execute(query)
found = result.scalar_one_or_none()
if not found: # pragma: no cover
raise RuntimeError(
f"Failed to retrieve entity after update: {entity.file_path}"
)
return found

# No existing entity with same file_path, try insert
# Try simple insert first
try:
# Simple insert for new entity
session.add(entity)
await session.flush()

Expand All @@ -175,20 +136,20 @@ async def upsert_entity(self, entity: Entity) -> Entity:
return found

except IntegrityError:
# Could be either file_path or permalink conflict
await session.rollback()

# Check if it's a file_path conflict (race condition)
existing_by_path_check = await session.execute(
select(Entity).where(
# Re-query after rollback to get a fresh, attached entity
existing_result = await session.execute(
select(Entity)
.where(
Entity.file_path == entity.file_path, Entity.project_id == entity.project_id
)
.options(*self.get_load_options())
)
race_condition_entity = existing_by_path_check.scalar_one_or_none()
existing_entity = existing_result.scalar_one_or_none()

if race_condition_entity:
# Race condition: file_path conflict detected after our initial check
# Update the existing entity instead
if existing_entity:
# File path conflict - update the existing entity
for key, value in {
"title": entity.title,
"entity_type": entity.entity_type,
Expand All @@ -198,25 +159,22 @@ async def upsert_entity(self, entity: Entity) -> Entity:
"checksum": entity.checksum,
"updated_at": entity.updated_at,
}.items():
setattr(race_condition_entity, key, value)

await session.flush()
# Return the updated entity with relationships loaded
query = (
self.select()
.where(Entity.file_path == entity.file_path)
.options(*self.get_load_options())
)
result = await session.execute(query)
found = result.scalar_one_or_none()
if not found: # pragma: no cover
raise RuntimeError(
f"Failed to retrieve entity after race condition update: {entity.file_path}"
)
return found
setattr(existing_entity, key, value)

# Clear and re-add observations
existing_entity.observations.clear()
for obs in entity.observations:
obs.entity_id = existing_entity.id
existing_entity.observations.append(obs)

await session.commit()
return existing_entity

else:
# Must be permalink conflict - generate unique permalink
return await self._handle_permalink_conflict(entity, session)
# No file_path conflict - must be permalink conflict
# Generate unique permalink and retry
entity = await self._handle_permalink_conflict(entity, session)
return entity

async def _handle_permalink_conflict(self, entity: Entity, session: AsyncSession) -> Entity:
"""Handle permalink conflicts by generating a unique permalink."""
Expand All @@ -237,18 +195,7 @@ async def _handle_permalink_conflict(self, entity: Entity, session: AsyncSession
break
suffix += 1

# Insert with unique permalink (no conflict possible now)
# Insert with unique permalink
session.add(entity)
await session.flush()

# Return the inserted entity with relationships loaded
query = (
self.select()
.where(Entity.file_path == entity.file_path)
.options(*self.get_load_options())
)
result = await session.execute(query)
found = result.scalar_one_or_none()
if not found: # pragma: no cover
raise RuntimeError(f"Failed to retrieve entity after insert: {entity.file_path}")
return found
return entity
45 changes: 16 additions & 29 deletions tests/repository/test_entity_repository_upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,10 +149,12 @@ async def test_upsert_entity_multiple_permalink_conflicts(entity_repository: Ent

@pytest.mark.asyncio
async def test_upsert_entity_race_condition_file_path(entity_repository: EntityRepository):
"""Test that upsert handles race condition where file_path conflict occurs after initial check."""
from unittest.mock import patch
from sqlalchemy.exc import IntegrityError
"""Test that upsert handles file_path conflicts using ON CONFLICT DO UPDATE.

With SQLite's ON CONFLICT, race conditions are handled at the database level
without requiring application-level checks. This test verifies that updating
an existing entity by file_path works correctly.
"""
# Create an entity first
entity1 = Entity(
project_id=entity_repository.project_id,
Expand All @@ -168,42 +170,27 @@ async def test_upsert_entity_race_condition_file_path(entity_repository: EntityR
result1 = await entity_repository.upsert_entity(entity1)
original_id = result1.id

# Create another entity with different file_path and permalink
# Create another entity with same file_path but different title and permalink
# This simulates a concurrent update scenario
entity2 = Entity(
project_id=entity_repository.project_id,
title="Race Condition Test",
entity_type="note",
permalink="test/race-entity",
file_path="test/different-file.md", # Different initially
file_path="test/race-file.md", # Same file path as entity1
content_type="text/markdown",
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)

# Now simulate race condition: change file_path to conflict after the initial check
original_add = entity_repository.session_maker().add
call_count = 0

def mock_add(obj):
nonlocal call_count
if isinstance(obj, Entity) and call_count == 0:
call_count += 1
# Simulate race condition by changing file_path to conflict
obj.file_path = "test/race-file.md" # Same as entity1
# This should trigger IntegrityError for file_path constraint
raise IntegrityError("UNIQUE constraint failed: entity.file_path", None, None)
return original_add(obj)

# Mock session.add to simulate the race condition
with patch.object(entity_repository.session_maker().__class__, "add", side_effect=mock_add):
# This should handle the race condition gracefully by updating the existing entity
result2 = await entity_repository.upsert_entity(entity2)

# Should return the updated original entity (same ID)
assert result2.id == original_id
assert result2.title == "Race Condition Test" # Updated title
assert result2.file_path == "test/race-file.md" # Same file path
assert result2.permalink == "test/race-entity" # Updated permalink
# ON CONFLICT should update the existing entity
result2 = await entity_repository.upsert_entity(entity2)

# Should return the updated original entity (same ID)
assert result2.id == original_id
assert result2.title == "Race Condition Test" # Updated title
assert result2.file_path == "test/race-file.md" # Same file path
assert result2.permalink == "test/race-entity" # Updated permalink


@pytest.mark.asyncio
Expand Down
Loading