Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,24 @@ async def dispatch_stream_events() -> None:
if custom_output_extractor:
return await custom_output_extractor(run_result)

if run_result.final_output is not None and (
Comment thread
elainegan-openai marked this conversation as resolved.
not isinstance(run_result.final_output, str) or run_result.final_output != ""
):
return run_result.final_output

from .items import ItemHelpers

text_output = ItemHelpers.text_message_outputs(run_result.new_items)
if text_output:
return text_output

for item in reversed(run_result.to_input_list()):
if item.get("type") != "function_call_output":
continue
output = item.get("output")
if isinstance(output, str) and output:
return output

return run_result.final_output

run_agent_tool = _build_wrapped_function_tool(
Expand Down
15 changes: 8 additions & 7 deletions src/agents/run_internal/tool_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,9 +1400,10 @@ async def _drain_cancelled_tasks(
self,
tasks: set[asyncio.Task[Any]],
) -> tuple[_FunctionToolFailure | None, set[asyncio.Task[Any]]]:
late_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = {
task: "cancelled_teardown" for task in tasks
}
late_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = dict.fromkeys(
tasks,
"cancelled_teardown",
)
return await _drain_cancelled_function_tool_tasks(
pending_tasks=tasks,
task_states=self.task_states,
Expand All @@ -1415,9 +1416,9 @@ async def _wait_post_invoke_tasks(
self,
tasks: set[asyncio.Task[Any]],
) -> tuple[_FunctionToolFailure | None, set[asyncio.Task[Any]]]:
post_invoke_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = {
task: "post_invoke" for task in tasks
}
post_invoke_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = (
dict.fromkeys(tasks, "post_invoke")
)
return await _wait_pending_function_tool_tasks_for_timeout(
pending_tasks=tasks,
task_states=self.task_states,
Expand Down Expand Up @@ -1638,7 +1639,7 @@ async def _invoke_tool_and_run_post_invoke(
arguments=tool_call.arguments,
)
except asyncio.CancelledError as e:
if not self.isolate_parallel_failures or outer_task in self.teardown_cancelled_tasks:
if outer_task in self.teardown_cancelled_tasks:
raise

result = await maybe_invoke_function_tool_failure_error_function(
Expand Down
39 changes: 39 additions & 0 deletions tests/test_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,45 @@ async def _cancel_tool() -> str:
]


@pytest.mark.asyncio
async def test_single_tool_call_with_cancelled_tool_reaches_final_output() -> None:
async def _cancel_tool() -> str:
raise asyncio.CancelledError("tool-cancelled")

model = FakeModel()
agent = Agent(
name="test",
model=model,
tools=[function_tool(_cancel_tool, name_override="cancel_tool")],
)

model.add_multiple_turn_outputs(
[
[get_function_tool_call("cancel_tool", "{}", call_id="call_cancel")],
[get_text_message("final answer")],
]
)

result = await Runner.run(agent, input="user_message")

assert result.final_output == "final answer"
assert len(result.raw_responses) == 2

second_turn_input = cast(list[dict[str, Any]], model.last_turn_args["input"])
tool_outputs = [
item for item in second_turn_input if item.get("type") == "function_call_output"
]
assert tool_outputs == [
{
"call_id": "call_cancel",
"output": (
"An error occurred while running the tool. Please try again. Error: tool-cancelled"
),
"type": "function_call_output",
},
]


@pytest.mark.asyncio
async def test_reasoning_item_id_policy_omits_follow_up_reasoning_ids() -> None:
model = FakeModel()
Expand Down
40 changes: 40 additions & 0 deletions tests/test_agent_runner_streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,46 @@ async def _cancel_tool() -> str:
]


@pytest.mark.asyncio
async def test_streamed_single_tool_call_with_cancelled_tool_reaches_final_output() -> None:
async def _cancel_tool() -> str:
raise asyncio.CancelledError("tool-cancelled")

model = FakeModel()
agent = Agent(
name="test",
model=model,
tools=[function_tool(_cancel_tool, name_override="cancel_tool")],
)

model.add_multiple_turn_outputs(
[
[get_function_tool_call("cancel_tool", "{}", call_id="call_cancel")],
[get_text_message("final answer")],
]
)

result = Runner.run_streamed(agent, input="user_message")
await consume_stream(result)

assert result.final_output == "final answer"
assert len(result.raw_responses) == 2

second_turn_input = cast(list[dict[str, Any]], model.last_turn_args["input"])
tool_outputs = [
item for item in second_turn_input if item.get("type") == "function_call_output"
]
assert tool_outputs == [
{
"call_id": "call_cancel",
"output": (
"An error occurred while running the tool. Please try again. Error: tool-cancelled"
),
"type": "function_call_output",
},
]


@pytest.mark.asyncio
async def test_streamed_reasoning_item_id_policy_omits_follow_up_reasoning_ids() -> None:
model = FakeModel()
Expand Down
23 changes: 23 additions & 0 deletions tests/test_run_step_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,29 @@ async def _manual_on_invoke_tool(_ctx: ToolContext[Any], _args: str) -> str:
)


@pytest.mark.asyncio
async def test_single_tool_call_uses_default_failure_error_function_for_cancelled_tool():
async def _cancel_tool() -> str:
raise asyncio.CancelledError("tool-cancelled")

cancel_tool = function_tool(_cancel_tool, name_override="cancel_tool")
agent = Agent(name="test", tools=[cancel_tool])
response = ModelResponse(
output=[get_function_tool_call("cancel_tool", "{}", call_id="1")],
usage=Usage(),
response_id=None,
)

result = await get_execute_result(agent, response)

assert len(result.generated_items) == 2
assert isinstance(result.next_step, NextStepRunAgain)
assert_item_is_function_tool_call_output(
result.generated_items[1],
"An error occurred while running the tool. Please try again. Error: tool-cancelled",
)


@pytest.mark.asyncio
async def test_multiple_tool_calls_surface_hook_failure_over_sibling_cancellation():
hook_started = asyncio.Event()
Expand Down
Loading