-
Notifications
You must be signed in to change notification settings - Fork 206
Expand file tree
/
Copy pathobservation_repository.py
More file actions
102 lines (77 loc) · 3.89 KB
/
observation_repository.py
File metadata and controls
102 lines (77 loc) · 3.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Repository for managing Observation objects."""
from typing import Dict, List, Sequence
from sqlalchemy import select
from sqlalchemy.ext.asyncio import async_sessionmaker
from basic_memory.models import Observation
from basic_memory.repository.repository import Repository
class ObservationRepository(Repository[Observation]):
"""Repository for Observation model with memory-specific operations."""
def __init__(self, session_maker: async_sessionmaker, project_id: int):
"""Initialize with session maker and project_id filter.
Args:
session_maker: SQLAlchemy session maker
project_id: Project ID to filter all operations by
"""
super().__init__(session_maker, Observation, project_id=project_id)
async def find_by_entity(self, entity_id: int) -> Sequence[Observation]:
"""Find all observations for a specific entity."""
query = select(Observation).filter(Observation.entity_id == entity_id)
result = await self.execute_query(query)
return result.scalars().all()
async def find_by_context(self, context: str) -> Sequence[Observation]:
"""Find observations with a specific context."""
query = select(Observation).filter(Observation.context == context)
result = await self.execute_query(query)
return result.scalars().all()
async def find_by_category(self, category: str) -> Sequence[Observation]:
"""Find observations with a specific context."""
query = select(Observation).filter(Observation.category == category)
result = await self.execute_query(query)
return result.scalars().all()
async def observation_categories(self) -> Sequence[str]:
"""Return a list of all observation categories."""
query = select(Observation.category).distinct()
result = await self.execute_query(query, use_query_options=False)
return result.scalars().all()
async def find_by_entities(self, entity_ids: List[int]) -> Dict[int, List[Observation]]:
"""Find all observations for multiple entities in a single query.
Args:
entity_ids: List of entity IDs to fetch observations for
Returns:
Dictionary mapping entity_id to list of observations
"""
if not entity_ids: # pragma: no cover
return {}
# Query observations for all entities in the list
query = select(Observation).filter(Observation.entity_id.in_(entity_ids))
result = await self.execute_query(query)
observations = result.scalars().all()
# Group observations by entity_id
observations_by_entity = {}
for obs in observations:
if obs.entity_id not in observations_by_entity:
observations_by_entity[obs.entity_id] = []
observations_by_entity[obs.entity_id].append(obs)
return observations_by_entity
async def delete_by_entity_ids(self, entity_ids: List[int]) -> int:
"""Delete all observations for multiple entities in a single query.
Optimized for batch operations - deletes observations for many entities
in one database transaction.
Args:
entity_ids: List of entity IDs whose observations should be deleted
Returns:
Number of observations deleted
"""
if not entity_ids:
return 0
from basic_memory import db
async with db.scoped_session(self.session_maker) as session:
# Use bulk delete with IN clause
query = select(Observation).where(Observation.entity_id.in_(entity_ids))
result = await session.execute(query)
observations_to_delete = result.scalars().all()
# Delete all observations
for obs in observations_to_delete:
await session.delete(obs)
await session.flush()
return len(observations_to_delete)