diff --git a/src/agents/extensions/handoff_filters.py b/src/agents/extensions/handoff_filters.py index 663bda4cc8..3690f262a9 100644 --- a/src/agents/extensions/handoff_filters.py +++ b/src/agents/extensions/handoff_filters.py @@ -13,6 +13,7 @@ MCPListToolsItem, ReasoningItem, RunItem, + ToolApprovalItem, ToolCallItem, ToolCallOutputItem, ToolSearchCallItem, @@ -63,6 +64,7 @@ def _remove_tools_from_items(items: tuple[RunItem, ...]) -> tuple[RunItem, ...]: or isinstance(item, MCPListToolsItem) or isinstance(item, MCPApprovalRequestItem) or isinstance(item, MCPApprovalResponseItem) + or isinstance(item, ToolApprovalItem) ): continue filtered_items.append(item) @@ -86,6 +88,14 @@ def _remove_tool_types_from_input( "mcp_approval_request", "mcp_approval_response", "reasoning", + "code_interpreter_call", + "image_generation_call", + "local_shell_call", + "local_shell_call_output", + "shell_call", + "shell_call_output", + "apply_patch_call", + "apply_patch_call_output", ] filtered_items: list[TResponseInputItem] = [] diff --git a/tests/test_extension_filters.py b/tests/test_extension_filters.py index 30299883a2..97924d2852 100644 --- a/tests/test_extension_filters.py +++ b/tests/test_extension_filters.py @@ -24,6 +24,7 @@ MCPListToolsItem, MessageOutputItem, ReasoningItem, + ToolApprovalItem, ToolCallItem, ToolCallOutputItem, ToolSearchCallItem, @@ -1015,3 +1016,57 @@ def test_removes_mixed_mcp_and_function_items() -> None: assert len(filtered_data.input_history) == 2 assert len(filtered_data.pre_handoff_items) == 1 assert len(filtered_data.new_items) == 1 + + +def _get_hosted_tool_input_item(type_name: str) -> TResponseInputItem: + return cast(TResponseInputItem, {"id": "ht1", "type": type_name}) + + +def _get_tool_approval_run_item() -> ToolApprovalItem: + return ToolApprovalItem( + agent=fake_agent(), + raw_item={"type": "function_call", "call_id": "c1", "name": "fn", "arguments": "{}"}, + tool_name="fn", + ) + + +def test_removes_hosted_tool_types_from_input_history() -> None: + """Hosted tool types in raw input history should be removed by remove_all_tools.""" + hosted_types = [ + "code_interpreter_call", + "image_generation_call", + "local_shell_call", + "local_shell_call_output", + "shell_call", + "shell_call_output", + "apply_patch_call", + "apply_patch_call_output", + ] + input_items: list[TResponseInputItem] = [_get_message_input_item("Hello")] + for t in hosted_types: + input_items.append(_get_hosted_tool_input_item(t)) + input_items.append(_get_message_input_item("World")) + + handoff_input_data = handoff_data(input_history=tuple(input_items)) + filtered_data = remove_all_tools(handoff_input_data) + assert len(filtered_data.input_history) == 2 + for item in filtered_data.input_history: + assert not isinstance(item, str) + assert item.get("type") not in set(hosted_types) + + +def test_removes_tool_approval_from_new_items() -> None: + """ToolApprovalItem should be removed from new_items and pre_handoff_items.""" + handoff_input_data = handoff_data( + pre_handoff_items=( + _get_tool_approval_run_item(), + _get_message_output_run_item("kept"), + ), + new_items=( + _get_tool_approval_run_item(), + _get_message_output_run_item("also kept"), + ), + ) + filtered_data = remove_all_tools(handoff_input_data) + assert len(filtered_data.pre_handoff_items) == 1 + assert len(filtered_data.new_items) == 1