Skip to content

Commit a4ae372

Browse files
fix: internal N+1 query in dialectic agent calls - DEV-1721 (plastic-labs#652)
* fix: internal N+1 query in dialectic agent calls * fix: comments
1 parent ad7c1b3 commit a4ae372

2 files changed

Lines changed: 115 additions & 20 deletions

File tree

src/crud/message.py

Lines changed: 35 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,12 @@ async def _build_merged_snippets(
140140
for msg in matched_messages:
141141
session_matches.setdefault(msg.session_name, []).append(msg)
142142

143-
snippets: list[tuple[list[models.Message], list[models.Message]]] = []
144-
143+
# Build merged ranges per session, then issue a single batched query
144+
session_ranges: dict[str, list[tuple[int, int, list[models.Message]]]] = {}
145145
for sess_name, matches in session_matches.items():
146146
matches.sort(key=lambda m: m.seq_in_session)
147147

148148
merged_ranges: list[tuple[int, int, list[models.Message]]] = []
149-
150149
for match in matches:
151150
start = match.seq_in_session - context_window
152151
end = match.seq_in_session + context_window
@@ -161,25 +160,42 @@ async def _build_merged_snippets(
161160
else:
162161
merged_ranges.append((start, end, [match]))
163162

164-
# Batch all ranges into a single query using OR conditions.
165-
# NOTE: If callers ever pass a very high limit (many disjoint ranges),
166-
# consider chunking to avoid oversized SQL / planner issues.
167-
range_conditions = [
168-
models.Message.seq_in_session.between(start_seq, end_seq)
169-
for start_seq, end_seq, _ in merged_ranges
170-
]
171-
context_stmt = (
172-
select(models.Message)
173-
.where(models.Message.workspace_name == workspace_name)
174-
.where(models.Message.session_name == sess_name)
175-
.where(or_(*range_conditions))
176-
.order_by(models.Message.seq_in_session.asc())
163+
session_ranges[sess_name] = merged_ranges
164+
165+
# One OR-of-ANDs predicate covers every (session, range) pair
166+
session_predicates = [
167+
and_(
168+
models.Message.session_name == sess_name,
169+
or_(
170+
*(
171+
models.Message.seq_in_session.between(start_seq, end_seq)
172+
for start_seq, end_seq, _ in merged_ranges
173+
)
174+
),
175+
)
176+
for sess_name, merged_ranges in session_ranges.items()
177+
]
178+
179+
context_stmt = (
180+
select(models.Message)
181+
.where(models.Message.workspace_name == workspace_name)
182+
.where(or_(*session_predicates))
183+
.order_by(
184+
models.Message.session_name.asc(),
185+
models.Message.seq_in_session.asc(),
177186
)
187+
)
178188

179-
context_result = await db.execute(context_stmt)
180-
all_context_messages = list(context_result.scalars().all())
189+
context_result = await db.execute(context_stmt)
190+
by_session: dict[str, list[models.Message]] = {}
191+
for msg in context_result.scalars().all():
192+
by_session.setdefault(msg.session_name, []).append(msg)
181193

182-
# Partition results back into their respective ranges
194+
snippets: list[
195+
tuple[list[models.Message], list[models.Message]]
196+
] = [] # list of tuples, each containing query matches and context messages
197+
for sess_name, merged_ranges in session_ranges.items():
198+
all_context_messages = by_session.get(sess_name, [])
183199
for start_seq, end_seq, range_matches in merged_ranges:
184200
context_messages = [
185201
msg

tests/integration/test_message_embeddings.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,50 @@
1717
from src.config import settings
1818
from src.crud import create_messages
1919
from src.crud import message as message_crud
20-
from src.models import Peer, Workspace
20+
from src.models import Message, Peer, Workspace
2121
from src.schemas import MessageCreate
2222
from src.utils.search import search
2323

2424

25+
class _FakeScalarResult:
26+
def __init__(self, rows: list[models.Message]):
27+
self._rows: list[Message] = rows
28+
29+
def all(self) -> list[models.Message]:
30+
return self._rows
31+
32+
33+
class _FakeResult:
34+
def __init__(self, rows: list[models.Message]):
35+
self._rows: list[Message] = rows
36+
37+
def scalars(self) -> _FakeScalarResult:
38+
return _FakeScalarResult(self._rows)
39+
40+
41+
class _CountingDb:
42+
def __init__(self, rows: list[models.Message]):
43+
self._rows: list[Message] = rows
44+
self.execute_count: int = 0
45+
46+
async def execute(self, _stmt: Any) -> _FakeResult:
47+
self.execute_count += 1
48+
return _FakeResult(self._rows)
49+
50+
51+
def _message(session_name: str, seq_in_session: int) -> models.Message:
52+
return models.Message(
53+
workspace_name="workspace",
54+
session_name=session_name,
55+
peer_name="peer",
56+
content=f"{session_name}:{seq_in_session}",
57+
public_id=generate_nanoid(),
58+
seq_in_session=seq_in_session,
59+
token_count=1,
60+
created_at=datetime.now(timezone.utc),
61+
)
62+
63+
2564
@pytest.mark.asyncio
2665
async def test_message_embedding_created_when_setting_enabled(
2766
db_session: AsyncSession,
@@ -260,6 +299,46 @@ async def test_semantic_search_when_embeddings_enabled(
260299
assert created_message.public_id in found_message_ids
261300

262301

302+
@pytest.mark.asyncio
303+
async def test_build_merged_snippets_batches_context_query_across_sessions():
304+
"""Context expansion should not issue one DB query per matched session."""
305+
matched_messages = [
306+
_message("session_a", 10),
307+
_message("session_b", 20),
308+
_message("session_c", 30),
309+
]
310+
context_messages = [
311+
_message("session_a", 9),
312+
_message("session_a", 10),
313+
_message("session_a", 11),
314+
_message("session_a", 99),
315+
_message("session_b", 19),
316+
_message("session_b", 20),
317+
_message("session_b", 21),
318+
_message("session_c", 29),
319+
_message("session_c", 30),
320+
_message("session_c", 31),
321+
]
322+
db = _CountingDb(context_messages)
323+
324+
snippets = await message_crud._build_merged_snippets( # pyright: ignore[reportPrivateUsage]
325+
db, # pyright: ignore[reportArgumentType]
326+
workspace_name="workspace",
327+
matched_messages=matched_messages,
328+
context_window=1,
329+
)
330+
331+
assert db.execute_count == 1
332+
assert [len(matches) for matches, _ in snippets] == [1, 1, 1]
333+
assert [
334+
[msg.content for msg in context_messages] for _, context_messages in snippets
335+
] == [
336+
["session_a:9", "session_a:10", "session_a:11"],
337+
["session_b:19", "session_b:20", "session_b:21"],
338+
["session_c:29", "session_c:30", "session_c:31"],
339+
]
340+
341+
263342
@pytest.mark.asyncio
264343
async def test_search_messages_external_lookup_happens_before_tracked_db(
265344
monkeypatch: pytest.MonkeyPatch,

0 commit comments

Comments
 (0)