|
| 1 | +import asyncio |
1 | 2 | from typing import Any, cast |
2 | 3 |
|
3 | 4 | import pytest |
4 | 5 | from openai.types.responses.response_output_item import McpCall, McpListTools, McpListToolsTool |
5 | 6 |
|
6 | 7 | from agents import Agent, HostedMCPTool |
| 8 | +from agents.handoffs import handoff |
7 | 9 | from agents.items import MCPListToolsItem, ModelResponse, RunItem, ToolCallItem, TResponseInputItem |
8 | 10 | from agents.lifecycle import RunHooks |
9 | 11 | from agents.models.fake_id import FAKE_RESPONSES_ID |
|
12 | 14 | from agents.run_context import RunContextWrapper |
13 | 15 | from agents.run_internal.oai_conversation import OpenAIServerConversationTracker |
14 | 16 | from agents.run_internal.run_loop import get_new_response, run_single_turn_streamed |
| 17 | +from agents.run_internal.run_steps import NextStepHandoff |
15 | 18 | from agents.run_internal.tool_use_tracker import AgentToolUseTracker |
16 | 19 | from agents.stream_events import RunItemStreamEvent |
17 | 20 | from agents.usage import Usage |
18 | 21 |
|
19 | 22 | from .fake_model import FakeModel |
20 | | -from .test_responses import get_text_message |
| 23 | +from .test_responses import get_handoff_tool_call, get_text_message |
21 | 24 |
|
22 | 25 |
|
23 | 26 | class DummyRunItem: |
@@ -805,3 +808,102 @@ def _filter_input(payload: Any) -> ModelInputData: |
805 | 808 | assert len(tool_call_events) == 1 |
806 | 809 | assert tool_call_events[0].description == "Search the docs." |
807 | 810 | assert tool_call_events[0].title == "Search Docs" |
| 811 | + |
| 812 | + |
| 813 | +@pytest.mark.asyncio |
| 814 | +async def test_run_single_turn_streamed_recovers_cancelled_queue_for_handoff( |
| 815 | + monkeypatch: pytest.MonkeyPatch, |
| 816 | +) -> None: |
| 817 | + model = FakeModel() |
| 818 | + target = Agent(name="target", model=FakeModel()) |
| 819 | + model.set_next_output([get_handoff_tool_call(target)]) |
| 820 | + agent = Agent(name="source", model=model, handoffs=[handoff(target)]) |
| 821 | + context_wrapper: RunContextWrapper[dict[str, Any]] = RunContextWrapper(context={}) |
| 822 | + tool_use_tracker = AgentToolUseTracker() |
| 823 | + streamed_result = RunResultStreaming( |
| 824 | + input=[cast(TResponseInputItem, {"role": "user", "content": "first"})], |
| 825 | + new_items=[], |
| 826 | + raw_responses=[], |
| 827 | + final_output=None, |
| 828 | + input_guardrail_results=[], |
| 829 | + output_guardrail_results=[], |
| 830 | + tool_input_guardrail_results=[], |
| 831 | + tool_output_guardrail_results=[], |
| 832 | + context_wrapper=context_wrapper, |
| 833 | + current_agent=agent, |
| 834 | + current_turn=1, |
| 835 | + max_turns=2, |
| 836 | + _current_agent_output_schema=None, |
| 837 | + trace=None, |
| 838 | + interruptions=[], |
| 839 | + ) |
| 840 | + |
| 841 | + def _raise_cancelled(*args: Any, **kwargs: Any) -> None: |
| 842 | + raise asyncio.CancelledError |
| 843 | + |
| 844 | + monkeypatch.setattr( |
| 845 | + "agents.run_internal.run_loop.stream_step_result_to_queue", |
| 846 | + _raise_cancelled, |
| 847 | + ) |
| 848 | + |
| 849 | + result = await run_single_turn_streamed( |
| 850 | + streamed_result, |
| 851 | + agent, |
| 852 | + RunHooks(), |
| 853 | + context_wrapper, |
| 854 | + RunConfig(), |
| 855 | + should_run_agent_start_hooks=False, |
| 856 | + tool_use_tracker=tool_use_tracker, |
| 857 | + all_tools=[], |
| 858 | + ) |
| 859 | + |
| 860 | + assert isinstance(result.next_step, NextStepHandoff) |
| 861 | + assert result.next_step.new_agent.name == "target" |
| 862 | + |
| 863 | + |
| 864 | +@pytest.mark.asyncio |
| 865 | +async def test_run_single_turn_streamed_propagates_cancelled_queue_without_handoff( |
| 866 | + monkeypatch: pytest.MonkeyPatch, |
| 867 | +) -> None: |
| 868 | + model = FakeModel() |
| 869 | + model.set_next_output([get_text_message("ok")]) |
| 870 | + agent = Agent(name="source", model=model) |
| 871 | + context_wrapper: RunContextWrapper[dict[str, Any]] = RunContextWrapper(context={}) |
| 872 | + tool_use_tracker = AgentToolUseTracker() |
| 873 | + streamed_result = RunResultStreaming( |
| 874 | + input=[cast(TResponseInputItem, {"role": "user", "content": "first"})], |
| 875 | + new_items=[], |
| 876 | + raw_responses=[], |
| 877 | + final_output=None, |
| 878 | + input_guardrail_results=[], |
| 879 | + output_guardrail_results=[], |
| 880 | + tool_input_guardrail_results=[], |
| 881 | + tool_output_guardrail_results=[], |
| 882 | + context_wrapper=context_wrapper, |
| 883 | + current_agent=agent, |
| 884 | + current_turn=1, |
| 885 | + max_turns=2, |
| 886 | + _current_agent_output_schema=None, |
| 887 | + trace=None, |
| 888 | + interruptions=[], |
| 889 | + ) |
| 890 | + |
| 891 | + def _raise_cancelled(*args: Any, **kwargs: Any) -> None: |
| 892 | + raise asyncio.CancelledError |
| 893 | + |
| 894 | + monkeypatch.setattr( |
| 895 | + "agents.run_internal.run_loop.stream_step_result_to_queue", |
| 896 | + _raise_cancelled, |
| 897 | + ) |
| 898 | + |
| 899 | + with pytest.raises(asyncio.CancelledError): |
| 900 | + await run_single_turn_streamed( |
| 901 | + streamed_result, |
| 902 | + agent, |
| 903 | + RunHooks(), |
| 904 | + context_wrapper, |
| 905 | + RunConfig(), |
| 906 | + should_run_agent_start_hooks=False, |
| 907 | + tool_use_tracker=tool_use_tracker, |
| 908 | + all_tools=[], |
| 909 | + ) |
0 commit comments