Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,26 @@ async def dispatch_stream_events() -> None:
if custom_output_extractor:
return await custom_output_extractor(run_result)

if run_result.final_output is not None and (
Comment thread
elainegan-openai marked this conversation as resolved.
not isinstance(run_result.final_output, str) or run_result.final_output != ""
):
return run_result.final_output

from .items import ItemHelpers, MessageOutputItem, ToolCallOutputItem

for item in reversed(run_result.new_items):
if isinstance(item, MessageOutputItem):
text_output = ItemHelpers.text_message_output(item)
if text_output:
return text_output
Comment thread
seratch marked this conversation as resolved.

if (
isinstance(item, ToolCallOutputItem)
and isinstance(item.output, str)
and item.output
):
return item.output

return run_result.final_output

run_agent_tool = _build_wrapped_function_tool(
Expand Down
15 changes: 8 additions & 7 deletions src/agents/run_internal/tool_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,9 +1400,10 @@ async def _drain_cancelled_tasks(
self,
tasks: set[asyncio.Task[Any]],
) -> tuple[_FunctionToolFailure | None, set[asyncio.Task[Any]]]:
late_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = {
task: "cancelled_teardown" for task in tasks
}
late_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = dict.fromkeys(
tasks,
"cancelled_teardown",
)
return await _drain_cancelled_function_tool_tasks(
pending_tasks=tasks,
task_states=self.task_states,
Expand All @@ -1415,9 +1416,9 @@ async def _wait_post_invoke_tasks(
self,
tasks: set[asyncio.Task[Any]],
) -> tuple[_FunctionToolFailure | None, set[asyncio.Task[Any]]]:
post_invoke_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = {
task: "post_invoke" for task in tasks
}
post_invoke_failure_sources: dict[asyncio.Task[Any], _FunctionToolFailureSource] = (
dict.fromkeys(tasks, "post_invoke")
)
return await _wait_pending_function_tool_tasks_for_timeout(
pending_tasks=tasks,
task_states=self.task_states,
Expand Down Expand Up @@ -1638,7 +1639,7 @@ async def _invoke_tool_and_run_post_invoke(
arguments=tool_call.arguments,
)
except asyncio.CancelledError as e:
if not self.isolate_parallel_failures or outer_task in self.teardown_cancelled_tasks:
if outer_task in self.teardown_cancelled_tasks:
raise

result = await maybe_invoke_function_tool_failure_error_function(
Expand Down
178 changes: 178 additions & 0 deletions tests/test_agent_as_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
Session,
SessionSettings,
ToolApprovalItem,
ToolCallOutputItem,
TResponseInputItem,
Usage,
tool_namespace,
Expand Down Expand Up @@ -407,6 +408,183 @@ async def extractor(result) -> str:
assert output == "custom output"


@pytest.mark.asyncio
async def test_agent_as_tool_fallback_uses_current_run_items_only(
monkeypatch: pytest.MonkeyPatch,
) -> None:
agent = Agent(name="summarizer")

message = ResponseOutputMessage(
id="msg_current",
role="assistant",
status="completed",
type="message",
content=[
ResponseOutputText(
annotations=[],
text="Current run summary",
type="output_text",
logprobs=[],
)
],
)

class DummyResult:
def __init__(self) -> None:
self.final_output = ""
self.new_items = [
ToolCallOutputItem(
agent=agent,
raw_item={
"call_id": "call_current",
"output": "Current tool output",
"type": "function_call_output",
},
output="Current tool output",
),
MessageOutputItem(agent=agent, raw_item=message),
]

def to_input_list(self) -> list[dict[str, Any]]:
return [
{
"call_id": "call_old",
"output": "Old output from prior history",
"type": "function_call_output",
}
]

run_result = DummyResult()

async def fake_run(
cls,
starting_agent,
input,
*,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
):
del (
cls,
starting_agent,
input,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
)
return run_result

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

tool = agent.as_tool(
tool_name="summary_tool",
tool_description="Summarize current run output",
)
tool_context = ToolContext(
context=None,
tool_name="summary_tool",
tool_call_id="call_1",
tool_arguments='{"input": "hello"}',
)

output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}')

assert output == "Current run summary"


@pytest.mark.asyncio
async def test_agent_as_tool_fallback_returns_most_recent_current_run_output(
monkeypatch: pytest.MonkeyPatch,
) -> None:
agent = Agent(name="summarizer")

older_message = ResponseOutputMessage(
id="msg_older",
role="assistant",
status="completed",
type="message",
content=[
ResponseOutputText(
annotations=[],
text="Older message output",
type="output_text",
logprobs=[],
)
],
)

class DummyResult:
def __init__(self) -> None:
self.final_output = ""
self.new_items = [
MessageOutputItem(agent=agent, raw_item=older_message),
ToolCallOutputItem(
agent=agent,
raw_item={
"call_id": "call_current",
"output": "Newest tool output",
"type": "function_call_output",
},
output="Newest tool output",
),
]

run_result = DummyResult()

async def fake_run(
cls,
starting_agent,
input,
*,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
):
del (
cls,
starting_agent,
input,
context,
max_turns,
hooks,
run_config,
previous_response_id,
conversation_id,
session,
)
return run_result

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

tool = agent.as_tool(
tool_name="summary_tool",
tool_description="Summarize current run output",
)
tool_context = ToolContext(
context=None,
tool_name="summary_tool",
tool_call_id="call_1",
tool_arguments='{"input": "hello"}',
)

output = await tool.on_invoke_tool(tool_context, '{"input": "hello"}')

assert output == "Newest tool output"


@pytest.mark.asyncio
async def test_agent_as_tool_extractor_can_access_agent_tool_invocation(
monkeypatch: pytest.MonkeyPatch,
Expand Down
39 changes: 39 additions & 0 deletions tests/test_agent_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,45 @@ async def _cancel_tool() -> str:
]


@pytest.mark.asyncio
async def test_single_tool_call_with_cancelled_tool_reaches_final_output() -> None:
async def _cancel_tool() -> str:
raise asyncio.CancelledError("tool-cancelled")

model = FakeModel()
agent = Agent(
name="test",
model=model,
tools=[function_tool(_cancel_tool, name_override="cancel_tool")],
)

model.add_multiple_turn_outputs(
[
[get_function_tool_call("cancel_tool", "{}", call_id="call_cancel")],
[get_text_message("final answer")],
]
)

result = await Runner.run(agent, input="user_message")

assert result.final_output == "final answer"
assert len(result.raw_responses) == 2

second_turn_input = cast(list[dict[str, Any]], model.last_turn_args["input"])
tool_outputs = [
item for item in second_turn_input if item.get("type") == "function_call_output"
]
assert tool_outputs == [
{
"call_id": "call_cancel",
"output": (
"An error occurred while running the tool. Please try again. Error: tool-cancelled"
),
"type": "function_call_output",
},
]


@pytest.mark.asyncio
async def test_reasoning_item_id_policy_omits_follow_up_reasoning_ids() -> None:
model = FakeModel()
Expand Down
40 changes: 40 additions & 0 deletions tests/test_agent_runner_streamed.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,46 @@ async def _cancel_tool() -> str:
]


@pytest.mark.asyncio
async def test_streamed_single_tool_call_with_cancelled_tool_reaches_final_output() -> None:
async def _cancel_tool() -> str:
raise asyncio.CancelledError("tool-cancelled")

model = FakeModel()
agent = Agent(
name="test",
model=model,
tools=[function_tool(_cancel_tool, name_override="cancel_tool")],
)

model.add_multiple_turn_outputs(
[
[get_function_tool_call("cancel_tool", "{}", call_id="call_cancel")],
[get_text_message("final answer")],
]
)

result = Runner.run_streamed(agent, input="user_message")
await consume_stream(result)

assert result.final_output == "final answer"
assert len(result.raw_responses) == 2

second_turn_input = cast(list[dict[str, Any]], model.last_turn_args["input"])
tool_outputs = [
item for item in second_turn_input if item.get("type") == "function_call_output"
]
assert tool_outputs == [
{
"call_id": "call_cancel",
"output": (
"An error occurred while running the tool. Please try again. Error: tool-cancelled"
),
"type": "function_call_output",
},
]


@pytest.mark.asyncio
async def test_streamed_reasoning_item_id_policy_omits_follow_up_reasoning_ids() -> None:
model = FakeModel()
Expand Down
23 changes: 23 additions & 0 deletions tests/test_run_step_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,6 +740,29 @@ async def _manual_on_invoke_tool(_ctx: ToolContext[Any], _args: str) -> str:
)


@pytest.mark.asyncio
async def test_single_tool_call_uses_default_failure_error_function_for_cancelled_tool():
async def _cancel_tool() -> str:
raise asyncio.CancelledError("tool-cancelled")

cancel_tool = function_tool(_cancel_tool, name_override="cancel_tool")
agent = Agent(name="test", tools=[cancel_tool])
response = ModelResponse(
output=[get_function_tool_call("cancel_tool", "{}", call_id="1")],
usage=Usage(),
response_id=None,
)

result = await get_execute_result(agent, response)

assert len(result.generated_items) == 2
assert isinstance(result.next_step, NextStepRunAgain)
assert_item_is_function_tool_call_output(
result.generated_items[1],
"An error occurred while running the tool. Please try again. Error: tool-cancelled",
)


@pytest.mark.asyncio
async def test_multiple_tool_calls_surface_hook_failure_over_sibling_cancellation():
hook_started = asyncio.Event()
Expand Down
Loading