diff --git a/storage/search.py b/storage/search.py index 850f5e1..d4a3461 100644 --- a/storage/search.py +++ b/storage/search.py @@ -315,28 +315,22 @@ async def _load_candidate_sessions( if tag_conditions: stmt = stmt.where(or_(*tag_conditions)) - result = await self.session.execute(stmt) - sessions = list(result.scalars().all()) - - # Filter by event_type if specified (requires checking events) + # Filter by event_type in SQL using a subquery to avoid N+1 queries. + # Previously this was done with a Python loop issuing one query per session. if event_type: - sessions_with_event_type = [] - for session in sessions: - # Check if session has any event of the specified type - # Handle both string and enum values - event_type_str = event_type.value if hasattr(event_type, 'value') else event_type - - event_stmt = select(EventModel).where( - EventModel.session_id == session.id, + event_type_str = event_type.value if hasattr(event_type, "value") else event_type + event_subq = ( + select(EventModel.session_id) + .where( EventModel.tenant_id == self.tenant_id, EventModel.event_type == event_type_str, - ).limit(1) - event_result = await self.session.execute(event_stmt) - if event_result.scalar_one_or_none() is not None: - sessions_with_event_type.append(session) - sessions = sessions_with_event_type + ) + .scalar_subquery() + ) + stmt = stmt.where(SessionModel.id.in_(event_subq)) - return sessions + result = await self.session.execute(stmt) + return list(result.scalars().all()) async def _score_sessions( self, diff --git a/tests/storage/test_search.py b/tests/storage/test_search.py index c148b8c..605358f 100644 --- a/tests/storage/test_search.py +++ b/tests/storage/test_search.py @@ -423,6 +423,55 @@ async def test_load_candidate_sessions_without_status(self): assert isinstance(result, list) assert mock_session.execute.called + @pytest.mark.asyncio + async def test_load_candidate_sessions_event_type_filter_in_sql(self): + """Test _load_candidate_sessions pushes event_type filter into SQL subquery. + + The event_type predicate must appear in the WHERE clause of the single + SQL statement, not as a Python-level N+1 loop over sessions. + """ + mock_session = _create_mock_async_session() + service = SessionSearchService(mock_session, tenant_id="tenant-1") + + mock_result = Mock() + mock_scalars = Mock() + mock_scalars.all = Mock(return_value=[]) + mock_result.scalars = Mock(return_value=mock_scalars) + mock_session.execute = AsyncMock(return_value=mock_result) + + await service._load_candidate_sessions(event_type="tool_call") + + # Exactly one query should have been issued (no N+1 per-session queries). + assert mock_session.execute.call_count == 1 + + # The WHERE clause must reference the events table for event_type filtering. + stmt = mock_session.execute.await_args.args[0] + where_clause = str(stmt.whereclause) + assert "events.event_type" in where_clause + + @pytest.mark.asyncio + async def test_load_candidate_sessions_event_type_excludes_non_matching_sessions(self): + """Test _load_candidate_sessions returns only sessions returned by the DB. + + When the DB returns no sessions (because none have the requested + event_type), the result must be empty — no session should slip through. + """ + mock_session = _create_mock_async_session() + service = SessionSearchService(mock_session, tenant_id="tenant-1") + + # DB finds no sessions matching the event_type subquery. + mock_result = Mock() + mock_scalars = Mock() + mock_scalars.all = Mock(return_value=[]) + mock_result.scalars = Mock(return_value=mock_scalars) + mock_session.execute = AsyncMock(return_value=mock_result) + + result = await service._load_candidate_sessions(event_type="error") + + assert result == [] + # Still exactly one query — no secondary per-session lookups. + assert mock_session.execute.call_count == 1 + @pytest.mark.asyncio async def test_load_candidate_sessions_with_status(self): """Test _load_candidate_sessions applies status filter when provided."""