-
Notifications
You must be signed in to change notification settings - Fork 201
Expand file tree
/
Copy pathtest_context_service.py
More file actions
320 lines (270 loc) · 12.9 KB
/
test_context_service.py
File metadata and controls
320 lines (270 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
"""Tests for context service."""
from datetime import datetime, timedelta, UTC
import pytest
import pytest_asyncio
from basic_memory.repository.search_repository import SearchIndexRow
from basic_memory.schemas.memory import memory_url, memory_url_path
from basic_memory.schemas.search import SearchItemType
from basic_memory.services.context_service import ContextService
from basic_memory.models.knowledge import Entity, Relation
from basic_memory.models.project import Project
@pytest_asyncio.fixture
async def context_service(search_repository, entity_repository, observation_repository):
"""Create context service for testing."""
return ContextService(search_repository, entity_repository, observation_repository)
@pytest.mark.asyncio
async def test_find_connected_depth_limit(context_service, test_graph):
"""Test depth limiting works.
Our traversal path is:
- Depth 0: Root
- Depth 1: Relations + directly connected entities (Connected1, Connected2)
- Depth 2: Relations + next level entities (Deep)
"""
type_id_pairs = [("entity", test_graph["root"].id)]
# With depth=1, we get direct connections
# shallow_results = await context_service.find_related(type_id_pairs, max_depth=1)
# shallow_entities = {(r.id, r.type) for r in shallow_results if r.type == "entity"}
#
# assert (test_graph["deep"].id, "entity") not in shallow_entities
# search deeper
deep_results = await context_service.find_related(type_id_pairs, max_depth=3, max_results=100)
deep_entities = {(r.id, r.type) for r in deep_results if r.type == "entity"}
print(deep_entities)
# Should now include Deep entity
assert (test_graph["deep"].id, "entity") in deep_entities
@pytest.mark.asyncio
async def test_find_connected_timeframe(
context_service, test_graph, search_repository, entity_repository
):
"""Test timeframe filtering.
This tests how traversal is affected by the item dates.
When we filter by date, items are only included if:
1. They match the timeframe
2. There is a valid path to them through other items in the timeframe
"""
now = datetime.now(UTC)
old_date = now - timedelta(days=10)
recent_date = now - timedelta(days=1)
# Update entity table timestamps directly
# Root entity uses old date
root_entity = test_graph["root"]
await entity_repository.update(root_entity.id, {"created_at": old_date, "updated_at": old_date})
# Connected entity uses recent date
connected_entity = test_graph["connected1"]
await entity_repository.update(
connected_entity.id, {"created_at": recent_date, "updated_at": recent_date}
)
# Also update search_index for test consistency
await search_repository.index_item(
SearchIndexRow(
project_id=entity_repository.project_id,
id=test_graph["root"].id,
title=test_graph["root"].title,
content_snippet="Root content",
permalink=test_graph["root"].permalink,
file_path=test_graph["root"].file_path,
type=SearchItemType.ENTITY,
metadata={"created_at": old_date.isoformat()},
created_at=old_date.isoformat(),
updated_at=old_date.isoformat(),
)
)
await search_repository.index_item(
SearchIndexRow(
project_id=entity_repository.project_id,
id=test_graph["relations"][0].id,
title="Root Entity → Connected Entity 1",
content_snippet="",
permalink=f"{test_graph['root'].permalink}/connects_to/{test_graph['connected1'].permalink}",
file_path=test_graph["root"].file_path,
type=SearchItemType.RELATION,
from_id=test_graph["root"].id,
to_id=test_graph["connected1"].id,
relation_type="connects_to",
metadata={"created_at": old_date.isoformat()},
created_at=old_date.isoformat(),
updated_at=old_date.isoformat(),
)
)
await search_repository.index_item(
SearchIndexRow(
project_id=entity_repository.project_id,
id=test_graph["connected1"].id,
title=test_graph["connected1"].title,
content_snippet="Connected 1 content",
permalink=test_graph["connected1"].permalink,
file_path=test_graph["connected1"].file_path,
type=SearchItemType.ENTITY,
metadata={"created_at": recent_date.isoformat()},
created_at=recent_date.isoformat(),
updated_at=recent_date.isoformat(),
)
)
type_id_pairs = [("entity", test_graph["root"].id)]
# Search with a 7-day cutoff
since_date = now - timedelta(days=7)
results = await context_service.find_related(type_id_pairs, since=since_date)
# Only connected1 is recent, but we can't get to it
# because its connecting relation is too old and is filtered out
# (we can only reach connected1 through a relation starting from root)
entity_ids = {r.id for r in results if r.type == "entity"}
assert len(entity_ids) == 0 # No accessible entities within timeframe
@pytest.mark.asyncio
async def test_build_context(context_service, test_graph):
"""Test exact permalink lookup."""
url = memory_url.validate_strings("memory://test/root")
context_result = await context_service.build_context(url)
# Check metadata
assert context_result.metadata.uri == memory_url_path(url)
assert context_result.metadata.depth == 1
assert context_result.metadata.primary_count == 1
assert context_result.metadata.related_count > 0
assert context_result.metadata.generated_at is not None
# Check results
assert len(context_result.results) == 1
context_item = context_result.results[0]
# Check primary result
primary_result = context_item.primary_result
assert primary_result.id == test_graph["root"].id
assert primary_result.type == "entity"
assert primary_result.title == "Root"
assert primary_result.permalink == "test/root"
assert primary_result.file_path == "test/Root.md"
assert primary_result.created_at is not None
# Check related results
assert len(context_item.related_results) > 0
# Find related relation
relation = next((r for r in context_item.related_results if r.type == "relation"), None)
assert relation is not None
assert relation.relation_type == "connects_to"
assert relation.from_id == test_graph["root"].id
assert relation.to_id == test_graph["connected1"].id
# Find related entity
related_entity = next((r for r in context_item.related_results if r.type == "entity"), None)
assert related_entity is not None
assert related_entity.id == test_graph["connected1"].id
assert related_entity.title == test_graph["connected1"].title
assert related_entity.permalink == test_graph["connected1"].permalink
@pytest.mark.asyncio
async def test_build_context_with_observations(context_service, test_graph):
"""Test context building with observations."""
# The test_graph fixture already creates observations for root entity
# Let's use those existing observations
# Build context
url = memory_url.validate_strings("memory://test/root")
context_result = await context_service.build_context(url, include_observations=True)
# Check the metadata
assert context_result.metadata.total_observations > 0
assert len(context_result.results) == 1
# Check that observations were included
context_item = context_result.results[0]
assert len(context_item.observations) > 0
# Check observation properties
for observation in context_item.observations:
assert observation.type == "observation"
assert observation.category in ["note", "tech"] # Categories from test_graph fixture
assert observation.entity_id == test_graph["root"].id
# Verify at least one observation has the correct category and content
note_observation = next((o for o in context_item.observations if o.category == "note"), None)
assert note_observation is not None
assert "Root note" in note_observation.content
@pytest.mark.asyncio
async def test_build_context_not_found(context_service):
"""Test handling non-existent permalinks."""
context = await context_service.build_context("memory://does/not/exist")
assert len(context.results) == 0
assert context.metadata.primary_count == 0
assert context.metadata.related_count == 0
@pytest.mark.asyncio
async def test_context_metadata(context_service, test_graph):
"""Test metadata is correctly populated."""
context = await context_service.build_context("memory://test/root", depth=2)
metadata = context.metadata
assert metadata.uri == "test/root"
assert metadata.depth == 2
assert metadata.generated_at is not None
assert metadata.primary_count > 0
@pytest.mark.asyncio
async def test_project_isolation_in_find_related(session_maker):
"""Test that find_related respects project boundaries and doesn't leak data."""
from basic_memory.repository.entity_repository import EntityRepository
from basic_memory.repository.observation_repository import ObservationRepository
from basic_memory.repository.search_repository import SearchRepository
from basic_memory import db
# Create database session
async with db.scoped_session(session_maker) as db_session:
# Create two separate projects
project1 = Project(name="project1", path="/test1")
project2 = Project(name="project2", path="/test2")
db_session.add(project1)
db_session.add(project2)
await db_session.flush()
# Create entities in project1
entity1_p1 = Entity(
title="Entity1_P1",
entity_type="document",
content_type="text/markdown",
project_id=project1.id,
permalink="project1/entity1",
file_path="project1/entity1.md",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC)
)
entity2_p1 = Entity(
title="Entity2_P1",
entity_type="document",
content_type="text/markdown",
project_id=project1.id,
permalink="project1/entity2",
file_path="project1/entity2.md",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC)
)
# Create entities in project2
entity1_p2 = Entity(
title="Entity1_P2",
entity_type="document",
content_type="text/markdown",
project_id=project2.id,
permalink="project2/entity1",
file_path="project2/entity1.md",
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC)
)
db_session.add_all([entity1_p1, entity2_p1, entity1_p2])
await db_session.flush()
# Create relation in project1 (between entities of project1)
relation_p1 = Relation(
from_id=entity1_p1.id,
to_id=entity2_p1.id,
to_name="Entity2_P1",
relation_type="connects_to"
)
db_session.add(relation_p1)
await db_session.commit()
# Create repositories for project1
search_repo_p1 = SearchRepository(session_maker, project1.id)
entity_repo_p1 = EntityRepository(session_maker, project1.id)
obs_repo_p1 = ObservationRepository(session_maker, project1.id)
context_service_p1 = ContextService(search_repo_p1, entity_repo_p1, obs_repo_p1)
# Create repositories for project2
search_repo_p2 = SearchRepository(session_maker, project2.id)
entity_repo_p2 = EntityRepository(session_maker, project2.id)
obs_repo_p2 = ObservationRepository(session_maker, project2.id)
context_service_p2 = ContextService(search_repo_p2, entity_repo_p2, obs_repo_p2)
# Test: find_related for project1 should only return project1 entities
type_id_pairs_p1 = [("entity", entity1_p1.id)]
related_p1 = await context_service_p1.find_related(type_id_pairs_p1, max_depth=2)
# Verify only project1 entities are returned
related_entity_ids = [r.id for r in related_p1 if r.type == "entity"]
assert entity2_p1.id in related_entity_ids # Should find connected entity2 in project1
assert entity1_p2.id not in related_entity_ids # Should NOT find entity from project2
# Test: find_related for project2 should return empty (no relations)
type_id_pairs_p2 = [("entity", entity1_p2.id)]
related_p2 = await context_service_p2.find_related(type_id_pairs_p2, max_depth=2)
# Project2 has no relations, so should return empty
assert len(related_p2) == 0
# Double-check: verify entities exist in their respective projects
assert entity1_p1.project_id == project1.id
assert entity2_p1.project_id == project1.id
assert entity1_p2.project_id == project2.id