Skip to content

Commit aaca35e

Browse files
committed
Fix chat completions fallback output indexes
1 parent 1b7d878 commit aaca35e

2 files changed

Lines changed: 81 additions & 3 deletions

File tree

src/agents/models/chatcmpl_stream_handler.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,7 @@ async def handle_stream(
624624
)
625625

626626
# Send completion events for function calls
627+
fallback_emitted_count = 0
627628
for index, function_call in state.function_calls.items():
628629
if state.function_call_streaming.get(index, False):
629630
# Function call was streamed, just send the completion event
@@ -668,6 +669,7 @@ async def handle_stream(
668669
fallback_starting_index += sum(
669670
1 for streaming in state.function_call_streaming.values() if streaming
670671
)
672+
fallback_output_index = fallback_starting_index + fallback_emitted_count
671673

672674
# Build function call kwargs, include provider_data if present
673675
fallback_func_call_kwargs: dict[str, Any] = {
@@ -690,23 +692,24 @@ async def handle_stream(
690692
# Send all events at once (backward compatibility)
691693
yield ResponseOutputItemAddedEvent(
692694
item=ResponseFunctionToolCall(**fallback_func_call_kwargs),
693-
output_index=fallback_starting_index,
695+
output_index=fallback_output_index,
694696
type="response.output_item.added",
695697
sequence_number=sequence_number.get_and_increment(),
696698
)
697699
yield ResponseFunctionCallArgumentsDeltaEvent(
698700
delta=function_call.arguments,
699701
item_id=FAKE_RESPONSES_ID,
700-
output_index=fallback_starting_index,
702+
output_index=fallback_output_index,
701703
type="response.function_call_arguments.delta",
702704
sequence_number=sequence_number.get_and_increment(),
703705
)
704706
yield ResponseOutputItemDoneEvent(
705707
item=ResponseFunctionToolCall(**fallback_func_call_kwargs),
706-
output_index=fallback_starting_index,
708+
output_index=fallback_output_index,
707709
type="response.output_item.done",
708710
sequence_number=sequence_number.get_and_increment(),
709711
)
712+
fallback_emitted_count += 1
710713

711714
# Finally, send the Response completed event
712715
outputs: list[ResponseOutputItem] = []

tests/models/test_openai_chatcompletions_stream.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,78 @@ async def patched_fetch_response(self, *args, **kwargs):
553553
assert isinstance(function_call_output, ResponseFunctionToolCall)
554554
assert function_call_output.name == "write_file"
555555
assert function_call_output.arguments == '{"filename": "test.py", "content": "print(hello)"}'
556+
557+
558+
@pytest.mark.allow_call_model_methods
559+
@pytest.mark.asyncio
560+
async def test_stream_response_fallback_tool_calls_use_distinct_output_indexes(monkeypatch) -> None:
561+
tool_call_delta1 = ChoiceDeltaToolCall(
562+
index=0,
563+
function=ChoiceDeltaToolCallFunction(name="first_tool", arguments='{"a": 1}'),
564+
type="function",
565+
)
566+
tool_call_delta2 = ChoiceDeltaToolCall(
567+
index=1,
568+
function=ChoiceDeltaToolCallFunction(name="second_tool", arguments='{"b": 2}'),
569+
type="function",
570+
)
571+
chunk1 = ChatCompletionChunk(
572+
id="chunk-id",
573+
created=1,
574+
model="fake",
575+
object="chat.completion.chunk",
576+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))],
577+
)
578+
chunk2 = ChatCompletionChunk(
579+
id="chunk-id",
580+
created=1,
581+
model="fake",
582+
object="chat.completion.chunk",
583+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))],
584+
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
585+
)
586+
587+
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
588+
for chunk in (chunk1, chunk2):
589+
yield chunk
590+
591+
async def patched_fetch_response(self, *args, **kwargs):
592+
response = Response(
593+
id="resp-id",
594+
created_at=0,
595+
model="fake-model",
596+
object="response",
597+
output=[],
598+
tool_choice="none",
599+
tools=[],
600+
parallel_tool_calls=False,
601+
)
602+
return response, fake_stream()
603+
604+
monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
605+
model = OpenAIProvider(use_responses=False).get_model("gpt-4")
606+
output_events = []
607+
608+
async for event in model.stream_response(
609+
system_instructions=None,
610+
input="",
611+
model_settings=ModelSettings(),
612+
tools=[],
613+
output_schema=None,
614+
handoffs=[],
615+
tracing=ModelTracing.DISABLED,
616+
previous_response_id=None,
617+
conversation_id=None,
618+
prompt=None,
619+
):
620+
output_events.append(event)
621+
622+
added_events = [event for event in output_events if event.type == "response.output_item.added"]
623+
delta_events = [
624+
event for event in output_events if event.type == "response.function_call_arguments.delta"
625+
]
626+
done_events = [event for event in output_events if event.type == "response.output_item.done"]
627+
628+
assert [event.output_index for event in added_events] == [0, 1]
629+
assert [event.output_index for event in delta_events] == [0, 1]
630+
assert [event.output_index for event in done_events] == [0, 1]

0 commit comments

Comments
 (0)