diff --git a/docs/decisions/0001-agent-run-response.md b/docs/decisions/0001-agent-run-response.md index 12724aca3ae..6ffebe7e4f3 100644 --- a/docs/decisions/0001-agent-run-response.md +++ b/docs/decisions/0001-agent-run-response.md @@ -4,8 +4,8 @@ status: accepted contact: westey-m date: 2025-07-10 {YYYY-MM-DD when the decision was last updated} deciders: sergeymenshykh, markwallace, rbarreto, dmytrostruk, westey-m, eavanvalkenburg, stephentoub -consulted: -informed: +consulted: +informed: --- # Agent Run Responses Design @@ -64,7 +64,7 @@ Approaches observed from the compared SDKs: | AutoGen | **Approach 1** Separates messages into Agent-Agent (maps to Primary) and Internal (maps to Secondary) and these are returned as separate properties on the agent response object. See [types of messages](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/messages.html#types-of-messages) and [Response](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.base.html#autogen_agentchat.base.Response) | **Approach 2** Returns a stream of internal events and the last item is a Response object. See [ChatAgent.on_messages_stream](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.base.html#autogen_agentchat.base.ChatAgent.on_messages_stream) | | OpenAI Agent SDK | **Approach 1** Separates new_items (Primary+Secondary) from final output (Primary) as separate properties on the [RunResult](https://github.com/openai/openai-agents-python/blob/main/src/agents/result.py#L39) | **Approach 1** Similar to non-streaming, has a way of streaming updates via a method on the response object which includes all data, and then a separate final output property on the response object which is populated only when the run is complete. See [RunResultStreaming](https://github.com/openai/openai-agents-python/blob/main/src/agents/result.py#L136) | | Google ADK | **Approach 2** [Emits events](https://google.github.io/adk-docs/runtime/#step-by-step-breakdown) with [FinalResponse](https://github.com/google/adk-java/blob/main/core/src/main/java/com/google/adk/events/Event.java#L232) true (Primary) / false (Secondary) and callers have to filter out those with false to get just the final response message | **Approach 2** Similar to non-streaming except [events](https://google.github.io/adk-docs/runtime/#streaming-vs-non-streaming-output-partialtrue) are emitted with [Partial](https://github.com/google/adk-java/blob/main/core/src/main/java/com/google/adk/events/Event.java#L133) true to indicate that they are streaming messages. A final non partial event is also emitted. | -| AWS (Strands) | **Approach 3** Returns an [AgentResult](https://strandsagents.com/docs/api/python/strands.agent.agent_result/#agentresult) (Primary) with messages and a reason for the run's completion. | **Approach 2** [Streams events](https://strandsagents.com/docs/user-guide/concepts/streaming/) (Primary+Secondary) including, response text, current_tool_use, even data from "callbacks" (strands plugins) | +| AWS (Strands) | **Approach 3** Returns an [AgentResult](https://strandsagents.com/docs/api/python/strands.agent.agent_result/) (Primary) with messages and a reason for the run's completion. | **Approach 2** [Streams events](https://strandsagents.com/docs/api/python/strands.agent.agent/) (Primary+Secondary) including, response text, current_tool_use, even data from "callbacks" (strands plugins) | | LangGraph | **Approach 2** A mixed list of all [messages](https://langchain-ai.github.io/langgraph/agents/run_agents/#output-format) | **Approach 2** A mixed list of all [messages](https://langchain-ai.github.io/langgraph/agents/run_agents/#output-format) | | Agno | **Combination of various approaches** Returns a [RunResponse](https://docs.agno.com/reference/agents/run-response) object with text content, messages (essentially chat history including inputs and instructions), reasoning and thinking text properties. Secondary events could potentially be extracted from messages. | **Approach 2** Returns [RunResponseEvent](https://docs.agno.com/reference/agents/run-response#runresponseevent-types-and-attributes) objects including tool call, memory update, etc, information, where the [RunResponseCompletedEvent](https://docs.agno.com/reference/agents/run-response#runresponsecompletedevent) has similar properties to RunResponse| | A2A | **Approach 3** Returns a [Task or Message](https://a2aproject.github.io/A2A/latest/specification/#71-messagesend) where the message is the final result (Primary) and task is a reference to a long running process. | **Approach 2** Returns a [stream](https://a2aproject.github.io/A2A/latest/specification/#72-messagestream) that contains task updates (Secondary) and a final message (Primary) | @@ -496,7 +496,7 @@ We need to decide what AIContent types, each agent response type will be mapped |-|-| | AutoGen | **Approach 1** Supports [configuring an agent](https://microsoft.github.io/autogen/stable/user-guide/agentchat-user-guide/tutorial/agents.html#structured-output) at agent creation. | | Google ADK | **Approach 1** Both [input and output schemas can be specified for LLM Agents](https://google.github.io/adk-docs/agents/llm-agents/#structuring-data-input_schema-output_schema-output_key) at construction time. This option is specific to this agent type and other agent types do not necessarily support | -| AWS (Strands) | **Approach 2** Supports a special invocation method called [structured_output](https://strandsagents.com/docs/user-guide/concepts/agents/structured-output/) | +| AWS (Strands) | **Approach 2** Supports a special invocation method called [structured_output](https://strandsagents.com/docs/api/python/strands.agent.agent/) | | LangGraph | **Approach 1** Supports [configuring an agent](https://langchain-ai.github.io/langgraph/agents/agents/?h=structured#6-configure-structured-output) at agent construction time, and a [structured response](https://langchain-ai.github.io/langgraph/agents/run_agents/#output-format) can be retrieved as a special property on the agent response | | Agno | **Approach 1** Supports [configuring an agent](https://docs.agno.com/input-output/structured-output/agent) at agent construction time | | A2A | **Informal Approach 2** Doesn't formally support schema negotiation, but [hints can be provided via metadata](https://a2a-protocol.org/latest/specification/#97-structured-data-exchange-requesting-and-providing-json) at invocation time | @@ -508,7 +508,7 @@ We need to decide what AIContent types, each agent response type will be mapped |-|-| | AutoGen | Supports a [stop reason](https://microsoft.github.io/autogen/stable/reference/python/autogen_agentchat.base.html#autogen_agentchat.base.TaskResult.stop_reason) which is a freeform text string | | Google ADK | [No equivalent present](https://github.com/google/adk-python/blob/main/src/google/adk/events/event.py) | -| AWS (Strands) | Exposes a `stop_reason` property on the [AgentResult](https://strandsagents.com/docs/api/python/strands.agent.agent_result/#agentresult) class with options that are tied closely to LLM operations. | +| AWS (Strands) | Exposes a [stop_reason](https://strandsagents.com/docs/api/python/strands.types.event_loop/) property on the [AgentResult](https://strandsagents.com/docs/api/python/strands.agent.agent_result/) class with options that are tied closely to LLM operations. | | LangGraph | No equivalent present, output contains only [messages](https://langchain-ai.github.io/langgraph/agents/run_agents/#output-format) | | Agno | [No equivalent present](https://docs.agno.com/reference/agents/run-response) | | A2A | No equivalent present, response only contains a [message](https://a2a-protocol.org/latest/specification/#64-message-object) or [task](https://a2a-protocol.org/latest/specification/#61-task-object). | diff --git a/python/CODING_STANDARD.md b/python/CODING_STANDARD.md index ccb8e058e3c..86115926923 100644 --- a/python/CODING_STANDARD.md +++ b/python/CODING_STANDARD.md @@ -127,7 +127,12 @@ def create_agent(name: str, tool_mode: Literal['auto', 'required', 'none'] | Cha Avoid `**kwargs` unless absolutely necessary. It should only be used as an escape route, not for well-known flows of data: - **Prefer named parameters**: If there are known extra arguments being passed, use explicit named parameters instead of kwargs +- **Prefer purpose-specific buckets over generic kwargs**: If a flexible payload is still needed, use an explicit named parameter such as `additional_properties`, `function_invocation_kwargs`, or `client_kwargs` rather than a blanket `**kwargs` - **Subclassing support**: kwargs is acceptable in methods that are part of classes designed for subclassing, allowing subclass-defined kwargs to pass through without issues. In this case, clearly document that kwargs exists for subclass extensibility and not for passing arbitrary data +- **Make known flows explicit first**: For abstract hooks, move known data flows into explicit parameters before leaving `**kwargs` behind for subclass extensibility (for example, prefer `state=` explicitly instead of passing it through kwargs) +- **Prefer explicit metadata containers**: For constructors that expose metadata, prefer an explicit `additional_properties` parameter. +- **Keep SDK passthroughs narrow and documented**: A kwargs escape hatch may be acceptable for provider helper APIs that pass through to a large or unstable external SDK surface, but it should be documented as SDK passthrough and revisited regularly +- **Do not keep passthrough kwargs on wrappers that do not use them**: Convenience wrappers and session helpers should not accept generic kwargs merely to forward or ignore them - **Remove when possible**: In other cases, removing kwargs is likely better than keeping it - **Separate kwargs by purpose**: When combining kwargs for multiple purposes, use specific parameters like `client_kwargs: dict[str, Any]` instead of mixing everything in `**kwargs` - **Always document**: If kwargs must be used, always document how it's used, either by referencing external documentation or explaining its purpose diff --git a/python/packages/a2a/agent_framework_a2a/_agent.py b/python/packages/a2a/agent_framework_a2a/_agent.py index 31fac386b35..c954c90fc00 100644 --- a/python/packages/a2a/agent_framework_a2a/_agent.py +++ b/python/packages/a2a/agent_framework_a2a/_agent.py @@ -6,7 +6,7 @@ import json import re import uuid -from collections.abc import AsyncIterable, Awaitable, Sequence +from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence from typing import Any, Final, Literal, TypeAlias, overload import httpx @@ -114,9 +114,10 @@ def __init__( """Initialize the A2AAgent. Keyword Args: - name: The name of the agent. + name: The name of the agent. Defaults to agent_card.name if agent_card is provided. id: The unique identifier for the agent, will be created automatically if not provided. - description: A brief description of the agent's purpose. + description: A brief description of the agent's purpose. Defaults to agent_card.description + if agent_card is provided. agent_card: The agent card for the agent. url: The URL for the A2A server. client: The A2A client for the agent. @@ -127,6 +128,13 @@ def __init__( 10.0s write, 5.0s pool - optimized for A2A operations). kwargs: any additional properties, passed to BaseAgent. """ + # Default name/description from agent_card when not explicitly provided + if agent_card is not None: + if name is None: + name = agent_card.name + if description is None: + description = agent_card.description + super().__init__(id=id, name=name, description=description, **kwargs) self._http_client: httpx.AsyncClient | None = http_client self._timeout_config = self._create_timeout_config(timeout) @@ -218,6 +226,8 @@ def run( *, stream: Literal[False] = ..., session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, continuation_token: A2AContinuationToken | None = None, background: bool = False, **kwargs: Any, @@ -230,17 +240,21 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, continuation_token: A2AContinuationToken | None = None, background: bool = False, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... - def run( + def run( # pyright: ignore[reportIncompatibleMethodOverride] self, messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, continuation_token: A2AContinuationToken | None = None, background: bool = False, **kwargs: Any, @@ -253,17 +267,23 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). + function_invocation_kwargs: Present for compatibility with the shared agent interface. + A2AAgent does not use these values directly. + client_kwargs: Present for compatibility with the shared agent interface. + A2AAgent does not use these values directly. + kwargs: Additional compatibility keyword arguments. + A2AAgent does not use these values directly. continuation_token: Optional token to resume a long-running task instead of starting a new one. background: When True, in-progress task updates surface continuation tokens so the caller can poll or resubscribe later. When False (default), the agent internally waits for the task to complete. - kwargs: Additional keyword arguments. Returns: When stream=False: An Awaitable[AgentResponse]. When stream=True: A ResponseStream of AgentResponseUpdate items. """ + del function_invocation_kwargs, client_kwargs, kwargs if continuation_token is not None: a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe( TaskIdParams(id=continuation_token["task_id"]) diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index 61123df5aba..ce7bb42a48a 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -145,6 +145,54 @@ def test_a2a_agent_initialization_with_client(mock_a2a_client: MockA2AClient) -> assert agent.client == mock_a2a_client +def test_a2a_agent_defaults_name_description_from_agent_card(mock_a2a_client: MockA2AClient) -> None: + """Test A2AAgent defaults name and description from agent_card when not explicitly provided.""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "Card Agent Name" + mock_card.description = "Card agent description" + + agent = A2AAgent(agent_card=mock_card, client=mock_a2a_client, http_client=None) + + assert agent.name == "Card Agent Name" + assert agent.description == "Card agent description" + + +def test_a2a_agent_explicit_name_description_overrides_agent_card(mock_a2a_client: MockA2AClient) -> None: + """Test that explicit name/description take precedence over agent_card values.""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "Card Agent Name" + mock_card.description = "Card agent description" + + agent = A2AAgent( + name="Explicit Name", + description="Explicit description", + agent_card=mock_card, + client=mock_a2a_client, + http_client=None, + ) + + assert agent.name == "Explicit Name" + assert agent.description == "Explicit description" + + +def test_a2a_agent_empty_string_name_description_not_overridden(mock_a2a_client: MockA2AClient) -> None: + """Test that explicitly provided empty strings are not overridden by agent_card values.""" + mock_card = MagicMock(spec=AgentCard) + mock_card.name = "Card Agent Name" + mock_card.description = "Card agent description" + + agent = A2AAgent( + name="", + description="", + agent_card=mock_card, + client=mock_a2a_client, + http_client=None, + ) + + assert agent.name == "" + assert agent.description == "" + + def test_a2a_agent_initialization_without_client_raises_error() -> None: """Test A2AAgent initialization without client or URL raises ValueError.""" with raises(ValueError, match="Either agent_card or url must be provided"): @@ -561,6 +609,8 @@ def test_transport_negotiation_both_fail() -> None: # Create a mock agent card mock_agent_card = MagicMock(spec=AgentCard) mock_agent_card.url = "http://test-agent.example.com" + mock_agent_card.name = "Test Agent" + mock_agent_card.description = "A test agent" # Mock the factory to simulate both primary and fallback failures mock_factory = MagicMock() diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index 7188eb739c6..d2fb59bbb6b 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -220,7 +220,6 @@ def __init__( additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, - **kwargs: Any, ) -> None: """Initialize the AG-UI chat client. @@ -231,13 +230,11 @@ def __init__( additional_properties: Additional properties to store middleware: Optional middleware to apply to the client. function_invocation_configuration: Optional function invocation configuration override. - **kwargs: Additional arguments passed to BaseChatClient """ super().__init__( additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) self._http_service = AGUIHttpService( endpoint=endpoint, diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py index 442138649a9..585bcb5c3e7 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py @@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any from agent_framework import BaseChatClient +from agent_framework._tools import _append_unique_tools # pyright: ignore[reportPrivateUsage] if TYPE_CHECKING: from agent_framework import SupportsAgentRun @@ -22,7 +23,7 @@ def _collect_mcp_tool_functions(mcp_tools: list[Any]) -> list[Any]: mcp_tools: List of MCP tool instances. Returns: - List of functions from connected MCP tools. + Functions from connected MCP tools. """ functions: list[Any] = [] for mcp_tool in mcp_tools: @@ -56,7 +57,11 @@ def collect_server_tools(agent: SupportsAgentRun) -> list[Any]: # Include functions from connected MCP tools (only available on Agent) mcp_tools = getattr(agent, "mcp_tools", None) if mcp_tools: - server_tools.extend(_collect_mcp_tool_functions(mcp_tools)) + _append_unique_tools( + server_tools, + _collect_mcp_tool_functions(mcp_tools), + duplicate_error_message="Tool names must be unique. Consider setting `tool_name_prefix` on the MCPTool.", + ) logger.info(f"[TOOLS] Agent has {len(server_tools)} configured tools") for tool in server_tools: @@ -109,26 +114,13 @@ def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list logger.info("[TOOLS] No client tools - not passing tools= parameter (using agent's configured tools)") return None - server_tool_names = {getattr(tool, "name", None) for tool in server_tools} - unique_client_tools = [tool for tool in client_tools if getattr(tool, "name", None) not in server_tool_names] - - if not unique_client_tools: - # Same check: must pass server tools if any require approval - if server_tools and _has_approval_tools(server_tools): - logger.info( - f"[TOOLS] Client tools duplicate server but server has approval tools - " - f"passing {len(server_tools)} server tools for approval mode" - ) - return server_tools - logger.info("[TOOLS] All client tools duplicate server tools - not passing tools= parameter") - return None - - combined_tools: list[Any] = [] - if server_tools: - combined_tools.extend(server_tools) - combined_tools.extend(unique_client_tools) + combined_tools = _append_unique_tools( + list(server_tools), + client_tools, + duplicate_error_message="Tool names must be unique.", + ) logger.info( f"[TOOLS] Passing tools= parameter with {len(combined_tools)} tools " - f"({len(server_tools)} server + {len(unique_client_tools)} unique client)" + f"({len(server_tools)} server + {len(client_tools)} client)" ) return combined_tools diff --git a/python/packages/ag-ui/tests/ag_ui/conftest.py b/python/packages/ag-ui/tests/ag_ui/conftest.py index b73eddb8ad3..42a6967371d 100644 --- a/python/packages/ag-ui/tests/ag_ui/conftest.py +++ b/python/packages/ag-ui/tests/ag_ui/conftest.py @@ -98,7 +98,11 @@ def get_response( options: OptionsCoT | ChatOptions[Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: - self.last_session = kwargs.get("session") + client_kwargs = kwargs.get("client_kwargs") + if isinstance(client_kwargs, Mapping): + self.last_session = cast(AgentSession | None, client_kwargs.get("session")) + else: + self.last_session = None self.last_service_session_id = self.last_session.service_session_id if self.last_session else None return cast( Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]], diff --git a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py index e98eb9c9c4e..e6f58ef0fd8 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/ag_ui/test_agent_wrapper_comprehensive.py @@ -702,14 +702,9 @@ async def test_agent_with_use_service_session_is_true(streaming_chat_client_stub """Test that when use_service_session is True, the AgentSession used to run the agent is set to the service session ID.""" from agent_framework.ag_ui import AgentFrameworkAgent - request_service_session_id: str | None = None - async def stream_fn( messages: MutableSequence[Message], chat_options: ChatOptions, **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - nonlocal request_service_session_id - session = kwargs.get("session") - request_service_session_id = session.service_session_id if session else None yield ChatResponseUpdate( contents=[Content.from_text(text="Response")], response_id="resp_67890", conversation_id="conv_12345" ) @@ -719,11 +714,22 @@ async def stream_fn( input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"} + # Spy on agent.run to capture the session kwarg at call time (before streaming mutates it) + captured_service_session_id: str | None = None + original_run = agent.run + + def capturing_run(*args: Any, **kwargs: Any) -> Any: + nonlocal captured_service_session_id + session = kwargs.get("session") + captured_service_session_id = session.service_session_id if session else None + return original_run(*args, **kwargs) + + agent.run = capturing_run # type: ignore[assignment, method-assign] + events: list[Any] = [] async for event in wrapper.run(input_data): events.append(event) - request_service_session_id = agent.client.last_service_session_id - assert request_service_session_id == "conv_123456" # type: ignore[attr-defined] (service_session_id should be set) + assert captured_service_session_id == "conv_123456" async def test_function_approval_mode_executes_tool(streaming_chat_client_stub): diff --git a/python/packages/ag-ui/tests/ag_ui/test_event_converters.py b/python/packages/ag-ui/tests/ag_ui/test_event_converters.py index a51d1364276..70bd4a0f04b 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_event_converters.py +++ b/python/packages/ag-ui/tests/ag_ui/test_event_converters.py @@ -185,7 +185,7 @@ def test_tool_call_result_event(self) -> None: assert update.role == "tool" assert len(update.contents) == 1 assert update.contents[0].call_id == "call_123" - assert update.contents[0].result == {"temperature": 22, "condition": "sunny"} + assert update.contents[0].result == '{"temperature": 22, "condition": "sunny"}' def test_run_finished_event(self) -> None: """Test conversion of RUN_FINISHED event.""" diff --git a/python/packages/ag-ui/tests/ag_ui/test_tooling.py b/python/packages/ag-ui/tests/ag_ui/test_tooling.py index e8567a586d9..890ae445415 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_tooling.py +++ b/python/packages/ag-ui/tests/ag_ui/test_tooling.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock +import pytest from agent_framework import Agent, tool from agent_framework_ag_ui._orchestration._tooling import ( @@ -20,7 +21,8 @@ def __init__(self, name: str) -> None: class MockMCPTool: """Mock MCP tool that simulates connected MCP tool with functions.""" - def __init__(self, functions: list[DummyTool], is_connected: bool = True) -> None: + def __init__(self, functions: list[DummyTool], is_connected: bool = True, name: str = "mock-mcp") -> None: + self.name = name self.functions = functions self.is_connected = is_connected @@ -45,11 +47,8 @@ def test_merge_tools_filters_duplicates() -> None: server = [DummyTool("a"), DummyTool("b")] client = [DummyTool("b"), DummyTool("c")] - merged = merge_tools(server, client) - - assert merged is not None - names = [getattr(t, "name", None) for t in merged] - assert names == ["a", "b", "c"] + with pytest.raises(ValueError, match="Duplicate tool name 'b'"): + merge_tools(server, client) def test_register_additional_client_tools_assigns_when_configured() -> None: @@ -131,6 +130,17 @@ def test_collect_server_tools_with_mcp_tools_via_public_property() -> None: assert len(tools) == 2 +def test_collect_server_tools_raises_on_duplicate_agent_and_mcp_tool_names() -> None: + duplicate_tool = DummyTool("regular_tool") + mock_mcp = MockMCPTool([duplicate_tool], is_connected=True, name="docs-mcp") + + agent = _create_chat_agent_with_tool("regular_tool") + agent.mcp_tools = [mock_mcp] + + with pytest.raises(ValueError, match="Duplicate tool name 'regular_tool'"): + collect_server_tools(agent) + + # Additional tests for tooling coverage @@ -176,11 +186,11 @@ def test_merge_tools_no_client_tools() -> None: def test_merge_tools_all_duplicates() -> None: - """merge_tools returns None when all client tools duplicate server tools.""" + """merge_tools raises when client and server tools share a name.""" server = [DummyTool("a"), DummyTool("b")] client = [DummyTool("a"), DummyTool("b")] - result = merge_tools(server, client) - assert result is None + with pytest.raises(ValueError, match="Duplicate tool name 'a'"): + merge_tools(server, client) def test_merge_tools_empty_server() -> None: @@ -208,7 +218,7 @@ def __init__(self, name: str): def test_merge_tools_with_approval_tools_all_duplicates() -> None: - """merge_tools returns server tools with approval mode even when client duplicates.""" + """merge_tools raises even when a client tool duplicates an approval-gated server tool.""" class ApprovalTool: def __init__(self, name: str): @@ -217,7 +227,5 @@ def __init__(self, name: str): server = [ApprovalTool("write_doc")] client = [DummyTool("write_doc")] # Same name as server - result = merge_tools(server, client) - assert result is not None - assert len(result) == 1 - assert result[0].approval_mode == "always_require" + with pytest.raises(ValueError, match="Duplicate tool name 'write_doc'"): + merge_tools(server, client) diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index 5cda4991c88..a1915a69fb2 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -228,11 +228,11 @@ def __init__( model_id: str | None = None, anthropic_client: AsyncAnthropic | None = None, additional_beta_flags: list[str] | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Anthropic Agent client. @@ -244,11 +244,11 @@ def __init__( For instance if you need to set a different base_url for testing or private deployments. additional_beta_flags: Additional beta flags to enable on the client. Default flags are: "mcp-client-2025-04-04", "code-execution-2025-08-25". + additional_properties: Additional properties stored on the client instance. middleware: Optional middleware to apply to the client. function_invocation_configuration: Optional function invocation configuration override. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. - kwargs: Additional keyword arguments passed to the parent class. Examples: .. code-block:: python @@ -319,9 +319,9 @@ class MyOptions(AnthropicChatOptions, total=False): # Initialize parent super().__init__( + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) # Initialize instance variables @@ -716,12 +716,46 @@ def _prepare_message_for_anthropic(self, message: Message) -> dict[str, Any]: "input": content.parse_arguments(), }) case "function_result": - a_content.append({ - "type": "tool_result", - "tool_use_id": content.call_id, - "content": content.result if content.result is not None else "", - "is_error": content.exception is not None, - }) + if content.items: + tool_content: list[dict[str, Any]] = [] + for item in content.items: + if item.type == "text": + tool_content.append({"type": "text", "text": item.text or ""}) + elif item.type == "data" and item.has_top_level_media_type("image"): + tool_content.append({ + "type": "image", + "source": { + "data": _get_data_bytes_as_str(item), # type: ignore[attr-defined] + "media_type": item.media_type, + "type": "base64", + }, + }) + elif item.type == "uri" and item.has_top_level_media_type("image"): + tool_content.append({ + "type": "image", + "source": {"type": "url", "url": item.uri}, + }) + else: + logger.debug( + "Ignoring unsupported rich content media type in tool result: %s", + item.media_type, + ) + tool_result_content = ( + tool_content if tool_content else (content.result if content.result is not None else "") + ) + a_content.append({ + "type": "tool_result", + "tool_use_id": content.call_id, + "content": tool_result_content, + "is_error": content.exception is not None, + }) + else: + a_content.append({ + "type": "tool_result", + "tool_use_id": content.call_id, + "content": content.result if content.result is not None else "", + "is_error": content.exception is not None, + }) case "mcp_server_tool_call": mcp_call: dict[str, Any] = { "type": "mcp_tool_use", diff --git a/python/packages/anthropic/tests/test_anthropic_client.py b/python/packages/anthropic/tests/test_anthropic_client.py index 4f86c3eac21..272239b1d74 100644 --- a/python/packages/anthropic/tests/test_anthropic_client.py +++ b/python/packages/anthropic/tests/test_anthropic_client.py @@ -96,7 +96,9 @@ def test_anthropic_settings_init_with_explicit_values() -> None: @pytest.mark.parametrize("exclude_list", [["ANTHROPIC_API_KEY"]], indirect=True) -def test_anthropic_settings_missing_api_key(anthropic_unit_test_env: dict[str, str]) -> None: +def test_anthropic_settings_missing_api_key( + anthropic_unit_test_env: dict[str, str], +) -> None: """Test AnthropicSettings when API key is missing.""" settings = load_settings(AnthropicSettings, env_prefix="ANTHROPIC_") assert settings["api_key"] is None @@ -115,7 +117,9 @@ def test_anthropic_client_init_with_client(mock_anthropic_client: MagicMock) -> assert isinstance(client, SupportsChatGetResponse) -def test_anthropic_client_init_auto_create_client(anthropic_unit_test_env: dict[str, str]) -> None: +def test_anthropic_client_init_auto_create_client( + anthropic_unit_test_env: dict[str, str], +) -> None: """Test AnthropicClient initialization with auto-created anthropic_client.""" client = AnthropicClient( api_key=anthropic_unit_test_env["ANTHROPIC_API_KEY"], @@ -129,7 +133,10 @@ def test_anthropic_client_init_auto_create_client(anthropic_unit_test_env: dict[ def test_anthropic_client_init_missing_api_key() -> None: """Test AnthropicClient initialization when API key is missing.""" with patch("agent_framework_anthropic._chat_client.load_settings") as mock_load: - mock_load.return_value = {"api_key": None, "chat_model_id": "claude-3-5-sonnet-20241022"} + mock_load.return_value = { + "api_key": None, + "chat_model_id": "claude-3-5-sonnet-20241022", + } with pytest.raises(ValueError, match="Anthropic API key is required"): AnthropicClient() @@ -157,7 +164,9 @@ def test_prepare_message_for_anthropic_text(mock_anthropic_client: MagicMock) -> assert result["content"][0]["text"] == "Hello, world!" -def test_prepare_message_for_anthropic_function_call(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_for_anthropic_function_call( + mock_anthropic_client: MagicMock, +) -> None: """Test converting function call message to Anthropic format.""" client = create_test_anthropic_client(mock_anthropic_client) message = Message( @@ -181,7 +190,9 @@ def test_prepare_message_for_anthropic_function_call(mock_anthropic_client: Magi assert result["content"][0]["input"] == {"location": "San Francisco"} -def test_prepare_message_for_anthropic_function_result(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_for_anthropic_function_result( + mock_anthropic_client: MagicMock, +) -> None: """Test converting function result message to Anthropic format.""" client = create_test_anthropic_client(mock_anthropic_client) message = Message( @@ -200,13 +211,124 @@ def test_prepare_message_for_anthropic_function_result(mock_anthropic_client: Ma assert len(result["content"]) == 1 assert result["content"][0]["type"] == "tool_result" assert result["content"][0]["tool_use_id"] == "call_123" - # The degree symbol might be escaped differently depending on JSON encoder - assert "Sunny" in result["content"][0]["content"] - assert "72" in result["content"][0]["content"] + tool_content = result["content"][0]["content"] + assert isinstance(tool_content, list) + assert len(tool_content) == 1 + assert tool_content[0]["type"] == "text" + assert "Sunny" in tool_content[0]["text"] + assert "72" in tool_content[0]["text"] assert result["content"][0]["is_error"] is False -def test_prepare_message_for_anthropic_text_reasoning(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_for_anthropic_function_result_with_data_image( + mock_anthropic_client: MagicMock, +) -> None: + """Test function result with a data-type image item produces a base64 image block.""" + client = create_test_anthropic_client(mock_anthropic_client) + image_content = Content.from_data(data=b"fake_image_bytes", media_type="image/png") + message = Message( + role="tool", + contents=[ + Content.from_function_result( + call_id="call_img", + result=[Content.from_text("Here is the image"), image_content], + ) + ], + ) + + result = client._prepare_message_for_anthropic(message) + + assert result["role"] == "user" + tool_result = result["content"][0] + assert tool_result["type"] == "tool_result" + assert tool_result["tool_use_id"] == "call_img" + content = tool_result["content"] + assert len(content) == 2 + assert content[0]["type"] == "text" + assert content[0]["text"] == "Here is the image" + assert content[1]["type"] == "image" + assert content[1]["source"]["type"] == "base64" + assert content[1]["source"]["media_type"] == "image/png" + + +def test_prepare_message_for_anthropic_function_result_with_uri_image( + mock_anthropic_client: MagicMock, +) -> None: + """Test function result with a uri-type image item produces a URL image block.""" + client = create_test_anthropic_client(mock_anthropic_client) + uri_content = Content.from_uri(uri="https://example.com/image.png", media_type="image/png") + message = Message( + role="tool", + contents=[ + Content.from_function_result( + call_id="call_uri", + result=[uri_content], + ) + ], + ) + + result = client._prepare_message_for_anthropic(message) + + tool_result = result["content"][0] + content = tool_result["content"] + assert len(content) == 1 + assert content[0]["type"] == "image" + assert content[0]["source"]["type"] == "url" + assert content[0]["source"]["url"] == "https://example.com/image.png" + + +def test_prepare_message_for_anthropic_function_result_with_unsupported_media( + mock_anthropic_client: MagicMock, +) -> None: + """Test function result with unsupported media type skips the item.""" + client = create_test_anthropic_client(mock_anthropic_client) + audio_content = Content.from_data(data=b"audio_bytes", media_type="audio/wav") + message = Message( + role="tool", + contents=[ + Content.from_function_result( + call_id="call_audio", + result=[Content.from_text("Some text"), audio_content], + ) + ], + ) + + result = client._prepare_message_for_anthropic(message) + + tool_result = result["content"][0] + content = tool_result["content"] + # Audio should be skipped, only text remains + assert len(content) == 1 + assert content[0]["type"] == "text" + assert content[0]["text"] == "Some text" + + +def test_prepare_message_for_anthropic_function_result_all_unsupported_media( + mock_anthropic_client: MagicMock, +) -> None: + """Test function result where all items are unsupported falls back to string result.""" + client = create_test_anthropic_client(mock_anthropic_client) + audio_content = Content.from_data(data=b"audio_bytes", media_type="audio/wav") + message = Message( + role="tool", + contents=[ + Content.from_function_result( + call_id="call_all_unsupported", + result=[audio_content], + ) + ], + ) + + result = client._prepare_message_for_anthropic(message) + + tool_result = result["content"][0] + # All items unsupported → tool_content is empty → falls back to string result + assert tool_result["content"] == "" + + +def test_prepare_message_for_anthropic_text_reasoning( + mock_anthropic_client: MagicMock, +) -> None: """Test converting text reasoning message to Anthropic format.""" client = create_test_anthropic_client(mock_anthropic_client) message = Message( @@ -223,7 +345,9 @@ def test_prepare_message_for_anthropic_text_reasoning(mock_anthropic_client: Mag assert "signature" not in result["content"][0] -def test_prepare_message_for_anthropic_text_reasoning_with_signature(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_for_anthropic_text_reasoning_with_signature( + mock_anthropic_client: MagicMock, +) -> None: """Test converting text reasoning message with signature to Anthropic format.""" client = create_test_anthropic_client(mock_anthropic_client) message = Message( @@ -240,7 +364,9 @@ def test_prepare_message_for_anthropic_text_reasoning_with_signature(mock_anthro assert result["content"][0]["signature"] == "sig_abc123" -def test_prepare_message_for_anthropic_mcp_server_tool_call(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_for_anthropic_mcp_server_tool_call( + mock_anthropic_client: MagicMock, +) -> None: """Test converting MCP server tool call message to Anthropic format.""" client = create_test_anthropic_client(mock_anthropic_client) message = Message( @@ -266,7 +392,9 @@ def test_prepare_message_for_anthropic_mcp_server_tool_call(mock_anthropic_clien assert result["content"][0]["input"] == {"query": "Azure Functions"} -def test_prepare_message_for_anthropic_mcp_server_tool_call_no_server_name(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_for_anthropic_mcp_server_tool_call_no_server_name( + mock_anthropic_client: MagicMock, +) -> None: """Test converting MCP server tool call with no server name defaults to empty string.""" client = create_test_anthropic_client(mock_anthropic_client) message = Message( @@ -291,7 +419,9 @@ def test_prepare_message_for_anthropic_mcp_server_tool_call_no_server_name(mock_ assert result["content"][0]["input"] == {} -def test_prepare_message_for_anthropic_mcp_server_tool_result(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_for_anthropic_mcp_server_tool_result( + mock_anthropic_client: MagicMock, +) -> None: """Test converting MCP server tool result message to Anthropic format.""" client = create_test_anthropic_client(mock_anthropic_client) message = Message( @@ -313,7 +443,9 @@ def test_prepare_message_for_anthropic_mcp_server_tool_result(mock_anthropic_cli assert result["content"][0]["content"] == "Found 3 results for Azure Functions." -def test_prepare_message_for_anthropic_mcp_server_tool_result_none_output(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_for_anthropic_mcp_server_tool_result_none_output( + mock_anthropic_client: MagicMock, +) -> None: """Test converting MCP server tool result with None output defaults to empty string.""" client = create_test_anthropic_client(mock_anthropic_client) message = Message( @@ -335,7 +467,9 @@ def test_prepare_message_for_anthropic_mcp_server_tool_result_none_output(mock_a assert result["content"][0]["content"] == "" -def test_prepare_messages_for_anthropic_with_system(mock_anthropic_client: MagicMock) -> None: +def test_prepare_messages_for_anthropic_with_system( + mock_anthropic_client: MagicMock, +) -> None: """Test converting messages list with system message.""" client = create_test_anthropic_client(mock_anthropic_client) messages = [ @@ -351,7 +485,9 @@ def test_prepare_messages_for_anthropic_with_system(mock_anthropic_client: Magic assert result[0]["content"][0]["text"] == "Hello!" -def test_prepare_messages_for_anthropic_without_system(mock_anthropic_client: MagicMock) -> None: +def test_prepare_messages_for_anthropic_without_system( + mock_anthropic_client: MagicMock, +) -> None: """Test converting messages list without system message.""" client = create_test_anthropic_client(mock_anthropic_client) messages = [ @@ -374,7 +510,9 @@ def test_prepare_tools_for_anthropic_tool(mock_anthropic_client: MagicMock) -> N client = create_test_anthropic_client(mock_anthropic_client) @tool(approval_mode="never_require") - def get_weather(location: Annotated[str, Field(description="Location to get weather for")]) -> str: + def get_weather( + location: Annotated[str, Field(description="Location to get weather for")], + ) -> str: """Get weather for a location.""" return f"Weather for {location}" @@ -389,7 +527,9 @@ def get_weather(location: Annotated[str, Field(description="Location to get weat assert "Get weather for a location" in result["tools"][0]["description"] -def test_prepare_tools_for_anthropic_web_search(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_web_search( + mock_anthropic_client: MagicMock, +) -> None: """Test converting web_search dict tool to Anthropic format.""" client = create_test_anthropic_client(mock_anthropic_client) chat_options = ChatOptions(tools=[client.get_web_search_tool()]) @@ -403,7 +543,9 @@ def test_prepare_tools_for_anthropic_web_search(mock_anthropic_client: MagicMock assert result["tools"][0]["name"] == "web_search" -def test_prepare_tools_for_anthropic_code_interpreter(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_code_interpreter( + mock_anthropic_client: MagicMock, +) -> None: """Test converting code_interpreter dict tool to Anthropic format.""" client = create_test_anthropic_client(mock_anthropic_client) chat_options = ChatOptions(tools=[client.get_code_interpreter_tool()]) @@ -421,7 +563,9 @@ def _dummy_bash(command: str) -> str: return f"executed: {command}" -def test_prepare_tools_for_anthropic_shell_tool(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_shell_tool( + mock_anthropic_client: MagicMock, +) -> None: """Test converting tool-decorated FunctionTool to Anthropic bash format.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -440,7 +584,9 @@ def run_bash(command: str) -> str: assert result["tools"][0]["name"] == "bash" -def test_prepare_tools_for_anthropic_shell_tool_custom_type(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_shell_tool_custom_type( + mock_anthropic_client: MagicMock, +) -> None: """Test shell tool with custom type via additional_properties.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -458,7 +604,9 @@ def run_bash(command: str) -> str: assert result["tools"][0]["name"] == "bash" -def test_prepare_tools_for_anthropic_shell_tool_does_not_mutate_name(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_shell_tool_does_not_mutate_name( + mock_anthropic_client: MagicMock, +) -> None: """Shell tool API name should be 'bash' without mutating local FunctionTool name.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -478,7 +626,9 @@ def run_local_shell(command: str) -> str: assert run_local_shell.name == "run_local_shell" -def test_get_shell_tool_reuses_function_tool_instance(mock_anthropic_client: MagicMock) -> None: +def test_get_shell_tool_reuses_function_tool_instance( + mock_anthropic_client: MagicMock, +) -> None: """Passing a FunctionTool should update and return the same tool instance.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -513,7 +663,9 @@ def test_prepare_tools_for_anthropic_mcp_tool(mock_anthropic_client: MagicMock) assert result["mcp_servers"][0]["url"] == "https://example.com/mcp" -def test_prepare_tools_for_anthropic_mcp_with_auth(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_mcp_with_auth( + mock_anthropic_client: MagicMock, +) -> None: """Test converting MCP dict tool with authorization token.""" client = create_test_anthropic_client(mock_anthropic_client) # Use the static method with authorization_token @@ -533,7 +685,9 @@ def test_prepare_tools_for_anthropic_mcp_with_auth(mock_anthropic_client: MagicM assert result["mcp_servers"][0]["authorization_token"] == "Bearer token123" -def test_prepare_tools_for_anthropic_dict_tool(mock_anthropic_client: MagicMock) -> None: +def test_prepare_tools_for_anthropic_dict_tool( + mock_anthropic_client: MagicMock, +) -> None: """Test converting dict tool to Anthropic format.""" client = create_test_anthropic_client(mock_anthropic_client) chat_options = ChatOptions(tools=[{"type": "custom", "name": "custom_tool", "description": "A custom tool"}]) @@ -574,7 +728,9 @@ async def test_prepare_options_basic(mock_anthropic_client: MagicMock) -> None: assert "messages" in run_options -async def test_prepare_options_with_system_message(mock_anthropic_client: MagicMock) -> None: +async def test_prepare_options_with_system_message( + mock_anthropic_client: MagicMock, +) -> None: """Test _prepare_options with system message.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -590,7 +746,9 @@ async def test_prepare_options_with_system_message(mock_anthropic_client: MagicM assert len(run_options["messages"]) == 1 # System message not in messages list -async def test_anthropic_shell_tool_is_invoked_in_function_loop(mock_anthropic_client: MagicMock) -> None: +async def test_anthropic_shell_tool_is_invoked_in_function_loop( + mock_anthropic_client: MagicMock, +) -> None: """Function invocation loop should execute shell tool when Anthropic returns bash tool_use.""" client = create_test_anthropic_client(mock_anthropic_client) executed_commands: list[str] = [] @@ -625,7 +783,10 @@ def run_local_shell(command: str) -> str: second_message.model = "claude-test" second_message.stop_reason = "end_turn" - mock_anthropic_client.beta.messages.create.side_effect = [first_message, second_message] + mock_anthropic_client.beta.messages.create.side_effect = [ + first_message, + second_message, + ] await client.get_response( messages=[Message(role="user", text="Run pwd")], @@ -643,10 +804,14 @@ def run_local_shell(command: str) -> str: ] assert len(tool_results) == 1 assert tool_results[0]["tool_use_id"] == "call_bash_loop" - assert "executed: pwd" in tool_results[0]["content"] + tool_content = tool_results[0]["content"] + assert isinstance(tool_content, list) + assert any("executed: pwd" in item.get("text", "") for item in tool_content) -async def test_prepare_options_with_tool_choice_auto(mock_anthropic_client: MagicMock) -> None: +async def test_prepare_options_with_tool_choice_auto( + mock_anthropic_client: MagicMock, +) -> None: """Test _prepare_options with auto tool choice.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -660,7 +825,9 @@ async def test_prepare_options_with_tool_choice_auto(mock_anthropic_client: Magi assert "allow_multiple_tool_calls" not in run_options -async def test_prepare_options_with_tool_choice_required(mock_anthropic_client: MagicMock) -> None: +async def test_prepare_options_with_tool_choice_required( + mock_anthropic_client: MagicMock, +) -> None: """Test _prepare_options with required tool choice.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -674,7 +841,9 @@ async def test_prepare_options_with_tool_choice_required(mock_anthropic_client: assert run_options["tool_choice"]["name"] == "get_weather" -async def test_prepare_options_with_tool_choice_none(mock_anthropic_client: MagicMock) -> None: +async def test_prepare_options_with_tool_choice_none( + mock_anthropic_client: MagicMock, +) -> None: """Test _prepare_options with none tool choice.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -704,7 +873,9 @@ def get_weather(location: str) -> str: assert len(run_options["tools"]) == 1 -async def test_prepare_options_with_stop_sequences(mock_anthropic_client: MagicMock) -> None: +async def test_prepare_options_with_stop_sequences( + mock_anthropic_client: MagicMock, +) -> None: """Test _prepare_options with stop sequences.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -728,7 +899,9 @@ async def test_prepare_options_with_top_p(mock_anthropic_client: MagicMock) -> N assert run_options["top_p"] == 0.9 -async def test_prepare_options_excludes_stream_option(mock_anthropic_client: MagicMock) -> None: +async def test_prepare_options_excludes_stream_option( + mock_anthropic_client: MagicMock, +) -> None: """Test _prepare_options excludes stream when stream is provided in options.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -740,7 +913,9 @@ async def test_prepare_options_excludes_stream_option(mock_anthropic_client: Mag assert "stream" not in run_options -async def test_prepare_options_filters_internal_kwargs(mock_anthropic_client: MagicMock) -> None: +async def test_prepare_options_filters_internal_kwargs( + mock_anthropic_client: MagicMock, +) -> None: """Test _prepare_options filters internal framework kwargs. Internal kwargs like _function_middleware_pipeline, thread, and middleware @@ -859,7 +1034,9 @@ def test_parse_contents_from_anthropic_text(mock_anthropic_client: MagicMock) -> assert result[0].text == "Hello!" -def test_parse_contents_from_anthropic_tool_use(mock_anthropic_client: MagicMock) -> None: +def test_parse_contents_from_anthropic_tool_use( + mock_anthropic_client: MagicMock, +) -> None: """Test _parse_contents_from_anthropic with tool use.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -879,7 +1056,9 @@ def test_parse_contents_from_anthropic_tool_use(mock_anthropic_client: MagicMock assert result[0].name == "get_weather" -def test_parse_contents_from_anthropic_input_json_delta_no_duplicate_name(mock_anthropic_client: MagicMock) -> None: +def test_parse_contents_from_anthropic_input_json_delta_no_duplicate_name( + mock_anthropic_client: MagicMock, +) -> None: """Test that input_json_delta events have empty name to prevent duplicate ToolCallStartEvents. When streaming tool calls, the initial tool_use event provides the name, @@ -969,7 +1148,9 @@ async def test_inner_get_response(mock_anthropic_client: MagicMock) -> None: assert len(response.messages) == 1 -async def test_inner_get_response_ignores_options_stream_non_streaming(mock_anthropic_client: MagicMock) -> None: +async def test_inner_get_response_ignores_options_stream_non_streaming( + mock_anthropic_client: MagicMock, +) -> None: """Test stream option in options does not conflict in non-streaming mode.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -1019,7 +1200,9 @@ async def mock_stream(): assert isinstance(chunks, list) -async def test_inner_get_response_ignores_options_stream_streaming(mock_anthropic_client: MagicMock) -> None: +async def test_inner_get_response_ignores_options_stream_streaming( + mock_anthropic_client: MagicMock, +) -> None: """Test stream option in options does not conflict in streaming mode.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -1368,7 +1551,9 @@ def test_prepare_response_format_openai_style(mock_anthropic_client: MagicMock) assert result["schema"]["properties"]["name"]["type"] == "string" -def test_prepare_response_format_direct_schema(mock_anthropic_client: MagicMock) -> None: +def test_prepare_response_format_direct_schema( + mock_anthropic_client: MagicMock, +) -> None: """Test response_format with direct schema key.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -1402,7 +1587,9 @@ def test_prepare_response_format_raw_schema(mock_anthropic_client: MagicMock) -> assert result["schema"]["properties"]["count"]["type"] == "integer" -def test_prepare_response_format_pydantic_model(mock_anthropic_client: MagicMock) -> None: +def test_prepare_response_format_pydantic_model( + mock_anthropic_client: MagicMock, +) -> None: """Test response_format with Pydantic BaseModel.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -1475,7 +1662,9 @@ def test_prepare_message_with_unsupported_data_type( assert len(result["content"]) == 0 -def test_prepare_message_with_unsupported_uri_type(mock_anthropic_client: MagicMock) -> None: +def test_prepare_message_with_unsupported_uri_type( + mock_anthropic_client: MagicMock, +) -> None: """Test preparing messages with unsupported URI content type.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -1612,7 +1801,9 @@ def test_parse_contents_mcp_tool_result_object_content( assert result[0].type == "mcp_server_tool_result" -def test_parse_contents_web_search_tool_result(mock_anthropic_client: MagicMock) -> None: +def test_parse_contents_web_search_tool_result( + mock_anthropic_client: MagicMock, +) -> None: """Test parsing web search tool result.""" client = create_test_anthropic_client(mock_anthropic_client) client._last_call_id_name = ("call_789", "web_search") @@ -1742,7 +1933,9 @@ def test_func() -> str: assert result["tool_choice"]["type"] == "any" -def test_tool_choice_required_specific_function(mock_anthropic_client: MagicMock) -> None: +def test_tool_choice_required_specific_function( + mock_anthropic_client: MagicMock, +) -> None: """Test tool_choice required mode with specific function.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -1782,7 +1975,9 @@ def test_func() -> str: assert result["tool_choice"]["type"] == "none" -def test_tool_choice_required_allows_parallel_use(mock_anthropic_client: MagicMock) -> None: +def test_tool_choice_required_allows_parallel_use( + mock_anthropic_client: MagicMock, +) -> None: """Test tool choice required mode with allow_multiple=True.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -1902,7 +2097,9 @@ def test_parse_usage_with_cache_tokens(mock_anthropic_client: MagicMock) -> None # Code Execution Result Tests -def test_parse_code_execution_result_with_error(mock_anthropic_client: MagicMock) -> None: +def test_parse_code_execution_result_with_error( + mock_anthropic_client: MagicMock, +) -> None: """Test parsing code execution result with error.""" client = create_test_anthropic_client(mock_anthropic_client) client._last_call_id_name = ("call_code1", "code_execution_tool") @@ -1925,7 +2122,9 @@ def test_parse_code_execution_result_with_error(mock_anthropic_client: MagicMock assert result[0].type == "code_interpreter_tool_result" -def test_parse_code_execution_result_with_stdout(mock_anthropic_client: MagicMock) -> None: +def test_parse_code_execution_result_with_stdout( + mock_anthropic_client: MagicMock, +) -> None: """Test parsing code execution result with stdout.""" client = create_test_anthropic_client(mock_anthropic_client) client._last_call_id_name = ("call_code2", "code_execution_tool") @@ -1947,7 +2146,9 @@ def test_parse_code_execution_result_with_stdout(mock_anthropic_client: MagicMoc assert result[0].type == "code_interpreter_tool_result" -def test_parse_code_execution_result_with_stderr(mock_anthropic_client: MagicMock) -> None: +def test_parse_code_execution_result_with_stderr( + mock_anthropic_client: MagicMock, +) -> None: """Test parsing code execution result with stderr.""" client = create_test_anthropic_client(mock_anthropic_client) client._last_call_id_name = ("call_code3", "code_execution_tool") @@ -1969,7 +2170,9 @@ def test_parse_code_execution_result_with_stderr(mock_anthropic_client: MagicMoc assert result[0].type == "code_interpreter_tool_result" -def test_parse_code_execution_result_with_files(mock_anthropic_client: MagicMock) -> None: +def test_parse_code_execution_result_with_files( + mock_anthropic_client: MagicMock, +) -> None: """Test parsing code execution result with file outputs.""" client = create_test_anthropic_client(mock_anthropic_client) client._last_call_id_name = ("call_code4", "code_execution_tool") @@ -1998,8 +2201,10 @@ def test_parse_code_execution_result_with_files(mock_anthropic_client: MagicMock # Bash Execution Result Tests -def test_parse_bash_execution_result_with_stdout(mock_anthropic_client: MagicMock) -> None: - """Test parsing bash execution result with stdout produces shell_tool_result.""" +def test_parse_bash_execution_result_with_stdout( + mock_anthropic_client: MagicMock, +) -> None: + """Test parsing bash execution result with stdout.""" client = create_test_anthropic_client(mock_anthropic_client) client._last_call_id_name = ("call_bash2", "bash_code_execution") @@ -2028,8 +2233,10 @@ def test_parse_bash_execution_result_with_stdout(mock_anthropic_client: MagicMoc assert result[0].outputs[0].timed_out is False -def test_parse_bash_execution_result_with_stderr(mock_anthropic_client: MagicMock) -> None: - """Test parsing bash execution result with stderr produces shell_tool_result.""" +def test_parse_bash_execution_result_with_stderr( + mock_anthropic_client: MagicMock, +) -> None: + """Test parsing bash execution result with stderr.""" client = create_test_anthropic_client(mock_anthropic_client) client._last_call_id_name = ("call_bash3", "bash_code_execution") @@ -2056,7 +2263,9 @@ def test_parse_bash_execution_result_with_stderr(mock_anthropic_client: MagicMoc assert result[0].outputs[0].exit_code == 1 -def test_parse_bash_execution_result_with_error(mock_anthropic_client: MagicMock) -> None: +def test_parse_bash_execution_result_with_error( + mock_anthropic_client: MagicMock, +) -> None: """Test parsing bash execution error produces shell_tool_result with error info.""" from anthropic.types.beta.beta_bash_code_execution_tool_result_error import ( BetaBashCodeExecutionToolResultError, @@ -2277,7 +2486,9 @@ def test_parse_citations_page_location(mock_anthropic_client: MagicMock) -> None assert len(result) > 0 -def test_parse_citations_content_block_location(mock_anthropic_client: MagicMock) -> None: +def test_parse_citations_content_block_location( + mock_anthropic_client: MagicMock, +) -> None: """Test parsing citations with content_block_location.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -2322,7 +2533,9 @@ def test_parse_citations_web_search_location(mock_anthropic_client: MagicMock) - assert len(result) > 0 -def test_parse_citations_search_result_location(mock_anthropic_client: MagicMock) -> None: +def test_parse_citations_search_result_location( + mock_anthropic_client: MagicMock, +) -> None: """Test parsing citations with search_result_location.""" client = create_test_anthropic_client(mock_anthropic_client) @@ -2344,3 +2557,33 @@ def test_parse_citations_search_result_location(mock_anthropic_client: MagicMock result = client._parse_citations_from_anthropic(mock_block) assert len(result) > 0 + + +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_anthropic_integration_tests_disabled +async def test_anthropic_client_integration_tool_rich_content_image() -> None: + """Integration test: a tool returns an image and the model describes it.""" + image_path = Path(__file__).parent / "assets" / "sample_image.jpg" + image_bytes = image_path.read_bytes() + + @tool(approval_mode="never_require") + def get_test_image() -> Content: + """Return a test image for analysis.""" + return Content.from_data(data=image_bytes, media_type="image/jpeg") + + client = AnthropicClient() + client.function_invocation_configuration["max_iterations"] = 2 + + messages = [Message(role="user", text="Call the get_test_image tool and describe what you see.")] + + response = await client.get_response( + messages=messages, + options={"tools": [get_test_image], "tool_choice": "auto", "max_tokens": 200}, + ) + + assert response is not None + assert response.text is not None + assert len(response.text) > 0 + # sample_image.jpg contains a photo of a house; the model should mention it. + assert "house" in response.text.lower(), f"Model did not describe the house image. Response: {response.text}" diff --git a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py index 4c065174eac..9972f1301d6 100644 --- a/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py +++ b/python/packages/azure-ai-search/tests/test_aisearch_context_provider.py @@ -17,10 +17,15 @@ @pytest.fixture(autouse=True) -def clear_azure_search_environment(monkeypatch: pytest.MonkeyPatch) -> None: - for key in tuple(os.environ): - if key.startswith("AZURE_SEARCH_"): - monkeypatch.delenv(key, raising=False) +def clear_azure_search_env(monkeypatch: pytest.MonkeyPatch) -> None: + """Keep tests isolated from ambient Azure Search environment variables.""" + for key in ( + "AZURE_SEARCH_ENDPOINT", + "AZURE_SEARCH_INDEX_NAME", + "AZURE_SEARCH_KNOWLEDGE_BASE_NAME", + "AZURE_SEARCH_API_KEY", + ): + monkeypatch.delenv(key, raising=False) class MockSearchResults: diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py index 4c0e3a56e72..d349ef32478 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py @@ -444,11 +444,11 @@ def __init__( model_deployment_name: str | None = None, credential: AzureCredentialTypes | None = None, should_cleanup_agent: bool = True, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Azure AI Agent client. @@ -471,11 +471,11 @@ def __init__( should_cleanup_agent: Whether to cleanup (delete) agents created by this client when the client is closed or context is exited. Defaults to True. Only affects agents created by this client instance; existing agents passed via agent_id are never deleted. + additional_properties: Additional properties stored on the client instance. middleware: Optional sequence of middlewares to include. function_invocation_configuration: Optional function invocation configuration. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. - kwargs: Additional keyword arguments passed to the parent class. Examples: .. code-block:: python @@ -548,9 +548,9 @@ class MyOptions(AzureAIAgentOptions, total=False): # Initialize parent super().__init__( + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) # Initialize instance variables @@ -1402,11 +1402,20 @@ def _prepare_tool_outputs_for_azure_ai( call_id = run_and_call_ids[1] if content.type == "function_result": + if content.items: + text_parts = [item.text or "" for item in content.items if item.type == "text"] + rich_items = [item for item in content.items if item.type in ("data", "uri")] + if rich_items: + logger.warning( + "Azure AI Agents does not support rich content (images, audio) in tool results. " + "Rich content items will be omitted." + ) + output_text = "\n".join(text_parts) if text_parts else "" + else: + output_text = content.result if content.result is not None else "" if tool_outputs is None: tool_outputs = [] - tool_outputs.append( - ToolOutput(tool_call_id=call_id, output=content.result if content.result is not None else "") - ) + tool_outputs.append(ToolOutput(tool_call_id=call_id, output=output_text)) elif content.type == "function_approval_response": if tool_approvals is None: tool_approvals = [] diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_client.py index ba5dd8aad72..1fc6c7c1c96 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_client.py @@ -119,9 +119,9 @@ def __init__( credential: AzureCredentialTypes | None = None, use_latest_version: bool | None = None, allow_preview: bool | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a bare Azure AI client. @@ -145,9 +145,9 @@ def __init__( use_latest_version: Boolean flag that indicates whether to use latest agent version if it exists in the service. allow_preview: Enables preview opt-in on internally-created ``AIProjectClient``. + additional_properties: Additional properties stored on the client instance. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. - kwargs: Additional keyword arguments passed to the parent class. Examples: .. code-block:: python @@ -217,7 +217,7 @@ class MyOptions(ChatOptions, total=False): # Initialize parent super().__init__( - **kwargs, + additional_properties=additional_properties, ) # Initialize instance variables @@ -1243,11 +1243,11 @@ def __init__( credential: AzureCredentialTypes | None = None, use_latest_version: bool | None = None, allow_preview: bool | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Azure AI client with full layer support. @@ -1268,11 +1268,11 @@ def __init__( use_latest_version: Boolean flag that indicates whether to use latest agent version if it exists in the service. allow_preview: Enables preview opt-in on internally-created ``AIProjectClient`` + additional_properties: Additional properties stored on the client instance. middleware: Optional sequence of chat middlewares to include. function_invocation_configuration: Optional function invocation configuration. env_file_path: Path to environment file for loading settings. env_file_encoding: Encoding of the environment file. - kwargs: Additional keyword arguments passed to the parent class. Examples: .. code-block:: python @@ -1319,9 +1319,9 @@ class MyOptions(ChatOptions, total=False): credential=credential, use_latest_version=use_latest_version, allow_preview=allow_preview, + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, env_file_path=env_file_path, env_file_encoding=env_file_encoding, - **kwargs, ) diff --git a/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py b/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py index a243f77a38d..3daa6783334 100644 --- a/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py +++ b/python/packages/azure-ai/agent_framework_azure_ai/_embedding_client.py @@ -124,9 +124,9 @@ def __init__( text_client: EmbeddingsClient | None = None, image_client: ImageEmbeddingsClient | None = None, credential: AzureKeyCredential | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a raw Azure AI Inference embedding client.""" settings = load_settings( @@ -160,7 +160,7 @@ def __init__( credential=credential, # type: ignore[arg-type] ) self._endpoint = resolved_endpoint - super().__init__(**kwargs) + super().__init__(additional_properties=additional_properties) async def close(self) -> None: """Close the underlying SDK clients and release resources.""" @@ -376,9 +376,9 @@ def __init__( image_client: ImageEmbeddingsClient | None = None, credential: AzureKeyCredential | None = None, otel_provider_name: str | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Azure AI Inference embedding client.""" super().__init__( @@ -389,8 +389,8 @@ def __init__( text_client=text_client, image_client=image_client, credential=credential, + additional_properties=additional_properties, otel_provider_name=otel_provider_name, env_file_path=env_file_path, env_file_encoding=env_file_encoding, - **kwargs, ) diff --git a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py index 4d20add20af..afa073c6aba 100644 --- a/python/packages/azure-ai/tests/test_azure_ai_agent_client.py +++ b/python/packages/azure-ai/tests/test_azure_ai_agent_client.py @@ -1208,8 +1208,8 @@ def __init__(self, data: str): assert len(tool_outputs) == 1 assert tool_outputs[0].tool_call_id == "call_456" - # Result is pre-parsed string (already JSON) - assert tool_outputs[0].output == pre_parsed + # Result is the text content extracted from items + assert tool_outputs[0].output == function_result.result async def test_azure_ai_chat_client_convert_required_action_approval_response( diff --git a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py index 35c4243c373..6d205fa378b 100644 --- a/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py +++ b/python/packages/azure-cosmos/agent_framework_azure_cosmos/_history_provider.py @@ -124,7 +124,13 @@ def __init__( self._database_client = self._cosmos_client.get_database_client(self.database_name) - async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]: + async def get_messages( + self, + session_id: str | None, + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> list[Message]: """Retrieve stored messages for this session from Azure Cosmos DB.""" await self._ensure_container_proxy() session_key = self._session_partition_key(session_id) @@ -157,7 +163,14 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess return messages - async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None: + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: """Persist messages for this session to Azure Cosmos DB.""" if not messages: return diff --git a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py index c108f7739db..1c43264398a 100644 --- a/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py +++ b/python/packages/azurefunctions/agent_framework_azurefunctions/_app.py @@ -14,6 +14,7 @@ import re import uuid from collections.abc import Callable, Mapping +from copy import deepcopy from dataclasses import dataclass from datetime import datetime, timezone from typing import TYPE_CHECKING, Any, TypeVar, cast @@ -58,6 +59,11 @@ HandlerT = TypeVar("HandlerT", bound=Callable[..., Any]) +def _create_state_snapshot(state: dict[str, Any]) -> dict[str, Any]: + """Create a deep copy of the deserialized state for later diffing.""" + return deepcopy(state) + + @dataclass class AgentMetadata: """Metadata for a registered agent. @@ -306,7 +312,7 @@ async def run() -> dict[str, Any]: deserialized_state: dict[str, Any] = { str(k): deserialize_value(v) for k, v in shared_state_snapshot.items() } - original_snapshot: dict[str, Any] = dict(deserialized_state) + original_snapshot = _create_state_snapshot(deserialized_state) shared_state.import_state(deserialized_state) if is_hitl_response: @@ -339,9 +345,10 @@ async def run() -> dict[str, Any]: deletes: set[str] = original_keys - current_keys # Updates = keys in current that are new or have different values - updates = { - k: v for k, v in current_state.items() if k not in original_snapshot or original_snapshot[k] != v - } + updates: dict[str, Any] = {} + for key in current_keys: + if key not in original_keys or current_state[key] != original_snapshot.get(key): + updates[key] = current_state[key] # Drain messages and events from runner context sent_messages = await runner_context.drain_messages() diff --git a/python/packages/azurefunctions/tests/test_app.py b/python/packages/azurefunctions/tests/test_app.py index f4b86ba2d7d..03084d5ada5 100644 --- a/python/packages/azurefunctions/tests/test_app.py +++ b/python/packages/azurefunctions/tests/test_app.py @@ -26,6 +26,7 @@ from agent_framework_azurefunctions import AgentFunctionApp from agent_framework_azurefunctions._entities import create_agent_entity +from agent_framework_azurefunctions._workflow import SOURCE_ORCHESTRATOR FuncT = TypeVar("FuncT", bound=Callable[..., Any]) @@ -1441,5 +1442,286 @@ def test_build_status_url_handles_trailing_slash(self) -> None: assert "instance-456" in url +def _compute_state_updates(original_snapshot: dict[str, Any], current_state: dict[str, Any]) -> dict[str, Any]: + """Compute state updates by comparing current state against the original snapshot. + + This mirrors the inlined logic in ``_app.py``'s ``executor_activity.run()``. + """ + original_keys = set(original_snapshot.keys()) + current_keys = set(current_state.keys()) + updates: dict[str, Any] = {} + for key in current_keys: + if key not in original_keys or current_state[key] != original_snapshot.get(key): + updates[key] = current_state[key] + return updates + + +class TestStateSnapshotDiff: + """Test suite for state snapshot diffing in activity execution. + + The activity executor snapshots state before execution and diffs against the + post-execution state to determine which keys were updated. These tests exercise + the production snapshot helper and the state-update diffing logic to ensure that + in-place mutations to nested objects (dicts, lists) are correctly detected as changes. + """ + + def test_nested_dict_mutation_detected_in_diff(self) -> None: + """Test that mutating values inside a nested dict appears in the diff.""" + from agent_framework._workflows._state import State + + from agent_framework_azurefunctions._app import _create_state_snapshot + + deserialized_state: dict[str, Any] = { + "Local.config": {"code": "", "enabled": False}, + "simple_key": "simple_value", + } + + original_snapshot = _create_state_snapshot(deserialized_state) + + shared_state = State() + shared_state.import_state(deserialized_state) + + config = shared_state.get("Local.config") + config["code"] = "SOMECODEXXX" + config["enabled"] = True + + shared_state.commit() + current_state = shared_state.export_state() + + updates = _compute_state_updates(original_snapshot, current_state) + + assert "Local.config" in updates + assert updates["Local.config"]["code"] == "SOMECODEXXX" + assert updates["Local.config"]["enabled"] is True + + def test_new_key_in_nested_dict_detected_in_diff(self) -> None: + """Test that adding a key to a nested dict appears in the diff.""" + from agent_framework._workflows._state import State + + from agent_framework_azurefunctions._app import _create_state_snapshot + + deserialized_state: dict[str, Any] = { + "Local.data": {"existing": "value"}, + } + + original_snapshot = _create_state_snapshot(deserialized_state) + + shared_state = State() + shared_state.import_state(deserialized_state) + + data = shared_state.get("Local.data") + data["code"] = "NEW_CODE" + + shared_state.commit() + current_state = shared_state.export_state() + + updates = _compute_state_updates(original_snapshot, current_state) + + assert "Local.data" in updates + assert updates["Local.data"]["code"] == "NEW_CODE" + + def test_nested_list_mutation_detected_in_diff(self) -> None: + """Test that appending to a nested list appears in the diff.""" + from agent_framework._workflows._state import State + + from agent_framework_azurefunctions._app import _create_state_snapshot + + deserialized_state: dict[str, Any] = { + "Local.items": [1, 2, 3], + } + + original_snapshot = _create_state_snapshot(deserialized_state) + + shared_state = State() + shared_state.import_state(deserialized_state) + + items = shared_state.get("Local.items") + items.append(4) + + shared_state.commit() + current_state = shared_state.export_state() + + updates = _compute_state_updates(original_snapshot, current_state) + + assert "Local.items" in updates + assert updates["Local.items"] == [1, 2, 3, 4] + + def test_new_top_level_key_detected_in_diff(self) -> None: + """Test that setting a new top-level key appears in the diff.""" + from agent_framework._workflows._state import State + + from agent_framework_azurefunctions._app import _create_state_snapshot + + deserialized_state: dict[str, Any] = { + "existing": "value", + } + + original_snapshot = _create_state_snapshot(deserialized_state) + + shared_state = State() + shared_state.import_state(deserialized_state) + + shared_state.set("Local.code", "SOMECODEXXX") + + shared_state.commit() + current_state = shared_state.export_state() + + updates = _compute_state_updates(original_snapshot, current_state) + + assert "Local.code" in updates + assert updates["Local.code"] == "SOMECODEXXX" + + def test_unchanged_nested_state_produces_empty_diff(self) -> None: + """Test that unmodified nested state produces no updates.""" + from agent_framework._workflows._state import State + + from agent_framework_azurefunctions._app import _create_state_snapshot + + deserialized_state: dict[str, Any] = { + "Local.config": {"code": "existing", "enabled": True}, + "simple_key": "simple_value", + } + + original_snapshot = _create_state_snapshot(deserialized_state) + + shared_state = State() + shared_state.import_state(deserialized_state) + + # No mutations performed + shared_state.commit() + current_state = shared_state.export_state() + + updates = _compute_state_updates(original_snapshot, current_state) + + assert updates == {} + + def test_shallow_copy_would_miss_nested_mutations(self) -> None: + """Regression test: a shallow copy (dict()) shares nested refs, hiding mutations. + + This reproduces the original bug from #4500 where ``dict(deserialized_state)`` + was used instead of ``copy.deepcopy()``. With a shallow copy the snapshot and + the live state share nested objects, so in-place mutations appear in both and + the diff produces an empty update set. + """ + from agent_framework._workflows._state import State + + deserialized_state: dict[str, Any] = { + "Local.config": {"code": "", "enabled": False}, + } + + # Shallow copy (the OLD, buggy behaviour) + shallow_snapshot = dict(deserialized_state) + + shared_state = State() + shared_state.import_state(deserialized_state) + + config = shared_state.get("Local.config") + config["code"] = "SOMECODEXXX" + config["enabled"] = True + + shared_state.commit() + current_state = shared_state.export_state() + + # With a shallow copy the mutation leaks into the snapshot → empty diff + updates_shallow = _compute_state_updates(shallow_snapshot, current_state) + assert updates_shallow == {}, "shallow copy should miss nested mutations (demonstrating the bug)" + + def test_create_state_snapshot_isolates_nested_objects(self) -> None: + """Verify _create_state_snapshot produces a deep copy that is mutation-proof. + + This ensures the production snapshot helper is not equivalent to ``dict()`` + and will correctly isolate nested objects so that later mutations are detected. + """ + from agent_framework_azurefunctions._app import _create_state_snapshot + + original: dict[str, Any] = { + "nested_dict": {"a": 1}, + "nested_list": [1, 2, 3], + } + + snapshot = _create_state_snapshot(original) + + # Mutate the originals in place + original["nested_dict"]["a"] = 999 + original["nested_list"].append(4) + + # Snapshot must be unaffected + assert snapshot["nested_dict"]["a"] == 1 + assert snapshot["nested_list"] == [1, 2, 3] + + def test_executor_activity_detects_nested_state_mutations(self) -> None: + """Integration test: the full activity wrapper detects nested mutations. + + This exercises the actual executor_activity function registered by + _setup_executor_activity to verify the production code path uses + _create_state_snapshot (deep copy) rather than dict() (shallow copy). + If the implementation regressed to using a shallow copy such as + ``dict(deserialized_state)``, this test would fail because in-place + mutations would leak into the snapshot and produce an empty diff. + """ + mock_executor = Mock() + mock_executor.id = "test-exec" + + async def mutate_nested_state( + message: Any, + source_executor_ids: Any, + state: Any, + runner_context: Any, + ) -> None: + config = state.get("Local.config") + config["code"] = "MUTATED" + config["enabled"] = True + state.commit() + + mock_executor.execute = AsyncMock(side_effect=mutate_nested_state) + + mock_workflow = Mock() + mock_workflow.executors = {"test-exec": mock_executor} + + # Capture the activity function by making decorators pass-through + captured_activity: dict[str, Any] = {} + + def passthrough_function_name(name: str) -> Callable[[FuncT], FuncT]: + def decorator(fn: FuncT) -> FuncT: + captured_activity["fn"] = fn + return fn + + return decorator + + def passthrough_activity_trigger(input_name: str) -> Callable[[FuncT], FuncT]: + def decorator(fn: FuncT) -> FuncT: + return fn + + return decorator + + with ( + patch.object(AgentFunctionApp, "function_name", side_effect=passthrough_function_name), + patch.object(AgentFunctionApp, "activity_trigger", side_effect=passthrough_activity_trigger), + patch.object(AgentFunctionApp, "_setup_workflow_orchestration"), + ): + AgentFunctionApp(workflow=mock_workflow) + + assert "fn" in captured_activity, "activity function was not captured" + + # Call the activity with nested state that the executor will mutate + input_data = json.dumps({ + "message": "test", + "shared_state_snapshot": { + "Local.config": {"code": "", "enabled": False}, + }, + "source_executor_ids": [SOURCE_ORCHESTRATOR], + }) + + result = json.loads(captured_activity["fn"](input_data)) + + # The deep copy snapshot must detect the in-place nested mutations + assert "Local.config" in result["shared_state_updates"], ( + "nested mutation not detected — snapshot may be using shallow copy" + ) + updated_config = result["shared_state_updates"]["Local.config"] + assert updated_config["code"] == "MUTATED" + assert updated_config["enabled"] is True + + if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py index 5bc97358464..c546ef5535e 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_chat_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_chat_client.py @@ -236,11 +236,11 @@ def __init__( session_token: str | None = None, client: BaseClient | None = None, boto3_session: Boto3Session | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Create a Bedrock chat client and load AWS credentials. @@ -252,11 +252,11 @@ def __init__( session_token: Optional AWS session token for temporary credentials. client: Preconfigured Bedrock runtime client; when omitted a boto3 session is created. boto3_session: Custom boto3 session used to build the runtime client if provided. + additional_properties: Additional properties stored on the client instance. middleware: Optional sequence of middlewares to include. function_invocation_configuration: Optional function invocation configuration env_file_path: Optional .env file path used by ``BedrockSettings`` to load defaults. env_file_encoding: Encoding for the optional .env file. - kwargs: Additional arguments forwarded to ``BaseChatClient``. Examples: .. code-block:: python @@ -303,9 +303,9 @@ class MyOptions(BedrockChatOptions, total=False): ) super().__init__( + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) self.model_id = chat_model_id self.region = region @@ -405,11 +405,16 @@ def _prepare_options( tool_config = self._prepare_tools(options.get("tools")) if tool_mode := validate_tool_mode(options.get("tool_choice")): - tool_config = tool_config or {} match tool_mode.get("mode"): - case "auto" | "none": - tool_config["toolChoice"] = {tool_mode.get("mode"): {}} + case "none": + # Bedrock doesn't support toolChoice "none". + # Omit toolConfig entirely so the model won't attempt tool calls. + tool_config = None + case "auto": + tool_config = tool_config or {} + tool_config["toolChoice"] = {"auto": {}} case "required": + tool_config = tool_config or {} if required_name := tool_mode.get("required_function_name"): tool_config["toolChoice"] = {"tool": {"name": required_name}} else: @@ -518,10 +523,22 @@ def _convert_content_to_bedrock_block(self, content: Content) -> dict[str, Any] } } case "function_result": + if content.items: + text_parts = [item.text or "" for item in content.items if item.type == "text"] + rich_items = [item for item in content.items if item.type in ("data", "uri")] + if rich_items: + logger.warning( + "Bedrock does not support rich content (images, audio) in tool results. " + "Rich content items will be omitted." + ) + tool_result_text = "\n".join(text_parts) if text_parts else "" + tool_result_blocks = self._convert_tool_result_to_blocks(tool_result_text) + else: + tool_result_blocks = self._convert_tool_result_to_blocks(content.result) tool_result_block = { "toolResult": { "toolUseId": content.call_id, - "content": self._convert_tool_result_to_blocks(content.result), + "content": tool_result_blocks, "status": "error" if content.exception else "success", } } @@ -542,7 +559,12 @@ def _convert_content_to_bedrock_block(self, content: Content) -> dict[str, Any] return None def _convert_tool_result_to_blocks(self, result: Any) -> list[dict[str, Any]]: - prepared_result = result if isinstance(result, str) else FunctionTool.parse_result(result) + if isinstance(result, str): + prepared_result = result + else: + parsed = FunctionTool.parse_result(result) + text_parts = [c.text or "" for c in parsed if c.type == "text"] + prepared_result = "\n".join(text_parts) if text_parts else str(result) try: parsed_result: object = json.loads(prepared_result) except json.JSONDecodeError: diff --git a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py index d07bdee45c1..3161ed4c884 100644 --- a/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py +++ b/python/packages/bedrock/agent_framework_bedrock/_embedding_client.py @@ -104,9 +104,9 @@ def __init__( session_token: str | None = None, client: BaseClient | None = None, boto3_session: Boto3Session | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a raw Bedrock embedding client.""" settings = load_settings( @@ -145,7 +145,7 @@ def __init__( self.model_id: str = settings["embedding_model_id"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess] self.region = resolved_region - super().__init__(**kwargs) + super().__init__(additional_properties=additional_properties) def service_url(self) -> str: """Get the URL of the service.""" @@ -274,9 +274,9 @@ def __init__( client: BaseClient | None = None, boto3_session: Boto3Session | None = None, otel_provider_name: str | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a Bedrock embedding client.""" super().__init__( @@ -287,8 +287,8 @@ def __init__( session_token=session_token, client=client, boto3_session=boto3_session, + additional_properties=additional_properties, otel_provider_name=otel_provider_name, env_file_path=env_file_path, env_file_encoding=env_file_encoding, - **kwargs, ) diff --git a/python/packages/bedrock/tests/test_bedrock_client.py b/python/packages/bedrock/tests/test_bedrock_client.py index e2a2f71750d..1566bff234c 100644 --- a/python/packages/bedrock/tests/test_bedrock_client.py +++ b/python/packages/bedrock/tests/test_bedrock_client.py @@ -31,6 +31,15 @@ def converse(self, **kwargs: Any) -> dict[str, Any]: } +def _make_client() -> BedrockChatClient: + """Create a BedrockChatClient with a stub runtime for unit tests.""" + return BedrockChatClient( + model_id="amazon.titan-text", + region="us-west-2", + client=_StubBedrockRuntime(), + ) + + async def test_get_response_invokes_bedrock_runtime() -> None: stub = _StubBedrockRuntime() client = BedrockChatClient( @@ -65,3 +74,66 @@ def test_build_request_requires_non_system_messages() -> None: with pytest.raises(ValueError): client._prepare_options(messages, {}) + + +def test_prepare_options_tool_choice_none_omits_tool_config() -> None: + """When tool_choice='none', toolConfig must be omitted entirely. + + Bedrock's Converse API only accepts 'auto', 'any', or 'tool' as valid + toolChoice keys. Sending {"none": {}} causes a ParamValidationError. + The fix omits toolConfig so the model won't attempt tool calls. + + Fixes #4529. + """ + client = _make_client() + messages = [Message(role="user", contents=[Content.from_text(text="hello")])] + + # Even when tools are provided, tool_choice="none" should strip toolConfig + options: dict[str, Any] = { + "tool_choice": "none", + "tools": [ + {"toolSpec": {"name": "get_weather", "description": "Get weather", "inputSchema": {"json": {}}}}, + ], + } + + request = client._prepare_options(messages, options) + + assert "toolConfig" not in request, ( + f"toolConfig should be omitted when tool_choice='none', got: {request.get('toolConfig')}" + ) + + +def test_prepare_options_tool_choice_auto_includes_tool_config() -> None: + """When tool_choice='auto', toolConfig.toolChoice should be {'auto': {}}.""" + client = _make_client() + messages = [Message(role="user", contents=[Content.from_text(text="hello")])] + + options: dict[str, Any] = { + "tool_choice": "auto", + "tools": [ + {"toolSpec": {"name": "get_weather", "description": "Get weather", "inputSchema": {"json": {}}}}, + ], + } + + request = client._prepare_options(messages, options) + + assert "toolConfig" in request + assert request["toolConfig"]["toolChoice"] == {"auto": {}} + + +def test_prepare_options_tool_choice_required_includes_any() -> None: + """When tool_choice='required' (no specific function), toolChoice should be {'any': {}}.""" + client = _make_client() + messages = [Message(role="user", contents=[Content.from_text(text="hello")])] + + options: dict[str, Any] = { + "tool_choice": "required", + "tools": [ + {"toolSpec": {"name": "get_weather", "description": "Get weather", "inputSchema": {"json": {}}}}, + ], + } + + request = client._prepare_options(messages, options) + + assert "toolConfig" in request + assert request["toolConfig"]["toolChoice"] == {"any": {}} diff --git a/python/packages/bedrock/tests/test_bedrock_settings.py b/python/packages/bedrock/tests/test_bedrock_settings.py index 016ed8ff05f..85e417602ad 100644 --- a/python/packages/bedrock/tests/test_bedrock_settings.py +++ b/python/packages/bedrock/tests/test_bedrock_settings.py @@ -132,4 +132,5 @@ def test_process_response_parses_tool_result() -> None: contents = chat_response.messages[0].contents assert contents[0].type == "function_result" - assert contents[0].result == {"answer": 42} + assert "answer" in str(contents[0].result) + assert contents[0].items is not None diff --git a/python/packages/claude/agent_framework_claude/_agent.py b/python/packages/claude/agent_framework_claude/_agent.py index 127e3647eeb..23703b2c53c 100644 --- a/python/packages/claude/agent_framework_claude/_agent.py +++ b/python/packages/claude/agent_framework_claude/_agent.py @@ -496,7 +496,16 @@ async def handler(args: dict[str, Any]) -> dict[str, Any]: result = await func_tool.invoke(arguments=args_instance) else: result = await func_tool.invoke(arguments=args) - return {"content": [{"type": "text", "text": str(result)}]} + content_blocks: list[dict[str, str]] = [] + for c in result: + if c.type == "text" and c.text: + content_blocks.append({"type": "text", "text": c.text}) + elif c.type in ("data", "uri"): + logger.warning( + "Claude Agent SDK does not support rich content (images, audio) " + "in tool results. Rich content items will be omitted." + ) + return {"content": content_blocks or [{"type": "text", "text": ""}]} except Exception as e: return {"content": [{"type": "text", "text": f"Error: {e}"}]} @@ -581,6 +590,7 @@ def run( *, stream: Literal[False] = ..., session: AgentSession | None = None, + options: OptionsT | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -591,6 +601,7 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, + options: OptionsT | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -600,7 +611,8 @@ def run( *, stream: bool = False, session: AgentSession | None = None, - **kwargs: Any, + options: OptionsT | None = None, + **kwargs: Any, # type: ignore ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Run the agent with the given messages. @@ -612,16 +624,16 @@ def run( returns an awaitable AgentResponse. session: The conversation session. If session has service_session_id set, the agent will resume that session. - kwargs: Additional keyword arguments including 'options' for runtime options - (model, permission_mode can be changed per-request). + options: Runtime options. Model and permission_mode can be changed per request. + kwargs: Additional keyword arguments for compatibility with the shared agent + interface (e.g. compaction_strategy, tokenizer). Not used by ClaudeAgent. Returns: When stream=True: An ResponseStream for streaming updates. When stream=False: An Awaitable[AgentResponse] with the complete response. """ - options = kwargs.pop("options", None) response = ResponseStream( - self._get_stream(messages, session=session, options=options, **kwargs), + self._get_stream(messages, session=session, options=options), finalizer=self._finalize_response, ) @@ -634,8 +646,7 @@ async def _get_stream( messages: AgentRunInputs | None = None, *, session: AgentSession | None = None, - options: OptionsT | MutableMapping[str, Any] | None = None, - **kwargs: Any, + options: OptionsT | None = None, ) -> AsyncIterable[AgentResponseUpdate]: """Internal streaming implementation.""" session = session or self.create_session() diff --git a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py index edacb614a57..fc2a35c72b8 100644 --- a/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py +++ b/python/packages/copilotstudio/agent_framework_copilotstudio/_agent.py @@ -196,7 +196,6 @@ def run( *, stream: Literal[False] = False, session: AgentSession | None = None, - **kwargs: Any, ) -> Awaitable[AgentResponse]: ... @overload @@ -206,7 +205,6 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, - **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... def run( @@ -215,7 +213,6 @@ def run( *, stream: bool = False, session: AgentSession | None = None, - **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: """Get a response from the agent. @@ -229,22 +226,20 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). - kwargs: Additional keyword arguments. Returns: When stream=False: An Awaitable[AgentResponse]. When stream=True: A ResponseStream of AgentResponseUpdate items. """ if stream: - return self._run_stream_impl(messages=messages, session=session, **kwargs) - return self._run_impl(messages=messages, session=session, **kwargs) + return self._run_stream_impl(messages=messages, session=session) + return self._run_impl(messages=messages, session=session) async def _run_impl( self, messages: AgentRunInputs | None = None, *, session: AgentSession | None = None, - **kwargs: Any, ) -> AgentResponse: """Non-streaming implementation of run.""" if not session: @@ -269,7 +264,6 @@ def _run_stream_impl( messages: AgentRunInputs | None = None, *, session: AgentSession | None = None, - **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: """Streaming implementation of run.""" diff --git a/python/packages/core/agent_framework/__init__.py b/python/packages/core/agent_framework/__init__.py index 95d9b97d644..0f652f23bdc 100644 --- a/python/packages/core/agent_framework/__init__.py +++ b/python/packages/core/agent_framework/__init__.py @@ -215,6 +215,7 @@ ) from .exceptions import ( MiddlewareException, + UserInputRequiredException, WorkflowCheckpointException, WorkflowConvergenceException, WorkflowException, @@ -349,6 +350,7 @@ "TypeCompatibilityError", "UpdateT", "UsageDetails", + "UserInputRequiredException", "ValidationTypeEnum", "Workflow", "WorkflowAgent", diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 2b35b96e589..c2c6e874f18 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -2,10 +2,10 @@ from __future__ import annotations -import inspect import logging import re import sys +import warnings from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack from copy import deepcopy @@ -27,11 +27,13 @@ from mcp import types from mcp.server.lowlevel import Server from mcp.shared.exceptions import McpError -from pydantic import BaseModel, Field, create_model +from pydantic import BaseModel +from . import _tools as _tool_utils # pyright: ignore[reportPrivateUsage] from ._clients import BaseChatClient, SupportsChatGetResponse +from ._docstrings import apply_layered_docstring from ._mcp import LOG_LEVEL_MAPPING, MCPTool -from ._middleware import AgentMiddlewareLayer, MiddlewareTypes +from ._middleware import AgentMiddlewareLayer, FunctionInvocationContext, MiddlewareTypes from ._serialization import SerializationMixin from ._sessions import ( AgentSession, @@ -40,12 +42,7 @@ InMemoryHistoryProvider, SessionContext, ) -from ._tools import ( - FunctionInvocationLayer, - FunctionTool, - ToolTypes, - normalize_tools, -) +from ._tools import FunctionInvocationLayer, FunctionTool, ToolTypes, normalize_tools from ._types import ( AgentResponse, AgentResponseUpdate, @@ -57,7 +54,7 @@ map_chat_to_agent_update, normalize_messages, ) -from .exceptions import AgentInvalidResponseException +from .exceptions import AgentInvalidResponseException, UserInputRequiredException from .observability import AgentTelemetryLayer if sys.version_info >= (3, 13): @@ -79,6 +76,9 @@ logger = logging.getLogger("agent_framework") +_append_unique_tools = _tool_utils._append_unique_tools # pyright: ignore[reportPrivateUsage] +_get_tool_name = _tool_utils._get_tool_name # pyright: ignore[reportPrivateUsage] + ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) OptionsCoT = TypeVar( "OptionsCoT", @@ -88,19 +88,6 @@ ) -def _get_tool_name(tool: Any) -> str | None: - """Extract a tool's name from either an object with a .name attribute or a dict tool definition.""" - if isinstance(tool, Mapping): - tool_mapping = cast(Mapping[str, Any], tool) - func = tool_mapping.get("function") - if isinstance(func, Mapping): - func_mapping = cast(Mapping[str, Any], func) - name = func_mapping.get("name") - return name if isinstance(name, str) else None - return None - return getattr(tool, "name", None) - - def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: """Merge two options dicts, with override values taking precedence. @@ -115,11 +102,14 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, for key, value in override.items(): if value is None: continue - if key == "tools" and result.get("tools"): - # Combine tool lists, avoiding duplicates by name - existing_names = {_get_tool_name(t) for t in result["tools"]} - {None} - unique_new = [t for t in value if _get_tool_name(t) not in existing_names] - result["tools"] = list(result["tools"]) + unique_new + if key == "tools" and (result.get("tools") or value): + base_tools = normalize_tools(result.get("tools")) + override_tools = normalize_tools(value) + result["tools"] = _append_unique_tools( + list(base_tools), + override_tools, + duplicate_error_message="Tool names must be unique.", + ) elif key == "logit_bias" and result.get("logit_bias"): # Merge logit_bias dicts result["logit_bias"] = {**result["logit_bias"], **value} @@ -180,8 +170,8 @@ class _RunContext(TypedDict): chat_options: MutableMapping[str, Any] compaction_strategy: CompactionStrategy | None tokenizer: TokenizerProtocol | None - filtered_kwargs: Mapping[str, Any] - finalize_kwargs: Mapping[str, Any] + client_kwargs: Mapping[str, Any] + function_invocation_kwargs: Mapping[str, Any] # region Agent Protocol @@ -229,15 +219,15 @@ async def _stream(): return AgentResponse(messages=[], response_id="custom-response") - def create_session(self, **kwargs): + def create_session(self, *, session_id: str | None = None): from agent_framework import AgentSession - return AgentSession(**kwargs) + return AgentSession(session_id=session_id) - def get_session(self, *, service_session_id, **kwargs): + def get_session(self, service_session_id: str, *, session_id: str | None = None): from agent_framework import AgentSession - return AgentSession(service_session_id=service_session_id, **kwargs) + return AgentSession(service_session_id=service_session_id, session_id=session_id) # Verify the instance satisfies the protocol @@ -256,6 +246,8 @@ def run( *, stream: Literal[False] = ..., session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: """Get a response from the agent (non-streaming).""" @@ -268,6 +260,8 @@ def run( *, stream: Literal[True], session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Get a streaming response from the agent.""" @@ -279,6 +273,8 @@ def run( *, stream: bool = False, session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Get a response from the agent. @@ -293,6 +289,8 @@ def run( Keyword Args: stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). + function_invocation_kwargs: Keyword arguments forwarded to tool invocation. + client_kwargs: Additional client-specific keyword arguments. kwargs: Additional keyword arguments. Returns: @@ -302,11 +300,11 @@ def run( """ ... - def create_session(self, **kwargs: Any) -> AgentSession: + def create_session(self, *, session_id: str | None = None) -> AgentSession: """Creates a new conversation session.""" ... - def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: """Gets or creates a session for a service-managed session ID.""" ... @@ -389,6 +387,13 @@ def __init__( additional_properties: Additional properties set on the agent. kwargs: Additional keyword arguments (merged into additional_properties). """ + if kwargs: + warnings.warn( + "Passing additional properties as direct keyword arguments to BaseAgent is deprecated; " + "pass them via additional_properties instead.", + DeprecationWarning, + stacklevel=3, + ) if id is None: id = str(uuid4()) self.id = id @@ -403,27 +408,40 @@ def __init__( self.additional_properties: dict[str, Any] = cast(dict[str, Any], additional_properties or {}) self.additional_properties.update(kwargs) - def create_session(self, *, session_id: str | None = None, **kwargs: Any) -> AgentSession: + def create_session(self, *, session_id: str | None = None) -> AgentSession: """Create a new lightweight session. + This will be used by an agent to hold the persisted session. + This depends on the service used, in some cases, or with store=True + this will add the ``service_session_id`` based on the response, + which is then fed back to the API on the next call. + + In other cases, if there is a HistoryProvider setup in the agent, + that is used and it can store state in the session. + + If there is no HistoryProvider and store=False or the default of a service is False. + Then a ``InMemoryHistoryProvider`` instance is added to the agent and used with the session automatically. + The ``InMemoryHistoryProvider`` stores the messages as `state` in the session by default. + Keyword Args: session_id: Optional session ID (generated if not provided). - kwargs: Additional keyword arguments. Returns: A new AgentSession instance. """ return AgentSession(session_id=session_id) - def get_session(self, *, service_session_id: str, session_id: str | None = None, **kwargs: Any) -> AgentSession: - """Get or create a session for a service-managed session ID. + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + """Get a session for a service-managed session ID. + + Only use this to create a session continuing that session id from a service. + Otherwise use ``create_session``. Args: service_session_id: The service-managed session ID. Keyword Args: session_id: Optional local session ID (generated if not provided). - kwargs: Additional keyword arguments. Returns: A new AgentSession instance with service_session_id set. @@ -463,9 +481,8 @@ def as_tool( description: str | None = None, arg_name: str = "task", arg_description: str | None = None, - stream_callback: Callable[[AgentResponseUpdate], None] - | Callable[[AgentResponseUpdate], Awaitable[None]] - | None = None, + approval_mode: Literal["always_require", "never_require"] = "never_require", + stream_callback: Callable[[AgentResponseUpdate], Awaitable[None] | None] | None = None, propagate_session: bool = False, ) -> FunctionTool: """Create a FunctionTool that wraps this agent. @@ -476,21 +493,15 @@ def as_tool( arg_name: The name of the function argument (default: "task"). arg_description: The description for the function argument. If None, defaults to "Task for {tool_name}". + approval_mode: Whether this delegated tool requires approval before execution. stream_callback: Optional callback for streaming responses. If provided, uses run(..., stream=True). - propagate_session: If True, the parent agent's ``AgentSession`` is - forwarded to this sub-agent's ``run()`` call, so both agents - operate within the same logical session (sharing the same - ``session_id`` and provider-managed state, such as any stored - conversation history or metadata). Defaults to False, meaning - the sub-agent runs with a new, independent session. + propagate_session: If True, the parent agent's session is forwarded + to this sub-agent's ``run()`` call so both agents share the + same session. Defaults to False. Returns: A FunctionTool that can be used as a tool by other agents. - Raises: - TypeError: If the agent does not implement SupportsAgentRun. - ValueError: If the agent tool name cannot be determined. - Examples: .. code-block:: python @@ -518,59 +529,46 @@ def as_tool( tool_description = description or self.description or "" argument_description = arg_description or f"Task for {tool_name}" - # Create dynamic input model with the specified argument name - field_info = Field(..., description=argument_description) - model_name = f"{name or _sanitize_agent_name(self.name) or 'agent'}_task" - input_model = create_model(model_name, **{arg_name: (str, field_info)}) # type: ignore[call-overload] - - # Check if callback is async once, outside the wrapper - is_async_callback = stream_callback is not None and inspect.iscoroutinefunction(stream_callback) - - async def agent_wrapper(**kwargs: Any) -> str: - """Wrapper function that calls the agent.""" - # Extract the input from kwargs using the specified arg_name - input_text = kwargs.get(arg_name, "") - - # Extract parent session when propagate_session is enabled - parent_session = kwargs.get("session") if propagate_session else None - - # Forward runtime context kwargs, excluding framework-internal keys. - forwarded_kwargs = { - k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options", "session") - } - - if stream_callback is None: - # Use non-streaming mode - return ( - await self.run( - input_text, - stream=False, - session=parent_session, - **forwarded_kwargs, - ) - ).text - - # Use streaming mode - accumulate updates and create final response - response_updates: list[AgentResponseUpdate] = [] - async for update in self.run(input_text, stream=True, session=parent_session, **forwarded_kwargs): - response_updates.append(update) - if is_async_callback: - await stream_callback(update) # type: ignore[misc] - else: - stream_callback(update) + input_schema = { + "type": "object", + "properties": { + arg_name: { + "type": "string", + "description": argument_description, + } + }, + "required": [arg_name], + "additionalProperties": False, + } - # Create final text from accumulated updates - return AgentResponse.from_updates(response_updates).text + async def _agent_wrapper(ctx: FunctionInvocationContext, **kwargs: Any) -> str: + """Wrapper function that calls the agent. - agent_tool: FunctionTool = FunctionTool( + Args: + ctx: the function invocation context used + **kwargs: only used to dynamically load the argument that is defined for this tool. + """ + stream = self.run( + str(kwargs.get(arg_name, "")), + stream=True, + session=ctx.session if propagate_session else None, + function_invocation_kwargs=dict(ctx.kwargs), + ) + if stream_callback is not None: + stream.with_transform_hook(stream_callback) + final_response = await stream.get_final_response() + if final_response.user_input_requests: + raise UserInputRequiredException(contents=final_response.user_input_requests) + # TODO(Copilot): update once #4331 merges + return final_response.text + + return FunctionTool( name=tool_name, description=tool_description, - func=agent_wrapper, - input_model=input_model, # type: ignore - approval_mode="never_require", + func=_agent_wrapper, + input_model=input_schema, + approval_mode=approval_mode, ) - agent_tool._forward_runtime_kwargs = True # type: ignore - return agent_tool # region Agent @@ -812,6 +810,8 @@ def run( options: ChatOptions[ResponseModelBoundT], compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ... @@ -826,6 +826,8 @@ def run( options: OptionsCoT | ChatOptions[None] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -840,6 +842,8 @@ def run( options: OptionsCoT | ChatOptions[Any] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -853,6 +857,8 @@ def run( options: OptionsCoT | ChatOptions[Any] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Run the agent with the given messages and options. @@ -882,14 +888,23 @@ def run( tokenizer: Optional per-run tokenizer override passed to ``client.get_response()``. When omitted, the agent-level override is used, falling back to the client default. - kwargs: Additional keyword arguments for the agent. These are only - passed to functions that are called. + function_invocation_kwargs: Keyword arguments forwarded to tool invocation. + client_kwargs: Additional client-specific keyword arguments for the chat client. + kwargs: Deprecated additional keyword arguments for the agent. + They are forwarded to both tool invocation and the chat client for compatibility. Returns: When stream=False: An Awaitable[AgentResponse] containing the agent's response. When stream=True: A ResponseStream of AgentResponseUpdate items with ``get_final_response()`` for the final AgentResponse. """ + if kwargs: + warnings.warn( + "Passing runtime keyword arguments directly to run() is deprecated; pass tool values via " + "function_invocation_kwargs and client-specific values via client_kwargs instead.", + DeprecationWarning, + stacklevel=2, + ) if not stream: async def _run_non_streaming() -> AgentResponse[Any]: @@ -900,7 +915,9 @@ async def _run_non_streaming() -> AgentResponse[Any]: options=options, compaction_strategy=compaction_strategy, tokenizer=tokenizer, - kwargs=kwargs, + legacy_kwargs=kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ) response = cast( ChatResponse[Any], @@ -910,7 +927,8 @@ async def _run_non_streaming() -> AgentResponse[Any]: options=ctx["chat_options"], # type: ignore[reportArgumentType] compaction_strategy=ctx["compaction_strategy"], tokenizer=ctx["tokenizer"], - **ctx["filtered_kwargs"], + function_invocation_kwargs=ctx["function_invocation_kwargs"], + client_kwargs=ctx["client_kwargs"], ), ) @@ -985,7 +1003,9 @@ async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]] options=options, compaction_strategy=compaction_strategy, tokenizer=tokenizer, - kwargs=kwargs, + legacy_kwargs=kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ) ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it return self.client.get_response( # type: ignore[call-overload, no-any-return] @@ -994,7 +1014,8 @@ async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]] options=ctx["chat_options"], # type: ignore[reportArgumentType] compaction_strategy=ctx["compaction_strategy"], tokenizer=ctx["tokenizer"], - **ctx["filtered_kwargs"], + function_invocation_kwargs=ctx["function_invocation_kwargs"], + client_kwargs=ctx["client_kwargs"], ) def _propagate_conversation_id( @@ -1082,9 +1103,12 @@ async def _prepare_run_context( options: Mapping[str, Any] | None, compaction_strategy: CompactionStrategy | None, tokenizer: TokenizerProtocol | None, - kwargs: dict[str, Any], + legacy_kwargs: Mapping[str, Any], + function_invocation_kwargs: Mapping[str, Any] | None, + client_kwargs: Mapping[str, Any] | None, ) -> _RunContext: opts = dict(options) if options else {} + existing_additional_args: dict[str, Any] = opts.pop("additional_function_arguments", None) or {} # Get tools from options or named parameter (named param takes precedence) tools_ = tools if tools is not None else opts.pop("tools", None) @@ -1115,35 +1139,50 @@ async def _prepare_run_context( input_messages=input_messages, options=opts, ) + default_additional_args = chat_options.pop("additional_function_arguments", None) + if isinstance(default_additional_args, Mapping): + existing_additional_args = { + **dict(cast(Mapping[str, Any], default_additional_args)), + **existing_additional_args, + } agent_name = self._get_agent_name() + base_tools = normalize_tools(chat_options.pop("tools", None)) + mcp_duplicate_message = "Tool names must be unique. Consider setting `tool_name_prefix` on the MCPTool." # Normalize tools normalized_tools = normalize_tools(tools_) - # Resolve final tool list (runtime provided tools + local MCP server tools) - final_tools: list[FunctionTool | Callable[..., Any] | dict[str, Any] | Any] = [] + # Resolve final tool list (configured tools + runtime provided tools + local MCP server tools) + final_tools = list(base_tools) for tool in normalized_tools: if isinstance(tool, MCPTool): if not tool.is_connected: await self._async_exit_stack.enter_async_context(tool) - final_tools.extend(tool.functions) # type: ignore + _append_unique_tools( + final_tools, + tool.functions, + duplicate_error_message=mcp_duplicate_message, + ) else: - final_tools.append(tool) # type: ignore + _append_unique_tools(final_tools, [tool]) # type: ignore[list-item] - existing_names = {name for t in final_tools if (name := _get_tool_name(t)) is not None} for mcp_server in self.mcp_tools: if not mcp_server.is_connected: await self._async_exit_stack.enter_async_context(mcp_server) - final_tools.extend(f for f in mcp_server.functions if f.name not in existing_names) + _append_unique_tools( + final_tools, + mcp_server.functions, + duplicate_error_message=mcp_duplicate_message, + ) - # Merge runtime kwargs into additional_function_arguments so they're available - # in function middleware context and tool invocation. - existing_additional_args: dict[str, Any] = opts.pop("additional_function_arguments", None) or {} - additional_function_arguments = {**kwargs, **existing_additional_args} - # Include session so as_tool() wrappers with propagate_session=True can access it. - if active_session is not None: - additional_function_arguments["session"] = active_session + # TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed. + # Legacy compatibility still fans out direct run kwargs into tool runtime kwargs. + effective_function_invocation_kwargs = { + **dict(legacy_kwargs), + **(dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}), + } + additional_function_arguments = {**effective_function_invocation_kwargs, **existing_additional_args} # Build options dict from run() options merged with provided options run_opts: dict[str, Any] = { @@ -1152,7 +1191,6 @@ async def _prepare_run_context( if active_session else opts.pop("conversation_id", None), "allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None), - "additional_function_arguments": additional_function_arguments or None, "frequency_penalty": opts.pop("frequency_penalty", None), "logit_bias": opts.pop("logit_bias", None), "max_tokens": opts.pop("max_tokens", None), @@ -1164,7 +1202,7 @@ async def _prepare_run_context( "store": opts.pop("store", None), "temperature": opts.pop("temperature", None), "tool_choice": opts.pop("tool_choice", None), - "tools": final_tools, + "tools": final_tools or None, "top_p": opts.pop("top_p", None), "user": opts.pop("user", None), **opts, # Remaining options are provider-specific @@ -1176,11 +1214,14 @@ async def _prepare_run_context( # Build session_messages from session context: context messages + input messages session_messages: list[Message] = session_context.get_messages(include_input=True) - # Ensure session is forwarded in kwargs for tool invocation - finalize_kwargs = dict(kwargs) - finalize_kwargs["session"] = active_session - # Filter chat_options from kwargs to prevent duplicate keyword argument - filtered_kwargs = {k: v for k, v in finalize_kwargs.items() if k != "chat_options"} + # TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed. + # Legacy compatibility still fans out direct run kwargs into client kwargs. + effective_client_kwargs = { + **dict(legacy_kwargs), + **(dict(client_kwargs) if client_kwargs is not None else {}), + } + if active_session is not None: + effective_client_kwargs["session"] = active_session return { "session": active_session, @@ -1191,8 +1232,8 @@ async def _prepare_run_context( "chat_options": co, "compaction_strategy": compaction_strategy or self.compaction_strategy, "tokenizer": tokenizer or self.tokenizer, - "filtered_kwargs": filtered_kwargs, - "finalize_kwargs": finalize_kwargs, + "client_kwargs": effective_client_kwargs, + "function_invocation_kwargs": additional_function_arguments, } async def _finalize_response( @@ -1395,11 +1436,19 @@ async def _call_tool( # type: ignore ), ) from e - # Convert result to MCP content - if isinstance(result, str): - return [types.TextContent(type="text", text=result)] # type: ignore[attr-defined] - - return [types.TextContent(type="text", text=str(result))] # type: ignore[attr-defined] + # Convert result to MCP content. + # Currently only text items are forwarded over MCP; rich content + # (images, audio) is not yet supported in the MCP server path. + mcp_content: list[types.TextContent | types.ImageContent | types.EmbeddedResource] = [] # type: ignore[attr-defined] + for c in result: + if c.type == "text" and c.text: + mcp_content.append(types.TextContent(type="text", text=c.text)) # type: ignore[attr-defined] + elif c.type in ("data", "uri"): + logger.warning( + "MCP server does not yet forward rich content (images, audio) " + "in tool results. Rich content items will be omitted." + ) + return mcp_content or [types.TextContent(type="text", text="")] # type: ignore[attr-defined] @server.set_logging_level() # type: ignore async def _set_logging_level(level: types.LoggingLevel) -> None: # type: ignore @@ -1434,6 +1483,58 @@ class Agent( For a minimal implementation without these features, use :class:`RawAgent`. """ + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[False] = ..., + session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + + @overload + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: Literal[True], + session: AgentSession | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + middleware: Sequence[MiddlewareTypes] | None = None, + options: OptionsCoT | ChatOptions[Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: + """Run the agent.""" + super_run = cast( + "Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]", + super().run, # type: ignore[misc] + ) + return super_run( # type: ignore[no-any-return] + messages=messages, + stream=stream, + session=session, + middleware=middleware, + options=options, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + **kwargs, + ) + def __init__( self, client: SupportsChatGetResponse[OptionsCoT], @@ -1465,3 +1566,34 @@ def __init__( tokenizer=tokenizer, **kwargs, ) + + +def _apply_agent_docstrings() -> None: + """Align public agent docstrings with the raw implementation.""" + apply_layered_docstring( + AgentMiddlewareLayer.run, + RawAgent.run, + extra_keyword_args={ + "middleware": """ + Optional per-run agent, chat, and function middleware. + Agent middleware wraps the run itself, while chat and function middleware are forwarded to the + underlying chat-client stack for this call. + """, + }, + ) + apply_layered_docstring(AgentTelemetryLayer.run, AgentMiddlewareLayer.run) + apply_layered_docstring( + Agent.run, + RawAgent.run, + extra_keyword_args={ + "middleware": """ + Optional per-run agent, chat, and function middleware. + Agent middleware wraps the run itself, while chat and function middleware are forwarded to the + underlying chat-client stack for this call. + """, + }, + ) + apply_layered_docstring(Agent.__init__, RawAgent.__init__) + + +_apply_agent_docstrings() diff --git a/python/packages/core/agent_framework/_clients.py b/python/packages/core/agent_framework/_clients.py index 5f9c1bb08f8..4fd563d3e03 100644 --- a/python/packages/core/agent_framework/_clients.py +++ b/python/packages/core/agent_framework/_clients.py @@ -4,6 +4,7 @@ import logging import sys +import warnings from abc import ABC, abstractmethod from collections.abc import ( AsyncIterable, @@ -27,6 +28,7 @@ from pydantic import BaseModel +from ._docstrings import apply_layered_docstring from ._serialization import SerializationMixin from ._tools import ( FunctionInvocationConfiguration, @@ -105,7 +107,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]): class CustomChatClient: additional_properties: dict = {} - def get_response(self, messages, *, stream=False, **kwargs): + def get_response(self, messages, *, stream=False, client_kwargs=None, **kwargs): if stream: from agent_framework import ChatResponseUpdate, ResponseStream @@ -149,6 +151,8 @@ def get_response( options: OptionsContraT | ChatOptions[None] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]]: ... @@ -161,6 +165,8 @@ def get_response( options: OptionsContraT | ChatOptions[Any] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... @@ -172,6 +178,8 @@ def get_response( options: OptionsContraT | ChatOptions[Any] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: """Send input and return the response. @@ -182,7 +190,9 @@ def get_response( options: Chat options as a TypedDict. compaction_strategy: Optional per-call compaction override. tokenizer: Optional per-call tokenizer override. - **kwargs: Additional chat options. + function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers. + client_kwargs: Additional client-specific keyword arguments. + **kwargs: Deprecated additional client-specific keyword arguments. Returns: When stream=False: An awaitable ChatResponse from the client. @@ -283,23 +293,31 @@ async def _stream(): def __init__( self, *, - additional_properties: dict[str, Any] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + additional_properties: dict[str, Any] | None = None, **kwargs: Any, ) -> None: """Initialize a BaseChatClient instance. Keyword Args: - additional_properties: Additional properties for the client. compaction_strategy: Optional compaction strategy to apply before model calls. tokenizer: Optional tokenizer used by token-aware compaction strategies. - kwargs: Additional keyword arguments (merged into additional_properties). + additional_properties: Additional properties for the client. + kwargs: Additional keyword arguments (merged into additional_properties for now). """ self.additional_properties = additional_properties or {} self.compaction_strategy = compaction_strategy self.tokenizer = tokenizer - super().__init__(**kwargs) + if kwargs: + warnings.warn( + "Passing additional properties as direct keyword arguments to BaseChatClient is deprecated; " + "pass them via additional_properties instead.", + DeprecationWarning, + stacklevel=3, + ) + self.additional_properties.update(kwargs) + super().__init__() def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]: """Convert the instance to a dictionary. @@ -486,7 +504,13 @@ def get_response( When omitted, the client-level default is used. tokenizer: Optional per-call tokenizer override. When omitted, the client-level default is used. - **kwargs: Other keyword arguments, can be used to pass function specific parameters. + **kwargs: Additional compatibility keyword arguments. Lower chat-client layers do not + consume ``function_invocation_kwargs`` directly; if present, it is ignored here + because function invocation has already been handled by upper layers. If a + ``client_kwargs`` mapping is present, it is flattened into standard keyword + arguments before forwarding to ``_inner_get_response()`` so client implementations + can leverage those values, while implementations that ignore + extra kwargs remain compatible. Returns: When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse. @@ -495,12 +519,21 @@ def get_response( compaction_strategy=compaction_strategy, tokenizer=tokenizer, ) + compatibility_client_kwargs = kwargs.pop("client_kwargs", None) + kwargs.pop("function_invocation_kwargs", None) + merged_client_kwargs = ( + dict(cast(Mapping[str, Any], compatibility_client_kwargs)) + if isinstance(compatibility_client_kwargs, Mapping) + else {} + ) + merged_client_kwargs.update(kwargs) + if not compaction_overrides: return self._inner_get_response( messages=messages, stream=stream, - options=options or {}, - **kwargs, + options=options or {}, # type: ignore[arg-type] + **merged_client_kwargs, ) if stream: @@ -514,7 +547,7 @@ async def _get_stream() -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]] messages=prepared_messages, stream=True, options=options or {}, - **kwargs, + **merged_client_kwargs, ) if isinstance(stream_response, ResponseStream): return stream_response # type: ignore[reportUnknownVariableType] @@ -534,7 +567,7 @@ async def _get_response() -> ChatResponse[Any]: messages=prepared_messages, stream=False, options=options or {}, - **kwargs, + **merged_client_kwargs, ) return _get_response() @@ -564,7 +597,7 @@ def as_agent( function_invocation_configuration: FunctionInvocationConfiguration | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, - **kwargs: Any, + additional_properties: Mapping[str, Any] | None = None, ) -> Agent[OptionsCoT]: """Create a Agent with this client. @@ -590,7 +623,7 @@ def as_agent( client-level compaction defaults remain in effect for each call. tokenizer: Optional agent-level tokenizer override. When omitted, client-level tokenizer defaults remain in effect for each call. - kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``. + additional_properties: Additional properties stored on the created agent. Returns: A Agent instance configured with this chat client. @@ -615,21 +648,24 @@ def as_agent( """ from ._agents import Agent - return Agent( - client=self, - id=id, - name=name, - description=description, - instructions=instructions, - tools=tools, - default_options=cast(Any, default_options), - context_providers=context_providers, - middleware=middleware, - function_invocation_configuration=function_invocation_configuration, - compaction_strategy=compaction_strategy, - tokenizer=tokenizer, - **kwargs, - ) + agent_kwargs: dict[str, Any] = { + "client": self, + "id": id, + "name": name, + "description": description, + "instructions": instructions, + "tools": tools, + "default_options": cast(Any, default_options), + "context_providers": context_providers, + "middleware": middleware, + "compaction_strategy": compaction_strategy, + "tokenizer": tokenizer, + "additional_properties": dict(additional_properties) if additional_properties is not None else None, + } + if function_invocation_configuration is not None: + agent_kwargs["function_invocation_configuration"] = function_invocation_configuration + + return Agent(**agent_kwargs) # endregion @@ -892,16 +928,14 @@ def __init__( self, *, additional_properties: dict[str, Any] | None = None, - **kwargs: Any, ) -> None: """Initialize a BaseEmbeddingClient instance. Args: additional_properties: Additional properties to pass to the client. - **kwargs: Additional keyword arguments passed to parent classes (for MRO). """ self.additional_properties = additional_properties or {} - super().__init__(**kwargs) + super().__init__() @abstractmethod async def get_embeddings( @@ -923,3 +957,36 @@ async def get_embeddings( # endregion + + +def _apply_get_response_docstrings() -> None: + """Align layered chat-client docstrings with the lowest public implementation.""" + from ._middleware import ChatMiddlewareLayer + from ._tools import FunctionInvocationLayer + from .observability import ChatTelemetryLayer + + apply_layered_docstring(ChatTelemetryLayer.get_response, BaseChatClient.get_response) + apply_layered_docstring( + FunctionInvocationLayer.get_response, + ChatTelemetryLayer.get_response, + extra_keyword_args={ + "function_middleware": """ + Optional per-call function middleware. + When omitted, middleware configured on the client or forwarded from higher layers is used. + """, + }, + ) + apply_layered_docstring( + ChatMiddlewareLayer.get_response, + FunctionInvocationLayer.get_response, + extra_keyword_args={ + "middleware": """ + Optional per-call chat and function middleware. + This compatibility keyword argument is merged with any ``client_kwargs["middleware"]`` value + before the request is executed. + """, + }, + ) + + +_apply_get_response_docstrings() diff --git a/python/packages/core/agent_framework/_compaction.py b/python/packages/core/agent_framework/_compaction.py index 07d18da6957..8a15a6438c4 100644 --- a/python/packages/core/agent_framework/_compaction.py +++ b/python/packages/core/agent_framework/_compaction.py @@ -466,6 +466,9 @@ def annotate_message_groups( def _serialize_content(content: Content) -> dict[str, Any]: payload = content.to_dict(exclude_none=True) payload.pop("raw_representation", None) + # ``items`` mirrors ``result`` for function_result content; exclude it + # to avoid double-counting tokens during estimation. + payload.pop("items", None) return payload diff --git a/python/packages/core/agent_framework/_docstrings.py b/python/packages/core/agent_framework/_docstrings.py new file mode 100644 index 00000000000..44dd7c50a36 --- /dev/null +++ b/python/packages/core/agent_framework/_docstrings.py @@ -0,0 +1,85 @@ +# Copyright (c) Microsoft. All rights reserved. + +from __future__ import annotations + +import inspect +from collections.abc import Callable, Mapping +from typing import Any + +_GOOGLE_SECTION_HEADERS = ( + "Args:", + "Keyword Args:", + "Returns:", + "Raises:", + "Examples:", + "Note:", + "Notes:", + "Warning:", + "Warnings:", +) + + +def _find_section_index(lines: list[str], header: str) -> int | None: + for index, line in enumerate(lines): + if line == header: + return index + return None + + +def _find_next_section_index(lines: list[str], start: int) -> int: + for index in range(start, len(lines)): + if lines[index] in _GOOGLE_SECTION_HEADERS: + return index + return len(lines) + + +def _format_keyword_arg_lines(extra_keyword_args: Mapping[str, str]) -> list[str]: + formatted_lines: list[str] = [] + for name, description in extra_keyword_args.items(): + description_lines = inspect.cleandoc(description).splitlines() + if not description_lines: + formatted_lines.append(f" {name}:") + continue + formatted_lines.append(f" {name}: {description_lines[0]}") + formatted_lines.extend(f" {line}" for line in description_lines[1:]) + return formatted_lines + + +def build_layered_docstring( + source: Callable[..., Any], + *, + extra_keyword_args: Mapping[str, str] | None = None, +) -> str | None: + """Build a Google-style docstring from a lower-layer implementation.""" + docstring = inspect.getdoc(source) + if not docstring: + return None + if not extra_keyword_args: + return docstring + + lines = docstring.splitlines() + formatted_keyword_arg_lines = _format_keyword_arg_lines(extra_keyword_args) + keyword_args_index = _find_section_index(lines, "Keyword Args:") + + if keyword_args_index is None: + args_index = _find_section_index(lines, "Args:") + if args_index is not None: + insert_index = _find_next_section_index(lines, args_index + 1) + else: + insert_index = _find_next_section_index(lines, 0) + lines[insert_index:insert_index] = ["", "Keyword Args:", *formatted_keyword_arg_lines] + return "\n".join(lines).rstrip() + + insert_index = _find_next_section_index(lines, keyword_args_index + 1) + lines[insert_index:insert_index] = formatted_keyword_arg_lines + return "\n".join(lines).rstrip() + + +def apply_layered_docstring( + target: Callable[..., Any], + source: Callable[..., Any], + *, + extra_keyword_args: Mapping[str, str] | None = None, +) -> None: + """Copy a lower-layer docstring onto a wrapper and extend it when needed.""" + target.__doc__ = build_layered_docstring(source, extra_keyword_args=extra_keyword_args) diff --git a/python/packages/core/agent_framework/_mcp.py b/python/packages/core/agent_framework/_mcp.py index b07a8722041..28c5f6db6a9 100644 --- a/python/packages/core/agent_framework/_mcp.py +++ b/python/packages/core/agent_framework/_mcp.py @@ -26,9 +26,7 @@ from mcp.shared.session import RequestResponder from opentelemetry import propagate -from ._tools import ( - FunctionTool, -) +from ._tools import FunctionTool from ._types import ( Content, Message, @@ -59,6 +57,8 @@ class MCPSpecificApproval(TypedDict, total=False): logger = logging.getLogger(__name__) +_MCP_REMOTE_NAME_KEY = "_mcp_remote_name" +_MCP_NORMALIZED_NAME_KEY = "_mcp_normalized_name" # region: Helpers @@ -142,69 +142,60 @@ def _parse_message_from_mcp( def _parse_tool_result_from_mcp( mcp_type: types.CallToolResult, -) -> str: - """Parse an MCP CallToolResult directly into a string representation. +) -> list[Content]: + """Parse an MCP CallToolResult into a list of Content items. - Converts each content item in the MCP result to its string form and combines them. - This skips the intermediate Content object step for tool results. + Converts each content item in the MCP result to its appropriate + Content form. Text items become ``Content(type="text")`` and media + items (images, audio) are preserved as rich Content. Args: mcp_type: The MCP CallToolResult object to convert. Returns: - A string representation of the tool result — either plain text or serialized JSON. + A list of Content items representing the tool result. """ - import json - - parts: list[str] = [] + result: list[Content] = [] for item in mcp_type.content: match item: case types.TextContent(): - parts.append(item.text) + result.append(Content.from_text(item.text)) case types.ImageContent() | types.AudioContent(): - parts.append( - json.dumps( - { - "type": "image" if isinstance(item, types.ImageContent) else "audio", - "data": item.data, - "mimeType": item.mimeType, - }, - default=str, + decoded = base64.b64decode(item.data) + result.append( + Content.from_data( + data=decoded, + media_type=item.mimeType, ) ) case types.ResourceLink(): - parts.append( - json.dumps( - { - "type": "resource_link", - "uri": str(item.uri), - "mimeType": item.mimeType, - }, - default=str, + result.append( + Content.from_uri( + uri=str(item.uri), + media_type=item.mimeType, ) ) case types.EmbeddedResource(): match item.resource: case types.TextResourceContents(): - parts.append(item.resource.text) + result.append(Content.from_text(item.resource.text)) case types.BlobResourceContents(): - parts.append( - json.dumps( - { - "type": "blob", - "data": item.resource.blob, - "mimeType": item.resource.mimeType, - }, - default=str, + blob = item.resource.blob + mime = item.resource.mimeType or "application/octet-stream" + if not blob.startswith("data:"): + blob = f"data:{mime};base64,{blob}" + result.append( + Content.from_uri( + uri=blob, + media_type=mime, ) ) case _: - parts.append(str(item)) - if not parts: - return "" - if len(parts) == 1: - return parts[0] - return json.dumps(parts, default=str) + result.append(Content.from_text(str(item))) + + if not result: + result.append(Content.from_text("")) + return result def _parse_content_from_mcp( @@ -381,6 +372,20 @@ def _normalize_mcp_name(name: str) -> str: return re.sub(r"[^A-Za-z0-9_.-]", "-", name) +def _build_prefixed_mcp_name( + normalized_name: str, + tool_name_prefix: str | None, +) -> str: + """Build the exposed MCP function name from a normalized name and optional prefix.""" + if not tool_name_prefix: + return normalized_name + normalized_prefix = _normalize_mcp_name(tool_name_prefix).rstrip("_.-") + if not normalized_prefix: + return normalized_name + trimmed_name = normalized_name.lstrip("_.-") + return f"{normalized_prefix}_{trimmed_name}" if trimmed_name else normalized_prefix + + def _inject_otel_into_mcp_meta(meta: dict[str, Any] | None = None) -> dict[str, Any] | None: """Inject OpenTelemetry trace context into MCP request _meta via the global propagator(s).""" carrier: dict[str, str] = {} @@ -424,8 +429,9 @@ def __init__( description: str | None = None, approval_mode: (Literal["always_require", "never_require"] | MCPSpecificApproval | None) = None, allowed_tools: Collection[str] | None = None, + tool_name_prefix: str | None = None, load_tools: bool = True, - parse_tool_results: Callable[[types.CallToolResult], str] | None = None, + parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None, load_prompts: bool = True, parse_prompt_results: Callable[[types.GetPromptResult], str] | None = None, session: ClientSession | None = None, @@ -444,6 +450,7 @@ def __init__( description: A description of the MCP tool. approval_mode: Whether approval is required to run tools. allowed_tools: A collection of tool names to allow. + tool_name_prefix: Optional prefix to prepend to exposed MCP function names. load_tools: Whether to load tools from the MCP server. parse_tool_results: An optional callable with signature ``Callable[[types.CallToolResult], str]`` that overrides the default result @@ -467,6 +474,7 @@ def __init__( self.description = description or "" self.approval_mode = approval_mode self.allowed_tools = allowed_tools + self.tool_name_prefix = _normalize_mcp_name(tool_name_prefix).rstrip("_.-") if tool_name_prefix else None self.additional_properties = additional_properties self.load_tools_flag = load_tools self.parse_tool_results = parse_tool_results @@ -489,7 +497,19 @@ def functions(self) -> list[FunctionTool]: """Get the list of functions that are allowed.""" if not self.allowed_tools: return self._functions - return [func for func in self._functions if func.name in self.allowed_tools] + allowed_names = set(self.allowed_tools) + filtered_functions: list[FunctionTool] = [] + for func in self._functions: + additional_properties = func.additional_properties or {} + normalized_name = additional_properties.get(_MCP_NORMALIZED_NAME_KEY) + remote_name = additional_properties.get(_MCP_REMOTE_NAME_KEY) + if ( + func.name in allowed_names + or (isinstance(normalized_name, str) and normalized_name in allowed_names) + or (isinstance(remote_name, str) and remote_name in allowed_names) + ): + filtered_functions.append(func) + return filtered_functions async def _safe_close_exit_stack(self) -> None: """Safely close the exit stack, handling cross-task boundary errors. @@ -715,12 +735,16 @@ async def message_handler( def _determine_approval_mode( self, - local_name: str, + *candidate_names: str, ) -> Literal["always_require", "never_require"] | None: if isinstance(self.approval_mode, dict): - if (always_require := self.approval_mode.get("always_require_approval")) and local_name in always_require: + if (always_require := self.approval_mode.get("always_require_approval")) and any( + name in always_require for name in candidate_names + ): return "always_require" - if (never_require := self.approval_mode.get("never_require_approval")) and local_name in never_require: + if (never_require := self.approval_mode.get("never_require_approval")) and any( + name in never_require for name in candidate_names + ): return "never_require" return None return self.approval_mode # type: ignore[reportReturnType] @@ -745,20 +769,25 @@ async def load_prompts(self) -> None: prompt_list = await self.session.list_prompts(params=params) # type: ignore[union-attr] for prompt in prompt_list.prompts: - local_name = _normalize_mcp_name(prompt.name) + normalized_name = _normalize_mcp_name(prompt.name) + local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) # Skip if already loaded if local_name in existing_names: continue input_model = _get_input_model_from_mcp_prompt(prompt) - approval_mode = self._determine_approval_mode(local_name) + approval_mode = self._determine_approval_mode(local_name, normalized_name, prompt.name) func: FunctionTool = FunctionTool( func=partial(self.get_prompt, prompt.name), name=local_name, description=prompt.description or "", approval_mode=approval_mode, input_model=input_model, + additional_properties={ + _MCP_REMOTE_NAME_KEY: prompt.name, + _MCP_NORMALIZED_NAME_KEY: normalized_name, + }, ) self._functions.append(func) existing_names.add(local_name) @@ -788,13 +817,14 @@ async def load_tools(self) -> None: tool_list = await self.session.list_tools(params=params) # type: ignore[union-attr] for tool in tool_list.tools: - local_name = _normalize_mcp_name(tool.name) + normalized_name = _normalize_mcp_name(tool.name) + local_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) # Skip if already loaded if local_name in existing_names: continue - approval_mode = self._determine_approval_mode(local_name) + approval_mode = self._determine_approval_mode(local_name, normalized_name, tool.name) # Create FunctionTools out of each tool func: FunctionTool = FunctionTool( func=partial(self.call_tool, tool.name), @@ -802,6 +832,10 @@ async def load_tools(self) -> None: description=tool.description or "", approval_mode=approval_mode, input_model=tool.inputSchema, + additional_properties={ + _MCP_REMOTE_NAME_KEY: tool.name, + _MCP_NORMALIZED_NAME_KEY: normalized_name, + }, ) self._functions.append(func) existing_names.add(local_name) @@ -850,7 +884,7 @@ async def _ensure_connected(self) -> None: inner_exception=ex, ) from ex - async def call_tool(self, tool_name: str, **kwargs: Any) -> str: + async def call_tool(self, tool_name: str, **kwargs: Any) -> str | list[Content]: """Call a tool with the given arguments. Args: @@ -860,7 +894,9 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str: kwargs: Arguments to pass to the tool. Returns: - A string representation of the tool result — either plain text or serialized JSON. + A list of Content items representing the tool output. The default + ``parse_tool_results`` always returns ``list[Content]``; a custom + callback may return a plain ``str`` which is also accepted. Raises: ToolExecutionException: If the MCP server is not connected, tools are not loaded, @@ -902,7 +938,13 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> str: try: result = await self.session.call_tool(tool_name, arguments=filtered_kwargs, meta=otel_meta) # type: ignore if result.isError: - raise ToolExecutionException(parser(result)) + parsed = parser(result) + text = ( + "\n".join(c.text for c in parsed if c.type == "text" and c.text) + if isinstance(parsed, list) + else str(parsed) + ) + raise ToolExecutionException(text or str(parsed)) return parser(result) except ToolExecutionException: raise @@ -1056,8 +1098,9 @@ def __init__( name: str, command: str, *, + tool_name_prefix: str | None = None, load_tools: bool = True, - parse_tool_results: Callable[[types.CallToolResult], str] | None = None, + parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None, load_prompts: bool = True, parse_prompt_results: Callable[[types.GetPromptResult], str] | None = None, request_timeout: int | None = None, @@ -1084,6 +1127,7 @@ def __init__( command: The command to run the MCP server. Keyword Args: + tool_name_prefix: Optional prefix to prepend to exposed MCP function names. load_tools: Whether to load tools from the MCP server. parse_tool_results: An optional callable with signature ``Callable[[types.CallToolResult], str]`` that overrides the default result @@ -1120,6 +1164,7 @@ def __init__( description=description, approval_mode=approval_mode, allowed_tools=allowed_tools, + tool_name_prefix=tool_name_prefix, additional_properties=additional_properties, session=session, client=client, @@ -1181,8 +1226,9 @@ def __init__( name: str, url: str, *, + tool_name_prefix: str | None = None, load_tools: bool = True, - parse_tool_results: Callable[[types.CallToolResult], str] | None = None, + parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None, load_prompts: bool = True, parse_prompt_results: Callable[[types.GetPromptResult], str] | None = None, request_timeout: int | None = None, @@ -1209,6 +1255,7 @@ def __init__( url: The URL of the MCP server. Keyword Args: + tool_name_prefix: Optional prefix to prepend to exposed MCP function names. load_tools: Whether to load tools from the MCP server. parse_tool_results: An optional callable with signature ``Callable[[types.CallToolResult], str]`` that overrides the default result @@ -1247,6 +1294,7 @@ def __init__( description=description, approval_mode=approval_mode, allowed_tools=allowed_tools, + tool_name_prefix=tool_name_prefix, additional_properties=additional_properties, session=session, client=client, @@ -1300,8 +1348,9 @@ def __init__( name: str, url: str, *, + tool_name_prefix: str | None = None, load_tools: bool = True, - parse_tool_results: Callable[[types.CallToolResult], str] | None = None, + parse_tool_results: Callable[[types.CallToolResult], str | list[Content]] | None = None, load_prompts: bool = True, parse_prompt_results: Callable[[types.GetPromptResult], str] | None = None, request_timeout: int | None = None, @@ -1326,6 +1375,7 @@ def __init__( url: The URL of the MCP server. Keyword Args: + tool_name_prefix: Optional prefix to prepend to exposed MCP function names. load_tools: Whether to load tools from the MCP server. parse_tool_results: An optional callable with signature ``Callable[[types.CallToolResult], str]`` that overrides the default result @@ -1359,6 +1409,7 @@ def __init__( description=description, approval_mode=approval_mode, allowed_tools=allowed_tools, + tool_name_prefix=tool_name_prefix, additional_properties=additional_properties, session=session, client=client, diff --git a/python/packages/core/agent_framework/_middleware.py b/python/packages/core/agent_framework/_middleware.py index ba11355adc9..66845a2e9dc 100644 --- a/python/packages/core/agent_framework/_middleware.py +++ b/python/packages/core/agent_framework/_middleware.py @@ -109,7 +109,9 @@ class AgentContext: to see the actual execution result or can be set to override the execution result. For non-streaming: should be AgentResponse. For streaming: should be ResponseStream[AgentResponseUpdate, AgentResponse]. - kwargs: Additional keyword arguments passed to the agent run method. + kwargs: Legacy runtime keyword arguments visible to agent middleware. + client_kwargs: Client-specific keyword arguments for downstream chat clients. + function_invocation_kwargs: Keyword arguments forwarded to tool invocation. Examples: .. code-block:: python @@ -147,6 +149,8 @@ def __init__( metadata: Mapping[str, Any] | None = None, result: AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None = None, kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, stream_transform_hooks: Sequence[ Callable[[AgentResponseUpdate], AgentResponseUpdate | Awaitable[AgentResponseUpdate]] ] @@ -167,7 +171,9 @@ def __init__( tokenizer: Optional per-run tokenizer override. metadata: Metadata dictionary for sharing data between agent middleware. result: Agent execution result. - kwargs: Additional keyword arguments passed to the agent run method. + kwargs: Legacy runtime keyword arguments visible to agent middleware. + client_kwargs: Client-specific keyword arguments for downstream chat clients. + function_invocation_kwargs: Keyword arguments forwarded to tool invocation. stream_transform_hooks: Hooks to transform streamed updates. stream_result_hooks: Hooks to process the final result after streaming. stream_cleanup_hooks: Hooks to run after streaming completes. @@ -182,6 +188,10 @@ def __init__( self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {} self.result = result self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {} + self.client_kwargs: dict[str, Any] = dict(client_kwargs) if client_kwargs is not None else {} + self.function_invocation_kwargs: dict[str, Any] = ( + dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} + ) self.stream_transform_hooks = list(stream_transform_hooks or []) self.stream_result_hooks = list(stream_result_hooks or []) self.stream_cleanup_hooks = list(stream_cleanup_hooks or []) @@ -196,11 +206,11 @@ class FunctionInvocationContext: Attributes: function: The function being invoked. arguments: The validated arguments for the function. + session: The agent session for this invocation, if any. metadata: Metadata dictionary for sharing data between function middleware. result: Function execution result. Can be observed after calling ``call_next()`` to see the actual execution result or can be set to override the execution result. - - kwargs: Additional keyword arguments passed to the chat method that invoked this function. + kwargs: Additional runtime keyword arguments forwarded to the function invocation. Examples: .. code-block:: python @@ -225,6 +235,7 @@ def __init__( self, function: FunctionTool, arguments: BaseModel | Mapping[str, Any], + session: AgentSession | None = None, metadata: Mapping[str, Any] | None = None, result: Any = None, kwargs: Mapping[str, Any] | None = None, @@ -234,12 +245,14 @@ def __init__( Args: function: The function being invoked. arguments: The validated arguments for the function. + session: The agent session for this invocation, if any. metadata: Metadata dictionary for sharing data between function middleware. result: Function execution result. - kwargs: Additional keyword arguments passed to the chat method that invoked this function. + kwargs: Additional runtime keyword arguments forwarded to the function invocation. """ self.function = function self.arguments = arguments + self.session = session self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {} self.result = result self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {} @@ -262,6 +275,7 @@ class ChatContext: For non-streaming: should be ChatResponse. For streaming: should be ResponseStream[ChatResponseUpdate, ChatResponse]. kwargs: Additional keyword arguments passed to the chat client. + function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers. stream_transform_hooks: Hooks applied to transform each streamed update. stream_result_hooks: Hooks applied to the finalized response (after finalizer). stream_cleanup_hooks: Hooks executed after stream consumption (before finalizer). @@ -298,6 +312,7 @@ def __init__( metadata: Mapping[str, Any] | None = None, result: ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] | None = None, kwargs: Mapping[str, Any] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, stream_transform_hooks: Sequence[ Callable[[ChatResponseUpdate], ChatResponseUpdate | Awaitable[ChatResponseUpdate]] ] @@ -315,6 +330,7 @@ def __init__( metadata: Metadata dictionary for sharing data between chat middleware. result: Chat execution result. kwargs: Additional keyword arguments passed to the chat client. + function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers. stream_transform_hooks: Transform hooks to apply to each streamed update. stream_result_hooks: Result hooks to apply to the finalized streaming response. stream_cleanup_hooks: Cleanup hooks to run after streaming completes. @@ -326,6 +342,9 @@ def __init__( self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {} self.result = result self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {} + self.function_invocation_kwargs: dict[str, Any] = ( + dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} + ) self.stream_transform_hooks = list(stream_transform_hooks or []) self.stream_result_hooks = list(stream_result_hooks or []) self.stream_cleanup_hooks = list(stream_cleanup_hooks or []) @@ -980,6 +999,7 @@ def get_response( options: ChatOptions[ResponseModelBoundT], compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... @@ -992,6 +1012,8 @@ def get_response( options: OptionsCoT | ChatOptions[None] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]]: ... @@ -1004,6 +1026,8 @@ def get_response( options: OptionsCoT | ChatOptions[Any] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... @@ -1015,6 +1039,8 @@ def get_response( options: OptionsCoT | ChatOptions[Any] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: """Execute the chat pipeline if middleware is configured.""" @@ -1025,9 +1051,10 @@ def get_response( if tokenizer is not None: kwargs["tokenizer"] = tokenizer - call_middleware = kwargs.pop("middleware", []) + effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + call_middleware = kwargs.pop("middleware", effective_client_kwargs.pop("middleware", [])) middleware = categorize_middleware(call_middleware) - kwargs["function_middleware"] = middleware["function"] + effective_client_kwargs["function_middleware"] = middleware["function"] pipeline = ChatMiddlewarePipeline( *self.chat_middleware, @@ -1038,6 +1065,8 @@ def get_response( messages=messages, stream=stream, options=options, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=effective_client_kwargs, **kwargs, ) @@ -1046,7 +1075,8 @@ def get_response( messages=list(messages), options=options, stream=stream, - kwargs=kwargs, + kwargs={**effective_client_kwargs, **kwargs}, + function_invocation_kwargs=function_invocation_kwargs, ) async def _execute() -> ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] | None: @@ -1079,11 +1109,17 @@ def _middleware_handler( self, context: ChatContext ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: """Internal middleware handler to adapt to pipeline.""" + handler_kwargs = dict(context.kwargs) + compaction_strategy = handler_kwargs.pop("compaction_strategy", None) + tokenizer = handler_kwargs.pop("tokenizer", None) return super().get_response( # type: ignore[misc, no-any-return] messages=context.messages, stream=context.stream, options=context.options or {}, - **context.kwargs, + compaction_strategy=compaction_strategy, + tokenizer=tokenizer, + function_invocation_kwargs=context.function_invocation_kwargs, + client_kwargs=handler_kwargs, ) @@ -1115,6 +1151,8 @@ def run( options: ChatOptions[ResponseModelBoundT], compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ... @@ -1129,6 +1167,8 @@ def run( options: ChatOptions[None] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -1143,6 +1183,8 @@ def run( options: ChatOptions[Any] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -1156,6 +1198,8 @@ def run( options: ChatOptions[Any] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """MiddlewareTypes-enabled unified run method.""" @@ -1175,9 +1219,12 @@ def run( + run_middleware_list["function"] + run_middleware_list["chat"] ) - combined_kwargs = dict(kwargs) - combined_kwargs["middleware"] = combined_function_chat_middleware if combined_function_chat_middleware else None - + effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + if combined_function_chat_middleware: + effective_client_kwargs["middleware"] = combined_function_chat_middleware + effective_function_invocation_kwargs = ( + dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} + ) # Execute with middleware if available if not pipeline.has_middlewares: return super().run( # type: ignore[misc, no-any-return] @@ -1187,7 +1234,9 @@ def run( options=options, compaction_strategy=compaction_strategy, tokenizer=tokenizer, - **combined_kwargs, + function_invocation_kwargs=effective_function_invocation_kwargs, + client_kwargs=effective_client_kwargs, + **kwargs, ) context = AgentContext( @@ -1198,7 +1247,9 @@ def run( stream=stream, compaction_strategy=compaction_strategy, tokenizer=tokenizer, - kwargs=combined_kwargs, + kwargs=kwargs, + client_kwargs=effective_client_kwargs, + function_invocation_kwargs=effective_function_invocation_kwargs, ) async def _execute() -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None: @@ -1230,6 +1281,13 @@ async def _execute_stream() -> ResponseStream[AgentResponseUpdate, AgentResponse def _middleware_handler( self, context: AgentContext ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + # TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed. + client_kwargs = {**context.client_kwargs, **context.kwargs} + # TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed. + function_invocation_kwargs = { + **context.function_invocation_kwargs, + **{k: v for k, v in context.kwargs.items() if k != "middleware"}, + } return super().run( # type: ignore[misc, no-any-return] context.messages, stream=context.stream, @@ -1237,7 +1295,8 @@ def _middleware_handler( options=context.options, compaction_strategy=context.compaction_strategy, tokenizer=context.tokenizer, - **context.kwargs, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, ) diff --git a/python/packages/core/agent_framework/_sessions.py b/python/packages/core/agent_framework/_sessions.py index 434a8d1fd48..84656824aa2 100644 --- a/python/packages/core/agent_framework/_sessions.py +++ b/python/packages/core/agent_framework/_sessions.py @@ -392,12 +392,16 @@ def __init__( self.store_outputs = store_outputs @abstractmethod - async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]: + async def get_messages( + self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any + ) -> list[Message]: """Retrieve stored messages for this session. Args: session_id: The session ID to retrieve messages for. - **kwargs: Additional arguments (e.g., ``state`` for in-memory providers). + state: Optional session state for providers that persist in session state. + Not used by all providers. + **kwargs: Additional subclass-specific extensibility arguments. Returns: List of stored messages. @@ -405,13 +409,22 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess ... @abstractmethod - async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None: + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: """Persist messages for this session. Args: session_id: The session ID to store messages for. messages: The messages to persist. - **kwargs: Additional arguments (e.g., ``state`` for in-memory providers). + state: Optional session state for providers that persist in session state. + Not used by all providers. + **kwargs: Additional subclass-specific extensibility arguments. """ ... diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index e920800f9ea..4119afec05f 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -7,6 +7,8 @@ import json import logging import sys +import typing +import warnings from collections.abc import ( AsyncIterable, Awaitable, @@ -37,7 +39,7 @@ from pydantic import BaseModel, Field, ValidationError, create_model from ._serialization import SerializationMixin -from .exceptions import ToolException +from .exceptions import ToolException, UserInputRequiredException from .observability import ( OPERATION_DURATION_BUCKET_BOUNDARIES, OtelAttr, @@ -61,7 +63,8 @@ from ._clients import SupportsChatGetResponse from ._compaction import CompactionStrategy, TokenizerProtocol from ._mcp import MCPTool - from ._middleware import FunctionMiddlewarePipeline, FunctionMiddlewareTypes + from ._middleware import FunctionInvocationContext, FunctionMiddlewarePipeline, FunctionMiddlewareTypes + from ._sessions import AgentSession from ._types import ( ChatOptions, ChatResponse, @@ -71,7 +74,6 @@ ResponseStream, ) - ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) else: MCPTool = Any # type: ignore[assignment,misc] @@ -83,9 +85,23 @@ DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3 SHELL_TOOL_KIND_VALUE: Final[str] = "shell" ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]") +ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) + # region Helpers +def _get_tool_name(tool: Any) -> str | None: + """Extract a tool name from a tool object or dict tool definition.""" + if isinstance(tool, Mapping): + func = tool.get("function", None) # type: ignore + if func and isinstance(func, Mapping): + name = func.get("name") # type: ignore + return name if isinstance(name, str) else None + return None + name = getattr(tool, "name", None) + return name if isinstance(name, str) else None + + def _parse_inputs( # pyright: ignore[reportUnusedFunction] inputs: Content | dict[str, Any] | str | list[Content | dict[str, Any] | str] | None, ) -> list[Content]: @@ -174,6 +190,16 @@ def _default_histogram() -> Histogram: ) +def _annotation_includes_function_invocation_context(annotation: Any) -> bool: + """Check whether an annotation resolves to FunctionInvocationContext.""" + from ._middleware import FunctionInvocationContext + + candidates = get_args(annotation) or (annotation,) + return any( + candidate is FunctionInvocationContext or candidate == "FunctionInvocationContext" for candidate in candidates + ) + + ClassT = TypeVar("ClassT", bound="SerializationMixin") @@ -246,7 +272,7 @@ def __init__( additional_properties: dict[str, Any] | None = None, func: Callable[..., Any] | None = None, input_model: type[BaseModel] | Mapping[str, Any] | None = None, - result_parser: Callable[[Any], str] | None = None, + result_parser: Callable[[Any], str | list[Content]] | None = None, **kwargs: Any, ) -> None: """Initialize the FunctionTool. @@ -310,6 +336,12 @@ def __init__( # FunctionTool-specific attributes self.func = func self._instance = None # Store the instance for bound methods + self._context_parameter_name: str | None = None + self._input_model_explicitly_provided = input_model is not None + # TODO(Copilot): Delete once legacy ``**kwargs`` runtime injection is removed. + self._forward_runtime_kwargs: bool = False + if self.func: + self._discover_injected_parameters() # Initialize schema cache (will be lazily populated) self._input_schema_cached: dict[str, Any] | None = None @@ -336,13 +368,37 @@ def __init__( self._invocation_duration_histogram = _default_histogram() self.type: Literal["function_tool"] = "function_tool" self.result_parser = result_parser - self._forward_runtime_kwargs: bool = False - if self.func: - sig = inspect.signature(self.func) - for param in sig.parameters.values(): - if param.kind == inspect.Parameter.VAR_KEYWORD: - self._forward_runtime_kwargs = True - break + + def _discover_injected_parameters(self) -> None: + """Inspect the wrapped function for runtime injection parameters.""" + func = self.func.func if isinstance(self.func, FunctionTool) else self.func + if func is None: + return + + signature = inspect.signature(func) + try: + type_hints = typing.get_type_hints(func) + except Exception: + type_hints = {name: param.annotation for name, param in signature.parameters.items()} + + for name, param in signature.parameters.items(): + if name in {"self", "cls"}: + continue + if param.kind == inspect.Parameter.VAR_KEYWORD: + self._forward_runtime_kwargs = True + continue + + annotation = type_hints.get(name, param.annotation) + if self._is_context_parameter(name, annotation): + if self._context_parameter_name is not None: + raise ValueError(f"Function '{self.name}' defines multiple FunctionInvocationContext parameters.") + self._context_parameter_name = name + + def _is_context_parameter(self, name: str, annotation: Any) -> bool: + """Check whether a callable parameter should receive FunctionInvocationContext injection.""" + if _annotation_includes_function_invocation_context(annotation): + return True + return self._input_model_explicitly_provided and name == "ctx" and annotation is inspect.Parameter.empty def __str__(self) -> str: """Return a string representation of the tool.""" @@ -411,6 +467,7 @@ def _resolve_input_model(self, input_model: type[BaseModel] | None) -> type[Base ) for pname, param in sig.parameters.items() if pname not in {"self", "cls"} + and pname != self._context_parameter_name and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD} } return create_model(f"{self.name}_input", **fields) @@ -448,20 +505,23 @@ async def invoke( self, *, arguments: BaseModel | Mapping[str, Any] | None = None, + context: FunctionInvocationContext | None = None, **kwargs: Any, - ) -> str: + ) -> list[Content]: """Run the AI function with the provided arguments as a Pydantic model. - The raw return value of the wrapped function is automatically parsed into a ``str`` - (either plain text or serialized JSON) using :meth:`parse_result` or the custom - ``result_parser`` if one was provided. + The raw return value of the wrapped function is automatically parsed into a + ``list[Content]`` using :meth:`parse_result` or the custom ``result_parser`` + if one was provided. Every result — text, rich media, or serialized objects — + is represented uniformly as Content items. Keyword Args: arguments: A mapping or model instance containing the arguments for the function. - kwargs: Keyword arguments to pass to the function, will not be used if ``arguments`` is provided. + context: Explicit function invocation context carrying runtime kwargs. + kwargs: Deprecated keyword arguments to pass to the function. Use ``context`` instead. Returns: - The parsed result as a string — either plain text or serialized JSON. + A list of Content items representing the tool output. Raises: TypeError: If arguments is not mapping-like or fails schema checks. @@ -469,13 +529,37 @@ async def invoke( if self.declaration_only: raise ToolException(f"Function '{self.name}' is declaration only and cannot be invoked.") global OBSERVABILITY_SETTINGS + from ._middleware import FunctionInvocationContext + from ._types import Content from .observability import OBSERVABILITY_SETTINGS parser = self.result_parser or FunctionTool.parse_result - original_kwargs = dict(kwargs) - tool_call_id = original_kwargs.pop("tool_call_id", None) - if arguments is not None: + parameter_names = set(self.parameters().get("properties", {}).keys()) + direct_argument_kwargs = ( + {key: value for key, value in kwargs.items() if key in parameter_names} if arguments is None else {} + ) + runtime_kwargs = dict(context.kwargs) if context is not None else {} + deprecated_runtime_kwargs = { + key: value for key, value in kwargs.items() if key not in direct_argument_kwargs and key != "tool_call_id" + } + if deprecated_runtime_kwargs: + warnings.warn( + "Passing runtime keyword arguments directly to FunctionTool.invoke() is deprecated; " + "pass them via FunctionInvocationContext instead.", + DeprecationWarning, + stacklevel=2, + ) + runtime_kwargs.update(deprecated_runtime_kwargs) + tool_call_id = kwargs.get("tool_call_id", runtime_kwargs.pop("tool_call_id", None)) + if arguments is None and direct_argument_kwargs: + arguments = direct_argument_kwargs + if arguments is None and context is not None: + arguments = context.arguments + + if arguments is None: + validated_arguments: dict[str, Any] = {} + else: try: if isinstance(arguments, Mapping): parsed_arguments = dict(arguments) @@ -497,34 +581,66 @@ async def invoke( ) except ValidationError as exc: raise TypeError(f"Invalid arguments for '{self.name}': {exc}") from exc - kwargs = _validate_arguments_against_schema( + + validated_arguments = _validate_arguments_against_schema( arguments=parsed_arguments, schema=self.parameters(), tool_name=self.name, ) - if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs: - kwargs.update(original_kwargs) - else: - kwargs = original_kwargs + + effective_context = context + if effective_context is None and self._context_parameter_name is not None: + effective_context = FunctionInvocationContext( + function=self, + arguments=validated_arguments, + kwargs=runtime_kwargs, + ) + if effective_context is not None: + effective_context.function = self + effective_context.arguments = validated_arguments + effective_context.kwargs = dict(runtime_kwargs) + + call_kwargs = dict(validated_arguments) + observable_kwargs = dict(validated_arguments) + + # Legacy runtime kwargs injection path retained for backwards compatibility with tools + # that still declare ``**kwargs``. New tools should consume runtime data via ``ctx``. + legacy_runtime_kwargs = dict(runtime_kwargs) + if self._forward_runtime_kwargs and legacy_runtime_kwargs: + for key, value in legacy_runtime_kwargs.items(): + if key not in call_kwargs: + call_kwargs[key] = value + if key not in observable_kwargs: + observable_kwargs[key] = value + + if self._context_parameter_name is not None and effective_context is not None: + call_kwargs[self._context_parameter_name] = effective_context + if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined] logger.info(f"Function name: {self.name}") - logger.debug(f"Function arguments: {kwargs}") - res = self.__call__(**kwargs) + logger.debug(f"Function arguments: {observable_kwargs}") + res = self.__call__(**call_kwargs) result = await res if inspect.isawaitable(res) else res try: parsed = parser(result) except Exception: logger.warning(f"Function {self.name}: result parser failed, falling back to str().") - parsed = str(result) + parsed = [Content.from_text(str(result))] + if isinstance(parsed, str): + parsed = [Content.from_text(parsed)] logger.info(f"Function {self.name} succeeded.") - logger.debug(f"Function result: {parsed or 'None'}") + if parsed: + types = [item.type for item in parsed] + logger.debug(f"Function result: {len(parsed)} item(s) ({', '.join(types)})") + else: + logger.debug("Function result: None") return parsed attributes = get_function_span_attributes(self, tool_call_id=tool_call_id) # Filter out framework kwargs that are not JSON serializable. serializable_kwargs = { k: v - for k, v in kwargs.items() + for k, v in observable_kwargs.items() if k not in { "chat_options", @@ -550,7 +666,7 @@ async def invoke( start_time_stamp = perf_counter() end_time_stamp: float | None = None try: - res = self.__call__(**kwargs) + res = self.__call__(**call_kwargs) result = await res if inspect.isawaitable(res) else res end_time_stamp = perf_counter() except Exception as exception: @@ -564,11 +680,14 @@ async def invoke( parsed = parser(result) except Exception: logger.warning(f"Function {self.name}: result parser failed, falling back to str().") - parsed = str(result) + parsed = [Content.from_text(str(result))] + if isinstance(parsed, str): + parsed = [Content.from_text(parsed)] logger.info(f"Function {self.name} succeeded.") if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED: # type: ignore[name-defined] - span.set_attribute(OtelAttr.TOOL_RESULT, parsed) - logger.debug(f"Function result: {parsed}") + result_str = "\n".join(c.text or "" for c in parsed if c.type == "text") or str(parsed) + span.set_attribute(OtelAttr.TOOL_RESULT, result_str) + logger.debug(f"Function result: {result_str}") return parsed finally: duration = (end_time_stamp or perf_counter()) - start_time_stamp @@ -622,10 +741,14 @@ def _make_dumpable(value: Any) -> Any: return value @staticmethod - def parse_result(result: Any) -> str: - """Convert a raw function return value to a string representation. + def parse_result(result: Any) -> list[Content]: + """Convert a raw function return value to a list of Content items. + + Every tool result is represented as a uniform ``list[Content]``. Text + results become ``Content(type="text")``, rich media (images, audio, + files) are preserved as-is, and arbitrary objects are serialized to JSON + text. - The return value is always a ``str`` — either plain text or serialized JSON. This is called automatically by :meth:`invoke` before returning the result, ensuring that the result stored in ``Content.from_function_result`` is already in a form that can be passed directly to LLM APIs. @@ -634,16 +757,30 @@ def parse_result(result: Any) -> str: result: The raw return value from the wrapped function. Returns: - A string representation of the result, either plain text or serialized JSON. + A list of Content items representing the tool output. """ + from ._types import Content + if result is None: - return "" + return [Content.from_text("")] if isinstance(result, str): - return result + return [Content.from_text(result)] + if isinstance(result, Content): + return [result] + if isinstance(result, list) and any(isinstance(item, Content) for item in result): # type: ignore[reportUnknownVariableType] + parsed_items: list[Content] = [] + for item in result: # type: ignore[reportUnknownVariableType] + if isinstance(item, Content): + parsed_items.append(item) + else: + dumpable = FunctionTool._make_dumpable(item) # type: ignore[reportUnknownArgumentType] + text = dumpable if isinstance(dumpable, str) else json.dumps(dumpable, default=str) # type: ignore[reportUnknownArgumentType] + parsed_items.append(Content.from_text(text)) + return parsed_items dumpable = FunctionTool._make_dumpable(result) if isinstance(dumpable, str): - return dumpable - return json.dumps(dumpable, default=str) + return [Content.from_text(dumpable)] + return [Content.from_text(json.dumps(dumpable, default=str))] def to_json_schema_spec(self) -> dict[str, Any]: """Convert a FunctionTool to the JSON Schema function specification format. @@ -672,6 +809,51 @@ def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) ToolTypes: TypeAlias = FunctionTool | MCPTool | Mapping[str, Any] | object +def _raise_duplicate_tool_name(tool_name: str, duplicate_error_message: str | None = None) -> None: + message = duplicate_error_message or "Tool names must be unique." + raise ValueError(f"Duplicate tool name '{tool_name}'. {message}") + + +def _append_unique_tools( + existing_tools: list[ToolTypes], + new_tools: Sequence[ToolTypes], + *, + duplicate_error_message: str | None = None, +) -> list[ToolTypes]: + seen_by_name: dict[str, ToolTypes] = {} + for tool_item in existing_tools: + if tool_name := _get_tool_name(tool_item): + seen_by_name[tool_name] = tool_item + + for tool_item in new_tools: + tool_name = _get_tool_name(tool_item) + if tool_name is None: + existing_tools.append(tool_item) + continue + + existing_tool = seen_by_name.get(tool_name) + if existing_tool is None: + seen_by_name[tool_name] = tool_item + existing_tools.append(tool_item) + continue + + if existing_tool is tool_item: + continue + + _raise_duplicate_tool_name(tool_name, duplicate_error_message) + + return existing_tools + + +def _ensure_unique_tool_names( + tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]], + *, + duplicate_error_message: str | None = None, +) -> list[ToolTypes]: + normalized_tools = normalize_tools(tools) + return _append_unique_tools([], normalized_tools, duplicate_error_message=duplicate_error_message) + + def normalize_tools( tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None, ) -> list[ToolTypes]: @@ -860,7 +1042,7 @@ def tool( max_invocations: int | None = None, max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, - result_parser: Callable[[Any], str] | None = None, + result_parser: Callable[[Any], str | list[Content]] | None = None, ) -> FunctionTool: ... @@ -876,7 +1058,7 @@ def tool( max_invocations: int | None = None, max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, - result_parser: Callable[[Any], str] | None = None, + result_parser: Callable[[Any], str | list[Content]] | None = None, ) -> Callable[[Callable[..., Any]], FunctionTool]: ... @@ -891,7 +1073,7 @@ def tool( max_invocations: int | None = None, max_invocation_exceptions: int | None = None, additional_properties: dict[str, Any] | None = None, - result_parser: Callable[[Any], str] | None = None, + result_parser: Callable[[Any], str | list[Content]] | None = None, ) -> FunctionTool | Callable[[Callable[..., Any]], FunctionTool]: """Decorate a function to turn it into a FunctionTool that can be passed to models and executed automatically. @@ -1131,9 +1313,10 @@ async def _auto_invoke_function( *, config: FunctionInvocationConfiguration, tool_map: dict[str, FunctionTool], + invocation_session: AgentSession | None = None, sequence_index: int | None = None, request_index: int | None = None, - middleware_pipeline: FunctionMiddlewarePipeline | None = None, # Optional MiddlewarePipeline + middleware_pipeline: FunctionMiddlewarePipeline | None = None, ) -> Content: """Invoke a function call requested by the agent, applying middleware that is defined. @@ -1144,6 +1327,7 @@ async def _auto_invoke_function( Keyword Args: config: The function invocation configuration. tool_map: A mapping of tool names to FunctionTool instances. + invocation_session: The agent session for this invocation, if any. sequence_index: The index of the function call in the sequence. request_index: The index of the request iteration. middleware_pipeline: Optional middleware pipeline to apply during execution. @@ -1195,6 +1379,8 @@ async def _auto_invoke_function( for key, value in (custom_args or {}).items() if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"} } + if invocation_session is not None: + runtime_kwargs["session"] = invocation_session try: if not cast(bool, getattr(tool, "_schema_supplied", False)) and tool.input_model is not None: args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True) @@ -1216,19 +1402,31 @@ async def _auto_invoke_function( additional_properties=function_call_content.additional_properties, ) + from ._middleware import FunctionInvocationContext + if middleware_pipeline is None or not middleware_pipeline.has_middlewares: # No middleware - execute directly try: + direct_context = None + if getattr(tool, "_forward_runtime_kwargs", False) or getattr(tool, "_context_parameter_name", None): + direct_context = FunctionInvocationContext( + function=tool, + arguments=args, + session=invocation_session, + kwargs=runtime_kwargs.copy(), + ) function_result = await tool.invoke( arguments=args, + context=direct_context, tool_call_id=function_call_content.call_id, - **runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {}, ) return Content.from_function_result( call_id=function_call_content.call_id, # type: ignore[arg-type] result=function_result, additional_properties=function_call_content.additional_properties, ) + except UserInputRequiredException: + raise except Exception as exc: message = "Error: Function failed." if config.get("include_detailed_errors", False): @@ -1240,19 +1438,18 @@ async def _auto_invoke_function( additional_properties=function_call_content.additional_properties, ) # Execute through middleware pipeline if available - from ._middleware import FunctionInvocationContext - middleware_context = FunctionInvocationContext( function=tool, arguments=args, + session=invocation_session, kwargs=runtime_kwargs.copy(), ) async def final_function_handler(context_obj: Any) -> Any: return await tool.invoke( arguments=context_obj.arguments, + context=context_obj, tool_call_id=function_call_content.call_id, - **context_obj.kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {}, ) from ._middleware import MiddlewareTermination @@ -1275,6 +1472,8 @@ async def final_function_handler(context_obj: Any) -> Any: additional_properties=function_call_content.additional_properties, ) raise + except UserInputRequiredException: + raise except Exception as exc: message = "Error: Function failed." if config.get("include_detailed_errors", False): @@ -1291,7 +1490,7 @@ def _get_tool_map( tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]], ) -> dict[str, FunctionTool]: tool_list: dict[str, FunctionTool] = {} - for tool_item in normalize_tools(tools): + for tool_item in _ensure_unique_tool_names(tools): if isinstance(tool_item, FunctionTool): tool_list[tool_item.name] = tool_item return tool_list @@ -1303,7 +1502,8 @@ async def _try_execute_function_calls( function_calls: Sequence[Content], tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]], config: FunctionInvocationConfiguration, - middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports + invocation_session: AgentSession | None = None, + middleware_pipeline: Any = None, ) -> tuple[Sequence[Content], bool]: """Execute multiple function calls concurrently. @@ -1313,6 +1513,7 @@ async def _try_execute_function_calls( function_calls: A sequence of FunctionCallContent to execute. tools: The tools available for execution. config: Configuration for function invocation. + invocation_session: The agent session for this invocation, if any. middleware_pipeline: Optional middleware pipeline to apply during execution. Returns: @@ -1382,6 +1583,8 @@ async def _try_execute_function_calls( # Run all function calls concurrently, handling MiddlewareTermination from ._middleware import MiddlewareTermination + extra_user_input_contents: list[Content] = [] + async def invoke_with_termination_handling( function_call: Content, seq_idx: int, @@ -1392,6 +1595,7 @@ async def invoke_with_termination_handling( function_call_content=function_call, # type: ignore[arg-type] custom_args=custom_args, tool_map=tool_map, + invocation_session=invocation_session, sequence_index=seq_idx, request_index=attempt_idx, middleware_pipeline=middleware_pipeline, @@ -1408,6 +1612,26 @@ async def invoke_with_termination_handling( result=exc.result, ) return (result_content, True) + except UserInputRequiredException as exc: + if exc.contents: + propagated: list[Content] = [] + for item in exc.contents: + if isinstance(item, Content): + item.call_id = function_call.call_id # type: ignore[attr-defined] + if not item.id: # type: ignore[attr-defined] + item.id = function_call.call_id # type: ignore[attr-defined] + propagated.append(item) + if propagated: + extra_user_input_contents.extend(propagated[1:]) + return (propagated[0], False) + return ( + Content.from_function_result( + call_id=function_call.call_id, # type: ignore[arg-type] + result="Tool requires user input but no request details were provided.", + exception="UserInputRequiredException", + ), + False, + ) execution_results = await asyncio.gather(*[ invoke_with_termination_handling(function_call, seq_idx) for seq_idx, function_call in enumerate(function_calls) @@ -1415,6 +1639,7 @@ async def invoke_with_termination_handling( # Unpack results - each is (Content, terminate_flag) contents: list[Content] = [result[0] for result in execution_results] + contents.extend(extra_user_input_contents) # If any function requested termination, terminate the loop should_terminate = any(result[1] for result in execution_results) return (contents, should_terminate) @@ -1427,6 +1652,7 @@ async def _execute_function_calls( function_calls: list[Content], tool_options: dict[str, Any] | None, config: FunctionInvocationConfiguration, + invocation_session: AgentSession | None = None, middleware_pipeline: Any = None, ) -> tuple[list[Content], bool, bool]: tools = _extract_tools(tool_options) @@ -1437,6 +1663,7 @@ async def _execute_function_calls( attempt_idx=attempt_idx, function_calls=function_calls, tools=tools, # type: ignore + invocation_session=invocation_session, middleware_pipeline=middleware_pipeline, config=config, ) @@ -1646,7 +1873,10 @@ def _handle_function_call_results( ) -> FunctionRequestResult: from ._types import Message - if any(fccr.type in {"function_approval_request", "function_call"} for fccr in function_call_results): + if any( + fccr.type in {"function_approval_request", "function_call"} or fccr.user_input_request + for fccr in function_call_results + ): # Only add items that aren't already in the message (e.g. function_approval_request wrappers). # Declaration-only function_call items are already present from the LLM response. new_items = [fccr for fccr in function_call_results if fccr.type != "function_call"] @@ -1814,6 +2044,8 @@ def get_response( options: ChatOptions[ResponseModelBoundT], compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... @@ -1826,6 +2058,8 @@ def get_response( options: OptionsCoT | ChatOptions[None] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]]: ... @@ -1838,6 +2072,8 @@ def get_response( options: OptionsCoT | ChatOptions[Any] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... @@ -1850,6 +2086,8 @@ def get_response( function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: from ._middleware import FunctionMiddlewarePipeline @@ -1860,28 +2098,45 @@ def get_response( ) super_get_response = super().get_response # type: ignore[misc] + if kwargs: + warnings.warn( + "Passing client-specific keyword arguments directly to get_response() is deprecated; " + "pass them via client_kwargs instead.", + DeprecationWarning, + stacklevel=2, + ) + + effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + effective_function_middleware = function_middleware + if effective_function_middleware is None: + middleware_from_client_kwargs = effective_client_kwargs.pop("function_middleware", None) + if middleware_from_client_kwargs is not None: + effective_function_middleware = cast(Sequence[Any], middleware_from_client_kwargs) # ChatMiddleware adds this kwarg function_middleware_pipeline = FunctionMiddlewarePipeline( - *(self.function_middleware), *(function_middleware or []) + *(self.function_middleware), *(effective_function_middleware or []) ) max_errors = self.function_invocation_configuration.get( "max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST ) - additional_function_arguments: dict[str, Any] = {} + additional_function_arguments = ( + dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {} + ) if options and (additional_opts := options.get("additional_function_arguments")): # type: ignore[attr-defined] - additional_function_arguments = additional_opts # type: ignore + additional_function_arguments.update(cast(Mapping[str, Any], additional_opts)) + from ._sessions import AgentSession as _AgentSession + + raw_session = effective_client_kwargs.get("session") + invocation_session = raw_session if isinstance(raw_session, _AgentSession) else None execute_function_calls = partial( _execute_function_calls, custom_args=additional_function_arguments, config=self.function_invocation_configuration, + invocation_session=invocation_session, middleware_pipeline=function_middleware_pipeline, ) - filtered_kwargs = {k: v for k, v in kwargs.items() if k != "session"} - if compaction_strategy is not None: - filtered_kwargs["compaction_strategy"] = compaction_strategy - if tokenizer is not None: - filtered_kwargs["tokenizer"] = tokenizer + filtered_kwargs = {k: v for k, v in {**effective_client_kwargs, **kwargs}.items() if k != "session"} # Make options mutable so we can update conversation_id during function invocation loop mutable_options: dict[str, Any] = dict(options) if options else {} @@ -1931,7 +2186,9 @@ async def _get_response() -> ChatResponse[Any]: messages=prepped_messages, stream=False, options=mutable_options, - **filtered_kwargs, + compaction_strategy=compaction_strategy, + tokenizer=tokenizer, + client_kwargs=filtered_kwargs, ), ) @@ -2000,7 +2257,9 @@ async def _get_response() -> ChatResponse[Any]: messages=prepped_messages, stream=False, options=mutable_options, - **filtered_kwargs, + compaction_strategy=compaction_strategy, + tokenizer=tokenizer, + client_kwargs=filtered_kwargs, ), ) if fcc_messages: @@ -2050,7 +2309,9 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: messages=prepped_messages, stream=True, options=mutable_options, - **filtered_kwargs, + compaction_strategy=compaction_strategy, + tokenizer=tokenizer, + client_kwargs=filtered_kwargs, ), ) await inner_stream @@ -2142,7 +2403,9 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: messages=prepped_messages, stream=True, options=mutable_options, - **filtered_kwargs, + compaction_strategy=compaction_strategy, + tokenizer=tokenizer, + client_kwargs=filtered_kwargs, ), ) await final_inner_stream diff --git a/python/packages/core/agent_framework/_types.py b/python/packages/core/agent_framework/_types.py index a44baac2dd0..a4e3a573305 100644 --- a/python/packages/core/agent_framework/_types.py +++ b/python/packages/core/agent_framework/_types.py @@ -480,6 +480,7 @@ def __init__( arguments: str | Mapping[str, Any] | None = None, exception: str | None = None, result: Any = None, + items: Sequence[Content] | None = None, # Hosted file/vector store fields file_id: str | None = None, vector_store_id: str | None = None, @@ -539,6 +540,7 @@ def __init__( self.arguments = arguments self.exception = exception self.result = result + self.items = items self.file_id = file_id self.vector_store_id = vector_store_id self.inputs = inputs @@ -813,11 +815,48 @@ def from_function_result( additional_properties: MutableMapping[str, Any] | None = None, raw_representation: Any = None, ) -> ContentT: - """Create function result content.""" + """Create function result content. + + All tool output is represented uniformly as Content items in the + ``items`` field. The ``result`` field is populated with the concatenated + text from text items for backwards compatibility. + + Args: + call_id: The ID of the function call this result corresponds to. + + Keyword Args: + result: The tool output. Accepts a ``list[Content]`` (the canonical + form produced by :meth:`~FunctionTool.parse_result`), a plain + ``str``, or any other value (which is stringified). + exception: The exception message if the function call failed. + annotations: Optional annotations for the content. + additional_properties: Optional additional properties. + raw_representation: Optional raw representation from the provider. + """ + if isinstance(result, list): + if all(isinstance(c, Content) for c in result): # type: ignore[reportUnknownVariableType] + items_list: list[Content] = list(result) # type: ignore[reportUnknownArgumentType] + else: + items_list = [Content.from_text(str(result))] # type: ignore[reportUnknownArgumentType] + elif isinstance(result, str): + items_list = [Content.from_text(result)] + elif result is not None: + try: + text = json.dumps(result, default=str) + except (TypeError, ValueError): + text = str(result) + items_list = [Content.from_text(text)] + else: + items_list = [Content.from_text("")] + + text_parts = [c.text for c in items_list if c.type == "text" and c.text] + text_result = "\n".join(text_parts) if text_parts else "" + return cls( "function_result", call_id=call_id, - result=result, + result=text_result, + items=items_list, exception=exception, annotations=annotations, additional_properties=additional_properties, @@ -1218,6 +1257,7 @@ def to_dict(self, *, exclude_none: bool = True, exclude: set[str] | None = None) "arguments", "exception", "result", + "items", "file_id", "vector_store_id", "inputs", @@ -1299,6 +1339,8 @@ def from_dict(cls: type[ContentT], data: Mapping[str, Any]) -> ContentT: remaining["inputs"] = [cls.from_dict(item) if isinstance(item, dict) else item for item in input_items] # type: ignore[reportUnknownVariableType] if (output_items := remaining.get("outputs")) and isinstance(output_items, list): remaining["outputs"] = [cls.from_dict(item) if isinstance(item, dict) else item for item in output_items] # type: ignore[reportUnknownVariableType] + if (content_items := remaining.get("items")) and isinstance(content_items, list): + remaining["items"] = [cls.from_dict(item) if isinstance(item, dict) else item for item in content_items] # type: ignore[reportUnknownVariableType] return cls( type=content_type, @@ -2656,7 +2698,7 @@ def __init__( stream: AsyncIterable[UpdateT] | Awaitable[AsyncIterable[UpdateT]], *, finalizer: Callable[[Sequence[UpdateT]], FinalT | Awaitable[FinalT]] | None = None, - transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None]] | None = None, + transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None]] | None = None, cleanup_hooks: list[Callable[[], Awaitable[None] | None]] | None = None, result_hooks: list[Callable[[FinalT], FinalT | Awaitable[FinalT | None] | None]] | None = None, ) -> None: @@ -2680,7 +2722,7 @@ def __init__( self._consumed: bool = False self._finalized: bool = False self._final_result: FinalT | None = None - self._transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None]] = ( + self._transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None]] = ( transform_hooks if transform_hooks is not None else [] ) self._result_hooks: list[Callable[[FinalT], FinalT | Awaitable[FinalT | None] | None]] = ( @@ -2953,7 +2995,7 @@ async def get_final_response(self) -> FinalT: def with_transform_hook( self, - hook: Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None], + hook: Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None], ) -> ResponseStream[UpdateT, FinalT]: """Register a transform hook executed for each update during iteration.""" self._transform_hooks.append(hook) diff --git a/python/packages/core/agent_framework/_workflows/_edge.py b/python/packages/core/agent_framework/_workflows/_edge.py index 02544ad3dfa..b9dbd266ec5 100644 --- a/python/packages/core/agent_framework/_workflows/_edge.py +++ b/python/packages/core/agent_framework/_workflows/_edge.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from typing import Any, ClassVar, TypeAlias, TypeVar +from .._agents import SupportsAgentRun from ._const import INTERNAL_SOURCE_ID from ._executor import Executor from ._model_utils import DictConvertible, encode_value @@ -264,7 +265,7 @@ def __init__(self) -> None: """ condition: Callable[[Any], bool] - target: Executor | str + target: Executor | SupportsAgentRun @dataclass @@ -287,7 +288,7 @@ def __init__(self) -> None: assert fallback.target.id == "dead_letter" """ - target: Executor | str + target: Executor | SupportsAgentRun @dataclass(init=False) diff --git a/python/packages/core/agent_framework/azure/_chat_client.py b/python/packages/core/agent_framework/azure/_chat_client.py index b57abd6faf6..21c38f6b574 100644 --- a/python/packages/core/agent_framework/azure/_chat_client.py +++ b/python/packages/core/agent_framework/azure/_chat_client.py @@ -172,12 +172,12 @@ def __init__( credential: AzureCredentialTypes | AzureTokenProvider | None = None, default_headers: Mapping[str, str] | None = None, async_client: AsyncAzureOpenAI | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, instruction_role: str | None = None, middleware: Sequence[MiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, - **kwargs: Any, ) -> None: """Initialize an Azure OpenAI Chat completion client. @@ -205,13 +205,13 @@ def __init__( default_headers: The default headers mapping of string keys to string values for HTTP requests. async_client: An existing client to use. + additional_properties: Additional properties stored on the client instance. env_file_path: Use the environment settings file as a fallback to using env vars. env_file_encoding: The encoding of the environment settings file, defaults to 'utf-8'. instruction_role: The role to use for 'instruction' messages, for example, summarization prompts could use `developer` or `system`. middleware: Optional sequence of middleware to apply to requests. function_invocation_configuration: Optional configuration for function invocation behavior. - kwargs: Other keyword parameters. Examples: .. code-block:: python @@ -283,10 +283,10 @@ class MyOptions(AzureOpenAIChatOptions, total=False): credential=credential, default_headers=default_headers, client=async_client, + additional_properties=additional_properties, instruction_role=instruction_role, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) @override diff --git a/python/packages/core/agent_framework/exceptions.py b/python/packages/core/agent_framework/exceptions.py index f38aa38590c..4f56c34b5c6 100644 --- a/python/packages/core/agent_framework/exceptions.py +++ b/python/packages/core/agent_framework/exceptions.py @@ -180,6 +180,34 @@ class ToolExecutionException(ToolException): pass +class UserInputRequiredException(ToolException): + """Raised when a tool wrapping a sub-agent requires user input to proceed. + + This exception carries the ``user_input_request`` Content items emitted by + the sub-agent (e.g., ``oauth_consent_request``, ``function_approval_request``) + so the tool invocation layer can propagate them to the parent agent's response + instead of swallowing them as a generic tool error. + + Args: + contents: The user-input-request Content items from the sub-agent response. + message: Human-readable description of why user input is needed. + """ + + def __init__( + self, + contents: list[Any], + message: str = "Tool requires user input to proceed.", + ) -> None: + """Create a UserInputRequiredException. + + Args: + contents: The user-input-request Content items from the sub-agent response. + message: Human-readable description of why user input is needed. + """ + super().__init__(message, log_level=None) + self.contents = contents + + # endregion # region Middleware Exceptions diff --git a/python/packages/core/agent_framework/observability.py b/python/packages/core/agent_framework/observability.py index 2407074efc7..3307f9f4eb3 100644 --- a/python/packages/core/agent_framework/observability.py +++ b/python/packages/core/agent_framework/observability.py @@ -14,6 +14,7 @@ from __future__ import annotations import contextlib +import contextvars import json import logging import os @@ -93,6 +94,13 @@ logger = logging.getLogger("agent_framework") +INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS: Final[contextvars.ContextVar[set[str] | None]] = contextvars.ContextVar( + "inner_response_telemetry_captured_fields", default=None +) +INNER_RESPONSE_ID_CAPTURED_FIELD: Final[str] = "response_id" +INNER_USAGE_CAPTURED_FIELD: Final[str] = "usage" + + OTEL_METRICS: Final[str] = "__otel_metrics__" TOKEN_USAGE_BUCKET_BOUNDARIES: Final[tuple[float, ...]] = ( 1, @@ -1162,11 +1170,35 @@ def get_response( tokenizer: TokenizerProtocol | None = None, **kwargs: Any, ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: - """Trace chat responses with OpenTelemetry spans and metrics.""" + """Trace chat responses with OpenTelemetry spans and metrics. + + Args: + messages: The message or messages to send to the model. + stream: Whether to stream the response. Defaults to False. + options: Chat options as a TypedDict. + compaction_strategy: Optional compaction strategy to apply before model calls. + tokenizer: Optional tokenizer used by token-aware compaction strategies. + + Keyword Args: + kwargs: Compatibility keyword arguments from higher client layers. This layer does + not consume ``function_invocation_kwargs`` directly; if present, it is ignored + because function invocation has already been processed above. If a ``client_kwargs`` + mapping is present, it is flattened into ordinary keyword arguments for tracing and + forwarding so clients that use those values continue to work while clients that + ignore extra kwargs remain compatible. + """ from ._types import ChatResponse, ChatResponseUpdate, ResponseStream # type: ignore[reportUnusedImport] global OBSERVABILITY_SETTINGS super_get_response = super().get_response # type: ignore[misc] + compatibility_client_kwargs = kwargs.pop("client_kwargs", None) + kwargs.pop("function_invocation_kwargs", None) + merged_client_kwargs = ( + dict(cast(Mapping[str, Any], compatibility_client_kwargs)) + if isinstance(compatibility_client_kwargs, Mapping) + else {} + ) + merged_client_kwargs.update(kwargs) if not OBSERVABILITY_SETTINGS.ENABLED: return super_get_response( # type: ignore[no-any-return] @@ -1175,12 +1207,14 @@ def get_response( options=options, compaction_strategy=compaction_strategy, tokenizer=tokenizer, - **kwargs, + **merged_client_kwargs, ) opts: dict[str, Any] = options or {} # type: ignore[assignment] provider_name = str(getattr(self, "otel_provider_name", "unknown")) - model_id = kwargs.get("model_id") or opts.get("model_id") or getattr(self, "model_id", None) or "unknown" + model_id = ( + merged_client_kwargs.get("model_id") or opts.get("model_id") or getattr(self, "model_id", None) or "unknown" + ) service_url_func = getattr(self, "service_url", None) service_url = str(service_url_func() if callable(service_url_func) else "unknown") attributes = _get_span_attributes( @@ -1188,7 +1222,7 @@ def get_response( provider_name=provider_name, model=model_id, service_url=service_url, - **kwargs, + **merged_client_kwargs, ) if stream: @@ -1200,7 +1234,7 @@ def get_response( options=opts, compaction_strategy=compaction_strategy, tokenizer=tokenizer, - **kwargs, + **merged_client_kwargs, ), ) @@ -1247,6 +1281,7 @@ async def _finalize_stream() -> None: operation_duration_histogram=getattr(self, "duration_histogram", None), duration=duration, ) + _mark_inner_response_telemetry_captured(response) if ( OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and isinstance(response, ChatResponse) @@ -1291,7 +1326,7 @@ async def _get_response() -> ChatResponse: options=opts, compaction_strategy=compaction_strategy, tokenizer=tokenizer, - **kwargs, + **merged_client_kwargs, ), ) except Exception as exception: @@ -1306,6 +1341,7 @@ async def _get_response() -> ChatResponse: operation_duration_histogram=getattr(self, "duration_histogram", None), duration=duration, ) + _mark_inner_response_telemetry_captured(response) if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages: finish_reason = cast( "FinishReason | None", @@ -1420,6 +1456,8 @@ def run( session: AgentSession | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]]: ... @@ -1432,6 +1470,8 @@ def run( session: AgentSession | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... @@ -1443,6 +1483,8 @@ def run( session: AgentSession | None = None, compaction_strategy: CompactionStrategy | None = None, tokenizer: TokenizerProtocol | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, **kwargs: Any, ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: """Trace agent runs with OpenTelemetry spans and metrics.""" @@ -1454,8 +1496,6 @@ def run( super().run, # type: ignore[misc] ) provider_name = str(self.otel_provider_name) - capture_usage = bool(getattr(self, "_otel_capture_usage", True)) - if not OBSERVABILITY_SETTINGS.ENABLED: return super_run( # type: ignore[no-any-return] messages=messages, @@ -1463,11 +1503,15 @@ def run( session=session, compaction_strategy=compaction_strategy, tokenizer=tokenizer, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, **kwargs, ) default_options = getattr(self, "default_options", {}) options = kwargs.get("options") + merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {} + merged_client_kwargs.update(kwargs) merged_options: dict[str, Any] = merge_chat_options(default_options, options or {}) attributes = _get_span_attributes( operation_name=OtelAttr.AGENT_INVOKE_OPERATION, @@ -1477,24 +1521,35 @@ def run( agent_description=getattr(self, "description", None), thread_id=session.service_session_id if session else None, all_options=merged_options, - **kwargs, + **merged_client_kwargs, + ) + + inner_response_telemetry_captured_fields: set[str] = set() + inner_response_telemetry_captured_fields_token = INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.set( + inner_response_telemetry_captured_fields ) if stream: - run_result: object = super_run( - messages=messages, - stream=True, - session=session, - compaction_strategy=compaction_strategy, - tokenizer=tokenizer, - **kwargs, - ) - if isinstance(run_result, ResponseStream): - result_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = run_result # pyright: ignore[reportUnknownVariableType] - elif isinstance(run_result, Awaitable): - result_stream = ResponseStream.from_awaitable(run_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] - else: - raise RuntimeError("Streaming telemetry requires a ResponseStream result.") + try: + run_result: object = super_run( + messages=messages, + stream=True, + session=session, + compaction_strategy=compaction_strategy, + tokenizer=tokenizer, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + **kwargs, + ) + if isinstance(run_result, ResponseStream): + result_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = run_result # pyright: ignore[reportUnknownVariableType] + elif isinstance(run_result, Awaitable): + result_stream = ResponseStream.from_awaitable(run_result) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + else: + raise RuntimeError("Streaming telemetry requires a ResponseStream result.") + except Exception: + INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token) + raise # Create span directly without trace.use_span() context attachment. # Streaming spans are closed asynchronously in cleanup hooks, which run @@ -1534,7 +1589,9 @@ async def _finalize_stream() -> None: response_attributes = _get_response_attributes( attributes, response, - capture_usage=capture_usage, + capture_response_id=INNER_RESPONSE_ID_CAPTURED_FIELD + not in inner_response_telemetry_captured_fields, + capture_usage=INNER_USAGE_CAPTURED_FIELD not in inner_response_telemetry_captured_fields, ) _capture_response(span=span, attributes=response_attributes, duration=duration) if ( @@ -1551,6 +1608,7 @@ async def _finalize_stream() -> None: except Exception as exception: capture_exception(span=span, exception=exception, timestamp=time_ns()) finally: + INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token) _close_span() # Register a weak reference callback to close the span if stream is garbage collected @@ -1562,39 +1620,50 @@ async def _finalize_stream() -> None: return wrapped_stream async def _run() -> AgentResponse: - with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span: - if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages: - _capture_messages( - span=span, - provider_name=provider_name, - messages=messages, - system_instructions=_get_instructions_from_options(merged_options), - ) - start_time_stamp = perf_counter() - try: - response: AgentResponse[Any] = await super_run( - messages=messages, - stream=False, - session=session, - compaction_strategy=compaction_strategy, - tokenizer=tokenizer, - **kwargs, - ) - except Exception as exception: - capture_exception(span=span, exception=exception, timestamp=time_ns()) - raise - duration = perf_counter() - start_time_stamp - if response: - response_attributes = _get_response_attributes(attributes, response, capture_usage=capture_usage) - _capture_response(span=span, attributes=response_attributes, duration=duration) - if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages: + try: + with _get_span(attributes=attributes, span_name_attribute=OtelAttr.AGENT_NAME) as span: + if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and messages: _capture_messages( span=span, provider_name=provider_name, - messages=response.messages, - output=True, + messages=messages, + system_instructions=_get_instructions_from_options(merged_options), ) - return response # type: ignore[return-value,no-any-return] + start_time_stamp = perf_counter() + try: + response: AgentResponse[Any] = await super_run( + messages=messages, + stream=False, + session=session, + compaction_strategy=compaction_strategy, + tokenizer=tokenizer, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + **kwargs, + ) + except Exception as exception: + capture_exception(span=span, exception=exception, timestamp=time_ns()) + raise + duration = perf_counter() - start_time_stamp + if response: + response_attributes = _get_response_attributes( + attributes, + response, + capture_response_id=INNER_RESPONSE_ID_CAPTURED_FIELD + not in inner_response_telemetry_captured_fields, + capture_usage=INNER_USAGE_CAPTURED_FIELD not in inner_response_telemetry_captured_fields, + ) + _capture_response(span=span, attributes=response_attributes, duration=duration) + if OBSERVABILITY_SETTINGS.SENSITIVE_DATA_ENABLED and response.messages: + _capture_messages( + span=span, + provider_name=provider_name, + messages=response.messages, + output=True, + ) + return response # type: ignore[return-value,no-any-return] + finally: + INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.reset(inner_response_telemetry_captured_fields_token) return _run() @@ -1850,14 +1919,26 @@ def _to_otel_part(content: Content) -> dict[str, Any] | None: return None +def _mark_inner_response_telemetry_captured(response: ChatResponse | AgentResponse) -> None: + """Record when an inner chat telemetry span already captured response metadata.""" + captured_fields = INNER_RESPONSE_TELEMETRY_CAPTURED_FIELDS.get() + if captured_fields is None: + return + if response.response_id: + captured_fields.add(INNER_RESPONSE_ID_CAPTURED_FIELD) + if response.usage_details: + captured_fields.add(INNER_USAGE_CAPTURED_FIELD) + + def _get_response_attributes( attributes: dict[str, Any], response: ChatResponse | AgentResponse, *, + capture_response_id: bool = True, capture_usage: bool = True, ) -> dict[str, Any]: """Get the response attributes from a response.""" - if response.response_id: + if capture_response_id and response.response_id: attributes[OtelAttr.RESPONSE_ID] = response.response_id finish_reason = getattr(response, "finish_reason", None) if not finish_reason: diff --git a/python/packages/core/agent_framework/openai/_chat_client.py b/python/packages/core/agent_framework/openai/_chat_client.py index 0562e68f3e6..6df57fe428f 100644 --- a/python/packages/core/agent_framework/openai/_chat_client.py +++ b/python/packages/core/agent_framework/openai/_chat_client.py @@ -15,7 +15,7 @@ ) from datetime import datetime, timezone from itertools import chain -from typing import Any, Generic, Literal, cast +from typing import Any, Generic, Literal, cast, overload from openai import AsyncOpenAI, BadRequestError from openai.lib._parsing._completions import type_to_response_format_param @@ -30,7 +30,8 @@ from pydantic import BaseModel from .._clients import BaseChatClient -from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer +from .._docstrings import apply_layered_docstring +from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer, FunctionMiddlewareTypes from .._settings import load_settings from .._tools import ( FunctionInvocationConfiguration, @@ -72,6 +73,7 @@ logger = logging.getLogger("agent_framework.openai") +ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel) ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None) @@ -213,6 +215,57 @@ def get_web_search_tool( # endregion + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[False] = ..., + options: ChatOptions[ResponseModelBoundT], + **kwargs: Any, + ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... + + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[False] = ..., + options: OpenAIChatOptionsT | ChatOptions[None] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[Any]]: ... + + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[True], + options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, + **kwargs: Any, + ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... + + @override + def get_response( + self, + messages: Sequence[Message], + *, + stream: bool = False, + options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: + """Get a response from the raw OpenAI chat client.""" + super_get_response = cast( + "Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]", + super().get_response, # type: ignore[misc] + ) + return super_get_response( # type: ignore[no-any-return] + messages=messages, + stream=stream, + options=options, + **kwargs, + ) + @override def _inner_get_response( self, @@ -579,9 +632,20 @@ def _prepare_message_for_openai(self, message: Message) -> list[dict[str, Any]]: args["tool_calls"] = [self._prepare_content_for_openai(content)] # type: ignore case "function_result": args["tool_call_id"] = content.call_id - # Always include content for tool results - API requires it even if empty - # Functions returning None should still have a tool result message - args["content"] = content.result if content.result is not None else "" + if content.items: + text_parts = [item.text or "" for item in content.items if item.type == "text"] + rich_items = [item for item in content.items if item.type in ("data", "uri")] + if rich_items: + logger.warning( + "OpenAI Chat Completions API does not support rich content (images, audio) " + "in tool results. Rich content items will be omitted. " + "Use the Responses API client for rich tool results." + ) + args["content"] = "\n".join(text_parts) if text_parts else "" + else: + args["content"] = content.result if content.result is not None else "" + all_messages.append(args) + continue case "text_reasoning" if (protected_data := content.protected_data) is not None: # Buffer reasoning to attach to the next message with content/tool_calls pending_reasoning = json.loads(protected_data) @@ -646,7 +710,7 @@ def _prepare_content_for_openai(self, content: Content) -> dict[str, Any]: case "function_result": return { "tool_call_id": content.call_id, - "content": content.result, + "content": content.result if content.result is not None else "", } case "data" | "uri" if content.has_top_level_media_type("image"): return { @@ -716,6 +780,77 @@ class OpenAIChatClient( # type: ignore[misc] ): """OpenAI Chat completion class with middleware, telemetry, and function invocation support.""" + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[False] = ..., + options: ChatOptions[ResponseModelBoundT], + function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ... + + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[False] = ..., + options: OpenAIChatOptionsT | ChatOptions[None] | None = None, + function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[Any]]: ... + + @overload + def get_response( + self, + messages: Sequence[Message], + *, + stream: Literal[True], + options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, + function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + **kwargs: Any, + ) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ... + + @override + def get_response( + self, + messages: Sequence[Message], + *, + stream: bool = False, + options: OpenAIChatOptionsT | ChatOptions[Any] | None = None, + function_middleware: Sequence[FunctionMiddlewareTypes] | None = None, + function_invocation_kwargs: Mapping[str, Any] | None = None, + client_kwargs: Mapping[str, Any] | None = None, + middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, + **kwargs: Any, + ) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: + """Get a response from the OpenAI chat client with all standard layers enabled.""" + super_get_response = cast( + "Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]", + super().get_response, # type: ignore[misc] + ) + return super_get_response( # type: ignore[no-any-return] + messages=messages, + stream=stream, + options=options, + function_middleware=function_middleware, + function_invocation_kwargs=function_invocation_kwargs, + client_kwargs=client_kwargs, + middleware=middleware, + **kwargs, + ) + def __init__( self, *, @@ -819,3 +954,25 @@ class MyOptions(OpenAIChatOptions, total=False): middleware=middleware, function_invocation_configuration=function_invocation_configuration, ) + + +def _apply_openai_chat_client_docstrings() -> None: + """Align OpenAI chat-client docstrings with the raw implementation.""" + apply_layered_docstring(RawOpenAIChatClient.get_response, BaseChatClient.get_response) + apply_layered_docstring( + OpenAIChatClient.get_response, + RawOpenAIChatClient.get_response, + extra_keyword_args={ + "function_middleware": """ + Optional per-call function middleware. + When omitted, middleware configured on the client or forwarded from higher layers is used. + """, + "middleware": """ + Optional per-call chat and function middleware. + This is merged with any middleware configured on the client for the current request. + """, + }, + ) + + +_apply_openai_chat_client_docstrings() diff --git a/python/packages/core/agent_framework/openai/_responses_client.py b/python/packages/core/agent_framework/openai/_responses_client.py index 03dc1cd5ed7..145986fb9a0 100644 --- a/python/packages/core/agent_framework/openai/_responses_client.py +++ b/python/packages/core/agent_framework/openai/_responses_client.py @@ -16,7 +16,16 @@ ) from datetime import datetime, timezone from itertools import chain -from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, NoReturn, TypedDict, cast +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Literal, + NoReturn, + TypedDict, + cast, +) from openai import AsyncOpenAI, BadRequestError from openai.types.responses import FunctionShellTool @@ -309,23 +318,33 @@ async def _stream() -> AsyncIterable[ChatResponseUpdate]: ) async for chunk in stream_response: yield self._parse_chunk_from_openai( - chunk, options=validated_options, function_call_ids=function_call_ids + chunk, + options=validated_options, + function_call_ids=function_call_ids, ) except Exception as ex: self._handle_request_error(ex) else: - client, run_options, validated_options = await self._prepare_request(messages, options, **kwargs) + ( + client, + run_options, + validated_options, + ) = await self._prepare_request(messages, options, **kwargs) try: if "text_format" in run_options: async with client.responses.stream(**run_options) as response: async for chunk in response: yield self._parse_chunk_from_openai( - chunk, options=validated_options, function_call_ids=function_call_ids + chunk, + options=validated_options, + function_call_ids=function_call_ids, ) else: async for chunk in await client.responses.create(stream=True, **run_options): yield self._parse_chunk_from_openai( - chunk, options=validated_options, function_call_ids=function_call_ids + chunk, + options=validated_options, + function_call_ids=function_call_ids, ) except Exception as ex: self._handle_request_error(ex) @@ -439,7 +458,8 @@ def _get_conversation_id( # region Prep methods def _prepare_tools_for_openai( - self, tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None + self, + tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None, ) -> list[Any]: """Prepare tools for the OpenAI Responses API. @@ -1194,10 +1214,22 @@ def _prepare_content_for_openai( "output": self._to_local_shell_output_payload(content), } # call_id for the result needs to be the same as the call_id for the function call + output: str | list[dict[str, Any]] = content.result or "" + if content.items and any(item.type in ("data", "uri") for item in content.items): + output_parts: list[dict[str, Any]] = [] + for item in content.items: + if item.type == "text": + output_parts.append({"type": "input_text", "text": item.text or ""}) + else: + part = self._prepare_content_for_openai("user", item, call_id_to_id) # type: ignore[arg-type] + if part: + output_parts.append(part) + if output_parts: + output = output_parts return { "call_id": content.call_id, "type": "function_call_output", - "output": content.result if content.result is not None else "", + "output": output, } case "function_approval_request": return { @@ -1825,7 +1857,10 @@ def _parse_chunk_from_openai( case "response.created": response_id = event.response.id conversation_id = self._get_conversation_id(event.response, options.get("store")) - if event.response.status and event.response.status in ("in_progress", "queued"): + if event.response.status and event.response.status in ( + "in_progress", + "queued", + ): continuation_token = OpenAIContinuationToken(response_id=event.response.id) case "response.in_progress": response_id = event.response.id @@ -2003,7 +2038,11 @@ def _parse_chunk_from_openai( Content.from_shell_tool_call( call_id=local_call_id, commands=[local_command] if local_command else [], - timeout_ms=getattr(getattr(event_item, "action", None), "timeout_ms", None), + timeout_ms=getattr( + getattr(event_item, "action", None), + "timeout_ms", + None, + ), status=getattr(event_item, "status", None), raw_representation=event_item, ) diff --git a/python/packages/core/tests/assets/sample_image.jpg b/python/packages/core/tests/assets/sample_image.jpg new file mode 100644 index 00000000000..ea6486656fd Binary files /dev/null and b/python/packages/core/tests/assets/sample_image.jpg differ diff --git a/python/packages/core/tests/azure/test_azure_chat_client.py b/python/packages/core/tests/azure/test_azure_chat_client.py index b6809d097d2..dd50c48db49 100644 --- a/python/packages/core/tests/azure/test_azure_chat_client.py +++ b/python/packages/core/tests/azure/test_azure_chat_client.py @@ -89,18 +89,26 @@ def test_init_endpoint(azure_openai_unit_test_env: dict[str, str]) -> None: @pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_CHAT_DEPLOYMENT_NAME"]], indirect=True) -def test_init_with_empty_deployment_name(azure_openai_unit_test_env: dict[str, str]) -> None: +def test_init_with_empty_deployment_name( + azure_openai_unit_test_env: dict[str, str], +) -> None: with pytest.raises(ValueError): AzureOpenAIChatClient() @pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_BASE_URL"]], indirect=True) -def test_init_with_empty_endpoint_and_base_url(azure_openai_unit_test_env: dict[str, str]) -> None: +def test_init_with_empty_endpoint_and_base_url( + azure_openai_unit_test_env: dict[str, str], +) -> None: with pytest.raises(ValueError): AzureOpenAIChatClient() -@pytest.mark.parametrize("override_env_param_dict", [{"AZURE_OPENAI_ENDPOINT": "http://test.com"}], indirect=True) +@pytest.mark.parametrize( + "override_env_param_dict", + [{"AZURE_OPENAI_ENDPOINT": "http://test.com"}], + indirect=True, +) def test_init_with_invalid_endpoint(azure_openai_unit_test_env: dict[str, str]) -> None: # Note: URL scheme validation was previously handled by pydantic's HTTPsUrl type. # After migrating to load_settings with TypedDict, endpoint is a plain string and no longer @@ -147,7 +155,11 @@ def mock_chat_completion_response() -> ChatCompletion: return ChatCompletion( id="test_id", choices=[ - Choice(index=0, message=ChatCompletionMessage(content="test", role="assistant"), finish_reason="stop") + Choice( + index=0, + message=ChatCompletionMessage(content="test", role="assistant"), + finish_reason="stop", + ) ], created=0, model="test", @@ -159,7 +171,13 @@ def mock_chat_completion_response() -> ChatCompletion: def mock_streaming_chat_completion_response() -> AsyncStream[ChatCompletionChunk]: content = ChatCompletionChunk( id="test_id", - choices=[ChunkChoice(index=0, delta=ChunkChoiceDelta(content="test", role="assistant"), finish_reason="stop")], + choices=[ + ChunkChoice( + index=0, + delta=ChunkChoiceDelta(content="test", role="assistant"), + finish_reason="stop", + ) + ], created=0, model="test", object="chat.completion.chunk", @@ -546,7 +564,9 @@ async def test_bad_request_non_content_filter( test_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") assert test_endpoint is not None mock_create.side_effect = openai.BadRequestError( - "The request was bad.", response=Response(400, request=Request("POST", test_endpoint)), body={} + "The request was bad.", + response=Response(400, request=Request("POST", test_endpoint)), + body={}, ) azure_chat_client = AzureOpenAIChatClient() @@ -605,7 +625,13 @@ async def test_streaming_with_none_delta( # Second chunk has actual content chunk_with_content = ChatCompletionChunk( id="test_id", - choices=[ChunkChoice(index=0, delta=ChunkChoiceDelta(content="test", role="assistant"), finish_reason="stop")], + choices=[ + ChunkChoice( + index=0, + delta=ChunkChoiceDelta(content="test", role="assistant"), + finish_reason="stop", + ) + ], created=0, model="test", object="chat.completion.chunk", @@ -854,7 +880,10 @@ async def test_azure_openai_chat_client_agent_basic_run_streaming(): ) as agent: # Test streaming run full_text = "" - async for chunk in agent.run("Please respond with exactly: 'This is a streaming response test.'", stream=True): + async for chunk in agent.run( + "Please respond with exactly: 'This is a streaming response test.'", + stream=True, + ): assert isinstance(chunk, AgentResponseUpdate) if chunk.text: full_text += chunk.text diff --git a/python/packages/core/tests/azure/test_azure_responses_client.py b/python/packages/core/tests/azure/test_azure_responses_client.py index 68ee0661582..35eaa2b4072 100644 --- a/python/packages/core/tests/azure/test_azure_responses_client.py +++ b/python/packages/core/tests/azure/test_azure_responses_client.py @@ -3,6 +3,7 @@ import json import logging import os +from pathlib import Path from typing import Annotated, Any from unittest.mock import MagicMock @@ -44,10 +45,13 @@ async def get_weather(location: Annotated[str, "The location as a city name"]) - return f"The weather in {location} is sunny and 72°F." -async def create_vector_store(client: AzureOpenAIResponsesClient) -> tuple[str, Content]: +async def create_vector_store( + client: AzureOpenAIResponsesClient, +) -> tuple[str, Content]: """Create a vector store with sample documents for testing.""" file = await client.client.files.create( - file=("todays_weather.txt", b"The weather today is sunny with a high of 75F."), purpose="assistants" + file=("todays_weather.txt", b"The weather today is sunny with a high of 75F."), + purpose="assistants", ) vector_store = await client.client.vector_stores.create( name="knowledge_base", @@ -98,7 +102,9 @@ def test_init_model_id_kwarg(azure_openai_unit_test_env: dict[str, str]) -> None assert isinstance(azure_responses_client, SupportsChatGetResponse) -def test_init_model_id_kwarg_does_not_override_deployment_name(azure_openai_unit_test_env: dict[str, str]) -> None: +def test_init_model_id_kwarg_does_not_override_deployment_name( + azure_openai_unit_test_env: dict[str, str], +) -> None: """Test that deployment_name takes precedence over model_id kwarg (issue #4299).""" azure_responses_client = AzureOpenAIResponsesClient(deployment_name="my-deployment", model_id="gpt-4o") @@ -323,7 +329,12 @@ def test_serialize(azure_openai_unit_test_env: dict[str, str]) -> None: "temperature_c": {"type": "number"}, "advisory": {"type": "string"}, }, - "required": ["location", "conditions", "temperature_c", "advisory"], + "required": [ + "location", + "conditions", + "temperature_c", + "advisory", + ], "additionalProperties": False, }, }, @@ -445,7 +456,12 @@ async def test_integration_web_search() -> None: # Test that the client will use the web search tool with location content = { - "messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")], + "messages": [ + Message( + role="user", + text="What is the current weather? Do not ask for my current location.", + ) + ], "options": { "tool_choice": "auto", "tools": [ @@ -556,7 +572,12 @@ async def test_integration_client_agent_hosted_code_interpreter_tool(): client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) response = await client.get_response( - messages=[Message(role="user", text="Calculate the sum of numbers from 1 to 10 using Python code.")], + messages=[ + Message( + role="user", + text="Calculate the sum of numbers from 1 to 10 using Python code.", + ) + ], options={ "tools": [AzureOpenAIResponsesClient.get_code_interpreter_tool()], }, @@ -604,6 +625,44 @@ async def test_integration_client_agent_existing_session(): assert "photography" in second_response.text.lower() +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_azure_integration_tests_disabled +async def test_azure_openai_responses_client_tool_rich_content_image() -> None: + """Test that Azure OpenAI Responses client can handle tool results containing images.""" + image_path = Path(__file__).parent.parent / "assets" / "sample_image.jpg" + image_bytes = image_path.read_bytes() + + @tool(approval_mode="never_require") + def get_test_image() -> Content: + """Return a test image for analysis.""" + return Content.from_data(data=image_bytes, media_type="image/jpeg") + + client = AzureOpenAIResponsesClient(credential=AzureCliCredential()) + client.function_invocation_configuration["max_iterations"] = 2 + + for streaming in [False, True]: + messages = [ + Message( + role="user", + text="Call the get_test_image tool and describe what you see.", + ) + ] + options: dict[str, Any] = {"tools": [get_test_image], "tool_choice": "auto"} + + if streaming: + response = await client.get_response(messages=messages, stream=True, options=options).get_final_response() + else: + response = await client.get_response(messages=messages, options=options) + + assert response is not None + assert isinstance(response, ChatResponse) + assert response.text is not None + assert len(response.text) > 0 + # sample_image.jpg contains a photo of a house; the model should mention it. + assert "house" in response.text.lower(), f"Model did not describe the house image. Response: {response.text}" + + # region Integration with Foundry V2 diff --git a/python/packages/core/tests/core/test_agents.py b/python/packages/core/tests/core/test_agents.py index d804d07c555..8e6faa37c4c 100644 --- a/python/packages/core/tests/core/test_agents.py +++ b/python/packages/core/tests/core/test_agents.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. import contextlib +import inspect from collections.abc import AsyncIterable, MutableSequence from typing import Any from unittest.mock import AsyncMock, MagicMock @@ -30,7 +31,8 @@ tool, ) from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name -from agent_framework._mcp import MCPTool +from agent_framework._mcp import MCPTool, _build_prefixed_mcp_name, _normalize_mcp_name +from agent_framework._middleware import FunctionInvocationContext class _FixedTokenizer: @@ -41,6 +43,30 @@ def count_tokens(self, text: str) -> int: return self.token_count +class _ConnectedMCPTool(MCPTool): + def __init__(self, name: str, function_names: list[str], *, tool_name_prefix: str | None = None) -> None: + super().__init__(name=name, tool_name_prefix=tool_name_prefix) + self.is_connected = True + self._functions = [] + for function_name in function_names: + normalized_name = _normalize_mcp_name(function_name) + exposed_name = _build_prefixed_mcp_name(normalized_name, self.tool_name_prefix) + self._functions.append( + FunctionTool( + func=lambda value=function_name: value, + name=exposed_name, + description=f"{function_name} from {name}", + additional_properties={ + "_mcp_remote_name": function_name, + "_mcp_normalized_name": normalized_name, + }, + ) + ) + + def get_mcp_client(self) -> contextlib.AbstractAsyncContextManager[Any]: + raise NotImplementedError + + def test_agent_session_type(agent_session: AgentSession) -> None: assert isinstance(agent_session, AgentSession) @@ -77,6 +103,30 @@ def test_chat_client_agent_type(client: SupportsChatGetResponse) -> None: assert isinstance(chat_client_agent, SupportsAgentRun) +def test_agent_init_docstring_surfaces_raw_agent_constructor_docs() -> None: + docstring = inspect.getdoc(Agent.__init__) + + assert docstring is not None + assert "client: The chat client to use for the agent." in docstring + assert "middleware: List of middleware to intercept agent and function invocations." in docstring + + +def test_agent_run_docstring_surfaces_raw_agent_runtime_docs() -> None: + docstring = inspect.getdoc(Agent.run) + + assert docstring is not None + assert "Run the agent with the given messages and options." in docstring + assert "function_invocation_kwargs: Keyword arguments forwarded to tool invocation." in docstring + assert "middleware: Optional per-run agent, chat, and function middleware." in docstring + + +def test_agent_run_is_defined_on_agent_class() -> None: + signature = inspect.signature(Agent.run) + + assert Agent.run.__qualname__ == "Agent.run" + assert "middleware" in signature.parameters + + async def test_chat_client_agent_init(client: SupportsChatGetResponse) -> None: agent_id = str(uuid4()) agent = Agent(client=client, id=agent_id, description="Test") @@ -97,6 +147,13 @@ async def test_chat_client_agent_init_with_name( assert agent.description == "Test" +def test_agent_init_warns_for_direct_additional_properties(client: SupportsChatGetResponse) -> None: + with pytest.warns(DeprecationWarning, match="additional_properties"): + agent = Agent(client=client, legacy_key="legacy-value") + + assert agent.additional_properties["legacy_key"] == "legacy-value" + + async def test_chat_client_agent_run(client: SupportsChatGetResponse) -> None: agent = Agent(client=client) @@ -229,33 +286,38 @@ async def test_prepare_session_does_not_mutate_agent_chat_options( assert len(agent.default_options["tools"]) == 1 -async def test_prepare_run_context_keeps_compaction_overrides_out_of_kwargs( +async def test_prepare_run_context_handles_function_kwargs( chat_client_base: SupportsChatGetResponse, ) -> None: - strategy = SlidingWindowStrategy(keep_last_groups=2) - tokenizer = _FixedTokenizer(13) agent = Agent(client=chat_client_base) + session = agent.create_session() ctx = await agent._prepare_run_context( # type: ignore[reportPrivateUsage] - messages=[Message(role="user", text="Hello")], - session=None, + messages="Hello", + session=session, tools=None, - options=None, - compaction_strategy=strategy, - tokenizer=tokenizer, - kwargs={"custom_flag": True}, + options={ + "temperature": 0.4, + "additional_function_arguments": {"from_options": "options-value"}, + }, + compaction_strategy=None, + tokenizer=None, + legacy_kwargs={"legacy_key": "legacy-value"}, + function_invocation_kwargs={"runtime_key": "runtime-value"}, + client_kwargs={"client_key": "client-value"}, ) - assert ctx["compaction_strategy"] is strategy - assert ctx["tokenizer"] is tokenizer - assert ctx["filtered_kwargs"].get("custom_flag") is True - assert "compaction_strategy" not in ctx["filtered_kwargs"] - assert "tokenizer" not in ctx["filtered_kwargs"] + assert ctx["chat_options"]["temperature"] == 0.4 + assert "additional_function_arguments" not in ctx["chat_options"] + assert ctx["function_invocation_kwargs"]["from_options"] == "options-value" + assert ctx["function_invocation_kwargs"]["legacy_key"] == "legacy-value" + assert ctx["function_invocation_kwargs"]["runtime_key"] == "runtime-value" + assert "session" not in ctx["function_invocation_kwargs"] + assert ctx["client_kwargs"]["client_key"] == "client-value" + assert ctx["client_kwargs"]["session"] is session -async def test_chat_client_agent_run_with_session( - chat_client_base: SupportsChatGetResponse, -) -> None: +async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChatGetResponse) -> None: mock_response = ChatResponse( messages=[Message(role="assistant", contents=[Content.from_text("test response")])], conversation_id="123", @@ -696,8 +758,9 @@ async def test_chat_agent_as_tool_basic(client: SupportsChatGetResponse) -> None assert tool.name == "TestAgent" assert tool.description == "Test agent for as_tool" + assert tool.approval_mode == "never_require" assert hasattr(tool, "func") - assert hasattr(tool, "input_model") + assert tool.input_model is None async def test_chat_agent_as_tool_custom_parameters( @@ -711,13 +774,15 @@ async def test_chat_agent_as_tool_custom_parameters( description="Custom description", arg_name="query", arg_description="Custom input description", + approval_mode="always_require", ) assert tool.name == "CustomTool" assert tool.description == "Custom description" + assert tool.approval_mode == "always_require" # Check that the input model has the custom field name - schema = tool.input_model.model_json_schema() + schema = tool.parameters() assert "query" in schema["properties"] assert schema["properties"]["query"]["description"] == "Custom input description" @@ -736,7 +801,7 @@ async def test_chat_agent_as_tool_defaults(client: SupportsChatGetResponse) -> N assert tool.description == "" # Should default to empty string # Check default input field - schema = tool.input_model.model_json_schema() + schema = tool.parameters() assert "task" in schema["properties"] assert "Task for TestAgent" in schema["properties"]["task"]["description"] @@ -759,11 +824,12 @@ async def test_chat_agent_as_tool_function_execution( tool = agent.as_tool() # Test function execution - result = await tool.invoke(arguments=tool.input_model(task="Hello")) + result = await tool.invoke(arguments={"task": "Hello"}) - # Should return the agent's response text - assert isinstance(result, str) - assert result == "test response" # From mock chat client + # Should return the agent's response text as a list of Content items + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].text == "test streaming response another update" # From mock streaming client async def test_chat_agent_as_tool_with_stream_callback( @@ -781,14 +847,15 @@ def stream_callback(update: AgentResponseUpdate) -> None: tool = agent.as_tool(stream_callback=stream_callback) # Execute the tool - result = await tool.invoke(arguments=tool.input_model(task="Hello")) + result = await tool.invoke(arguments={"task": "Hello"}) # Should have collected streaming updates assert len(collected_updates) > 0 - assert isinstance(result, str) + assert isinstance(result, list) + result_text = result[0].text # Result should be concatenation of all streaming updates expected_text = "".join(update.text for update in collected_updates) - assert result == expected_text + assert result_text == expected_text async def test_chat_agent_as_tool_with_custom_arg_name( @@ -800,8 +867,9 @@ async def test_chat_agent_as_tool_with_custom_arg_name( tool = agent.as_tool(arg_name="prompt", arg_description="Custom prompt input") # Test that the custom argument name works - result = await tool.invoke(arguments=tool.input_model(prompt="Test prompt")) - assert result == "test response" + result = await tool.invoke(arguments={"prompt": "Test prompt"}) + assert isinstance(result, list) + assert result[0].text == "test streaming response another update" async def test_chat_agent_as_tool_with_async_stream_callback( @@ -819,14 +887,15 @@ async def async_stream_callback(update: AgentResponseUpdate) -> None: tool = agent.as_tool(stream_callback=async_stream_callback) # Execute the tool - result = await tool.invoke(arguments=tool.input_model(task="Hello")) + result = await tool.invoke(arguments={"task": "Hello"}) # Should have collected streaming updates assert len(collected_updates) > 0 - assert isinstance(result, str) + assert isinstance(result, list) + result_text = result[0].text # Result should be concatenation of all streaming updates expected_text = "".join(update.text for update in collected_updates) - assert result == expected_text + assert result_text == expected_text async def test_chat_agent_as_tool_name_sanitization( @@ -849,17 +918,14 @@ async def test_chat_agent_as_tool_name_sanitization( assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}" -async def test_chat_agent_as_tool_propagate_session_true( - client: SupportsChatGetResponse, -) -> None: - """Test that propagate_session=True forwards the parent's session to the sub-agent.""" +async def test_chat_agent_as_tool_propagate_session_true(client: SupportsChatGetResponse) -> None: + """Test that propagate_session=True forwards the session to the sub-agent.""" agent = Agent(client=client, name="SubAgent", description="Sub agent") tool = agent.as_tool(propagate_session=True) parent_session = AgentSession(session_id="parent-session-123") parent_session.state["shared_key"] = "shared_value" - # Spy on the agent's run method to capture the session argument original_run = agent.run captured_session = None @@ -870,16 +936,20 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: agent.run = capturing_run # type: ignore[assignment, method-assign] - await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session) + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + session=parent_session, + ) + ) assert captured_session is parent_session assert captured_session.session_id == "parent-session-123" assert captured_session.state["shared_key"] == "shared_value" -async def test_chat_agent_as_tool_propagate_session_false_by_default( - client: SupportsChatGetResponse, -) -> None: +async def test_chat_agent_as_tool_propagate_session_false_by_default(client: SupportsChatGetResponse) -> None: """Test that propagate_session defaults to False and does not forward the session.""" agent = Agent(client=client, name="SubAgent", description="Sub agent") tool = agent.as_tool() # default: propagate_session=False @@ -896,22 +966,25 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: agent.run = capturing_run # type: ignore[assignment, method-assign] - await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session) + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + session=parent_session, + ) + ) assert captured_session is None -async def test_chat_agent_as_tool_propagate_session_shares_state( - client: SupportsChatGetResponse, -) -> None: - """Test that shared session allows the sub-agent to read and write parent's state.""" +async def test_chat_agent_as_tool_propagate_session_shares_state(client: SupportsChatGetResponse) -> None: + """Test that a propagated session allows the sub-agent to read and write parent state.""" agent = Agent(client=client, name="SubAgent", description="Sub agent") tool = agent.as_tool(propagate_session=True) parent_session = AgentSession(session_id="shared-session") parent_session.state["counter"] = 0 - # The sub-agent receives the same session object, so mutations are shared original_run = agent.run captured_session = None @@ -924,9 +997,14 @@ def capturing_run(*args: Any, **kwargs: Any) -> Any: agent.run = capturing_run # type: ignore[assignment, method-assign] - await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session) + await tool.invoke( + context=FunctionInvocationContext( + function=tool, + arguments={"task": "Hello"}, + session=parent_session, + ) + ) - # The parent's state should reflect the sub-agent's mutation assert parent_session.state["counter"] == 1 @@ -949,6 +1027,7 @@ async def test_chat_agent_run_with_mcp_tools(client: SupportsChatGetResponse) -> # Create a mock MCP tool mock_mcp_tool = MagicMock(spec=MCPTool) + mock_mcp_tool.name = "mock-mcp" mock_mcp_tool.is_connected = False mock_mcp_tool.functions = [MagicMock()] @@ -966,6 +1045,7 @@ async def test_chat_agent_with_local_mcp_tools(client: SupportsChatGetResponse) """Test agent initialization with local MCP tools.""" # Create a mock MCP tool mock_mcp_tool = MagicMock(spec=MCPTool) + mock_mcp_tool.name = "mock-mcp" mock_mcp_tool.is_connected = False mock_mcp_tool.__aenter__ = AsyncMock(return_value=mock_mcp_tool) mock_mcp_tool.__aexit__ = AsyncMock(return_value=None) @@ -1005,6 +1085,7 @@ async def capturing_inner( # Create a mock MCP tool that is already connected (simulates turn 2) mock_mcp_tool = MagicMock(spec=MCPTool) + mock_mcp_tool.name = "mock-mcp" mock_mcp_tool.is_connected = True mock_mcp_tool.functions = [mcp_func_a, mcp_func_b] mock_mcp_tool.__aenter__ = AsyncMock(return_value=mock_mcp_tool) @@ -1028,8 +1109,79 @@ async def capturing_inner( assert len(tool_names) == 3 +async def test_agent_run_raises_on_local_and_agent_mcp_name_conflict(chat_client_base: Any) -> None: + local_tool = FunctionTool( + func=lambda: "local", + name="delete_all_data", + description="Local protected tool", + approval_mode="always_require", + ) + agent = Agent( + client=chat_client_base, + name="TestAgent", + tools=[_ConnectedMCPTool(name="dangerous-mcp", function_names=["delete_all_data"])], + ) + + with raises(ValueError, match="tool_name_prefix"): + await agent.run("hello", tools=[local_tool]) + + +async def test_agent_run_raises_on_runtime_local_and_runtime_mcp_name_conflict(chat_client_base: Any) -> None: + local_tool = FunctionTool( + func=lambda: "local", + name="delete_all_data", + description="Local protected tool", + approval_mode="always_require", + ) + runtime_mcp = _ConnectedMCPTool(name="dangerous-mcp", function_names=["delete_all_data"]) + agent = Agent(client=chat_client_base, name="TestAgent") + + with raises(ValueError, match="tool_name_prefix"): + await agent.run("hello", tools=[local_tool, runtime_mcp]) + + +async def test_agent_run_raises_on_duplicate_agent_mcp_names(chat_client_base: Any) -> None: + agent = Agent( + client=chat_client_base, + name="TestAgent", + tools=[ + _ConnectedMCPTool(name="docs-mcp", function_names=["search"]), + _ConnectedMCPTool(name="github-mcp", function_names=["search"]), + ], + ) + + with raises(ValueError, match="tool_name_prefix"): + await agent.run("hello") + + +async def test_agent_run_accepts_prefixed_mcp_tools(chat_client_base: Any) -> None: + captured_options: list[dict[str, Any]] = [] + + original_inner = chat_client_base._inner_get_response + + async def capturing_inner( + *, messages: MutableSequence[Message], options: dict[str, Any], **kwargs: Any + ) -> ChatResponse: + captured_options.append(dict(options)) + return await original_inner(messages=messages, options=options, **kwargs) + + chat_client_base._inner_get_response = capturing_inner + + local_tool = FunctionTool(func=lambda: "local", name="search", description="Local search tool") + agent = Agent( + client=chat_client_base, + name="TestAgent", + tools=[_ConnectedMCPTool(name="docs-mcp", function_names=["search"], tool_name_prefix="docs")], + ) + + await agent.run("hello", tools=[local_tool]) + + tool_names = [tool.name for tool in captured_options[0]["tools"]] + assert tool_names == ["search", "docs_search"] + + async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> None: - """Verify tool execution receives 'session' inside **kwargs when function is called by client.""" + """Verify legacy **kwargs tools receive the session when agent.run() is called with one.""" captured: dict[str, Any] = {} @@ -1040,7 +1192,6 @@ def echo_session_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUn captured["has_state"] = session.state is not None if isinstance(session, AgentSession) else False return f"echo: {text}" - # Make the base client emit a function call for our tool chat_client_base.run_responses = [ ChatResponse( messages=Message( @@ -1060,17 +1211,52 @@ def echo_session_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUn agent = Agent(client=chat_client_base, tools=[echo_session_info]) session = agent.create_session() - result = await agent.run( - "hello", - session=session, - options={"additional_function_arguments": {"session": session}}, - ) + result = await agent.run("hello", session=session) assert result.text == "done" assert captured.get("has_session") is True assert captured.get("has_state") is True +async def test_agent_tool_receives_explicit_session_via_function_invocation_context_kwargs( + chat_client_base: Any, +) -> None: + """Verify ctx-based tools receive the session via FunctionInvocationContext.session.""" + + captured: dict[str, Any] = {} + + @tool(name="capture_session_context", approval_mode="never_require") + def capture_session_context(text: str, ctx: FunctionInvocationContext) -> str: + captured["session"] = ctx.session + captured["has_state"] = ctx.session.state is not None if isinstance(ctx.session, AgentSession) else False + return f"echo: {text}" + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="1", + name="capture_session_context", + arguments='{"text": "hello"}', + ) + ], + ) + ), + ChatResponse(messages=Message(role="assistant", text="done")), + ] + + agent = Agent(client=chat_client_base, tools=[capture_session_context]) + session = agent.create_session() + + result = await agent.run("hello", session=session) + + assert result.text == "done" + assert captured["session"] is session + assert captured["has_state"] is True + + async def test_chat_agent_tool_choice_run_level_overrides_agent_level(chat_client_base: Any, tool_tool: Any) -> None: """Verify that tool_choice passed to run() overrides agent-level tool_choice.""" @@ -1287,7 +1473,7 @@ def test_merge_options_none_values_ignored(): def test_merge_options_tools_combined(): - """Test _merge_options combines tool lists without duplicates.""" + """Test _merge_options raises when distinct tools share the same name.""" class MockTool: def __init__(self, name): @@ -1300,13 +1486,8 @@ def __init__(self, name): base = {"tools": [tool1]} override = {"tools": [tool2, tool3]} - result = _merge_options(base, override) - - # Should have tool1 and tool2, but not duplicate tool3 - assert len(result["tools"]) == 2 - tool_names = [t.name for t in result["tools"]] - assert "tool1" in tool_names - assert "tool2" in tool_names + with raises(ValueError, match="Duplicate tool name 'tool1'"): + _merge_options(base, override) def test_merge_options_dict_tools_combined(): @@ -1331,7 +1512,7 @@ def test_merge_options_dict_tools_combined(): def test_merge_options_dict_tools_deduplicates(): - """Test _merge_options deduplicates dict-defined tools by function name.""" + """Test _merge_options raises on duplicate dict-defined tool names.""" base = { "tools": [ {"type": "function", "function": {"name": "tool_a"}}, @@ -1344,12 +1525,8 @@ def test_merge_options_dict_tools_deduplicates(): ] } - result = _merge_options(base, override) - - assert len(result["tools"]) == 2 - names = [_get_tool_name(t) for t in result["tools"]] - assert names.count("tool_a") == 1 - assert "tool_b" in names + with raises(ValueError, match="Duplicate tool name 'tool_a'"): + _merge_options(base, override) def test_merge_options_mixed_tools_combined(): @@ -1375,7 +1552,7 @@ def __init__(self, name): def test_merge_options_mixed_tools_deduplicates(): - """Test _merge_options deduplicates when a dict tool and object tool share the same name.""" + """Test _merge_options raises when a dict tool and object tool share the same name.""" class MockTool: def __init__(self, name): @@ -1388,10 +1565,8 @@ def __init__(self, name): ] } - result = _merge_options(base, override) - - assert len(result["tools"]) == 1 - assert _get_tool_name(result["tools"][0]) == "tool_a" + with raises(ValueError, match="Duplicate tool name 'tool_a'"): + _merge_options(base, override) def test_merge_options_nameless_tools_not_deduplicated(): @@ -1413,6 +1588,20 @@ def test_merge_options_nameless_tools_not_deduplicated(): assert len(result["tools"]) == 2 +def test_merge_options_same_tool_object_kept_once(): + """Test _merge_options silently keeps a repeated reference to the same tool object once.""" + + class MockTool: + def __init__(self, name): + self.name = name + + tool_a = MockTool("tool_a") + + result = _merge_options({"tools": [tool_a]}, {"tools": [tool_a]}) + + assert result["tools"] == [tool_a] + + def test_get_tool_name_dict_no_function_key(): """_get_tool_name returns None for a dict without a 'function' key.""" assert _get_tool_name({"type": "function"}) is None @@ -1754,4 +1943,26 @@ async def test_stores_by_default_with_store_false_in_default_options_injects_inm assert any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers) -# endregion +# region as_tool user_input_request propagation + + +async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetResponse) -> None: + """Test that as_tool raises when the wrapped sub-agent requests user input.""" + from agent_framework.exceptions import UserInputRequiredException + + consent_content = Content.from_oauth_consent_request( + consent_link="https://login.microsoftonline.com/consent", + ) + client.streaming_responses = [ # type: ignore[attr-defined] + [ChatResponseUpdate(contents=[consent_content], role="assistant")], + ] + + agent = Agent(client=client, name="OAuthAgent", description="Agent requiring consent") + agent_tool = agent.as_tool() + + with raises(UserInputRequiredException) as exc_info: + await agent_tool.invoke(arguments={"task": "Do something"}) + + assert len(exc_info.value.contents) == 1 + assert exc_info.value.contents[0].type == "oauth_consent_request" + assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent" diff --git a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py index da8e907c40f..8aa71a45822 100644 --- a/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py +++ b/python/packages/core/tests/core/test_as_tool_kwargs_propagation.py @@ -6,7 +6,7 @@ from typing import Any from agent_framework import Agent, ChatResponse, Content, Message, agent_middleware -from agent_framework._middleware import AgentContext +from agent_framework._middleware import AgentContext, FunctionInvocationContext from .conftest import MockChatClient @@ -14,14 +14,28 @@ class TestAsToolKwargsPropagation: """Test cases for kwargs propagation through as_tool() delegation.""" + @staticmethod + def _build_context( + tool: Any, + *, + task: str, + runtime_kwargs: dict[str, Any] | None = None, + ) -> FunctionInvocationContext: + return FunctionInvocationContext( + function=tool, + arguments={"task": task}, + kwargs=runtime_kwargs, + ) + async def test_as_tool_forwards_runtime_kwargs(self, client: MockChatClient) -> None: - """Test that runtime kwargs are forwarded through as_tool() to sub-agent.""" + """Test that runtime kwargs are forwarded through as_tool() to sub-agent tools.""" captured_kwargs: dict[str, Any] = {} + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - # Capture kwargs passed to the sub-agent captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock response @@ -39,29 +53,31 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Create tool from sub-agent tool = sub_agent.as_tool(name="delegate", arg_name="task") - # Directly invoke the tool with kwargs (simulating what happens during agent execution) + # Directly invoke the tool with explicit runtime context (simulating agent execution). _ = await tool.invoke( - arguments=tool.input_model(task="Test delegation"), - api_token="secret-xyz-123", - user_id="user-456", - session_id="session-789", + context=self._build_context( + tool, + task="Test delegation", + runtime_kwargs={ + "api_token": "secret-xyz-123", + "user_id": "user-456", + "session_id": "session-789", + }, + ), ) - # Verify kwargs were forwarded to sub-agent - assert "api_token" in captured_kwargs, f"Expected 'api_token' in {captured_kwargs}" - assert captured_kwargs["api_token"] == "secret-xyz-123" - assert "user_id" in captured_kwargs - assert captured_kwargs["user_id"] == "user-456" - assert "session_id" in captured_kwargs - assert captured_kwargs["session_id"] == "session-789" + assert captured_kwargs == {} + assert captured_function_invocation_kwargs["api_token"] == "secret-xyz-123" + assert captured_function_invocation_kwargs["user_id"] == "user-456" + assert captured_function_invocation_kwargs["session_id"] == "session-789" - async def test_as_tool_excludes_arg_name_from_forwarded_kwargs(self, client: MockChatClient) -> None: - """Test that the arg_name parameter is not forwarded as a kwarg.""" - captured_kwargs: dict[str, Any] = {} + async def test_as_tool_forwards_context_kwargs_verbatim(self, client: MockChatClient) -> None: + """Test that runtime kwargs are forwarded exactly from FunctionInvocationContext.kwargs.""" + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock response @@ -79,25 +95,26 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Invoke tool with both the arg_name field and additional kwargs await tool.invoke( - arguments=tool.input_model(custom_task="Test task"), - api_token="token-123", - custom_task="should_be_excluded", # This should be filtered out + context=FunctionInvocationContext( + function=tool, + arguments={"custom_task": "Test task"}, + kwargs={ + "api_token": "token-123", + "custom_task": "should_be_excluded", + }, + ) ) - # The arg_name ("custom_task") should NOT be in the forwarded kwargs - assert "custom_task" not in captured_kwargs - # But other kwargs should be present - assert "api_token" in captured_kwargs - assert captured_kwargs["api_token"] == "token-123" + assert captured_function_invocation_kwargs["custom_task"] == "should_be_excluded" + assert captured_function_invocation_kwargs["api_token"] == "token-123" async def test_as_tool_nested_delegation_propagates_kwargs(self, client: MockChatClient) -> None: - """Test that kwargs propagate through multiple levels of delegation (A → B → C).""" - captured_kwargs_list: list[dict[str, Any]] = [] + """Test that runtime kwargs propagate through multiple levels of delegation (A -> B -> C).""" + captured_function_invocation_kwargs_list: list[dict[str, Any]] = [] @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - # Capture kwargs at each level - captured_kwargs_list.append(dict(context.kwargs)) + captured_function_invocation_kwargs_list.append(dict(context.function_invocation_kwargs)) await call_next() # Setup mock responses to trigger nested tool invocation: B calls tool C, then completes. @@ -140,24 +157,29 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Invoke tool B with kwargs - should propagate to both B and C await tool_b.invoke( - arguments=tool_b.input_model(task="Test cascade"), - trace_id="trace-abc-123", - tenant_id="tenant-xyz", - options={"additional_function_arguments": {"trace_id": "trace-abc-123", "tenant_id": "tenant-xyz"}}, + context=self._build_context( + tool_b, + task="Test cascade", + runtime_kwargs={ + "trace_id": "trace-abc-123", + "tenant_id": "tenant-xyz", + }, + ), ) - # Verify kwargs were forwarded to the first agent invocation. - assert len(captured_kwargs_list) >= 1 - assert captured_kwargs_list[0].get("trace_id") == "trace-abc-123" - assert captured_kwargs_list[0].get("tenant_id") == "tenant-xyz" + assert len(captured_function_invocation_kwargs_list) >= 1 + assert captured_function_invocation_kwargs_list[0].get("trace_id") == "trace-abc-123" + assert captured_function_invocation_kwargs_list[0].get("tenant_id") == "tenant-xyz" async def test_as_tool_streaming_mode_forwards_kwargs(self, client: MockChatClient) -> None: - """Test that kwargs are forwarded in streaming mode.""" + """Test that runtime kwargs are forwarded in streaming mode.""" captured_kwargs: dict[str, Any] = {} + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock streaming responses @@ -182,13 +204,15 @@ async def stream_callback(update: Any) -> None: # Invoke tool with kwargs while streaming callback is active await tool.invoke( - arguments=tool.input_model(task="Test streaming"), - api_key="streaming-key-999", + context=self._build_context( + tool, + task="Test streaming", + runtime_kwargs={"api_key": "streaming-key-999"}, + ), ) - # Verify kwargs were forwarded even in streaming mode - assert "api_key" in captured_kwargs - assert captured_kwargs["api_key"] == "streaming-key-999" + assert captured_kwargs == {} + assert captured_function_invocation_kwargs["api_key"] == "streaming-key-999" assert len(captured_updates) == 1 async def test_as_tool_empty_kwargs_still_works(self, client: MockChatClient) -> None: @@ -206,18 +230,20 @@ async def test_as_tool_empty_kwargs_still_works(self, client: MockChatClient) -> tool = sub_agent.as_tool() # Invoke without any extra kwargs - should work without errors - result = await tool.invoke(arguments=tool.input_model(task="Simple task")) + result = await tool.invoke(arguments={"task": "Simple task"}) # Verify tool executed successfully assert result is not None async def test_as_tool_kwargs_with_chat_options(self, client: MockChatClient) -> None: - """Test that kwargs including chat_options are properly forwarded.""" + """Test that runtime kwargs are forwarded only via function_invocation_kwargs.""" captured_kwargs: dict[str, Any] = {} + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock response @@ -235,24 +261,26 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Invoke with various kwargs await tool.invoke( - arguments=tool.input_model(task="Test with options"), - temperature=0.8, - max_tokens=500, - custom_param="custom_value", + context=self._build_context( + tool, + task="Test with options", + runtime_kwargs={ + "temperature": 0.8, + "max_tokens": 500, + "custom_param": "custom_value", + }, + ), ) - # Verify all kwargs were forwarded - assert "temperature" in captured_kwargs - assert captured_kwargs["temperature"] == 0.8 - assert "max_tokens" in captured_kwargs - assert captured_kwargs["max_tokens"] == 500 - assert "custom_param" in captured_kwargs - assert captured_kwargs["custom_param"] == "custom_value" + assert captured_kwargs == {} + assert captured_function_invocation_kwargs["temperature"] == 0.8 + assert captured_function_invocation_kwargs["max_tokens"] == 500 + assert captured_function_invocation_kwargs["custom_param"] == "custom_value" async def test_as_tool_kwargs_isolated_per_invocation(self, client: MockChatClient) -> None: - """Test that kwargs are isolated per invocation and don't leak between calls.""" - first_call_kwargs: dict[str, Any] = {} - second_call_kwargs: dict[str, Any] = {} + """Test that runtime kwargs are isolated per invocation and don't leak between calls.""" + first_call_function_invocation_kwargs: dict[str, Any] = {} + second_call_function_invocation_kwargs: dict[str, Any] = {} call_count = 0 @agent_middleware @@ -260,9 +288,9 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai nonlocal call_count call_count += 1 if call_count == 1: - first_call_kwargs.update(context.kwargs) + first_call_function_invocation_kwargs.update(context.function_invocation_kwargs) elif call_count == 2: - second_call_kwargs.update(context.kwargs) + second_call_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock responses for both calls @@ -281,33 +309,35 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # First call with specific kwargs await tool.invoke( - arguments=tool.input_model(task="First task"), - session_id="session-1", - api_token="token-1", + context=self._build_context( + tool, + task="First task", + runtime_kwargs={"session_id": "session-1", "api_token": "token-1"}, + ), ) # Second call with different kwargs await tool.invoke( - arguments=tool.input_model(task="Second task"), - session_id="session-2", - api_token="token-2", + context=self._build_context( + tool, + task="Second task", + runtime_kwargs={"session_id": "session-2", "api_token": "token-2"}, + ), ) - # Verify first call had its own kwargs - assert first_call_kwargs.get("session_id") == "session-1" - assert first_call_kwargs.get("api_token") == "token-1" + assert first_call_function_invocation_kwargs.get("session_id") == "session-1" + assert first_call_function_invocation_kwargs.get("api_token") == "token-1" - # Verify second call had its own kwargs (not leaked from first) - assert second_call_kwargs.get("session_id") == "session-2" - assert second_call_kwargs.get("api_token") == "token-2" + assert second_call_function_invocation_kwargs.get("session_id") == "session-2" + assert second_call_function_invocation_kwargs.get("api_token") == "token-2" - async def test_as_tool_excludes_conversation_id_from_forwarded_kwargs(self, client: MockChatClient) -> None: - """Test that conversation_id is not forwarded to sub-agent.""" - captured_kwargs: dict[str, Any] = {} + async def test_as_tool_forwards_conversation_id_from_context_kwargs(self, client: MockChatClient) -> None: + """Test that conversation_id is forwarded when explicitly present in runtime context kwargs.""" + captured_function_invocation_kwargs: dict[str, Any] = {} @agent_middleware async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None: - captured_kwargs.update(context.kwargs) + captured_function_invocation_kwargs.update(context.function_invocation_kwargs) await call_next() # Setup mock response @@ -325,17 +355,17 @@ async def capture_middleware(context: AgentContext, call_next: Callable[[], Awai # Invoke tool with conversation_id in kwargs (simulating parent's conversation state) await tool.invoke( - arguments=tool.input_model(task="Test delegation"), - conversation_id="conv-parent-456", - api_token="secret-xyz-123", - user_id="user-456", - ) - - # Verify conversation_id was NOT forwarded to sub-agent - assert "conversation_id" not in captured_kwargs, ( - f"conversation_id should not be forwarded, but got: {captured_kwargs}" + context=self._build_context( + tool, + task="Test delegation", + runtime_kwargs={ + "conversation_id": "conv-parent-456", + "api_token": "secret-xyz-123", + "user_id": "user-456", + }, + ), ) - # Verify other kwargs were still forwarded - assert captured_kwargs.get("api_token") == "secret-xyz-123" - assert captured_kwargs.get("user_id") == "user-456" + assert captured_function_invocation_kwargs.get("conversation_id") == "conv-parent-456" + assert captured_function_invocation_kwargs.get("api_token") == "secret-xyz-123" + assert captured_function_invocation_kwargs.get("user_id") == "user-456" diff --git a/python/packages/core/tests/core/test_clients.py b/python/packages/core/tests/core/test_clients.py index b060b183fb4..7e150c47c68 100644 --- a/python/packages/core/tests/core/test_clients.py +++ b/python/packages/core/tests/core/test_clients.py @@ -1,9 +1,12 @@ # Copyright (c) Microsoft. All rights reserved. +import inspect from typing import Any from unittest.mock import patch +import pytest + from agent_framework import ( GROUP_ANNOTATION_KEY, GROUP_TOKEN_COUNT_KEY, @@ -50,6 +53,60 @@ def test_base_client(chat_client_base: SupportsChatGetResponse): assert isinstance(chat_client_base, SupportsChatGetResponse) +def test_base_client_warns_for_direct_additional_properties(chat_client_base: SupportsChatGetResponse) -> None: + with pytest.warns(DeprecationWarning, match="additional_properties"): + client = type(chat_client_base)(legacy_key="legacy-value") + + assert client.additional_properties["legacy_key"] == "legacy-value" + + +def test_base_client_as_agent_uses_explicit_additional_properties(chat_client_base: SupportsChatGetResponse) -> None: + agent = chat_client_base.as_agent(additional_properties={"team": "core"}) + + assert agent.additional_properties == {"team": "core"} + + +def test_openai_chat_client_get_response_docstring_surfaces_layered_runtime_docs() -> None: + from agent_framework.openai import OpenAIChatClient + + docstring = inspect.getdoc(OpenAIChatClient.get_response) + + assert docstring is not None + assert "Get a response from a chat client." in docstring + assert "function_invocation_kwargs" in docstring + assert "function_middleware: Optional per-call function middleware." in docstring + assert "middleware: Optional per-call chat and function middleware." in docstring + + +def test_openai_chat_client_get_response_is_defined_on_openai_class() -> None: + from agent_framework.openai import OpenAIChatClient + + signature = inspect.signature(OpenAIChatClient.get_response) + + assert OpenAIChatClient.get_response.__qualname__ == "OpenAIChatClient.get_response" + assert "function_middleware" in signature.parameters + assert "middleware" in signature.parameters + + +async def test_base_client_get_response_uses_explicit_client_kwargs(chat_client_base: SupportsChatGetResponse) -> None: + async def fake_inner_get_response(**kwargs): + assert kwargs["trace_id"] == "trace-123" + assert "function_invocation_kwargs" not in kwargs + return ChatResponse(messages=[Message(role="assistant", text="ok")]) + + with patch.object( + chat_client_base, + "_inner_get_response", + side_effect=fake_inner_get_response, + ) as mock_inner_get_response: + await chat_client_base.get_response( + [Message(role="user", text="hello")], + function_invocation_kwargs={"tool_request_id": "tool-123"}, + client_kwargs={"trace_id": "trace-123"}, + ) + mock_inner_get_response.assert_called_once() + + async def test_base_client_get_response(chat_client_base: SupportsChatGetResponse): response = await chat_client_base.get_response([Message(role="user", text="Hello")]) assert response.messages[0].role == "assistant" diff --git a/python/packages/core/tests/core/test_docstrings.py b/python/packages/core/tests/core/test_docstrings.py new file mode 100644 index 00000000000..ab4b1164220 --- /dev/null +++ b/python/packages/core/tests/core/test_docstrings.py @@ -0,0 +1,175 @@ +# Copyright (c) Microsoft. All rights reserved. + +from agent_framework._docstrings import apply_layered_docstring, build_layered_docstring + +# -- Helpers: stub functions with various docstring shapes -- + + +def _source_with_full_docstring(x: int) -> int: + """Do something useful. + + Args: + x: The input value. + + Keyword Args: + timeout: Max seconds to wait. + + Returns: + The computed result. + """ + return x + + +def _source_with_args_only(x: int) -> int: + """Do something useful. + + Args: + x: The input value. + + Returns: + The computed result. + """ + return x + + +def _source_no_sections() -> None: + """A plain summary with no Google-style sections.""" + + +def _source_no_docstring() -> None: + pass + + +def _target_stub() -> None: + pass + + +# -- build_layered_docstring tests -- + + +def test_build_returns_none_when_source_has_no_docstring() -> None: + result = build_layered_docstring(_source_no_docstring) + assert result is None + + +def test_build_returns_original_when_no_extra_kwargs() -> None: + result = build_layered_docstring(_source_with_full_docstring) + assert result is not None + assert "Do something useful." in result + assert "Keyword Args:" in result + + +def test_build_returns_original_when_extra_kwargs_empty() -> None: + result = build_layered_docstring(_source_with_full_docstring, extra_keyword_args={}) + assert result is not None + assert result == build_layered_docstring(_source_with_full_docstring) + + +def test_build_appends_to_existing_keyword_args_section() -> None: + result = build_layered_docstring( + _source_with_full_docstring, + extra_keyword_args={"retries": "Number of retries."}, + ) + assert result is not None + assert "timeout: Max seconds to wait." in result + assert "retries: Number of retries." in result + # Both should be under Keyword Args + lines = result.splitlines() + kw_index = next(i for i, line in enumerate(lines) if line == "Keyword Args:") + ret_index = next(i for i, line in enumerate(lines) if line == "Returns:") + retries_index = next(i for i, line in enumerate(lines) if "retries:" in line) + assert kw_index < retries_index < ret_index + + +def test_build_inserts_keyword_args_after_args_section() -> None: + result = build_layered_docstring( + _source_with_args_only, + extra_keyword_args={"verbose": "Enable verbose output."}, + ) + assert result is not None + assert "Keyword Args:" in result + assert "verbose: Enable verbose output." in result + lines = result.splitlines() + args_index = next(i for i, line in enumerate(lines) if line == "Args:") + kw_index = next(i for i, line in enumerate(lines) if line == "Keyword Args:") + ret_index = next(i for i, line in enumerate(lines) if line == "Returns:") + assert args_index < kw_index < ret_index + + +def test_build_inserts_keyword_args_in_docstring_with_no_sections() -> None: + result = build_layered_docstring( + _source_no_sections, + extra_keyword_args={"debug": "Enable debug mode."}, + ) + assert result is not None + assert "A plain summary" in result + assert "Keyword Args:" in result + assert "debug: Enable debug mode." in result + + +def test_build_handles_multiline_descriptions() -> None: + result = build_layered_docstring( + _source_with_args_only, + extra_keyword_args={ + "config": "The configuration object.\nMust be a valid mapping.\nDefaults to empty.", + }, + ) + assert result is not None + lines = result.splitlines() + config_line = next(line for line in lines if "config:" in line) + assert "The configuration object." in config_line + # Continuation lines should be indented + config_idx = lines.index(config_line) + assert "Must be a valid mapping." in lines[config_idx + 1] + assert "Defaults to empty." in lines[config_idx + 2] + + +def test_build_preserves_multiple_extra_kwargs_order() -> None: + result = build_layered_docstring( + _source_with_args_only, + extra_keyword_args={ + "alpha": "First.", + "beta": "Second.", + "gamma": "Third.", + }, + ) + assert result is not None + lines = result.splitlines() + alpha_idx = next(i for i, line in enumerate(lines) if "alpha:" in line) + beta_idx = next(i for i, line in enumerate(lines) if "beta:" in line) + gamma_idx = next(i for i, line in enumerate(lines) if "gamma:" in line) + assert alpha_idx < beta_idx < gamma_idx + + +# -- apply_layered_docstring tests -- + + +def test_apply_sets_docstring_on_target() -> None: + def target() -> None: + pass + + apply_layered_docstring(target, _source_with_full_docstring) + assert target.__doc__ is not None + assert "Do something useful." in target.__doc__ + + +def test_apply_with_extra_kwargs() -> None: + def target() -> None: + pass + + apply_layered_docstring( + target, + _source_with_args_only, + extra_keyword_args={"flag": "A boolean flag."}, + ) + assert target.__doc__ is not None + assert "flag: A boolean flag." in target.__doc__ + assert "Keyword Args:" in target.__doc__ + + +def test_apply_sets_none_when_source_has_no_docstring() -> None: + def target() -> None: + """Original.""" + + apply_layered_docstring(target, _source_no_docstring) + assert target.__doc__ is None diff --git a/python/packages/core/tests/core/test_embedding_client.py b/python/packages/core/tests/core/test_embedding_client.py index 71d2bcfd70e..1c49c1d0129 100644 --- a/python/packages/core/tests/core/test_embedding_client.py +++ b/python/packages/core/tests/core/test_embedding_client.py @@ -4,6 +4,8 @@ from collections.abc import Sequence +import pytest + from agent_framework import ( BaseEmbeddingClient, Embedding, @@ -63,6 +65,11 @@ def test_base_additional_properties_custom() -> None: assert client.additional_properties == {"key": "value"} +def test_base_embedding_client_rejects_unknown_kwargs() -> None: + with pytest.raises(TypeError): + MockEmbeddingClient(legacy_key="value") # type: ignore[call-arg] + + # --- SupportsGetEmbeddings protocol tests --- diff --git a/python/packages/core/tests/core/test_function_invocation_logic.py b/python/packages/core/tests/core/test_function_invocation_logic.py index 59c932f9465..3c610402891 100644 --- a/python/packages/core/tests/core/test_function_invocation_logic.py +++ b/python/packages/core/tests/core/test_function_invocation_logic.py @@ -3651,3 +3651,131 @@ def test_dict_overwrites_existing_conversation_id(self): # endregion +async def test_user_input_request_propagates_through_as_tool(chat_client_base: SupportsChatGetResponse): + """Test that user_input_request content from a sub-agent wrapped as a tool propagates to the parent response.""" + from agent_framework.exceptions import UserInputRequiredException + + @tool(name="delegate_agent", approval_mode="never_require") + def delegate_tool(task: str) -> str: + del task + raise UserInputRequiredException( + contents=[ + Content.from_oauth_consent_request( + consent_link="https://login.microsoftonline.com/consent", + ) + ] + ) + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="1", name="delegate_agent", arguments='{"task": "do it"}'), + ], + ) + ) + ] + + response = await chat_client_base.get_response( + [Message(role="user", text="delegate this")], + options={"tool_choice": "auto", "tools": [delegate_tool]}, + ) + + user_requests = [ + content + for msg in response.messages + for content in msg.contents + if isinstance(content, Content) and content.user_input_request + ] + assert len(user_requests) == 1 + assert user_requests[0].type == "oauth_consent_request" + assert user_requests[0].consent_link == "https://login.microsoftonline.com/consent" + assert user_requests[0].user_input_request is True + + +async def test_user_input_request_multiple_contents_propagate(chat_client_base: SupportsChatGetResponse): + """Test that multiple user_input_request items in a single exception all propagate to the parent response.""" + from agent_framework.exceptions import UserInputRequiredException + + @tool(name="multi_request_tool", approval_mode="never_require") + def multi_request(task: str) -> str: + del task + raise UserInputRequiredException( + contents=[ + Content.from_oauth_consent_request( + consent_link="https://example.com/consent1", + ), + Content.from_oauth_consent_request( + consent_link="https://example.com/consent2", + ), + Content.from_oauth_consent_request( + consent_link="https://example.com/consent3", + ), + ] + ) + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="1", name="multi_request_tool", arguments='{"task": "do it"}'), + ], + ) + ) + ] + + response = await chat_client_base.get_response( + [Message(role="user", text="do something")], + options={"tool_choice": "auto", "tools": [multi_request]}, + ) + + user_requests = [ + content + for msg in response.messages + for content in msg.contents + if isinstance(content, Content) and content.user_input_request + ] + assert len(user_requests) == 3 + consent_links = {r.consent_link for r in user_requests} + assert consent_links == { + "https://example.com/consent1", + "https://example.com/consent2", + "https://example.com/consent3", + } + + +async def test_user_input_request_empty_contents_returns_fallback(chat_client_base: SupportsChatGetResponse): + """Test that UserInputRequiredException with empty contents produces a fallback function_result.""" + from agent_framework.exceptions import UserInputRequiredException + + @tool(name="empty_request_tool", approval_mode="never_require") + def empty_request(task: str) -> str: + del task + raise UserInputRequiredException(contents=[]) + + chat_client_base.run_responses = [ + ChatResponse( + messages=Message( + role="assistant", + contents=[ + Content.from_function_call(call_id="1", name="empty_request_tool", arguments='{"task": "do it"}'), + ], + ) + ), + ChatResponse(messages=Message(role="assistant", text="handled")), + ] + + response = await chat_client_base.get_response( + [Message(role="user", text="do something")], + options={"tool_choice": "auto", "tools": [empty_request]}, + ) + + # With empty contents, the handler returns a function_result with an error message + # and the loop continues to the next chat response. + function_results = [ + content for msg in response.messages for content in msg.contents if content.type == "function_result" + ] + assert len(function_results) >= 1 + assert any("user input" in (fr.result or "").lower() for fr in function_results) diff --git a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py index cecd466d864..160ea0fcc40 100644 --- a/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py +++ b/python/packages/core/tests/core/test_kwargs_propagation_to_ai_function.py @@ -6,11 +6,13 @@ from typing import Any from agent_framework import ( + Agent, BaseChatClient, ChatMiddlewareLayer, ChatResponse, ChatResponseUpdate, Content, + FunctionInvocationContext, FunctionInvocationLayer, Message, ResponseStream, @@ -97,6 +99,7 @@ class TestKwargsPropagationToFunctionTool: async def test_kwargs_propagate_to_tool_with_kwargs(self) -> None: """Test that kwargs passed to get_response() are available in @tool **kwargs.""" + # TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed. captured_kwargs: dict[str, Any] = {} @tool(approval_mode="never_require") @@ -149,6 +152,7 @@ def capture_kwargs_tool(x: int, **kwargs: Any) -> str: async def test_kwargs_not_forwarded_to_tool_without_kwargs(self) -> None: """Test that kwargs are NOT forwarded to @tool that doesn't accept **kwargs.""" + # TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed. @tool(approval_mode="never_require") def simple_tool(x: int) -> str: @@ -185,6 +189,7 @@ def simple_tool(x: int) -> str: async def test_kwargs_isolated_between_function_calls(self) -> None: """Test that kwargs are consistent across multiple function call invocations.""" + # TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed. invocation_kwargs: list[dict[str, Any]] = [] @tool(approval_mode="never_require") @@ -235,6 +240,7 @@ def tracking_tool(name: str, **kwargs: Any) -> str: async def test_streaming_response_kwargs_propagation(self) -> None: """Test that kwargs propagate to @tool in streaming mode.""" + # TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed. captured_kwargs: dict[str, Any] = {} @tool(approval_mode="never_require") @@ -287,3 +293,59 @@ def streaming_capture_tool(value: str, **kwargs: Any) -> str: assert "streaming_session" in captured_kwargs, f"Expected 'streaming_session' in {captured_kwargs}" assert captured_kwargs["streaming_session"] == "session-xyz" assert captured_kwargs["correlation_id"] == "corr-123" + + async def test_agent_run_injects_function_invocation_context(self) -> None: + """Test that Agent.run injects FunctionInvocationContext for ctx-based tools.""" + captured_context_kwargs: dict[str, Any] = {} + captured_client_kwargs: dict[str, Any] = {} + captured_options: dict[str, Any] = {} + + @tool(approval_mode="never_require") + def capture_context_tool(x: int, ctx: FunctionInvocationContext) -> str: + captured_context_kwargs.update(ctx.kwargs) + return f"result: x={x}" + + class CapturingFunctionInvokingMockClient(FunctionInvokingMockClient): + async def _get_non_streaming_response( + self, + *, + messages: MutableSequence[Message], + options: dict[str, Any], + **kwargs: Any, + ) -> ChatResponse: + captured_options.update(options) + captured_client_kwargs.update(kwargs) + return await super()._get_non_streaming_response(messages=messages, options=options, **kwargs) + + client = CapturingFunctionInvokingMockClient() + client.run_responses = [ + ChatResponse( + messages=[ + Message( + role="assistant", + contents=[ + Content.from_function_call( + call_id="call_1", + name="capture_context_tool", + arguments='{"x": 42}', + ) + ], + ) + ] + ), + ChatResponse(messages=[Message(role="assistant", text="Done!")]), + ] + + agent = Agent(client=client, tools=[capture_context_tool]) + result = await agent.run( + [Message(role="user", text="Test")], + function_invocation_kwargs={"tool_request_id": "tool-123"}, + client_kwargs={"client_request_id": "client-456"}, + ) + + assert captured_context_kwargs["tool_request_id"] == "tool-123" + assert "client_request_id" not in captured_context_kwargs + assert captured_client_kwargs["client_request_id"] == "client-456" + assert "tool_request_id" not in captured_client_kwargs + assert "additional_function_arguments" not in captured_options + assert result.messages[-1].text == "Done!" diff --git a/python/packages/core/tests/core/test_mcp.py b/python/packages/core/tests/core/test_mcp.py index 867e7183cfb..df3187673a3 100644 --- a/python/packages/core/tests/core/test_mcp.py +++ b/python/packages/core/tests/core/test_mcp.py @@ -53,6 +53,81 @@ def test_normalize_mcp_name(): assert _normalize_mcp_name("name/with\\slashes") == "name-with-slashes" +def test_mcp_transport_subclasses_accept_tool_name_prefix() -> None: + assert MCPStdioTool(name="stdio", command="python", tool_name_prefix="stdio").tool_name_prefix == "stdio" + assert ( + MCPStreamableHTTPTool( + name="http", + url="https://example.com/mcp", + tool_name_prefix="http", + ).tool_name_prefix + == "http" + ) + assert ( + MCPWebsocketTool( + name="ws", + url="wss://example.com/mcp", + tool_name_prefix="ws", + ).tool_name_prefix + == "ws" + ) + + +async def test_load_tools_with_tool_name_prefix_preserves_matching_configuration(): + """Prefixed MCP tool names should still honor unprefixed allow/approval configuration.""" + tool = MCPTool( + name="docs", + tool_name_prefix="docs", + allowed_tools=["search_docs"], + approval_mode={"always_require_approval": ["search_docs"]}, + ) + + mock_session = AsyncMock() + tool.session = mock_session + tool.load_tools_flag = True + + page = Mock() + page.tools = [ + types.Tool( + name="search_docs", + description="Search docs", + inputSchema={"type": "object", "properties": {"query": {"type": "string"}}}, + ), + ] + page.nextCursor = None + mock_session.list_tools = AsyncMock(return_value=page) + + await tool.load_tools() + + assert [function.name for function in tool._functions] == ["docs_search_docs"] + assert [function.name for function in tool.functions] == ["docs_search_docs"] + assert tool.functions[0].approval_mode == "always_require" + + +async def test_load_prompts_with_tool_name_prefix() -> None: + """Prefixed MCP prompt names should be exposed with the configured prefix.""" + tool = MCPTool(name="docs", tool_name_prefix="docs") + + mock_session = AsyncMock() + tool.session = mock_session + tool.load_prompts_flag = True + + page = Mock() + page.prompts = [ + types.Prompt( + name="summarize docs", + description="Summarize docs", + arguments=[types.PromptArgument(name="topic", description="Topic", required=True)], + ), + ] + page.nextCursor = None + mock_session.list_prompts = AsyncMock(return_value=page) + + await tool.load_prompts() + + assert [function.name for function in tool._functions] == ["docs_summarize-docs"] + + def test_mcp_prompt_message_to_ai_content(): """Test conversion from MCP prompt message to AI content.""" mcp_message = types.PromptMessage(role="user", content=types.TextContent(type="text", text="Hello, world!")) @@ -67,30 +142,31 @@ def test_mcp_prompt_message_to_ai_content(): def test_parse_tool_result_from_mcp(): - """Test conversion from MCP tool result to string representation.""" + """Test conversion from MCP tool result with images preserves original order.""" mcp_result = types.CallToolResult( content=[ types.TextContent(type="text", text="Result text"), types.ImageContent(type="image", data="eHl6", mimeType="image/png"), + types.TextContent(type="text", text="After image"), types.ImageContent(type="image", data="YWJj", mimeType="image/webp"), ] ) result = _parse_tool_result_from_mcp(mcp_result) - # Multiple items produce a JSON array of strings - assert isinstance(result, str) - import json - - parsed = json.loads(result) - assert len(parsed) == 3 - assert parsed[0] == "Result text" - # Image items are JSON-encoded strings within the array - img1 = json.loads(parsed[1]) - assert img1["type"] == "image" - assert img1["data"] == "eHl6" - img2 = json.loads(parsed[2]) - assert img2["type"] == "image" - assert img2["data"] == "YWJj" + # Results with images return a list of Content objects in original order + assert isinstance(result, list) + assert len(result) == 4 + # Order is preserved: text, image, text, image + assert result[0].type == "text" + assert result[0].text == "Result text" + assert result[1].type == "data" + assert result[1].media_type == "image/png" + assert "eHl6" in result[1].uri + assert result[2].type == "text" + assert result[2].text == "After image" + assert result[3].type == "data" + assert result[3].media_type == "image/webp" + assert "YWJj" in result[3].uri def test_parse_tool_result_from_mcp_single_text(): @@ -98,26 +174,73 @@ def test_parse_tool_result_from_mcp_single_text(): mcp_result = types.CallToolResult(content=[types.TextContent(type="text", text="Simple result")]) result = _parse_tool_result_from_mcp(mcp_result) - # Single text item returns just the text - assert result == "Simple result" + # Single text item returns list with one text Content + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].type == "text" + assert result[0].text == "Simple result" def test_parse_tool_result_from_mcp_meta_not_in_string(): - """Test that _meta data is not included in the string result (it's tool-level, not content-level).""" + """Test that _meta data is not included in the result (it's tool-level, not content-level).""" mcp_result = types.CallToolResult( content=[types.TextContent(type="text", text="Error occurred")], _meta={"isError": True, "errorCode": "TOOL_ERROR"}, ) result = _parse_tool_result_from_mcp(mcp_result) - assert result == "Error occurred" + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].text == "Error occurred" def test_parse_tool_result_from_mcp_empty_content(): - """Test that empty content produces empty string.""" + """Test that empty content produces list with empty text Content.""" mcp_result = types.CallToolResult(content=[]) result = _parse_tool_result_from_mcp(mcp_result) - assert result == "" + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].type == "text" + assert result[0].text == "" + + +def test_parse_tool_result_from_mcp_audio_content(): + """Test conversion from MCP tool result with audio returns rich content list.""" + mcp_result = types.CallToolResult( + content=[ + types.AudioContent(type="audio", data="YXVkaW8=", mimeType="audio/wav"), + ] + ) + result = _parse_tool_result_from_mcp(mcp_result) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].type == "data" + assert result[0].media_type == "audio/wav" + assert "YXVkaW8=" in result[0].uri + + +def test_parse_tool_result_from_mcp_blob_plain_base64(): + """Test that plain base64 blob (without data: prefix) is wrapped into a data URI.""" + mcp_result = types.CallToolResult( + content=[ + types.EmbeddedResource( + type="resource", + resource=types.BlobResourceContents( + uri=AnyUrl("file://test.bin"), + mimeType="application/pdf", + blob="dGVzdCBkYXRh", + ), + ), + ] + ) + result = _parse_tool_result_from_mcp(mcp_result) + + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].type == "data" + assert result[0].media_type == "application/pdf" + assert "dGVzdCBkYXRh" in result[0].uri def test_mcp_content_types_to_ai_content_text(): @@ -769,7 +892,10 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: func = server.functions[0] result = await func.invoke(param="test_value") - assert result == "Tool executed with metadata" + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].type == "text" + assert result[0].text == "Tool executed with metadata" async def test_local_mcp_server_function_execution(): @@ -808,7 +934,8 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: func = server.functions[0] result = await func.invoke(param="test_value") - assert result == "Tool executed successfully" + assert isinstance(result, list) + assert result[0].text == "Tool executed successfully" async def test_local_mcp_server_function_execution_with_nested_object(): @@ -855,7 +982,8 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: # Call with nested object result = await func.invoke(params={"customer_id": 251}) - assert result == '{"name": "John Doe", "id": 251}' + assert isinstance(result, list) + assert result[0].text == '{"name": "John Doe", "id": 251}' # Verify the session.call_tool was called with the correct nested structure server.session.call_tool.assert_called_once() @@ -977,7 +1105,8 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: await server.load_tools() func = server.functions[0] result = await func.invoke(param="test_value") - assert result == "Success" + assert isinstance(result, list) + assert result[0].text == "Success" async def test_mcp_tool_is_error_propagates_through_function_middleware(): @@ -1080,7 +1209,8 @@ def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]: prompt = server.functions[0] result = await prompt.invoke(arg="test_value") - assert result == "Test message" + assert isinstance(result, list) + assert result[0].text == "Test message" @pytest.mark.parametrize( diff --git a/python/packages/core/tests/core/test_observability.py b/python/packages/core/tests/core/test_observability.py index 838efcb0e98..7f82457c3d3 100644 --- a/python/packages/core/tests/core/test_observability.py +++ b/python/packages/core/tests/core/test_observability.py @@ -11,6 +11,7 @@ from agent_framework import ( AGENT_FRAMEWORK_USER_AGENT, + Agent, AgentResponse, BaseChatClient, ChatResponse, @@ -472,10 +473,10 @@ class MockChatClientAgent(AgentTelemetryLayer, _MockChatClientAgent): @pytest.mark.parametrize("enable_sensitive_data", [True, False], indirect=True) -async def test_agent_instrumentation_enabled( +async def test_agent_span_captures_response_telemetry_without_inner_chat_span( mock_chat_agent: SupportsAgentRun, span_exporter: InMemorySpanExporter, enable_sensitive_data ): - """Test that when agent diagnostics are enabled, telemetry is applied.""" + """Agent spans should retain response telemetry when no inner chat span owns it.""" agent = mock_chat_agent() @@ -491,6 +492,7 @@ async def test_agent_instrumentation_enabled( assert span.attributes[OtelAttr.AGENT_NAME] == "test_agent" assert span.attributes[OtelAttr.AGENT_DESCRIPTION] == "Test agent description" assert span.attributes[OtelAttr.REQUEST_MODEL] == "TestModel" + assert span.attributes[OtelAttr.RESPONSE_ID] == "test_response_id" assert span.attributes[OtelAttr.INPUT_TOKENS] == 15 assert span.attributes[OtelAttr.OUTPUT_TOKENS] == 25 if enable_sensitive_data: @@ -1433,6 +1435,24 @@ def test_get_response_attributes_capture_usage_false(): assert OtelAttr.OUTPUT_TOKENS not in result +def test_get_response_attributes_capture_response_id_false(): + """Test _get_response_attributes skips response_id when capture_response_id is False.""" + from unittest.mock import Mock + + from agent_framework.observability import OtelAttr, _get_response_attributes + + response = Mock() + response.response_id = "resp_123" + response.finish_reason = None + response.raw_representation = None + response.usage_details = None + + attrs = {} + result = _get_response_attributes(attrs, response, capture_response_id=False) + + assert OtelAttr.RESPONSE_ID not in result + + # region Test _get_exporters_from_env @@ -2263,6 +2283,81 @@ async def _get() -> ChatResponse: assert sorted_spans[2].name.startswith("chat"), f"Third span should be 'chat', got '{sorted_spans[2].name}'" +@pytest.mark.parametrize("stream", [False, True]) +async def test_agent_and_chat_spans_do_not_duplicate_response_telemetry( + span_exporter: InMemorySpanExporter, stream: bool +): + """Only the inner chat span should own response-id and usage telemetry.""" + + class NestedTelemetryChatClient(ChatTelemetryLayer, BaseChatClient[Any]): + def service_url(self): + return "https://test.example.com" + + def _inner_get_response( + self, *, messages: MutableSequence[Message], stream: bool, options: dict[str, Any], **kwargs: Any + ) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]: + if stream: + + async def _stream() -> AsyncIterable[ChatResponseUpdate]: + yield ChatResponseUpdate(contents=[Content.from_text("Nested")], role="assistant") + yield ChatResponseUpdate(contents=[Content.from_text(" response")], role="assistant") + + def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse: + return ChatResponse( + messages=[Message(role="assistant", text="Nested response")], + response_id="nested_resp_123", + usage_details=UsageDetails(input_token_count=11, output_token_count=22), + finish_reason="stop", + ) + + return ResponseStream(_stream(), finalizer=_finalize) + + async def _get() -> ChatResponse: + return ChatResponse( + messages=[Message(role="assistant", text="Nested response")], + response_id="nested_resp_123", + usage_details=UsageDetails(input_token_count=11, output_token_count=22), + finish_reason="stop", + ) + + return _get() + + agent = Agent( + client=NestedTelemetryChatClient(), + id="nested_agent_id", + name="nested_agent", + description="Nested telemetry agent", + default_options={"model_id": "NestedModel"}, + ) + + span_exporter.clear() + + if stream: + result_stream = agent.run("Test message", stream=True) + async for _ in result_stream: + pass + response = await result_stream.get_final_response() + else: + response = await agent.run("Test message") + + assert response is not None + + spans = span_exporter.get_finished_spans() + assert len(spans) == 2 + + span_by_operation = {span.attributes[OtelAttr.OPERATION.value]: span for span in spans} + agent_span = span_by_operation[OtelAttr.AGENT_INVOKE_OPERATION] + chat_span = span_by_operation[OtelAttr.CHAT_COMPLETION_OPERATION] + + assert chat_span.attributes[OtelAttr.RESPONSE_ID] == "nested_resp_123" + assert chat_span.attributes[OtelAttr.INPUT_TOKENS] == 11 + assert chat_span.attributes[OtelAttr.OUTPUT_TOKENS] == 22 + + assert OtelAttr.RESPONSE_ID not in agent_span.attributes + assert OtelAttr.INPUT_TOKENS not in agent_span.attributes + assert OtelAttr.OUTPUT_TOKENS not in agent_span.attributes + + # region Test non-ASCII character handling in JSON serialization @@ -2385,7 +2480,8 @@ def echo(text: str) -> str: span_exporter.clear() result = await echo.invoke(text=arabic_text) - assert result == arabic_text + assert isinstance(result, list) + assert result[0].text == arabic_text spans = span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0] diff --git a/python/packages/core/tests/core/test_sessions.py b/python/packages/core/tests/core/test_sessions.py index 4d2e6032743..bd2cb8155ef 100644 --- a/python/packages/core/tests/core/test_sessions.py +++ b/python/packages/core/tests/core/test_sessions.py @@ -192,10 +192,10 @@ def __init__(self, source_id: str, stored_messages: list[Message] | None = None, self.stored: list[Message] = [] self._stored_messages = stored_messages or [] - async def get_messages(self, session_id: str | None, **kwargs) -> list[Message]: + async def get_messages(self, session_id: str | None, *, state=None, **kwargs) -> list[Message]: return list(self._stored_messages) - async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs) -> None: + async def save_messages(self, session_id: str | None, messages: Sequence[Message], *, state=None, **kwargs) -> None: self.stored.extend(messages) diff --git a/python/packages/core/tests/core/test_tools.py b/python/packages/core/tests/core/test_tools.py index f7674edc9be..859a012e1de 100644 --- a/python/packages/core/tests/core/test_tools.py +++ b/python/packages/core/tests/core/test_tools.py @@ -12,6 +12,7 @@ FunctionTool, tool, ) +from agent_framework._middleware import FunctionInvocationContext from agent_framework._tools import ( _parse_annotation, _parse_inputs, @@ -124,7 +125,8 @@ def search(query: str, max_results: int = 10) -> str: return f"{query}:{max_results}" result = await search.invoke(arguments={"query": "hello", "max_results": 3}) - assert result == "hello:3" + assert isinstance(result, list) + assert result[0].text == "hello:3" async def test_tool_decorator_with_json_schema_invoke_missing_required(): @@ -221,7 +223,8 @@ def calculate(a: int, b: int) -> int: return a + b result = await calculate.invoke(arguments=CalcInput(a=3, b=7)) - assert result == "10" + assert isinstance(result, list) + assert result[0].text == "10" def test_tool_decorator_with_schema_overrides_annotations(): @@ -492,11 +495,13 @@ def multiply(self, factor: int) -> str: # Test with invoke method as well (simulating agent execution) result6 = await increment_tool.invoke(amount=5) - assert result6 == "Counter incremented by 5. New value: 60" + assert isinstance(result6, list) + assert result6[0].text == "Counter incremented by 5. New value: 60" assert counter_instance.counter == 60 result7 = await get_value_tool.invoke() - assert result7 == "Current counter value: 60" + assert isinstance(result7, list) + assert result7[0].text == "Current counter value: 60" assert counter_instance.counter == 60 @@ -519,7 +524,8 @@ def telemetry_test_tool(x: int, y: int) -> int: result = await telemetry_test_tool.invoke(x=1, y=2, tool_call_id="test_call_id") # Verify result - assert result == "3" + assert isinstance(result, list) + assert result[0].text == "3" # Verify telemetry calls spans = span_exporter.get_finished_spans() @@ -563,7 +569,8 @@ def telemetry_test_tool(x: int, y: int) -> int: result = await telemetry_test_tool.invoke(x=1, y=2, tool_call_id="test_call_id") # Verify result - assert result == "3" + assert isinstance(result, list) + assert result[0].text == "3" # Verify telemetry calls spans = span_exporter.get_finished_spans() @@ -604,7 +611,8 @@ async def simple_tool(message: str) -> str: options={"model_id": "dummy"}, ) - assert result == "HELLO WORLD" + assert isinstance(result, list) + assert result[0].text == "HELLO WORLD" async def test_tool_invoke_telemetry_with_pydantic_args(span_exporter: InMemorySpanExporter): @@ -628,7 +636,8 @@ def pydantic_test_tool(x: int, y: int) -> int: result = await pydantic_test_tool.invoke(arguments=args_model, tool_call_id="pydantic_call") # Verify result - assert result == "15" + assert isinstance(result, list) + assert result[0].text == "15" spans = span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0] @@ -696,7 +705,8 @@ async def async_telemetry_test(x: int, y: int) -> int: result = await async_telemetry_test.invoke(x=3, y=4, tool_call_id="async_call") # Verify result - assert result == "12" + assert isinstance(result, list) + assert result[0].text == "12" spans = span_exporter.get_finished_spans() assert len(spans) == 1 span = spans[0] @@ -932,13 +942,137 @@ def tool_with_kwargs(x: int, **kwargs: Any) -> str: arguments=tool_with_kwargs.input_model(x=5), user_id="user2", ) - assert result == "x=5, user=user2" + assert isinstance(result, list) + assert result[0].text == "x=5, user=user2" # Verify invoke works without injected args (uses default) result_default = await tool_with_kwargs.invoke( arguments=tool_with_kwargs.input_model(x=10), ) - assert result_default == "x=10, user=unknown" + assert isinstance(result_default, list) + assert result_default[0].text == "x=10, user=unknown" + + +async def test_ai_function_with_explicit_invocation_context(): + """Test that invoke() can receive runtime kwargs via FunctionInvocationContext.""" + + @tool + def tool_with_context(x: int, ctx: FunctionInvocationContext) -> str: + """A tool that accepts runtime context injection.""" + user_id = ctx.kwargs.get("user_id", "unknown") + return f"x={x}, user={user_id}" + + assert tool_with_context.parameters() == { + "properties": {"x": {"title": "X", "type": "integer"}}, + "required": ["x"], + "title": "tool_with_context_input", + "type": "object", + } + + context = FunctionInvocationContext( + function=tool_with_context, + arguments=tool_with_context.input_model(x=7), + kwargs={"user_id": "ctx-user"}, + ) + + result = await tool_with_context.invoke(context=context) + + assert result[0].text == "x=7, user=ctx-user" + + +async def test_ai_function_with_typed_context_parameter_using_custom_name(): + """Test that typed context injection works for names other than ctx.""" + + @tool + def tool_with_runtime_context(x: int, runtime: FunctionInvocationContext) -> str: + """A tool that uses a custom context parameter name.""" + user_id = runtime.kwargs.get("user_id", "unknown") + return f"x={x}, user={user_id}" + + assert tool_with_runtime_context.parameters() == { + "properties": {"x": {"title": "X", "type": "integer"}}, + "required": ["x"], + "title": "tool_with_runtime_context_input", + "type": "object", + } + + context = FunctionInvocationContext( + function=tool_with_runtime_context, + arguments=tool_with_runtime_context.input_model(x=8), + kwargs={"user_id": "runtime-user"}, + ) + + result = await tool_with_runtime_context.invoke(context=context) + + assert result[0].text == "x=8, user=runtime-user" + + +async def test_ai_function_with_explicit_schema_and_untyped_ctx(): + """Test that explicit schemas allow an untyped ctx parameter.""" + + class ToolInput(BaseModel): + x: int + + @tool(schema=ToolInput) + def tool_with_schema(x, ctx) -> str: + """A tool with explicit schema and implicit ctx injection.""" + return f"x={x}, user={ctx.kwargs.get('user_id', 'unknown')}" + + context = FunctionInvocationContext( + function=tool_with_schema, + arguments=ToolInput(x=9), + kwargs={"user_id": "schema-user"}, + ) + + result = await tool_with_schema.invoke(context=context) + + assert result[0].text == "x=9, user=schema-user" + + +async def test_ai_function_with_explicit_schema_and_typed_ctx(): + """Test that explicit schemas also work with typed context injection.""" + + class ToolInput(BaseModel): + x: int + + @tool(schema=ToolInput) + def tool_with_schema(x: int, runtime: FunctionInvocationContext) -> str: + """A tool with explicit schema and typed context injection.""" + return f"x={x}, user={runtime.kwargs.get('user_id', 'unknown')}" + + context = FunctionInvocationContext( + function=tool_with_schema, + arguments=ToolInput(x=11), + kwargs={"user_id": "typed-schema-user"}, + ) + + result = await tool_with_schema.invoke(context=context) + + assert tool_with_schema.parameters() == ToolInput.model_json_schema() + assert result[0].text == "x=11, user=typed-schema-user" + + +def test_ai_function_with_multiple_typed_context_parameters_fails(): + """Test that tools reject multiple typed FunctionInvocationContext parameters.""" + + with pytest.raises(ValueError, match="multiple FunctionInvocationContext parameters"): + + @tool + def invalid_tool(ctx_one: FunctionInvocationContext, ctx_two: FunctionInvocationContext) -> str: + return f"{ctx_one.kwargs}-{ctx_two.kwargs}" + + +def test_ai_function_with_ctx_and_typed_context_parameter_fails(): + """Test that explicit-schema tools reject both implicit ctx and typed context parameters.""" + + class ToolInput(BaseModel): + x: int + + with pytest.raises(ValueError, match="multiple FunctionInvocationContext parameters"): + + @tool(schema=ToolInput) + def invalid_tool(x, ctx, runtime: FunctionInvocationContext) -> str: + return f"{x}-{ctx.kwargs}-{runtime.kwargs}" # region _parse_annotation tests diff --git a/python/packages/core/tests/core/test_types.py b/python/packages/core/tests/core/test_types.py index 2609cb29bdf..5e9469c8bd2 100644 --- a/python/packages/core/tests/core/test_types.py +++ b/python/packages/core/tests/core/test_types.py @@ -542,7 +542,12 @@ def test_function_result_content(): # Check the type and content assert content.type == "function_result" - assert content.result == {"param1": "value1"} + # Dict results are stringified and stored as text items + assert "param1" in content.result + assert "value1" in content.result + assert content.items is not None + assert len(content.items) == 1 + assert content.items[0].type == "text" # Ensure the instance is of type BaseContent assert isinstance(content, Content) @@ -2455,12 +2460,13 @@ class NestedModel(BaseModel): def test_parse_result_pydantic_model(): """Test that Pydantic BaseModel subclasses are properly serialized using model_dump().""" result = WeatherResult(temperature=22.5, condition="sunny") - json_result = FunctionTool.parse_result(result) + parsed = FunctionTool.parse_result(result) - # The result should be a valid JSON string - assert isinstance(json_result, str) - assert '"temperature": 22.5' in json_result or '"temperature":22.5' in json_result - assert '"condition": "sunny"' in json_result or '"condition":"sunny"' in json_result + assert isinstance(parsed, list) + assert len(parsed) == 1 + assert parsed[0].type == "text" + assert '"temperature": 22.5' in parsed[0].text or '"temperature":22.5' in parsed[0].text + assert '"condition": "sunny"' in parsed[0].text or '"condition":"sunny"' in parsed[0].text def test_parse_result_pydantic_model_in_list(): @@ -2469,14 +2475,14 @@ def test_parse_result_pydantic_model_in_list(): WeatherResult(temperature=20.0, condition="cloudy"), WeatherResult(temperature=25.0, condition="sunny"), ] - json_result = FunctionTool.parse_result(results) + parsed = FunctionTool.parse_result(results) - # The result should be a valid JSON string representing a list - assert isinstance(json_result, str) - assert json_result.startswith("[") - assert json_result.endswith("]") - assert "cloudy" in json_result - assert "sunny" in json_result + assert isinstance(parsed, list) + assert len(parsed) == 1 + assert parsed[0].type == "text" + assert parsed[0].text.startswith("[") + assert "cloudy" in parsed[0].text + assert "sunny" in parsed[0].text def test_parse_result_pydantic_model_in_dict(): @@ -2485,26 +2491,28 @@ def test_parse_result_pydantic_model_in_dict(): "current": WeatherResult(temperature=22.0, condition="partly cloudy"), "forecast": WeatherResult(temperature=24.0, condition="sunny"), } - json_result = FunctionTool.parse_result(results) + parsed = FunctionTool.parse_result(results) - # The result should be a valid JSON string representing a dict - assert isinstance(json_result, str) - assert "current" in json_result - assert "forecast" in json_result - assert "partly cloudy" in json_result - assert "sunny" in json_result + assert isinstance(parsed, list) + assert len(parsed) == 1 + assert parsed[0].type == "text" + assert "current" in parsed[0].text + assert "forecast" in parsed[0].text + assert "partly cloudy" in parsed[0].text + assert "sunny" in parsed[0].text def test_parse_result_nested_pydantic_model(): """Test that nested Pydantic models are properly serialized.""" result = NestedModel(name="Seattle", weather=WeatherResult(temperature=18.0, condition="rainy")) - json_result = FunctionTool.parse_result(result) + parsed = FunctionTool.parse_result(result) - # The result should be a valid JSON string - assert isinstance(json_result, str) - assert "Seattle" in json_result - assert "rainy" in json_result - assert "18.0" in json_result or "18" in json_result + assert isinstance(parsed, list) + assert len(parsed) == 1 + assert parsed[0].type == "text" + assert "Seattle" in parsed[0].text + assert "rainy" in parsed[0].text + assert "18.0" in parsed[0].text or "18" in parsed[0].text # region FunctionTool.parse_result with MCP TextContent-like objects @@ -2518,11 +2526,12 @@ class MockTextContent: text: str result = [MockTextContent("Hello from MCP tool!")] - json_result = FunctionTool.parse_result(result) + parsed = FunctionTool.parse_result(result) - # Should extract text and serialize as JSON array of strings - assert isinstance(json_result, str) - assert json_result == '["Hello from MCP tool!"]' + # Non-Content list items are serialized via _make_dumpable + assert isinstance(parsed, list) + assert len(parsed) == 1 + assert parsed[0].type == "text" def test_parse_result_text_content_multiple(): @@ -2533,11 +2542,12 @@ class MockTextContent: text: str result = [MockTextContent("First result"), MockTextContent("Second result")] - json_result = FunctionTool.parse_result(result) + parsed = FunctionTool.parse_result(result) - # Should extract text from each and serialize as JSON array - assert isinstance(json_result, str) - assert json_result == '["First result", "Second result"]' + # Non-Content list items are serialized via _make_dumpable + assert isinstance(parsed, list) + assert len(parsed) == 1 + assert parsed[0].type == "text" def test_parse_result_text_content_with_non_string_text(): @@ -2548,38 +2558,174 @@ def __init__(self): self.text = 12345 # Not a string! result = [BadTextContent()] - json_result = FunctionTool.parse_result(result) + parsed = FunctionTool.parse_result(result) # Should not extract text since it's not a string, will serialize the object - assert isinstance(json_result, str) + assert isinstance(parsed, list) + assert len(parsed) == 1 + assert parsed[0].type == "text" def test_parse_result_none_returns_empty_string(): - """Test that None returns an empty string.""" - assert FunctionTool.parse_result(None) == "" + """Test that None returns a list with empty text Content.""" + parsed = FunctionTool.parse_result(None) + assert isinstance(parsed, list) + assert len(parsed) == 1 + assert parsed[0].type == "text" + assert parsed[0].text == "" def test_parse_result_string_passthrough(): - """Test that strings are returned as-is.""" - assert FunctionTool.parse_result("hello world") == "hello world" - assert FunctionTool.parse_result('{"key": "value"}') == '{"key": "value"}' + """Test that strings are wrapped in Content.""" + parsed = FunctionTool.parse_result("hello world") + assert isinstance(parsed, list) + assert len(parsed) == 1 + assert parsed[0].text == "hello world" + + parsed2 = FunctionTool.parse_result('{"key": "value"}') + assert isinstance(parsed2, list) + assert len(parsed2) == 1 + assert parsed2[0].text == '{"key": "value"}' def test_parse_result_content_object(): - """Test that Content objects are serialized via to_dict.""" + """Test that text Content objects are wrapped in a list.""" content = Content.from_text("hello") result = FunctionTool.parse_result(content) - assert isinstance(result, str) - assert "hello" in result + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].type == "text" + assert result[0].text == "hello" def test_parse_result_list_of_content(): - """Test that list[Content] is serialized to JSON.""" + """Test that list[Content] with text-only items is returned as list[Content].""" contents = [Content.from_text("hello"), Content.from_text("world")] result = FunctionTool.parse_result(contents) - assert isinstance(result, str) - assert "hello" in result - assert "world" in result + assert isinstance(result, list) + assert len(result) == 2 + assert result[0].text == "hello" + assert result[1].text == "world" + + +def test_parse_result_single_image_content(): + """Test that a single image Content is preserved as list[Content].""" + image_content = Content.from_data(data=b"fake_png_bytes", media_type="image/png") + result = FunctionTool.parse_result(image_content) + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].type == "data" + assert result[0].media_type == "image/png" + + +def test_parse_result_single_text_content(): + """Test that a single text Content returns a list with one text Content.""" + text_content = Content.from_text("just text") + result = FunctionTool.parse_result(text_content) + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].type == "text" + assert result[0].text == "just text" + + +def test_parse_result_mixed_content_list(): + """Test that list with text and image Content is preserved.""" + contents = [ + Content.from_text("Chart rendered."), + Content.from_data(data=b"image_bytes", media_type="image/png"), + ] + result = FunctionTool.parse_result(contents) + assert isinstance(result, list) + assert len(result) == 2 + assert result[0].type == "text" + assert result[1].type == "data" + + +def test_from_function_result_with_content_list(): + """Test Content.from_function_result stores all items uniformly.""" + content_list = [ + Content.from_text("Chart rendered."), + Content.from_data(data=b"image_bytes", media_type="image/png"), + ] + result = Content.from_function_result(call_id="test-123", result=content_list) + assert result.type == "function_result" + assert result.call_id == "test-123" + assert result.result == "Chart rendered." + assert result.items is not None + assert len(result.items) == 2 + assert result.items[0].type == "text" + assert result.items[0].text == "Chart rendered." + assert result.items[1].type == "data" + assert result.items[1].media_type == "image/png" + + +def test_from_function_result_with_string(): + """Test Content.from_function_result with plain string result.""" + result = Content.from_function_result(call_id="test-123", result="just text") + assert result.type == "function_result" + assert result.call_id == "test-123" + assert result.result == "just text" + assert result.items is not None + assert len(result.items) == 1 + assert result.items[0].type == "text" + assert result.items[0].text == "just text" + + +def test_content_from_function_result_items_in_to_dict(): + """Test that items are included in to_dict serialization.""" + content_list = [ + Content.from_text("done"), + Content.from_data(data=b"png_data", media_type="image/png"), + ] + result = Content.from_function_result( + call_id="call-1", + result=content_list, + ) + d = result.to_dict() + assert "items" in d + assert len(d["items"]) == 2 + assert d["items"][0]["type"] == "text" + assert d["items"][1]["type"] == "data" + + +def test_from_function_result_with_only_rich_content_list(): + """Test Content.from_function_result with only image items and no text.""" + content_list = [ + Content.from_data(data=b"image_bytes", media_type="image/png"), + ] + result = Content.from_function_result(call_id="test-456", result=content_list) + assert result.type == "function_result" + assert result.result == "" + assert result.items is not None + assert len(result.items) == 1 + assert result.items[0].type == "data" + + +def test_function_result_items_roundtrip_via_dict(): + """Test that items survive a to_dict/from_dict round-trip as Content objects.""" + content_list = [ + Content.from_text("done"), + Content.from_data(data=b"png_data", media_type="image/png"), + ] + original = Content.from_function_result(call_id="call-rt", result=content_list) + restored = Content.from_dict(original.to_dict()) + assert restored.items is not None + assert len(restored.items) == 2 + assert isinstance(restored.items[0], Content) + assert restored.items[0].type == "text" + assert restored.items[0].text == "done" + assert isinstance(restored.items[1], Content) + assert restored.items[1].type == "data" + + +def test_from_function_result_with_non_content_list(): + """Test Content.from_function_result with a list of non-Content objects falls back to str.""" + result = Content.from_function_result(call_id="test-789", result=["hello", "world"]) + assert result.type == "function_result" + assert result.result == "['hello', 'world']" + assert result.items is not None + assert len(result.items) == 1 + assert result.items[0].type == "text" # endregion diff --git a/python/packages/core/tests/openai/test_openai_chat_client.py b/python/packages/core/tests/openai/test_openai_chat_client.py index 04321b08836..3dc4c23c6d5 100644 --- a/python/packages/core/tests/openai/test_openai_chat_client.py +++ b/python/packages/core/tests/openai/test_openai_chat_client.py @@ -142,7 +142,9 @@ def test_serialize_with_org_id(openai_unit_test_env: dict[str, str]) -> None: assert "User-Agent" not in dumped_settings.get("default_headers", {}) -async def test_content_filter_exception_handling(openai_unit_test_env: dict[str, str]) -> None: +async def test_content_filter_exception_handling( + openai_unit_test_env: dict[str, str], +) -> None: """Test that content filter errors are properly handled.""" client = OpenAIChatClient() messages = [Message(role="user", text="test message")] @@ -150,7 +152,9 @@ async def test_content_filter_exception_handling(openai_unit_test_env: dict[str, # Create a mock BadRequestError with content_filter code mock_response = MagicMock() mock_error = BadRequestError( - message="Content filter error", response=mock_response, body={"error": {"code": "content_filter"}} + message="Content filter error", + response=mock_response, + body={"error": {"code": "content_filter"}}, ) mock_error.code = "content_filter" @@ -184,7 +188,9 @@ class UnsupportedTool: assert result["tools"] == [dict_tool] -def test_prepare_tools_with_single_function_tool(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_tools_with_single_function_tool( + openai_unit_test_env: dict[str, str], +) -> None: """Test that a single FunctionTool is accepted for tool preparation.""" client = OpenAIChatClient() @@ -241,12 +247,17 @@ async def test_exception_message_includes_original_error_details() -> None: assert original_error_message in exception_message -def test_chat_response_content_order_text_before_tool_calls(openai_unit_test_env: dict[str, str]): +def test_chat_response_content_order_text_before_tool_calls( + openai_unit_test_env: dict[str, str], +): """Test that text content appears before tool calls in ChatResponse contents.""" # Import locally to avoid break other tests when the import changes from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage - from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall, Function + from openai.types.chat.chat_completion_message_tool_call import ( + ChatCompletionMessageToolCall, + Function, + ) # Create a mock OpenAI response with both text and tool calls mock_response = ChatCompletion( @@ -296,9 +307,10 @@ def test_function_result_falsy_values_handling(openai_unit_test_env: dict[str, s """ client = OpenAIChatClient() - # Test with empty list serialized as JSON string (as FunctionTool.invoke would produce) + # Test with empty list serialized as JSON string (pre-serialized result passed to from_function_result) message_with_empty_list = Message( - role="tool", contents=[Content.from_function_result(call_id="call-123", result="[]")] + role="tool", + contents=[Content.from_function_result(call_id="call-123", result="[]")], ) openai_messages = client._prepare_message_for_openai(message_with_empty_list) @@ -307,16 +319,18 @@ def test_function_result_falsy_values_handling(openai_unit_test_env: dict[str, s # Test with empty string (falsy but not None) message_with_empty_string = Message( - role="tool", contents=[Content.from_function_result(call_id="call-456", result="")] + role="tool", + contents=[Content.from_function_result(call_id="call-456", result="")], ) openai_messages = client._prepare_message_for_openai(message_with_empty_string) assert len(openai_messages) == 1 assert openai_messages[0]["content"] == "" # Empty string should be preserved - # Test with False serialized as JSON string (as FunctionTool.invoke would produce) + # Test with False serialized as JSON string (pre-serialized result passed to from_function_result) message_with_false = Message( - role="tool", contents=[Content.from_function_result(call_id="call-789", result="false")] + role="tool", + contents=[Content.from_function_result(call_id="call-789", result="false")], ) openai_messages = client._prepare_message_for_openai(message_with_false) @@ -336,7 +350,11 @@ def test_function_result_exception_handling(openai_unit_test_env: dict[str, str] message_with_exception = Message( role="tool", contents=[ - Content.from_function_result(call_id="call-123", result="Error: Function failed.", exception=test_exception) + Content.from_function_result( + call_id="call-123", + result="Error: Function failed.", + exception=test_exception, + ) ], ) @@ -346,16 +364,50 @@ def test_function_result_exception_handling(openai_unit_test_env: dict[str, str] assert openai_messages[0]["tool_call_id"] == "call-123" +def test_function_result_with_rich_items_warns_and_omits( + openai_unit_test_env: dict[str, str], +) -> None: + """Test that function_result with items logs a warning and omits rich items.""" + + client = OpenAIChatClient() + image_content = Content.from_data(data=b"image_bytes", media_type="image/png") + message = Message( + role="tool", + contents=[ + Content.from_function_result( + call_id="call_rich", + result=[Content.from_text("Result text"), image_content], + ) + ], + ) + + with patch("agent_framework.openai._chat_client.logger") as mock_logger: + openai_messages = client._prepare_message_for_openai(message) + + # Warning should be logged + mock_logger.warning.assert_called_once() + assert "does not support rich content" in mock_logger.warning.call_args[0][0] + + # Tool message should still be emitted with text result + assert len(openai_messages) == 1 + assert openai_messages[0]["role"] == "tool" + assert openai_messages[0]["tool_call_id"] == "call_rich" + assert openai_messages[0]["content"] == "Result text" + + def test_parse_result_string_passthrough(): - """Test that string values are passed through directly without JSON encoding.""" + """Test that string values are wrapped in Content.""" from agent_framework import FunctionTool result = FunctionTool.parse_result("simple string") - assert result == "simple string" - assert isinstance(result, str) + assert isinstance(result, list) + assert len(result) == 1 + assert result[0].text == "simple string" -def test_prepare_content_for_openai_data_content_image(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_content_for_openai_data_content_image( + openai_unit_test_env: dict[str, str], +) -> None: """Test _prepare_content_for_openai converts DataContent with image media type to OpenAI format.""" client = OpenAIChatClient() @@ -397,7 +449,8 @@ def test_prepare_content_for_openai_data_content_image(openai_unit_test_env: dic # Test DataContent with MP3 audio mp3_data_content = Content.from_uri( - uri="data:audio/mp3;base64,//uQAAAAWGluZwAAAA8AAAACAAACcQ==", media_type="audio/mp3" + uri="data:audio/mp3;base64,//uQAAAAWGluZwAAAA8AAAACAAACcQ==", + media_type="audio/mp3", ) result = client._prepare_content_for_openai(mp3_data_content) # type: ignore @@ -409,7 +462,9 @@ def test_prepare_content_for_openai_data_content_image(openai_unit_test_env: dic assert result["input_audio"]["format"] == "mp3" -def test_prepare_content_for_openai_document_file_mapping(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_content_for_openai_document_file_mapping( + openai_unit_test_env: dict[str, str], +) -> None: """Test _prepare_content_for_openai converts document files (PDF, DOCX, etc.) to OpenAI file format.""" client = OpenAIChatClient() @@ -515,7 +570,9 @@ def test_prepare_content_for_openai_document_file_mapping(openai_unit_test_env: assert "filename" not in result["file"] # None filename should be omitted -def test_parse_text_reasoning_content_from_response(openai_unit_test_env: dict[str, str]) -> None: +def test_parse_text_reasoning_content_from_response( + openai_unit_test_env: dict[str, str], +) -> None: """Test that TextReasoningContent is correctly parsed from OpenAI response with reasoning_details.""" client = OpenAIChatClient() @@ -563,7 +620,9 @@ def test_parse_text_reasoning_content_from_response(openai_unit_test_env: dict[s assert parsed_details == mock_reasoning_details -def test_parse_text_reasoning_content_from_streaming_chunk(openai_unit_test_env: dict[str, str]) -> None: +def test_parse_text_reasoning_content_from_streaming_chunk( + openai_unit_test_env: dict[str, str], +) -> None: """Test that TextReasoningContent is correctly parsed from streaming OpenAI chunk with reasoning_details.""" from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice @@ -611,7 +670,9 @@ def test_parse_text_reasoning_content_from_streaming_chunk(openai_unit_test_env: assert parsed_details == mock_reasoning_details -def test_prepare_message_with_text_reasoning_content(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_message_with_text_reasoning_content( + openai_unit_test_env: dict[str, str], +) -> None: """Test that TextReasoningContent with protected_data is correctly prepared for OpenAI.""" client = OpenAIChatClient() @@ -643,7 +704,9 @@ def test_prepare_message_with_text_reasoning_content(openai_unit_test_env: dict[ assert prepared[0]["content"] == "The answer is 42." -def test_prepare_message_with_only_text_reasoning_content(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_message_with_only_text_reasoning_content( + openai_unit_test_env: dict[str, str], +) -> None: """Test that a message with only text_reasoning content does not raise IndexError. Regression test for https://github.com/microsoft/agent-framework/issues/4384 @@ -677,7 +740,9 @@ def test_prepare_message_with_only_text_reasoning_content(openai_unit_test_env: assert prepared[0]["content"] == "" -def test_prepare_message_with_text_reasoning_before_text(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_message_with_text_reasoning_before_text( + openai_unit_test_env: dict[str, str], +) -> None: """Test that text_reasoning content appearing before text content is handled correctly. Regression test for https://github.com/microsoft/agent-framework/issues/4384 @@ -711,7 +776,9 @@ def test_prepare_message_with_text_reasoning_before_text(openai_unit_test_env: d assert prepared[0]["content"] == "The answer is 42." -def test_prepare_message_with_text_reasoning_before_function_call(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_message_with_text_reasoning_before_function_call( + openai_unit_test_env: dict[str, str], +) -> None: """Test that text_reasoning content appearing before a function call is handled correctly. Regression test for https://github.com/microsoft/agent-framework/issues/4384 @@ -747,7 +814,9 @@ def test_prepare_message_with_text_reasoning_before_function_call(openai_unit_te assert prepared[0]["role"] == "assistant" -def test_function_approval_content_is_skipped_in_preparation(openai_unit_test_env: dict[str, str]) -> None: +def test_function_approval_content_is_skipped_in_preparation( + openai_unit_test_env: dict[str, str], +) -> None: """Test that function approval request and response content are skipped.""" client = OpenAIChatClient() @@ -793,7 +862,9 @@ def test_function_approval_content_is_skipped_in_preparation(openai_unit_test_en assert prepared_mixed[0]["content"] == "I need approval for this action." -def test_usage_content_in_streaming_response(openai_unit_test_env: dict[str, str]) -> None: +def test_usage_content_in_streaming_response( + openai_unit_test_env: dict[str, str], +) -> None: """Test that UsageContent is correctly parsed from streaming response with usage data.""" from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.completion_usage import CompletionUsage @@ -829,13 +900,19 @@ def test_usage_content_in_streaming_response(openai_unit_test_env: dict[str, str assert usage_content.usage_details["total_token_count"] == 150 -def test_streaming_chunk_with_usage_and_text(openai_unit_test_env: dict[str, str]) -> None: +def test_streaming_chunk_with_usage_and_text( + openai_unit_test_env: dict[str, str], +) -> None: """Test that text content is not lost when usage data is in the same chunk. Some providers (e.g. Gemini) include both usage and text content in the same streaming chunk. See https://github.com/microsoft/agent-framework/issues/3434 """ - from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta + from openai.types.chat.chat_completion_chunk import ( + ChatCompletionChunk, + Choice, + ChoiceDelta, + ) from openai.types.completion_usage import CompletionUsage client = OpenAIChatClient() @@ -923,7 +1000,9 @@ def test_prepare_options_without_messages(openai_unit_test_env: dict[str, str]) client._prepare_options([], {}) -def test_prepare_tools_with_web_search_no_location(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_tools_with_web_search_no_location( + openai_unit_test_env: dict[str, str], +) -> None: """Test preparing web search tool without user location.""" client = OpenAIChatClient() @@ -937,7 +1016,9 @@ def test_prepare_tools_with_web_search_no_location(openai_unit_test_env: dict[st assert result["web_search_options"] == {} -def test_prepare_options_with_instructions(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_options_with_instructions( + openai_unit_test_env: dict[str, str], +) -> None: """Test that instructions are prepended as system message.""" client = OpenAIChatClient() @@ -969,7 +1050,9 @@ def test_prepare_message_with_author_name(openai_unit_test_env: dict[str, str]) assert prepared[0]["name"] == "TestUser" -def test_prepare_message_with_tool_result_author_name(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_message_with_tool_result_author_name( + openai_unit_test_env: dict[str, str], +) -> None: """Test that author_name is not included for TOOL role messages.""" client = OpenAIChatClient() @@ -987,7 +1070,9 @@ def test_prepare_message_with_tool_result_author_name(openai_unit_test_env: dict assert "name" not in prepared[0] -def test_prepare_system_message_content_is_string(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_system_message_content_is_string( + openai_unit_test_env: dict[str, str], +) -> None: """Test that system message content is a plain string, not a list. Some OpenAI-compatible endpoints (e.g. NVIDIA NIM) reject system messages @@ -1005,7 +1090,9 @@ def test_prepare_system_message_content_is_string(openai_unit_test_env: dict[str assert prepared[0]["content"] == "You are a helpful assistant." -def test_prepare_developer_message_content_is_string(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_developer_message_content_is_string( + openai_unit_test_env: dict[str, str], +) -> None: """Test that developer message content is a plain string, not a list.""" client = OpenAIChatClient() @@ -1019,7 +1106,9 @@ def test_prepare_developer_message_content_is_string(openai_unit_test_env: dict[ assert prepared[0]["content"] == "Follow these rules." -def test_prepare_system_message_multiple_text_contents_joined(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_system_message_multiple_text_contents_joined( + openai_unit_test_env: dict[str, str], +) -> None: """Test that system messages with multiple text contents are joined into a single string.""" client = OpenAIChatClient() @@ -1039,7 +1128,9 @@ def test_prepare_system_message_multiple_text_contents_joined(openai_unit_test_e assert prepared[0]["content"] == "You are a helpful assistant.\nBe concise." -def test_prepare_user_message_text_content_is_string(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_user_message_text_content_is_string( + openai_unit_test_env: dict[str, str], +) -> None: """Test that text-only user message content is flattened to a plain string. Some OpenAI-compatible endpoints (e.g. Foundry Local) cannot deserialize @@ -1057,7 +1148,9 @@ def test_prepare_user_message_text_content_is_string(openai_unit_test_env: dict[ assert prepared[0]["content"] == "Hello" -def test_prepare_user_message_multimodal_content_remains_list(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_user_message_multimodal_content_remains_list( + openai_unit_test_env: dict[str, str], +) -> None: """Test that multimodal user message content remains a list.""" client = OpenAIChatClient() @@ -1076,7 +1169,9 @@ def test_prepare_user_message_multimodal_content_remains_list(openai_unit_test_e assert has_list_content -def test_prepare_assistant_message_text_content_is_string(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_assistant_message_text_content_is_string( + openai_unit_test_env: dict[str, str], +) -> None: """Test that text-only assistant message content is flattened to a plain string.""" client = OpenAIChatClient() @@ -1090,7 +1185,9 @@ def test_prepare_assistant_message_text_content_is_string(openai_unit_test_env: assert prepared[0]["content"] == "Sure, I can help." -def test_tool_choice_required_with_function_name(openai_unit_test_env: dict[str, str]) -> None: +def test_tool_choice_required_with_function_name( + openai_unit_test_env: dict[str, str], +) -> None: """Test that tool_choice with required mode and function name is correctly prepared.""" client = OpenAIChatClient() @@ -1125,7 +1222,9 @@ def test_response_format_dict_passthrough(openai_unit_test_env: dict[str, str]) assert prepared_options["response_format"] == custom_format -def test_multiple_function_calls_in_single_message(openai_unit_test_env: dict[str, str]) -> None: +def test_multiple_function_calls_in_single_message( + openai_unit_test_env: dict[str, str], +) -> None: """Test that multiple function calls in a message are correctly prepared.""" client = OpenAIChatClient() @@ -1148,7 +1247,9 @@ def test_multiple_function_calls_in_single_message(openai_unit_test_env: dict[st assert prepared[0]["tool_calls"][1]["id"] == "call_2" -def test_prepare_options_removes_parallel_tool_calls_when_no_tools(openai_unit_test_env: dict[str, str]) -> None: +def test_prepare_options_removes_parallel_tool_calls_when_no_tools( + openai_unit_test_env: dict[str, str], +) -> None: """Test that parallel_tool_calls is removed when no tools are present.""" client = OpenAIChatClient() @@ -1176,7 +1277,9 @@ def test_prepare_options_excludes_conversation_id(openai_unit_test_env: dict[str assert prepared_options["temperature"] == 0.7 -async def test_streaming_exception_handling(openai_unit_test_env: dict[str, str]) -> None: +async def test_streaming_exception_handling( + openai_unit_test_env: dict[str, str], +) -> None: """Test that streaming errors are properly handled.""" client = OpenAIChatClient() messages = [Message(role="user", text="test")] @@ -1220,7 +1323,12 @@ class OutputStruct(BaseModel): param("allow_multiple_tool_calls", True, False, id="allow_multiple_tool_calls"), # OpenAIChatOptions - just verify they don't fail param("logit_bias", {"50256": -1}, False, id="logit_bias"), - param("prediction", {"type": "content", "content": "hello world"}, False, id="prediction"), + param( + "prediction", + {"type": "content", "content": "hello world"}, + False, + id="prediction", + ), # Complex options requiring output validation param("tools", [get_weather], True, id="tools_function"), param("tool_choice", "auto", True, id="tool_choice_auto"), @@ -1249,7 +1357,12 @@ class OutputStruct(BaseModel): "temperature_c": {"type": "number"}, "advisory": {"type": "string"}, }, - "required": ["location", "conditions", "temperature_c", "advisory"], + "required": [ + "location", + "conditions", + "temperature_c", + "advisory", + ], "additionalProperties": False, }, }, @@ -1383,7 +1496,12 @@ async def test_integration_web_search() -> None: } ) content = { - "messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")], + "messages": [ + Message( + role="user", + text="What is the current weather? Do not ask for my current location.", + ) + ], "options": { "tool_choice": "auto", "tools": [web_search_tool_with_location], diff --git a/python/packages/core/tests/openai/test_openai_responses_client.py b/python/packages/core/tests/openai/test_openai_responses_client.py index 78ff6ec17db..696dd777722 100644 --- a/python/packages/core/tests/openai/test_openai_responses_client.py +++ b/python/packages/core/tests/openai/test_openai_responses_client.py @@ -4,6 +4,7 @@ import json import os from datetime import datetime, timezone +from pathlib import Path from typing import Annotated, Any from unittest.mock import MagicMock, patch @@ -36,7 +37,10 @@ SupportsChatGetResponse, tool, ) -from agent_framework.exceptions import ChatClientException, ChatClientInvalidRequestException +from agent_framework.exceptions import ( + ChatClientException, + ChatClientInvalidRequestException, +) from agent_framework.openai import OpenAIResponsesClient from agent_framework.openai._exceptions import OpenAIContentFilterException from agent_framework.openai._responses_client import OPENAI_LOCAL_SHELL_CALL_ITEM_ID_KEY @@ -1313,7 +1317,10 @@ def test_prepare_messages_for_openai_full_conversation_with_reasoning() -> None: ), ], ), - Message(role="assistant", contents=[Content.from_text(text="I found hotels for you")]), + Message( + role="assistant", + contents=[Content.from_text(text="I found hotels for you")], + ), ] result = client._prepare_messages_for_openai(messages) @@ -1422,10 +1429,16 @@ def test_response_format_with_conflicting_definitions() -> None: client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") # Mock response_format and text_config that conflict - response_format = {"type": "json_schema", "format": {"type": "json_schema", "name": "Test", "schema": {}}} + response_format = { + "type": "json_schema", + "format": {"type": "json_schema", "name": "Test", "schema": {}}, + } text_config = {"format": {"type": "json_object"}} - with pytest.raises(ChatClientInvalidRequestException, match="Conflicting response_format definitions"): + with pytest.raises( + ChatClientInvalidRequestException, + match="Conflicting response_format definitions", + ): client._prepare_response_and_text_format(response_format=response_format, text_config=text_config) @@ -1457,7 +1470,13 @@ def test_response_format_with_format_key() -> None: """Test response_format that already has a format key.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") - response_format = {"format": {"type": "json_schema", "name": "MySchema", "schema": {"type": "object"}}} + response_format = { + "format": { + "type": "json_schema", + "name": "MySchema", + "schema": {"type": "object"}, + } + } _, text_config = client._prepare_response_and_text_format(response_format=response_format, text_config=None) @@ -1487,7 +1506,11 @@ def test_response_format_json_schema_with_strict() -> None: response_format = { "type": "json_schema", - "json_schema": {"name": "StrictSchema", "schema": {"type": "object"}, "strict": True}, + "json_schema": { + "name": "StrictSchema", + "schema": {"type": "object"}, + "strict": True, + }, } _, text_config = client._prepare_response_and_text_format(response_format=response_format, text_config=None) @@ -1521,7 +1544,10 @@ def test_response_format_json_schema_missing_schema() -> None: response_format = {"type": "json_schema", "json_schema": {"name": "NoSchema"}} - with pytest.raises(ChatClientInvalidRequestException, match="json_schema response_format requires a schema"): + with pytest.raises( + ChatClientInvalidRequestException, + match="json_schema response_format requires a schema", + ): client._prepare_response_and_text_format(response_format=response_format, text_config=None) @@ -1541,7 +1567,10 @@ def test_response_format_invalid_type() -> None: response_format = "invalid" # Not a Pydantic model or mapping - with pytest.raises(ChatClientInvalidRequestException, match="response_format must be a Pydantic model or mapping"): + with pytest.raises( + ChatClientInvalidRequestException, + match="response_format must be a Pydantic model or mapping", + ): client._prepare_response_and_text_format(response_format=response_format, text_config=None) # type: ignore @@ -2198,7 +2227,9 @@ async def test_get_response_streaming_with_response_format() -> None: async def run_streaming(): async for _ in client.get_response( - stream=True, messages=messages, options={"response_format": OutputStruct} + stream=True, + messages=messages, + options={"response_format": OutputStruct}, ): pass @@ -2262,6 +2293,45 @@ def test_prepare_content_for_openai_unsupported_content() -> None: assert result == {} +def test_prepare_content_for_openai_function_result_with_rich_items() -> None: + """Test _prepare_content_for_openai with function_result containing rich items.""" + client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") + + image_content = Content.from_data(data=b"image_bytes", media_type="image/png") + content = Content.from_function_result( + call_id="call_rich", + result=[Content.from_text("Result text"), image_content], + ) + + result = client._prepare_content_for_openai("user", content, {}) # type: ignore + + assert result["type"] == "function_call_output" + assert result["call_id"] == "call_rich" + # Output should be a list with text and image parts + output = result["output"] + assert isinstance(output, list) + assert len(output) == 2 + assert output[0]["type"] == "input_text" + assert output[0]["text"] == "Result text" + assert output[1]["type"] == "input_image" + + +def test_prepare_content_for_openai_function_result_without_items() -> None: + """Test _prepare_content_for_openai with plain string function_result.""" + client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") + + content = Content.from_function_result( + call_id="call_plain", + result="Simple result", + ) + + result = client._prepare_content_for_openai("user", content, {}) # type: ignore + + assert result["type"] == "function_call_output" + assert result["call_id"] == "call_plain" + assert result["output"] == "Simple result" + + def test_parse_chunk_from_openai_code_interpreter() -> None: """Test _parse_chunk_from_openai with code_interpreter_call.""" client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") @@ -2778,7 +2848,10 @@ async def test_instructions_sent_first_turn_then_skipped_for_continuation() -> N await client.get_response( messages=[Message(role="user", text="Tell me a joke")], - options={"instructions": "Reply in uppercase.", "conversation_id": "resp_123"}, + options={ + "instructions": "Reply in uppercase.", + "conversation_id": "resp_123", + }, ) second_input_messages = mock_create.call_args.kwargs["input"] @@ -2788,7 +2861,9 @@ async def test_instructions_sent_first_turn_then_skipped_for_continuation() -> N @pytest.mark.parametrize("conversation_id", ["resp_456", "conv_abc123"]) -async def test_instructions_not_repeated_for_continuation_ids(conversation_id: str) -> None: +async def test_instructions_not_repeated_for_continuation_ids( + conversation_id: str, +) -> None: client = OpenAIResponsesClient(model_id="test-model", api_key="test-key") mock_response = _create_mock_responses_text_response(response_id="resp_456") @@ -2889,7 +2964,12 @@ async def get_api_key() -> str: "temperature_c": {"type": "number"}, "advisory": {"type": "string"}, }, - "required": ["location", "conditions", "temperature_c", "advisory"], + "required": [ + "location", + "conditions", + "temperature_c", + "advisory", + ], "additionalProperties": False, }, }, @@ -3014,7 +3094,12 @@ async def test_integration_web_search() -> None: user_location={"country": "US", "city": "Seattle"}, ) content = { - "messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")], + "messages": [ + Message( + role="user", + text="What is the current weather? Do not ask for my current location.", + ) + ], "options": { "tool_choice": "auto", "tools": [web_search_tool_with_location], @@ -3105,7 +3190,42 @@ async def test_integration_streaming_file_search() -> None: assert "75" in full_message -# region Background Response / ContinuationToken Tests +@pytest.mark.flaky +@pytest.mark.integration +@skip_if_openai_integration_tests_disabled +async def test_integration_tool_rich_content_image() -> None: + """Integration test: a tool returns an image and the model describes it.""" + image_path = Path(__file__).parent.parent / "assets" / "sample_image.jpg" + image_bytes = image_path.read_bytes() + + @tool(approval_mode="never_require") + def get_test_image() -> Content: + """Return a test image for analysis.""" + return Content.from_data(data=image_bytes, media_type="image/jpeg") + + client = OpenAIResponsesClient() + client.function_invocation_configuration["max_iterations"] = 2 + + for streaming in [False, True]: + messages = [ + Message( + role="user", + text="Call the get_test_image tool and describe what you see.", + ) + ] + options: dict[str, Any] = {"tools": [get_test_image], "tool_choice": "auto"} + + if streaming: + response = await client.get_response(messages=messages, stream=True, options=options).get_final_response() + else: + response = await client.get_response(messages=messages, options=options) + + assert response is not None + assert isinstance(response, ChatResponse) + assert response.text is not None + assert len(response.text) > 0 + # sample_image.jpg contains a photo of a house; the model should mention it. + assert "house" in response.text.lower(), f"Model did not describe the house image. Response: {response.text}" def test_continuation_token_json_serializable() -> None: diff --git a/python/packages/core/tests/workflow/test_workflow_builder.py b/python/packages/core/tests/workflow/test_workflow_builder.py index 3a7b7195300..29a4bf02923 100644 --- a/python/packages/core/tests/workflow/test_workflow_builder.py +++ b/python/packages/core/tests/workflow/test_workflow_builder.py @@ -13,6 +13,8 @@ AgentRunInputs, AgentSession, BaseAgent, + Case, + Default, Executor, Message, ResponseStream, @@ -223,6 +225,29 @@ def condition_func(msg: MockMessage) -> bool: assert "Target" in workflow.executors +def test_switch_case_with_agents(): + """Test add_switch_case_edge_group with Case and Default edges using agents.""" + router = DummyAgent(id="router_agent", name="router") + handler = DummyAgent(id="handler", name="handler") + fallback = DummyAgent(id="fallback_agent", name="fallback") + + workflow = ( + WorkflowBuilder(start_executor=router) + .add_switch_case_edge_group( + router, + [ + Case(condition=lambda _: True, target=handler), + Default(target=fallback), + ], + ) + .build() + ) + + # All three agents should be AgentExecutor wrappers + agent_executors = [e for e in workflow.executors.values() if isinstance(e, AgentExecutor)] + assert len(agent_executors) == 3 + + # region with_output_from tests diff --git a/python/packages/durabletask/agent_framework_durabletask/_executors.py b/python/packages/durabletask/agent_framework_durabletask/_executors.py index 0a7cf50b0ba..713c1b4e69f 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_executors.py +++ b/python/packages/durabletask/agent_framework_durabletask/_executors.py @@ -124,10 +124,20 @@ def run_durable_agent( """ raise NotImplementedError - def get_new_session(self, agent_name: str, **kwargs: Any) -> DurableAgentSession: + def get_new_session( + self, + agent_name: str, + *, + session_id: str | None = None, + service_session_id: str | None = None, + ) -> DurableAgentSession: """Create a new DurableAgentSession with random session ID.""" - session_id = self._create_session_id(agent_name) - return DurableAgentSession.from_session_id(session_id, **kwargs) + durable_session_id = self._create_session_id(agent_name) + return DurableAgentSession( + durable_session_id=durable_session_id, + session_id=session_id, + service_session_id=service_session_id, + ) def _create_session_id( self, diff --git a/python/packages/durabletask/agent_framework_durabletask/_models.py b/python/packages/durabletask/agent_framework_durabletask/_models.py index 1c5484afbf6..19d5804bc29 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_models.py +++ b/python/packages/durabletask/agent_framework_durabletask/_models.py @@ -284,46 +284,48 @@ def __init__( durable_session_id: AgentSessionId | None = None, session_id: str | None = None, service_session_id: str | None = None, - **kwargs: Any, ) -> None: - super().__init__(session_id=session_id, service_session_id=service_session_id, **kwargs) - self._session_id_value: AgentSessionId | None = durable_session_id + super().__init__(session_id=session_id, service_session_id=service_session_id) + self.durable_session_id: AgentSessionId | None = durable_session_id - @property - def durable_session_id(self) -> AgentSessionId | None: - return self._session_id_value - - @durable_session_id.setter - def durable_session_id(self, value: AgentSessionId | None) -> None: - self._session_id_value = value + def to_dict(self) -> dict[str, Any]: + state = super().to_dict() + if self.durable_session_id is not None: + state[self._SERIALIZED_SESSION_ID_KEY] = str(self.durable_session_id) + return state @classmethod def from_session_id( cls, - session_id: AgentSessionId, - **kwargs: Any, + durable_session_id: AgentSessionId, + *, + session_id: str | None = None, + service_session_id: str | None = None, ) -> DurableAgentSession: - return cls(durable_session_id=session_id, **kwargs) - - def to_dict(self) -> dict[str, Any]: - state = super().to_dict() - if self._session_id_value is not None: - state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id_value) - return state + """Create a DurableAgentSession from an AgentSessionId.""" + return cls( + durable_session_id=durable_session_id, + session_id=session_id, + service_session_id=service_session_id, + ) @classmethod def from_dict(cls, data: dict[str, Any]) -> DurableAgentSession: - state_payload = dict(data) - session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None) - session = super().from_dict(state_payload) + """Create a DurableAgentSession from a state dict.""" + data = dict(data) # defensive copy — avoid mutating caller's dict + session_id_value = data.pop(cls._SERIALIZED_SESSION_ID_KEY, None) + session = super().from_dict(data) + durable_session_id: AgentSessionId | None = None # We need to create a DurableAgentSession from the base AgentSession + if session_id_value is not None: + if not isinstance(session_id_value, str): + raise ValueError("durable_session_id must be a string when present in serialized state") + durable_session_id = AgentSessionId.parse(session_id_value) + durable_session = cls( + durable_session_id=durable_session_id, session_id=session.session_id, service_session_id=session.service_session_id, ) durable_session.state.update(session.state) - if session_id_value is not None: - if not isinstance(session_id_value, str): - raise ValueError("durable_session_id must be a string when present in serialized state") - durable_session._session_id_value = AgentSessionId.parse(session_id_value) return durable_session diff --git a/python/packages/durabletask/agent_framework_durabletask/_shim.py b/python/packages/durabletask/agent_framework_durabletask/_shim.py index 5693876ad79..b21cac68318 100644 --- a/python/packages/durabletask/agent_framework_durabletask/_shim.py +++ b/python/packages/durabletask/agent_framework_durabletask/_shim.py @@ -133,16 +133,13 @@ def run( # type: ignore[override] session=session, ) - def create_session(self, **kwargs: Any) -> DurableAgentSession: + def create_session(self, *, session_id: str | None = None) -> DurableAgentSession: """Create a new agent session via the provider.""" - return self._executor.get_new_session(self.name, **kwargs) + return self._executor.get_new_session(self.name) - def get_session(self, **kwargs: Any) -> AgentSession: - """Retrieve an existing session via the provider. - - For durable agents, sessions do not use `service_session_id` so this is not used. - """ - return self._executor.get_new_session(self.name, **kwargs) + def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession: + """Retrieve an existing session via the provider.""" + return self._executor.get_new_session(self.name, service_session_id=service_session_id, session_id=session_id) def _normalize_messages(self, messages: AgentRunInputs | None) -> str: """Convert supported message inputs to a single string. diff --git a/python/packages/durabletask/tests/test_agent_session_id.py b/python/packages/durabletask/tests/test_agent_session_id.py index 571212f145e..3902acd22ff 100644 --- a/python/packages/durabletask/tests/test_agent_session_id.py +++ b/python/packages/durabletask/tests/test_agent_session_id.py @@ -2,6 +2,8 @@ """Unit tests for AgentSessionId and DurableAgentSession.""" +from typing import Any + import pytest from agent_framework import AgentSession @@ -153,7 +155,7 @@ def test_durable_session_id_setter(self) -> None: def test_from_session_id(self) -> None: """Test creating DurableAgentSession from session ID.""" session_id = AgentSessionId(name="TestAgent", key="test-key") - session = DurableAgentSession.from_session_id(session_id) + session = DurableAgentSession(durable_session_id=session_id) assert isinstance(session, DurableAgentSession) assert session.durable_session_id is not None @@ -161,10 +163,10 @@ def test_from_session_id(self) -> None: assert session.durable_session_id.name == "TestAgent" assert session.durable_session_id.key == "test-key" - def test_from_session_id_with_service_session_id(self) -> None: - """Test creating DurableAgentSession with service session ID.""" + def test_init_with_service_session_id(self) -> None: + """Test creating DurableAgentSession with explicit service session ID.""" session_id = AgentSessionId(name="TestAgent", key="test-key") - session = DurableAgentSession.from_session_id(session_id, service_session_id="service-123") + session = DurableAgentSession(durable_session_id=session_id, service_session_id="service-123") assert session.durable_session_id is not None assert session.durable_session_id == session_id @@ -192,7 +194,7 @@ def test_to_dict_without_durable_session_id(self) -> None: def test_from_dict_with_durable_session_id(self) -> None: """Test deserialization restores durable session ID.""" - serialized = { + serialized: dict[str, Any] = { "type": "session", "session_id": "session-123", "service_session_id": "service-123", @@ -210,7 +212,7 @@ def test_from_dict_with_durable_session_id(self) -> None: def test_from_dict_without_durable_session_id(self) -> None: """Test deserialization without durable session ID.""" - serialized = { + serialized: dict[str, Any] = { "type": "session", "session_id": "session-456", "service_session_id": "service-456", diff --git a/python/packages/durabletask/tests/test_client.py b/python/packages/durabletask/tests/test_client.py index 0acdfb2f9c6..a056d4e2549 100644 --- a/python/packages/durabletask/tests/test_client.py +++ b/python/packages/durabletask/tests/test_client.py @@ -88,15 +88,6 @@ def test_client_agent_can_create_sessions(self, agent_client: DurableAIAgentClie assert isinstance(session, DurableAgentSession) - def test_client_agent_session_with_parameters(self, agent_client: DurableAIAgentClient) -> None: - """Verify agent can create sessions with custom parameters.""" - agent = agent_client.get_agent("assistant") - - session = agent.create_session(service_session_id="client-session-123") - - assert isinstance(session, DurableAgentSession) - assert session.service_session_id == "client-session-123" - class TestDurableAIAgentClientPollingConfiguration: """Test polling configuration parameters for DurableAIAgentClient.""" diff --git a/python/packages/durabletask/tests/test_orchestration_context.py b/python/packages/durabletask/tests/test_orchestration_context.py index 033c274c88b..9f7cde156ca 100644 --- a/python/packages/durabletask/tests/test_orchestration_context.py +++ b/python/packages/durabletask/tests/test_orchestration_context.py @@ -82,17 +82,6 @@ def test_orchestration_agent_can_create_sessions(self, agent_context: DurableAIA assert isinstance(session, DurableAgentSession) - def test_orchestration_agent_session_with_parameters( - self, agent_context: DurableAIAgentOrchestrationContext - ) -> None: - """Verify agent can create sessions with custom parameters.""" - agent = agent_context.get_agent("assistant") - - session = agent.create_session(service_session_id="orch-session-456") - - assert isinstance(session, DurableAgentSession) - assert session.service_session_id == "orch-session-456" - if __name__ == "__main__": pytest.main([__file__, "-v", "--tb=short"]) diff --git a/python/packages/durabletask/tests/test_shim.py b/python/packages/durabletask/tests/test_shim.py index 423f587871a..687a0746a70 100644 --- a/python/packages/durabletask/tests/test_shim.py +++ b/python/packages/durabletask/tests/test_shim.py @@ -184,16 +184,31 @@ def test_create_session_delegates_to_executor(self, test_agent: DurableAIAgent[A mock_executor.get_new_session.assert_called_once_with("test_agent") assert session == mock_session - def test_create_session_forwards_kwargs(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: - """Verify create_session forwards kwargs to executor.""" - mock_session = DurableAgentSession(service_session_id="session-123") + def test_get_session_forwards_service_session_id( + self, test_agent: DurableAIAgent[Any], mock_executor: Mock + ) -> None: + """Verify get_session forwards service_session_id and session_id to executor.""" + mock_session = DurableAgentSession(service_session_id="svc-123") mock_executor.get_new_session.return_value = mock_session - test_agent.create_session(service_session_id="session-123") + session = test_agent.get_session("svc-123", session_id="local-456") - mock_executor.get_new_session.assert_called_once() - _, kwargs = mock_executor.get_new_session.call_args - assert kwargs["service_session_id"] == "session-123" + mock_executor.get_new_session.assert_called_once_with( + "test_agent", service_session_id="svc-123", session_id="local-456" + ) + assert session.service_session_id == "svc-123" + + def test_get_session_without_session_id(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None: + """Verify get_session works with only service_session_id (session_id defaults to None).""" + mock_session = DurableAgentSession(service_session_id="svc-789") + mock_executor.get_new_session.return_value = mock_session + + session = test_agent.get_session("svc-789") + + mock_executor.get_new_session.assert_called_once_with( + "test_agent", service_session_id="svc-789", session_id=None + ) + assert session.service_session_id == "svc-789" class TestDurableAgentProviderInterface: diff --git a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py index 16451ae85a2..4c1e64cd7cf 100644 --- a/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py +++ b/python/packages/foundry_local/agent_framework_foundry_local/_foundry_local_client.py @@ -146,11 +146,11 @@ def __init__( timeout: float | None = None, prepare_model: bool = True, device: DeviceType | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str = "utf-8", - **kwargs: Any, ) -> None: """Initialize a FoundryLocalClient. @@ -169,12 +169,11 @@ def __init__( The device is used to select the appropriate model variant. If not provided, the default device for your system will be used. The values are in the foundry_local.models.DeviceType enum. + additional_properties: Additional properties stored on the client instance. middleware: Optional sequence of ChatAndFunctionMiddlewareTypes to apply to requests. function_invocation_configuration: Optional configuration for function invocation support. env_file_path: If provided, the .env settings are read from this file path location. env_file_encoding: The encoding of the .env file, defaults to 'utf-8'. - kwargs: Additional keyword arguments, are passed to the RawOpenAIChatClient. - This can include middleware and additional properties. Examples: @@ -271,8 +270,8 @@ class MyOptions(FoundryLocalChatOptions, total=False): super().__init__( model_id=model_info.id, client=AsyncOpenAI(base_url=manager.endpoint, api_key=manager.api_key), + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) self.manager = manager diff --git a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py index 7fa7d0dce41..f8340b1bce0 100644 --- a/python/packages/github_copilot/agent_framework_github_copilot/_agent.py +++ b/python/packages/github_copilot/agent_framework_github_copilot/_agent.py @@ -303,7 +303,6 @@ def run( stream: Literal[False] = False, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> Awaitable[AgentResponse]: ... @overload @@ -314,7 +313,6 @@ def run( stream: Literal[True], session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ... def run( @@ -324,7 +322,6 @@ def run( stream: bool = False, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: """Get a response from the agent. @@ -339,7 +336,6 @@ def run( stream: Whether to stream the response. Defaults to False. session: The conversation session associated with the message(s). options: Runtime options (model, timeout, etc.). - kwargs: Additional keyword arguments. Returns: When stream=False: An Awaitable[AgentResponse]. @@ -354,10 +350,10 @@ def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse: return AgentResponse.from_updates(updates) return ResponseStream( - self._stream_updates(messages=messages, session=session, options=options, **kwargs), + self._stream_updates(messages=messages, session=session, options=options), finalizer=_finalize, ) - return self._run_impl(messages=messages, session=session, options=options, **kwargs) + return self._run_impl(messages=messages, session=session, options=options) async def _run_impl( self, @@ -365,7 +361,6 @@ async def _run_impl( *, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> AgentResponse: """Non-streaming implementation of run.""" if not self._started: @@ -414,7 +409,6 @@ async def _stream_updates( *, session: AgentSession | None = None, options: OptionsT | None = None, - **kwargs: Any, ) -> AsyncIterable[AgentResponseUpdate]: """Internal method to stream updates from GitHub Copilot. @@ -424,7 +418,6 @@ async def _stream_updates( Keyword Args: session: The conversation session associated with the message(s). options: Runtime options (model, timeout, etc.). - kwargs: Additional keyword arguments. Yields: AgentResponseUpdate items. @@ -535,8 +528,15 @@ async def handler(invocation: ToolInvocation) -> ToolResult: result = await ai_func.invoke(arguments=args_instance) else: result = await ai_func.invoke(arguments=args) + rich = [c for c in result if c.type in ("data", "uri")] + if rich: + logger.warning( + "GitHub Copilot does not support rich tool content; " + f"dropping {len(rich)} non-text item(s) from '{ai_func.name}'." + ) + text = "\n".join(c.text for c in result if c.type == "text" and c.text) return ToolResult( - text_result_for_llm=str(result), + text_result_for_llm=text or str(result), result_type="success", ) except Exception as e: diff --git a/python/packages/ollama/agent_framework_ollama/_chat_client.py b/python/packages/ollama/agent_framework_ollama/_chat_client.py index e31c1971da3..b931c894998 100644 --- a/python/packages/ollama/agent_framework_ollama/_chat_client.py +++ b/python/packages/ollama/agent_framework_ollama/_chat_client.py @@ -300,11 +300,11 @@ def __init__( host: str | None = None, client: AsyncClient | None = None, model_id: str | None = None, + additional_properties: dict[str, Any] | None = None, middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None, function_invocation_configuration: FunctionInvocationConfiguration | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Ollama Chat client. @@ -313,11 +313,11 @@ def __init__( Can be set via the OLLAMA_HOST env variable. client: An optional Ollama Client instance. If not provided, a new instance will be created. model_id: The Ollama chat model ID to use. Can be set via the OLLAMA_MODEL_ID env variable. + additional_properties: Additional properties stored on the client instance. middleware: Optional middleware to apply to the client. function_invocation_configuration: Optional function invocation configuration override. env_file_path: An optional path to a dotenv (.env) file to load environment variables from. env_file_encoding: The encoding to use when reading the dotenv (.env) file. Defaults to 'utf-8'. - **kwargs: Additional keyword arguments passed to BaseChatClient. """ ollama_settings = load_settings( OllamaSettings, @@ -336,9 +336,9 @@ def __init__( self.host = str(self.client._client.base_url) # type: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType] super().__init__( + additional_properties=additional_properties, middleware=middleware, function_invocation_configuration=function_invocation_configuration, - **kwargs, ) self.middleware = list(self.chat_middleware) @@ -500,11 +500,22 @@ def _format_assistant_message(self, message: Message) -> list[OllamaMessage]: def _format_tool_message(self, message: Message) -> list[OllamaMessage]: # Ollama does not support multiple tool results in a single message, so we create a separate - return [ - OllamaMessage(role="tool", content=str(item.result), tool_name=item.call_id) - for item in message.contents - if item.type == "function_result" - ] + messages: list[OllamaMessage] = [] + for item in message.contents: + if item.type == "function_result": + if item.items: + text_parts = [c.text or "" for c in item.items if c.type == "text"] + rich_items = [c for c in item.items if c.type in ("data", "uri")] + if rich_items: + logger.warning( + "Ollama does not support rich content (images, audio) in tool results. " + "Rich content items will be omitted." + ) + tool_text = "\n".join(text_parts) if text_parts else "" + else: + tool_text = str(item.result) if item.result is not None else "" + messages.append(OllamaMessage(role="tool", content=tool_text, tool_name=item.call_id)) + return messages def _parse_contents_from_ollama(self, response: OllamaChatResponse) -> list[Content]: contents: list[Content] = [] diff --git a/python/packages/ollama/agent_framework_ollama/_embedding_client.py b/python/packages/ollama/agent_framework_ollama/_embedding_client.py index 5cd35fc9f31..8e0508c708c 100644 --- a/python/packages/ollama/agent_framework_ollama/_embedding_client.py +++ b/python/packages/ollama/agent_framework_ollama/_embedding_client.py @@ -92,9 +92,9 @@ def __init__( model_id: str | None = None, host: str | None = None, client: AsyncClient | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize a raw Ollama embedding client.""" ollama_settings = load_settings( @@ -110,7 +110,7 @@ def __init__( self.model_id = ollama_settings["embedding_model_id"] # type: ignore[assignment,reportTypedDictNotRequiredAccess] self.client = client or AsyncClient(host=ollama_settings.get("host")) self.host = str(self.client._client.base_url) # type: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType] - super().__init__(**kwargs) + super().__init__(additional_properties=additional_properties) def service_url(self) -> str: """Get the URL of the service.""" @@ -214,17 +214,17 @@ def __init__( host: str | None = None, client: AsyncClient | None = None, otel_provider_name: str | None = None, + additional_properties: dict[str, Any] | None = None, env_file_path: str | None = None, env_file_encoding: str | None = None, - **kwargs: Any, ) -> None: """Initialize an Ollama embedding client.""" super().__init__( model_id=model_id, host=host, client=client, + additional_properties=additional_properties, otel_provider_name=otel_provider_name, env_file_path=env_file_path, env_file_encoding=env_file_encoding, - **kwargs, ) diff --git a/python/packages/redis/agent_framework_redis/_history_provider.py b/python/packages/redis/agent_framework_redis/_history_provider.py index e1a20b62180..be2db098b83 100644 --- a/python/packages/redis/agent_framework_redis/_history_provider.py +++ b/python/packages/redis/agent_framework_redis/_history_provider.py @@ -107,11 +107,18 @@ def _redis_key(self, session_id: str | None) -> str: """Get the Redis key for a given session's messages.""" return f"{self.key_prefix}:{session_id or 'default'}" - async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]: + async def get_messages( + self, + session_id: str | None, + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> list[Message]: """Retrieve stored messages for this session from Redis. Args: session_id: The session ID to retrieve messages for. + state: Optional session state. Unused for Redis-backed history. **kwargs: Additional arguments (unused). Returns: @@ -125,12 +132,20 @@ async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Mess messages.append(Message.from_dict(self._deserialize_json(serialized))) # type: ignore[union-attr] return messages - async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None: + async def save_messages( + self, + session_id: str | None, + messages: Sequence[Message], + *, + state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> None: """Persist messages for this session to Redis. Args: session_id: The session ID to store messages for. messages: The messages to persist. + state: Optional session state. Unused for Redis-backed history. **kwargs: Additional arguments (unused). """ if not messages: diff --git a/python/samples/02-agents/skills/script_approval/script_approval.py b/python/samples/02-agents/skills/script_approval/script_approval.py index 701d88de065..b1613ef28f0 100644 --- a/python/samples/02-agents/skills/script_approval/script_approval.py +++ b/python/samples/02-agents/skills/script_approval/script_approval.py @@ -90,7 +90,7 @@ async def main() -> None: # maintained automatically — just send the approval response) while result.user_input_requests: for request in result.user_input_requests: - print(f"\nApproval needed:") + print("\nApproval needed:") print(f" Function: {request.function_call.name}") # type: ignore[union-attr] print(f" Arguments: {request.function_call.arguments}") # type: ignore[union-attr] diff --git a/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py b/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py index 33748437e0d..fa78a9ede5c 100644 --- a/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py +++ b/python/samples/02-agents/tools/agent_as_tool_with_session_propagation.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Awaitable, Callable -from agent_framework import AgentContext, AgentSession +from agent_framework import AgentContext, AgentSession, FunctionInvocationContext, tool from agent_framework.openai import OpenAIResponsesClient from dotenv import load_dotenv @@ -18,9 +18,6 @@ When session propagation is enabled, both agents share the same session object, including session_id and the mutable state dict. This allows correlated conversation tracking and shared state across the agent hierarchy. - -The middleware functions below are purely for observability — they are NOT -required for session propagation to work. """ @@ -28,65 +25,83 @@ async def log_session( context: AgentContext, call_next: Callable[[], Awaitable[None]], ) -> None: - """Agent middleware that logs the session received by each agent. - - NOT required for session propagation — only used to observe the flow. - If propagation is working, both agents will show the same session_id. - """ + """Agent middleware that logs the session received by each agent.""" session: AgentSession | None = context.session + if not session: + print("No session found.") + await call_next() + return agent_name = context.agent.name or "unknown" - session_id = session.session_id if session else None - state = dict(session.state) if session else {} - print(f" [{agent_name}] session_id={session_id}, state={state}") + print( + f" [{agent_name}] session_id={session.session_id}, " + f"service_session_id={session.service_session_id} state={session.state}" + ) await call_next() +@tool(description="Use this tool to store the findings so that other agents can reason over them.") +def store_findings(findings: str, ctx: FunctionInvocationContext) -> None: + if ctx.session is None: + return + current_findings = ctx.session.state.get("findings") + if current_findings is None: + ctx.session.state["findings"] = findings + else: + ctx.session.state["findings"] = f"{current_findings}\n{findings}" + + +@tool(description="Use this tool to gather the current findings from other agents.") +def recall_findings(ctx: FunctionInvocationContext) -> str: + if ctx.session is None: + return "No session available" + current_findings = ctx.session.state.get("findings") + if current_findings is None: + return "Nothing yet" + return current_findings + + async def main() -> None: print("=== Agent-as-Tool: Session Propagation ===\n") client = OpenAIResponsesClient() - # --- Sub-agent: a research specialist --- - # The sub-agent has the same log_session middleware to prove it receives the session. research_agent = client.as_agent( name="ResearchAgent", - instructions="You are a research assistant. Provide concise answers.", + instructions="You are a research assistant. Provide concise answers and store your findings.", middleware=[log_session], + tools=[store_findings, recall_findings], ) - # propagate_session=True: the coordinator's session will be forwarded research_tool = research_agent.as_tool( name="research", - description="Research a topic and return findings", + description="Research a topic and store your findings.", arg_name="query", arg_description="The research query", propagate_session=True, ) - # --- Coordinator agent --- coordinator = client.as_agent( name="CoordinatorAgent", - instructions="You coordinate research. Use the 'research' tool to look up information.", - tools=[research_tool], + instructions=( + "You coordinate research. Use the 'research' tool to start research " + "and then use the recall findings tool to gather up everything." + ), + tools=[research_tool, store_findings, recall_findings], middleware=[log_session], ) - # Create a shared session and put some state in it session = coordinator.create_session() - session.state["request_source"] = "demo" + session.state["findings"] = None print(f"Session ID: {session.session_id}") print(f"Session state before run: {session.state}\n") - query = "What are the latest developments in quantum computing?" + query = "What are the latest developments in quantum computing and in AI?" print(f"User: {query}\n") result = await coordinator.run(query, session=session) print(f"\nCoordinator: {result}\n") print(f"Session state after run: {session.state}") - print( - "\nIf both agents show the same session_id above, session propagation is working." - ) if __name__ == "__main__": diff --git a/python/samples/02-agents/tools/function_tool_with_kwargs.py b/python/samples/02-agents/tools/function_tool_with_kwargs.py index 249ebc4a338..61db84eb170 100644 --- a/python/samples/02-agents/tools/function_tool_with_kwargs.py +++ b/python/samples/02-agents/tools/function_tool_with_kwargs.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from typing import Annotated, Any +from typing import Annotated -from agent_framework import tool +from agent_framework import FunctionInvocationContext, tool from agent_framework.openai import OpenAIResponsesClient from dotenv import load_dotenv from pydantic import Field @@ -14,27 +14,27 @@ """ AI Function with kwargs Example -This example demonstrates how to inject custom keyword arguments (kwargs) into an AI function -from the agent's run method, without exposing them to the AI model. +This example demonstrates how to inject runtime context into an AI function +from the agent's run method, without exposing it to the AI model. This is useful for passing runtime information like access tokens, user IDs, or request-specific context that the tool needs but the model shouldn't know about -or provide. +or provide. The injected context parameter can be typed as +``FunctionInvocationContext`` as shown here, or left untyped as ``ctx`` when you +prefer a lighter-weight sample setup. """ -# Define the function tool with **kwargs to accept injected arguments -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; -# see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. +# Define the function tool with explicit invocation context. +# The context parameter can also be declared as an untyped ``ctx`` parameter. @tool(approval_mode="never_require") def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], - **kwargs: Any, + ctx: FunctionInvocationContext, ) -> str: """Get the weather for a given location.""" - # Extract the injected argument from kwargs - user_id = kwargs.get("user_id", "unknown") + # Extract the injected argument from the explicit context + user_id = ctx.kwargs.get("user_id", "unknown") # Simulate using the user_id for logging or personalization print(f"Getting weather for user: {user_id}") @@ -49,9 +49,11 @@ async def main() -> None: tools=[get_weather], ) - # Pass the injected argument when running the agent - # The 'user_id' kwarg will be passed down to the tool execution via **kwargs - response = await agent.run("What is the weather like in Amsterdam?", user_id="user_123") + # Pass the runtime context explicitly when running the agent. + response = await agent.run( + "What is the weather like in Amsterdam?", + function_invocation_kwargs={"user_id": "user_123"}, + ) print(f"Agent: {response.text}") diff --git a/python/samples/02-agents/tools/function_tool_with_session_injection.py b/python/samples/02-agents/tools/function_tool_with_session_injection.py index 2689ff5f9c5..53cc63c2c0e 100644 --- a/python/samples/02-agents/tools/function_tool_with_session_injection.py +++ b/python/samples/02-agents/tools/function_tool_with_session_injection.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. import asyncio -from typing import Annotated, Any +from typing import Annotated -from agent_framework import AgentSession, tool +from agent_framework import AgentSession, FunctionInvocationContext, tool from agent_framework.openai import OpenAIResponsesClient from dotenv import load_dotenv from pydantic import Field @@ -14,23 +14,21 @@ """ AI Function with Session Injection Example -This example demonstrates the behavior when passing 'session' to agent.run() -and accessing that session in AI function. +This example demonstrates accessing the agent session inside a tool function +via ``FunctionInvocationContext.session``. The session is automatically +available when the agent is invoked with a session. """ -# Define the function tool with **kwargs -# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; -# see samples/02-agents/tools/function_tool_with_approval.py -# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py. +# Define the function tool with explicit invocation context. +# The context parameter can also be declared as an untyped parameter with the name: ``ctx``. @tool(approval_mode="never_require") async def get_weather( location: Annotated[str, Field(description="The location to get the weather for.")], - **kwargs: Any, + ctx: FunctionInvocationContext, ) -> str: """Get the weather for a given location.""" - # Get session object from kwargs - session = kwargs.get("session") + session = ctx.session if session and isinstance(session, AgentSession) and session.service_session_id: print(f"Session ID: {session.service_session_id}.") @@ -42,17 +40,19 @@ async def main() -> None: name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=[get_weather], - options={"store": True}, + default_options={"store": True}, ) # Create a session session = agent.create_session() - # Run the agent with the session - # Pass session via additional_function_arguments so tools can access it via **kwargs - opts = {"additional_function_arguments": {"session": session}} - print(f"Agent: {await agent.run('What is the weather in London?', session=session, options=opts)}") - print(f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session, options=opts)}") + # Run the agent with the session; tools receive it via ctx.session. + print( + f"Agent: {await agent.run('What is the weather in London?', session=session)}" + ) + print( + f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session)}" + ) print(f"Agent: {await agent.run('What cities did I ask about?', session=session)}")