Skip to content

Commit 11ea048

Browse files
dpageclaude
andcommitted
Address CodeRabbit review feedback for chat context and compaction.
- Track tool-use turns as groups instead of one-to-one pairs, so multi-tool assistant messages don't leave orphaned results. - Add fallback to shrink the recent window when protected messages alone exceed the token budget, preventing compaction no-ops. - Fix low-value test fixtures to keep transient messages short so they actually classify as low-importance. - Guard Clear button against in-flight stream race conditions by adding a clearedRef flag and cancelling active streams. - Assert that conversation history is actually passed through to chat_with_database in the "With History" test. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent c30e633 commit 11ea048

File tree

4 files changed

+103
-41
lines changed

4 files changed

+103
-41
lines changed

web/pgadmin/llm/compaction.py

Lines changed: 49 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -182,34 +182,41 @@ def _classify_message(message: Message) -> float:
182182
return CLASS_CONTEXTUAL
183183

184184

185-
def _find_tool_pair_indices(messages: list[Message]) -> dict[int, int]:
186-
"""Find indices of tool_call/tool_result pairs that must stay together.
185+
def _find_tool_pair_indices(
186+
messages: list[Message]
187+
) -> dict[int, frozenset[int]]:
188+
"""Find indices of tool_call/tool_result groups that must stay together.
187189
188-
Returns a mapping where both the assistant message index and the
189-
tool result message index map to each other, so removing one
190-
implies removing both.
190+
An assistant message may contain multiple tool_calls, each with a
191+
corresponding tool result message. All messages in such a group
192+
must be dropped or kept together.
193+
194+
Returns a mapping where every index in a group maps to the full
195+
set of indices in that group.
191196
192197
Args:
193198
messages: The message list.
194199
195200
Returns:
196-
Dict mapping index -> paired index.
201+
Dict mapping index -> frozenset of all indices in the group.
197202
"""
198-
pairs = {}
203+
groups: dict[int, frozenset[int]] = {}
199204

200205
for i, msg in enumerate(messages):
201206
if msg.role == Role.ASSISTANT and msg.tool_calls:
202-
# Find the corresponding tool result(s)
203207
tool_call_ids = {tc.id for tc in msg.tool_calls}
208+
group_indices = {i}
204209
for j in range(i + 1, len(messages)):
205210
if messages[j].role == Role.TOOL:
206211
for tr in messages[j].tool_results:
207212
if tr.tool_call_id in tool_call_ids:
208-
pairs[i] = j
209-
pairs[j] = i
213+
group_indices.add(j)
210214
break
215+
group = frozenset(group_indices)
216+
for idx in group:
217+
groups[idx] = group
211218

212-
return pairs
219+
return groups
213220

214221

215222
def compact_history(
@@ -257,8 +264,21 @@ def compact_history(
257264
for i in range(recent_start, total):
258265
protected.add(i)
259266

260-
# Find tool pairs
261-
tool_pairs = _find_tool_pair_indices(messages)
267+
# If protected messages alone exceed the budget, shrink the
268+
# recent window until we have room for compaction candidates.
269+
while recent_window > 0:
270+
protected_tokens = sum(
271+
estimate_message_tokens(messages[i], provider)
272+
for i in protected
273+
)
274+
if protected_tokens <= max_tokens:
275+
break
276+
recent_window -= 1
277+
recent_start = max(1, total - recent_window)
278+
protected = {0} | set(range(recent_start, total))
279+
280+
# Find tool groups
281+
tool_groups = _find_tool_pair_indices(messages)
262282

263283
# Classify and score all non-protected messages
264284
candidates = []
@@ -276,7 +296,7 @@ def compact_history(
276296
if current_tokens <= max_tokens:
277297
break
278298

279-
# Skip if already dropped (as part of a pair)
299+
# Skip if already dropped (as part of a group)
280300
if idx in dropped:
281301
continue
282302

@@ -288,12 +308,14 @@ def compact_history(
288308
saved = estimate_message_tokens(messages[idx], provider)
289309
dropped.add(idx)
290310

291-
# If this is part of a tool pair, drop the partner too
292-
if idx in tool_pairs:
293-
partner = tool_pairs[idx]
294-
if partner not in protected:
295-
saved += estimate_message_tokens(messages[partner], provider)
296-
dropped.add(partner)
311+
# If this is part of a tool group, drop all partners too
312+
if idx in tool_groups:
313+
for partner in tool_groups[idx]:
314+
if partner != idx and partner not in protected:
315+
saved += estimate_message_tokens(
316+
messages[partner], provider
317+
)
318+
dropped.add(partner)
297319

298320
current_tokens -= saved
299321

@@ -308,13 +330,13 @@ def compact_history(
308330
saved = estimate_message_tokens(messages[idx], provider)
309331
dropped.add(idx)
310332

311-
if idx in tool_pairs:
312-
partner = tool_pairs[idx]
313-
if partner not in protected:
314-
saved += estimate_message_tokens(
315-
messages[partner], provider
316-
)
317-
dropped.add(partner)
333+
if idx in tool_groups:
334+
for partner in tool_groups[idx]:
335+
if partner != idx and partner not in protected:
336+
saved += estimate_message_tokens(
337+
messages[partner], provider
338+
)
339+
dropped.add(partner)
318340

319341
current_tokens -= saved
320342

web/pgadmin/llm/tests/test_compaction.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,16 @@ def test_preserves_first_and_recent(self):
137137

138138
def test_drops_low_value(self):
139139
"""Low-value messages should be dropped first."""
140-
# Use longer messages to ensure we exceed the token budget
140+
# Filler only on important messages to inflate token count;
141+
# keep transient messages short so they classify as low-value.
141142
filler = ' This is extra text to increase token count.' * 5
142143
messages = [
143144
Message.user('First important query' + filler),
144-
# Short transient messages (low value)
145-
Message.user('ok' + filler),
146-
Message.assistant('ok' + filler),
147-
Message.user('thanks' + filler),
148-
Message.assistant('sure' + filler),
145+
# Short transient messages (low value) - no filler
146+
Message.user('ok'),
147+
Message.assistant('ok'),
148+
Message.user('thanks'),
149+
Message.assistant('sure'),
149150
# More substantial messages
150151
Message.user('Show me the schema with CREATE TABLE' + filler),
151152
Message.assistant(
@@ -166,6 +167,10 @@ def test_drops_low_value(self):
166167
self.assertIn('First important query', result[0].content)
167168
# Last 2 preserved
168169
self.assertIn('Final answer with details', result[-1].content)
170+
# Transient messages should be dropped
171+
contents = [m.content for m in result]
172+
for short_msg in ['ok', 'thanks', 'sure']:
173+
self.assertNotIn(short_msg, contents)
169174

170175
def test_tool_pairs(self):
171176
"""Tool call/result pairs should be dropped together."""

web/pgadmin/tools/sqleditor/static/js/components/sections/NLQChatPanel.jsx

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,7 @@ export function NLQChatPanel() {
292292
const abortControllerRef = useRef(null);
293293
const readerRef = useRef(null);
294294
const stoppedRef = useRef(false);
295+
const clearedRef = useRef(false);
295296
const eventBus = useContext(QueryToolEventsContext);
296297
const queryToolCtx = useContext(QueryToolContext);
297298
const editorPrefs = usePreferences().getPreferencesForModule('editor');
@@ -406,9 +407,21 @@ export function NLQChatPanel() {
406407
};
407408

408409
const handleClearConversation = () => {
410+
// Mark as cleared so in-flight stream handlers ignore late events
411+
clearedRef.current = true;
412+
// Cancel any active stream
413+
if (readerRef.current) {
414+
readerRef.current.cancel();
415+
readerRef.current = null;
416+
}
417+
if (abortControllerRef.current) {
418+
abortControllerRef.current.abort();
419+
abortControllerRef.current = null;
420+
}
409421
setMessages([]);
410422
setConversationId(null);
411423
setConversationHistory([]);
424+
setIsLoading(false);
412425
};
413426

414427
// Stop the current request
@@ -446,8 +459,9 @@ export function NLQChatPanel() {
446459
const handleSubmit = async () => {
447460
if (!inputValue.trim() || isLoading) return;
448461

449-
// Reset stopped flag
462+
// Reset stopped and cleared flags
450463
stoppedRef.current = false;
464+
clearedRef.current = false;
451465

452466
// Fetch latest LLM provider/model info before submitting
453467
fetchLlmInfo();
@@ -548,8 +562,8 @@ export function NLQChatPanel() {
548562

549563
readerRef.current = null;
550564

551-
// Check if user manually stopped
552-
if (stoppedRef.current) {
565+
// Check if user manually stopped (but not cleared)
566+
if (stoppedRef.current && !clearedRef.current) {
553567
setMessages((prev) => [
554568
...prev.filter((m) => m.id !== thinkingId),
555569
{
@@ -562,8 +576,10 @@ export function NLQChatPanel() {
562576
clearTimeout(timeoutId);
563577
abortControllerRef.current = null;
564578
readerRef.current = null;
565-
// Show appropriate message based on error type
566-
if (error.name === 'AbortError') {
579+
// If conversation was cleared, ignore all late errors
580+
if (clearedRef.current) {
581+
// Do nothing - conversation was wiped
582+
} else if (error.name === 'AbortError') {
567583
// Check if this was a user-initiated stop or a timeout
568584
if (stoppedRef.current) {
569585
// User manually stopped

web/pgadmin/tools/sqleditor/tests/test_nlq_chat.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,12 +109,14 @@ def runTest(self):
109109
patches.append(mock_check_trans)
110110

111111
# Mock chat_with_database
112+
mock_chat_patcher = None
113+
mock_chat_obj = None
112114
if hasattr(self, 'mock_response'):
113-
mock_chat = patch(
115+
mock_chat_patcher = patch(
114116
'pgadmin.llm.chat.chat_with_database',
115117
return_value=(self.mock_response, [])
116118
)
117-
patches.append(mock_chat)
119+
patches.append(mock_chat_patcher)
118120

119121
# Mock CSRF protection
120122
mock_csrf = patch(
@@ -124,8 +126,12 @@ def runTest(self):
124126
patches.append(mock_csrf)
125127

126128
# Start all patches
129+
started_mocks = []
127130
for p in patches:
128-
p.start()
131+
m = p.start()
132+
started_mocks.append(m)
133+
if p is mock_chat_patcher:
134+
mock_chat_obj = m
129135

130136
try:
131137
# Make request
@@ -156,6 +162,19 @@ def runTest(self):
156162
self.assertEqual(response.status_code, 200)
157163
self.assertIn('text/event-stream', response.content_type)
158164

165+
# Verify history was passed to chat_with_database
166+
if hasattr(self, 'history') and mock_chat_obj:
167+
mock_chat_obj.assert_called_once()
168+
call_kwargs = mock_chat_obj.call_args.kwargs
169+
conv_hist = call_kwargs.get(
170+
'conversation_history', []
171+
)
172+
self.assertTrue(
173+
len(conv_hist) > 0,
174+
'conversation_history should be non-empty '
175+
'when history is provided'
176+
)
177+
159178
finally:
160179
# Stop all patches
161180
for p in patches:

0 commit comments

Comments
 (0)