From 250cd557b4ecdea82846af54ad9847f4124f3980 Mon Sep 17 00:00:00 2001 From: Vincent Bai Date: Wed, 29 Apr 2026 16:10:22 -0700 Subject: [PATCH] feat(hooks): emit Before/AfterReduceContextEvent from ConversationManager Move event emission from the single overflow call site in agent.py into ConversationManager.reduce_context as a concrete template method. Subclasses implement _reduce_context; the framework wraps it with Before/After event emission so every call path (reactive overflow, proactive apply_management, per-turn, direct tool calls) gets events automatically. An __init_subclass__ shim detects third-party subclasses that override reduce_context directly, transparently re-wires them to _reduce_context, and emits a DeprecationWarning. Closes #2048 --- .../conversation_manager.py | 63 ++++++++- .../null_conversation_manager.py | 2 +- .../sliding_window_conversation_manager.py | 2 +- .../summarizing_conversation_manager.py | 2 +- src/strands/hooks/__init__.py | 4 + src/strands/hooks/events.py | 52 ++++++++ tests/strands/agent/test_agent.py | 80 +++++++++++- .../agent/test_conversation_manager.py | 122 +++++++++++++++++- .../test_summarizing_conversation_manager.py | 2 + tests/strands/hooks/test_events.py | 61 ++++++++- 10 files changed, 379 insertions(+), 11 deletions(-) diff --git a/src/strands/agent/conversation_manager/conversation_manager.py b/src/strands/agent/conversation_manager/conversation_manager.py index 690ecbde5..c6be932eb 100644 --- a/src/strands/agent/conversation_manager/conversation_manager.py +++ b/src/strands/agent/conversation_manager/conversation_manager.py @@ -1,14 +1,19 @@ """Abstract interface for conversation history management.""" +import logging +import warnings from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any +from ...hooks.events import AfterReduceContextEvent, BeforeReduceContextEvent from ...hooks.registry import HookProvider, HookRegistry from ...types.content import Message if TYPE_CHECKING: from ...agent.agent import Agent +logger = logging.getLogger(__name__) + class ConversationManager(ABC, HookProvider): """Abstract base class for managing conversation history. @@ -24,6 +29,10 @@ class ConversationManager(ABC, HookProvider): lifecycle events. Derived classes that override register_hooks must call the base implementation to ensure proper hook registration. + Subclasses should override ``_reduce_context`` (not ``reduce_context``) to implement their reduction strategy. + The framework wraps ``_reduce_context`` with ``BeforeReduceContextEvent`` / ``AfterReduceContextEvent`` emission + automatically. + Example: ```python class MyConversationManager(ConversationManager): @@ -33,6 +42,21 @@ def register_hooks(self, registry: HookRegistry, **kwargs: Any) -> None: ``` """ + def __init_subclass__(cls, **kwargs: Any) -> None: + """Detect legacy subclasses that override reduce_context directly and re-wire them.""" + super().__init_subclass__(**kwargs) + if "reduce_context" in cls.__dict__ and "_reduce_context" not in cls.__dict__: + warnings.warn( + f"{cls.__name__} overrides reduce_context() directly. " + f"This still works but the recommended pattern is to override _reduce_context(). " + f"Before/AfterReduceContextEvent will continue to fire because the framework " + f"wraps the override transparently.", + DeprecationWarning, + stacklevel=2, + ) + cls._reduce_context = cls.__dict__["reduce_context"] # type: ignore[attr-defined] + del cls.reduce_context + def __init__(self) -> None: """Initialize the ConversationManager. @@ -97,12 +121,43 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: """ pass - @abstractmethod def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: - """Called when the model's context window is exceeded. + """Reduce the conversation context, emitting Before/After hook events automatically. + + This is a concrete template method. Subclasses should override ``_reduce_context`` instead. + + Args: + agent: The agent whose conversation history will be reduced. + e: The exception that triggered the context reduction, if any. + **kwargs: Additional keyword arguments for future extensibility. + """ + message_count_before = len(agent.messages) + agent.hooks.invoke_callbacks( + BeforeReduceContextEvent( + agent=agent, + exception=e, + message_count=message_count_before, + ) + ) + self._reduce_context(agent, e=e, **kwargs) + message_count_after = len(agent.messages) + agent.hooks.invoke_callbacks( + AfterReduceContextEvent( + agent=agent, + exception=e, + messages_removed=message_count_before - message_count_after, + message_count_before=message_count_before, + message_count_after=message_count_after, + ) + ) + + @abstractmethod + def _reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: + """Subclass implementation of context reduction. - This method should implement the specific strategy for reducing the window size when a context overflow occurs. - It is typically called after a ContextWindowOverflowException is caught. + Called by the framework via ``reduce_context()``. Subclasses should not emit + ``BeforeReduceContextEvent`` / ``AfterReduceContextEvent`` themselves — the + framework does that automatically. Implementations might use strategies such as: diff --git a/src/strands/agent/conversation_manager/null_conversation_manager.py b/src/strands/agent/conversation_manager/null_conversation_manager.py index 11632525d..046c79479 100644 --- a/src/strands/agent/conversation_manager/null_conversation_manager.py +++ b/src/strands/agent/conversation_manager/null_conversation_manager.py @@ -28,7 +28,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: """ pass - def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: + def _reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Does not reduce context and raises an exception. Args: diff --git a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py index 1b45dd42c..3c01fe3c5 100644 --- a/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py +++ b/src/strands/agent/conversation_manager/sliding_window_conversation_manager.py @@ -155,7 +155,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: return self.reduce_context(agent) - def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: + def _reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Trim the oldest messages to reduce the conversation context size. The method handles special cases where trimming the messages leads to: diff --git a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py index abd4d08b5..0e482d003 100644 --- a/src/strands/agent/conversation_manager/summarizing_conversation_manager.py +++ b/src/strands/agent/conversation_manager/summarizing_conversation_manager.py @@ -123,7 +123,7 @@ def apply_management(self, agent: "Agent", **kwargs: Any) -> None: # No proactive management - summarization only happens on context overflow pass - def reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: + def _reduce_context(self, agent: "Agent", e: Exception | None = None, **kwargs: Any) -> None: """Reduce context using summarization. Args: diff --git a/src/strands/hooks/__init__.py b/src/strands/hooks/__init__.py index 96c7f577b..d141f56f4 100644 --- a/src/strands/hooks/__init__.py +++ b/src/strands/hooks/__init__.py @@ -35,12 +35,14 @@ def log_end(self, event: AfterInvocationEvent) -> None: # Multiagent hook events AfterMultiAgentInvocationEvent, AfterNodeCallEvent, + AfterReduceContextEvent, AfterToolCallEvent, AgentInitializedEvent, BeforeInvocationEvent, BeforeModelCallEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, + BeforeReduceContextEvent, BeforeToolCallEvent, MessageAddedEvent, MultiAgentInitializedEvent, @@ -54,6 +56,8 @@ def log_end(self, event: AfterInvocationEvent) -> None: "AfterToolCallEvent", "BeforeModelCallEvent", "AfterModelCallEvent", + "BeforeReduceContextEvent", + "AfterReduceContextEvent", "AfterInvocationEvent", "MessageAddedEvent", "HookEvent", diff --git a/src/strands/hooks/events.py b/src/strands/hooks/events.py index 80b50770a..f0846fd3f 100644 --- a/src/strands/hooks/events.py +++ b/src/strands/hooks/events.py @@ -308,6 +308,58 @@ def should_reverse_callbacks(self) -> bool: return True +@dataclass +class BeforeReduceContextEvent(HookEvent): + """Event triggered before the agent's conversation context is reduced. + + Fired whenever ``reduce_context()`` is about to run, regardless of trigger + (reactive overflow exception, proactive sliding-window overflow, or any + third-party manager's own logic). + + Attributes: + exception: The exception that triggered context reduction, if any. + ``None`` when reduction was triggered proactively (e.g. sliding-window overflow). + message_count: The number of messages in the agent's conversation history + immediately before reduction is applied. + """ + + exception: Exception | None = None + message_count: int = 0 + + +@dataclass +class AfterReduceContextEvent(HookEvent): + """Event triggered after the agent's conversation context has been reduced. + + Fired only on successful completion of ``reduce_context()``. If the underlying + reduction raises, no ``AfterReduceContextEvent`` is emitted; the exception + propagates normally. + + Note: This event uses reverse callback ordering, meaning callbacks registered + later will be invoked first during cleanup. + + Attributes: + exception: The exception that triggered context reduction, if any. + ``None`` for proactive reductions. + messages_removed: Number of messages removed during this reduction + (``message_count_before - message_count_after``). + message_count_before: Number of messages in the conversation history before + reduction was applied. + message_count_after: Number of messages in the conversation history after + reduction completed. + """ + + exception: Exception | None = None + messages_removed: int = 0 + message_count_before: int = 0 + message_count_after: int = 0 + + @property + def should_reverse_callbacks(self) -> bool: + """True to invoke callbacks in reverse order.""" + return True + + # Multiagent hook events start here @dataclass class MultiAgentInitializedEvent(BaseHookEvent): diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 680a1d23c..3ca923dc3 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -21,7 +21,13 @@ from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.state import AgentState from strands.handlers.callback_handler import PrintingCallbackHandler, null_callback_handler -from strands.hooks import BeforeInvocationEvent, BeforeModelCallEvent, BeforeToolCallEvent +from strands.hooks import ( + AfterReduceContextEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeReduceContextEvent, + BeforeToolCallEvent, +) from strands.interrupt import Interrupt from strands.models.bedrock import DEFAULT_BEDROCK_MODEL_ID, BedrockModel from strands.session.repository_session_manager import RepositorySessionManager @@ -529,6 +535,78 @@ def test_agent__call__retry_with_reduced_context(mock_model, agent, tool, agener assert conversation_manager_spy.apply_management.call_count == 1 +def test_agent__call__emits_reduce_context_events(mock_model, agent, agenerator): + """Verify Before/AfterReduceContextEvent fire with correct metadata when overflow triggers reduction.""" + messages: Messages = [ + {"role": "user", "content": [{"text": "Hello!"}]}, + {"role": "assistant", "content": [{"text": "Hi!"}]}, + {"role": "user", "content": [{"text": "Whats your favorite color?"}]}, + {"role": "assistant", "content": [{"text": "Blue!"}]}, + ] + agent.messages = messages + + before_events: list[BeforeReduceContextEvent] = [] + after_events: list[AfterReduceContextEvent] = [] + + agent.hooks.add_callback(BeforeReduceContextEvent, before_events.append) + agent.hooks.add_callback(AfterReduceContextEvent, after_events.append) + + trigger_exception = ContextWindowOverflowException(RuntimeError("Input is too long for requested model")) + mock_model.mock_stream.side_effect = [ + trigger_exception, + agenerator( + [ + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "Green!"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ), + ] + + agent("And now?") + + assert len(before_events) == 1 + assert len(after_events) == 1 + + before_event = before_events[0] + assert before_event.agent is agent + assert before_event.exception is trigger_exception + # Before reduction runs, the prompt "And now?" has already been appended to messages (5 total). + assert before_event.message_count == 5 + + after_event = after_events[0] + assert after_event.agent is agent + assert after_event.exception is trigger_exception + assert after_event.message_count_before == 5 + assert after_event.message_count_after < after_event.message_count_before + assert after_event.messages_removed == after_event.message_count_before - after_event.message_count_after + assert after_event.messages_removed > 0 + + +def test_agent__call__no_reduce_context_events_on_success(mock_model, agent, agenerator): + """Verify reduce-context events are NOT fired on a normal successful invocation.""" + before_events: list[BeforeReduceContextEvent] = [] + after_events: list[AfterReduceContextEvent] = [] + + agent.hooks.add_callback(BeforeReduceContextEvent, before_events.append) + agent.hooks.add_callback(AfterReduceContextEvent, after_events.append) + + mock_model.mock_stream.return_value = agenerator( + [ + {"contentBlockStart": {"start": {}}}, + {"contentBlockDelta": {"delta": {"text": "ok"}}}, + {"contentBlockStop": {}}, + {"messageStop": {"stopReason": "end_turn"}}, + ] + ) + + agent("Hello?") + + assert before_events == [] + assert after_events == [] + + def test_agent__call__always_sliding_window_conversation_manager_doesnt_infinite_loop(mock_model, agent, tool): conversation_manager = SlidingWindowConversationManager(window_size=500, should_truncate_results=False) conversation_manager_spy = unittest.mock.Mock(wraps=conversation_manager) diff --git a/tests/strands/agent/test_conversation_manager.py b/tests/strands/agent/test_conversation_manager.py index 8679e6fd7..c0d92fd01 100644 --- a/tests/strands/agent/test_conversation_manager.py +++ b/tests/strands/agent/test_conversation_manager.py @@ -323,7 +323,7 @@ def test_null_conversation_manager_reduce_context_raises_context_window_overflow manager.apply_management(test_agent) with pytest.raises(ContextWindowOverflowException): - manager.reduce_context(messages) + manager.reduce_context(test_agent) assert messages == original_messages @@ -341,7 +341,7 @@ def test_null_conversation_manager_reduce_context_with_exception_raises_same_exc manager.apply_management(test_agent) with pytest.raises(RuntimeError): - manager.reduce_context(messages, RuntimeError("test")) + manager.reduce_context(test_agent, RuntimeError("test")) assert messages == original_messages @@ -755,3 +755,121 @@ def test_window_size_zero_clears_on_overflow(): manager.reduce_context(test_agent, e=Exception("overflow")) assert messages == [] + + +# --- Template-method and backward-compatibility tests --- + + +def test_template_method_emits_events(): + """Subclass overriding _reduce_context gets Before/After events automatically.""" + from strands.hooks import AfterReduceContextEvent, BeforeReduceContextEvent + from strands.agent.conversation_manager.conversation_manager import ConversationManager + + class CustomManager(ConversationManager): + def apply_management(self, agent, **kwargs): + pass + + def _reduce_context(self, agent, e=None, **kwargs): + agent.messages[:] = agent.messages[-1:] + + manager = CustomManager() + messages = [ + {"role": "user", "content": [{"text": "a"}]}, + {"role": "assistant", "content": [{"text": "b"}]}, + {"role": "user", "content": [{"text": "c"}]}, + ] + test_agent = Agent(messages=messages) + + before_events: list[BeforeReduceContextEvent] = [] + after_events: list[AfterReduceContextEvent] = [] + test_agent.hooks.add_callback(BeforeReduceContextEvent, before_events.append) + test_agent.hooks.add_callback(AfterReduceContextEvent, after_events.append) + + manager.reduce_context(test_agent) + + assert len(before_events) == 1 + assert len(after_events) == 1 + assert before_events[0].message_count == 3 + assert after_events[0].messages_removed == 2 + assert after_events[0].message_count_before == 3 + assert after_events[0].message_count_after == 1 + + +def test_legacy_override_emits_events_with_warning(): + """Third-party subclass overriding reduce_context directly still gets events + DeprecationWarning.""" + import warnings + + from strands.hooks import AfterReduceContextEvent, BeforeReduceContextEvent + from strands.agent.conversation_manager.conversation_manager import ConversationManager + + with warnings.catch_warnings(record=True) as caught: + warnings.simplefilter("always") + + class LegacyManager(ConversationManager): + def apply_management(self, agent, **kwargs): + pass + + def reduce_context(self, agent, e=None, **kwargs): + agent.messages[:] = agent.messages[1:] + + assert any(issubclass(w.category, DeprecationWarning) and "LegacyManager" in str(w.message) for w in caught) + + manager = LegacyManager() + messages = [ + {"role": "user", "content": [{"text": "a"}]}, + {"role": "user", "content": [{"text": "b"}]}, + ] + test_agent = Agent(messages=messages) + + after_events: list[AfterReduceContextEvent] = [] + test_agent.hooks.add_callback(AfterReduceContextEvent, after_events.append) + + manager.reduce_context(test_agent) + + assert len(after_events) == 1 + assert after_events[0].messages_removed == 1 + + +def test_null_manager_before_fires_after_does_not(): + """NullConversationManager raises; BeforeReduceContextEvent fires but AfterReduceContextEvent does not.""" + from strands.hooks import AfterReduceContextEvent, BeforeReduceContextEvent + + manager = NullConversationManager() + test_agent = Agent(messages=[{"role": "user", "content": [{"text": "hi"}]}]) + + before_events: list[BeforeReduceContextEvent] = [] + after_events: list[AfterReduceContextEvent] = [] + test_agent.hooks.add_callback(BeforeReduceContextEvent, before_events.append) + test_agent.hooks.add_callback(AfterReduceContextEvent, after_events.append) + + with pytest.raises(ContextWindowOverflowException): + manager.reduce_context(test_agent) + + assert len(before_events) == 1 + assert len(after_events) == 0 + + +def test_proactive_sliding_window_emits_events(): + """SlidingWindowConversationManager.apply_management calls reduce_context which emits events.""" + from strands.hooks import AfterReduceContextEvent, BeforeReduceContextEvent + + manager = SlidingWindowConversationManager(window_size=2, should_truncate_results=False) + messages = [ + {"role": "user", "content": [{"text": "1"}]}, + {"role": "assistant", "content": [{"text": "2"}]}, + {"role": "user", "content": [{"text": "3"}]}, + ] + test_agent = Agent(messages=messages) + + before_events: list[BeforeReduceContextEvent] = [] + after_events: list[AfterReduceContextEvent] = [] + test_agent.hooks.add_callback(BeforeReduceContextEvent, before_events.append) + test_agent.hooks.add_callback(AfterReduceContextEvent, after_events.append) + + manager.apply_management(test_agent) + + assert len(before_events) == 1 + assert len(after_events) == 1 + assert before_events[0].exception is None + assert after_events[0].exception is None + assert after_events[0].messages_removed > 0 \ No newline at end of file diff --git a/tests/strands/agent/test_summarizing_conversation_manager.py b/tests/strands/agent/test_summarizing_conversation_manager.py index c49c69de6..d334debcb 100644 --- a/tests/strands/agent/test_summarizing_conversation_manager.py +++ b/tests/strands/agent/test_summarizing_conversation_manager.py @@ -8,6 +8,7 @@ DEFAULT_SUMMARIZATION_PROMPT, SummarizingConversationManager, ) +from strands.hooks.registry import HookRegistry from strands.types.content import Messages from strands.types.exceptions import ContextWindowOverflowException from tests.fixtures.mocked_model_provider import MockedModelProvider @@ -51,6 +52,7 @@ def __init__(self, summary_response="This is a summary of the conversation."): self.tool_registry = Mock() self.tool_names = [] self._default_structured_output_model = None + self.hooks = HookRegistry() def __call__(self, prompt): """Mock agent call that returns a summary.""" diff --git a/tests/strands/hooks/test_events.py b/tests/strands/hooks/test_events.py index 90ab205a9..616d38559 100644 --- a/tests/strands/hooks/test_events.py +++ b/tests/strands/hooks/test_events.py @@ -1,4 +1,4 @@ -"""Tests for multi-agent execution lifecycle events.""" +"""Tests for agent and multi-agent execution lifecycle events.""" from unittest.mock import Mock @@ -7,9 +7,12 @@ from strands.hooks import ( AfterMultiAgentInvocationEvent, AfterNodeCallEvent, + AfterReduceContextEvent, BaseHookEvent, BeforeMultiAgentInvocationEvent, BeforeNodeCallEvent, + BeforeReduceContextEvent, + HookEvent, MultiAgentInitializedEvent, ) @@ -105,3 +108,59 @@ def test_after_events_should_reverse_callbacks(orchestrator): assert after_node_event.should_reverse_callbacks is True assert after_invocation_event.should_reverse_callbacks is True + + +@pytest.fixture +def agent(): + """Mock agent for testing.""" + return Mock() + + +def test_before_reduce_context_event_defaults(agent): + """BeforeReduceContextEvent has sensible defaults and inherits from HookEvent.""" + event = BeforeReduceContextEvent(agent=agent) + + assert event.agent is agent + assert event.exception is None + assert event.message_count == 0 + assert event.should_reverse_callbacks is False + assert isinstance(event, HookEvent) + + +def test_before_reduce_context_event_with_fields(agent): + """BeforeReduceContextEvent carries the trigger exception and message count.""" + exc = RuntimeError("overflow") + event = BeforeReduceContextEvent(agent=agent, exception=exc, message_count=12) + + assert event.exception is exc + assert event.message_count == 12 + + +def test_after_reduce_context_event_defaults(agent): + """AfterReduceContextEvent has sensible defaults and runs callbacks in reverse.""" + event = AfterReduceContextEvent(agent=agent) + + assert event.agent is agent + assert event.exception is None + assert event.messages_removed == 0 + assert event.message_count_before == 0 + assert event.message_count_after == 0 + assert event.should_reverse_callbacks is True + assert isinstance(event, HookEvent) + + +def test_after_reduce_context_event_with_fields(agent): + """AfterReduceContextEvent carries before/after counts and the original exception.""" + exc = RuntimeError("overflow") + event = AfterReduceContextEvent( + agent=agent, + exception=exc, + messages_removed=3, + message_count_before=10, + message_count_after=7, + ) + + assert event.exception is exc + assert event.messages_removed == 3 + assert event.message_count_before == 10 + assert event.message_count_after == 7 \ No newline at end of file