diff --git a/src/strands/_async.py b/src/strands/_async.py index 976487c374..141ca71b7e 100644 --- a/src/strands/_async.py +++ b/src/strands/_async.py @@ -1,6 +1,7 @@ """Private async execution utilities.""" import asyncio +import contextvars from concurrent.futures import ThreadPoolExecutor from typing import Awaitable, Callable, TypeVar @@ -27,5 +28,6 @@ def execute() -> T: return asyncio.run(execute_async()) with ThreadPoolExecutor() as executor: - future = executor.submit(execute) + context = contextvars.copy_context() + future = executor.submit(context.run, execute) return future.result() diff --git a/src/strands/agent/agent.py b/src/strands/agent/agent.py index 8137f1887a..e13b9f6d8b 100644 --- a/src/strands/agent/agent.py +++ b/src/strands/agent/agent.py @@ -46,6 +46,7 @@ HookRegistry, MessageAddedEvent, ) +from ..interrupt import _InterruptState from ..models.bedrock import BedrockModel from ..models.model import Model from ..session.session_manager import SessionManager @@ -60,7 +61,6 @@ from ..types.agent import AgentInput from ..types.content import ContentBlock, Message, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException -from ..types.interrupt import InterruptResponseContent from ..types.tools import ToolResult, ToolUse from ..types.traces import AttributeValue from .agent_result import AgentResult @@ -68,7 +68,6 @@ ConversationManager, SlidingWindowConversationManager, ) -from .interrupt import InterruptState from .state import AgentState logger = logging.getLogger(__name__) @@ -171,22 +170,21 @@ async def acall() -> ToolResult: self._agent._interrupt_state.deactivate() raise RuntimeError("cannot raise interrupt in direct tool call") - return tool_results[0] + tool_result = tool_results[0] - tool_result = run_async(acall) + if record_direct_tool_call is not None: + should_record_direct_tool_call = record_direct_tool_call + else: + should_record_direct_tool_call = self._agent.record_direct_tool_call - if record_direct_tool_call is not None: - should_record_direct_tool_call = record_direct_tool_call - else: - should_record_direct_tool_call = self._agent.record_direct_tool_call + if should_record_direct_tool_call: + # Create a record of this tool execution in the message history + await self._agent._record_tool_execution(tool_use, tool_result, user_message_override) - if should_record_direct_tool_call: - # Create a record of this tool execution in the message history - self._agent._record_tool_execution(tool_use, tool_result, user_message_override) + return tool_result - # Apply window management + tool_result = run_async(acall) self._agent.conversation_manager.apply_management(self._agent) - return tool_result return caller @@ -289,8 +287,8 @@ def __init__( """ self.model = BedrockModel() if not model else BedrockModel(model_id=model) if isinstance(model, str) else model self.messages = messages if messages is not None else [] - # initializing self.system_prompt for backwards compatibility - self.system_prompt, self._system_prompt_content = self._initialize_system_prompt(system_prompt) + # initializing self._system_prompt for backwards compatibility + self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(system_prompt) self._default_structured_output_model = structured_output_model self.agent_id = _identifier.validate(agent_id or _DEFAULT_AGENT_ID, _identifier.Identifier.AGENT) self.name = name or _DEFAULT_AGENT_NAME @@ -353,7 +351,7 @@ def __init__( self.hooks = HookRegistry() - self._interrupt_state = InterruptState() + self._interrupt_state = _InterruptState() # Initialize session management functionality self._session_manager = session_manager @@ -367,6 +365,35 @@ def __init__( self.hooks.add_hook(hook) self.hooks.invoke_callbacks(AgentInitializedEvent(agent=self)) + @property + def system_prompt(self) -> str | None: + """Get the system prompt as a string for backwards compatibility. + + Returns the system prompt as a concatenated string when it contains text content, + or None if no text content is present. This maintains backwards compatibility + with existing code that expects system_prompt to be a string. + + Returns: + The system prompt as a string, or None if no text content exists. + """ + return self._system_prompt + + @system_prompt.setter + def system_prompt(self, value: str | list[SystemContentBlock] | None) -> None: + """Set the system prompt and update internal content representation. + + Accepts either a string or list of SystemContentBlock objects. + When set, both the backwards-compatible string representation and the internal + content block representation are updated to maintain consistency. + + Args: + value: System prompt as string, list of SystemContentBlock objects, or None. + - str: Simple text prompt (most common use case) + - list[SystemContentBlock]: Content blocks with features like caching + - None: Clear the system prompt + """ + self._system_prompt, self._system_prompt_content = self._initialize_system_prompt(value) + @property def tool(self) -> ToolCaller: """Call tool as a function. @@ -534,7 +561,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu category=DeprecationWarning, stacklevel=2, ) - self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) with self.tracer.tracer.start_as_current_span( "execute_structured_output", kind=trace_api.SpanKind.CLIENT ) as structured_output_span: @@ -542,7 +569,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu if not self.messages and not prompt: raise ValueError("No conversation history or prompt provided") - temp_messages: Messages = self.messages + self._convert_prompt_to_messages(prompt) + temp_messages: Messages = self.messages + await self._convert_prompt_to_messages(prompt) structured_output_span.set_attributes( { @@ -575,7 +602,7 @@ async def structured_output_async(self, output_model: Type[T], prompt: AgentInpu return event["output"] finally: - self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self)) def cleanup(self) -> None: """Clean up resources used by the agent. @@ -641,7 +668,7 @@ async def stream_async( yield event["data"] ``` """ - self._resume_interrupt(prompt) + self._interrupt_state.resume(prompt) merged_state = {} if kwargs: @@ -658,7 +685,7 @@ async def stream_async( callback_handler = kwargs.get("callback_handler", self.callback_handler) # Process input and get message to add (if any) - messages = self._convert_prompt_to_messages(prompt) + messages = await self._convert_prompt_to_messages(prompt) self.trace_span = self._start_agent_trace_span(messages) @@ -684,38 +711,6 @@ async def stream_async( self._end_agent_trace_span(error=e) raise - def _resume_interrupt(self, prompt: AgentInput) -> None: - """Configure the interrupt state if resuming from an interrupt event. - - Args: - prompt: User responses if resuming from interrupt. - - Raises: - TypeError: If in interrupt state but user did not provide responses. - """ - if not self._interrupt_state.activated: - return - - if not isinstance(prompt, list): - raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's") - - invalid_types = [ - content_type for content in prompt for content_type in content if content_type != "interruptResponse" - ] - if invalid_types: - raise TypeError( - f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's" - ) - - for content in cast(list[InterruptResponseContent], prompt): - interrupt_id = content["interruptResponse"]["interruptId"] - interrupt_response = content["interruptResponse"]["response"] - - if interrupt_id not in self._interrupt_state.interrupts: - raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found") - - self._interrupt_state.interrupts[interrupt_id].response = interrupt_response - async def _run_loop( self, messages: Messages, @@ -732,13 +727,13 @@ async def _run_loop( Yields: Events from the event loop cycle. """ - self.hooks.invoke_callbacks(BeforeInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(BeforeInvocationEvent(agent=self)) try: yield InitEventLoopEvent() for message in messages: - self._append_message(message) + await self._append_message(message) structured_output_context = StructuredOutputContext( structured_output_model or self._default_structured_output_model @@ -764,7 +759,7 @@ async def _run_loop( finally: self.conversation_manager.apply_management(self) - self.hooks.invoke_callbacks(AfterInvocationEvent(agent=self)) + await self.hooks.invoke_callbacks_async(AfterInvocationEvent(agent=self)) async def _execute_event_loop_cycle( self, invocation_state: dict[str, Any], structured_output_context: StructuredOutputContext | None = None @@ -813,7 +808,7 @@ async def _execute_event_loop_cycle( if structured_output_context: structured_output_context.cleanup(self.tool_registry) - def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: + async def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: if self._interrupt_state.activated: return [] @@ -828,7 +823,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: tool_use_ids = [ content["toolUse"]["toolUseId"] for content in self.messages[-1]["content"] if "toolUse" in content ] - self._append_message( + await self._append_message( { "role": "user", "content": generate_missing_tool_result_content(tool_use_ids), @@ -859,7 +854,7 @@ def _convert_prompt_to_messages(self, prompt: AgentInput) -> Messages: raise ValueError("Input prompt must be of type: `str | list[Contentblock] | Messages | None`.") return messages - def _record_tool_execution( + async def _record_tool_execution( self, tool: ToolUse, tool_result: ToolResult, @@ -919,10 +914,10 @@ def _record_tool_execution( } # Add to message history - self._append_message(user_msg) - self._append_message(tool_use_msg) - self._append_message(tool_result_msg) - self._append_message(assistant_msg) + await self._append_message(user_msg) + await self._append_message(tool_use_msg) + await self._append_message(tool_result_msg) + await self._append_message(assistant_msg) def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: """Starts a trace span for the agent. @@ -938,6 +933,7 @@ def _start_agent_trace_span(self, messages: Messages) -> trace_api.Span: tools=self.tool_names, system_prompt=self.system_prompt, custom_trace_attributes=self.trace_attributes, + tools_config=self.tool_registry.get_all_tools_config(), ) def _end_agent_trace_span( @@ -1007,10 +1003,10 @@ def _initialize_system_prompt( else: return None, None - def _append_message(self, message: Message) -> None: + async def _append_message(self, message: Message) -> None: """Appends a message to the agent's list of messages and invokes the callbacks for the MessageCreatedEvent.""" self.messages.append(message) - self.hooks.invoke_callbacks(MessageAddedEvent(agent=self, message=message)) + await self.hooks.invoke_callbacks_async(MessageAddedEvent(agent=self, message=message)) def _redact_user_content(self, content: list[ContentBlock], redact_message: str) -> list[ContentBlock]: """Redact user content preserving toolResult blocks. diff --git a/src/strands/agent/interrupt.py b/src/strands/agent/interrupt.py deleted file mode 100644 index 3cec1541be..0000000000 --- a/src/strands/agent/interrupt.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Track the state of interrupt events raised by the user for human-in-the-loop workflows.""" - -from dataclasses import asdict, dataclass, field -from typing import Any - -from ..interrupt import Interrupt - - -@dataclass -class InterruptState: - """Track the state of interrupt events raised by the user. - - Note, interrupt state is cleared after resuming. - - Attributes: - interrupts: Interrupts raised by the user. - context: Additional context associated with an interrupt event. - activated: True if agent is in an interrupt state, False otherwise. - """ - - interrupts: dict[str, Interrupt] = field(default_factory=dict) - context: dict[str, Any] = field(default_factory=dict) - activated: bool = False - - def activate(self, context: dict[str, Any] | None = None) -> None: - """Activate the interrupt state. - - Args: - context: Context associated with the interrupt event. - """ - self.context = context or {} - self.activated = True - - def deactivate(self) -> None: - """Deacitvate the interrupt state. - - Interrupts and context are cleared. - """ - self.interrupts = {} - self.context = {} - self.activated = False - - def to_dict(self) -> dict[str, Any]: - """Serialize to dict for session management.""" - return asdict(self) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "InterruptState": - """Initiailize interrupt state from serialized interrupt state. - - Interrupt state can be serialized with the `to_dict` method. - """ - return cls( - interrupts={ - interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items() - }, - context=data["context"], - activated=data["activated"], - ) diff --git a/src/strands/event_loop/event_loop.py b/src/strands/event_loop/event_loop.py index 66174c09fc..562de24b87 100644 --- a/src/strands/event_loop/event_loop.py +++ b/src/strands/event_loop/event_loop.py @@ -227,7 +227,7 @@ async def event_loop_cycle( ) structured_output_context.set_forced_mode() logger.debug("Forcing structured output tool") - agent._append_message( + await agent._append_message( {"role": "user", "content": [{"text": "You must format the previous response as structured output."}]} ) @@ -322,7 +322,7 @@ async def _handle_model_execution( model_id=model_id, ) with trace_api.use_span(model_invoke_span): - agent.hooks.invoke_callbacks( + await agent.hooks.invoke_callbacks_async( BeforeModelCallEvent( agent=agent, ) @@ -347,7 +347,7 @@ async def _handle_model_execution( stop_reason, message, usage, metrics = event["stop"] invocation_state.setdefault("request_state", {}) - agent.hooks.invoke_callbacks( + await agent.hooks.invoke_callbacks_async( AfterModelCallEvent( agent=agent, stop_response=AfterModelCallEvent.ModelStopResponse( @@ -368,7 +368,7 @@ async def _handle_model_execution( if model_invoke_span: tracer.end_span_with_error(model_invoke_span, str(e), e) - agent.hooks.invoke_callbacks( + await agent.hooks.invoke_callbacks_async( AfterModelCallEvent( agent=agent, exception=e, @@ -402,7 +402,7 @@ async def _handle_model_execution( # Add the response message to the conversation agent.messages.append(message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=message)) + await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=message)) # Update metrics agent.event_loop_metrics.update_usage(usage) @@ -507,7 +507,7 @@ async def _handle_tool_execution( } agent.messages.append(tool_result_message) - agent.hooks.invoke_callbacks(MessageAddedEvent(agent=agent, message=tool_result_message)) + await agent.hooks.invoke_callbacks_async(MessageAddedEvent(agent=agent, message=tool_result_message)) yield ToolResultMessageEvent(message=tool_result_message) diff --git a/src/strands/hooks/registry.py b/src/strands/hooks/registry.py index 564be85cb9..1efc0bf5b0 100644 --- a/src/strands/hooks/registry.py +++ b/src/strands/hooks/registry.py @@ -7,9 +7,10 @@ via hook provider objects. """ +import inspect import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Generator, Generic, Protocol, Type, TypeVar +from typing import TYPE_CHECKING, Any, Awaitable, Generator, Generic, Protocol, Type, TypeVar from ..interrupt import Interrupt, InterruptException @@ -122,10 +123,15 @@ class HookCallback(Protocol, Generic[TEvent]): ```python def my_callback(event: StartRequestEvent) -> None: print(f"Request started for agent: {event.agent.name}") + + # Or + + async def my_callback(event: StartRequestEvent) -> None: + # await an async operation ``` """ - def __call__(self, event: TEvent) -> None: + def __call__(self, event: TEvent) -> None | Awaitable[None]: """Handle a hook event. Args: @@ -164,6 +170,10 @@ def my_handler(event: StartRequestEvent): registry.add_callback(StartRequestEvent, my_handler) ``` """ + # Related issue: https://github.com/strands-agents/sdk-python/issues/330 + if event_type.__name__ == "AgentInitializedEvent" and inspect.iscoroutinefunction(callback): + raise ValueError("AgentInitializedEvent can only be registered with a synchronous callback") + callbacks = self._registered_callbacks.setdefault(event_type, []) callbacks.append(callback) @@ -189,6 +199,52 @@ def register_hooks(self, registry: HookRegistry): """ hook.register_hooks(self) + async def invoke_callbacks_async(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]: + """Invoke all registered callbacks for the given event. + + This method finds all callbacks registered for the event's type and + invokes them in the appropriate order. For events with should_reverse_callbacks=True, + callbacks are invoked in reverse registration order. Any exceptions raised by callback + functions will propagate to the caller. + + Additionally, this method aggregates interrupts raised by the user to instantiate human-in-the-loop workflows. + + Args: + event: The event to dispatch to registered callbacks. + + Returns: + The event dispatched to registered callbacks and any interrupts raised by the user. + + Raises: + ValueError: If interrupt name is used more than once. + + Example: + ```python + event = StartRequestEvent(agent=my_agent) + await registry.invoke_callbacks_async(event) + ``` + """ + interrupts: dict[str, Interrupt] = {} + + for callback in self.get_callbacks_for(event): + try: + if inspect.iscoroutinefunction(callback): + await callback(event) + else: + callback(event) + + except InterruptException as exception: + interrupt = exception.interrupt + if interrupt.name in interrupts: + message = f"interrupt_name=<{interrupt.name}> | interrupt name used more than once" + logger.error(message) + raise ValueError(message) from exception + + # Each callback is allowed to raise their own interrupt. + interrupts[interrupt.name] = interrupt + + return event, list(interrupts.values()) + def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Interrupt]]: """Invoke all registered callbacks for the given event. @@ -206,6 +262,7 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte The event dispatched to registered callbacks and any interrupts raised by the user. Raises: + RuntimeError: If at least one callback is async. ValueError: If interrupt name is used more than once. Example: @@ -214,9 +271,13 @@ def invoke_callbacks(self, event: TInvokeEvent) -> tuple[TInvokeEvent, list[Inte registry.invoke_callbacks(event) ``` """ + callbacks = list(self.get_callbacks_for(event)) interrupts: dict[str, Interrupt] = {} - for callback in self.get_callbacks_for(event): + if any(inspect.iscoroutinefunction(callback) for callback in callbacks): + raise RuntimeError(f"event=<{event}> | use invoke_callbacks_async to invoke async callback") + + for callback in callbacks: try: callback(event) except InterruptException as exception: diff --git a/src/strands/interrupt.py b/src/strands/interrupt.py index f0ed52389c..919927e1af 100644 --- a/src/strands/interrupt.py +++ b/src/strands/interrupt.py @@ -1,7 +1,11 @@ """Human-in-the-loop interrupt system for agent workflows.""" -from dataclasses import asdict, dataclass -from typing import Any +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Any, cast + +if TYPE_CHECKING: + from .types.agent import AgentInput + from .types.interrupt import InterruptResponseContent @dataclass @@ -31,3 +35,89 @@ class InterruptException(Exception): def __init__(self, interrupt: Interrupt) -> None: """Set the interrupt.""" self.interrupt = interrupt + + +@dataclass +class _InterruptState: + """Track the state of interrupt events raised by the user. + + Note, interrupt state is cleared after resuming. + + Attributes: + interrupts: Interrupts raised by the user. + context: Additional context associated with an interrupt event. + activated: True if agent is in an interrupt state, False otherwise. + """ + + interrupts: dict[str, Interrupt] = field(default_factory=dict) + context: dict[str, Any] = field(default_factory=dict) + activated: bool = False + + def activate(self, context: dict[str, Any] | None = None) -> None: + """Activate the interrupt state. + + Args: + context: Context associated with the interrupt event. + """ + self.context = context or {} + self.activated = True + + def deactivate(self) -> None: + """Deacitvate the interrupt state. + + Interrupts and context are cleared. + """ + self.interrupts = {} + self.context = {} + self.activated = False + + def resume(self, prompt: "AgentInput") -> None: + """Configure the interrupt state if resuming from an interrupt event. + + Args: + prompt: User responses if resuming from interrupt. + + Raises: + TypeError: If in interrupt state but user did not provide responses. + """ + if not self.activated: + return + + if not isinstance(prompt, list): + raise TypeError(f"prompt_type={type(prompt)} | must resume from interrupt with list of interruptResponse's") + + invalid_types = [ + content_type for content in prompt for content_type in content if content_type != "interruptResponse" + ] + if invalid_types: + raise TypeError( + f"content_types=<{invalid_types}> | must resume from interrupt with list of interruptResponse's" + ) + + contents = cast(list["InterruptResponseContent"], prompt) + for content in contents: + interrupt_id = content["interruptResponse"]["interruptId"] + interrupt_response = content["interruptResponse"]["response"] + + if interrupt_id not in self.interrupts: + raise KeyError(f"interrupt_id=<{interrupt_id}> | no interrupt found") + + self.interrupts[interrupt_id].response = interrupt_response + + def to_dict(self) -> dict[str, Any]: + """Serialize to dict for session management.""" + return asdict(self) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "_InterruptState": + """Initiailize interrupt state from serialized interrupt state. + + Interrupt state can be serialized with the `to_dict` method. + """ + return cls( + interrupts={ + interrupt_id: Interrupt(**interrupt_data) for interrupt_id, interrupt_data in data["interrupts"].items() + }, + context=data["context"], + activated=data["activated"], + ) diff --git a/src/strands/models/anthropic.py b/src/strands/models/anthropic.py index 48351da19c..68b2347296 100644 --- a/src/strands/models/anthropic.py +++ b/src/strands/models/anthropic.py @@ -39,6 +39,7 @@ class AnthropicModel(Model): } OVERFLOW_MESSAGES = { + "prompt is too long:", "input is too long", "input length exceeds context window", "input and output tokens exceed your context limit", diff --git a/src/strands/models/litellm.py b/src/strands/models/litellm.py index 7a8c0ae037..17f1bbb948 100644 --- a/src/strands/models/litellm.py +++ b/src/strands/models/litellm.py @@ -14,9 +14,10 @@ from typing_extensions import Unpack, override from ..tools import convert_pydantic_to_tool_spec -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock +from ..types.event_loop import Usage from ..types.exceptions import ContextWindowOverflowException -from ..types.streaming import StreamEvent +from ..types.streaming import MetadataEvent, StreamEvent from ..types.tools import ToolChoice, ToolSpec from ._validation import validate_config_keys from .openai import OpenAIModel @@ -81,11 +82,12 @@ def get_config(self) -> LiteLLMConfig: @override @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format a LiteLLM content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: LiteLLM formatted content block. @@ -131,6 +133,113 @@ def _stream_switch_content(self, data_type: str, prev_data_type: str | None) -> return chunks, data_type + @override + @classmethod + def _format_system_messages( + cls, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format system messages for LiteLLM with cache point support. + + Args: + system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + List of formatted system messages. + """ + # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + system_content: list[dict[str, Any]] = [] + for block in system_prompt_content or []: + if "text" in block: + system_content.append({"type": "text", "text": block["text"]}) + elif "cachePoint" in block and block["cachePoint"].get("type") == "default": + # Apply cache control to the immediately preceding content block + # for LiteLLM/Anthropic compatibility + if system_content: + system_content[-1]["cache_control"] = {"type": "ephemeral"} + + # Create single system message with content array rather than mulitple system messages + return [{"role": "system", "content": system_content}] if system_content else [] + + @override + @classmethod + def format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format a LiteLLM compatible messages array with cache point support. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model (for legacy compatibility). + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + A LiteLLM compatible messages array. + """ + formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) + formatted_messages.extend(cls._format_regular_messages(messages)) + + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] + + @override + def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: + """Format a LiteLLM response event into a standardized message chunk. + + This method overrides OpenAI's format_chunk to handle the metadata case + with prompt caching support. All other chunk types use the parent implementation. + + Args: + event: A response event from the LiteLLM model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + The formatted chunk. + + Raises: + RuntimeError: If chunk_type is not recognized. + """ + # Handle metadata case with prompt caching support + if event["chunk_type"] == "metadata": + usage_data: Usage = { + "inputTokens": event["data"].prompt_tokens, + "outputTokens": event["data"].completion_tokens, + "totalTokens": event["data"].total_tokens, + } + + # Only LiteLLM over Anthropic supports cache write tokens + # Waiting until a more general approach is available to set cacheWriteInputTokens + + if tokens_details := getattr(event["data"], "prompt_tokens_details", None): + if cached := getattr(tokens_details, "cached_tokens", None): + usage_data["cacheReadInputTokens"] = cached + if creation := getattr(tokens_details, "cache_creation_tokens", None): + usage_data["cacheWriteInputTokens"] = creation + + return StreamEvent( + metadata=MetadataEvent( + metrics={ + "latencyMs": 0, # TODO + }, + usage=usage_data, + ) + ) + # For all other cases, use the parent implementation + return super().format_chunk(event) + @override async def stream( self, @@ -139,6 +248,7 @@ async def stream( system_prompt: Optional[str] = None, *, tool_choice: ToolChoice | None = None, + system_prompt_content: Optional[list[SystemContentBlock]] = None, **kwargs: Any, ) -> AsyncGenerator[StreamEvent, None]: """Stream conversation with the LiteLLM model. @@ -148,17 +258,22 @@ async def stream( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. + system_prompt_content: System prompt content blocks to provide context to the model. **kwargs: Additional keyword arguments for future extensibility. Yields: Formatted message chunks from the model. """ logger.debug("formatting request") - request = self.format_request(messages, tool_specs, system_prompt, tool_choice) + request = self.format_request( + messages, tool_specs, system_prompt, tool_choice, system_prompt_content=system_prompt_content + ) logger.debug("request=<%s>", request) logger.debug("invoking model") try: + if kwargs.get("stream") is False: + raise ValueError("stream parameter cannot be explicitly set to False") response = await litellm.acompletion(**self.client_args, **request) except ContextWindowExceededError as e: logger.warning("litellm client raised context window overflow") diff --git a/src/strands/models/openai.py b/src/strands/models/openai.py index 1efe641e6a..435c82cabf 100644 --- a/src/strands/models/openai.py +++ b/src/strands/models/openai.py @@ -14,7 +14,7 @@ from pydantic import BaseModel from typing_extensions import Unpack, override -from ..types.content import ContentBlock, Messages +from ..types.content import ContentBlock, Messages, SystemContentBlock from ..types.exceptions import ContextWindowOverflowException, ModelThrottledException from ..types.streaming import StreamEvent from ..types.tools import ToolChoice, ToolResult, ToolSpec, ToolUse @@ -89,11 +89,12 @@ def get_config(self) -> OpenAIConfig: return cast(OpenAIModel.OpenAIConfig, self.config) @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible content block. @@ -131,11 +132,12 @@ def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any] raise TypeError(f"content_type=<{next(iter(content))}> | unsupported type") @classmethod - def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: + def format_request_message_tool_call(cls, tool_use: ToolUse, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible tool call. Args: tool_use: Tool use requested by the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible tool call. @@ -150,11 +152,12 @@ def format_request_message_tool_call(cls, tool_use: ToolUse) -> dict[str, Any]: } @classmethod - def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> dict[str, Any]: """Format an OpenAI compatible tool message. Args: tool_result: Tool result collected from a tool execution. + **kwargs: Additional keyword arguments for future extensibility. Returns: OpenAI compatible tool message. @@ -198,18 +201,46 @@ def _format_request_tool_choice(cls, tool_choice: ToolChoice | None) -> dict[str return {"tool_choice": "auto"} @classmethod - def format_request_messages(cls, messages: Messages, system_prompt: Optional[str] = None) -> list[dict[str, Any]]: - """Format an OpenAI compatible messages array. + def _format_system_messages( + cls, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format system messages for OpenAI-compatible providers. Args: - messages: List of message objects to be processed by the model. system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: - An OpenAI compatible messages array. + List of formatted system messages. + """ + # Handle backward compatibility: if system_prompt is provided but system_prompt_content is None + if system_prompt and system_prompt_content is None: + system_prompt_content = [{"text": system_prompt}] + + # TODO: Handle caching blocks https://github.com/strands-agents/sdk-python/issues/1140 + return [ + {"role": "system", "content": content["text"]} + for content in system_prompt_content or [] + if "text" in content + ] + + @classmethod + def _format_regular_messages(cls, messages: Messages, **kwargs: Any) -> list[dict[str, Any]]: + """Format regular messages for OpenAI-compatible providers. + + Args: + messages: List of message objects to be processed by the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + List of formatted messages. """ - formatted_messages: list[dict[str, Any]] - formatted_messages = [{"role": "system", "content": system_prompt}] if system_prompt else [] + formatted_messages = [] for message in messages: contents = message["content"] @@ -242,14 +273,42 @@ def format_request_messages(cls, messages: Messages, system_prompt: Optional[str formatted_messages.append(formatted_message) formatted_messages.extend(formatted_tool_messages) + return formatted_messages + + @classmethod + def format_request_messages( + cls, + messages: Messages, + system_prompt: Optional[str] = None, + *, + system_prompt_content: Optional[list[SystemContentBlock]] = None, + **kwargs: Any, + ) -> list[dict[str, Any]]: + """Format an OpenAI compatible messages array. + + Args: + messages: List of message objects to be processed by the model. + system_prompt: System prompt to provide context to the model. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. + + Returns: + An OpenAI compatible messages array. + """ + formatted_messages = cls._format_system_messages(system_prompt, system_prompt_content=system_prompt_content) + formatted_messages.extend(cls._format_regular_messages(messages)) + return [message for message in formatted_messages if message["content"] or "tool_calls" in message] def format_request( self, messages: Messages, - tool_specs: Optional[list[ToolSpec]] = None, - system_prompt: Optional[str] = None, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, tool_choice: ToolChoice | None = None, + *, + system_prompt_content: list[SystemContentBlock] | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Format an OpenAI compatible chat streaming request. @@ -258,6 +317,8 @@ def format_request( tool_specs: List of tool specifications to make available to the model. system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. + system_prompt_content: System prompt content blocks to provide context to the model. + **kwargs: Additional keyword arguments for future extensibility. Returns: An OpenAI compatible chat streaming request. @@ -267,7 +328,9 @@ def format_request( format. """ return { - "messages": self.format_request_messages(messages, system_prompt), + "messages": self.format_request_messages( + messages, system_prompt, system_prompt_content=system_prompt_content + ), "model": self.config["model_id"], "stream": True, "stream_options": {"include_usage": True}, @@ -286,11 +349,12 @@ def format_request( **cast(dict[str, Any], self.config.get("params", {})), } - def format_chunk(self, event: dict[str, Any]) -> StreamEvent: + def format_chunk(self, event: dict[str, Any], **kwargs: Any) -> StreamEvent: """Format an OpenAI response event into a standardized message chunk. Args: event: A response event from the OpenAI compatible model. + **kwargs: Additional keyword arguments for future extensibility. Returns: The formatted chunk. diff --git a/src/strands/models/sagemaker.py b/src/strands/models/sagemaker.py index 25b3ca7ced..7f8b8ff51a 100644 --- a/src/strands/models/sagemaker.py +++ b/src/strands/models/sagemaker.py @@ -202,6 +202,7 @@ def format_request( tool_specs: Optional[list[ToolSpec]] = None, system_prompt: Optional[str] = None, tool_choice: ToolChoice | None = None, + **kwargs: Any, ) -> dict[str, Any]: """Format an Amazon SageMaker chat streaming request. @@ -211,6 +212,7 @@ def format_request( system_prompt: System prompt to provide context to the model. tool_choice: Selection strategy for tool invocation. **Note: This parameter is accepted for interface consistency but is currently ignored for this model provider.** + **kwargs: Additional keyword arguments for future extensibility. Returns: An Amazon SageMaker chat streaming request. @@ -501,11 +503,12 @@ async def stream( @override @classmethod - def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: + def format_request_tool_message(cls, tool_result: ToolResult, **kwargs: Any) -> dict[str, Any]: """Format a SageMaker compatible tool message. Args: tool_result: Tool result collected from a tool execution. + **kwargs: Additional keyword arguments for future extensibility. Returns: SageMaker compatible tool message with content as a string. @@ -531,11 +534,12 @@ def format_request_tool_message(cls, tool_result: ToolResult) -> dict[str, Any]: @override @classmethod - def format_request_message_content(cls, content: ContentBlock) -> dict[str, Any]: + def format_request_message_content(cls, content: ContentBlock, **kwargs: Any) -> dict[str, Any]: """Format a content block. Args: content: Message content. + **kwargs: Additional keyword arguments for future extensibility. Returns: Formatted content block. diff --git a/src/strands/multiagent/graph.py b/src/strands/multiagent/graph.py index b421b70c1a..9f28876bf7 100644 --- a/src/strands/multiagent/graph.py +++ b/src/strands/multiagent/graph.py @@ -453,7 +453,7 @@ def __init__( self._resume_from_session = False self.id = id - self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self)) + run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -516,7 +516,7 @@ async def stream_async( if invocation_state is None: invocation_state = {} - self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state)) + await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state)) logger.debug("task=<%s> | starting graph execution", task) @@ -569,7 +569,7 @@ async def stream_async( raise finally: self.state.execution_time = round((time.time() - start_time) * 1000) - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self)) + await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self)) self._resume_from_session = False self._resume_next_nodes.clear() @@ -776,7 +776,7 @@ def _is_node_ready_with_conditions(self, node: GraphNode, completed_batch: list[ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) -> AsyncIterator[Any]: """Execute a single node and yield TypedEvent objects.""" - self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, node.node_id, invocation_state)) + await self.hooks.invoke_callbacks_async(BeforeNodeCallEvent(self, node.node_id, invocation_state)) # Reset the node's state if reset_on_revisit is enabled, and it's being revisited if self.reset_on_revisit and node in self.state.completed_nodes: @@ -920,7 +920,7 @@ async def _execute_node(self, node: GraphNode, invocation_state: dict[str, Any]) raise finally: - self.hooks.invoke_callbacks(AfterNodeCallEvent(self, node.node_id, invocation_state)) + await self.hooks.invoke_callbacks_async(AfterNodeCallEvent(self, node.node_id, invocation_state)) def _accumulate_metrics(self, node_result: NodeResult) -> None: """Accumulate metrics from a node result.""" diff --git a/src/strands/multiagent/swarm.py b/src/strands/multiagent/swarm.py index accd564630..cb5b36839e 100644 --- a/src/strands/multiagent/swarm.py +++ b/src/strands/multiagent/swarm.py @@ -273,7 +273,7 @@ def __init__( self._setup_swarm(nodes) self._inject_swarm_tools() - self.hooks.invoke_callbacks(MultiAgentInitializedEvent(self)) + run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self))) def __call__( self, task: str | list[ContentBlock], invocation_state: dict[str, Any] | None = None, **kwargs: Any @@ -336,7 +336,7 @@ async def stream_async( if invocation_state is None: invocation_state = {} - self.hooks.invoke_callbacks(BeforeMultiAgentInvocationEvent(self, invocation_state)) + await self.hooks.invoke_callbacks_async(BeforeMultiAgentInvocationEvent(self, invocation_state)) logger.debug("starting swarm execution") @@ -375,7 +375,7 @@ async def stream_async( raise finally: self.state.execution_time = round((time.time() - self.state.start_time) * 1000) - self.hooks.invoke_callbacks(AfterMultiAgentInvocationEvent(self, invocation_state)) + await self.hooks.invoke_callbacks_async(AfterMultiAgentInvocationEvent(self, invocation_state)) self._resume_from_session = False # Yield final result after execution_time is set @@ -687,7 +687,9 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato # TODO: Implement cancellation token to stop _execute_node from continuing try: # Execute with timeout wrapper for async generator streaming - self.hooks.invoke_callbacks(BeforeNodeCallEvent(self, current_node.node_id, invocation_state)) + await self.hooks.invoke_callbacks_async( + BeforeNodeCallEvent(self, current_node.node_id, invocation_state) + ) node_stream = self._stream_with_timeout( self._execute_node(current_node, self.state.task, invocation_state), self.node_timeout, @@ -699,7 +701,9 @@ async def _execute_swarm(self, invocation_state: dict[str, Any]) -> AsyncIterato self.state.node_history.append(current_node) # After self.state add current node, swarm state finish updating, we persist here - self.hooks.invoke_callbacks(AfterNodeCallEvent(self, current_node.node_id, invocation_state)) + await self.hooks.invoke_callbacks_async( + AfterNodeCallEvent(self, current_node.node_id, invocation_state) + ) logger.debug("node=<%s> | node execution completed", current_node.node_id) diff --git a/src/strands/telemetry/tracer.py b/src/strands/telemetry/tracer.py index 9cefc69115..c47a10c3fe 100644 --- a/src/strands/telemetry/tracer.py +++ b/src/strands/telemetry/tracer.py @@ -79,11 +79,12 @@ class Tracer: When the OTEL_EXPORTER_OTLP_ENDPOINT environment variable is set, traces are sent to the OTLP endpoint. + + Both attributes are controlled by including "gen_ai_latest_experimental" or "gen_ai_tool_definitions", + respectively, in the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. """ - def __init__( - self, - ) -> None: + def __init__(self) -> None: """Initialize the tracer.""" self.service_name = __name__ self.tracer_provider: Optional[trace_api.TracerProvider] = None @@ -92,17 +93,19 @@ def __init__( ThreadingInstrumentor().instrument() # Read OTEL_SEMCONV_STABILITY_OPT_IN environment variable - self.use_latest_genai_conventions = self._parse_semconv_opt_in() + opt_in_values = self._parse_semconv_opt_in() + ## To-do: should not set below attributes directly, use env var instead + self.use_latest_genai_conventions = "gen_ai_latest_experimental" in opt_in_values + self._include_tool_definitions = "gen_ai_tool_definitions" in opt_in_values - def _parse_semconv_opt_in(self) -> bool: + def _parse_semconv_opt_in(self) -> set[str]: """Parse the OTEL_SEMCONV_STABILITY_OPT_IN environment variable. Returns: - Set of opt-in values from the environment variable + A set of opt-in values from the environment variable. """ opt_in_env = os.getenv("OTEL_SEMCONV_STABILITY_OPT_IN", "") - - return "gen_ai_latest_experimental" in opt_in_env + return {value.strip() for value in opt_in_env.split(",")} def _start_span( self, @@ -551,6 +554,7 @@ def start_agent_span( model_id: Optional[str] = None, tools: Optional[list] = None, custom_trace_attributes: Optional[Mapping[str, AttributeValue]] = None, + tools_config: Optional[dict] = None, **kwargs: Any, ) -> Span: """Start a new span for an agent invocation. @@ -561,6 +565,7 @@ def start_agent_span( model_id: Optional model identifier. tools: Optional list of tools being used. custom_trace_attributes: Optional mapping of custom trace attributes to include in the span. + tools_config: Optional dictionary of tool configurations. **kwargs: Additional attributes to add to the span. Returns: @@ -577,8 +582,15 @@ def start_agent_span( attributes["gen_ai.request.model"] = model_id if tools: - tools_json = serialize(tools) - attributes["gen_ai.agent.tools"] = tools_json + attributes["gen_ai.agent.tools"] = serialize(tools) + + if self._include_tool_definitions and tools_config: + try: + tool_definitions = self._construct_tool_definitions(tools_config) + attributes["gen_ai.tool.definitions"] = serialize(tool_definitions) + except Exception: + # A failure in telemetry should not crash the agent + logger.warning("failed to attach tool metadata to agent span", exc_info=True) # Add custom trace attributes if provided if custom_trace_attributes: @@ -649,6 +661,18 @@ def end_agent_span( self._end_span(span, attributes, error) + def _construct_tool_definitions(self, tools_config: dict) -> list[dict[str, Any]]: + """Constructs a list of tool definitions from the provided tools_config.""" + return [ + { + "name": name, + "description": spec.get("description"), + "inputSchema": spec.get("inputSchema"), + "outputSchema": spec.get("outputSchema"), + } + for name, spec in tools_config.items() + ] + def start_multiagent_span( self, task: str | list[ContentBlock], diff --git a/src/strands/tools/decorator.py b/src/strands/tools/decorator.py index 5c49f4b581..8dc933f517 100644 --- a/src/strands/tools/decorator.py +++ b/src/strands/tools/decorator.py @@ -45,6 +45,7 @@ def my_tool(param1: str, param2: int = 42) -> dict: import inspect import logging from typing import ( + Annotated, Any, Callable, Generic, @@ -54,12 +55,15 @@ def my_tool(param1: str, param2: int = 42) -> dict: TypeVar, Union, cast, + get_args, + get_origin, get_type_hints, overload, ) import docstring_parser from pydantic import BaseModel, Field, create_model +from pydantic.fields import FieldInfo from typing_extensions import override from ..interrupt import InterruptException @@ -105,15 +109,66 @@ def __init__(self, func: Callable[..., Any], context_param: str | None = None) - # Parse the docstring with docstring_parser doc_str = inspect.getdoc(func) or "" self.doc = docstring_parser.parse(doc_str) - - # Get parameter descriptions from parsed docstring - self.param_descriptions = { + self.param_descriptions: dict[str, str] = { param.arg_name: param.description or f"Parameter {param.arg_name}" for param in self.doc.params } # Create a Pydantic model for validation self.input_model = self._create_input_model() + def _extract_annotated_metadata( + self, annotation: Any, param_name: str, param_default: Any + ) -> tuple[Any, FieldInfo]: + """Extracts type and a simple string description from an Annotated type hint. + + Returns: + A tuple of (actual_type, field_info), where field_info is a new, simple + Pydantic FieldInfo instance created from the extracted metadata. + """ + actual_type = annotation + description: str | None = None + + if get_origin(annotation) is Annotated: + args = get_args(annotation) + actual_type = args[0] + + # Look through metadata for a string description or a FieldInfo object + for meta in args[1:]: + if isinstance(meta, str): + description = meta + elif isinstance(meta, FieldInfo): + # --- Future Contributor Note --- + # We are explicitly blocking the use of `pydantic.Field` within `Annotated` + # because of the complexities of Pydantic v2's immutable Core Schema. + # + # Once a Pydantic model's schema is built, its `FieldInfo` objects are + # effectively frozen. Attempts to mutate a `FieldInfo` object after + # creation (e.g., by copying it and setting `.description` or `.default`) + # are unreliable because the underlying Core Schema does not see these changes. + # + # The correct way to support this would be to reliably extract all + # constraints (ge, le, pattern, etc.) from the original FieldInfo and + # rebuild a new one from scratch. However, these constraints are not + # stored as public attributes, making them difficult to inspect reliably. + # + # Deferring this complexity until there is clear demand and a robust + # pattern for inspecting FieldInfo constraints is established. + raise NotImplementedError( + "Using pydantic.Field within Annotated is not yet supported for tool decorators. " + "Please use a simple string for the description, or define constraints in the function's " + "docstring." + ) + + # Determine the final description with a clear priority order + # Priority: 1. Annotated string -> 2. Docstring -> 3. Fallback + final_description = description + if final_description is None: + final_description = self.param_descriptions.get(param_name) or f"Parameter {param_name}" + # Create FieldInfo object from scratch + final_field = Field(default=param_default, description=final_description) + + return actual_type, final_field + def _validate_signature(self) -> None: """Verify that ToolContext is used correctly in the function signature.""" for param in self.signature.parameters.values(): @@ -146,24 +201,73 @@ def _create_input_model(self) -> Type[BaseModel]: if self._is_special_parameter(name): continue - # Get parameter type and default - param_type = self.type_hints.get(name, Any) + # Use param.annotation directly to get the raw type hint. Using get_type_hints() + # can cause inconsistent behavior across Python versions for complex Annotated types. + param_type = param.annotation + if param_type is inspect.Parameter.empty: + param_type = Any default = ... if param.default is inspect.Parameter.empty else param.default - description = self.param_descriptions.get(name, f"Parameter {name}") - # Create Field with description and default - field_definitions[name] = (param_type, Field(default=default, description=description)) + actual_type, field_info = self._extract_annotated_metadata(param_type, name, default) + field_definitions[name] = (actual_type, field_info) - # Create model name based on function name model_name = f"{self.func.__name__.capitalize()}Tool" - # Create and return the model if field_definitions: return create_model(model_name, **field_definitions) else: - # Handle case with no parameters return create_model(model_name) + def _extract_description_from_docstring(self) -> str: + """Extract the docstring excluding only the Args section. + + This method uses the parsed docstring to extract everything except + the Args/Arguments/Parameters section, preserving Returns, Raises, + Examples, and other sections. + + Returns: + The description text, or the function name if no description is available. + """ + func_name = self.func.__name__ + + # Fallback: try to extract manually from raw docstring + raw_docstring = inspect.getdoc(self.func) + if raw_docstring: + lines = raw_docstring.strip().split("\n") + result_lines = [] + skip_args_section = False + + for line in lines: + stripped_line = line.strip() + + # Check if we're starting the Args section + if stripped_line.lower().startswith(("args:", "arguments:", "parameters:", "param:", "params:")): + skip_args_section = True + continue + + # Check if we're starting a new section (not Args) + elif ( + stripped_line.lower().startswith(("returns:", "return:", "yields:", "yield:")) + or stripped_line.lower().startswith(("raises:", "raise:", "except:", "exceptions:")) + or stripped_line.lower().startswith(("examples:", "example:", "note:", "notes:")) + or stripped_line.lower().startswith(("see also:", "seealso:", "references:", "ref:")) + ): + skip_args_section = False + result_lines.append(line) + continue + + # If we're not in the Args section, include the line + if not skip_args_section: + result_lines.append(line) + + # Join and clean up the description + description = "\n".join(result_lines).strip() + if description: + return description + + # Final fallback: use function name + return func_name + def extract_metadata(self) -> ToolSpec: """Extract metadata from the function to create a tool specification. @@ -173,7 +277,7 @@ def extract_metadata(self) -> ToolSpec: The specification includes: - name: The function name (or custom override) - - description: The function's docstring + - description: The function's docstring description (excluding Args) - inputSchema: A JSON schema describing the expected parameters Returns: @@ -181,12 +285,8 @@ def extract_metadata(self) -> ToolSpec: """ func_name = self.func.__name__ - # Extract function description from docstring, preserving paragraph breaks - description = inspect.getdoc(self.func) - if description: - description = description.strip() - else: - description = func_name + # Extract function description from parsed docstring, excluding Args section and beyond + description = self._extract_description_from_docstring() # Get schema directly from the Pydantic model input_schema = self.input_model.model_json_schema() diff --git a/src/strands/tools/executors/_executor.py b/src/strands/tools/executors/_executor.py index f9a482558f..87c38990db 100644 --- a/src/strands/tools/executors/_executor.py +++ b/src/strands/tools/executors/_executor.py @@ -85,7 +85,7 @@ async def _stream( } ) - before_event, interrupts = agent.hooks.invoke_callbacks( + before_event, interrupts = await agent.hooks.invoke_callbacks_async( BeforeToolCallEvent( agent=agent, selected_tool=tool_func, @@ -109,7 +109,7 @@ async def _stream( "status": "error", "content": [{"text": cancel_message}], } - after_event, _ = agent.hooks.invoke_callbacks( + after_event, _ = await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( agent=agent, tool_use=tool_use, @@ -147,7 +147,7 @@ async def _stream( "status": "error", "content": [{"text": f"Unknown tool: {tool_name}"}], } - after_event, _ = agent.hooks.invoke_callbacks( + after_event, _ = await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -184,7 +184,7 @@ async def _stream( result = cast(ToolResult, event) - after_event, _ = agent.hooks.invoke_callbacks( + after_event, _ = await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, @@ -204,7 +204,7 @@ async def _stream( "status": "error", "content": [{"text": f"Error: {str(e)}"}], } - after_event, _ = agent.hooks.invoke_callbacks( + after_event, _ = await agent.hooks.invoke_callbacks_async( AfterToolCallEvent( agent=agent, selected_tool=selected_tool, diff --git a/src/strands/tools/mcp/mcp_agent_tool.py b/src/strands/tools/mcp/mcp_agent_tool.py index af0c069a17..bedd93f242 100644 --- a/src/strands/tools/mcp/mcp_agent_tool.py +++ b/src/strands/tools/mcp/mcp_agent_tool.py @@ -6,6 +6,7 @@ """ import logging +from datetime import timedelta from typing import TYPE_CHECKING, Any from mcp.types import Tool as MCPTool @@ -28,7 +29,13 @@ class MCPAgentTool(AgentTool): seamlessly within the agent framework. """ - def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: str | None = None) -> None: + def __init__( + self, + mcp_tool: MCPTool, + mcp_client: "MCPClient", + name_override: str | None = None, + timeout: timedelta | None = None, + ) -> None: """Initialize a new MCPAgentTool instance. Args: @@ -36,12 +43,14 @@ def __init__(self, mcp_tool: MCPTool, mcp_client: "MCPClient", name_override: st mcp_client: The MCP server connection to use for tool invocation name_override: Optional name to use for the agent tool (for disambiguation) If None, uses the original MCP tool name + timeout: Optional timeout duration for tool execution """ super().__init__() logger.debug("tool_name=<%s> | creating mcp agent tool", mcp_tool.name) self.mcp_tool = mcp_tool self.mcp_client = mcp_client self._agent_tool_name = name_override or mcp_tool.name + self.timeout = timeout @property def tool_name(self) -> str: @@ -105,5 +114,6 @@ async def stream(self, tool_use: ToolUse, invocation_state: dict[str, Any], **kw tool_use_id=tool_use["toolUseId"], name=self.mcp_tool.name, # Use original MCP name for server communication arguments=tool_use["input"], + read_timeout_seconds=self.timeout, ) yield ToolResultEvent(result) diff --git a/src/strands/tools/mcp/mcp_client.py b/src/strands/tools/mcp/mcp_client.py index 2fe006466c..b16b9c2b4c 100644 --- a/src/strands/tools/mcp/mcp_client.py +++ b/src/strands/tools/mcp/mcp_client.py @@ -119,10 +119,12 @@ def __init__( mcp_instrumentation() self._session_id = uuid.uuid4() self._log_debug_with_thread("initializing MCPClient connection") - # Main thread blocks until future completesock + # Main thread blocks until future completes self._init_future: futures.Future[None] = futures.Future() + # Set within the inner loop as it needs the asyncio loop + self._close_future: asyncio.futures.Future[None] | None = None + self._close_exception: None | Exception = None # Do not want to block other threads while close event is false - self._close_event = asyncio.Event() self._transport_callable = transport_callable self._background_thread: threading.Thread | None = None @@ -288,11 +290,12 @@ def stop( - _background_thread: Thread running the async event loop - _background_thread_session: MCP ClientSession (auto-closed by context manager) - _background_thread_event_loop: AsyncIO event loop in background thread - - _close_event: AsyncIO event to signal thread shutdown + - _close_future: AsyncIO future to signal thread shutdown + - _close_exception: Exception that caused the background thread shutdown; None if a normal shutdown occurred. - _init_future: Future for initialization synchronization Cleanup order: - 1. Signal close event to background thread (if session initialized) + 1. Signal close future to background thread (if session initialized) 2. Wait for background thread to complete 3. Reset all state for reuse @@ -303,13 +306,14 @@ def stop( """ self._log_debug_with_thread("exiting MCPClient context") - # Only try to signal close event if we have a background thread + # Only try to signal close future if we have a background thread if self._background_thread is not None: - # Signal close event if event loop exists + # Signal close future if event loop exists if self._background_thread_event_loop is not None: async def _set_close_event() -> None: - self._close_event.set() + if self._close_future and not self._close_future.done(): + self._close_future.set_result(None) # Not calling _invoke_on_background_thread since the session does not need to exist # we only need the thread and event loop to exist. @@ -317,11 +321,11 @@ async def _set_close_event() -> None: self._log_debug_with_thread("waiting for background thread to join") self._background_thread.join() + self._log_debug_with_thread("background thread is closed, MCPClient context exited") # Reset fields to allow instance reuse self._init_future = futures.Future() - self._close_event = asyncio.Event() self._background_thread = None self._background_thread_session = None self._background_thread_event_loop = None @@ -330,6 +334,11 @@ async def _set_close_event() -> None: self._tool_provider_started = False self._consumers = set() + if self._close_exception: + exception = self._close_exception + self._close_exception = None + raise RuntimeError("Connection to the MCP server was closed") from exception + def list_tools_sync( self, pagination_token: str | None = None, @@ -563,6 +572,10 @@ async def _async_background_thread(self) -> None: signals readiness to the main thread, and waits for a close signal. """ self._log_debug_with_thread("starting async background thread for MCP connection") + + # Initialized here so that it has the asyncio loop + self._close_future = asyncio.Future() + try: async with self._transport_callable() as (read_stream, write_stream, *_): self._log_debug_with_thread("transport connection established") @@ -583,8 +596,9 @@ async def _async_background_thread(self) -> None: self._log_debug_with_thread("waiting for close signal") # Keep background thread running until signaled to close. - # Thread is not blocked as this is an asyncio.Event not a threading.Event - await self._close_event.wait() + # Thread is not blocked as this a future + await self._close_future + self._log_debug_with_thread("close signal received") except Exception as e: # If we encounter an exception and the future is still running, @@ -592,6 +606,12 @@ async def _async_background_thread(self) -> None: if not self._init_future.done(): self._init_future.set_exception(e) else: + # _close_future is automatically cancelled by the framework which doesn't provide us with the useful + # exception, so instead we store the exception in a different field where stop() can read it + self._close_exception = e + if self._close_future and not self._close_future.done(): + self._close_future.set_result(None) + self._log_debug_with_thread( "encountered exception on background thread after initialization %s", str(e) ) @@ -601,7 +621,7 @@ def _background_task(self) -> None: This method creates a new event loop for the background thread, sets it as the current event loop, and runs the async_background_thread - coroutine until completion. In this case "until completion" means until the _close_event is set. + coroutine until completion. In this case "until completion" means until the _close_future is resolved. This allows for a long-running event loop. """ self._log_debug_with_thread("setting up background task event loop") @@ -699,9 +719,34 @@ def _log_debug_with_thread(self, msg: str, *args: Any, **kwargs: Any) -> None: ) def _invoke_on_background_thread(self, coro: Coroutine[Any, Any, T]) -> futures.Future[T]: - if self._background_thread_session is None or self._background_thread_event_loop is None: + # save a reference to this so that even if it's reset we have the original + close_future = self._close_future + + if ( + self._background_thread_session is None + or self._background_thread_event_loop is None + or close_future is None + ): raise MCPClientInitializationError("the client session was not initialized") - return asyncio.run_coroutine_threadsafe(coro=coro, loop=self._background_thread_event_loop) + + async def run_async() -> T: + # Fix for strands-agents/sdk-python/issues/995 - cancel all pending invocations if/when the session closes + invoke_event = asyncio.create_task(coro) + tasks: list[asyncio.Task | asyncio.Future] = [ + invoke_event, + close_future, + ] + + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + + if done.pop() == close_future: + self._log_debug_with_thread("event loop for the server closed before the invoke completed") + raise RuntimeError("Connection to the MCP server was closed") + else: + return await invoke_event + + invoke_future = asyncio.run_coroutine_threadsafe(coro=run_async(), loop=self._background_thread_event_loop) + return invoke_future def _should_include_tool(self, tool: MCPAgentTool) -> bool: """Check if a tool should be included based on constructor filters.""" diff --git a/src/strands/types/session.py b/src/strands/types/session.py index 4e72a14689..8b78ab4485 100644 --- a/src/strands/types/session.py +++ b/src/strands/types/session.py @@ -7,7 +7,7 @@ from enum import Enum from typing import TYPE_CHECKING, Any, Optional -from ..agent.interrupt import InterruptState +from ..interrupt import _InterruptState from .content import Message if TYPE_CHECKING: @@ -148,7 +148,7 @@ def to_dict(self) -> dict[str, Any]: def initialize_internal_state(self, agent: "Agent") -> None: """Initialize internal state of agent.""" if "interrupt_state" in self._internal_state: - agent._interrupt_state = InterruptState.from_dict(self._internal_state["interrupt_state"]) + agent._interrupt_state = _InterruptState.from_dict(self._internal_state["interrupt_state"]) @dataclass diff --git a/tests/strands/agent/hooks/test_hook_registry.py b/tests/strands/agent/hooks/test_hook_registry.py index 680ded682b..ad1415f222 100644 --- a/tests/strands/agent/hooks/test_hook_registry.py +++ b/tests/strands/agent/hooks/test_hook_registry.py @@ -113,29 +113,32 @@ def test_get_callbacks_for_after_event(hook_registry, after_event): assert callbacks[1] == callback1 # Reverse order -def test_invoke_callbacks(hook_registry, normal_event): - """Test that invoke_callbacks calls all registered callbacks for an event.""" +@pytest.mark.asyncio +async def test_invoke_callbacks_async(hook_registry, normal_event): + """Test that invoke_callbacks_async calls all registered callbacks for an event.""" callback1 = Mock() callback2 = Mock() hook_registry.add_callback(NormalTestEvent, callback1) hook_registry.add_callback(NormalTestEvent, callback2) - hook_registry.invoke_callbacks(normal_event) + await hook_registry.invoke_callbacks_async(normal_event) callback1.assert_called_once_with(normal_event) callback2.assert_called_once_with(normal_event) -def test_invoke_callbacks_no_registered_callbacks(hook_registry, normal_event): - """Test that invoke_callbacks doesn't fail when there are no registered callbacks.""" +@pytest.mark.asyncio +async def test_invoke_callbacks_async_no_registered_callbacks(hook_registry, normal_event): + """Test that invoke_callbacks_async doesn't fail when there are no registered callbacks.""" # No callbacks registered - hook_registry.invoke_callbacks(normal_event) + await hook_registry.invoke_callbacks_async(normal_event) # Test passes if no exception is raised -def test_invoke_callbacks_after_event(hook_registry, after_event): - """Test that invoke_callbacks calls callbacks in reverse order for after events.""" +@pytest.mark.asyncio +async def test_invoke_callbacks_async_after_event(hook_registry, after_event): + """Test that invoke_callbacks_async calls callbacks in reverse order for after events.""" call_order: List[str] = [] def callback1(_event): @@ -147,7 +150,7 @@ def callback2(_event): hook_registry.add_callback(AfterTestEvent, callback1) hook_registry.add_callback(AfterTestEvent, callback2) - hook_registry.invoke_callbacks(after_event) + await hook_registry.invoke_callbacks_async(after_event) assert call_order == ["callback2", "callback1"] # Reverse order diff --git a/tests/strands/agent/test_agent.py b/tests/strands/agent/test_agent.py index 3a0bc2dfbf..d04f579481 100644 --- a/tests/strands/agent/test_agent.py +++ b/tests/strands/agent/test_agent.py @@ -1221,6 +1221,37 @@ async def test_stream_async_multi_modal_input(mock_model, agent, agenerator, ali assert tru_message == exp_message +def test_system_prompt_setter_string(): + """Test that setting system_prompt with string updates both internal fields.""" + agent = Agent(system_prompt="initial prompt") + + agent.system_prompt = "updated prompt" + + assert agent.system_prompt == "updated prompt" + assert agent._system_prompt_content == [{"text": "updated prompt"}] + + +def test_system_prompt_setter_list(): + """Test that setting system_prompt with list updates both internal fields.""" + agent = Agent() + + content_blocks = [{"text": "You are helpful"}, {"cache_control": {"type": "ephemeral"}}] + agent.system_prompt = content_blocks + + assert agent.system_prompt == "You are helpful" + assert agent._system_prompt_content == content_blocks + + +def test_system_prompt_setter_none(): + """Test that setting system_prompt to None clears both internal fields.""" + agent = Agent(system_prompt="initial prompt") + + agent.system_prompt = None + + assert agent.system_prompt is None + assert agent._system_prompt_content is None + + @pytest.mark.asyncio async def test_stream_async_passes_invocation_state(agent, mock_model, mock_event_loop_cycle, agenerator, alist): mock_model.mock_stream.side_effect = [ @@ -1360,6 +1391,7 @@ def test_agent_call_creates_and_ends_span_on_success(mock_get_tracer, mock_model tools=agent.tool_names, system_prompt=agent.system_prompt, custom_trace_attributes=agent.trace_attributes, + tools_config=unittest.mock.ANY, ) # Verify span was ended with the result @@ -1394,6 +1426,7 @@ async def test_event_loop(*args, **kwargs): tools=agent.tool_names, system_prompt=agent.system_prompt, custom_trace_attributes=agent.trace_attributes, + tools_config=unittest.mock.ANY, ) expected_response = AgentResult( @@ -1432,6 +1465,7 @@ def test_agent_call_creates_and_ends_span_on_exception(mock_get_tracer, mock_mod tools=agent.tool_names, system_prompt=agent.system_prompt, custom_trace_attributes=agent.trace_attributes, + tools_config=unittest.mock.ANY, ) # Verify span was ended with the exception @@ -1468,6 +1502,7 @@ async def test_agent_stream_async_creates_and_ends_span_on_exception(mock_get_tr tools=agent.tool_names, system_prompt=agent.system_prompt, custom_trace_attributes=agent.trace_attributes, + tools_config=unittest.mock.ANY, ) # Verify span was ended with the exception @@ -2240,8 +2275,8 @@ def test_agent_backwards_compatibility_single_text_block(): # Should extract text for backwards compatibility assert agent.system_prompt == text - - + + @pytest.mark.parametrize( "content, expected", [ diff --git a/tests/strands/agent/test_interrupt.py b/tests/strands/agent/test_interrupt.py deleted file mode 100644 index e248c29a68..0000000000 --- a/tests/strands/agent/test_interrupt.py +++ /dev/null @@ -1,61 +0,0 @@ -import pytest - -from strands.agent.interrupt import InterruptState -from strands.interrupt import Interrupt - - -@pytest.fixture -def interrupt(): - return Interrupt(id="test_id", name="test_name", reason="test reason") - - -def test_interrupt_activate(): - interrupt_state = InterruptState() - - interrupt_state.activate(context={"test": "context"}) - - assert interrupt_state.activated - - tru_context = interrupt_state.context - exp_context = {"test": "context"} - assert tru_context == exp_context - - -def test_interrupt_deactivate(): - interrupt_state = InterruptState(context={"test": "context"}, activated=True) - - interrupt_state.deactivate() - - assert not interrupt_state.activated - - tru_context = interrupt_state.context - exp_context = {} - assert tru_context == exp_context - - -def test_interrupt_state_to_dict(interrupt): - interrupt_state = InterruptState(interrupts={"test_id": interrupt}, context={"test": "context"}, activated=True) - - tru_data = interrupt_state.to_dict() - exp_data = { - "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, - "context": {"test": "context"}, - "activated": True, - } - assert tru_data == exp_data - - -def test_interrupt_state_from_dict(): - data = { - "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, - "context": {"test": "context"}, - "activated": True, - } - - tru_state = InterruptState.from_dict(data) - exp_state = InterruptState( - interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, - context={"test": "context"}, - activated=True, - ) - assert tru_state == exp_state diff --git a/tests/strands/event_loop/test_event_loop.py b/tests/strands/event_loop/test_event_loop.py index 72fe1b4bda..9335f91a84 100644 --- a/tests/strands/event_loop/test_event_loop.py +++ b/tests/strands/event_loop/test_event_loop.py @@ -1,12 +1,11 @@ import concurrent import unittest.mock -from unittest.mock import ANY, MagicMock, call, patch +from unittest.mock import ANY, AsyncMock, MagicMock, call, patch import pytest import strands import strands.telemetry -from strands.agent.interrupt import InterruptState from strands.hooks import ( AfterModelCallEvent, BeforeModelCallEvent, @@ -14,7 +13,7 @@ HookRegistry, MessageAddedEvent, ) -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState from strands.telemetry.metrics import EventLoopMetrics from strands.tools.executors import SequentialToolExecutor from strands.tools.registry import ToolRegistry @@ -143,7 +142,7 @@ def agent(model, system_prompt, messages, tool_registry, thread_pool, hook_regis mock.event_loop_metrics = EventLoopMetrics() mock.hooks = hook_registry mock.tool_executor = tool_executor - mock._interrupt_state = InterruptState() + mock._interrupt_state = _InterruptState() return mock @@ -750,6 +749,7 @@ async def test_request_state_initialization(alist): # not setting this to False results in endless recursion mock_agent._interrupt_state.activated = False mock_agent.event_loop_metrics.start_cycle.return_value = (0, MagicMock()) + mock_agent.hooks.invoke_callbacks_async = AsyncMock() # Call without providing request_state stream = strands.event_loop.event_loop.event_loop_cycle( diff --git a/tests/strands/event_loop/test_event_loop_structured_output.py b/tests/strands/event_loop/test_event_loop_structured_output.py index 6d3e3a9b51..886da2f0bd 100644 --- a/tests/strands/event_loop/test_event_loop_structured_output.py +++ b/tests/strands/event_loop/test_event_loop_structured_output.py @@ -1,6 +1,6 @@ """Tests for structured output integration in the event loop.""" -from unittest.mock import Mock, patch +from unittest.mock import AsyncMock, Mock, patch import pytest from pydantic import BaseModel @@ -38,10 +38,10 @@ def mock_agent(): agent.tool_registry = ToolRegistry() agent.event_loop_metrics = EventLoopMetrics() agent.hooks = Mock() - agent.hooks.invoke_callbacks = Mock() + agent.hooks.invoke_callbacks_async = AsyncMock() agent.trace_span = None agent.tool_executor = Mock() - agent._append_message = Mock() + agent._append_message = AsyncMock() # Set up _interrupt_state properly agent._interrupt_state = Mock() diff --git a/tests/strands/experimental/hooks/test_hook_aliases.py b/tests/strands/experimental/hooks/test_hook_aliases.py index db9cd3783f..6744aa00c0 100644 --- a/tests/strands/experimental/hooks/test_hook_aliases.py +++ b/tests/strands/experimental/hooks/test_hook_aliases.py @@ -9,6 +9,8 @@ import sys from unittest.mock import Mock +import pytest + from strands.experimental.hooks import ( AfterModelInvocationEvent, AfterToolInvocationEvent, @@ -80,7 +82,8 @@ def test_after_model_call_event_type_equality(): assert isinstance(after_model_event, AfterModelCallEvent) -def test_experimental_aliases_in_hook_registry(): +@pytest.mark.asyncio +async def test_experimental_aliases_in_hook_registry(): """Verify that experimental aliases work with hook registry callbacks.""" hook_registry = HookRegistry() callback_called = False @@ -103,7 +106,7 @@ def experimental_callback(event: BeforeToolInvocationEvent): ) # Invoke callbacks - should work since alias points to same type - hook_registry.invoke_callbacks(test_event) + await hook_registry.invoke_callbacks_async(test_event) assert callback_called assert received_event is test_event diff --git a/tests/strands/hooks/test_registry.py b/tests/strands/hooks/test_registry.py index 6918bd2eee..3daf417347 100644 --- a/tests/strands/hooks/test_registry.py +++ b/tests/strands/hooks/test_registry.py @@ -2,9 +2,8 @@ import pytest -from strands.agent.interrupt import InterruptState -from strands.hooks import BeforeToolCallEvent, HookRegistry -from strands.interrupt import Interrupt +from strands.hooks import AgentInitializedEvent, BeforeInvocationEvent, BeforeToolCallEvent, HookRegistry +from strands.interrupt import Interrupt, _InterruptState @pytest.fixture @@ -15,11 +14,19 @@ def registry(): @pytest.fixture def agent(): instance = unittest.mock.Mock() - instance._interrupt_state = InterruptState() + instance._interrupt_state = _InterruptState() return instance -def test_hook_registry_invoke_callbacks_interrupt(registry, agent): +def test_hook_registry_add_callback_agent_init_coroutine(registry): + callback = unittest.mock.AsyncMock() + + with pytest.raises(ValueError, match=r"AgentInitializedEvent can only be registered with a synchronous callback"): + registry.add_callback(AgentInitializedEvent, callback) + + +@pytest.mark.asyncio +async def test_hook_registry_invoke_callbacks_async_interrupt(registry, agent): event = BeforeToolCallEvent( agent=agent, selected_tool=None, @@ -35,7 +42,7 @@ def test_hook_registry_invoke_callbacks_interrupt(registry, agent): registry.add_callback(BeforeToolCallEvent, callback2) registry.add_callback(BeforeToolCallEvent, callback3) - _, tru_interrupts = registry.invoke_callbacks(event) + _, tru_interrupts = await registry.invoke_callbacks_async(event) exp_interrupts = [ Interrupt( id="v1:before_tool_call:test_tool_id:da3551f3-154b-5978-827e-50ac387877ee", @@ -55,7 +62,8 @@ def test_hook_registry_invoke_callbacks_interrupt(registry, agent): callback3.assert_called_once_with(event) -def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent): +@pytest.mark.asyncio +async def test_hook_registry_invoke_callbacks_async_interrupt_name_clash(registry, agent): event = BeforeToolCallEvent( agent=agent, selected_tool=None, @@ -70,4 +78,12 @@ def test_hook_registry_invoke_callbacks_interrupt_name_clash(registry, agent): registry.add_callback(BeforeToolCallEvent, callback2) with pytest.raises(ValueError, match="interrupt_name= | interrupt name used more than once"): - registry.invoke_callbacks(event) + await registry.invoke_callbacks_async(event) + + +def test_hook_registry_invoke_callbacks_coroutine(registry, agent): + callback = unittest.mock.AsyncMock() + registry.add_callback(BeforeInvocationEvent, callback) + + with pytest.raises(RuntimeError, match=r"use invoke_callbacks_async to invoke async callback"): + registry.invoke_callbacks(BeforeInvocationEvent(agent=agent)) diff --git a/tests/strands/models/test_litellm.py b/tests/strands/models/test_litellm.py index 57a8593cd4..aafee1d17f 100644 --- a/tests/strands/models/test_litellm.py +++ b/tests/strands/models/test_litellm.py @@ -192,6 +192,8 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, mock_event_7 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason=None, delta=mock_delta_7)]) mock_event_8 = unittest.mock.Mock(choices=[unittest.mock.Mock(finish_reason="tool_calls", delta=mock_delta_8)]) mock_event_9 = unittest.mock.Mock() + mock_event_9.usage.prompt_tokens_details.cached_tokens = 10 + mock_event_9.usage.prompt_tokens_details.cache_creation_tokens = 10 litellm_acompletion.side_effect = unittest.mock.AsyncMock( return_value=agenerator( @@ -252,6 +254,8 @@ async def test_stream(litellm_acompletion, api_key, model_id, model, agenerator, { "metadata": { "usage": { + "cacheReadInputTokens": mock_event_9.usage.prompt_tokens_details.cached_tokens, + "cacheWriteInputTokens": mock_event_9.usage.prompt_tokens_details.cache_creation_tokens, "inputTokens": mock_event_9.usage.prompt_tokens, "outputTokens": mock_event_9.usage.completion_tokens, "totalTokens": mock_event_9.usage.total_tokens, @@ -402,3 +406,75 @@ async def test_context_window_maps_to_typed_exception(litellm_acompletion, model with pytest.raises(ContextWindowOverflowException): async for _ in model.stream([{"role": "user", "content": [{"text": "x"}]}]): pass + + +@pytest.mark.asyncio +async def test_stream_raises_error_when_stream_is_false(model): + """Test that stream raises ValueError when stream parameter is explicitly False.""" + messages = [{"role": "user", "content": [{"text": "test"}]}] + + with pytest.raises(ValueError, match="stream parameter cannot be explicitly set to False"): + async for _ in model.stream(messages, stream=False): + pass + + +def test_format_request_messages_with_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}, {"cachePoint": {"type": "default"}}] + + result = LiteLLMModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "You are a helpful assistant.", "cache_control": {"type": "ephemeral"}} + ], + }, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_backward_compatibility_system_prompt(): + """Test that system_prompt is converted to system_prompt_content when system_prompt_content is None.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt = "You are a helpful assistant." + + result = LiteLLMModel.format_request_messages(messages, system_prompt=system_prompt) + + expected = [ + {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_cache_point_support(): + """Test that cache points are properly applied to preceding content blocks.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [ + {"text": "First instruction."}, + {"text": "Second instruction."}, + {"cachePoint": {"type": "default"}}, + {"text": "Third instruction."}, + ] + + result = LiteLLMModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + { + "role": "system", + "content": [ + {"type": "text", "text": "First instruction."}, + {"type": "text", "text": "Second instruction.", "cache_control": {"type": "ephemeral"}}, + {"type": "text", "text": "Third instruction."}, + ], + }, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected diff --git a/tests/strands/models/test_openai.py b/tests/strands/models/test_openai.py index cc30b7420c..0de0c4ebcc 100644 --- a/tests/strands/models/test_openai.py +++ b/tests/strands/models/test_openai.py @@ -944,3 +944,45 @@ async def test_structured_output_rate_limit_as_throttle(openai_client, model, me # Verify the exception message contains the original error assert "tokens per min" in str(exc_info.value) assert exc_info.value.__cause__ == mock_error + + +def test_format_request_messages_with_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}] + + result = OpenAIModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + expected = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected + + +def test_format_request_messages_with_none_system_prompt_content(): + """Test format_request_messages with system_prompt_content parameter.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + + result = OpenAIModel.format_request_messages(messages) + + expected = [{"role": "user", "content": [{"text": "Hello", "type": "text"}]}] + + assert result == expected + + +def test_format_request_messages_drops_cache_points(): + """Test that cache points are dropped in OpenAI format_request_messages.""" + messages = [{"role": "user", "content": [{"text": "Hello"}]}] + system_prompt_content = [{"text": "You are a helpful assistant."}, {"cachePoint": {"type": "default"}}] + + result = OpenAIModel.format_request_messages(messages, system_prompt_content=system_prompt_content) + + # Cache points should be dropped, only text content included + expected = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"text": "Hello", "type": "text"}]}, + ] + + assert result == expected diff --git a/tests/strands/session/test_repository_session_manager.py b/tests/strands/session/test_repository_session_manager.py index ed0ec9072e..451d0dd091 100644 --- a/tests/strands/session/test_repository_session_manager.py +++ b/tests/strands/session/test_repository_session_manager.py @@ -7,7 +7,7 @@ from strands.agent.agent import Agent from strands.agent.conversation_manager.sliding_window_conversation_manager import SlidingWindowConversationManager from strands.agent.conversation_manager.summarizing_conversation_manager import SummarizingConversationManager -from strands.agent.interrupt import InterruptState +from strands.interrupt import _InterruptState from strands.session.repository_session_manager import RepositorySessionManager from strands.types.content import ContentBlock from strands.types.exceptions import SessionException @@ -131,7 +131,7 @@ def test_initialize_restores_existing_agent(session_manager, agent): assert len(agent.messages) == 1 assert agent.messages[0]["role"] == "user" assert agent.messages[0]["content"][0]["text"] == "Hello" - assert agent._interrupt_state == InterruptState(interrupts={}, context={"test": "init"}, activated=False) + assert agent._interrupt_state == _InterruptState(interrupts={}, context={"test": "init"}, activated=False) def test_initialize_restores_existing_agent_with_summarizing_conversation_manager(session_manager): diff --git a/tests/strands/telemetry/test_tracer.py b/tests/strands/telemetry/test_tracer.py index 05dbe387f4..98cfb459f3 100644 --- a/tests/strands/telemetry/test_tracer.py +++ b/tests/strands/telemetry/test_tracer.py @@ -163,11 +163,11 @@ def test_start_model_invoke_span(mock_tracer): assert span is not None -def test_start_model_invoke_span_latest_conventions(mock_tracer): +def test_start_model_invoke_span_latest_conventions(mock_tracer, monkeypatch): """Test starting a model invoke span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -244,11 +244,11 @@ def test_end_model_invoke_span(mock_span): mock_span.end.assert_called_once() -def test_end_model_invoke_span_latest_conventions(mock_span): +def test_end_model_invoke_span_latest_conventions(mock_span, monkeypatch): """Test ending a model invoke span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True message = {"role": "assistant", "content": [{"text": "Response"}]} usage = Usage(inputTokens=10, outputTokens=20, totalTokens=30) metrics = Metrics(latencyMs=20, timeToFirstByteMs=10) @@ -307,11 +307,11 @@ def test_start_tool_call_span(mock_tracer): assert span is not None -def test_start_tool_call_span_latest_conventions(mock_tracer): +def test_start_tool_call_span_latest_conventions(mock_tracer, monkeypatch): """Test starting a tool call span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -396,11 +396,11 @@ def test_start_swarm_span_with_contentblock_task(mock_tracer): assert span is not None -def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer): +def test_start_swarm_span_with_contentblock_task_latest_conventions(mock_tracer, monkeypatch): """Test starting a swarm call span with task as list of contentBlock with latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -439,10 +439,10 @@ def test_end_swarm_span(mock_span): ) -def test_end_swarm_span_latest_conventions(mock_span): +def test_end_swarm_span_latest_conventions(mock_span, monkeypatch): """Test ending a tool call span with latest semantic conventions.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True swarm_final_reuslt = "foo bar bar" tracer.end_swarm_span(mock_span, swarm_final_reuslt) @@ -503,10 +503,10 @@ def test_end_tool_call_span(mock_span): mock_span.end.assert_called_once() -def test_end_tool_call_span_latest_conventions(mock_span): +def test_end_tool_call_span_latest_conventions(mock_span, monkeypatch): """Test ending a tool call span with the latest semantic conventions.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tool_result = {"status": "success", "content": [{"text": "Tool result"}, {"json": {"foo": "bar"}}]} tracer.end_tool_call_span(mock_span, tool_result) @@ -558,11 +558,11 @@ def test_start_event_loop_cycle_span(mock_tracer): assert span is not None -def test_start_event_loop_cycle_span_latest_conventions(mock_tracer): +def test_start_event_loop_cycle_span_latest_conventions(mock_tracer, monkeypatch): """Test starting an event loop cycle span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -609,10 +609,10 @@ def test_end_event_loop_cycle_span(mock_span): mock_span.end.assert_called_once() -def test_end_event_loop_cycle_span_latest_conventions(mock_span): +def test_end_event_loop_cycle_span_latest_conventions(mock_span, monkeypatch): """Test ending an event loop cycle span with the latest semantic conventions.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True message = {"role": "assistant", "content": [{"text": "Response"}]} tool_result_message = { "role": "assistant", @@ -679,11 +679,11 @@ def test_start_agent_span(mock_tracer): assert span is not None -def test_start_agent_span_latest_conventions(mock_tracer): +def test_start_agent_span_latest_conventions(mock_tracer, monkeypatch): """Test starting an agent span with the latest semantic conventions.""" with mock.patch("strands.telemetry.tracer.trace_api.get_tracer", return_value=mock_tracer): + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True tracer.tracer = mock_tracer mock_span = mock.MagicMock() @@ -749,10 +749,10 @@ def test_end_agent_span(mock_span): mock_span.end.assert_called_once() -def test_end_agent_span_latest_conventions(mock_span): +def test_end_agent_span_latest_conventions(mock_span, monkeypatch): """Test ending an agent span with the latest semantic conventions.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_latest_experimental") tracer = Tracer() - tracer.use_latest_genai_conventions = True # Mock AgentResult with metrics mock_metrics = mock.MagicMock() @@ -1324,3 +1324,59 @@ def test_start_event_loop_cycle_span_with_tool_result_message(mock_tracer): "gen_ai.tool.message", attributes={"content": json.dumps(messages[0]["content"])} ) assert span is not None + + +def test_start_agent_span_does_not_include_tool_definitions_by_default(): + """Verify that start_agent_span does not include tool definitions by default.""" + tracer = Tracer() + tracer._start_span = mock.MagicMock() + + tools_config = { + "my_tool": { + "name": "my_tool", + "description": "A test tool", + "inputSchema": {"json": {}}, + "outputSchema": {"json": {}}, + } + } + + tracer.start_agent_span(messages=[], agent_name="TestAgent", tools_config=tools_config) + + tracer._start_span.assert_called_once() + _, call_kwargs = tracer._start_span.call_args + attributes = call_kwargs.get("attributes", {}) + assert "gen_ai.tool.definitions" not in attributes + + +def test_start_agent_span_includes_tool_definitions_when_enabled(monkeypatch): + """Verify that start_agent_span includes tool definitions when enabled.""" + monkeypatch.setenv("OTEL_SEMCONV_STABILITY_OPT_IN", "gen_ai_tool_definitions") + tracer = Tracer() + tracer._start_span = mock.MagicMock() + + tools_config = { + "my_tool": { + "name": "my_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + "outputSchema": {"json": {"type": "object", "properties": {}}}, + } + } + + tracer.start_agent_span(messages=[], agent_name="TestAgent", tools_config=tools_config) + + tracer._start_span.assert_called_once() + _, call_kwargs = tracer._start_span.call_args + attributes = call_kwargs.get("attributes", {}) + + assert "gen_ai.tool.definitions" in attributes + expected_tool_details = [ + { + "name": "my_tool", + "description": "A test tool", + "inputSchema": {"json": {"type": "object", "properties": {}}}, + "outputSchema": {"json": {"type": "object", "properties": {}}}, + } + ] + expected_json = serialize(expected_tool_details) + assert attributes["gen_ai.tool.definitions"] == expected_json diff --git a/tests/strands/test_interrupt.py b/tests/strands/test_interrupt.py index 8ce9721037..a45d524e46 100644 --- a/tests/strands/test_interrupt.py +++ b/tests/strands/test_interrupt.py @@ -1,6 +1,6 @@ import pytest -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState @pytest.fixture @@ -22,3 +22,109 @@ def test_interrupt_to_dict(interrupt): "response": {"response": "test"}, } assert tru_dict == exp_dict + + +def test_interrupt_state_activate(): + interrupt_state = _InterruptState() + + interrupt_state.activate(context={"test": "context"}) + + assert interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {"test": "context"} + assert tru_context == exp_context + + +def test_interrupt_state_deactivate(): + interrupt_state = _InterruptState(context={"test": "context"}, activated=True) + + interrupt_state.deactivate() + + assert not interrupt_state.activated + + tru_context = interrupt_state.context + exp_context = {} + assert tru_context == exp_context + + +def test_interrupt_state_to_dict(): + interrupt_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + context={"test": "context"}, + activated=True, + ) + + tru_data = interrupt_state.to_dict() + exp_data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + assert tru_data == exp_data + + +def test_interrupt_state_from_dict(): + data = { + "interrupts": {"test_id": {"id": "test_id", "name": "test_name", "reason": "test reason", "response": None}}, + "context": {"test": "context"}, + "activated": True, + } + + tru_state = _InterruptState.from_dict(data) + exp_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + context={"test": "context"}, + activated=True, + ) + assert tru_state == exp_state + + +def test_interrupt_state_resume(): + interrupt_state = _InterruptState( + interrupts={"test_id": Interrupt(id="test_id", name="test_name", reason="test reason")}, + activated=True, + ) + + prompt = [ + { + "interruptResponse": { + "interruptId": "test_id", + "response": "test response", + } + } + ] + interrupt_state.resume(prompt) + + tru_response = interrupt_state.interrupts["test_id"].response + exp_response = "test response" + assert tru_response == exp_response + + +def test_interrupt_state_resumse_deactivated(): + interrupt_state = _InterruptState(activated=False) + interrupt_state.resume([]) + + +def test_interrupt_state_resume_invalid_prompt(): + interrupt_state = _InterruptState(activated=True) + + exp_message = r"prompt_type= \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + interrupt_state.resume("invalid") + + +def test_interrupt_state_resume_invalid_content(): + interrupt_state = _InterruptState(activated=True) + + exp_message = r"content_types=<\['text'\]> \| must resume from interrupt with list of interruptResponse's" + with pytest.raises(TypeError, match=exp_message): + interrupt_state.resume([{"text": "invalid"}]) + + +def test_interrupt_resume_invalid_id(): + interrupt_state = _InterruptState(activated=True) + + exp_message = r"interrupt_id= \| no interrupt found" + with pytest.raises(KeyError, match=exp_message): + interrupt_state.resume([{"interruptResponse": {"interruptId": "invalid", "response": None}}]) diff --git a/tests/strands/tools/executors/conftest.py b/tests/strands/tools/executors/conftest.py index d25cf14bdf..4d299a5395 100644 --- a/tests/strands/tools/executors/conftest.py +++ b/tests/strands/tools/executors/conftest.py @@ -4,8 +4,8 @@ import pytest import strands -from strands.agent.interrupt import InterruptState from strands.hooks import AfterToolCallEvent, BeforeToolCallEvent, HookRegistry +from strands.interrupt import _InterruptState from strands.tools.registry import ToolRegistry from strands.types.tools import ToolContext @@ -104,7 +104,7 @@ def agent(tool_registry, hook_registry): mock_agent = unittest.mock.Mock() mock_agent.tool_registry = tool_registry mock_agent.hooks = hook_registry - mock_agent._interrupt_state = InterruptState() + mock_agent._interrupt_state = _InterruptState() return mock_agent diff --git a/tests/strands/tools/mcp/test_mcp_agent_tool.py b/tests/strands/tools/mcp/test_mcp_agent_tool.py index 442a9919ba..81a2d9afb2 100644 --- a/tests/strands/tools/mcp/test_mcp_agent_tool.py +++ b/tests/strands/tools/mcp/test_mcp_agent_tool.py @@ -1,3 +1,4 @@ +from datetime import timedelta from unittest.mock import MagicMock import pytest @@ -88,5 +89,31 @@ async def test_stream(mcp_agent_tool, mock_mcp_client, alist): assert tru_events == exp_events mock_mcp_client.call_tool_async.assert_called_once_with( - tool_use_id="test-123", name="test_tool", arguments={"param": "value"} + tool_use_id="test-123", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=None + ) + + +def test_timeout_initialization(mock_mcp_tool, mock_mcp_client): + timeout = timedelta(seconds=30) + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout) + assert agent_tool.timeout == timeout + + +def test_timeout_default_none(mock_mcp_tool, mock_mcp_client): + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client) + assert agent_tool.timeout is None + + +@pytest.mark.asyncio +async def test_stream_with_timeout(mock_mcp_tool, mock_mcp_client, alist): + timeout = timedelta(seconds=45) + agent_tool = MCPAgentTool(mock_mcp_tool, mock_mcp_client, timeout=timeout) + tool_use = {"toolUseId": "test-456", "name": "test_tool", "input": {"param": "value"}} + + tru_events = await alist(agent_tool.stream(tool_use, {})) + exp_events = [ToolResultEvent(mock_mcp_client.call_tool_async.return_value)] + + assert tru_events == exp_events + mock_mcp_client.call_tool_async.assert_called_once_with( + tool_use_id="test-456", name="test_tool", arguments={"param": "value"}, read_timeout_seconds=timeout ) diff --git a/tests/strands/tools/test_decorator.py b/tests/strands/tools/test_decorator.py index 25f9bc39e5..a2a4c6213b 100644 --- a/tests/strands/tools/test_decorator.py +++ b/tests/strands/tools/test_decorator.py @@ -3,15 +3,15 @@ """ from asyncio import Queue -from typing import Any, AsyncGenerator, Dict, Optional, Union +from typing import Annotated, Any, AsyncGenerator, Dict, List, Optional, Union from unittest.mock import MagicMock import pytest +from pydantic import Field import strands from strands import Agent -from strands.agent.interrupt import InterruptState -from strands.interrupt import Interrupt +from strands.interrupt import Interrupt, _InterruptState from strands.types._events import ToolInterruptEvent, ToolResultEvent, ToolStreamEvent from strands.types.tools import AgentTool, ToolContext, ToolUse @@ -151,7 +151,7 @@ async def test_stream_interrupt(alist): tool_use = {"toolUseId": "test_tool_id"} mock_agent = MagicMock() - mock_agent._interrupt_state = InterruptState() + mock_agent._interrupt_state = _InterruptState() invocation_state = {"agent": mock_agent} @@ -178,7 +178,7 @@ async def test_stream_interrupt_resume(alist): tool_use = {"toolUseId": "test_tool_id"} mock_agent = MagicMock() - mock_agent._interrupt_state = InterruptState(interrupts={interrupt.id: interrupt}) + mock_agent._interrupt_state = _InterruptState(interrupts={interrupt.id: interrupt}) invocation_state = {"agent": mock_agent} @@ -221,14 +221,7 @@ def test_tool(param1: str, param2: int) -> str: # Check basic spec properties assert spec["name"] == "test_tool" - assert ( - spec["description"] - == """Test tool function. - -Args: - param1: First parameter - param2: Second parameter""" - ) + assert spec["description"] == "Test tool function." # Check input schema schema = spec["inputSchema"]["json"] @@ -310,6 +303,174 @@ def test_tool(required: str, optional: Optional[int] = None) -> str: exp_events = [ ToolResultEvent({"toolUseId": "test-id", "status": "success", "content": [{"text": "Result: hello 42"}]}) ] + assert tru_events == exp_events + + +@pytest.mark.asyncio +async def test_docstring_description_extraction(): + """Test that docstring descriptions are extracted correctly, excluding Args section.""" + + @strands.tool + def tool_with_full_docstring(param1: str, param2: int) -> str: + """This is the main description. + + This is more description text. + + Args: + param1: First parameter + param2: Second parameter + + Returns: + A string result + + Raises: + ValueError: If something goes wrong + """ + return f"{param1} {param2}" + + spec = tool_with_full_docstring.tool_spec + assert ( + spec["description"] + == """This is the main description. + +This is more description text. + +Returns: + A string result + +Raises: + ValueError: If something goes wrong""" + ) + + +def test_docstring_args_variations(): + """Test that various Args section formats are properly excluded.""" + + @strands.tool + def tool_with_args(param: str) -> str: + """Main description. + + Args: + param: Parameter description + """ + return param + + @strands.tool + def tool_with_arguments(param: str) -> str: + """Main description. + + Arguments: + param: Parameter description + """ + return param + + @strands.tool + def tool_with_parameters(param: str) -> str: + """Main description. + + Parameters: + param: Parameter description + """ + return param + + @strands.tool + def tool_with_params(param: str) -> str: + """Main description. + + Params: + param: Parameter description + """ + return param + + for tool in [tool_with_args, tool_with_arguments, tool_with_parameters, tool_with_params]: + spec = tool.tool_spec + assert spec["description"] == "Main description." + + +def test_docstring_no_args_section(): + """Test docstring extraction when there's no Args section.""" + + @strands.tool + def tool_no_args(param: str) -> str: + """This is the complete description. + + Returns: + A string result + """ + return param + + spec = tool_no_args.tool_spec + expected_desc = """This is the complete description. + +Returns: + A string result""" + assert spec["description"] == expected_desc + + +def test_docstring_only_args_section(): + """Test docstring extraction when there's only an Args section.""" + + @strands.tool + def tool_only_args(param: str) -> str: + """Args: + param: Parameter description + """ + return param + + spec = tool_only_args.tool_spec + # Should fall back to function name when no description remains + assert spec["description"] == "tool_only_args" + + +def test_docstring_empty(): + """Test docstring extraction when docstring is empty.""" + + @strands.tool + def tool_empty_docstring(param: str) -> str: + return param + + spec = tool_empty_docstring.tool_spec + # Should fall back to function name + assert spec["description"] == "tool_empty_docstring" + + +def test_docstring_preserves_other_sections(): + """Test that non-Args sections are preserved in the description.""" + + @strands.tool + def tool_multiple_sections(param: str) -> str: + """Main description here. + + Args: + param: This should be excluded + + Returns: + This should be included + + Raises: + ValueError: This should be included + + Examples: + This should be included + + Note: + This should be included + """ + return param + + spec = tool_multiple_sections.tool_spec + description = spec["description"] + + # Should include main description and other sections + assert "Main description here." in description + assert "Returns:" in description + assert "This should be included" in description + assert "Raises:" in description + assert "Examples:" in description + assert "Note:" in description + + # Should exclude Args section + assert "This should be excluded" not in description @pytest.mark.asyncio @@ -1450,3 +1611,214 @@ def test_function_tool_metadata_validate_signature_missing_context_config(): @strands.tool def my_tool(tool_context: ToolContext): pass + + +def test_tool_decorator_annotated_string_description(): + """Test tool decorator with Annotated type hints for descriptions.""" + + @strands.tool + def annotated_tool( + name: Annotated[str, "The user's full name"], + age: Annotated[int, "The user's age in years"], + city: str, # No annotation - should use docstring or generic + ) -> str: + """Tool with annotated parameters. + + Args: + city: The user's city (from docstring) + """ + return f"{name}, {age}, {city}" + + spec = annotated_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Check that annotated descriptions are used + assert schema["properties"]["name"]["description"] == "The user's full name" + assert schema["properties"]["age"]["description"] == "The user's age in years" + + # Check that docstring is still used for non-annotated params + assert schema["properties"]["city"]["description"] == "The user's city (from docstring)" + + # Verify all are required + assert set(schema["required"]) == {"name", "age", "city"} + + +def test_tool_decorator_annotated_pydantic_field_constraints(): + """Test that using pydantic.Field in Annotated raises a NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def field_annotated_tool( + email: Annotated[str, Field(description="User's email address", pattern=r"^[\w\.-]+@[\w\.-]+\\.w+$")], + score: Annotated[int, Field(description="Score between 0-100", ge=0, le=100)] = 50, + ) -> str: + """Tool with Pydantic Field annotations.""" + return f"{email}: {score}" + + +def test_tool_decorator_annotated_overrides_docstring(): + """Test that Annotated descriptions override docstring descriptions.""" + + @strands.tool + def override_tool(param: Annotated[str, "Description from annotation"]) -> str: + """Tool with both annotation and docstring. + + Args: + param: Description from docstring (should be overridden) + """ + return param + + spec = override_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Annotated description should win + assert schema["properties"]["param"]["description"] == "Description from annotation" + + +def test_tool_decorator_annotated_optional_type(): + """Test tool with Optional types in Annotated.""" + + @strands.tool + def optional_annotated_tool( + required: Annotated[str, "Required parameter"], optional: Annotated[Optional[str], "Optional parameter"] = None + ) -> str: + """Tool with optional annotated parameter.""" + return f"{required}, {optional}" + + spec = optional_annotated_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Check descriptions + assert schema["properties"]["required"]["description"] == "Required parameter" + assert schema["properties"]["optional"]["description"] == "Optional parameter" + + # Check required list + assert "required" in schema["required"] + assert "optional" not in schema["required"] + + +def test_tool_decorator_annotated_complex_types(): + """Test tool with complex types in Annotated.""" + + @strands.tool + def complex_annotated_tool( + tags: Annotated[List[str], "List of tag strings"], config: Annotated[Dict[str, Any], "Configuration dictionary"] + ) -> str: + """Tool with complex annotated types.""" + return f"Tags: {len(tags)}, Config: {len(config)}" + + spec = complex_annotated_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Check descriptions + assert schema["properties"]["tags"]["description"] == "List of tag strings" + assert schema["properties"]["config"]["description"] == "Configuration dictionary" + + # Check types are preserved + assert schema["properties"]["tags"]["type"] == "array" + assert schema["properties"]["config"]["type"] == "object" + + +def test_tool_decorator_annotated_mixed_styles(): + """Test that using pydantic.Field in a mixed-style annotation raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def mixed_tool( + plain: str, + annotated_str: Annotated[str, "String description"], + annotated_field: Annotated[int, Field(description="Field description", ge=0)], + docstring_only: int, + ) -> str: + """Tool with mixed parameter styles. + + Args: + plain: Plain parameter description + docstring_only: Docstring description for this param + """ + return "mixed" + + +@pytest.mark.asyncio +async def test_tool_decorator_annotated_execution(alist): + """Test that annotated tools execute correctly.""" + + @strands.tool + def execution_test(name: Annotated[str, "User name"], count: Annotated[int, "Number of times"] = 1) -> str: + """Test execution with annotations.""" + return f"Hello {name} " * count + + # Test tool use + tool_use = {"toolUseId": "test-id", "input": {"name": "Alice", "count": 2}} + stream = execution_test.stream(tool_use, {}) + + result = (await alist(stream))[-1] + assert result["tool_result"]["status"] == "success" + assert "Hello Alice Hello Alice" in result["tool_result"]["content"][0]["text"] + + # Test direct call + direct_result = execution_test("Bob", 3) + assert direct_result == "Hello Bob Hello Bob Hello Bob " + + +def test_tool_decorator_annotated_no_description_fallback(): + """Test that Annotated with a Field raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def no_desc_annotated( + param: Annotated[str, Field()], # Field without description + ) -> str: + """Tool with Annotated but no description. + + Args: + param: Docstring description + """ + return param + + +def test_tool_decorator_annotated_empty_string_description(): + """Test handling of empty string descriptions in Annotated.""" + + @strands.tool + def empty_desc_tool( + param: Annotated[str, ""], # Empty string description + ) -> str: + """Tool with empty annotation description. + + Args: + param: Docstring description + """ + return param + + spec = empty_desc_tool.tool_spec + schema = spec["inputSchema"]["json"] + + # Empty string is still a valid description, should not fall back + assert schema["properties"]["param"]["description"] == "" + + +@pytest.mark.asyncio +async def test_tool_decorator_annotated_validation_error(alist): + """Test that validation works correctly with annotated parameters.""" + + @strands.tool + def validation_tool(age: Annotated[int, "User age"]) -> str: + """Tool for validation testing.""" + return f"Age: {age}" + + # Test with wrong type + tool_use = {"toolUseId": "test-id", "input": {"age": "not an int"}} + stream = validation_tool.stream(tool_use, {}) + + result = (await alist(stream))[-1] + assert result["tool_result"]["status"] == "error" + + +def test_tool_decorator_annotated_field_with_inner_default(): + """Test that a default value in an Annotated Field raises NotImplementedError.""" + with pytest.raises(NotImplementedError, match="Using pydantic.Field within Annotated is not yet supported"): + + @strands.tool + def inner_default_tool(name: str, level: Annotated[int, Field(description="A level value", default=10)]) -> str: + return f"{name} is at level {level}" diff --git a/tests/strands/types/test_interrupt.py b/tests/strands/types/test_interrupt.py index ade0fa5e88..ad31384b63 100644 --- a/tests/strands/types/test_interrupt.py +++ b/tests/strands/types/test_interrupt.py @@ -2,8 +2,7 @@ import pytest -from strands.agent.interrupt import InterruptState -from strands.interrupt import Interrupt, InterruptException +from strands.interrupt import Interrupt, InterruptException, _InterruptState from strands.types.interrupt import _Interruptible @@ -20,7 +19,7 @@ def interrupt(): @pytest.fixture def agent(): instance = unittest.mock.Mock() - instance._interrupt_state = InterruptState() + instance._interrupt_state = _InterruptState() return instance diff --git a/tests/strands/types/test_session.py b/tests/strands/types/test_session.py index 26d4062e40..3e53607429 100644 --- a/tests/strands/types/test_session.py +++ b/tests/strands/types/test_session.py @@ -3,8 +3,8 @@ from uuid import uuid4 from strands.agent.conversation_manager.null_conversation_manager import NullConversationManager -from strands.agent.interrupt import InterruptState from strands.agent.state import AgentState +from strands.interrupt import _InterruptState from strands.types.session import ( Session, SessionAgent, @@ -101,7 +101,7 @@ def test_session_agent_from_agent(): agent.agent_id = "a1" agent.conversation_manager = unittest.mock.Mock(get_state=lambda: {"test": "conversation"}) agent.state = AgentState({"test": "state"}) - agent._interrupt_state = InterruptState(interrupts={}, context={}, activated=False) + agent._interrupt_state = _InterruptState(interrupts={}, context={}, activated=False) tru_session_agent = SessionAgent.from_agent(agent) exp_session_agent = SessionAgent( @@ -127,5 +127,5 @@ def test_session_agent_initialize_internal_state(): session_agent.initialize_internal_state(agent) tru_interrupt_state = agent._interrupt_state - exp_interrupt_state = InterruptState(interrupts={}, context={"test": "init"}, activated=False) + exp_interrupt_state = _InterruptState(interrupts={}, context={"test": "init"}, activated=False) assert tru_interrupt_state == exp_interrupt_state diff --git a/tests_integ/hooks/__init__.py b/tests_integ/hooks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests_integ/hooks/multiagent/__init__.py b/tests_integ/hooks/multiagent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests_integ/hooks/multiagent/test_events.py b/tests_integ/hooks/multiagent/test_events.py new file mode 100644 index 0000000000..e8039444f1 --- /dev/null +++ b/tests_integ/hooks/multiagent/test_events.py @@ -0,0 +1,122 @@ +import pytest + +from strands import Agent +from strands.experimental.hooks.multiagent import ( + AfterMultiAgentInvocationEvent, + AfterNodeCallEvent, + BeforeMultiAgentInvocationEvent, + BeforeNodeCallEvent, + MultiAgentInitializedEvent, +) +from strands.hooks import HookProvider +from strands.multiagent import GraphBuilder, Swarm + + +@pytest.fixture +def callback_names(): + return [] + + +@pytest.fixture +def hook_provider(callback_names): + class TestHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(AfterMultiAgentInvocationEvent, self.after_multi_agent_invocation) + registry.add_callback(AfterMultiAgentInvocationEvent, self.after_multi_agent_invocation_async) + registry.add_callback(AfterNodeCallEvent, self.after_node_call) + registry.add_callback(AfterNodeCallEvent, self.after_node_call_async) + registry.add_callback(BeforeMultiAgentInvocationEvent, self.before_multi_agent_invocation) + registry.add_callback(BeforeMultiAgentInvocationEvent, self.before_multi_agent_invocation_async) + registry.add_callback(BeforeNodeCallEvent, self.before_node_call) + registry.add_callback(BeforeNodeCallEvent, self.before_node_call_async) + registry.add_callback(MultiAgentInitializedEvent, self.multi_agent_initialized_event) + registry.add_callback(MultiAgentInitializedEvent, self.multi_agent_initialized_event_async) + + def after_multi_agent_invocation(self, _event): + callback_names.append("after_multi_agent_invocation") + + async def after_multi_agent_invocation_async(self, _event): + callback_names.append("after_multi_agent_invocation_async") + + def after_node_call(self, _event): + callback_names.append("after_node_call") + + async def after_node_call_async(self, _event): + callback_names.append("after_node_call_async") + + def before_multi_agent_invocation(self, _event): + callback_names.append("before_multi_agent_invocation") + + async def before_multi_agent_invocation_async(self, _event): + callback_names.append("before_multi_agent_invocation_async") + + def before_node_call(self, _event): + callback_names.append("before_node_call") + + async def before_node_call_async(self, _event): + callback_names.append("before_node_call_async") + + def multi_agent_initialized_event(self, _event): + callback_names.append("multi_agent_initialized_event") + + async def multi_agent_initialized_event_async(self, _event): + callback_names.append("multi_agent_initialized_event_async") + + return TestHook() + + +@pytest.fixture +def agent(): + return Agent() + + +@pytest.fixture +def graph(agent, hook_provider): + builder = GraphBuilder() + builder.add_node(agent, "agent") + builder.set_entry_point("agent") + builder.set_hook_providers([hook_provider]) + return builder.build() + + +@pytest.fixture +def swarm(agent, hook_provider): + return Swarm([agent], hooks=[hook_provider]) + + +def test_graph_events(graph, callback_names): + graph("Hello") + + tru_callback_names = callback_names + exp_callback_names = [ + "multi_agent_initialized_event", + "multi_agent_initialized_event_async", + "before_multi_agent_invocation", + "before_multi_agent_invocation_async", + "before_node_call", + "before_node_call_async", + "after_node_call_async", + "after_node_call", + "after_multi_agent_invocation_async", + "after_multi_agent_invocation", + ] + assert tru_callback_names == exp_callback_names + + +def test_swarm_events(swarm, callback_names): + swarm("Hello") + + tru_callback_names = callback_names + exp_callback_names = [ + "multi_agent_initialized_event", + "multi_agent_initialized_event_async", + "before_multi_agent_invocation", + "before_multi_agent_invocation_async", + "before_node_call", + "before_node_call_async", + "after_node_call_async", + "after_node_call", + "after_multi_agent_invocation_async", + "after_multi_agent_invocation", + ] + assert tru_callback_names == exp_callback_names diff --git a/tests_integ/hooks/test_events.py b/tests_integ/hooks/test_events.py new file mode 100644 index 0000000000..25971ecb00 --- /dev/null +++ b/tests_integ/hooks/test_events.py @@ -0,0 +1,138 @@ +import pytest + +from strands import Agent, tool +from strands.hooks import ( + AfterInvocationEvent, + AfterModelCallEvent, + AfterToolCallEvent, + AgentInitializedEvent, + BeforeInvocationEvent, + BeforeModelCallEvent, + BeforeToolCallEvent, + HookProvider, + MessageAddedEvent, +) + + +@pytest.fixture +def callback_names(): + return [] + + +@pytest.fixture +def hook_provider(callback_names): + class TestHook(HookProvider): + def register_hooks(self, registry): + registry.add_callback(AfterInvocationEvent, self.after_invocation) + registry.add_callback(AfterInvocationEvent, self.after_invocation_async) + registry.add_callback(AfterModelCallEvent, self.after_model_call) + registry.add_callback(AfterModelCallEvent, self.after_model_call_async) + registry.add_callback(AfterToolCallEvent, self.after_tool_call) + registry.add_callback(AfterToolCallEvent, self.after_tool_call_async) + registry.add_callback(AgentInitializedEvent, self.agent_initialized) + registry.add_callback(BeforeInvocationEvent, self.before_invocation) + registry.add_callback(BeforeInvocationEvent, self.before_invocation_async) + registry.add_callback(BeforeModelCallEvent, self.before_model_call) + registry.add_callback(BeforeModelCallEvent, self.before_model_call_async) + registry.add_callback(BeforeToolCallEvent, self.before_tool_call) + registry.add_callback(BeforeToolCallEvent, self.before_tool_call_async) + registry.add_callback(MessageAddedEvent, self.message_added) + registry.add_callback(MessageAddedEvent, self.message_added_async) + + def after_invocation(self, _event): + callback_names.append("after_invocation") + + async def after_invocation_async(self, _event): + callback_names.append("after_invocation_async") + + def after_model_call(self, _event): + callback_names.append("after_model_call") + + async def after_model_call_async(self, _event): + callback_names.append("after_model_call_async") + + def after_tool_call(self, _event): + callback_names.append("after_tool_call") + + async def after_tool_call_async(self, _event): + callback_names.append("after_tool_call_async") + + def agent_initialized(self, _event): + callback_names.append("agent_initialized") + + async def agent_initialized_async(self, _event): + callback_names.append("agent_initialized_async") + + def before_invocation(self, _event): + callback_names.append("before_invocation") + + async def before_invocation_async(self, _event): + callback_names.append("before_invocation_async") + + def before_model_call(self, _event): + callback_names.append("before_model_call") + + async def before_model_call_async(self, _event): + callback_names.append("before_model_call_async") + + def before_tool_call(self, _event): + callback_names.append("before_tool_call") + + async def before_tool_call_async(self, _event): + callback_names.append("before_tool_call_async") + + def message_added(self, _event): + callback_names.append("message_added") + + async def message_added_async(self, _event): + callback_names.append("message_added_async") + + return TestHook() + + +@pytest.fixture +def time_tool(): + @tool(name="time_tool") + def tool_() -> str: + return "12:00" + + return tool_ + + +@pytest.fixture +def agent(hook_provider, time_tool): + return Agent(hooks=[hook_provider], tools=[time_tool]) + + +def test_events(agent, callback_names): + agent("What time is it?") + + tru_callback_names = callback_names + exp_callback_names = [ + "agent_initialized", + "before_invocation", + "before_invocation_async", + "message_added", + "message_added_async", + "before_model_call", + "before_model_call_async", + "after_model_call_async", + "after_model_call", + "message_added", + "message_added_async", + "before_tool_call", + "before_tool_call_async", + "after_tool_call_async", + "after_tool_call", + "message_added", + "message_added_async", + "before_model_call", + "before_model_call_async", + "after_model_call_async", + "after_model_call", + "message_added", + "message_added_async", + "after_invocation_async", + "after_invocation", + ] + assert tru_callback_names == exp_callback_names diff --git a/tests_integ/mcp/test_mcp_client.py b/tests_integ/mcp/test_mcp_client.py index 2c9bb73e17..35cfd7e86e 100644 --- a/tests_integ/mcp/test_mcp_client.py +++ b/tests_integ/mcp/test_mcp_client.py @@ -420,3 +420,70 @@ def transport_callback() -> MCPTransport: result = await streamable_http_client.call_tool_async(tool_use_id="123", name="timeout_tool") assert result["status"] == "error" assert result["content"][0]["text"] == "Tool execution failed: Connection closed" + + +def start_5xx_proxy_for_tool_calls(target_url: str, proxy_port: int): + """Starts a proxy that throws a 5XX when a tool call is invoked""" + import aiohttp + from aiohttp import web + + async def proxy_handler(request): + url = f"{target_url}{request.path_qs}" + + async with aiohttp.ClientSession() as session: + data = await request.read() + + if "tools/call" in f"{data}": + return web.Response(status=500, text="Internal Server Error") + + async with session.request( + method=request.method, url=url, headers=request.headers, data=data, allow_redirects=False + ) as resp: + print(f"Got request to {url} {data}") + response = web.StreamResponse(status=resp.status, headers=resp.headers) + await response.prepare(request) + + async for chunk in resp.content.iter_chunked(8192): + await response.write(chunk) + + return response + + app = web.Application() + app.router.add_route("*", "/{path:.*}", proxy_handler) + + web.run_app(app, host="127.0.0.1", port=proxy_port) + + +@pytest.mark.asyncio +async def test_streamable_http_mcp_client_with_500_error(): + import asyncio + import multiprocessing + + server_thread = threading.Thread( + target=start_comprehensive_mcp_server, kwargs={"transport": "streamable-http", "port": 8001}, daemon=True + ) + server_thread.start() + + proxy_process = multiprocessing.Process( + target=start_5xx_proxy_for_tool_calls, kwargs={"target_url": "http://127.0.0.1:8001", "proxy_port": 8002} + ) + proxy_process.start() + + try: + await asyncio.sleep(2) # wait for server to startup completely + + def transport_callback() -> MCPTransport: + return streamablehttp_client(url="http://127.0.0.1:8002/mcp") + + streamable_http_client = MCPClient(transport_callback) + with pytest.raises(RuntimeError, match="Connection to the MCP server was closed"): + with streamable_http_client: + result = await streamable_http_client.call_tool_async( + tool_use_id="123", name="calculator", arguments={"x": 3, "y": 4} + ) + finally: + proxy_process.terminate() + proxy_process.join() + + assert result["status"] == "error" + assert result["content"][0]["text"] == "Tool execution failed: Connection to the MCP server was closed" diff --git a/tests_integ/models/test_model_anthropic.py b/tests_integ/models/test_model_anthropic.py index 62a95d06d8..9a0d19dff6 100644 --- a/tests_integ/models/test_model_anthropic.py +++ b/tests_integ/models/test_model_anthropic.py @@ -5,7 +5,10 @@ import strands from strands import Agent +from strands.agent import NullConversationManager from strands.models.anthropic import AnthropicModel +from strands.types.content import ContentBlock, Message +from strands.types.exceptions import ContextWindowOverflowException """ These tests only run if we have the anthropic api key @@ -152,3 +155,30 @@ def test_structured_output_multi_modal_input(agent, yellow_img, yellow_color): tru_color = agent.structured_output(type(yellow_color), content) exp_color = yellow_color assert tru_color == exp_color + + +@pytest.mark.asyncio +def test_input_and_max_tokens_exceed_context_limit(): + """Test that triggers 'input length and max_tokens exceed context limit' error.""" + + # Note that this test is written specifically in a style that allows us to swap out conversation_manager and + # verify behavior + + model = AnthropicModel( + model_id="claude-sonnet-4-20250514", + max_tokens=64000, + ) + + large_message = "This is a very long text. " * 10000 + + messages = [ + Message(role="user", content=[ContentBlock(text=large_message)]), + Message(role="assistant", content=[ContentBlock(text=large_message)]), + Message(role="user", content=[ContentBlock(text=large_message)]), + ] + + # NullConversationManager will propagate ContextWindowOverflowException directly instead of handling it + agent = Agent(model=model, conversation_manager=NullConversationManager()) + + with pytest.raises(ContextWindowOverflowException): + agent(messages) diff --git a/tests_integ/models/test_model_litellm.py b/tests_integ/models/test_model_litellm.py index b348c29f46..f177c08a49 100644 --- a/tests_integ/models/test_model_litellm.py +++ b/tests_integ/models/test_model_litellm.py @@ -211,3 +211,25 @@ def test_structured_output_unsupported_model(model, nested_weather): # Verify that the tool method was called and schema method was not mock_tool.assert_called_once() mock_schema.assert_not_called() + + +@pytest.mark.asyncio +async def test_cache_read_tokens_multi_turn(model): + """Integration test for cache read tokens in multi-turn conversation.""" + from strands.types.content import SystemContentBlock + + system_prompt_content: list[SystemContentBlock] = [ + # Caching only works when prompts are large + {"text": "You are a helpful assistant. Always be concise." * 200}, + {"cachePoint": {"type": "default"}}, + ] + + agent = Agent(model=model, system_prompt=system_prompt_content) + + # First turn - establishes cache + agent("Hello, what's 2+2?") + result = agent("What's 3+3?") + result.metrics.accumulated_usage["cacheReadInputTokens"] + + assert result.metrics.accumulated_usage["cacheReadInputTokens"] > 0 + assert result.metrics.accumulated_usage["cacheWriteInputTokens"] > 0 diff --git a/tests_integ/models/test_model_openai.py b/tests_integ/models/test_model_openai.py index 7beb3013cd..feb591d1ad 100644 --- a/tests_integ/models/test_model_openai.py +++ b/tests_integ/models/test_model_openai.py @@ -231,3 +231,29 @@ def test_content_blocks_handling(model): result = agent(content) assert "4" in result.message["content"][0]["text"] + + +def test_system_prompt_content_integration(model): + """Integration test for system_prompt_content parameter.""" + from strands.types.content import SystemContentBlock + + system_prompt_content: list[SystemContentBlock] = [ + {"text": "You are a helpful assistant that always responds with 'SYSTEM_TEST_RESPONSE'."} + ] + + agent = Agent(model=model, system_prompt=system_prompt_content) + result = agent("Hello") + + # The response should contain our specific system prompt instruction + assert "SYSTEM_TEST_RESPONSE" in result.message["content"][0]["text"] + + +def test_system_prompt_backward_compatibility_integration(model): + """Integration test for backward compatibility with system_prompt parameter.""" + system_prompt = "You are a helpful assistant that always responds with 'BACKWARD_COMPAT_TEST'." + + agent = Agent(model=model, system_prompt=system_prompt) + result = agent("Hello") + + # The response should contain our specific system prompt instruction + assert "BACKWARD_COMPAT_TEST" in result.message["content"][0]["text"] diff --git a/tests_integ/tools/__init__.py b/tests_integ/tools/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests_integ/tools/test_thread_context.py b/tests_integ/tools/test_thread_context.py new file mode 100644 index 0000000000..b86c9b2c0c --- /dev/null +++ b/tests_integ/tools/test_thread_context.py @@ -0,0 +1,47 @@ +import contextvars + +import pytest + +from strands import Agent, tool + + +@pytest.fixture +def result(): + return {} + + +@pytest.fixture +def contextvar(): + return contextvars.ContextVar("agent") + + +@pytest.fixture +def context_tool(result, contextvar): + @tool(name="context_tool") + def tool_(): + result["context_value"] = contextvar.get("local_context") + + return tool_ + + +@pytest.fixture +def agent(context_tool): + return Agent(tools=[context_tool]) + + +def test_agent_invoke_context_sharing(result, contextvar, agent): + contextvar.set("shared_context") + agent("Execute context_tool") + + tru_context = result["context_value"] + exp_context = contextvar.get() + assert tru_context == exp_context + + +def test_tool_call_context_sharing(result, contextvar, agent): + contextvar.set("shared_context") + agent.tool.context_tool() + + tru_context = result["context_value"] + exp_context = contextvar.get() + assert tru_context == exp_context