|
| 1 | +from unittest.mock import MagicMock, patch |
| 2 | + |
1 | 3 | import pytest |
2 | 4 |
|
| 5 | +from strands import tool |
3 | 6 | from strands.agent.agent import Agent |
4 | 7 | from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager |
5 | 8 | from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager |
| 9 | +from strands.hooks.events import BeforeModelCallEvent |
| 10 | +from strands.hooks.registry import HookProvider, HookRegistry |
6 | 11 | from strands.types.exceptions import ContextWindowOverflowException |
| 12 | +from tests.fixtures.mocked_model_provider import MockedModelProvider |
7 | 13 |
|
8 | 14 |
|
9 | 15 | @pytest.fixture |
@@ -246,3 +252,171 @@ def test_null_conversation_does_not_restore_with_incorrect_state(): |
246 | 252 |
|
247 | 253 | with pytest.raises(ValueError): |
248 | 254 | manager.restore_from_session({}) |
| 255 | + |
| 256 | + |
| 257 | +# ============================================================================== |
| 258 | +# Per-Turn Management Tests |
| 259 | +# ============================================================================== |
| 260 | + |
| 261 | + |
| 262 | +def test_per_turn_parameter_validation(): |
| 263 | + """Test per_turn parameter validation.""" |
| 264 | + # Valid values |
| 265 | + assert SlidingWindowConversationManager(per_turn=False).per_turn is False |
| 266 | + assert SlidingWindowConversationManager(per_turn=True).per_turn is True |
| 267 | + assert SlidingWindowConversationManager(per_turn=3).per_turn == 3 |
| 268 | + |
| 269 | + |
| 270 | +def test_conversation_manager_is_hook_provider(): |
| 271 | + """Test that ConversationManager implements HookProvider protocol.""" |
| 272 | + manager = NullConversationManager() |
| 273 | + assert isinstance(manager, HookProvider) |
| 274 | + |
| 275 | + |
| 276 | +def test_derived_class_does_not_need_to_implement_register_hooks(): |
| 277 | + """Test that derived classes don't need to override register_hooks for backwards compatibility.""" |
| 278 | + from strands.agent.conversation_manager.conversation_manager import ConversationManager |
| 279 | + |
| 280 | + class MinimalConversationManager(ConversationManager): |
| 281 | + """A minimal implementation that only implements abstract methods.""" |
| 282 | + |
| 283 | + def apply_management(self, agent, **kwargs): |
| 284 | + pass |
| 285 | + |
| 286 | + def reduce_context(self, agent, e=None, **kwargs): |
| 287 | + pass |
| 288 | + |
| 289 | + # Should be able to instantiate without implementing register_hooks |
| 290 | + manager = MinimalConversationManager() |
| 291 | + registry = HookRegistry() |
| 292 | + |
| 293 | + # Should work without error |
| 294 | + manager.register_hooks(registry) |
| 295 | + assert not registry.has_callbacks() |
| 296 | + |
| 297 | + |
| 298 | +def test_per_turn_hooks_registration(): |
| 299 | + """Test that hooks are registered when conversation_manager implements HookProvider.""" |
| 300 | + manager = SlidingWindowConversationManager(per_turn=True) |
| 301 | + assert isinstance(manager, HookProvider) |
| 302 | + |
| 303 | + registry = HookRegistry() |
| 304 | + manager.register_hooks(registry) |
| 305 | + assert registry.has_callbacks() |
| 306 | + |
| 307 | + |
| 308 | +def test_per_turn_false_no_management_during_loop(): |
| 309 | + """Test that per_turn=False only manages in finally block.""" |
| 310 | + manager = SlidingWindowConversationManager(per_turn=False, window_size=100) |
| 311 | + responses = [{"role": "assistant", "content": [{"text": "Response"}]}] * 3 |
| 312 | + model = MockedModelProvider(responses) |
| 313 | + agent = Agent(model=model, conversation_manager=manager) |
| 314 | + |
| 315 | + with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock: |
| 316 | + agent("Test") |
| 317 | + # Should only be called once in finally block (per_turn disabled) |
| 318 | + assert mock.call_count == 1 |
| 319 | + |
| 320 | + |
| 321 | +def test_per_turn_true_manages_each_model_call(): |
| 322 | + """Test that per_turn=True applies management before each model call.""" |
| 323 | + manager = SlidingWindowConversationManager(per_turn=True, window_size=100) |
| 324 | + responses = [{"role": "assistant", "content": [{"text": "Response"}]}] * 3 |
| 325 | + model = MockedModelProvider(responses) |
| 326 | + agent = Agent(model=model, conversation_manager=manager) |
| 327 | + |
| 328 | + with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock: |
| 329 | + agent("Test") |
| 330 | + # Should be called for each model call + finally block |
| 331 | + # With simple text responses, agent makes 1 model call then stops |
| 332 | + assert mock.call_count >= 1 |
| 333 | + |
| 334 | + |
| 335 | +def test_per_turn_integer_manages_every_n_calls(): |
| 336 | + """Test that per_turn=N applies management every N model calls.""" |
| 337 | + manager = SlidingWindowConversationManager(per_turn=2, window_size=100) |
| 338 | + # Create responses that trigger multiple model calls |
| 339 | + responses = [ |
| 340 | + {"role": "assistant", "content": [{"toolUse": {"toolUseId": f"{i}", "name": "test", "input": {}}}]} |
| 341 | + for i in range(5) |
| 342 | + ] + [{"role": "assistant", "content": [{"text": "Done"}]}] |
| 343 | + model = MockedModelProvider(responses) |
| 344 | + |
| 345 | + @tool(name="test") |
| 346 | + def test_tool(query: str = "") -> str: |
| 347 | + return "result" |
| 348 | + |
| 349 | + agent = Agent(model=model, conversation_manager=manager, tools=[test_tool]) |
| 350 | + |
| 351 | + with patch.object(manager, "apply_management", wraps=manager.apply_management) as mock: |
| 352 | + agent("Test") |
| 353 | + # With 6 model calls and per_turn=2: called on 2nd, 4th, 6th + finally |
| 354 | + assert mock.call_count == 4 |
| 355 | + |
| 356 | + |
| 357 | +def test_per_turn_dynamic_change(): |
| 358 | + """Test that per_turn can be changed dynamically.""" |
| 359 | + manager = SlidingWindowConversationManager(per_turn=False) |
| 360 | + registry = HookRegistry() |
| 361 | + manager.register_hooks(registry) |
| 362 | + |
| 363 | + mock_agent = MagicMock() |
| 364 | + mock_agent.messages = [] |
| 365 | + event = BeforeModelCallEvent(agent=mock_agent) |
| 366 | + |
| 367 | + # Initially disabled |
| 368 | + with patch.object(manager, "apply_management") as mock_apply: |
| 369 | + registry.invoke_callbacks(event) |
| 370 | + assert mock_apply.call_count == 0 |
| 371 | + |
| 372 | + # Enable dynamically |
| 373 | + manager.per_turn = True |
| 374 | + with patch.object(manager, "apply_management") as mock_apply: |
| 375 | + registry.invoke_callbacks(event) |
| 376 | + assert mock_apply.call_count == 1 |
| 377 | + |
| 378 | + |
| 379 | +def test_per_turn_reduces_message_count(): |
| 380 | + """Test that per_turn actually reduces message count during execution.""" |
| 381 | + manager = SlidingWindowConversationManager(per_turn=1, window_size=4) |
| 382 | + responses = [{"role": "assistant", "content": [{"text": f"Response {i}"}]} for i in range(10)] |
| 383 | + model = MockedModelProvider(responses) |
| 384 | + agent = Agent(model=model, conversation_manager=manager) |
| 385 | + |
| 386 | + message_counts = [] |
| 387 | + original_apply = manager.apply_management |
| 388 | + |
| 389 | + def track_apply(agent_instance): |
| 390 | + message_counts.append(len(agent_instance.messages)) |
| 391 | + return original_apply(agent_instance) |
| 392 | + |
| 393 | + with patch.object(manager, "apply_management", side_effect=track_apply): |
| 394 | + agent("Test") |
| 395 | + |
| 396 | + # Verify message count stayed around window_size |
| 397 | + assert any(count <= manager.window_size for count in message_counts) |
| 398 | + |
| 399 | + |
| 400 | +def test_per_turn_state_persistence(): |
| 401 | + """Test that model_call_count is persisted in state.""" |
| 402 | + manager = SlidingWindowConversationManager(per_turn=3) |
| 403 | + manager._model_call_count = 7 |
| 404 | + |
| 405 | + state = manager.get_state() |
| 406 | + assert state["model_call_count"] == 7 |
| 407 | + |
| 408 | + new_manager = SlidingWindowConversationManager(per_turn=3) |
| 409 | + new_manager.restore_from_session(state) |
| 410 | + assert new_manager._model_call_count == 7 |
| 411 | + |
| 412 | + |
| 413 | +def test_per_turn_backward_compatibility(): |
| 414 | + """Test that existing code without per_turn still works.""" |
| 415 | + manager = SlidingWindowConversationManager(window_size=40) |
| 416 | + assert manager.per_turn is False |
| 417 | + |
| 418 | + responses = [{"role": "assistant", "content": [{"text": "Hello"}]}] |
| 419 | + model = MockedModelProvider(responses) |
| 420 | + agent = Agent(model=model, conversation_manager=manager) |
| 421 | + result = agent("Hello") |
| 422 | + assert result is not None |
0 commit comments