Skip to content

Commit 0471c3d

Browse files
committed
Merge branch 'fix/issue-132-n-plus-one-event-type-filter'
# Conflicts: # storage/search.py
2 parents e097858 + 51549ab commit 0471c3d

2 files changed

Lines changed: 58 additions & 11 deletions

File tree

storage/search.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -313,24 +313,22 @@ async def _load_candidate_sessions(
313313
if tag_conditions:
314314
stmt = stmt.where(or_(*tag_conditions))
315315

316-
# For event_type filtering, use EXISTS to avoid N+1 queries
316+
# Filter by event_type in SQL using a subquery to avoid N+1 queries.
317+
# Previously this was done with a Python loop issuing one query per session.
317318
if event_type:
318-
event_type_str = event_type.value if hasattr(event_type, 'value') else event_type
319-
320-
# Use EXISTS subquery to filter sessions that have at least one event of the specified type
321-
event_exists = exists(
322-
select(EventModel.id).where(
323-
EventModel.session_id == SessionModel.id,
319+
event_type_str = event_type.value if hasattr(event_type, "value") else event_type
320+
event_subq = (
321+
select(EventModel.session_id)
322+
.where(
324323
EventModel.tenant_id == self.tenant_id,
325324
EventModel.event_type == event_type_str,
326325
)
326+
.scalar_subquery()
327327
)
328-
stmt = stmt.where(event_exists)
328+
stmt = stmt.where(SessionModel.id.in_(event_subq))
329329

330330
result = await self.session.execute(stmt)
331-
sessions = list(result.scalars().all())
332-
333-
return sessions
331+
return list(result.scalars().all())
334332

335333
async def _score_sessions(
336334
self,

tests/storage/test_search.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,55 @@ async def test_load_candidate_sessions_without_status(self):
423423
assert isinstance(result, list)
424424
assert mock_session.execute.called
425425

426+
@pytest.mark.asyncio
427+
async def test_load_candidate_sessions_event_type_filter_in_sql(self):
428+
"""Test _load_candidate_sessions pushes event_type filter into SQL subquery.
429+
430+
The event_type predicate must appear in the WHERE clause of the single
431+
SQL statement, not as a Python-level N+1 loop over sessions.
432+
"""
433+
mock_session = _create_mock_async_session()
434+
service = SessionSearchService(mock_session, tenant_id="tenant-1")
435+
436+
mock_result = Mock()
437+
mock_scalars = Mock()
438+
mock_scalars.all = Mock(return_value=[])
439+
mock_result.scalars = Mock(return_value=mock_scalars)
440+
mock_session.execute = AsyncMock(return_value=mock_result)
441+
442+
await service._load_candidate_sessions(event_type="tool_call")
443+
444+
# Exactly one query should have been issued (no N+1 per-session queries).
445+
assert mock_session.execute.call_count == 1
446+
447+
# The WHERE clause must reference the events table for event_type filtering.
448+
stmt = mock_session.execute.await_args.args[0]
449+
where_clause = str(stmt.whereclause)
450+
assert "events.event_type" in where_clause
451+
452+
@pytest.mark.asyncio
453+
async def test_load_candidate_sessions_event_type_excludes_non_matching_sessions(self):
454+
"""Test _load_candidate_sessions returns only sessions returned by the DB.
455+
456+
When the DB returns no sessions (because none have the requested
457+
event_type), the result must be empty — no session should slip through.
458+
"""
459+
mock_session = _create_mock_async_session()
460+
service = SessionSearchService(mock_session, tenant_id="tenant-1")
461+
462+
# DB finds no sessions matching the event_type subquery.
463+
mock_result = Mock()
464+
mock_scalars = Mock()
465+
mock_scalars.all = Mock(return_value=[])
466+
mock_result.scalars = Mock(return_value=mock_scalars)
467+
mock_session.execute = AsyncMock(return_value=mock_result)
468+
469+
result = await service._load_candidate_sessions(event_type="error")
470+
471+
assert result == []
472+
# Still exactly one query — no secondary per-session lookups.
473+
assert mock_session.execute.call_count == 1
474+
426475
@pytest.mark.asyncio
427476
async def test_load_candidate_sessions_with_status(self):
428477
"""Test _load_candidate_sessions applies status filter when provided."""

0 commit comments

Comments
 (0)