Skip to content

Commit b005bca

Browse files
author
Goutham Chandramouli
committed
feat(conversation-manager): add protected_messages to SlidingWindowConversationManager
The sliding window can trim the first user message during context overflow recovery, causing the agent to lose its task prompt. Add a protected_messages parameter that preserves the first N messages from being removed during trimming. Before trimming, the first protected_messages messages are snapshotted. The trim index is clamped to never fall inside the protected region, and any protected messages removed by aggressive overflow trimming are re-inserted afterward. Default is 0 (no protection), preserving full backward compatibility.
1 parent 513e67d commit b005bca

2 files changed

Lines changed: 142 additions & 1 deletion

File tree

src/strands/agent/conversation_manager/sliding_window_conversation_manager.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
should_truncate_results: bool = True,
3838
*,
3939
per_turn: bool | int = False,
40+
protected_messages: int = 0,
4041
):
4142
"""Initialize the sliding window conversation manager.
4243
@@ -54,18 +55,29 @@ def __init__(
5455
manage message history and prevent the agent loop from slowing down. Start with
5556
per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed
5657
for performance tuning.
58+
protected_messages: Number of messages at the start of the conversation that should
59+
never be removed during trimming. Defaults to 0 (no protection).
60+
61+
Use this when the first message(s) contain a task prompt or critical context that
62+
the agent must retain throughout the entire conversation. For example, in batch
63+
report generation, set ``protected_messages=1`` to ensure the initial user prompt
64+
is never trimmed away during context overflow recovery.
5765
5866
Raises:
59-
ValueError: If per_turn is 0 or a negative integer.
67+
ValueError: If per_turn is 0 or a negative integer, or if protected_messages is negative.
6068
"""
6169
if isinstance(per_turn, int) and not isinstance(per_turn, bool) and per_turn <= 0:
6270
raise ValueError(f"per_turn must be a positive integer, True, or False, got {per_turn}")
6371

72+
if protected_messages < 0:
73+
raise ValueError(f"protected_messages must be non-negative, got {protected_messages}")
74+
6475
super().__init__()
6576

6677
self.window_size = window_size
6778
self.should_truncate_results = should_truncate_results
6879
self.per_turn = per_turn
80+
self.protected_messages = protected_messages
6981
self._model_call_count = 0
7082

7183
def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None:
@@ -160,6 +172,10 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A
160172
- toolResult with no corresponding toolUse
161173
- toolUse with no corresponding toolResult
162174
175+
When ``protected_messages`` is set, the first N messages are preserved and
176+
re-inserted after trimming so that critical context (e.g. the initial task
177+
prompt) is never lost.
178+
163179
Args:
164180
agent: The agent whose messages will be reduce.
165181
This list is modified in-place.
@@ -173,6 +189,11 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A
173189
"""
174190
messages = agent.messages
175191

192+
# Snapshot protected messages before any trimming
193+
protected: list = []
194+
if self.protected_messages > 0 and len(messages) > self.protected_messages:
195+
protected = [msg for msg in messages[: self.protected_messages]]
196+
176197
# Try to truncate the tool result first
177198
oldest_message_idx_with_tool_results = self._find_oldest_message_with_tool_results(messages)
178199
if oldest_message_idx_with_tool_results is not None and self.should_truncate_results:
@@ -188,6 +209,10 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A
188209
# If the number of messages is less than the window_size, then we default to 2, otherwise, trim to window size
189210
trim_index = 2 if len(messages) <= self.window_size else len(messages) - self.window_size
190211

212+
# Never trim into the protected region
213+
if trim_index < self.protected_messages:
214+
trim_index = self.protected_messages
215+
191216
# Find the next valid trim point that:
192217
# 1. Starts with a user message (required by most model providers)
193218
# 2. Does not start with an orphaned toolResult
@@ -256,6 +281,18 @@ def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: A
256281
# Overwrite message history
257282
messages[:] = messages[trim_index:]
258283

284+
# Re-insert protected messages that were trimmed away
285+
if protected:
286+
# Check which protected messages are no longer present
287+
reinsert = [msg for msg in protected if msg not in messages]
288+
if reinsert:
289+
messages[:0] = reinsert
290+
logger.info(
291+
"protected_messages=<%d> | re-inserted %d protected message(s) after trim",
292+
self.protected_messages,
293+
len(reinsert),
294+
)
295+
259296
def _truncate_tool_results(self, messages: Messages, msg_idx: int) -> bool:
260297
"""Truncate tool results and replace image blocks in a message to reduce context size.
261298

tests/strands/agent/test_conversation_manager.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,3 +703,107 @@ def test_boundary_text_in_tool_result_not_truncated():
703703

