Skip to content

Commit bf1b7aa

Browse files
authored
feat: add per_turn parameter to SlidingWindowConversationManager (#1374)
Allow conversation managers to act as hook providers and add an option to built-in conversation managers to proactively apply message management during the agent loop execution. Use that functionality to add an option to SlidingWindowConversationManager to allow per_turn management application Fixes #509 --------- Co-authored-by: Mackenzie Zastrow <zastrowm@users.noreply.github.com>
1 parent 033574b commit bf1b7aa

5 files changed

Lines changed: 297 additions & 3 deletions

File tree

src/strands/agent/agent.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ def __init__(
250250
if self._session_manager:
251251
self.hooks.add_hook(self._session_manager)
252252

253+
# Allow conversation_managers to subscribe to hooks
254+
self.hooks.add_hook(self.conversation_manager)
255+
253256
self.tool_executor = tool_executor or ConcurrentToolExecutor()
254257

255258
if hooks:

src/strands/agent/conversation_manager/conversation_manager.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33
from abc import ABC, abstractmethod
44
from typing import TYPE_CHECKING, Any, Optional
55

6+
from ...hooks.registry import HookProvider, HookRegistry
67
from ...types.content import Message
78

89
if TYPE_CHECKING:
910
from ...agent.agent import Agent
1011

1112

12-
class ConversationManager(ABC):
13+
class ConversationManager(ABC, HookProvider):
1314
"""Abstract base class for managing conversation history.
1415
1516
This class provides an interface for implementing conversation management strategies to control the size of message
@@ -18,6 +19,18 @@ class ConversationManager(ABC):
1819
- Manage memory usage
1920
- Control context length
2021
- Maintain relevant conversation state
22+
23+
ConversationManager implements the HookProvider protocol, allowing derived classes to register hooks for agent
24+
lifecycle events. Derived classes that override register_hooks must call the base implementation to ensure proper
25+
hook registration.
26+
27+
Example:
28+
```python
29+
class MyConversationManager(ConversationManager):
30+
def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
31+
super().register_hooks(registry, **kwargs)
32+
# Register additional hooks here
33+
```
2134
"""
2235

2336
def __init__(self) -> None:
@@ -30,6 +43,25 @@ def __init__(self) -> None:
3043
"""
3144
self.removed_message_count = 0
3245

46+
def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
47+
"""Register hooks for agent lifecycle events.
48+
49+
Derived classes that override this method must call the base implementation to ensure proper hook
50+
registration chain.
51+
52+
Args:
53+
registry: The hook registry to register callbacks with.
54+
**kwargs: Additional keyword arguments for future extensibility.
55+
56+
Example:
57+
```python
58+
def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None:
59+
super().register_hooks(registry, **kwargs)
60+
registry.add_callback(SomeEvent, self.on_some_event)
61+
```
62+
"""
63+
pass
64+
3365
def restore_from_session(self, state: dict[str, Any]) -> Optional[list[Message]]:
3466
"""Restore the Conversation Manager's state from a session.
3567

src/strands/agent/conversation_manager/sliding_window_conversation_manager.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
if TYPE_CHECKING:
77
from ...agent.agent import Agent
88

9+
from ...hooks import BeforeModelCallEvent, HookRegistry
910
from ...types.content import Messages
1011
from ...types.exceptions import ContextWindowOverflowException
1112
from .conversation_manager import ConversationManager
@@ -18,19 +19,102 @@ class SlidingWindowConversationManager(ConversationManager):
1819
1920
This class handles the logic of maintaining a conversation window that preserves tool usage pairs and avoids
2021
invalid window states.
22+
23+
Supports proactive management during agent loop execution via the per_turn parameter.
2124
"""
2225

23-
def __init__(self, window_size: int = 40, should_truncate_results: bool = True):
26+
def __init__(self, window_size: int = 40, should_truncate_results: bool = True, *, per_turn: bool | int = False):
2427
"""Initialize the sliding window conversation manager.
2528
2629
Args:
2730
window_size: Maximum number of messages to keep in the agent's history.
2831
Defaults to 40 messages.
2932
should_truncate_results: Truncate tool results when a message is too large for the model's context window
33+
per_turn: Controls when to apply message management during agent execution.
34+
- False (default): Only apply management at the end (default behavior)
35+
- True: Apply management before every model call
36+
- int (e.g., 3): Apply management before every N model calls
37+
38+
When to use per_turn: If your agent performs many tool operations in loops
39+
(e.g., web browsing with frequent screenshots), enable per_turn to proactively
40+
manage message history and prevent the agent loop from slowing down. Start with
41+
per_turn=True and adjust to a specific frequency (e.g., per_turn=5) if needed
42+
for performance tuning.
43+
44+
Raises:
45+
ValueError: If per_turn is 0 or a negative integer.
3046
"""
3147
super().__init__()
48+
3249
self.window_size = window_size
3350
self.should_truncate_results = should_truncate_results
51+
self.per_turn = per_turn
52+
self._model_call_count = 0
53+
54+
def register_hooks(self, registry: "HookRegistry", **kwargs: Any) -> None:
55+
"""Register hook callbacks for per-turn conversation management.
56+
57+
Args:
58+
registry: The hook registry to register callbacks with.
59+
**kwargs: Additional keyword arguments for future extensibility.
60+
"""
61+
super().register_hooks(registry, **kwargs)
62+
63+
# Always register the callback - per_turn check happens in the callback
64+
registry.add_callback(BeforeModelCallEvent, self._on_before_model_call)
65+
66+
def _on_before_model_call(self, event: BeforeModelCallEvent) -> None:
67+
"""Handle before model call event for per-turn management.
68+
69+
This callback is invoked before each model call. It tracks the model call count and applies message management
70+
based on the per_turn configuration.
71+
72+
Args:
73+
event: The before model call event containing the agent and model execution details.
74+
"""
75+
# Check if per_turn is enabled
76+
if self.per_turn is False:
77+
return
78+
79+
self._model_call_count += 1
80+
81+
# Determine if we should apply management
82+
should_apply = False
83+
if self.per_turn is True:
84+
should_apply = True
85+
elif isinstance(self.per_turn, int) and self.per_turn > 0:
86+
should_apply = self._model_call_count % self.per_turn == 0
87+
88+
if should_apply:
89+
logger.debug(
90+
"model_call_count=<%d>, per_turn=<%s> | applying per-turn conversation management",
91+
self._model_call_count,
92+
self.per_turn,
93+
)
94+
self.apply_management(event.agent)
95+
96+
def get_state(self) -> dict[str, Any]:
97+
"""Get the current state of the conversation manager.
98+
99+
Returns:
100+
Dictionary containing the manager's state, including model call count for per-turn tracking.
101+
"""
102+
state = super().get_state()
103+
state["model_call_count"] = self._model_call_count
104+
return state
105+
106+
def restore_from_session(self, state: dict[str, Any]) -> Optional[list]:
107+
"""Restore the conversation manager's state from a session.
108+
109+
Args:
110+
state: Previous state of the conversation manager
111+
112+
Returns:
113+
Optional list of messages to prepend to the agent's messages.
114+
"""
115+
result = super().restore_from_session(state)
116+
self._model_call_count = state.get("model_call_count", 0)
117+
return result
34118

35119
def apply_management(self, agent: "Agent", **kwargs: Any) -> None:
36120
"""Apply the sliding window to the agent's messages array to maintain a manageable history size.

src/strands/hooks/registry.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import inspect
1111
import logging
1212
from dataclasses import dataclass
13-
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar
13+
from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar, runtime_checkable
1414

1515
from ..interrupt import Interrupt, InterruptException
1616

@@ -84,6 +84,7 @@ class HookEvent(BaseHookEvent):
8484
"""Generic for invoking events - non-contravariant to enable returning events."""
8585

8686

87+
@runtime_checkable
8788
class HookProvider(Protocol):
8889
"""Protocol for objects that provide hook callbacks to an agent.
8990

tests/strands/agent/test_conversation_manager.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
1+
from unittest.mock import MagicMock, patch
2+
13
import pytest
24

5+
from strands import tool
36
from strands.agent.agent import Agent
47
from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager
58
from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager
9+
from strands.hooks.events import BeforeModelCallEvent
10+
from strands.hooks.registry import HookProvider, HookRegistry
611
from strands.types.exceptions import ContextWindowOverflowException
12+
from tests.fixtures.mocked_model_provider import MockedModelProvider
713

814

915
@pytest.fixture
@@ -246,3 +252,171 @@ def test_null_conversation_does_not_restore_with_incorrect_state():
246252

247253
with pytest.raises(ValueError):
248254
manager.restore_from_session({})
255+
256+
257+
# ==============================================================================
258+
# Per-Turn Management Tests
259+
# ==============================================================================
260+
261+
262+
def test_per_turn_parameter_validation():
263+
"""Test per_turn parameter validation."""
264+
# Valid values
265+
assert SlidingWindowConversationManager(per_turn=False).per_turn is False
266+
assert SlidingWindowConversationManager(per_turn=True).per_turn is True
267+
assert SlidingWindowConversationManager(per_turn=3).per_turn == 3
268+
269+
270+
def test_conversation_manager_is_hook_provider():
271+
"""Test that ConversationManager implements HookProvider protocol."""
272+
manager = NullConversationManager()
273+
assert isinstance(manager, HookProvider)
274+
275+
276+
def test_derived_class_does_not_need_to_implement_register_hooks():
277+
"""Test that derived classes don't need to override register_hooks for backwards compatibility."""
278+
from strands.agent.conversation_manager.conversation_manager import ConversationManager
279+
280+
class MinimalConversationManager(ConversationManager):
281+
"""A minimal implementation that only implements abstract methods."""
282+
283+
def apply_management(self, agent, **kwargs):
284+
pass
285+
286+
def reduce_context(self, agent, e=None, **kwargs):
287+
pass
288+
289+
# Should be able to instantiate without implementing register_hooks
290+
manager = MinimalConversationManager()
291+
registry = HookRegistry()
292+
293+
# Should work without error
294+
manager.register_hooks(registry)
295+
assert not registry.has_callbacks()
296+
297+
298+
def test_per_turn_hooks_registration():
299+
"""Test that hooks are registered when conversation_manager implements HookProvider."""
300+
manager = SlidingWindowConversationManager(per_turn=True)
301+
assert isinstance(manager, HookProvider)
302+
303+
registry = HookRegistry()
304+
manager.register_hooks(registry)
305+
assert registry.has_callbacks()
306+
307+
308+
def test_per_turn_false_no_management_during_loop():
309+
"""Test that per_turn=False only manages in finally block."""
310+
manager = SlidingWindowConversationManager(per_turn=False, window_size=100)
311+
responses = [{"role": "assistant", "content": [{"text": "Response"}]}] * 3
312+
model = MockedModelProvider(responses)
313+
agent = Agent(model=model, conversation_manager=manager)
314+
315+
with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock:
316+
agent("Test")
317+
# Should only be called once in finally block (per_turn disabled)
318+
assert mock.call_count == 1
319+
320+
321+
def test_per_turn_true_manages_each_model_call():
322+
"""Test that per_turn=True applies management before each model call."""
323+
manager = SlidingWindowConversationManager(per_turn=True, window_size=100)
324+
responses = [{"role": "assistant", "content": [{"text": "Response"}]}] * 3
325+
model = MockedModelProvider(responses)
326+
agent = Agent(model=model, conversation_manager=manager)
327+
328+
with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock:
329+
agent("Test")
330+
# Should be called for each model call + finally block
331+
# With simple text responses, agent makes 1 model call then stops
332+
assert mock.call_count >= 1
333+
334+
335+
def test_per_turn_integer_manages_every_n_calls():
336+
"""Test that per_turn=N applies management every N model calls."""
337+
manager = SlidingWindowConversationManager(per_turn=2, window_size=100)
338+
# Create responses that trigger multiple model calls
339+
responses = [
340+
{"role": "assistant", "content": [{"toolUse": {"toolUseId": f"{i}", "name": "test", "input": {}}}]}
341+
for i in range(5)
342+
] + [{"role": "assistant", "content": [{"text": "Done"}]}]
343+
model = MockedModelProvider(responses)
344+
345+
@tool(name="test")
346+
def test_tool(query: str = "") -> str:
347+
return "result"
348+
349+
agent = Agent(model=model, conversation_manager=manager, tools=[test_tool])
350+
351+
with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock:
352+
agent("Test")
353+
# With 6 model calls and per_turn=2: called on 2nd, 4th, 6th + finally
354+
assert mock.call_count == 4
355+
356+
357+
def test_per_turn_dynamic_change():
358+
"""Test that per_turn can be changed dynamically."""
359+
manager = SlidingWindowConversationManager(per_turn=False)
360+
registry = HookRegistry()
361+
manager.register_hooks(registry)
362+
363+
mock_agent = MagicMock()
364+
mock_agent.messages = []
365+
event = BeforeModelCallEvent(agent=mock_agent)
366+
367+
# Initially disabled
368+
with patch.object(manager, "apply_management") as mock_apply:
369+
registry.invoke_callbacks(event)
370+
assert mock_apply.call_count == 0
371+
372+
# Enable dynamically
373+
manager.per_turn = True
374+
with patch.object(manager, "apply_management") as mock_apply:
375+
registry.invoke_callbacks(event)
376+
assert mock_apply.call_count == 1
377+
378+
379+
def test_per_turn_reduces_message_count():
380+
"""Test that per_turn actually reduces message count during execution."""
381+
manager = SlidingWindowConversationManager(per_turn=1, window_size=4)
382+
responses = [{"role": "assistant", "content": [{"text": f"Response {i}"}]} for i in range(10)]
383+
model = MockedModelProvider(responses)
384+
agent = Agent(model=model, conversation_manager=manager)
385+
386+
message_counts = []
387+
original_apply = manager.apply_management
388+
389+
def track_apply(agent_instance):
390+
message_counts.append(len(agent_instance.messages))
391+
return original_apply(agent_instance)
392+
393+
with patch.object(manager, "apply_management", side_effect=track_apply):
394+
agent("Test")
395+
396+
# Verify message count stayed around window_size
397+
assert any(count <= manager.window_size for count in message_counts)
398+
399+
400+
def test_per_turn_state_persistence():
401+
"""Test that model_call_count is persisted in state."""
402+
manager = SlidingWindowConversationManager(per_turn=3)
403+
manager._model_call_count = 7
404+
405+
state = manager.get_state()
406+
assert state["model_call_count"] == 7
407+
408+
new_manager = SlidingWindowConversationManager(per_turn=3)
409+
new_manager.restore_from_session(state)
410+
assert new_manager._model_call_count == 7
411+
412+
413+
def test_per_turn_backward_compatibility():
414+
"""Test that existing code without per_turn still works."""
415+
manager = SlidingWindowConversationManager(window_size=40)
416+
assert manager.per_turn is False
417+
418+
responses = [{"role": "assistant", "content": [{"text": "Hello"}]}]
419+
model = MockedModelProvider(responses)
420+
agent = Agent(model=model, conversation_manager=manager)
421+
result = agent("Hello")
422+
assert result is not None

0 commit comments

Comments
 (0)