From e844b30071ccd7200fb332d3d4e3e0a36e9b7854 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 5 Nov 2025 14:24:52 -0500 Subject: [PATCH 01/13] fix: Handle "prompt is too long" from Anthropic (#1137) PR#1078 mentioned that context overflows were not handled, but I wasn't able to reproduce using the code changes in it. However, in testing (using @dea's suggested test) I was able to reproduce and consistently got a "prompt is too long:" error Co-authored-by: Mackenzie Zastrow --- src/strands/models/anthropic.py | 1 + tests_integ/models/test_model_anthropic.py | 30 ++++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 48351da19..68b234729 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -39,6 +39,7 @@ class AnthropicModel(Model): } OVERFLOW_MESSAGES = { + "prompt is too long:", "input is too long", "input length exceeds context window", "input and output tokens exceed your context limit", diff --git a/tests_integ/models/test_model_anthropic.py b/tests_integ/models/test_model_anthropic.py index 62a95d06d..9a0d19dff 100644 --- a/tests_integ/models/test_model_anthropic.py +++ b/tests_integ/models/test_model_anthropic.py @@ -5,7 +5,10 @@ import strands from strands import Agent +from strands.agent import NullConversationManager from strands.models.anthropic import AnthropicModel +from strands.types.content import ContentBlock, Message +from strands.types.exceptions import ContextWindowOverflowException """ These tests only run if we have the anthropic api key @@ -152,3 +155,30 @@ def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): tru_color = agent.structured_output(type(yellow_color), content) exp_color = yellow_color assert tru_color == exp_color + + +@pytest.mark.asyncio +def test_input_and_max_tokens_exceed_context_limit(): + """Test that triggers 'input length and max_tokens exceed context limit' error.""" + + # Note that this test is written specifically in a style that allows us to swap out conversation_manager and + # verify behavior + + model = AnthropicModel( + model_id="claude-sonnet-4-20250514", + max_tokens=64000, + ) + + large_message = "This is a very long text. " * 10000 + + messages = [ + Message(role="user", content=[ContentBlock(text=large_message)]), + Message(role="assistant", content=[ContentBlock(text=large_message)]), + Message(role="user", content=[ContentBlock(text=large_message)]), + ] + + # NullConversationManager will propagate ContextWindowOverflowException directly instead of handling it + agent = Agent(model=model, conversation_manager=NullConversationManager()) + + with pytest.raises(ContextWindowOverflowException): + agent(messages) From 1df45be924226985008814a508fab5d952a06201 Mon Sep 17 00:00:00 2001 From: ratish <114130421+Ratish1@users.noreply.github.com> Date: Fri, 7 Nov 2025 01:41:59 +0400 Subject: [PATCH 02/13] feat(telemetry): Add tool definitions to traces via semconv opt-in (#1113) --- src/strands/agent/agent.py | 1 + src/strands/telemetry/tracer.py | 47 ++++++++++++++++----- tests/strands/agent/test_agent.py | 8 +++- tests/strands/telemetry/test_tracer.py | 57 ++++++++++++++++++++++++++ 4 files changed, 101 insertions(+), 12 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8137f1887..9de5ffd21 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -938,6 +938,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: tools=self.tool_names, system_prompt=self.system_prompt, custom_trace_attributes=self.trace_attributes, + tools_config=self.tool_registry.get_all_tools_config(), ) def _end_agent_trace_span( diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 9cefc6911..a68aad8b7 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -79,11 +79,16 @@ class Tracer: When the OTEL_EXPORTER_OTLP_ENDPOINT environment variable is set, traces are sent to the OTLP endpoint. + + Attributes: + use_latest_genai_conventions: If True, uses the latest experimental GenAI semantic conventions. + include_tool_definitions: If True, includes detailed tool definitions in the agent trace span. + + Both attributes are controlled by including "gen_ai_latest_experimental" or "gen_ai_tool_definitions", + respectively, in the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. """ - def __init__( - self, - ) -> None: + def __init__(self) -> None: """Initialize the tracer.""" self.service_name = __name__ self.tracer_provider: Optional[trace_api.TracerProvider] = None @@ -92,17 +97,18 @@ def __init__( ThreadingInstrumentor().instrument() # Read OTEL_SEMCONV_STABILITY_OPT_IN environment variable - self.use_latest_genai_conventions = self._parse_semconv_opt_in() + opt_in_values = self._parse_semconv_opt_in() + self.use_latest_genai_conventions = "gen_ai_latest_experimental" in opt_in_values + self.include_tool_definitions = "gen_ai_tool_definitions" in opt_in_values - def _parse_semconv_opt_in(self) -> bool: + def _parse_semconv_opt_in(self) -> set[str]: """Parse the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. Returns: - Set of opt-in values from the environment variable + A set of opt-in values from the environment variable. """ opt_in_env = os.getenv("OTEL_SEMCONV_STABILITY_OPT_IN", "") - - return "gen_ai_latest_experimental" in opt_in_env + return {value.strip() for value in opt_in_env.split(",")} def _start_span( self, @@ -551,6 +557,7 @@ def start_agent_span( model_id: Optional[str] = None, tools: Optional[list] = None, custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + tools_config: Optional[dict] = None, **kwargs: Any, ) -> Span: """Start a new span for an agent invocation. @@ -561,6 +568,7 @@ def start_agent_span( model_id: Optional model identifier. tools: Optional list of tools being used. custom_trace_attributes: Optional mapping of custom trace attributes to include in the span. + tools_config: Optional dictionary of tool configurations. **kwargs: Additional attributes to add to the span. Returns: @@ -577,8 +585,15 @@ def start_agent_span( attributes["gen_ai.request.model"] = model_id if tools: - tools_json = serialize(tools) - attributes["gen_ai.agent.tools"] = tools_json + attributes["gen_ai.agent.tools"] = serialize(tools) + + if self.include_tool_definitions and tools_config: + try: + tool_definitions = self._construct_tool_definitions(tools_config) + attributes["gen_ai.tool.definitions"] = serialize(tool_definitions) + except Exception: + # A failure in telemetry should not crash the agent + logger.warning("failed to attach tool metadata to agent span", exc_info=True) # Add custom trace attributes if provided if custom_trace_attributes: @@ -649,6 +664,18 @@ def end_agent_span( self._end_span(span, attributes, error) + def _construct_tool_definitions(self, tools_config: dict) -> list[dict[str, Any]]: + """Constructs a list of tool definitions from the provided tools_config.""" + return [ + { + "name": name, + "description": spec.get("description"), + "inputSchema": spec.get("inputSchema"), + "outputSchema": spec.get("outputSchema"), + } + for name, spec in tools_config.items() + ] + def start_multiagent_span( self, task: str | list[ContentBlock], diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 3a0bc2dfb..b96a04b21 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1360,6 +1360,7 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model tools=agent.tool_names, system_prompt=agent.system_prompt, custom_trace_attributes=agent.trace_attributes, + tools_config=unittest.mock.ANY, ) # Verify span was ended with the result @@ -1394,6 +1395,7 @@ async def test_event_loop(*args, **kwargs): tools=agent.tool_names, system_prompt=agent.system_prompt, custom_trace_attributes=agent.trace_attributes, + tools_config=unittest.mock.ANY, ) expected_response = AgentResult( @@ -1432,6 +1434,7 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod tools=agent.tool_names, system_prompt=agent.system_prompt, custom_trace_attributes=agent.trace_attributes, + tools_config=unittest.mock.ANY, ) # Verify span was ended with the exception @@ -1468,6 +1471,7 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr tools=agent.tool_names, system_prompt=agent.system_prompt, custom_trace_attributes=agent.trace_attributes, + tools_config=unittest.mock.ANY, ) # Verify span was ended with the exception @@ -2240,8 +2244,8 @@ def test_agent_backwards_compatibility_single_text_block(): # Should extract text for backwards compatibility assert agent.system_prompt == text - - + + @pytest.mark.parametrize( "content, expected", [ diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 05dbe387f..25d477588 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -1324,3 +1324,60 @@ def test_start_event_loop_cycle_span_with_tool_result_message(mock_tracer): "gen_ai.tool.message", attributes={"content": json.dumps(messages[0]["content"])} ) assert span is not None + + +def test_start_agent_span_does_not_include_tool_definitions_by_default(): + """Verify that start_agent_span does not include tool definitions by default.""" + tracer = Tracer() + tracer.include_tool_definitions = False + tracer._start_span = mock.MagicMock() + + tools_config = { + "my_tool": { + "name": "my_tool", + "description": "A test tool", + "inputSchema": {"json": {}}, + "outputSchema": {"json": {}}, + } + } + + tracer.start_agent_span(messages=[], agent_name="TestAgent", tools_config=tools_config) + + tracer._start_span.assert_called_once() + _, call_kwargs = tracer._start_span.call_args + attributes = call_kwargs.get("attributes", {}) + assert "gen_ai.tool.definitions" not in attributes + + +def test_start_agent_span_includes_tool_definitions_when_enabled(): + """Verify that start_agent_span includes tool definitions when enabled.""" + tracer = Tracer() + tracer.include_tool_definitions = True + tracer._start_span = mock.MagicMock() + + tools_config = { + "my_tool": { + "name": "my_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + "outputSchema": {"json": {"type": "object", "properties": {}}}, + } + } + + tracer.start_agent_span(messages=[], agent_name="TestAgent", tools_config=tools_config) + + tracer._start_span.assert_called_once() + _, call_kwargs = tracer._start_span.call_args + attributes = call_kwargs.get("attributes", {}) + + assert "gen_ai.tool.definitions" in attributes + expected_tool_details = [ + { + "name": "my_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + "outputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + expected_json = serialize(expected_tool_details) + assert attributes["gen_ai.tool.definitions"] == expected_json From 28fea4112a2bf73156cb8304ecf6417cbfaaffdc Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Fri, 7 Nov 2025 09:30:03 -0500 Subject: [PATCH 03/13] fix: Strip argument sections out of inputSpec top-level description (#1142) Per #1067 including the args in the description is redundant as it's already included in the parameter docs which can increase the token counts. Strip args from the description strings for inputSpecs --------- Co-authored-by: Mackenzie Zastrow --- src/strands/tools/decorator.py | 60 ++++++++- tests/strands/tools/test_decorator.py | 177 ++++++++++++++++++++++++-- 2 files changed, 222 insertions(+), 15 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 5c49f4b58..0ea328a39 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -164,6 +164,56 @@ def _create_input_model(self) -> Type[BaseModel]: # Handle case with no parameters return create_model(model_name) + def _extract_description_from_docstring(self) -> str: + """Extract the docstring excluding only the Args section. + + This method uses the parsed docstring to extract everything except + the Args/Arguments/Parameters section, preserving Returns, Raises, + Examples, and other sections. + + Returns: + The description text, or the function name if no description is available. + """ + func_name = self.func.__name__ + + # Fallback: try to extract manually from raw docstring + raw_docstring = inspect.getdoc(self.func) + if raw_docstring: + lines = raw_docstring.strip().split("\n") + result_lines = [] + skip_args_section = False + + for line in lines: + stripped_line = line.strip() + + # Check if we're starting the Args section + if stripped_line.lower().startswith(("args:", "arguments:", "parameters:", "param:", "params:")): + skip_args_section = True + continue + + # Check if we're starting a new section (not Args) + elif ( + stripped_line.lower().startswith(("returns:", "return:", "yields:", "yield:")) + or stripped_line.lower().startswith(("raises:", "raise:", "except:", "exceptions:")) + or stripped_line.lower().startswith(("examples:", "example:", "note:", "notes:")) + or stripped_line.lower().startswith(("see also:", "seealso:", "references:", "ref:")) + ): + skip_args_section = False + result_lines.append(line) + continue + + # If we're not in the Args section, include the line + if not skip_args_section: + result_lines.append(line) + + # Join and clean up the description + description = "\n".join(result_lines).strip() + if description: + return description + + # Final fallback: use function name + return func_name + def extract_metadata(self) -> ToolSpec: """Extract metadata from the function to create a tool specification. @@ -173,7 +223,7 @@ def extract_metadata(self) -> ToolSpec: The specification includes: - name: The function name (or custom override) - - description: The function's docstring + - description: The function's docstring description (excluding Args) - inputSchema: A JSON schema describing the expected parameters Returns: @@ -181,12 +231,8 @@ def extract_metadata(self) -> ToolSpec: """ func_name = self.func.__name__ - # Extract function description from docstring, preserving paragraph breaks - description = inspect.getdoc(self.func) - if description: - description = description.strip() - else: - description = func_name + # Extract function description from parsed docstring, excluding Args section and beyond + description = self._extract_description_from_docstring() # Get schema directly from the Pydantic model input_schema = self.input_model.model_json_schema() diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 25f9bc39e..f89f1c945 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -221,14 +221,7 @@ def test_tool(param1: str, param2: int) -> str: # Check basic spec properties assert spec["name"] == "test_tool" - assert ( - spec["description"] - == """Test tool function. - -Args: - param1: First parameter - param2: Second parameter""" - ) + assert spec["description"] == "Test tool function." # Check input schema schema = spec["inputSchema"]["json"] @@ -310,6 +303,174 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: exp_events = [ ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}) ] + assert tru_events == exp_events + + +@pytest.mark.asyncio +async def test_docstring_description_extraction(): + """Test that docstring descriptions are extracted correctly, excluding Args section.""" + + @strands.tool + def tool_with_full_docstring(param1: str, param2: int) -> str: + """This is the main description. + + This is more description text. + + Args: + param1: First parameter + param2: Second parameter + + Returns: + A string result + + Raises: + ValueError: If something goes wrong + """ + return f"{param1} {param2}" + + spec = tool_with_full_docstring.tool_spec + assert ( + spec["description"] + == """This is the main description. + +This is more description text. + +Returns: + A string result + +Raises: + ValueError: If something goes wrong""" + ) + + +def test_docstring_args_variations(): + """Test that various Args section formats are properly excluded.""" + + @strands.tool + def tool_with_args(param: str) -> str: + """Main description. + + Args: + param: Parameter description + """ + return param + + @strands.tool + def tool_with_arguments(param: str) -> str: + """Main description. + + Arguments: + param: Parameter description + """ + return param + + @strands.tool + def tool_with_parameters(param: str) -> str: + """Main description. + + Parameters: + param: Parameter description + """ + return param + + @strands.tool + def tool_with_params(param: str) -> str: + """Main description. + + Params: + param: Parameter description + """ + return param + + for tool in [tool_with_args, tool_with_arguments, tool_with_parameters, tool_with_params]: + spec = tool.tool_spec + assert spec["description"] == "Main description." + + +def test_docstring_no_args_section(): + """Test docstring extraction when there's no Args section.""" + + @strands.tool + def tool_no_args(param: str) -> str: + """This is the complete description. + + Returns: + A string result + """ + return param + + spec = tool_no_args.tool_spec + expected_desc = """This is the complete description. + +Returns: + A string result""" + assert spec["description"] == expected_desc + + +def test_docstring_only_args_section(): + """Test docstring extraction when there's only an Args section.""" + + @strands.tool + def tool_only_args(param: str) -> str: + """Args: + param: Parameter description + """ + return param + + spec = tool_only_args.tool_spec + # Should fall back to function name when no description remains + assert spec["description"] == "tool_only_args" + + +def test_docstring_empty(): + """Test docstring extraction when docstring is empty.""" + + @strands.tool + def tool_empty_docstring(param: str) -> str: + return param + + spec = tool_empty_docstring.tool_spec + # Should fall back to function name + assert spec["description"] == "tool_empty_docstring" + + +def test_docstring_preserves_other_sections(): + """Test that non-Args sections are preserved in the description.""" + + @strands.tool + def tool_multiple_sections(param: str) -> str: + """Main description here. + + Args: + param: This should be excluded + + Returns: + This should be included + + Raises: + ValueError: This should be included + + Examples: + This should be included + + Note: + This should be included + """ + return param + + spec = tool_multiple_sections.tool_spec + description = spec["description"] + + # Should include main description and other sections + assert "Main description here." in description + assert "Returns:" in description + assert "This should be included" in description + assert "Raises:" in description + assert "Examples:" in description + assert "Note:" in description + + # Should exclude Args section + assert "This should be excluded" not in description @pytest.mark.asyncio From c250fc0d4ccfa304f58825d00354ed88f9069884 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 7 Nov 2025 13:52:57 -0500 Subject: [PATCH 04/13] share thread context (#1146) --- src/strands/_async.py | 4 +- tests_integ/tools/__init__.py | 0 tests_integ/tools/test_thread_context.py | 47 ++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 1 deletion(-) create mode 100644 tests_integ/tools/__init__.py create mode 100644 tests_integ/tools/test_thread_context.py diff --git a/src/strands/_async.py b/src/strands/_async.py index 976487c37..141ca71b7 100644 --- a/src/strands/_async.py +++ b/src/strands/_async.py @@ -1,6 +1,7 @@ """Private async execution utilities.""" import asyncio +import contextvars from concurrent.futures import ThreadPoolExecutor from typing import Awaitable, Callable, TypeVar @@ -27,5 +28,6 @@ def execute() -> T: return asyncio.run(execute_async()) with ThreadPoolExecutor() as executor: - future = executor.submit(execute) + context = contextvars.copy_context() + future = executor.submit(context.run, execute) return future.result() diff --git a/tests_integ/tools/__init__.py b/tests_integ/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/tools/test_thread_context.py b/tests_integ/tools/test_thread_context.py new file mode 100644 index 000000000..b86c9b2c0 --- /dev/null +++ b/tests_integ/tools/test_thread_context.py @@ -0,0 +1,47 @@ +import contextvars + +import pytest + +from strands import Agent, tool + + +@pytest.fixture +def result(): + return {} + + +@pytest.fixture +def contextvar(): + return contextvars.ContextVar("agent") + + +@pytest.fixture +def context_tool(result, contextvar): + @tool(name="context_tool") + def tool_(): + result["context_value"] = contextvar.get("local_context") + + return tool_ + + +@pytest.fixture +def agent(context_tool): + return Agent(tools=[context_tool]) + + +def test_agent_invoke_context_sharing(result, contextvar, agent): + contextvar.set("shared_context") + agent("Execute context_tool") + + tru_context = result["context_value"] + exp_context = contextvar.get() + assert tru_context == exp_context + + +def test_tool_call_context_sharing(result, contextvar, agent): + contextvar.set("shared_context") + agent.tool.context_tool() + + tru_context = result["context_value"] + exp_context = contextvar.get() + assert tru_context == exp_context From 2b0c6e662fff059ecbe65f927530c7e7bb9a0d05 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Fri, 7 Nov 2025 13:53:50 -0500 Subject: [PATCH 05/13] async hooks (#1119) --- src/strands/agent/agent.py | 53 ++++--- src/strands/event_loop/event_loop.py | 12 +- src/strands/hooks/registry.py | 67 ++++++++- src/strands/multiagent/graph.py | 10 +- src/strands/multiagent/swarm.py | 14 +- src/strands/tools/executors/_executor.py | 10 +- .../strands/agent/hooks/test_hook_registry.py | 21 +-- tests/strands/event_loop/test_event_loop.py | 3 +- .../test_event_loop_structured_output.py | 6 +- .../experimental/hooks/test_hook_aliases.py | 7 +- tests/strands/hooks/test_registry.py | 27 +++- tests_integ/hooks/__init__.py | 0 tests_integ/hooks/multiagent/__init__.py | 0 tests_integ/hooks/multiagent/test_events.py | 122 ++++++++++++++++ tests_integ/hooks/test_events.py | 138 ++++++++++++++++++ 15 files changed, 419 insertions(+), 71 deletions(-) create mode 100644 tests_integ/hooks/__init__.py create mode 100644 tests_integ/hooks/multiagent/__init__.py create mode 100644 tests_integ/hooks/multiagent/test_events.py create mode 100644 tests_integ/hooks/test_events.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 9de5ffd21..fa4f7051f 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -171,22 +171,21 @@ async def acall() -> ToolResult: self._agent._interrupt_state.deactivate() raise RuntimeError("cannot raise interrupt in direct tool call") - return tool_results[0] + tool_result = tool_results[0] - tool_result = run_async(acall) + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call - if record_direct_tool_call is not None: - should_record_direct_tool_call = record_direct_tool_call - else: - should_record_direct_tool_call = self._agent.record_direct_tool_call + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + await self._agent._record_tool_execution(tool_use, tool_result, user_message_override) - if should_record_direct_tool_call: - # Create a record of this tool execution in the message history - self._agent._record_tool_execution(tool_use, tool_result, user_message_override) + return tool_result - # Apply window management + tool_result = run_async(acall) self._agent.conversation_manager.apply_management(self._agent) - return tool_result return caller @@ -534,7 +533,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu category=DeprecationWarning, stacklevel=2, ) - self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) with self.tracer.tracer.start_as_current_span( "execute_structured_output", kind=trace_api.SpanKind.CLIENT ) as structured_output_span: @@ -542,7 +541,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu if not self.messages and not prompt: raise ValueError("No conversation history or prompt provided") - temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt) + temp_messages: Messages = self.messages + await self._convert_prompt_to_messages(prompt) structured_output_span.set_attributes( { @@ -575,7 +574,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu return event["output"] finally: - self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self)) def cleanup(self) -> None: """Clean up resources used by the agent. @@ -658,7 +657,7 @@ async def stream_async( callback_handler = kwargs.get("callback_handler", self.callback_handler) # Process input and get message to add (if any) - messages = self._convert_prompt_to_messages(prompt) + messages = await self._convert_prompt_to_messages(prompt) self.trace_span = self._start_agent_trace_span(messages) @@ -732,13 +731,13 @@ async def _run_loop( Yields: Events from the event loop cycle. """ - self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) try: yield InitEventLoopEvent() for message in messages: - self._append_message(message) + await self._append_message(message) structured_output_context = StructuredOutputContext( structured_output_model or self._default_structured_output_model @@ -764,7 +763,7 @@ async def _run_loop( finally: self.conversation_manager.apply_management(self) - self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self)) async def _execute_event_loop_cycle( self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None @@ -813,7 +812,7 @@ async def _execute_event_loop_cycle( if structured_output_context: structured_output_context.cleanup(self.tool_registry) - def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: + async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: if self._interrupt_state.activated: return [] @@ -828,7 +827,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: tool_use_ids = [ content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content ] - self._append_message( + await self._append_message( { "role": "user", "content": generate_missing_tool_result_content(tool_use_ids), @@ -859,7 +858,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.") return messages - def _record_tool_execution( + async def _record_tool_execution( self, tool: ToolUse, tool_result: ToolResult, @@ -919,10 +918,10 @@ def _record_tool_execution( } # Add to message history - self._append_message(user_msg) - self._append_message(tool_use_msg) - self._append_message(tool_result_msg) - self._append_message(assistant_msg) + await self._append_message(user_msg) + await self._append_message(tool_use_msg) + await self._append_message(tool_result_msg) + await self._append_message(assistant_msg) def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: """Starts a trace span for the agent. @@ -1008,10 +1007,10 @@ def _initialize_system_prompt( else: return None, None - def _append_message(self, message: Message) -> None: + async def _append_message(self, message: Message) -> None: """Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent.""" self.messages.append(message) - self.hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message)) + await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message)) def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]: """Redact user content preserving toolResult blocks. diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 66174c09f..562de24b8 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -227,7 +227,7 @@ async def event_loop_cycle( ) structured_output_context.set_forced_mode() logger.debug("Forcing structured output tool") - agent._append_message( + await agent._append_message( {"role": "user", "content": [{"text": "You must format the previous response as structured output."}]} ) @@ -322,7 +322,7 @@ async def _handle_model_execution( model_id=model_id, ) with trace_api.use_span(model_invoke_span): - agent.hooks.invoke_callbacks( + await agent.hooks.invoke_callbacks_async( BeforeModelCallEvent( agent=agent, ) @@ -347,7 +347,7 @@ async def _handle_model_execution( stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) - agent.hooks.invoke_callbacks( + await agent.hooks.invoke_callbacks_async( AfterModelCallEvent( agent=agent, stop_response=AfterModelCallEvent.ModelStopResponse( @@ -368,7 +368,7 @@ async def _handle_model_execution( if model_invoke_span: tracer.end_span_with_error(model_invoke_span, str(e), e) - agent.hooks.invoke_callbacks( + await agent.hooks.invoke_callbacks_async( AfterModelCallEvent( agent=agent, exception=e, @@ -402,7 +402,7 @@ async def _handle_model_execution( # Add the response message to the conversation agent.messages.append(message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) + await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=message)) # Update metrics agent.event_loop_metrics.update_usage(usage) @@ -507,7 +507,7 @@ async def _handle_tool_execution( } agent.messages.append(tool_result_message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) + await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=tool_result_message)) yield ToolResultMessageEvent(message=tool_result_message) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 564be85cb..1efc0bf5b 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -7,9 +7,10 @@ via hook provider objects. """ +import inspect import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar from ..interrupt import Interrupt, InterruptException @@ -122,10 +123,15 @@ class HookCallback(Protocol, Generic[TEvent]): ```python def my_callback(event: StartRequestEvent) -> None: print(f"Request started for agent: {event.agent.name}") + + # Or + + async def my_callback(event: StartRequestEvent) -> None: + # await an async operation ``` """ - def __call__(self, event: TEvent) -> None: + def __call__(self, event: TEvent) -> None | Awaitable[None]: """Handle a hook event. Args: @@ -164,6 +170,10 @@ def my_handler(event: StartRequestEvent): registry.add_callback(StartRequestEvent, my_handler) ``` """ + # Related issue: https://github.com/strands-agents/sdk-python/issues/330 + if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): + raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") + callbacks = self._registered_callbacks.setdefault(event_type, []) callbacks.append(callback) @@ -189,6 +199,52 @@ def register_hooks(self, registry: HookRegistry): """ hook.register_hooks(self) + async def invoke_callbacks_async(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]: + """Invoke all registered callbacks for the given event. + + This method finds all callbacks registered for the event's type and + invokes them in the appropriate order. For events with should_reverse_callbacks=True, + callbacks are invoked in reverse registration order. Any exceptions raised by callback + functions will propagate to the caller. + + Additionally, this method aggregates interrupts raised by the user to instantiate human-in-the-loop workflows. + + Args: + event: The event to dispatch to registered callbacks. + + Returns: + The event dispatched to registered callbacks and any interrupts raised by the user. + + Raises: + ValueError: If interrupt name is used more than once. + + Example: + ```python + event = StartRequestEvent(agent=my_agent) + await registry.invoke_callbacks_async(event) + ``` + """ + interrupts: dict[str, Interrupt] = {} + + for callback in self.get_callbacks_for(event): + try: + if inspect.iscoroutinefunction(callback): + await callback(event) + else: + callback(event) + + except InterruptException as exception: + interrupt = exception.interrupt + if interrupt.name in interrupts: + message = f"interrupt_name=<{interrupt.name}> | interrupt name used more than once" + logger.error(message) + raise ValueError(message) from exception + + # Each callback is allowed to raise their own interrupt. + interrupts[interrupt.name] = interrupt + + return event, list(interrupts.values()) + def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]: """Invoke all registered callbacks for the given event. @@ -206,6 +262,7 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte The event dispatched to registered callbacks and any interrupts raised by the user. Raises: + RuntimeError: If at least one callback is async. ValueError: If interrupt name is used more than once. Example: @@ -214,9 +271,13 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte registry.invoke_callbacks(event) ``` """ + callbacks = list(self.get_callbacks_for(event)) interrupts: dict[str, Interrupt] = {} - for callback in self.get_callbacks_for(event): + if any(inspect.iscoroutinefunction(callback) for callback in callbacks): + raise RuntimeError(f"event=<{event}> | use invoke_callbacks_async to invoke async callback") + + for callback in callbacks: try: callback(event) except InterruptException as exception: diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index b421b70c1..9f28876bf 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -453,7 +453,7 @@ def __init__( self._resume_from_session = False self.id = id - self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self)) + run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -516,7 +516,7 @@ async def stream_async( if invocation_state is None: invocation_state = {} - self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state)) + await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state)) logger.debug("task=<%s> | starting graph execution", task) @@ -569,7 +569,7 @@ async def stream_async( raise finally: self.state.execution_time = round((time.time() - start_time) * 1000) - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self)) + await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self)) self._resume_from_session = False self._resume_next_nodes.clear() @@ -776,7 +776,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute a single node and yield TypedEvent objects.""" - self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, node.node_id, invocation_state)) + await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state)) # Reset the node's state if reset_on_revisit is enabled, and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: @@ -920,7 +920,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) raise finally: - self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state)) + await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state)) def _accumulate_metrics(self, node_result: NodeResult) -> None: """Accumulate metrics from a node result.""" diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index accd56463..cb5b36839 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -273,7 +273,7 @@ def __init__( self._setup_swarm(nodes) self._inject_swarm_tools() - self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self)) + run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -336,7 +336,7 @@ async def stream_async( if invocation_state is None: invocation_state = {} - self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state)) + await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state)) logger.debug("starting swarm execution") @@ -375,7 +375,7 @@ async def stream_async( raise finally: self.state.execution_time = round((time.time() - self.state.start_time) * 1000) - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self, invocation_state)) + await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self, invocation_state)) self._resume_from_session = False # Yield final result after execution_time is set @@ -687,7 +687,9 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato # TODO: Implement cancellation token to stop _execute_node from continuing try: # Execute with timeout wrapper for async generator streaming - self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, current_node.node_id, invocation_state)) + await self.hooks.invoke_callbacks_async( + BeforeNodeCallEvent(self, current_node.node_id, invocation_state) + ) node_stream = self._stream_with_timeout( self._execute_node(current_node, self.state.task, invocation_state), self.node_timeout, @@ -699,7 +701,9 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato self.state.node_history.append(current_node) # After self.state add current node, swarm state finish updating, we persist here - self.hooks.invoke_callbacks(AfterNodeCallEvent(self, current_node.node_id, invocation_state)) + await self.hooks.invoke_callbacks_async( + AfterNodeCallEvent(self, current_node.node_id, invocation_state) + ) logger.debug("node=<%s> | node execution completed", current_node.node_id) diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index f9a482558..87c38990d 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -85,7 +85,7 @@ async def _stream( } ) - before_event, interrupts = agent.hooks.invoke_callbacks( + before_event, interrupts = await agent.hooks.invoke_callbacks_async( BeforeToolCallEvent( agent=agent, selected_tool=tool_func, @@ -109,7 +109,7 @@ async def _stream( "status": "error", "content": [{"text": cancel_message}], } - after_event, _ = agent.hooks.invoke_callbacks( + after_event, _ = await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( agent=agent, tool_use=tool_use, @@ -147,7 +147,7 @@ async def _stream( "status": "error", "content": [{"text": f"Unknown tool: {tool_name}"}], } - after_event, _ = agent.hooks.invoke_callbacks( + after_event, _ = await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -184,7 +184,7 @@ async def _stream( result = cast(ToolResult, event) - after_event, _ = agent.hooks.invoke_callbacks( + after_event, _ = await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -204,7 +204,7 @@ async def _stream( "status": "error", "content": [{"text": f"Error: {str(e)}"}], } - after_event, _ = agent.hooks.invoke_callbacks( + after_event, _ = await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, diff --git a/tests/strands/agent/hooks/test_hook_registry.py b/tests/strands/agent/hooks/test_hook_registry.py index 680ded682..ad1415f22 100644 --- a/tests/strands/agent/hooks/test_hook_registry.py +++ b/tests/strands/agent/hooks/test_hook_registry.py @@ -113,29 +113,32 @@ def test_get_callbacks_for_after_event(hook_registry, after_event): assert callbacks[1] == callback1 # Reverse order -def test_invoke_callbacks(hook_registry, normal_event): - """Test that invoke_callbacks calls all registered callbacks for an event.""" +@pytest.mark.asyncio +async def test_invoke_callbacks_async(hook_registry, normal_event): + """Test that invoke_callbacks_async calls all registered callbacks for an event.""" callback1 = Mock() callback2 = Mock() hook_registry.add_callback(NormalTestEvent, callback1) hook_registry.add_callback(NormalTestEvent, callback2) - hook_registry.invoke_callbacks(normal_event) + await hook_registry.invoke_callbacks_async(normal_event) callback1.assert_called_once_with(normal_event) callback2.assert_called_once_with(normal_event) -def test_invoke_callbacks_no_registered_callbacks(hook_registry, normal_event): - """Test that invoke_callbacks doesn't fail when there are no registered callbacks.""" +@pytest.mark.asyncio +async def test_invoke_callbacks_async_no_registered_callbacks(hook_registry, normal_event): + """Test that invoke_callbacks_async doesn't fail when there are no registered callbacks.""" # No callbacks registered - hook_registry.invoke_callbacks(normal_event) + await hook_registry.invoke_callbacks_async(normal_event) # Test passes if no exception is raised -def test_invoke_callbacks_after_event(hook_registry, after_event): - """Test that invoke_callbacks calls callbacks in reverse order for after events.""" +@pytest.mark.asyncio +async def test_invoke_callbacks_async_after_event(hook_registry, after_event): + """Test that invoke_callbacks_async calls callbacks in reverse order for after events.""" call_order: List[str] = [] def callback1(_event): @@ -147,7 +150,7 @@ def callback2(_event): hook_registry.add_callback(AfterTestEvent, callback1) hook_registry.add_callback(AfterTestEvent, callback2) - hook_registry.invoke_callbacks(after_event) + await hook_registry.invoke_callbacks_async(after_event) assert call_order == ["callback2", "callback1"] # Reverse order diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 72fe1b4bd..09bacbcb0 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,6 +1,6 @@ import concurrent import unittest.mock -from unittest.mock import ANY, MagicMock, call, patch +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest @@ -750,6 +750,7 @@ async def test_request_state_initialization(alist): # not setting this to False results in endless recursion mock_agent._interrupt_state.activated = False mock_agent.event_loop_metrics.start_cycle.return_value = (0, MagicMock()) + mock_agent.hooks.invoke_callbacks_async = AsyncMock() # Call without providing request_state stream = strands.event_loop.event_loop.event_loop_cycle( diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index 6d3e3a9b5..886da2f0b 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -1,6 +1,6 @@ """Tests for structured output integration in the event loop.""" -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from pydantic import BaseModel @@ -38,10 +38,10 @@ def mock_agent(): agent.tool_registry = ToolRegistry() agent.event_loop_metrics = EventLoopMetrics() agent.hooks = Mock() - agent.hooks.invoke_callbacks = Mock() + agent.hooks.invoke_callbacks_async = AsyncMock() agent.trace_span = None agent.tool_executor = Mock() - agent._append_message = Mock() + agent._append_message = AsyncMock() # Set up _interrupt_state properly agent._interrupt_state = Mock() diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py index db9cd3783..6744aa00c 100644 --- a/tests/strands/experimental/hooks/test_hook_aliases.py +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -9,6 +9,8 @@ import sys from unittest.mock import Mock +import pytest + from strands.experimental.hooks import ( AfterModelInvocationEvent, AfterToolInvocationEvent, @@ -80,7 +82,8 @@ def test_after_model_call_event_type_equality(): assert isinstance(after_model_event, AfterModelCallEvent) -def test_experimental_aliases_in_hook_registry(): +@pytest.mark.asyncio +async def test_experimental_aliases_in_hook_registry(): """Verify that experimental aliases work with hook registry callbacks.""" hook_registry = HookRegistry() callback_called = False @@ -103,7 +106,7 @@ def experimental_callback(event: BeforeToolInvocationEvent): ) # Invoke callbacks - should work since alias points to same type - hook_registry.invoke_callbacks(test_event) + await hook_registry.invoke_callbacks_async(test_event) assert callback_called assert received_event is test_event diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 6918bd2ee..81c3bf2d3 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -3,7 +3,7 @@ import pytest from strands.agent.interrupt import InterruptState -from strands.hooks import BeforeToolCallEvent, HookRegistry +from strands.hooks import AgentInitializedEvent, BeforeInvocationEvent, BeforeToolCallEvent, HookRegistry from strands.interrupt import Interrupt @@ -19,7 +19,15 @@ def agent(): return instance -def test_hook_registry_invoke_callbacks_interrupt(registry, agent): +def test_hook_registry_add_callback_agent_init_coroutine(registry): + callback = unittest.mock.AsyncMock() + + with pytest.raises(ValueError, match=r"AgentInitializedEvent can only be registered with a synchronous callback"): + registry.add_callback(AgentInitializedEvent, callback) + + +@pytest.mark.asyncio +async def test_hook_registry_invoke_callbacks_async_interrupt(registry, agent): event = BeforeToolCallEvent( agent=agent, selected_tool=None, @@ -35,7 +43,7 @@ def test_hook_registry_invoke_callbacks_interrupt(registry, agent): registry.add_callback(BeforeToolCallEvent, callback2) registry.add_callback(BeforeToolCallEvent, callback3) - _, tru_interrupts = registry.invoke_callbacks(event) + _, tru_interrupts = await registry.invoke_callbacks_async(event) exp_interrupts = [ Interrupt( id="v1:before_tool_call:test_tool_id:da3551f3-154b-5978-827e-50ac387877ee", @@ -55,7 +63,8 @@ def test_hook_registry_invoke_callbacks_interrupt(registry, agent): callback3.assert_called_once_with(event) -def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent): +@pytest.mark.asyncio +async def test_hook_registry_invoke_callbacks_async_interrupt_name_clash(registry, agent): event = BeforeToolCallEvent( agent=agent, selected_tool=None, @@ -70,4 +79,12 @@ def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent): registry.add_callback(BeforeToolCallEvent, callback2) with pytest.raises(ValueError, match="interrupt_name= | interrupt name used more than once"): - registry.invoke_callbacks(event) + await registry.invoke_callbacks_async(event) + + +def test_hook_registry_invoke_callbacks_coroutine(registry, agent): + callback = unittest.mock.AsyncMock() + registry.add_callback(BeforeInvocationEvent, callback) + + with pytest.raises(RuntimeError, match=r"use invoke_callbacks_async to invoke async callback"): + registry.invoke_callbacks(BeforeInvocationEvent(agent=agent)) diff --git a/tests_integ/hooks/__init__.py b/tests_integ/hooks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/hooks/multiagent/__init__.py b/tests_integ/hooks/multiagent/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests_integ/hooks/multiagent/test_events.py b/tests_integ/hooks/multiagent/test_events.py new file mode 100644 index 000000000..e8039444f --- /dev/null +++ b/tests_integ/hooks/multiagent/test_events.py @@ -0,0 +1,122 @@ +import pytest + +from strands import Agent +from strands.experimental.hooks.multiagent import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.hooks import HookProvider +from strands.multiagent import GraphBuilder, Swarm + + +@pytest.fixture +def callback_names(): + return [] + + +@pytest.fixture +def hook_provider(callback_names): + class TestHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(AfterMultiAgentInvocationEvent, self.after_multi_agent_invocation) + registry.add_callback(AfterMultiAgentInvocationEvent, self.after_multi_agent_invocation_async) + registry.add_callback(AfterNodeCallEvent, self.after_node_call) + registry.add_callback(AfterNodeCallEvent, self.after_node_call_async) + registry.add_callback(BeforeMultiAgentInvocationEvent, self.before_multi_agent_invocation) + registry.add_callback(BeforeMultiAgentInvocationEvent, self.before_multi_agent_invocation_async) + registry.add_callback(BeforeNodeCallEvent, self.before_node_call) + registry.add_callback(BeforeNodeCallEvent, self.before_node_call_async) + registry.add_callback(MultiAgentInitializedEvent, self.multi_agent_initialized_event) + registry.add_callback(MultiAgentInitializedEvent, self.multi_agent_initialized_event_async) + + def after_multi_agent_invocation(self, _event): + callback_names.append("after_multi_agent_invocation") + + async def after_multi_agent_invocation_async(self, _event): + callback_names.append("after_multi_agent_invocation_async") + + def after_node_call(self, _event): + callback_names.append("after_node_call") + + async def after_node_call_async(self, _event): + callback_names.append("after_node_call_async") + + def before_multi_agent_invocation(self, _event): + callback_names.append("before_multi_agent_invocation") + + async def before_multi_agent_invocation_async(self, _event): + callback_names.append("before_multi_agent_invocation_async") + + def before_node_call(self, _event): + callback_names.append("before_node_call") + + async def before_node_call_async(self, _event): + callback_names.append("before_node_call_async") + + def multi_agent_initialized_event(self, _event): + callback_names.append("multi_agent_initialized_event") + + async def multi_agent_initialized_event_async(self, _event): + callback_names.append("multi_agent_initialized_event_async") + + return TestHook() + + +@pytest.fixture +def agent(): + return Agent() + + +@pytest.fixture +def graph(agent, hook_provider): + builder = GraphBuilder() + builder.add_node(agent, "agent") + builder.set_entry_point("agent") + builder.set_hook_providers([hook_provider]) + return builder.build() + + +@pytest.fixture +def swarm(agent, hook_provider): + return Swarm([agent], hooks=[hook_provider]) + + +def test_graph_events(graph, callback_names): + graph("Hello") + + tru_callback_names = callback_names + exp_callback_names = [ + "multi_agent_initialized_event", + "multi_agent_initialized_event_async", + "before_multi_agent_invocation", + "before_multi_agent_invocation_async", + "before_node_call", + "before_node_call_async", + "after_node_call_async", + "after_node_call", + "after_multi_agent_invocation_async", + "after_multi_agent_invocation", + ] + assert tru_callback_names == exp_callback_names + + +def test_swarm_events(swarm, callback_names): + swarm("Hello") + + tru_callback_names = callback_names + exp_callback_names = [ + "multi_agent_initialized_event", + "multi_agent_initialized_event_async", + "before_multi_agent_invocation", + "before_multi_agent_invocation_async", + "before_node_call", + "before_node_call_async", + "after_node_call_async", + "after_node_call", + "after_multi_agent_invocation_async", + "after_multi_agent_invocation", + ] + assert tru_callback_names == exp_callback_names diff --git a/tests_integ/hooks/test_events.py b/tests_integ/hooks/test_events.py new file mode 100644 index 000000000..25971ecb0 --- /dev/null +++ b/tests_integ/hooks/test_events.py @@ -0,0 +1,138 @@ +import pytest + +from strands import Agent, tool +from strands.hooks import ( + AfterInvocationEvent, + AfterModelCallEvent, + AfterToolCallEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, + HookProvider, + MessageAddedEvent, +) + + +@pytest.fixture +def callback_names(): + return [] + + +@pytest.fixture +def hook_provider(callback_names): + class TestHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(AfterInvocationEvent, self.after_invocation) + registry.add_callback(AfterInvocationEvent, self.after_invocation_async) + registry.add_callback(AfterModelCallEvent, self.after_model_call) + registry.add_callback(AfterModelCallEvent, self.after_model_call_async) + registry.add_callback(AfterToolCallEvent, self.after_tool_call) + registry.add_callback(AfterToolCallEvent, self.after_tool_call_async) + registry.add_callback(AgentInitializedEvent, self.agent_initialized) + registry.add_callback(BeforeInvocationEvent, self.before_invocation) + registry.add_callback(BeforeInvocationEvent, self.before_invocation_async) + registry.add_callback(BeforeModelCallEvent, self.before_model_call) + registry.add_callback(BeforeModelCallEvent, self.before_model_call_async) + registry.add_callback(BeforeToolCallEvent, self.before_tool_call) + registry.add_callback(BeforeToolCallEvent, self.before_tool_call_async) + registry.add_callback(MessageAddedEvent, self.message_added) + registry.add_callback(MessageAddedEvent, self.message_added_async) + + def after_invocation(self, _event): + callback_names.append("after_invocation") + + async def after_invocation_async(self, _event): + callback_names.append("after_invocation_async") + + def after_model_call(self, _event): + callback_names.append("after_model_call") + + async def after_model_call_async(self, _event): + callback_names.append("after_model_call_async") + + def after_tool_call(self, _event): + callback_names.append("after_tool_call") + + async def after_tool_call_async(self, _event): + callback_names.append("after_tool_call_async") + + def agent_initialized(self, _event): + callback_names.append("agent_initialized") + + async def agent_initialized_async(self, _event): + callback_names.append("agent_initialized_async") + + def before_invocation(self, _event): + callback_names.append("before_invocation") + + async def before_invocation_async(self, _event): + callback_names.append("before_invocation_async") + + def before_model_call(self, _event): + callback_names.append("before_model_call") + + async def before_model_call_async(self, _event): + callback_names.append("before_model_call_async") + + def before_tool_call(self, _event): + callback_names.append("before_tool_call") + + async def before_tool_call_async(self, _event): + callback_names.append("before_tool_call_async") + + def message_added(self, _event): + callback_names.append("message_added") + + async def message_added_async(self, _event): + callback_names.append("message_added_async") + + return TestHook() + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def tool_() -> str: + return "12:00" + + return tool_ + + +@pytest.fixture +def agent(hook_provider, time_tool): + return Agent(hooks=[hook_provider], tools=[time_tool]) + + +def test_events(agent, callback_names): + agent("What time is it?") + + tru_callback_names = callback_names + exp_callback_names = [ + "agent_initialized", + "before_invocation", + "before_invocation_async", + "message_added", + "message_added_async", + "before_model_call", + "before_model_call_async", + "after_model_call_async", + "after_model_call", + "message_added", + "message_added_async", + "before_tool_call", + "before_tool_call_async", + "after_tool_call_async", + "after_tool_call", + "message_added", + "message_added_async", + "before_model_call", + "before_model_call_async", + "after_model_call_async", + "after_model_call", + "message_added", + "message_added_async", + "after_invocation_async", + "after_invocation", + ] + assert tru_callback_names == exp_callback_names From 3061116ebe839c1c8a3182eb736429c3fc4411b0 Mon Sep 17 00:00:00 2001 From: ratish <114130421+Ratish1@users.noreply.github.com> Date: Tue, 11 Nov 2025 01:55:13 +0400 Subject: [PATCH 06/13] feat(tools): Support string descriptions in Annotated parameters (#1089) --------- Co-authored-by: Dean Schmigelski --- src/strands/tools/decorator.py | 76 +++++++-- tests/strands/tools/test_decorator.py | 214 +++++++++++++++++++++++++- 2 files changed, 278 insertions(+), 12 deletions(-) diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 0ea328a39..8dc933f51 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -45,6 +45,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: import inspect import logging from typing import ( + Annotated, Any, Callable, Generic, @@ -54,12 +55,15 @@ def my_tool(param1: str, param2: int = 42) -> dict: TypeVar, Union, cast, + get_args, + get_origin, get_type_hints, overload, ) import docstring_parser from pydantic import BaseModel, Field, create_model +from pydantic.fields import FieldInfo from typing_extensions import override from ..interrupt import InterruptException @@ -105,15 +109,66 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - # Parse the docstring with docstring_parser doc_str = inspect.getdoc(func) or "" self.doc = docstring_parser.parse(doc_str) - - # Get parameter descriptions from parsed docstring - self.param_descriptions = { + self.param_descriptions: dict[str, str] = { param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params } # Create a Pydantic model for validation self.input_model = self._create_input_model() + def _extract_annotated_metadata( + self, annotation: Any, param_name: str, param_default: Any + ) -> tuple[Any, FieldInfo]: + """Extracts type and a simple string description from an Annotated type hint. + + Returns: + A tuple of (actual_type, field_info), where field_info is a new, simple + Pydantic FieldInfo instance created from the extracted metadata. + """ + actual_type = annotation + description: str | None = None + + if get_origin(annotation) is Annotated: + args = get_args(annotation) + actual_type = args[0] + + # Look through metadata for a string description or a FieldInfo object + for meta in args[1:]: + if isinstance(meta, str): + description = meta + elif isinstance(meta, FieldInfo): + # --- Future Contributor Note --- + # We are explicitly blocking the use of `pydantic.Field` within `Annotated` + # because of the complexities of Pydantic v2's immutable Core Schema. + # + # Once a Pydantic model's schema is built, its `FieldInfo` objects are + # effectively frozen. Attempts to mutate a `FieldInfo` object after + # creation (e.g., by copying it and setting `.description` or `.default`) + # are unreliable because the underlying Core Schema does not see these changes. + # + # The correct way to support this would be to reliably extract all + # constraints (ge, le, pattern, etc.) from the original FieldInfo and + # rebuild a new one from scratch. However, these constraints are not + # stored as public attributes, making them difficult to inspect reliably. + # + # Deferring this complexity until there is clear demand and a robust + # pattern for inspecting FieldInfo constraints is established. + raise NotImplementedError( + "Using pydantic.Field within Annotated is not yet supported for tool decorators. " + "Please use a simple string for the description, or define constraints in the function's " + "docstring." + ) + + # Determine the final description with a clear priority order + # Priority: 1. Annotated string -> 2. Docstring -> 3. Fallback + final_description = description + if final_description is None: + final_description = self.param_descriptions.get(param_name) or f"Parameter {param_name}" + # Create FieldInfo object from scratch + final_field = Field(default=param_default, description=final_description) + + return actual_type, final_field + def _validate_signature(self) -> None: """Verify that ToolContext is used correctly in the function signature.""" for param in self.signature.parameters.values(): @@ -146,22 +201,21 @@ def _create_input_model(self) -> Type[BaseModel]: if self._is_special_parameter(name): continue - # Get parameter type and default - param_type = self.type_hints.get(name, Any) + # Use param.annotation directly to get the raw type hint. Using get_type_hints() + # can cause inconsistent behavior across Python versions for complex Annotated types. + param_type = param.annotation + if param_type is inspect.Parameter.empty: + param_type = Any default = ... if param.default is inspect.Parameter.empty else param.default - description = self.param_descriptions.get(name, f"Parameter {name}") - # Create Field with description and default - field_definitions[name] = (param_type, Field(default=default, description=description)) + actual_type, field_info = self._extract_annotated_metadata(param_type, name, default) + field_definitions[name] = (actual_type, field_info) - # Create model name based on function name model_name = f"{self.func.__name__.capitalize()}Tool" - # Create and return the model if field_definitions: return create_model(model_name, **field_definitions) else: - # Handle case with no parameters return create_model(model_name) def _extract_description_from_docstring(self) -> str: diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index f89f1c945..0d5c65689 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -3,10 +3,11 @@ """ from asyncio import Queue -from typing import Any, AsyncGenerator, Dict, Optional, Union +from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union from unittest.mock import MagicMock import pytest +from pydantic import Field import strands from strands import Agent @@ -1611,3 +1612,214 @@ def test_function_tool_metadata_validate_signature_missing_context_config(): @strands.tool def my_tool(tool_context: ToolContext): pass + + +def test_tool_decorator_annotated_string_description(): + """Test tool decorator with Annotated type hints for descriptions.""" + + @strands.tool + def annotated_tool( + name: Annotated[str, "The user's full name"], + age: Annotated[int, "The user's age in years"], + city: str, # No annotation - should use docstring or generic + ) -> str: + """Tool with annotated parameters. + + Args: + city: The user's city (from docstring) + """ + return f"{name}, {age}, {city}" + + spec = annotated_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Check that annotated descriptions are used + assert schema["properties"]["name"]["description"] == "The user's full name" + assert schema["properties"]["age"]["description"] == "The user's age in years" + + # Check that docstring is still used for non-annotated params + assert schema["properties"]["city"]["description"] == "The user's city (from docstring)" + + # Verify all are required + assert set(schema["required"]) == {"name", "age", "city"} + + +def test_tool_decorator_annotated_pydantic_field_constraints(): + """Test that using pydantic.Field in Annotated raises a NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def field_annotated_tool( + email: Annotated[str, Field(description="User's email address", pattern=r"^[\w\.-]+@[\w\.-]+\\.w+$")], + score: Annotated[int, Field(description="Score between 0-100", ge=0, le=100)] = 50, + ) -> str: + """Tool with Pydantic Field annotations.""" + return f"{email}: {score}" + + +def test_tool_decorator_annotated_overrides_docstring(): + """Test that Annotated descriptions override docstring descriptions.""" + + @strands.tool + def override_tool(param: Annotated[str, "Description from annotation"]) -> str: + """Tool with both annotation and docstring. + + Args: + param: Description from docstring (should be overridden) + """ + return param + + spec = override_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Annotated description should win + assert schema["properties"]["param"]["description"] == "Description from annotation" + + +def test_tool_decorator_annotated_optional_type(): + """Test tool with Optional types in Annotated.""" + + @strands.tool + def optional_annotated_tool( + required: Annotated[str, "Required parameter"], optional: Annotated[Optional[str], "Optional parameter"] = None + ) -> str: + """Tool with optional annotated parameter.""" + return f"{required}, {optional}" + + spec = optional_annotated_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Check descriptions + assert schema["properties"]["required"]["description"] == "Required parameter" + assert schema["properties"]["optional"]["description"] == "Optional parameter" + + # Check required list + assert "required" in schema["required"] + assert "optional" not in schema["required"] + + +def test_tool_decorator_annotated_complex_types(): + """Test tool with complex types in Annotated.""" + + @strands.tool + def complex_annotated_tool( + tags: Annotated[List[str], "List of tag strings"], config: Annotated[Dict[str, Any], "Configuration dictionary"] + ) -> str: + """Tool with complex annotated types.""" + return f"Tags: {len(tags)}, Config: {len(config)}" + + spec = complex_annotated_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Check descriptions + assert schema["properties"]["tags"]["description"] == "List of tag strings" + assert schema["properties"]["config"]["description"] == "Configuration dictionary" + + # Check types are preserved + assert schema["properties"]["tags"]["type"] == "array" + assert schema["properties"]["config"]["type"] == "object" + + +def test_tool_decorator_annotated_mixed_styles(): + """Test that using pydantic.Field in a mixed-style annotation raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def mixed_tool( + plain: str, + annotated_str: Annotated[str, "String description"], + annotated_field: Annotated[int, Field(description="Field description", ge=0)], + docstring_only: int, + ) -> str: + """Tool with mixed parameter styles. + + Args: + plain: Plain parameter description + docstring_only: Docstring description for this param + """ + return "mixed" + + +@pytest.mark.asyncio +async def test_tool_decorator_annotated_execution(alist): + """Test that annotated tools execute correctly.""" + + @strands.tool + def execution_test(name: Annotated[str, "User name"], count: Annotated[int, "Number of times"] = 1) -> str: + """Test execution with annotations.""" + return f"Hello {name} " * count + + # Test tool use + tool_use = {"toolUseId": "test-id", "input": {"name": "Alice", "count": 2}} + stream = execution_test.stream(tool_use, {}) + + result = (await alist(stream))[-1] + assert result["tool_result"]["status"] == "success" + assert "Hello Alice Hello Alice" in result["tool_result"]["content"][0]["text"] + + # Test direct call + direct_result = execution_test("Bob", 3) + assert direct_result == "Hello Bob Hello Bob Hello Bob " + + +def test_tool_decorator_annotated_no_description_fallback(): + """Test that Annotated with a Field raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def no_desc_annotated( + param: Annotated[str, Field()], # Field without description + ) -> str: + """Tool with Annotated but no description. + + Args: + param: Docstring description + """ + return param + + +def test_tool_decorator_annotated_empty_string_description(): + """Test handling of empty string descriptions in Annotated.""" + + @strands.tool + def empty_desc_tool( + param: Annotated[str, ""], # Empty string description + ) -> str: + """Tool with empty annotation description. + + Args: + param: Docstring description + """ + return param + + spec = empty_desc_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Empty string is still a valid description, should not fall back + assert schema["properties"]["param"]["description"] == "" + + +@pytest.mark.asyncio +async def test_tool_decorator_annotated_validation_error(alist): + """Test that validation works correctly with annotated parameters.""" + + @strands.tool + def validation_tool(age: Annotated[int, "User age"]) -> str: + """Tool for validation testing.""" + return f"Age: {age}" + + # Test with wrong type + tool_use = {"toolUseId": "test-id", "input": {"age": "not an int"}} + stream = validation_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + assert result["tool_result"]["status"] == "error" + + +def test_tool_decorator_annotated_field_with_inner_default(): + """Test that a default value in an Annotated Field raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def inner_default_tool(name: str, level: Annotated[int, Field(description="A level value", default=10)]) -> str: + return f"{name} is at level {level}" From e930243e549415e7176f6220e5663d1874a8420a Mon Sep 17 00:00:00 2001 From: poshinchen Date: Tue, 11 Nov 2025 10:36:50 -0500 Subject: [PATCH 07/13] chore(telemetry): updated opt-in attributes to internal (#1152) --- src/strands/telemetry/tracer.py | 9 ++---- tests/strands/telemetry/test_tracer.py | 45 +++++++++++++------------- 2 files changed, 25 insertions(+), 29 deletions(-) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index a68aad8b7..c47a10c3f 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -80,10 +80,6 @@ class Tracer: When the OTEL_EXPORTER_OTLP_ENDPOINT environment variable is set, traces are sent to the OTLP endpoint. - Attributes: - use_latest_genai_conventions: If True, uses the latest experimental GenAI semantic conventions. - include_tool_definitions: If True, includes detailed tool definitions in the agent trace span. - Both attributes are controlled by including "gen_ai_latest_experimental" or "gen_ai_tool_definitions", respectively, in the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. """ @@ -98,8 +94,9 @@ def __init__(self) -> None: # Read OTEL_SEMCONV_STABILITY_OPT_IN environment variable opt_in_values = self._parse_semconv_opt_in() + ## To-do: should not set below attributes directly, use env var instead self.use_latest_genai_conventions = "gen_ai_latest_experimental" in opt_in_values - self.include_tool_definitions = "gen_ai_tool_definitions" in opt_in_values + self._include_tool_definitions = "gen_ai_tool_definitions" in opt_in_values def _parse_semconv_opt_in(self) -> set[str]: """Parse the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. @@ -587,7 +584,7 @@ def start_agent_span( if tools: attributes["gen_ai.agent.tools"] = serialize(tools) - if self.include_tool_definitions and tools_config: + if self._include_tool_definitions and tools_config: try: tool_definitions = self._construct_tool_definitions(tools_config) attributes["gen_ai.tool.definitions"] = serialize(tool_definitions) diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 25d477588..98cfb459f 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -163,11 +163,11 @@ def test_start_model_invoke_span(mock_tracer): assert span is not None -def test_start_model_invoke_span_latest_conventions(mock_tracer): +def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): """Test starting a model invoke span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -244,11 +244,11 @@ def test_end_model_invoke_span(mock_span): mock_span.end.assert_called_once() -def test_end_model_invoke_span_latest_conventions(mock_span): +def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): """Test ending a model invoke span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True message = {"role": "assistant", "content": [{"text": "Response"}]} usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) metrics = Metrics(latencyMs=20, timeToFirstByteMs=10) @@ -307,11 +307,11 @@ def test_start_tool_call_span(mock_tracer): assert span is not None -def test_start_tool_call_span_latest_conventions(mock_tracer): +def test_start_tool_call_span_latest_conventions(mock_tracer, monkeypatch): """Test starting a tool call span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -396,11 +396,11 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): assert span is not None -def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer): +def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, monkeypatch): """Test starting a swarm call span with task as list of contentBlock with latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -439,10 +439,10 @@ def test_end_swarm_span(mock_span): ) -def test_end_swarm_span_latest_conventions(mock_span): +def test_end_swarm_span_latest_conventions(mock_span, monkeypatch): """Test ending a tool call span with latest semantic conventions.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True swarm_final_reuslt = "foo bar bar" tracer.end_swarm_span(mock_span, swarm_final_reuslt) @@ -503,10 +503,10 @@ def test_end_tool_call_span(mock_span): mock_span.end.assert_called_once() -def test_end_tool_call_span_latest_conventions(mock_span): +def test_end_tool_call_span_latest_conventions(mock_span, monkeypatch): """Test ending a tool call span with the latest semantic conventions.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tool_result = {"status": "success", "content": [{"text": "Tool result"}, {"json": {"foo": "bar"}}]} tracer.end_tool_call_span(mock_span, tool_result) @@ -558,11 +558,11 @@ def test_start_event_loop_cycle_span(mock_tracer): assert span is not None -def test_start_event_loop_cycle_span_latest_conventions(mock_tracer): +def test_start_event_loop_cycle_span_latest_conventions(mock_tracer, monkeypatch): """Test starting an event loop cycle span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -609,10 +609,10 @@ def test_end_event_loop_cycle_span(mock_span): mock_span.end.assert_called_once() -def test_end_event_loop_cycle_span_latest_conventions(mock_span): +def test_end_event_loop_cycle_span_latest_conventions(mock_span, monkeypatch): """Test ending an event loop cycle span with the latest semantic conventions.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True message = {"role": "assistant", "content": [{"text": "Response"}]} tool_result_message = { "role": "assistant", @@ -679,11 +679,11 @@ def test_start_agent_span(mock_tracer): assert span is not None -def test_start_agent_span_latest_conventions(mock_tracer): +def test_start_agent_span_latest_conventions(mock_tracer, monkeypatch): """Test starting an agent span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -749,10 +749,10 @@ def test_end_agent_span(mock_span): mock_span.end.assert_called_once() -def test_end_agent_span_latest_conventions(mock_span): +def test_end_agent_span_latest_conventions(mock_span, monkeypatch): """Test ending an agent span with the latest semantic conventions.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True # Mock AgentResult with metrics mock_metrics = mock.MagicMock() @@ -1329,7 +1329,6 @@ def test_start_event_loop_cycle_span_with_tool_result_message(mock_tracer): def test_start_agent_span_does_not_include_tool_definitions_by_default(): """Verify that start_agent_span does not include tool definitions by default.""" tracer = Tracer() - tracer.include_tool_definitions = False tracer._start_span = mock.MagicMock() tools_config = { @@ -1349,10 +1348,10 @@ def test_start_agent_span_does_not_include_tool_definitions_by_default(): assert "gen_ai.tool.definitions" not in attributes -def test_start_agent_span_includes_tool_definitions_when_enabled(): +def test_start_agent_span_includes_tool_definitions_when_enabled(monkeypatch): """Verify that start_agent_span includes tool definitions when enabled.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_tool_definitions") tracer = Tracer() - tracer.include_tool_definitions = True tracer._start_span = mock.MagicMock() tools_config = { From bbe765de9f75dab67963592df4678c3a8a0a49c2 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Tue, 11 Nov 2025 17:46:25 +0200 Subject: [PATCH 08/13] feat(models): allow SystemContentBlocks in LiteLLMModel (#1141) --- src/strands/models/litellm.py | 121 ++++++++++++++++++++++- src/strands/models/openai.py | 92 ++++++++++++++--- src/strands/models/sagemaker.py | 8 +- tests/strands/models/test_litellm.py | 66 +++++++++++++ tests/strands/models/test_openai.py | 42 ++++++++ tests_integ/models/test_model_litellm.py | 22 +++++ tests_integ/models/test_model_openai.py | 26 +++++ 7 files changed, 357 insertions(+), 20 deletions(-) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 7a8c0ae03..f2480c8d8 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -14,9 +14,10 @@ from typing_extensions import Unpack, override from ..tools import convert_pydantic_to_tool_spec -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.event_loop import Usage from ..types.exceptions import ContextWindowOverflowException -from ..types.streaming import StreamEvent +from ..types.streaming import MetadataEvent, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys from .openai import OpenAIModel @@ -81,11 +82,12 @@ def get_config(self) -> LiteLLMConfig: @override @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format a LiteLLM content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: LiteLLM formatted content block. @@ -131,6 +133,113 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> return chunks, data_type + @override + @classmethod + def _format_system_messages( + cls, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format system messages for LiteLLM with cache point support. + + Args: + system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + List of formatted system messages. + """ + # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + system_content: list[dict[str, Any]] = [] + for block in system_prompt_content or []: + if "text" in block: + system_content.append({"type": "text", "text": block["text"]}) + elif "cachePoint" in block and block["cachePoint"].get("type") == "default": + # Apply cache control to the immediately preceding content block + # for LiteLLM/Anthropic compatibility + if system_content: + system_content[-1]["cache_control"] = {"type": "ephemeral"} + + # Create single system message with content array rather than mulitple system messages + return [{"role": "system", "content": system_content}] if system_content else [] + + @override + @classmethod + def format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format a LiteLLM compatible messages array with cache point support. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model (for legacy compatibility). + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + A LiteLLM compatible messages array. + """ + formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) + formatted_messages.extend(cls._format_regular_messages(messages)) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + @override + def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: + """Format a LiteLLM response event into a standardized message chunk. + + This method overrides OpenAI's format_chunk to handle the metadata case + with prompt caching support. All other chunk types use the parent implementation. + + Args: + event: A response event from the LiteLLM model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + # Handle metadata case with prompt caching support + if event["chunk_type"] == "metadata": + usage_data: Usage = { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + } + + # Only LiteLLM over Anthropic supports cache write tokens + # Waiting until a more general approach is available to set cacheWriteInputTokens + + if tokens_details := getattr(event["data"], "prompt_tokens_details", None): + if cached := getattr(tokens_details, "cached_tokens", None): + usage_data["cacheReadInputTokens"] = cached + if creation := getattr(tokens_details, "cache_creation_tokens", None): + usage_data["cacheWriteInputTokens"] = creation + + return StreamEvent( + metadata=MetadataEvent( + metrics={ + "latencyMs": 0, # TODO + }, + usage=usage_data, + ) + ) + # For all other cases, use the parent implementation + return super().format_chunk(event) + @override async def stream( self, @@ -139,6 +248,7 @@ async def stream( system_prompt: Optional[str] = None, *, tool_choice: ToolChoice | None = None, + system_prompt_content: Optional[list[SystemContentBlock]] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LiteLLM model. @@ -148,13 +258,16 @@ async def stream( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. + system_prompt_content: System prompt content blocks to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt, tool_choice) + request = self.format_request( + messages, tool_specs, system_prompt, tool_choice, system_prompt_content=system_prompt_content + ) logger.debug("request=<%s>", request) logger.debug("invoking model") diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 1efe641e6..435c82cab 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -14,7 +14,7 @@ from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse @@ -89,11 +89,12 @@ def get_config(self) -> OpenAIConfig: return cast(OpenAIModel.OpenAIConfig, self.config) @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible content block. @@ -131,11 +132,12 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @classmethod - def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible tool call. Args: tool_use: Tool use requested by the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible tool call. @@ -150,11 +152,12 @@ def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: } @classmethod - def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible tool message. Args: tool_result: Tool result collected from a tool execution. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible tool message. @@ -198,18 +201,46 @@ def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str return {"tool_choice": "auto"} @classmethod - def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: - """Format an OpenAI compatible messages array. + def _format_system_messages( + cls, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format system messages for OpenAI-compatible providers. Args: - messages: List of message objects to be processed by the model. system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: - An OpenAI compatible messages array. + List of formatted system messages. + """ + # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + # TODO: Handle caching blocks https://github.com/strands-agents/sdk-python/issues/1140 + return [ + {"role": "system", "content": content["text"]} + for content in system_prompt_content or [] + if "text" in content + ] + + @classmethod + def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dict[str, Any]]: + """Format regular messages for OpenAI-compatible providers. + + Args: + messages: List of message objects to be processed by the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + List of formatted messages. """ - formatted_messages: list[dict[str, Any]] - formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + formatted_messages = [] for message in messages: contents = message["content"] @@ -242,14 +273,42 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str formatted_messages.append(formatted_message) formatted_messages.extend(formatted_tool_messages) + return formatted_messages + + @classmethod + def format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) + formatted_messages.extend(cls._format_regular_messages(messages)) + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, + *, + system_prompt_content: list[SystemContentBlock] | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Format an OpenAI compatible chat streaming request. @@ -258,6 +317,8 @@ def format_request( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: An OpenAI compatible chat streaming request. @@ -267,7 +328,9 @@ def format_request( format. """ return { - "messages": self.format_request_messages(messages, system_prompt), + "messages": self.format_request_messages( + messages, system_prompt, system_prompt_content=system_prompt_content + ), "model": self.config["model_id"], "stream": True, "stream_options": {"include_usage": True}, @@ -286,11 +349,12 @@ def format_request( **cast(dict[str, Any], self.config.get("params", {})), } - def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: """Format an OpenAI response event into a standardized message chunk. Args: event: A response event from the OpenAI compatible model. + **kwargs: Additional keyword arguments for future extensibility. Returns: The formatted chunk. diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 25b3ca7ce..7f8b8ff51 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -202,6 +202,7 @@ def format_request( tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, tool_choice: ToolChoice | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Format an Amazon SageMaker chat streaming request. @@ -211,6 +212,7 @@ def format_request( system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for interface consistency but is currently ignored for this model provider.** + **kwargs: Additional keyword arguments for future extensibility. Returns: An Amazon SageMaker chat streaming request. @@ -501,11 +503,12 @@ async def stream( @override @classmethod - def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> dict[str, Any]: """Format a SageMaker compatible tool message. Args: tool_result: Tool result collected from a tool execution. + **kwargs: Additional keyword arguments for future extensibility. Returns: SageMaker compatible tool message with content as a string. @@ -531,11 +534,12 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: @override @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format a content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: Formatted content block. diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 57a8593cd..f56438cf5 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -192,6 +192,8 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, mock_event_7 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_7)]) mock_event_8 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_8)]) mock_event_9 = unittest.mock.Mock() + mock_event_9.usage.prompt_tokens_details.cached_tokens = 10 + mock_event_9.usage.prompt_tokens_details.cache_creation_tokens = 10 litellm_acompletion.side_effect = unittest.mock.AsyncMock( return_value=agenerator( @@ -252,6 +254,8 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, { "metadata": { "usage": { + "cacheReadInputTokens": mock_event_9.usage.prompt_tokens_details.cached_tokens, + "cacheWriteInputTokens": mock_event_9.usage.prompt_tokens_details.cache_creation_tokens, "inputTokens": mock_event_9.usage.prompt_tokens, "outputTokens": mock_event_9.usage.completion_tokens, "totalTokens": mock_event_9.usage.total_tokens, @@ -402,3 +406,65 @@ async def test_context_window_maps_to_typed_exception(litellm_acompletion, model with pytest.raises(ContextWindowOverflowException): async for _ in model.stream([{"role": "user", "content": [{"text": "x"}]}]): pass + + +def test_format_request_messages_with_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}, {"cachePoint": {"type": "default"}}] + + result = LiteLLMModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant.", "cache_control": {"type": "ephemeral"}} + ], + }, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_backward_compatibility_system_prompt(): + """Test that system_prompt is converted to system_prompt_content when system_prompt_content is None.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt = "You are a helpful assistant." + + result = LiteLLMModel.format_request_messages(messages, system_prompt=system_prompt) + + expected = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_cache_point_support(): + """Test that cache points are properly applied to preceding content blocks.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [ + {"text": "First instruction."}, + {"text": "Second instruction."}, + {"cachePoint": {"type": "default"}}, + {"text": "Third instruction."}, + ] + + result = LiteLLMModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "First instruction."}, + {"type": "text", "text": "Second instruction.", "cache_control": {"type": "ephemeral"}}, + {"type": "text", "text": "Third instruction."}, + ], + }, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index cc30b7420..0de0c4ebc 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -944,3 +944,45 @@ async def test_structured_output_rate_limit_as_throttle(openai_client, model, me # Verify the exception message contains the original error assert "tokens per min" in str(exc_info.value) assert exc_info.value.__cause__ == mock_error + + +def test_format_request_messages_with_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}] + + result = OpenAIModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_with_none_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + result = OpenAIModel.format_request_messages(messages) + + expected = [{"role": "user", "content": [{"text": "Hello", "type": "text"}]}] + + assert result == expected + + +def test_format_request_messages_drops_cache_points(): + """Test that cache points are dropped in OpenAI format_request_messages.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}, {"cachePoint": {"type": "default"}}] + + result = OpenAIModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + # Cache points should be dropped, only text content included + expected = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index b348c29f4..f177c08a4 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -211,3 +211,25 @@ def test_structured_output_unsupported_model(model, nested_weather): # Verify that the tool method was called and schema method was not mock_tool.assert_called_once() mock_schema.assert_not_called() + + +@pytest.mark.asyncio +async def test_cache_read_tokens_multi_turn(model): + """Integration test for cache read tokens in multi-turn conversation.""" + from strands.types.content import SystemContentBlock + + system_prompt_content: list[SystemContentBlock] = [ + # Caching only works when prompts are large + {"text": "You are a helpful assistant. Always be concise." * 200}, + {"cachePoint": {"type": "default"}}, + ] + + agent = Agent(model=model, system_prompt=system_prompt_content) + + # First turn - establishes cache + agent("Hello, what's 2+2?") + result = agent("What's 3+3?") + result.metrics.accumulated_usage["cacheReadInputTokens"] + + assert result.metrics.accumulated_usage["cacheReadInputTokens"] > 0 + assert result.metrics.accumulated_usage["cacheWriteInputTokens"] > 0 diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 7beb3013c..feb591d1a 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -231,3 +231,29 @@ def test_content_blocks_handling(model): result = agent(content) assert "4" in result.message["content"][0]["text"] + + +def test_system_prompt_content_integration(model): + """Integration test for system_prompt_content parameter.""" + from strands.types.content import SystemContentBlock + + system_prompt_content: list[SystemContentBlock] = [ + {"text": "You are a helpful assistant that always responds with 'SYSTEM_TEST_RESPONSE'."} + ] + + agent = Agent(model=model, system_prompt=system_prompt_content) + result = agent("Hello") + + # The response should contain our specific system prompt instruction + assert "SYSTEM_TEST_RESPONSE" in result.message["content"][0]["text"] + + +def test_system_prompt_backward_compatibility_integration(model): + """Integration test for backward compatibility with system_prompt parameter.""" + system_prompt = "You are a helpful assistant that always responds with 'BACKWARD_COMPAT_TEST'." + + agent = Agent(model=model, system_prompt=system_prompt) + result = agent("Hello") + + # The response should contain our specific system prompt instruction + assert "BACKWARD_COMPAT_TEST" in result.message["content"][0]["text"] From ccc3a8b46d71d11531c85277f815049cc1760bb4 Mon Sep 17 00:00:00 2001 From: Patrick Gray Date: Tue, 11 Nov 2025 11:27:56 -0500 Subject: [PATCH 09/13] share interrupt state (#1148) --- src/strands/agent/agent.py | 39 +------ src/strands/agent/interrupt.py | 59 ---------- src/strands/interrupt.py | 94 ++++++++++++++- src/strands/types/session.py | 4 +- tests/strands/agent/test_interrupt.py | 61 ---------- tests/strands/event_loop/test_event_loop.py | 5 +- tests/strands/hooks/test_registry.py | 5 +- .../test_repository_session_manager.py | 4 +- tests/strands/test_interrupt.py | 108 +++++++++++++++++- tests/strands/tools/executors/conftest.py | 4 +- tests/strands/tools/test_decorator.py | 7 +- tests/strands/types/test_interrupt.py | 5 +- tests/strands/types/test_session.py | 6 +- 13 files changed, 220 insertions(+), 181 deletions(-) delete mode 100644 src/strands/agent/interrupt.py delete mode 100644 tests/strands/agent/test_interrupt.py diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index fa4f7051f..b7633d5e8 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -46,6 +46,7 @@ HookRegistry, MessageAddedEvent, ) +from ..interrupt import _InterruptState from ..models.bedrock import BedrockModel from ..models.model import Model from ..session.session_manager import SessionManager @@ -60,7 +61,6 @@ from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException -from ..types.interrupt import InterruptResponseContent from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -68,7 +68,6 @@ ConversationManager, SlidingWindowConversationManager, ) -from .interrupt import InterruptState from .state import AgentState logger = logging.getLogger(__name__) @@ -352,7 +351,7 @@ def __init__( self.hooks = HookRegistry() - self._interrupt_state = InterruptState() + self._interrupt_state = _InterruptState() # Initialize session management functionality self._session_manager = session_manager @@ -640,7 +639,7 @@ async def stream_async( yield event["data"] ``` """ - self._resume_interrupt(prompt) + self._interrupt_state.resume(prompt) merged_state = {} if kwargs: @@ -683,38 +682,6 @@ async def stream_async( self._end_agent_trace_span(error=e) raise - def _resume_interrupt(self, prompt: AgentInput) -> None: - """Configure the interrupt state if resuming from an interrupt event. - - Args: - prompt: User responses if resuming from interrupt. - - Raises: - TypeError: If in interrupt state but user did not provide responses. - """ - if not self._interrupt_state.activated: - return - - if not isinstance(prompt, list): - raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's") - - invalid_types = [ - content_type for content in prompt for content_type in content if content_type != "interruptResponse" - ] - if invalid_types: - raise TypeError( - f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's" - ) - - for content in cast(list[InterruptResponseContent], prompt): - interrupt_id = content["interruptResponse"]["interruptId"] - interrupt_response = content["interruptResponse"]["response"] - - if interrupt_id not in self._interrupt_state.interrupts: - raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found") - - self._interrupt_state.interrupts[interrupt_id].response = interrupt_response - async def _run_loop( self, messages: Messages, diff --git a/src/strands/agent/interrupt.py b/src/strands/agent/interrupt.py deleted file mode 100644 index 3cec1541b..000000000 --- a/src/strands/agent/interrupt.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Track the state of interrupt events raised by the user for human-in-the-loop workflows.""" - -from dataclasses import asdict, dataclass, field -from typing import Any - -from ..interrupt import Interrupt - - -@dataclass -class InterruptState: - """Track the state of interrupt events raised by the user. - - Note, interrupt state is cleared after resuming. - - Attributes: - interrupts: Interrupts raised by the user. - context: Additional context associated with an interrupt event. - activated: True if agent is in an interrupt state, False otherwise. - """ - - interrupts: dict[str, Interrupt] = field(default_factory=dict) - context: dict[str, Any] = field(default_factory=dict) - activated: bool = False - - def activate(self, context: dict[str, Any] | None = None) -> None: - """Activate the interrupt state. - - Args: - context: Context associated with the interrupt event. - """ - self.context = context or {} - self.activated = True - - def deactivate(self) -> None: - """Deacitvate the interrupt state. - - Interrupts and context are cleared. - """ - self.interrupts = {} - self.context = {} - self.activated = False - - def to_dict(self) -> dict[str, Any]: - """Serialize to dict for session management.""" - return asdict(self) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "InterruptState": - """Initiailize interrupt state from serialized interrupt state. - - Interrupt state can be serialized with the `to_dict` method. - """ - return cls( - interrupts={ - interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items() - }, - context=data["context"], - activated=data["activated"], - ) diff --git a/src/strands/interrupt.py b/src/strands/interrupt.py index f0ed52389..919927e1a 100644 --- a/src/strands/interrupt.py +++ b/src/strands/interrupt.py @@ -1,7 +1,11 @@ """Human-in-the-loop interrupt system for agent workflows.""" -from dataclasses import asdict, dataclass -from typing import Any +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from .types.agent import AgentInput + from .types.interrupt import InterruptResponseContent @dataclass @@ -31,3 +35,89 @@ class InterruptException(Exception): def __init__(self, interrupt: Interrupt) -> None: """Set the interrupt.""" self.interrupt = interrupt + + +@dataclass +class _InterruptState: + """Track the state of interrupt events raised by the user. + + Note, interrupt state is cleared after resuming. + + Attributes: + interrupts: Interrupts raised by the user. + context: Additional context associated with an interrupt event. + activated: True if agent is in an interrupt state, False otherwise. + """ + + interrupts: dict[str, Interrupt] = field(default_factory=dict) + context: dict[str, Any] = field(default_factory=dict) + activated: bool = False + + def activate(self, context: dict[str, Any] | None = None) -> None: + """Activate the interrupt state. + + Args: + context: Context associated with the interrupt event. + """ + self.context = context or {} + self.activated = True + + def deactivate(self) -> None: + """Deacitvate the interrupt state. + + Interrupts and context are cleared. + """ + self.interrupts = {} + self.context = {} + self.activated = False + + def resume(self, prompt: "AgentInput") -> None: + """Configure the interrupt state if resuming from an interrupt event. + + Args: + prompt: User responses if resuming from interrupt. + + Raises: + TypeError: If in interrupt state but user did not provide responses. + """ + if not self.activated: + return + + if not isinstance(prompt, list): + raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's") + + invalid_types = [ + content_type for content in prompt for content_type in content if content_type != "interruptResponse" + ] + if invalid_types: + raise TypeError( + f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's" + ) + + contents = cast(list["InterruptResponseContent"], prompt) + for content in contents: + interrupt_id = content["interruptResponse"]["interruptId"] + interrupt_response = content["interruptResponse"]["response"] + + if interrupt_id not in self.interrupts: + raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found") + + self.interrupts[interrupt_id].response = interrupt_response + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict for session management.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "_InterruptState": + """Initiailize interrupt state from serialized interrupt state. + + Interrupt state can be serialized with the `to_dict` method. + """ + return cls( + interrupts={ + interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items() + }, + context=data["context"], + activated=data["activated"], + ) diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 4e72a1468..8b78ab448 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -7,7 +7,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Optional -from ..agent.interrupt import InterruptState +from ..interrupt import _InterruptState from .content import Message if TYPE_CHECKING: @@ -148,7 +148,7 @@ def to_dict(self) -> dict[str, Any]: def initialize_internal_state(self, agent: "Agent") -> None: """Initialize internal state of agent.""" if "interrupt_state" in self._internal_state: - agent._interrupt_state = InterruptState.from_dict(self._internal_state["interrupt_state"]) + agent._interrupt_state = _InterruptState.from_dict(self._internal_state["interrupt_state"]) @dataclass diff --git a/tests/strands/agent/test_interrupt.py b/tests/strands/agent/test_interrupt.py deleted file mode 100644 index e248c29a6..000000000 --- a/tests/strands/agent/test_interrupt.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest - -from strands.agent.interrupt import InterruptState -from strands.interrupt import Interrupt - - -@pytest.fixture -def interrupt(): - return Interrupt(id="test_id", name="test_name", reason="test reason") - - -def test_interrupt_activate(): - interrupt_state = InterruptState() - - interrupt_state.activate(context={"test": "context"}) - - assert interrupt_state.activated - - tru_context = interrupt_state.context - exp_context = {"test": "context"} - assert tru_context == exp_context - - -def test_interrupt_deactivate(): - interrupt_state = InterruptState(context={"test": "context"}, activated=True) - - interrupt_state.deactivate() - - assert not interrupt_state.activated - - tru_context = interrupt_state.context - exp_context = {} - assert tru_context == exp_context - - -def test_interrupt_state_to_dict(interrupt): - interrupt_state = InterruptState(interrupts={"test_id": interrupt}, context={"test": "context"}, activated=True) - - tru_data = interrupt_state.to_dict() - exp_data = { - "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, - "context": {"test": "context"}, - "activated": True, - } - assert tru_data == exp_data - - -def test_interrupt_state_from_dict(): - data = { - "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, - "context": {"test": "context"}, - "activated": True, - } - - tru_state = InterruptState.from_dict(data) - exp_state = InterruptState( - interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, - context={"test": "context"}, - activated=True, - ) - assert tru_state == exp_state diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 09bacbcb0..9335f91a8 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -6,7 +6,6 @@ import strands import strands.telemetry -from strands.agent.interrupt import InterruptState from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, @@ -14,7 +13,7 @@ HookRegistry, MessageAddedEvent, ) -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState from strands.telemetry.metrics import EventLoopMetrics from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry @@ -143,7 +142,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.event_loop_metrics = EventLoopMetrics() mock.hooks = hook_registry mock.tool_executor = tool_executor - mock._interrupt_state = InterruptState() + mock._interrupt_state = _InterruptState() return mock diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 81c3bf2d3..3daf41734 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -2,9 +2,8 @@ import pytest -from strands.agent.interrupt import InterruptState from strands.hooks import AgentInitializedEvent, BeforeInvocationEvent, BeforeToolCallEvent, HookRegistry -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState @pytest.fixture @@ -15,7 +14,7 @@ def registry(): @pytest.fixture def agent(): instance = unittest.mock.Mock() - instance._interrupt_state = InterruptState() + instance._interrupt_state = _InterruptState() return instance diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index ed0ec9072..451d0dd09 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -7,7 +7,7 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager -from strands.agent.interrupt import InterruptState +from strands.interrupt import _InterruptState from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import ContentBlock from strands.types.exceptions import SessionException @@ -131,7 +131,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): assert len(agent.messages) == 1 assert agent.messages[0]["role"] == "user" assert agent.messages[0]["content"][0]["text"] == "Hello" - assert agent._interrupt_state == InterruptState(interrupts={}, context={"test": "init"}, activated=False) + assert agent._interrupt_state == _InterruptState(interrupts={}, context={"test": "init"}, activated=False) def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(session_manager): diff --git a/tests/strands/test_interrupt.py b/tests/strands/test_interrupt.py index 8ce972103..a45d524e4 100644 --- a/tests/strands/test_interrupt.py +++ b/tests/strands/test_interrupt.py @@ -1,6 +1,6 @@ import pytest -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState @pytest.fixture @@ -22,3 +22,109 @@ def test_interrupt_to_dict(interrupt): "response": {"response": "test"}, } assert tru_dict == exp_dict + + +def test_interrupt_state_activate(): + interrupt_state = _InterruptState() + + interrupt_state.activate(context={"test": "context"}) + + assert interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {"test": "context"} + assert tru_context == exp_context + + +def test_interrupt_state_deactivate(): + interrupt_state = _InterruptState(context={"test": "context"}, activated=True) + + interrupt_state.deactivate() + + assert not interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {} + assert tru_context == exp_context + + +def test_interrupt_state_to_dict(): + interrupt_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + context={"test": "context"}, + activated=True, + ) + + tru_data = interrupt_state.to_dict() + exp_data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + assert tru_data == exp_data + + +def test_interrupt_state_from_dict(): + data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + + tru_state = _InterruptState.from_dict(data) + exp_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + context={"test": "context"}, + activated=True, + ) + assert tru_state == exp_state + + +def test_interrupt_state_resume(): + interrupt_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + activated=True, + ) + + prompt = [ + { + "interruptResponse": { + "interruptId": "test_id", + "response": "test response", + } + } + ] + interrupt_state.resume(prompt) + + tru_response = interrupt_state.interrupts["test_id"].response + exp_response = "test response" + assert tru_response == exp_response + + +def test_interrupt_state_resumse_deactivated(): + interrupt_state = _InterruptState(activated=False) + interrupt_state.resume([]) + + +def test_interrupt_state_resume_invalid_prompt(): + interrupt_state = _InterruptState(activated=True) + + exp_message = r"prompt_type= \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + interrupt_state.resume("invalid") + + +def test_interrupt_state_resume_invalid_content(): + interrupt_state = _InterruptState(activated=True) + + exp_message = r"content_types=<\['text'\]> \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + interrupt_state.resume([{"text": "invalid"}]) + + +def test_interrupt_resume_invalid_id(): + interrupt_state = _InterruptState(activated=True) + + exp_message = r"interrupt_id= \| no interrupt found" + with pytest.raises(KeyError, match=exp_message): + interrupt_state.resume([{"interruptResponse": {"interruptId": "invalid", "response": None}}]) diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index d25cf14bd..4d299a539 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -4,8 +4,8 @@ import pytest import strands -from strands.agent.interrupt import InterruptState from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, HookRegistry +from strands.interrupt import _InterruptState from strands.tools.registry import ToolRegistry from strands.types.tools import ToolContext @@ -104,7 +104,7 @@ def agent(tool_registry, hook_registry): mock_agent = unittest.mock.Mock() mock_agent.tool_registry = tool_registry mock_agent.hooks = hook_registry - mock_agent._interrupt_state = InterruptState() + mock_agent._interrupt_state = _InterruptState() return mock_agent diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 0d5c65689..a2a4c6213 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -11,8 +11,7 @@ import strands from strands import Agent -from strands.agent.interrupt import InterruptState -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState from strands.types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from strands.types.tools import AgentTool, ToolContext, ToolUse @@ -152,7 +151,7 @@ async def test_stream_interrupt(alist): tool_use = {"toolUseId": "test_tool_id"} mock_agent = MagicMock() - mock_agent._interrupt_state = InterruptState() + mock_agent._interrupt_state = _InterruptState() invocation_state = {"agent": mock_agent} @@ -179,7 +178,7 @@ async def test_stream_interrupt_resume(alist): tool_use = {"toolUseId": "test_tool_id"} mock_agent = MagicMock() - mock_agent._interrupt_state = InterruptState(interrupts={interrupt.id: interrupt}) + mock_agent._interrupt_state = _InterruptState(interrupts={interrupt.id: interrupt}) invocation_state = {"agent": mock_agent} diff --git a/tests/strands/types/test_interrupt.py b/tests/strands/types/test_interrupt.py index ade0fa5e8..ad31384b6 100644 --- a/tests/strands/types/test_interrupt.py +++ b/tests/strands/types/test_interrupt.py @@ -2,8 +2,7 @@ import pytest -from strands.agent.interrupt import InterruptState -from strands.interrupt import Interrupt, InterruptException +from strands.interrupt import Interrupt, InterruptException, _InterruptState from strands.types.interrupt import _Interruptible @@ -20,7 +19,7 @@ def interrupt(): @pytest.fixture def agent(): instance = unittest.mock.Mock() - instance._interrupt_state = InterruptState() + instance._interrupt_state = _InterruptState() return instance diff --git a/tests/strands/types/test_session.py b/tests/strands/types/test_session.py index 26d4062e4..3e5360742 100644 --- a/tests/strands/types/test_session.py +++ b/tests/strands/types/test_session.py @@ -3,8 +3,8 @@ from uuid import uuid4 from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager -from strands.agent.interrupt import InterruptState from strands.agent.state import AgentState +from strands.interrupt import _InterruptState from strands.types.session import ( Session, SessionAgent, @@ -101,7 +101,7 @@ def test_session_agent_from_agent(): agent.agent_id = "a1" agent.conversation_manager = unittest.mock.Mock(get_state=lambda: {"test": "conversation"}) agent.state = AgentState({"test": "state"}) - agent._interrupt_state = InterruptState(interrupts={}, context={}, activated=False) + agent._interrupt_state = _InterruptState(interrupts={}, context={}, activated=False) tru_session_agent = SessionAgent.from_agent(agent) exp_session_agent = SessionAgent( @@ -127,5 +127,5 @@ def test_session_agent_initialize_internal_state(): session_agent.initialize_internal_state(agent) tru_interrupt_state = agent._interrupt_state - exp_interrupt_state = InterruptState(interrupts={}, context={"test": "init"}, activated=False) + exp_interrupt_state = _InterruptState(interrupts={}, context={"test": "init"}, activated=False) assert tru_interrupt_state == exp_interrupt_state From 57e2081b7bdb9a2fbaa5af11026a67f5357fa025 Mon Sep 17 00:00:00 2001 From: Mackenzie Zastrow <3211021+zastrowm@users.noreply.github.com> Date: Wed, 12 Nov 2025 10:58:17 -0500 Subject: [PATCH 10/13] fix: Don't hang when MCP server returns 5xx (#1169) Fixes #995 where if a MCP tool_call receives a 5XX error from the server, the call hangs and never ends. The root cause is that Anthropic's MCP client - on receiving a 5XX - bubbles up an exception that ends up cancelling all TaskGroup tasks which results in the session/client/asyncio loop being torn down and the tool_call never resolves, thus the hang. The fix is two fold: - Detect that the situation occurs and trigger a close `close_future` future - Update all background_invokes to eagerly bail on `close_future` being triggered --------- Co-authored-by: Mackenzie Zastrow --- src/strands/tools/mcp/mcp_client.py | 71 +++++++++++++++++++++++------ tests_integ/mcp/test_mcp_client.py | 67 +++++++++++++++++++++++++++ 2 files changed, 125 insertions(+), 13 deletions(-) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 2fe006466..b16b9c2b4 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -119,10 +119,12 @@ def __init__( mcp_instrumentation() self._session_id = uuid.uuid4() self._log_debug_with_thread("initializing MCPClient connection") - # Main thread blocks until future completesock + # Main thread blocks until future completes self._init_future: futures.Future[None] = futures.Future() + # Set within the inner loop as it needs the asyncio loop + self._close_future: asyncio.futures.Future[None] | None = None + self._close_exception: None | Exception = None # Do not want to block other threads while close event is false - self._close_event = asyncio.Event() self._transport_callable = transport_callable self._background_thread: threading.Thread | None = None @@ -288,11 +290,12 @@ def stop( - _background_thread: Thread running the async event loop - _background_thread_session: MCP ClientSession (auto-closed by context manager) - _background_thread_event_loop: AsyncIO event loop in background thread - - _close_event: AsyncIO event to signal thread shutdown + - _close_future: AsyncIO future to signal thread shutdown + - _close_exception: Exception that caused the background thread shutdown; None if a normal shutdown occurred. - _init_future: Future for initialization synchronization Cleanup order: - 1. Signal close event to background thread (if session initialized) + 1. Signal close future to background thread (if session initialized) 2. Wait for background thread to complete 3. Reset all state for reuse @@ -303,13 +306,14 @@ def stop( """ self._log_debug_with_thread("exiting MCPClient context") - # Only try to signal close event if we have a background thread + # Only try to signal close future if we have a background thread if self._background_thread is not None: - # Signal close event if event loop exists + # Signal close future if event loop exists if self._background_thread_event_loop is not None: async def _set_close_event() -> None: - self._close_event.set() + if self._close_future and not self._close_future.done(): + self._close_future.set_result(None) # Not calling _invoke_on_background_thread since the session does not need to exist # we only need the thread and event loop to exist. @@ -317,11 +321,11 @@ async def _set_close_event() -> None: self._log_debug_with_thread("waiting for background thread to join") self._background_thread.join() + self._log_debug_with_thread("background thread is closed, MCPClient context exited") # Reset fields to allow instance reuse self._init_future = futures.Future() - self._close_event = asyncio.Event() self._background_thread = None self._background_thread_session = None self._background_thread_event_loop = None @@ -330,6 +334,11 @@ async def _set_close_event() -> None: self._tool_provider_started = False self._consumers = set() + if self._close_exception: + exception = self._close_exception + self._close_exception = None + raise RuntimeError("Connection to the MCP server was closed") from exception + def list_tools_sync( self, pagination_token: str | None = None, @@ -563,6 +572,10 @@ async def _async_background_thread(self) -> None: signals readiness to the main thread, and waits for a close signal. """ self._log_debug_with_thread("starting async background thread for MCP connection") + + # Initialized here so that it has the asyncio loop + self._close_future = asyncio.Future() + try: async with self._transport_callable() as (read_stream, write_stream, *_): self._log_debug_with_thread("transport connection established") @@ -583,8 +596,9 @@ async def _async_background_thread(self) -> None: self._log_debug_with_thread("waiting for close signal") # Keep background thread running until signaled to close. - # Thread is not blocked as this is an asyncio.Event not a threading.Event - await self._close_event.wait() + # Thread is not blocked as this a future + await self._close_future + self._log_debug_with_thread("close signal received") except Exception as e: # If we encounter an exception and the future is still running, @@ -592,6 +606,12 @@ async def _async_background_thread(self) -> None: if not self._init_future.done(): self._init_future.set_exception(e) else: + # _close_future is automatically cancelled by the framework which doesn't provide us with the useful + # exception, so instead we store the exception in a different field where stop() can read it + self._close_exception = e + if self._close_future and not self._close_future.done(): + self._close_future.set_result(None) + self._log_debug_with_thread( "encountered exception on background thread after initialization %s", str(e) ) @@ -601,7 +621,7 @@ def _background_task(self) -> None: This method creates a new event loop for the background thread, sets it as the current event loop, and runs the async_background_thread - coroutine until completion. In this case "until completion" means until the _close_event is set. + coroutine until completion. In this case "until completion" means until the _close_future is resolved. This allows for a long-running event loop. """ self._log_debug_with_thread("setting up background task event loop") @@ -699,9 +719,34 @@ def _log_debug_with_thread(self, msg: str, *args: Any, **kwargs: Any) -> None: ) def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures.Future[T]: - if self._background_thread_session is None or self._background_thread_event_loop is None: + # save a reference to this so that even if it's reset we have the original + close_future = self._close_future + + if ( + self._background_thread_session is None + or self._background_thread_event_loop is None + or close_future is None + ): raise MCPClientInitializationError("the client session was not initialized") - return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop) + + async def run_async() -> T: + # Fix for strands-agents/sdk-python/issues/995 - cancel all pending invocations if/when the session closes + invoke_event = asyncio.create_task(coro) + tasks: list[asyncio.Task | asyncio.Future] = [ + invoke_event, + close_future, + ] + + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + if done.pop() == close_future: + self._log_debug_with_thread("event loop for the server closed before the invoke completed") + raise RuntimeError("Connection to the MCP server was closed") + else: + return await invoke_event + + invoke_future = asyncio.run_coroutine_threadsafe(coro=run_async(), loop=self._background_thread_event_loop) + return invoke_future def _should_include_tool(self, tool: MCPAgentTool) -> bool: """Check if a tool should be included based on constructor filters.""" diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 2c9bb73e1..35cfd7e86 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -420,3 +420,70 @@ def transport_callback() -> MCPTransport: result = await streamable_http_client.call_tool_async(tool_use_id="123", name="timeout_tool") assert result["status"] == "error" assert result["content"][0]["text"] == "Tool execution failed: Connection closed" + + +def start_5xx_proxy_for_tool_calls(target_url: str, proxy_port: int): + """Starts a proxy that throws a 5XX when a tool call is invoked""" + import aiohttp + from aiohttp import web + + async def proxy_handler(request): + url = f"{target_url}{request.path_qs}" + + async with aiohttp.ClientSession() as session: + data = await request.read() + + if "tools/call" in f"{data}": + return web.Response(status=500, text="Internal Server Error") + + async with session.request( + method=request.method, url=url, headers=request.headers, data=data, allow_redirects=False + ) as resp: + print(f"Got request to {url} {data}") + response = web.StreamResponse(status=resp.status, headers=resp.headers) + await response.prepare(request) + + async for chunk in resp.content.iter_chunked(8192): + await response.write(chunk) + + return response + + app = web.Application() + app.router.add_route("*", "/{path:.*}", proxy_handler) + + web.run_app(app, host="127.0.0.1", port=proxy_port) + + +@pytest.mark.asyncio +async def test_streamable_http_mcp_client_with_500_error(): + import asyncio + import multiprocessing + + server_thread = threading.Thread( + target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True + ) + server_thread.start() + + proxy_process = multiprocessing.Process( + target=start_5xx_proxy_for_tool_calls, kwargs={"target_url": "http://127.0.0.1:8001", "proxy_port": 8002} + ) + proxy_process.start() + + try: + await asyncio.sleep(2) # wait for server to startup completely + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url="http://127.0.0.1:8002/mcp") + + streamable_http_client = MCPClient(transport_callback) + with pytest.raises(RuntimeError, match="Connection to the MCP server was closed"): + with streamable_http_client: + result = await streamable_http_client.call_tool_async( + tool_use_id="123", name="calculator", arguments={"x": 3, "y": 4} + ) + finally: + proxy_process.terminate() + proxy_process.join() + + assert result["status"] == "error" + assert result["content"][0]["text"] == "Tool execution failed: Connection to the MCP server was closed" From 8cae18cdc9a70cd892188485c2df47698a17af55 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Wed, 12 Nov 2025 18:26:20 +0200 Subject: [PATCH 11/13] fix(models): allow setter on system_prompt and system_prompt_content (#1171) --- src/strands/agent/agent.py | 33 +++++++++++++++++++++++++++++-- tests/strands/agent/test_agent.py | 31 +++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 2 deletions(-) diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index b7633d5e8..e13b9f6d8 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -287,8 +287,8 @@ def __init__( """ self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] - # initializing self.system_prompt for backwards compatibility - self.system_prompt, self._system_prompt_content = self._initialize_system_prompt(system_prompt) + # initializing self._system_prompt for backwards compatibility + self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(system_prompt) self._default_structured_output_model = structured_output_model self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME @@ -365,6 +365,35 @@ def __init__( self.hooks.add_hook(hook) self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) + @property + def system_prompt(self) -> str | None: + """Get the system prompt as a string for backwards compatibility. + + Returns the system prompt as a concatenated string when it contains text content, + or None if no text content is present. This maintains backwards compatibility + with existing code that expects system_prompt to be a string. + + Returns: + The system prompt as a string, or None if no text content exists. + """ + return self._system_prompt + + @system_prompt.setter + def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None: + """Set the system prompt and update internal content representation. + + Accepts either a string or list of SystemContentBlock objects. + When set, both the backwards-compatible string representation and the internal + content block representation are updated to maintain consistency. + + Args: + value: System prompt as string, list of SystemContentBlock objects, or None. + - str: Simple text prompt (most common use case) + - list[SystemContentBlock]: Content blocks with features like caching + - None: Clear the system prompt + """ + self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value) + @property def tool(self) -> ToolCaller: """Call tool as a function. diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index b96a04b21..d04f57948 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1221,6 +1221,37 @@ async def test_stream_async_multi_modal_input(mock_model, agent, agenerator, ali assert tru_message == exp_message +def test_system_prompt_setter_string(): + """Test that setting system_prompt with string updates both internal fields.""" + agent = Agent(system_prompt="initial prompt") + + agent.system_prompt = "updated prompt" + + assert agent.system_prompt == "updated prompt" + assert agent._system_prompt_content == [{"text": "updated prompt"}] + + +def test_system_prompt_setter_list(): + """Test that setting system_prompt with list updates both internal fields.""" + agent = Agent() + + content_blocks = [{"text": "You are helpful"}, {"cache_control": {"type": "ephemeral"}}] + agent.system_prompt = content_blocks + + assert agent.system_prompt == "You are helpful" + assert agent._system_prompt_content == content_blocks + + +def test_system_prompt_setter_none(): + """Test that setting system_prompt to None clears both internal fields.""" + agent = Agent(system_prompt="initial prompt") + + agent.system_prompt = None + + assert agent.system_prompt is None + assert agent._system_prompt_content is None + + @pytest.mark.asyncio async def test_stream_async_passes_invocation_state(agent, mock_model, mock_event_loop_cycle, agenerator, alist): mock_model.mock_stream.side_effect = [ From cee5145068b7a1fa991452c4dd150f956717060b Mon Sep 17 00:00:00 2001 From: Anirudh Konduru Date: Fri, 14 Nov 2025 15:14:13 -0500 Subject: [PATCH 12/13] feat: allow setting a timeout when creating MCPAgentTool (#1184) --- src/strands/tools/mcp/mcp_agent_tool.py | 12 +++++++- .../strands/tools/mcp/test_mcp_agent_tool.py | 29 ++++++++++++++++++- 2 files changed, 39 insertions(+), 2 deletions(-) diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index af0c069a1..bedd93f24 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -6,6 +6,7 @@ """ import logging +from datetime import timedelta from typing import TYPE_CHECKING, Any from mcp.types import Tool as MCPTool @@ -28,7 +29,13 @@ class MCPAgentTool(AgentTool): seamlessly within the agent framework. """ - def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: str | None = None) -> None: + def __init__( + self, + mcp_tool: MCPTool, + mcp_client: "MCPClient", + name_override: str | None = None, + timeout: timedelta | None = None, + ) -> None: """Initialize a new MCPAgentTool instance. Args: @@ -36,12 +43,14 @@ def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: st mcp_client: The MCP server connection to use for tool invocation name_override: Optional name to use for the agent tool (for disambiguation) If None, uses the original MCP tool name + timeout: Optional timeout duration for tool execution """ super().__init__() logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) self.mcp_tool = mcp_tool self.mcp_client = mcp_client self._agent_tool_name = name_override or mcp_tool.name + self.timeout = timeout @property def tool_name(self) -> str: @@ -105,5 +114,6 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw tool_use_id=tool_use["toolUseId"], name=self.mcp_tool.name, # Use original MCP name for server communication arguments=tool_use["input"], + read_timeout_seconds=self.timeout, ) yield ToolResultEvent(result) diff --git a/tests/strands/tools/mcp/test_mcp_agent_tool.py b/tests/strands/tools/mcp/test_mcp_agent_tool.py index 442a9919b..81a2d9afb 100644 --- a/tests/strands/tools/mcp/test_mcp_agent_tool.py +++ b/tests/strands/tools/mcp/test_mcp_agent_tool.py @@ -1,3 +1,4 @@ +from datetime import timedelta from unittest.mock import MagicMock import pytest @@ -88,5 +89,31 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist): assert tru_events == exp_events mock_mcp_client.call_tool_async.assert_called_once_with( - tool_use_id="test-123", name="test_tool", arguments={"param": "value"} + tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=None + ) + + +def test_timeout_initialization(mock_mcp_tool, mock_mcp_client): + timeout = timedelta(seconds=30) + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout) + assert agent_tool.timeout == timeout + + +def test_timeout_default_none(mock_mcp_tool, mock_mcp_client): + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client) + assert agent_tool.timeout is None + + +@pytest.mark.asyncio +async def test_stream_with_timeout(mock_mcp_tool, mock_mcp_client, alist): + timeout = timedelta(seconds=45) + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout) + tool_use = {"toolUseId": "test-456", "name": "test_tool", "input": {"param": "value"}} + + tru_events = await alist(agent_tool.stream(tool_use, {})) + exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)] + + assert tru_events == exp_events + mock_mcp_client.call_tool_async.assert_called_once_with( + tool_use_id="test-456", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=timeout ) From ded09346bbf689b0056157316830c32f1e1d3ad0 Mon Sep 17 00:00:00 2001 From: Dean Schmigelski Date: Mon, 17 Nov 2025 16:36:34 +0200 Subject: [PATCH 13/13] fix(litellm): add validation for stream parameter in LiteLLM (#1183) --- src/strands/models/litellm.py | 2 ++ tests/strands/models/test_litellm.py | 10 ++++++++++ 2 files changed, 12 insertions(+) diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index f2480c8d8..17f1bbb94 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -272,6 +272,8 @@ async def stream( logger.debug("invoking model") try: + if kwargs.get("stream") is False: + raise ValueError("stream parameter cannot be explicitly set to False") response = await litellm.acompletion(**self.client_args, **request) except ContextWindowExceededError as e: logger.warning("litellm client raised context window overflow") diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index f56438cf5..aafee1d17 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -408,6 +408,16 @@ async def test_context_window_maps_to_typed_exception(litellm_acompletion, model pass +@pytest.mark.asyncio +async def test_stream_raises_error_when_stream_is_false(model): + """Test that stream raises ValueError when stream parameter is explicitly False.""" + messages = [{"role": "user", "content": [{"text": "test"}]}] + + with pytest.raises(ValueError, match="stream parameter cannot be explicitly set to False"): + async for _ in model.stream(messages, stream=False): + pass + + def test_format_request_messages_with_system_prompt_content(): """Test format_request_messages with system_prompt_content parameter.""" messages = [{"role": "user", "content": [{"text": "Hello"}]}]