Skip to content

Commit 7abbed9

Browse files
Recover single cancelled function tools
Co-authored-by: Codex <noreply@openai.com>
1 parent 681320e commit 7abbed9

4 files changed

Lines changed: 110 additions & 7 deletions

File tree

src/agents/run_internal/tool_execution.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,9 +1400,10 @@ async def _drain_cancelled_tasks(
14001400
self,
14011401
tasks: set[asyncio.Task[Any]],
14021402
) -> tuple[_FunctionToolFailure | None, set[asyncio.Task[Any]]]:
1403-
late_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = {
1404-
task: "cancelled_teardown" for task in tasks
1405-
}
1403+
late_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = dict.fromkeys(
1404+
tasks,
1405+
"cancelled_teardown",
1406+
)
14061407
return await _drain_cancelled_function_tool_tasks(
14071408
pending_tasks=tasks,
14081409
task_states=self.task_states,
@@ -1415,9 +1416,9 @@ async def _wait_post_invoke_tasks(
14151416
self,
14161417
tasks: set[asyncio.Task[Any]],
14171418
) -> tuple[_FunctionToolFailure | None, set[asyncio.Task[Any]]]:
1418-
post_invoke_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = {
1419-
task: "post_invoke" for task in tasks
1420-
}
1419+
post_invoke_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = (
1420+
dict.fromkeys(tasks, "post_invoke")
1421+
)
14211422
return await _wait_pending_function_tool_tasks_for_timeout(
14221423
pending_tasks=tasks,
14231424
task_states=self.task_states,
@@ -1638,7 +1639,7 @@ async def _invoke_tool_and_run_post_invoke(
16381639
arguments=tool_call.arguments,
16391640
)
16401641
except asyncio.CancelledError as e:
1641-
if not self.isolate_parallel_failures or outer_task in self.teardown_cancelled_tasks:
1642+
if outer_task in self.teardown_cancelled_tasks:
16421643
raise
16431644

16441645
result = await maybe_invoke_function_tool_failure_error_function(

tests/test_agent_runner.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,45 @@ async def _cancel_tool() -> str:
641641
]
642642

643643

644+
@pytest.mark.asyncio
645+
async def test_single_tool_call_with_cancelled_tool_reaches_final_output() -> None:
646+
async def _cancel_tool() -> str:
647+
raise asyncio.CancelledError("tool-cancelled")
648+
649+
model = FakeModel()
650+
agent = Agent(
651+
name="test",
652+
model=model,
653+
tools=[function_tool(_cancel_tool, name_override="cancel_tool")],
654+
)
655+
656+
model.add_multiple_turn_outputs(
657+
[
658+
[get_function_tool_call("cancel_tool", "{}", call_id="call_cancel")],
659+
[get_text_message("final answer")],
660+
]
661+
)
662+
663+
result = await Runner.run(agent, input="user_message")
664+
665+
assert result.final_output == "final answer"
666+
assert len(result.raw_responses) == 2
667+
668+
second_turn_input = cast(list[dict[str, Any]], model.last_turn_args["input"])
669+
tool_outputs = [
670+
item for item in second_turn_input if item.get("type") == "function_call_output"
671+
]
672+
assert tool_outputs == [
673+
{
674+
"call_id": "call_cancel",
675+
"output": (
676+
"An error occurred while running the tool. Please try again. Error: tool-cancelled"
677+
),
678+
"type": "function_call_output",
679+
},
680+
]
681+
682+
644683
@pytest.mark.asyncio
645684
async def test_reasoning_item_id_policy_omits_follow_up_reasoning_ids() -> None:
646685
model = FakeModel()

tests/test_agent_runner_streamed.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,46 @@ async def _cancel_tool() -> str:
496496
]
497497

498498

499+
@pytest.mark.asyncio
500+
async def test_streamed_single_tool_call_with_cancelled_tool_reaches_final_output() -> None:
501+
async def _cancel_tool() -> str:
502+
raise asyncio.CancelledError("tool-cancelled")
503+
504+
model = FakeModel()
505+
agent = Agent(
506+
name="test",
507+
model=model,
508+
tools=[function_tool(_cancel_tool, name_override="cancel_tool")],
509+
)
510+
511+
model.add_multiple_turn_outputs(
512+
[
513+
[get_function_tool_call("cancel_tool", "{}", call_id="call_cancel")],
514+
[get_text_message("final answer")],
515+
]
516+
)
517+
518+
result = Runner.run_streamed(agent, input="user_message")
519+
await consume_stream(result)
520+
521+
assert result.final_output == "final answer"
522+
assert len(result.raw_responses) == 2
523+
524+
second_turn_input = cast(list[dict[str, Any]], model.last_turn_args["input"])
525+
tool_outputs = [
526+
item for item in second_turn_input if item.get("type") == "function_call_output"
527+
]
528+
assert tool_outputs == [
529+
{
530+
"call_id": "call_cancel",
531+
"output": (
532+
"An error occurred while running the tool. Please try again. Error: tool-cancelled"
533+
),
534+
"type": "function_call_output",
535+
},
536+
]
537+
538+
499539
@pytest.mark.asyncio
500540
async def test_streamed_reasoning_item_id_policy_omits_follow_up_reasoning_ids() -> None:
501541
model = FakeModel()

tests/test_run_step_execution.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -740,6 +740,29 @@ async def _manual_on_invoke_tool(_ctx: ToolContext[Any], _args: str) -> str:
740740
)
741741

742742

743+
@pytest.mark.asyncio
744+
async def test_single_tool_call_uses_default_failure_error_function_for_cancelled_tool():
745+
async def _cancel_tool() -> str:
746+
raise asyncio.CancelledError("tool-cancelled")
747+
748+
cancel_tool = function_tool(_cancel_tool, name_override="cancel_tool")
749+
agent = Agent(name="test", tools=[cancel_tool])
750+
response = ModelResponse(
751+
output=[get_function_tool_call("cancel_tool", "{}", call_id="1")],
752+
usage=Usage(),
753+
response_id=None,
754+
)
755+
756+
result = await get_execute_result(agent, response)
757+
758+
assert len(result.generated_items) == 2
759+
assert isinstance(result.next_step, NextStepRunAgain)
760+
assert_item_is_function_tool_call_output(
761+
result.generated_items[1],
762+
"An error occurred while running the tool. Please try again. Error: tool-cancelled",
763+
)
764+
765+
743766
@pytest.mark.asyncio
744767
async def test_multiple_tool_calls_surface_hook_failure_over_sibling_cancellation():
745768
hook_started = asyncio.Event()

0 commit comments

Comments
 (0)