Skip to content

Commit 8c8a2eb

Browse files
authored
fix: #3104 stabilize chat completions tool call output indexes (#3161)
1 parent ff8e3db commit 8c8a2eb

2 files changed

Lines changed: 228 additions & 39 deletions

File tree

src/agents/models/chatcmpl_stream_handler.py

Lines changed: 23 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class StreamingState:
6565
function_calls: dict[int, ResponseFunctionToolCall] = field(default_factory=dict)
6666
# Fields for real-time function call streaming
6767
function_call_streaming: dict[int, bool] = field(default_factory=dict)
68+
# Stable output indexes for function calls, including fallback calls.
6869
function_call_output_idx: dict[int, int] = field(default_factory=dict)
6970
# Store accumulated thinking text and signature for Anthropic compatibility
7071
thinking_text: str = ""
@@ -145,6 +146,17 @@ def _finish_reasoning_item(
145146
)
146147
state.reasoning_item_done = True
147148

149+
@staticmethod
150+
def _function_call_starting_index(state: StreamingState) -> int:
151+
starting_index = 0
152+
if state.reasoning_content_index_and_output:
153+
starting_index += 1
154+
if state.text_content_index_and_output:
155+
starting_index += 1
156+
if state.refusal_content_index_and_output:
157+
starting_index += 1
158+
return starting_index
159+
148160
@classmethod
149161
async def handle_stream(
150162
cls,
@@ -456,6 +468,10 @@ async def handle_stream(
456468
call_id="",
457469
)
458470
state.function_call_streaming[tc_delta.index] = False
471+
state.function_call_output_idx[tc_delta.index] = (
472+
cls._function_call_starting_index(state)
473+
+ len(state.function_call_output_idx)
474+
)
459475

460476
tc_function = tc_delta.function
461477

@@ -527,25 +543,10 @@ async def handle_stream(
527543
and function_call.name
528544
and function_call.call_id
529545
):
530-
# Calculate the output index for this function call
531-
function_call_starting_index = 0
532-
if state.reasoning_content_index_and_output:
533-
function_call_starting_index += 1
534-
if state.text_content_index_and_output:
535-
function_call_starting_index += 1
536-
if state.refusal_content_index_and_output:
537-
function_call_starting_index += 1
538-
539-
# Add offset for already started function calls
540-
function_call_starting_index += sum(
541-
1 for streaming in state.function_call_streaming.values() if streaming
542-
)
546+
output_index = state.function_call_output_idx[tc_delta.index]
543547

544-
# Mark this function call as streaming and store its output index
548+
# Mark this function call as streaming.
545549
state.function_call_streaming[tc_delta.index] = True
546-
state.function_call_output_idx[tc_delta.index] = (
547-
function_call_starting_index
548-
)
549550

550551
# Send initial function call added event
551552
func_call_item = ResponseFunctionToolCall(
@@ -570,7 +571,7 @@ async def handle_stream(
570571
func_call_item.provider_data = merged_provider_data # type: ignore[attr-defined]
571572
yield ResponseOutputItemAddedEvent(
572573
item=func_call_item,
573-
output_index=function_call_starting_index,
574+
output_index=output_index,
574575
type="response.output_item.added",
575576
sequence_number=sequence_number.get_and_increment(),
576577
)
@@ -593,12 +594,7 @@ async def handle_stream(
593594
for event in cls._finish_reasoning_item(state, sequence_number):
594595
yield event
595596

596-
function_call_starting_index = 0
597-
if state.reasoning_content_index_and_output:
598-
function_call_starting_index += 1
599-
600597
if state.text_content_index_and_output:
601-
function_call_starting_index += 1
602598
# Send end event for this content part
603599
yield ResponseContentPartDoneEvent(
604600
content_index=state.text_content_index_and_output[0],
@@ -611,7 +607,6 @@ async def handle_stream(
611607
)
612608

613609
if state.refusal_content_index_and_output:
614-
function_call_starting_index += 1
615610
# Send end event for this content part
616611
yield ResponseContentPartDoneEvent(
617612
content_index=state.refusal_content_index_and_output[0],
@@ -656,18 +651,7 @@ async def handle_stream(
656651
else:
657652
# Function call was not streamed (fallback to old behavior)
658653
# This handles edge cases where function name never arrived
659-
fallback_starting_index = 0
660-
if state.reasoning_content_index_and_output:
661-
fallback_starting_index += 1
662-
if state.text_content_index_and_output:
663-
fallback_starting_index += 1
664-
if state.refusal_content_index_and_output:
665-
fallback_starting_index += 1
666-
667-
# Add offset for already started function calls
668-
fallback_starting_index += sum(
669-
1 for streaming in state.function_call_streaming.values() if streaming
670-
)
654+
output_index = state.function_call_output_idx[index]
671655

672656
# Build function call kwargs, include provider_data if present
673657
fallback_func_call_kwargs: dict[str, Any] = {
@@ -690,20 +674,20 @@ async def handle_stream(
690674
# Send all events at once (backward compatibility)
691675
yield ResponseOutputItemAddedEvent(
692676
item=ResponseFunctionToolCall(**fallback_func_call_kwargs),
693-
output_index=fallback_starting_index,
677+
output_index=output_index,
694678
type="response.output_item.added",
695679
sequence_number=sequence_number.get_and_increment(),
696680
)
697681
yield ResponseFunctionCallArgumentsDeltaEvent(
698682
delta=function_call.arguments,
699683
item_id=FAKE_RESPONSES_ID,
700-
output_index=fallback_starting_index,
684+
output_index=output_index,
701685
type="response.function_call_arguments.delta",
702686
sequence_number=sequence_number.get_and_increment(),
703687
)
704688
yield ResponseOutputItemDoneEvent(
705689
item=ResponseFunctionToolCall(**fallback_func_call_kwargs),
706-
output_index=fallback_starting_index,
690+
output_index=output_index,
707691
type="response.output_item.done",
708692
sequence_number=sequence_number.get_and_increment(),
709693
)

tests/models/test_openai_chatcompletions_stream.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,3 +553,208 @@ async def patched_fetch_response(self, *args, **kwargs):
553553
assert isinstance(function_call_output, ResponseFunctionToolCall)
554554
assert function_call_output.name == "write_file"
555555
assert function_call_output.arguments == '{"filename": "test.py", "content": "print(hello)"}'
556+
557+
558+
@pytest.mark.allow_call_model_methods
559+
@pytest.mark.asyncio
560+
async def test_fallback_function_calls_have_unique_output_indexes(monkeypatch) -> None:
561+
tool_call_delta1 = ChoiceDeltaToolCall(
562+
index=0,
563+
function=ChoiceDeltaToolCallFunction(
564+
name="first_tool",
565+
arguments='{"a": 1}',
566+
),
567+
type="function",
568+
)
569+
tool_call_delta2 = ChoiceDeltaToolCall(
570+
index=1,
571+
function=ChoiceDeltaToolCallFunction(
572+
name="second_tool",
573+
arguments='{"b": 2}',
574+
),
575+
type="function",
576+
)
577+
578+
chunk1 = ChatCompletionChunk(
579+
id="chunk-id",
580+
created=1,
581+
model="fake",
582+
object="chat.completion.chunk",
583+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta1]))],
584+
)
585+
chunk2 = ChatCompletionChunk(
586+
id="chunk-id",
587+
created=1,
588+
model="fake",
589+
object="chat.completion.chunk",
590+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[tool_call_delta2]))],
591+
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
592+
)
593+
594+
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
595+
for c in (chunk1, chunk2):
596+
yield c
597+
598+
async def patched_fetch_response(self, *args, **kwargs):
599+
resp = Response(
600+
id="resp-id",
601+
created_at=0,
602+
model="fake-model",
603+
object="response",
604+
output=[],
605+
tool_choice="none",
606+
tools=[],
607+
parallel_tool_calls=False,
608+
)
609+
return resp, fake_stream()
610+
611+
monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
612+
model = OpenAIProvider(use_responses=False).get_model("gpt-4")
613+
614+
output_events = []
615+
async for event in model.stream_response(
616+
system_instructions=None,
617+
input="",
618+
model_settings=ModelSettings(),
619+
tools=[],
620+
output_schema=None,
621+
handoffs=[],
622+
tracing=ModelTracing.DISABLED,
623+
previous_response_id=None,
624+
conversation_id=None,
625+
prompt=None,
626+
):
627+
output_events.append(event)
628+
629+
added_indexes = [
630+
event.output_index for event in output_events if event.type == "response.output_item.added"
631+
]
632+
delta_indexes = [
633+
event.output_index
634+
for event in output_events
635+
if event.type == "response.function_call_arguments.delta"
636+
]
637+
done_indexes = [
638+
event.output_index for event in output_events if event.type == "response.output_item.done"
639+
]
640+
641+
assert added_indexes == [0, 1]
642+
assert delta_indexes == [0, 1]
643+
assert done_indexes == [0, 1]
644+
645+
646+
@pytest.mark.allow_call_model_methods
647+
@pytest.mark.asyncio
648+
async def test_fallback_function_call_keeps_index_before_streamed_call(monkeypatch) -> None:
649+
fallback_first = ChoiceDeltaToolCall(
650+
index=0,
651+
function=ChoiceDeltaToolCallFunction(
652+
name="fallback_first",
653+
arguments='{"a": 1}',
654+
),
655+
type="function",
656+
)
657+
streamed_second_start = ChoiceDeltaToolCall(
658+
index=1,
659+
id="tool-call-2",
660+
function=ChoiceDeltaToolCallFunction(
661+
name="streamed_second",
662+
arguments="",
663+
),
664+
type="function",
665+
)
666+
streamed_second_args = ChoiceDeltaToolCall(
667+
index=1,
668+
function=ChoiceDeltaToolCallFunction(arguments='{"b": 2}'),
669+
type="function",
670+
)
671+
672+
chunk1 = ChatCompletionChunk(
673+
id="chunk-id",
674+
created=1,
675+
model="fake",
676+
object="chat.completion.chunk",
677+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[fallback_first]))],
678+
)
679+
chunk2 = ChatCompletionChunk(
680+
id="chunk-id",
681+
created=1,
682+
model="fake",
683+
object="chat.completion.chunk",
684+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[streamed_second_start]))],
685+
)
686+
chunk3 = ChatCompletionChunk(
687+
id="chunk-id",
688+
created=1,
689+
model="fake",
690+
object="chat.completion.chunk",
691+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[streamed_second_args]))],
692+
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
693+
)
694+
695+
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
696+
for c in (chunk1, chunk2, chunk3):
697+
yield c
698+
699+
async def patched_fetch_response(self, *args, **kwargs):
700+
resp = Response(
701+
id="resp-id",
702+
created_at=0,
703+
model="fake-model",
704+
object="response",
705+
output=[],
706+
tool_choice="none",
707+
tools=[],
708+
parallel_tool_calls=False,
709+
)
710+
return resp, fake_stream()
711+
712+
monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
713+
model = OpenAIProvider(use_responses=False).get_model("gpt-4")
714+
715+
output_events = []
716+
async for event in model.stream_response(
717+
system_instructions=None,
718+
input="",
719+
model_settings=ModelSettings(),
720+
tools=[],
721+
output_schema=None,
722+
handoffs=[],
723+
tracing=ModelTracing.DISABLED,
724+
previous_response_id=None,
725+
conversation_id=None,
726+
prompt=None,
727+
):
728+
output_events.append(event)
729+
730+
completed = next(
731+
event.response for event in output_events if event.type == "response.completed"
732+
)
733+
assert [
734+
item.name for item in completed.output if isinstance(item, ResponseFunctionToolCall)
735+
] == [
736+
"fallback_first",
737+
"streamed_second",
738+
]
739+
740+
added_by_name = {
741+
event.item.name: event.output_index
742+
for event in output_events
743+
if event.type == "response.output_item.added"
744+
and isinstance(event.item, ResponseFunctionToolCall)
745+
}
746+
delta_indexes = [
747+
event.output_index
748+
for event in output_events
749+
if event.type == "response.function_call_arguments.delta"
750+
]
751+
done_by_name = {
752+
event.item.name: event.output_index
753+
for event in output_events
754+
if event.type == "response.output_item.done"
755+
and isinstance(event.item, ResponseFunctionToolCall)
756+
}
757+
758+
assert added_by_name == {"fallback_first": 0, "streamed_second": 1}
759+
assert delta_indexes == [1, 0]
760+
assert done_by_name == {"streamed_second": 1, "fallback_first": 0}

0 commit comments

Comments
 (0)