Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions storage/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,28 +315,22 @@ async def _load_candidate_sessions(
if tag_conditions:
stmt = stmt.where(or_(*tag_conditions))

result = await self.session.execute(stmt)
sessions = list(result.scalars().all())

# Filter by event_type if specified (requires checking events)
# Filter by event_type in SQL using a subquery to avoid N+1 queries.
# Previously this was done with a Python loop issuing one query per session.
if event_type:
sessions_with_event_type = []
for session in sessions:
# Check if session has any event of the specified type
# Handle both string and enum values
event_type_str = event_type.value if hasattr(event_type, 'value') else event_type

event_stmt = select(EventModel).where(
EventModel.session_id == session.id,
event_type_str = event_type.value if hasattr(event_type, "value") else event_type
event_subq = (
select(EventModel.session_id)
.where(
EventModel.tenant_id == self.tenant_id,
EventModel.event_type == event_type_str,
).limit(1)
event_result = await self.session.execute(event_stmt)
if event_result.scalar_one_or_none() is not None:
sessions_with_event_type.append(session)
sessions = sessions_with_event_type
)
.scalar_subquery()
)
stmt = stmt.where(SessionModel.id.in_(event_subq))

return sessions
result = await self.session.execute(stmt)
return list(result.scalars().all())
Comment on lines +318 to +333
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are existing unit tests for _load_candidate_sessions (e.g., status filtering), but this change significantly alters the event_type filtering behavior by pushing it into SQL. Please add a test that asserts the generated query includes the event_type predicate (or, preferably, an integration-style test that verifies sessions without matching events are excluded).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added two unit tests for _load_candidate_sessions with event_type filter in 51549ab: test_load_candidate_sessions_event_type_filter_in_sql asserts exactly one DB query is issued and that events.event_type appears in the WHERE clause (confirming the predicate is in SQL, not a Python loop); test_load_candidate_sessions_event_type_excludes_non_matching_sessions asserts that when the DB returns no rows the result is empty and still only one query was issued.


async def _score_sessions(
self,
Expand Down
49 changes: 49 additions & 0 deletions tests/storage/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,55 @@ async def test_load_candidate_sessions_without_status(self):
assert isinstance(result, list)
assert mock_session.execute.called

@pytest.mark.asyncio
async def test_load_candidate_sessions_event_type_filter_in_sql(self):
"""Test _load_candidate_sessions pushes event_type filter into SQL subquery.

The event_type predicate must appear in the WHERE clause of the single
SQL statement, not as a Python-level N+1 loop over sessions.
"""
mock_session = _create_mock_async_session()
service = SessionSearchService(mock_session, tenant_id="tenant-1")

mock_result = Mock()
mock_scalars = Mock()
mock_scalars.all = Mock(return_value=[])
mock_result.scalars = Mock(return_value=mock_scalars)
mock_session.execute = AsyncMock(return_value=mock_result)

await service._load_candidate_sessions(event_type="tool_call")

# Exactly one query should have been issued (no N+1 per-session queries).
assert mock_session.execute.call_count == 1

# The WHERE clause must reference the events table for event_type filtering.
stmt = mock_session.execute.await_args.args[0]
where_clause = str(stmt.whereclause)
assert "events.event_type" in where_clause

@pytest.mark.asyncio
async def test_load_candidate_sessions_event_type_excludes_non_matching_sessions(self):
"""Test _load_candidate_sessions returns only sessions returned by the DB.

When the DB returns no sessions (because none have the requested
event_type), the result must be empty — no session should slip through.
"""
mock_session = _create_mock_async_session()
service = SessionSearchService(mock_session, tenant_id="tenant-1")

# DB finds no sessions matching the event_type subquery.
mock_result = Mock()
mock_scalars = Mock()
mock_scalars.all = Mock(return_value=[])
mock_result.scalars = Mock(return_value=mock_scalars)
mock_session.execute = AsyncMock(return_value=mock_result)

result = await service._load_candidate_sessions(event_type="error")

assert result == []
# Still exactly one query — no secondary per-session lookups.
assert mock_session.execute.call_count == 1

@pytest.mark.asyncio
async def test_load_candidate_sessions_with_status(self):
"""Test _load_candidate_sessions applies status filter when provided."""
Expand Down
Loading