-
Notifications
You must be signed in to change notification settings - Fork 196
Expand file tree
/
Copy pathentity_repository.py
More file actions
449 lines (363 loc) · 18.1 KB
/
entity_repository.py
File metadata and controls
449 lines (363 loc) · 18.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
"""Repository for managing entities in the knowledge graph."""
from pathlib import Path
from typing import List, Optional, Sequence, Union
from loguru import logger
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import selectinload
from sqlalchemy.orm.interfaces import LoaderOption
from basic_memory import db
from basic_memory.models.knowledge import Entity, Observation, Relation
from basic_memory.repository.repository import Repository
class EntityRepository(Repository[Entity]):
"""Repository for Entity model.
Note: All file paths are stored as strings in the database. Convert Path objects
to strings before passing to repository methods.
"""
def __init__(self, session_maker: async_sessionmaker[AsyncSession], project_id: int):
"""Initialize with session maker and project_id filter.
Args:
session_maker: SQLAlchemy session maker
project_id: Project ID to filter all operations by
"""
super().__init__(session_maker, Entity, project_id=project_id)
async def get_by_permalink(self, permalink: str) -> Optional[Entity]:
"""Get entity by permalink.
Args:
permalink: Unique identifier for the entity
"""
query = self.select().where(Entity.permalink == permalink).options(*self.get_load_options())
return await self.find_one(query)
async def get_by_title(self, title: str) -> Sequence[Entity]:
"""Get entity by title.
Args:
title: Title of the entity to find
"""
query = self.select().where(Entity.title == title).options(*self.get_load_options())
result = await self.execute_query(query)
return list(result.scalars().all())
async def get_by_file_path(self, file_path: Union[Path, str]) -> Optional[Entity]:
"""Get entity by file_path.
Args:
file_path: Path to the entity file (will be converted to string internally)
"""
query = (
self.select()
.where(Entity.file_path == Path(file_path).as_posix())
.options(*self.get_load_options())
)
return await self.find_one(query)
async def get_by_file_paths_batch(
self, file_paths: Sequence[Union[Path, str]]
) -> dict[str, Entity]:
"""Batch fetch entities by file paths with eager-loaded relationships.
Optimized for scan operations - reduces N queries to 1 batched query.
Returns entities with relationships already loaded via selectinload.
Args:
file_paths: List of file paths to fetch entities for
Returns:
Dict mapping file_path (as posix string) -> Entity
Only includes entities that exist; missing files are not in dict
"""
if not file_paths:
return {}
# Convert all paths to posix strings
posix_paths = [Path(p).as_posix() for p in file_paths]
# Batch query with eager loading
query = (
self.select().where(Entity.file_path.in_(posix_paths)).options(*self.get_load_options())
)
result = await self.execute_query(query)
entities = list(result.scalars().all())
# Return as dict for O(1) lookup
return {e.file_path: e for e in entities}
async def find_by_checksum(self, checksum: str) -> Sequence[Entity]:
"""Find entities with the given checksum.
Used for move detection - finds entities that may have been moved to a new path.
Multiple entities may have the same checksum if files were copied.
Args:
checksum: File content checksum to search for
Returns:
Sequence of entities with matching checksum (may be empty)
"""
query = self.select().where(Entity.checksum == checksum)
# Don't load relationships for move detection - we only need file_path and checksum
result = await self.execute_query(query, use_query_options=False)
return list(result.scalars().all())
async def delete_by_file_path(self, file_path: Union[Path, str]) -> bool:
"""Delete entity with the provided file_path.
Args:
file_path: Path to the entity file (will be converted to string internally)
"""
return await self.delete_by_fields(file_path=Path(file_path).as_posix())
def get_load_options(self) -> List[LoaderOption]:
"""Get SQLAlchemy loader options for eager loading relationships."""
return [
selectinload(Entity.observations).selectinload(Observation.entity),
# Load from_relations and both entities for each relation
selectinload(Entity.outgoing_relations).selectinload(Relation.from_entity),
selectinload(Entity.outgoing_relations).selectinload(Relation.to_entity),
# Load to_relations and both entities for each relation
selectinload(Entity.incoming_relations).selectinload(Relation.from_entity),
selectinload(Entity.incoming_relations).selectinload(Relation.to_entity),
]
async def find_by_permalinks(self, permalinks: List[str]) -> Sequence[Entity]:
"""Find multiple entities by their permalink.
Args:
permalinks: List of permalink strings to find
"""
# Handle empty input explicitly
if not permalinks:
return []
# Use existing select pattern
query = (
self.select().options(*self.get_load_options()).where(Entity.permalink.in_(permalinks))
)
result = await self.execute_query(query)
return list(result.scalars().all())
async def upsert_entity(self, entity: Entity) -> Entity:
"""Insert or update entity using simple try/catch with database-level conflict resolution.
Handles file_path race conditions by checking for existing entity on IntegrityError.
For permalink conflicts, generates a unique permalink with numeric suffix.
Args:
entity: The entity to insert or update
Returns:
The inserted or updated entity
"""
async with db.scoped_session(self.session_maker) as session:
# Set project_id if applicable and not already set
self._set_project_id_if_needed(entity)
# Try simple insert first
try:
session.add(entity)
await session.flush()
# Return with relationships loaded
query = (
self.select()
.where(Entity.file_path == entity.file_path)
.options(*self.get_load_options())
)
result = await session.execute(query)
found = result.scalar_one_or_none()
if not found: # pragma: no cover
raise RuntimeError(
f"Failed to retrieve entity after insert: {entity.file_path}"
)
return found
except IntegrityError as e:
# Check if this is a FOREIGN KEY constraint failure
# SQLite: "FOREIGN KEY constraint failed"
# Postgres: "violates foreign key constraint"
error_str = str(e)
if (
"FOREIGN KEY constraint failed" in error_str
or "violates foreign key constraint" in error_str
):
# Import locally to avoid circular dependency (repository -> services -> repository)
from basic_memory.services.exceptions import SyncFatalError
# Project doesn't exist in database - this is a fatal sync error
raise SyncFatalError(
f"Cannot sync file '{entity.file_path}': "
f"project_id={entity.project_id} does not exist in database. "
f"The project may have been deleted. This sync will be terminated."
) from e
await session.rollback()
# Re-query after rollback to get a fresh, attached entity
existing_result = await session.execute(
select(Entity)
.where(
Entity.file_path == entity.file_path, Entity.project_id == entity.project_id
)
.options(*self.get_load_options())
)
existing_entity = existing_result.scalar_one_or_none()
if existing_entity:
# File path conflict - update the existing entity
logger.debug(
f"Resolving file_path conflict for {entity.file_path}, "
f"entity_id={existing_entity.id}, observations={len(entity.observations)}"
)
# Use merge to avoid session state conflicts
# Set the ID to update existing entity
entity.id = existing_entity.id
# Ensure observations reference the correct entity_id
for obs in entity.observations:
obs.entity_id = existing_entity.id
# Clear any existing ID to force INSERT as new observation
obs.id = None
# Merge the entity which will update the existing one
merged_entity = await session.merge(entity)
await session.commit()
# Re-query to get proper relationships loaded
final_result = await session.execute(
select(Entity)
.where(Entity.id == merged_entity.id)
.options(*self.get_load_options())
)
return final_result.scalar_one()
else:
# No file_path conflict - must be permalink conflict
# Generate unique permalink and retry
entity = await self._handle_permalink_conflict(entity, session)
return entity
async def get_all_file_paths(self) -> List[str]:
"""Get all file paths for this project - optimized for deletion detection.
Returns only file_path strings without loading entities or relationships.
Used by streaming sync to detect deleted files efficiently.
Returns:
List of file_path strings for all entities in the project
"""
query = select(Entity.file_path)
query = self._add_project_filter(query)
result = await self.execute_query(query, use_query_options=False)
return list(result.scalars().all())
async def get_distinct_directories(self) -> List[str]:
"""Extract unique directory paths from file_path column.
Optimized method for getting directory structure without loading full entities
or relationships. Returns a sorted list of unique directory paths.
Returns:
List of unique directory paths (e.g., ["notes", "notes/meetings", "specs"])
"""
# Query only file_path column, no entity objects or relationships
query = select(Entity.file_path).distinct()
query = self._add_project_filter(query)
# Execute with use_query_options=False to skip eager loading
result = await self.execute_query(query, use_query_options=False)
file_paths = [row for row in result.scalars().all()]
# Parse file paths to extract unique directories
directories = set()
for file_path in file_paths:
parts = [p for p in file_path.split("/") if p]
# Add all parent directories (exclude filename which is the last part)
for i in range(len(parts) - 1):
dir_path = "/".join(parts[: i + 1])
directories.add(dir_path)
return sorted(directories)
async def find_by_directory_prefix(self, directory_prefix: str) -> Sequence[Entity]:
"""Find entities whose file_path starts with the given directory prefix.
Optimized method for listing directory contents without loading all entities.
Uses SQL LIKE pattern matching to filter entities by directory path.
Args:
directory_prefix: Directory path prefix (e.g., "docs", "docs/guides")
Empty string returns all entities (root directory)
Returns:
Sequence of entities in the specified directory and subdirectories
"""
# Build SQL LIKE pattern
if directory_prefix == "" or directory_prefix == "/":
# Root directory - return all entities
return await self.find_all()
# Remove leading/trailing slashes for consistency
directory_prefix = directory_prefix.strip("/")
# Query entities with file_path starting with prefix
# Pattern matches "prefix/" to ensure we get files IN the directory,
# not just files whose names start with the prefix
pattern = f"{directory_prefix}/%"
query = self.select().where(Entity.file_path.like(pattern))
# Skip eager loading - we only need basic entity fields for directory trees
result = await self.execute_query(query, use_query_options=False)
return list(result.scalars().all())
async def _handle_permalink_conflict(self, entity: Entity, session: AsyncSession) -> Entity:
"""Handle permalink conflicts by generating a unique permalink."""
base_permalink = entity.permalink
suffix = 1
# Find a unique permalink
while True:
test_permalink = f"{base_permalink}-{suffix}"
existing = await session.execute(
select(Entity).where(
Entity.permalink == test_permalink, Entity.project_id == entity.project_id
)
)
if existing.scalar_one_or_none() is None:
# Found unique permalink
entity.permalink = test_permalink
break
suffix += 1
# Insert with unique permalink
session.add(entity)
try:
await session.flush()
except IntegrityError as e:
# Check if this is a FOREIGN KEY constraint failure
# SQLite: "FOREIGN KEY constraint failed"
# Postgres: "violates foreign key constraint"
error_str = str(e)
if (
"FOREIGN KEY constraint failed" in error_str
or "violates foreign key constraint" in error_str
):
# Import locally to avoid circular dependency (repository -> services -> repository)
from basic_memory.services.exceptions import SyncFatalError
# Project doesn't exist in database - this is a fatal sync error
raise SyncFatalError(
f"Cannot sync file '{entity.file_path}': "
f"project_id={entity.project_id} does not exist in database. "
f"The project may have been deleted. This sync will be terminated."
) from e
# Re-raise if not a foreign key error
raise
return entity
async def upsert_entities(self, entities: List[Entity]) -> List[Entity]:
"""Bulk insert or update multiple entities in a single transaction.
Optimized for batch operations with remote databases (Postgres).
Handles conflicts the same way as upsert_entity() but processes
all entities in one transaction.
Args:
entities: List of entities to upsert
Returns:
List of upserted entities with relationships loaded
Raises:
SyncFatalError: If any entity references a non-existent project_id
"""
if not entities:
return []
async with db.scoped_session(self.session_maker) as session:
# Set project_id on all entities if needed
for entity in entities:
self._set_project_id_if_needed(entity)
# Try to add all entities
for entity in entities:
session.add(entity)
try:
await session.flush()
# Fetch all entities with relationships loaded
file_paths = [e.file_path for e in entities]
query = (
self.select()
.where(Entity.file_path.in_(file_paths))
.options(*self.get_load_options())
)
result = await session.execute(query)
return list(result.scalars().all())
except IntegrityError as e:
# Check for foreign key constraint failures
error_str = str(e)
if (
"FOREIGN KEY constraint failed" in error_str
or "violates foreign key constraint" in error_str
):
from basic_memory.services.exceptions import SyncFatalError
raise SyncFatalError(
"Cannot sync entities: project_id does not exist in database. "
"The project may have been deleted. This sync will be terminated."
) from e
# For other integrity errors (file_path or permalink conflicts),
# rollback and fall back to individual processing
await session.rollback()
# Process each entity individually to handle conflicts properly
logger.debug(
f"Batch upsert failed with IntegrityError, falling back to individual upserts for {len(entities)} entities"
)
result_entities = []
for entity in entities:
try:
upserted = await self.upsert_entity(entity)
result_entities.append(upserted)
except Exception as individual_error:
logger.error(
f"Failed to upsert entity {entity.file_path}: {individual_error}"
)
# Continue with other entities
return result_entities