Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/agents/extensions/handoff_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
MCPListToolsItem,
ReasoningItem,
RunItem,
ToolApprovalItem,
ToolCallItem,
ToolCallOutputItem,
ToolSearchCallItem,
Expand Down Expand Up @@ -63,6 +64,7 @@ def _remove_tools_from_items(items: tuple[RunItem, ...]) -> tuple[RunItem, ...]:
or isinstance(item, MCPListToolsItem)
or isinstance(item, MCPApprovalRequestItem)
or isinstance(item, MCPApprovalResponseItem)
or isinstance(item, ToolApprovalItem)
):
continue
filtered_items.append(item)
Expand All @@ -86,6 +88,14 @@ def _remove_tool_types_from_input(
"mcp_approval_request",
"mcp_approval_response",
"reasoning",
"code_interpreter_call",
"image_generation_call",
"local_shell_call",
"local_shell_call_output",
"shell_call",
"shell_call_output",
"apply_patch_call",
"apply_patch_call_output",
]

filtered_items: list[TResponseInputItem] = []
Expand Down
55 changes: 55 additions & 0 deletions tests/test_extension_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
MCPListToolsItem,
MessageOutputItem,
ReasoningItem,
ToolApprovalItem,
ToolCallItem,
ToolCallOutputItem,
ToolSearchCallItem,
Expand Down Expand Up @@ -1015,3 +1016,57 @@ def test_removes_mixed_mcp_and_function_items() -> None:
assert len(filtered_data.input_history) == 2
assert len(filtered_data.pre_handoff_items) == 1
assert len(filtered_data.new_items) == 1


def _get_hosted_tool_input_item(type_name: str) -> TResponseInputItem:
return cast(TResponseInputItem, {"id": "ht1", "type": type_name})


def _get_tool_approval_run_item() -> ToolApprovalItem:
return ToolApprovalItem(
agent=fake_agent(),
raw_item={"type": "function_call", "call_id": "c1", "name": "fn", "arguments": "{}"},
tool_name="fn",
)


def test_removes_hosted_tool_types_from_input_history() -> None:
"""Hosted tool types in raw input history should be removed by remove_all_tools."""
hosted_types = [
"code_interpreter_call",
"image_generation_call",
"local_shell_call",
"local_shell_call_output",
"shell_call",
"shell_call_output",
"apply_patch_call",
"apply_patch_call_output",
]
input_items: list[TResponseInputItem] = [_get_message_input_item("Hello")]
for t in hosted_types:
input_items.append(_get_hosted_tool_input_item(t))
input_items.append(_get_message_input_item("World"))

handoff_input_data = handoff_data(input_history=tuple(input_items))
filtered_data = remove_all_tools(handoff_input_data)
assert len(filtered_data.input_history) == 2
for item in filtered_data.input_history:
assert not isinstance(item, str)
assert item.get("type") not in set(hosted_types)


def test_removes_tool_approval_from_new_items() -> None:
"""ToolApprovalItem should be removed from new_items and pre_handoff_items."""
handoff_input_data = handoff_data(
pre_handoff_items=(
_get_tool_approval_run_item(),
_get_message_output_run_item("kept"),
),
new_items=(
_get_tool_approval_run_item(),
_get_message_output_run_item("also kept"),
),
)
filtered_data = remove_all_tools(handoff_input_data)
assert len(filtered_data.pre_handoff_items) == 1
assert len(filtered_data.new_items) == 1