Skip to content

Commit 4e10e21

Browse files
committed
perf(core): streamline entity write upserts
Signed-off-by: phernandez <paul@basicmachines.co>
1 parent 69808b2 commit 4e10e21

File tree

5 files changed

+258
-108
lines changed

5 files changed

+258
-108
lines changed

src/basic_memory/repository/entity_repository.py

Lines changed: 33 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,27 @@ async def get_by_external_id(self, external_id: str) -> Optional[Entity]:
5959
)
6060
return await self.find_one(query)
6161

62-
async def get_by_permalink(self, permalink: str) -> Optional[Entity]:
62+
async def _find_one_by_query(self, query, *, load_relations: bool) -> Optional[Entity]:
63+
"""Return one entity row with optional eager loading."""
64+
if load_relations:
65+
query = query.options(*self.get_load_options())
66+
return await self.find_one(query)
67+
68+
result = await self.execute_query(query, use_query_options=False)
69+
return result.scalars().one_or_none()
70+
71+
async def get_by_permalink(
72+
self, permalink: str, *, load_relations: bool = True
73+
) -> Optional[Entity]:
6374
"""Get entity by permalink.
6475
6576
Args:
6677
permalink: Unique identifier for the entity
6778
"""
68-
query = self.select().where(Entity.permalink == permalink).options(*self.get_load_options())
69-
return await self.find_one(query)
79+
query = self.select().where(Entity.permalink == permalink)
80+
return await self._find_one_by_query(query, load_relations=load_relations)
7081

71-
async def get_by_title(self, title: str) -> Sequence[Entity]:
82+
async def get_by_title(self, title: str, *, load_relations: bool = True) -> Sequence[Entity]:
7283
"""Get entities by title, ordered by shortest path first.
7384
7485
When multiple entities share the same title (in different folders),
@@ -82,23 +93,20 @@ async def get_by_title(self, title: str) -> Sequence[Entity]:
8293
self.select()
8394
.where(Entity.title == title)
8495
.order_by(func.length(Entity.file_path), Entity.file_path)
85-
.options(*self.get_load_options())
8696
)
87-
result = await self.execute_query(query)
97+
result = await self.execute_query(query, use_query_options=load_relations)
8898
return list(result.scalars().all())
8999

90-
async def get_by_file_path(self, file_path: Union[Path, str]) -> Optional[Entity]:
100+
async def get_by_file_path(
101+
self, file_path: Union[Path, str], *, load_relations: bool = True
102+
) -> Optional[Entity]:
91103
"""Get entity by file_path.
92104
93105
Args:
94106
file_path: Path to the entity file (will be converted to string internally)
95107
"""
96-
query = (
97-
self.select()
98-
.where(Entity.file_path == Path(file_path).as_posix())
99-
.options(*self.get_load_options())
100-
)
101-
return await self.find_one(query)
108+
query = self.select().where(Entity.file_path == Path(file_path).as_posix())
109+
return await self._find_one_by_query(query, load_relations=load_relations)
102110

103111
# -------------------------------------------------------------------------
104112
# Lightweight methods for permalink resolution (no eager loading)
@@ -306,7 +314,7 @@ async def find_by_permalinks(self, permalinks: List[str]) -> Sequence[Entity]:
306314
result = await self.execute_query(query)
307315
return list(result.scalars().all())
308316

309-
async def upsert_entity(self, entity: Entity) -> Entity:
317+
async def upsert_entity(self, entity: Entity, *, reload: bool = True) -> Entity:
310318
"""Insert or update entity using simple try/catch with database-level conflict resolution.
311319
312320
Handles file_path race conditions by checking for existing entity on IntegrityError.
@@ -327,6 +335,9 @@ async def upsert_entity(self, entity: Entity) -> Entity:
327335
session.add(entity)
328336
await session.flush()
329337

338+
if not reload:
339+
return entity
340+
330341
# Return with relationships loaded
331342
query = (
332343
self.select()
@@ -363,13 +374,12 @@ async def upsert_entity(self, entity: Entity) -> Entity:
363374
await session.rollback()
364375

365376
# Re-query after rollback to get a fresh, attached entity
366-
existing_result = await session.execute(
367-
select(Entity)
368-
.where(
369-
Entity.file_path == entity.file_path, Entity.project_id == entity.project_id
370-
)
371-
.options(*self.get_load_options())
377+
existing_query = select(Entity).where(
378+
Entity.file_path == entity.file_path, Entity.project_id == entity.project_id
372379
)
380+
if reload:
381+
existing_query = existing_query.options(*self.get_load_options())
382+
existing_result = await session.execute(existing_query)
373383
existing_entity = existing_result.scalar_one_or_none()
374384

375385
if existing_entity:
@@ -393,6 +403,9 @@ async def upsert_entity(self, entity: Entity) -> Entity:
393403

394404
await session.commit()
395405

406+
if not reload:
407+
return merged_entity
408+
396409
# Re-query to get proper relationships loaded
397410
final_result = await session.execute(
398411
select(Entity)

src/basic_memory/repository/repository.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -268,8 +268,21 @@ async def create_all(self, data_list: List[dict]) -> Sequence[T]:
268268

269269
return await self.select_by_ids(session, [model.id for model in model_list]) # pyright: ignore [reportAttributeAccessIssue]
270270

271-
async def update(self, entity_id: int, entity_data: dict | T) -> Optional[T]:
272-
"""Update an entity with the given data."""
271+
async def update(
272+
self,
273+
entity_id: int,
274+
entity_data: dict | T,
275+
*,
276+
reload: bool = True,
277+
) -> Optional[T]:
278+
"""Update an entity with the given data.
279+
280+
Args:
281+
entity_id: Primary key to update
282+
entity_data: Column values or a model instance to copy from
283+
reload: When True, re-select the entity with repository load options.
284+
When False, return the attached row after flush/refresh.
285+
"""
273286
logger.debug(f"Updating {self.Model.__name__} {entity_id} with data: {entity_data}")
274287
async with db.scoped_session(self.session_maker) as session:
275288
try:
@@ -291,6 +304,8 @@ async def update(self, entity_id: int, entity_data: dict | T) -> Optional[T]:
291304
await session.refresh(entity) # Refresh
292305

293306
logger.debug(f"Updated {self.Model.__name__}: {entity_id}")
307+
if not reload:
308+
return entity
294309
return await self.select_by_id(session, entity.id) # pyright: ignore [reportAttributeAccessIssue]
295310

296311
except NoResultFound:

0 commit comments

Comments
 (0)