Skip to content

Commit 81c57c5

Browse files
authored
fix: backfill streamed terminal output (#3000)
1 parent 638388a commit 81c57c5

2 files changed

Lines changed: 229 additions & 1 deletion

File tree

src/agents/run_internal/run_loop.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
ResponseFunctionToolCall,
1818
ResponseOutputItemDoneEvent,
1919
)
20-
from openai.types.responses.response_output_item import McpCall, McpListTools
20+
from openai.types.responses.response_output_item import McpCall, McpListTools, ResponseOutputItem
2121
from openai.types.responses.response_prompt_param import ResponsePromptParam
2222
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
2323

@@ -1337,6 +1337,7 @@ def _tool_search_fingerprint(raw_item: Any) -> str:
13371337
model_settings = maybe_reset_tool_choice(public_agent, tool_use_tracker, model_settings)
13381338

13391339
final_response: ModelResponse | None = None
1340+
streamed_response_output: list[ResponseOutputItem] = []
13401341

13411342
if server_conversation_tracker is not None:
13421343
items_for_input = (
@@ -1474,14 +1475,21 @@ async def rewind_model_request() -> None:
14741475
streamed_result._event_queue.put_nowait(RawResponsesStreamEvent(data=event))
14751476

14761477
terminal_response: Response | None = None
1478+
is_completed_event = False
14771479
if isinstance(event, ResponseCompletedEvent):
1480+
is_completed_event = True
14781481
terminal_response = event.response
14791482
elif getattr(event, "type", None) in {"response.incomplete", "response.failed"}:
14801483
maybe_response = getattr(event, "response", None)
14811484
if isinstance(maybe_response, Response):
14821485
terminal_response = maybe_response
14831486

14841487
if terminal_response is not None:
1488+
if is_completed_event and not terminal_response.output and streamed_response_output:
1489+
# Some streaming backends emit output items during item.done events while leaving
1490+
# the terminal response output empty. Preserve those items so the runner can
1491+
# resolve the completed step correctly.
1492+
terminal_response.output = list(streamed_response_output)
14851493
usage = (
14861494
apply_retry_attempt_usage(
14871495
Usage(
@@ -1506,6 +1514,7 @@ async def rewind_model_request() -> None:
15061514

15071515
if isinstance(event, ResponseOutputItemDoneEvent):
15081516
output_item = event.item
1517+
streamed_response_output.append(output_item)
15091518
output_item_type = getattr(output_item, "type", None)
15101519

15111520
if output_item_type == "tool_search_call":
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
from __future__ import annotations
2+
3+
import json
4+
from collections.abc import AsyncIterator
5+
from typing import Any
6+
7+
import pytest
8+
from openai.types.responses import (
9+
ResponseCompletedEvent,
10+
ResponseCreatedEvent,
11+
ResponseInProgressEvent,
12+
ResponseOutputItemDoneEvent,
13+
)
14+
15+
from agents import Agent, Runner
16+
from agents.agent_output import AgentOutputSchemaBase
17+
from agents.handoffs import Handoff
18+
from agents.items import TResponseInputItem, TResponseOutputItem, TResponseStreamEvent
19+
from agents.model_settings import ModelSettings
20+
from agents.models.interface import ModelTracing
21+
from agents.tool import Tool, function_tool
22+
23+
from .fake_model import FakeModel, get_response_obj
24+
from .test_responses import get_final_output_message, get_function_tool_call
25+
26+
27+
class TerminalOutputStreamModel(FakeModel):
28+
def __init__(self) -> None:
29+
super().__init__()
30+
self.terminal_turn_outputs: list[list[TResponseOutputItem]] = []
31+
32+
def add_terminal_turn_outputs(
33+
self,
34+
outputs: list[list[TResponseOutputItem]],
35+
) -> None:
36+
self.terminal_turn_outputs.extend(outputs)
37+
38+
def get_next_terminal_output(self) -> list[TResponseOutputItem]:
39+
if not self.terminal_turn_outputs:
40+
return []
41+
return self.terminal_turn_outputs.pop(0)
42+
43+
async def stream_response(
44+
self,
45+
system_instructions: str | None,
46+
input: str | list[TResponseInputItem],
47+
model_settings: ModelSettings,
48+
tools: list[Tool],
49+
output_schema: AgentOutputSchemaBase | None,
50+
handoffs: list[Handoff],
51+
tracing: ModelTracing,
52+
*,
53+
previous_response_id: str | None = None,
54+
conversation_id: str | None = None,
55+
prompt: Any | None = None,
56+
) -> AsyncIterator[TResponseStreamEvent]:
57+
turn_args = {
58+
"system_instructions": system_instructions,
59+
"input": input,
60+
"model_settings": model_settings,
61+
"tools": tools,
62+
"output_schema": output_schema,
63+
"previous_response_id": previous_response_id,
64+
"conversation_id": conversation_id,
65+
}
66+
67+
if self.first_turn_args is None:
68+
self.first_turn_args = turn_args.copy()
69+
70+
self.last_turn_args = turn_args
71+
streamed_output = self.get_next_output()
72+
if isinstance(streamed_output, Exception):
73+
raise streamed_output
74+
75+
terminal_response = get_response_obj(
76+
self.get_next_terminal_output(),
77+
usage=self.hardcoded_usage,
78+
)
79+
sequence_number = 0
80+
81+
yield ResponseCreatedEvent(
82+
type="response.created",
83+
response=terminal_response,
84+
sequence_number=sequence_number,
85+
)
86+
sequence_number += 1
87+
88+
yield ResponseInProgressEvent(
89+
type="response.in_progress",
90+
response=terminal_response,
91+
sequence_number=sequence_number,
92+
)
93+
sequence_number += 1
94+
95+
for output_index, output_item in enumerate(streamed_output):
96+
yield ResponseOutputItemDoneEvent(
97+
type="response.output_item.done",
98+
item=output_item,
99+
output_index=output_index,
100+
sequence_number=sequence_number,
101+
)
102+
sequence_number += 1
103+
104+
yield ResponseCompletedEvent(
105+
type="response.completed",
106+
response=terminal_response,
107+
sequence_number=sequence_number,
108+
)
109+
110+
111+
@pytest.mark.asyncio
112+
async def test_streamed_runner_backfills_empty_terminal_output_before_step_resolution() -> None:
113+
tool_inputs: list[str] = []
114+
115+
async def test_tool(a: str) -> str:
116+
tool_inputs.append(a)
117+
return "tool_result"
118+
119+
tool = function_tool(test_tool, name_override="foo")
120+
model = TerminalOutputStreamModel()
121+
agent = Agent(name="test", model=model, tools=[tool])
122+
123+
model.add_multiple_turn_outputs(
124+
[
125+
[get_function_tool_call("foo", json.dumps({"a": "b"}), call_id="call-1")],
126+
[get_final_output_message("done")],
127+
]
128+
)
129+
model.add_terminal_turn_outputs(
130+
[
131+
[],
132+
[get_final_output_message("done")],
133+
]
134+
)
135+
136+
result = Runner.run_streamed(agent, input="test")
137+
async for _ in result.stream_events():
138+
pass
139+
140+
assert tool_inputs == ["b"]
141+
assert [item.type for item in result.raw_responses[0].output] == ["function_call"]
142+
assert result.final_output == "done"
143+
144+
145+
@pytest.mark.asyncio
146+
async def test_streamed_runner_preserves_populated_terminal_output() -> None:
147+
tool_inputs: list[str] = []
148+
149+
async def test_tool(a: str) -> str:
150+
tool_inputs.append(a)
151+
return "tool_result"
152+
153+
tool = function_tool(test_tool, name_override="foo")
154+
model = TerminalOutputStreamModel()
155+
agent = Agent(name="test", model=model, tools=[tool])
156+
157+
model.add_multiple_turn_outputs(
158+
[
159+
[get_function_tool_call("foo", json.dumps({"a": "b"}), call_id="call-1")],
160+
]
161+
)
162+
model.add_terminal_turn_outputs(
163+
[
164+
[get_final_output_message("done")],
165+
]
166+
)
167+
168+
result = Runner.run_streamed(agent, input="test")
169+
async for _ in result.stream_events():
170+
pass
171+
172+
assert tool_inputs == []
173+
assert [item.type for item in result.raw_responses[0].output] == ["message"]
174+
assert result.final_output == "done"
175+
176+
177+
@pytest.mark.asyncio
178+
async def test_streamed_runner_backfills_multiple_tool_calls_in_order() -> None:
179+
tool_inputs: list[tuple[str, str]] = []
180+
181+
async def foo_tool(a: str) -> str:
182+
tool_inputs.append(("foo", a))
183+
return "foo_result"
184+
185+
async def bar_tool(b: str) -> str:
186+
tool_inputs.append(("bar", b))
187+
return "bar_result"
188+
189+
foo = function_tool(foo_tool, name_override="foo")
190+
bar = function_tool(bar_tool, name_override="bar")
191+
model = TerminalOutputStreamModel()
192+
agent = Agent(name="test", model=model, tools=[foo, bar])
193+
194+
model.add_multiple_turn_outputs(
195+
[
196+
[
197+
get_function_tool_call("foo", json.dumps({"a": "first"}), call_id="call-1"),
198+
get_function_tool_call("bar", json.dumps({"b": "second"}), call_id="call-2"),
199+
],
200+
[get_final_output_message("done")],
201+
]
202+
)
203+
model.add_terminal_turn_outputs(
204+
[
205+
[],
206+
[get_final_output_message("done")],
207+
]
208+
)
209+
210+
result = Runner.run_streamed(agent, input="test")
211+
async for _ in result.stream_events():
212+
pass
213+
214+
assert tool_inputs == [("foo", "first"), ("bar", "second")]
215+
assert [item.type for item in result.raw_responses[0].output] == [
216+
"function_call",
217+
"function_call",
218+
]
219+
assert result.final_output == "done"

0 commit comments

Comments
 (0)