|
17 | 17 | from src.config import settings |
18 | 18 | from src.crud import create_messages |
19 | 19 | from src.crud import message as message_crud |
20 | | -from src.models import Peer, Workspace |
| 20 | +from src.models import Message, Peer, Workspace |
21 | 21 | from src.schemas import MessageCreate |
22 | 22 | from src.utils.search import search |
23 | 23 |
|
24 | 24 |
|
| 25 | +class _FakeScalarResult: |
| 26 | + def __init__(self, rows: list[models.Message]): |
| 27 | + self._rows: list[Message] = rows |
| 28 | + |
| 29 | + def all(self) -> list[models.Message]: |
| 30 | + return self._rows |
| 31 | + |
| 32 | + |
| 33 | +class _FakeResult: |
| 34 | + def __init__(self, rows: list[models.Message]): |
| 35 | + self._rows: list[Message] = rows |
| 36 | + |
| 37 | + def scalars(self) -> _FakeScalarResult: |
| 38 | + return _FakeScalarResult(self._rows) |
| 39 | + |
| 40 | + |
| 41 | +class _CountingDb: |
| 42 | + def __init__(self, rows: list[models.Message]): |
| 43 | + self._rows: list[Message] = rows |
| 44 | + self.execute_count: int = 0 |
| 45 | + |
| 46 | + async def execute(self, _stmt: Any) -> _FakeResult: |
| 47 | + self.execute_count += 1 |
| 48 | + return _FakeResult(self._rows) |
| 49 | + |
| 50 | + |
| 51 | +def _message(session_name: str, seq_in_session: int) -> models.Message: |
| 52 | + return models.Message( |
| 53 | + workspace_name="workspace", |
| 54 | + session_name=session_name, |
| 55 | + peer_name="peer", |
| 56 | + content=f"{session_name}:{seq_in_session}", |
| 57 | + public_id=generate_nanoid(), |
| 58 | + seq_in_session=seq_in_session, |
| 59 | + token_count=1, |
| 60 | + created_at=datetime.now(timezone.utc), |
| 61 | + ) |
| 62 | + |
| 63 | + |
25 | 64 | @pytest.mark.asyncio |
26 | 65 | async def test_message_embedding_created_when_setting_enabled( |
27 | 66 | db_session: AsyncSession, |
@@ -260,6 +299,46 @@ async def test_semantic_search_when_embeddings_enabled( |
260 | 299 | assert created_message.public_id in found_message_ids |
261 | 300 |
|
262 | 301 |
|
| 302 | +@pytest.mark.asyncio |
| 303 | +async def test_build_merged_snippets_batches_context_query_across_sessions(): |
| 304 | + """Context expansion should not issue one DB query per matched session.""" |
| 305 | + matched_messages = [ |
| 306 | + _message("session_a", 10), |
| 307 | + _message("session_b", 20), |
| 308 | + _message("session_c", 30), |
| 309 | + ] |
| 310 | + context_messages = [ |
| 311 | + _message("session_a", 9), |
| 312 | + _message("session_a", 10), |
| 313 | + _message("session_a", 11), |
| 314 | + _message("session_a", 99), |
| 315 | + _message("session_b", 19), |
| 316 | + _message("session_b", 20), |
| 317 | + _message("session_b", 21), |
| 318 | + _message("session_c", 29), |
| 319 | + _message("session_c", 30), |
| 320 | + _message("session_c", 31), |
| 321 | + ] |
| 322 | + db = _CountingDb(context_messages) |
| 323 | + |
| 324 | + snippets = await message_crud._build_merged_snippets( # pyright: ignore[reportPrivateUsage] |
| 325 | + db, # pyright: ignore[reportArgumentType] |
| 326 | + workspace_name="workspace", |
| 327 | + matched_messages=matched_messages, |
| 328 | + context_window=1, |
| 329 | + ) |
| 330 | + |
| 331 | + assert db.execute_count == 1 |
| 332 | + assert [len(matches) for matches, _ in snippets] == [1, 1, 1] |
| 333 | + assert [ |
| 334 | + [msg.content for msg in context_messages] for _, context_messages in snippets |
| 335 | + ] == [ |
| 336 | + ["session_a:9", "session_a:10", "session_a:11"], |
| 337 | + ["session_b:19", "session_b:20", "session_b:21"], |
| 338 | + ["session_c:29", "session_c:30", "session_c:31"], |
| 339 | + ] |
| 340 | + |
| 341 | + |
263 | 342 | @pytest.mark.asyncio |
264 | 343 | async def test_search_messages_external_lookup_happens_before_tracked_db( |
265 | 344 | monkeypatch: pytest.MonkeyPatch, |
|
0 commit comments