704704
assert not changed
705705
assert messages[0]["content"][0]["toolResult"]["content"][0]["text"] == boundary_text
706+
707+
708+
# ── protected_messages tests ──────────────────────────────────────────────
709+
710+
711+
def test_protected_messages_negative_raises():
712+
"""protected_messages must be non-negative."""
713+
with pytest.raises(ValueError, match="non-negative"):
714+
SlidingWindowConversationManager(protected_messages=-1)
715+
716+
717+
def test_protected_messages_zero_is_default():
718+
"""Default protected_messages=0 behaves identically to the original manager."""
719+
manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False)
720+
assert manager.protected_messages == 0
721+
722+
723+
def test_protected_messages_preserves_first_message_on_trim():
724+
"""When protected_messages=1, the first user message survives trimming."""
725+
manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False, protected_messages=1)
726+
agent = MagicMock()
727+
agent.messages = [
728+
{"role": "user", "content": [{"text": "Generate the report"}]},
729+
{"role": "assistant", "content": [{"text": "Step 1"}]},
730+
{"role": "user", "content": [{"text": "Follow-up"}]},
731+
{"role": "assistant", "content": [{"text": "Step 2"}]},
732+
{"role": "user", "content": [{"text": "Another question"}]},
733+
]
734+
735+
manager.apply_management(agent)
736+
737+
# The first message must still be present
738+
assert agent.messages[0]["content"][0]["text"] == "Generate the report"
739+
# And the conversation should end with the most recent messages
740+
assert agent.messages[-1]["content"][0]["text"] == "Another question"
741+
742+
743+
def test_protected_messages_preserves_first_message_on_overflow():
744+
"""protected_messages=1 preserves the prompt even during context overflow (reduce_context with e)."""
745+
manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False, protected_messages=1)
746+
agent = MagicMock()
747+
agent.messages = [
748+
{"role": "user", "content": [{"text": "Task prompt"}]},
749+
{"role": "assistant", "content": [{"text": "Calling tools"}]},
750+
{"role": "user", "content": [{"text": "Tool results"}]},
751+
{"role": "assistant", "content": [{"text": "More work"}]},
752+
{"role": "user", "content": [{"text": "More results"}]},
753+
]
754+
755+
manager.reduce_context(agent, e=RuntimeError("context overflow"))
756+
757+
assert agent.messages[0]["content"][0]["text"] == "Task prompt"
758+
759+
760+
def test_protected_messages_multiple():
761+
"""protected_messages=2 preserves the first two messages."""
762+
manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False, protected_messages=2)
763+
agent = MagicMock()
764+
agent.messages = [
765+
{"role": "user", "content": [{"text": "System context"}]},
766+
{"role": "assistant", "content": [{"text": "Acknowledged"}]},
767+
{"role": "user", "content": [{"text": "Question 1"}]},
768+
{"role": "assistant", "content": [{"text": "Answer 1"}]},
769+
{"role": "user", "content": [{"text": "Question 2"}]},
770+
]
771+
772+
manager.apply_management(agent)
773+
774+
assert agent.messages[0]["content"][0]["text"] == "System context"
775+
assert agent.messages[1]["content"][0]["text"] == "Acknowledged"
776+
777+
778+
def test_protected_messages_no_trim_needed():
779+
"""When messages fit in the window, protected_messages has no effect."""
780+
manager = SlidingWindowConversationManager(window_size=10, should_truncate_results=False, protected_messages=1)
781+
agent = MagicMock()
782+
agent.messages = [
783+
{"role": "user", "content": [{"text": "Hello"}]},
784+
{"role": "assistant", "content": [{"text": "Hi"}]},
785+
]
786+
787+
manager.apply_management(agent)
788+
789+
assert len(agent.messages) == 2
790+
791+
792+
def test_protected_messages_trim_index_skips_protected_region():
793+
"""The trim index must never fall within the protected region."""
794+
manager = SlidingWindowConversationManager(window_size=3, should_truncate_results=False, protected_messages=1)
795+
agent = MagicMock()
796+
# 5 messages, window_size=3 → trim_index starts at 2
797+
# But protected_messages=1 means index 0 is protected
798+
agent.messages = [
799+
{"role": "user", "content": [{"text": "Important prompt"}]},
800+
{"role": "assistant", "content": [{"text": "Response 1"}]},
801+
{"role": "user", "content": [{"text": "Q2"}]},
802+
{"role": "assistant", "content": [{"text": "Response 2"}]},
803+
{"role": "user", "content": [{"text": "Q3"}]},
804+
]
805+
806+
manager.apply_management(agent)
807+
808+
# First message must survive
809+
assert agent.messages[0]["content"][0]["text"] == "Important prompt"

0 commit comments

Comments
 (0)