Skip to content

Commit 5fa3fd4

Browse files
Recover streamed nested agent output before cancellation fallback (#2714)
Co-authored-by: Kazuhiro Sera <seratch@openai.com>
1 parent 05dc068 commit 5fa3fd4

File tree

5 files changed

+400
-14
lines changed

5 files changed

+400
-14
lines changed

src/agents/agent.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -789,25 +789,37 @@ async def dispatch_stream_events() -> None:
789789
break
790790

791791
dispatch_task = asyncio.create_task(dispatch_stream_events())
792+
stream_iteration_cancelled = False
792793

793794
try:
794795
from .stream_events import AgentUpdatedStreamEvent
795796

796797
current_agent = run_result_streaming.current_agent
797-
async for event in run_result_streaming.stream_events():
798-
if isinstance(event, AgentUpdatedStreamEvent):
799-
current_agent = event.new_agent
800-
801-
payload: AgentToolStreamEvent = {
802-
"event": event,
803-
"agent": current_agent,
804-
"tool_call": context.tool_call,
805-
}
806-
await event_queue.put(payload)
798+
try:
799+
async for event in run_result_streaming.stream_events():
800+
if isinstance(event, AgentUpdatedStreamEvent):
801+
current_agent = event.new_agent
802+
803+
payload: AgentToolStreamEvent = {
804+
"event": event,
805+
"agent": current_agent,
806+
"tool_call": context.tool_call,
807+
}
808+
await event_queue.put(payload)
809+
except asyncio.CancelledError:
810+
stream_iteration_cancelled = True
811+
raise
807812
finally:
808-
await event_queue.put(None)
809-
await event_queue.join()
810-
await dispatch_task
813+
if stream_iteration_cancelled:
814+
dispatch_task.cancel()
815+
try:
816+
await dispatch_task
817+
except asyncio.CancelledError:
818+
pass
819+
else:
820+
await event_queue.put(None)
821+
await event_queue.join()
822+
await dispatch_task
811823
run_result = run_result_streaming
812824
else:
813825
run_result = await Runner.run(

src/agents/items.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,19 @@ def extract_last_text(cls, message: TResponseOutputItem) -> str | None:
675675

676676
return None
677677

678+
@classmethod
679+
def extract_text(cls, message: TResponseOutputItem) -> str | None:
680+
"""Extracts all text content from a message, if any. Ignores refusals."""
681+
if not isinstance(message, ResponseOutputMessage):
682+
return None
683+
684+
text = ""
685+
for content_item in message.content:
686+
if isinstance(content_item, ResponseOutputText):
687+
text += content_item.text
688+
689+
return text or None
690+
678691
@classmethod
679692
def input_to_new_input_list(
680693
cls, input: str | list[TResponseInputItem]

src/agents/run_internal/turn_resolution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,7 @@ async def execute_tools_and_side_effects(
615615

616616
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
617617
potential_final_output_text = (
618-
ItemHelpers.extract_last_text(message_items[-1].raw_item) if message_items else None
618+
ItemHelpers.extract_text(message_items[-1].raw_item) if message_items else None
619619
)
620620

621621
if not processed_response.has_tools_or_approvals_to_run():

0 commit comments

Comments
 (0)