-
Notifications
You must be signed in to change notification settings - Fork 846
Expand file tree
/
Copy pathsliding_window_conversation_manager.py
More file actions
394 lines (322 loc) · 17.4 KB
/
sliding_window_conversation_manager.py
File metadata and controls
394 lines (322 loc) · 17.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
"""Sliding window conversation history management."""
import logging
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from ...agent.agent import Agent
from ...hooks import BeforeModelCallEvent, HookRegistry
from ...types.content import ContentBlock, Messages
from ...types.exceptions import ContextWindowOverflowException
from ...types.tools import ToolResultContent
from .conversation_manager import ConversationManager
logger = logging.getLogger(__name__)
_PRESERVE_CHARS = 200
class SlidingWindowConversationManager(ConversationManager):
"""Implements a sliding window strategy for managing conversation history.
This class handles the logic of maintaining a conversation window that preserves tool usage pairs and avoids
invalid window states.
When truncation is enabled (the default), large tool results are partially truncated, preserving the first
and last 200 characters, and image blocks inside tool results are replaced with descriptive text placeholders.
Truncation targets the oldest tool results first so the most relevant recent context is preserved as long
as possible.
Supports proactive management during agent loop execution via the per_turn parameter.
"""
def __init__(
self,
window_size: int = 40,
should_truncate_results: bool = True,
*,
per_turn: bool | int = False,
protected_messages: int = 0,
):
"""Initialize the sliding window conversation manager.
Args:
window_size: Maximum number of messages to keep in the agent's history.
Defaults to 40 messages.
should_truncate_results: Truncate tool results when a message is too large for the model's context window
per_turn: Controls when to apply message management during agent execution.
- False (default): Only apply management at the end (default behavior)
- True: Apply management before every model call
- int (e.g., 3): Apply management before every N model calls
When to use per_turn: If your agent performs many tool operations in loops
(e.g., web browsing with frequent screenshots), enable per_turn to proactively
manage message history and prevent the agent loop from slowing down. Start with
per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed
for performance tuning.
protected_messages: Number of messages at the start of the conversation that should
never be removed during trimming. Defaults to 0 (no protection).
Use this when the first message(s) contain a task prompt or critical context that
the agent must retain throughout the entire conversation. For example, in batch
report generation, set ``protected_messages=1`` to ensure the initial user prompt
is never trimmed away during context overflow recovery.
Raises:
ValueError: If per_turn is 0 or a negative integer, or if protected_messages is negative.
"""
if isinstance(per_turn, int) and not isinstance(per_turn, bool) and per_turn <= 0:
raise ValueError(f"per_turn must be a positive integer, True, or False, got {per_turn}")
if protected_messages < 0:
raise ValueError(f"protected_messages must be non-negative, got {protected_messages}")
super().__init__()
self.window_size = window_size
self.should_truncate_results = should_truncate_results
self.per_turn = per_turn
self.protected_messages = protected_messages
self._model_call_count = 0
def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None:
"""Register hook callbacks for per-turn conversation management.
Args:
registry: The hook registry to register callbacks with.
**kwargs: Additional keyword arguments for future extensibility.
"""
super().register_hooks(registry, **kwargs)
# Always register the callback - per_turn check happens in the callback
registry.add_callback(BeforeModelCallEvent, self._on_before_model_call)
def _on_before_model_call(self, event: BeforeModelCallEvent) -> None:
"""Handle before model call event for per-turn management.
This callback is invoked before each model call. It tracks the model call count and applies message management
based on the per_turn configuration.
Args:
event: The before model call event containing the agent and model execution details.
"""
# Check if per_turn is enabled
if self.per_turn is False:
return
self._model_call_count += 1
# Determine if we should apply management
should_apply = False
if self.per_turn is True:
should_apply = True
elif isinstance(self.per_turn, int) and self.per_turn > 0:
should_apply = self._model_call_count % self.per_turn == 0
if should_apply:
logger.debug(
"model_call_count=<%d>, per_turn=<%s> | applying per-turn conversation management",
self._model_call_count,
self.per_turn,
)
self.apply_management(event.agent)
def get_state(self) -> dict[str, Any]:
"""Get the current state of the conversation manager.
Returns:
Dictionary containing the manager's state, including model call count for per-turn tracking.
"""
state = super().get_state()
state["model_call_count"] = self._model_call_count
return state
def restore_from_session(self, state: dict[str, Any]) -> list | None:
"""Restore the conversation manager's state from a session.
Args:
state: Previous state of the conversation manager
Returns:
Optional list of messages to prepend to the agent's messages.
"""
result = super().restore_from_session(state)
self._model_call_count = state.get("model_call_count", 0)
return result
def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
"""Apply the sliding window to the agent's messages array to maintain a manageable history size.
This method is called after every event loop cycle to apply a sliding window if the message count
exceeds the window size.
Args:
agent: The agent whose messages will be managed.
This list is modified in-place.
**kwargs: Additional keyword arguments for future extensibility.
"""
messages = agent.messages
if len(messages) <= self.window_size:
logger.debug(
"message_count=<%s>, window_size=<%s> | skipping context reduction", len(messages), self.window_size
)
return
self.reduce_context(agent)
def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None:
"""Trim the oldest messages to reduce the conversation context size.
The method handles special cases where trimming the messages leads to:
- toolResult with no corresponding toolUse
- toolUse with no corresponding toolResult
When ``protected_messages`` is set, the first N messages are preserved and
re-inserted after trimming so that critical context (e.g. the initial task
prompt) is never lost.
Args:
agent: The agent whose messages will be reduce.
This list is modified in-place.
e: The exception that triggered the context reduction, if any.
**kwargs: Additional keyword arguments for future extensibility.
Raises:
ContextWindowOverflowException: If the context cannot be reduced further and a context overflow
error was provided (e is not None). When called during routine window management (e is None),
logs a warning and returns without modification.
"""
messages = agent.messages
# Snapshot protected messages before any trimming
protected: list = []
if self.protected_messages > 0 and len(messages) > self.protected_messages:
protected = [msg for msg in messages[: self.protected_messages]]
# Try to truncate the tool result first
oldest_message_idx_with_tool_results = self._find_oldest_message_with_tool_results(messages)
if oldest_message_idx_with_tool_results is not None and self.should_truncate_results:
logger.debug(
"message_index=<%s> | found message with tool results at index", oldest_message_idx_with_tool_results
)
results_truncated = self._truncate_tool_results(messages, oldest_message_idx_with_tool_results)
if results_truncated:
logger.debug("message_index=<%s> | tool results truncated", oldest_message_idx_with_tool_results)
return
# Try to trim index id when tool result cannot be truncated anymore
# If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size
# Never trim into the protected region
if trim_index < self.protected_messages:
trim_index = self.protected_messages
# Find the next valid trim point that:
# 1. Starts with a user message (required by most model providers)
# 2. Does not start with an orphaned toolResult
# 3. Does not start with a toolUse unless its toolResult immediately follows
# Falls back to an assistant(toolUse) + user(toolResult) boundary if no plain user message exists.
# This is acceptable because providers treat a complete toolUse/toolResult pair as a valid
# conversation continuation, and without this fallback tool-heavy conversations cannot be trimmed.
fallback_trim_index = None
while trim_index < len(messages):
# Prefer starting with a user message
if messages[trim_index]["role"] != "user":
# Track first valid assistant(toolUse) + user(toolResult) pair as fallback
if (
fallback_trim_index is None
and any("toolUse" in content for content in messages[trim_index]["content"])
and trim_index + 1 < len(messages)
and messages[trim_index + 1]["role"] == "user"
and any("toolResult" in content for content in messages[trim_index + 1]["content"])
):
fallback_trim_index = trim_index
trim_index += 1
continue
if (
# Oldest message cannot be a toolResult because it needs a toolUse preceding it
any("toolResult" in content for content in messages[trim_index]["content"])
or (
# Oldest message can be a toolUse only if a toolResult immediately follows it.
# Note: toolUse content normally appears only in assistant messages, but this
# check is kept as a defensive safeguard for non-standard message formats.
any("toolUse" in content for content in messages[trim_index]["content"])
and not (
trim_index + 1 < len(messages)
and any("toolResult" in content for content in messages[trim_index + 1]["content"])
)
)
):
trim_index += 1
else:
break
else:
# No plain user message found — use assistant+toolResult fallback if available
if fallback_trim_index is not None:
logger.debug(
"trim_index=<%s> | no plain user message trim point found, "
"falling back to assistant(toolUse) + user(toolResult) boundary",
fallback_trim_index,
)
trim_index = fallback_trim_index
elif e is not None:
raise ContextWindowOverflowException("Unable to trim conversation context!") from e
else:
logger.warning(
"window_size=<%s>, message_count=<%s> | unable to trim conversation context, "
"no valid trim point found",
self.window_size,
len(messages),
)
return
# trim_index represents the number of messages being removed from the agents messages array
self.removed_message_count += trim_index
# Overwrite message history
messages[:] = messages[trim_index:]
# Re-insert protected messages that were trimmed away
if protected:
# Check which protected messages are no longer present
reinsert = [msg for msg in protected if msg not in messages]
if reinsert:
messages[:0] = reinsert
logger.info(
"protected_messages=<%d> | re-inserted %d protected message(s) after trim",
self.protected_messages,
len(reinsert),
)
def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool:
"""Truncate tool results and replace image blocks in a message to reduce context size.
For text blocks within tool results, all blocks are partially truncated unless they
have already been truncated. The first and last _PRESERVE_CHARS characters are kept,
and the removed middle is replaced with a notice indicating how many characters were
removed. The tool result status is not changed.
Image blocks nested inside tool result content are replaced with a short descriptive placeholder.
Args:
messages: The conversation message history.
msg_idx: Index of the message containing tool results to truncate.
Returns:
True if any changes were made to the message, False otherwise.
"""
if msg_idx >= len(messages) or msg_idx < 0:
return False
def _image_placeholder(image_block: Any) -> str:
source: Any = image_block.get("source", {})
media_type = image_block.get("format", "unknown")
data = source.get("bytes", b"")
return f"[image: {media_type}, {len(data) if data else 0} bytes]"
message = messages[msg_idx]
changes_made = False
new_content: list[ContentBlock] = []
for content in message.get("content", []):
if "toolResult" in content:
tool_result: Any = content["toolResult"]
tool_result_items = tool_result.get("content", [])
new_items: list[ToolResultContent] = []
item_changed = False
for item in tool_result_items:
# Replace image items nested inside toolResult content
if "image" in item:
new_items.append({"text": _image_placeholder(item["image"])})
item_changed = True
continue
# Partially truncate text items that have not already been truncated
if "text" in item:
text = item["text"]
truncation_marker = "... [truncated:"
if truncation_marker not in text and len(text) > 2 * _PRESERVE_CHARS:
prefix = text[:_PRESERVE_CHARS]
suffix = text[-_PRESERVE_CHARS:]
removed = len(text) - 2 * _PRESERVE_CHARS
truncated_text = (
f"{prefix}...\n\n... [truncated: {removed} chars removed] ...\n\n...{suffix}"
)
new_items.append({"text": truncated_text})
item_changed = True
continue
new_items.append(item)
if item_changed:
updated_tool_result: Any = {
**{k: v for k, v in tool_result.items() if k != "content"},
"content": new_items,
}
new_content.append({"toolResult": updated_tool_result})
changes_made = True
else:
new_content.append(content)
continue
new_content.append(content)
if changes_made:
message["content"] = new_content
return changes_made
def _find_oldest_message_with_tool_results(self, messages: Messages) -> int | None:
"""Find the index of the oldest message containing tool results.
Iterates from oldest to newest so that truncation targets the least-recent
(and therefore least relevant) tool results first.
Args:
messages: The conversation message history.
Returns:
Index of the oldest message with tool results, or None if no such message exists.
"""
# Iterate from oldest to newest
for idx in range(len(messages)):
current_message = messages[idx]
for content in current_message.get("content", []):
if isinstance(content, dict) and "toolResult" in content:
return idx
return None