Skip to content

Commit ba02386

Browse files
committed
fix: stabilize chat completions tool call indexes
1 parent d1d0abe commit ba02386

3 files changed

Lines changed: 324 additions & 40 deletions

File tree

src/agents/models/chatcmpl_stream_handler.py

Lines changed: 19 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,19 @@ class ChatCmplStreamHandler:
8888
def _assistant_message_output_index(state: StreamingState) -> int:
8989
return 1 if state.reasoning_content_index_and_output is not None else 0
9090

91+
@staticmethod
92+
def _function_call_output_base(state: StreamingState) -> int:
93+
output_index = 0
94+
if state.reasoning_content_index_and_output:
95+
output_index += 1
96+
if state.text_content_index_and_output or state.refusal_content_index_and_output:
97+
output_index += 1
98+
return output_index
99+
100+
@classmethod
101+
def _next_function_call_output_index(cls, state: StreamingState) -> int:
102+
return cls._function_call_output_base(state) + len(state.function_calls)
103+
91104
@classmethod
92105
def _finish_reasoning_summary_part(
93106
cls,
@@ -447,6 +460,9 @@ async def handle_stream(
447460
if delta.tool_calls:
448461
for tc_delta in delta.tool_calls:
449462
if tc_delta.index not in state.function_calls:
463+
state.function_call_output_idx[tc_delta.index] = (
464+
cls._next_function_call_output_index(state)
465+
)
450466
state.function_calls[tc_delta.index] = ResponseFunctionToolCall(
451467
id=FAKE_RESPONSES_ID,
452468
arguments="",
@@ -526,25 +542,9 @@ async def handle_stream(
526542
and function_call.name
527543
and function_call.call_id
528544
):
529-
# Calculate the output index for this function call
530-
function_call_starting_index = 0
531-
if state.reasoning_content_index_and_output:
532-
function_call_starting_index += 1
533-
if state.text_content_index_and_output:
534-
function_call_starting_index += 1
535-
if state.refusal_content_index_and_output:
536-
function_call_starting_index += 1
537-
538-
# Add offset for already started function calls
539-
function_call_starting_index += sum(
540-
1 for streaming in state.function_call_streaming.values() if streaming
541-
)
542-
543545
# Mark this function call as streaming and store its output index
544546
state.function_call_streaming[tc_delta.index] = True
545-
state.function_call_output_idx[tc_delta.index] = (
546-
function_call_starting_index
547-
)
547+
function_call_output_index = state.function_call_output_idx[tc_delta.index]
548548

549549
# Send initial function call added event
550550
func_call_item = ResponseFunctionToolCall(
@@ -569,7 +569,7 @@ async def handle_stream(
569569
func_call_item.provider_data = merged_provider_data # type: ignore[attr-defined]
570570
yield ResponseOutputItemAddedEvent(
571571
item=func_call_item,
572-
output_index=function_call_starting_index,
572+
output_index=function_call_output_index,
573573
type="response.output_item.added",
574574
sequence_number=sequence_number.get_and_increment(),
575575
)
@@ -592,12 +592,7 @@ async def handle_stream(
592592
for event in cls._finish_reasoning_item(state, sequence_number):
593593
yield event
594594

595-
function_call_starting_index = 0
596-
if state.reasoning_content_index_and_output:
597-
function_call_starting_index += 1
598-
599595
if state.text_content_index_and_output:
600-
function_call_starting_index += 1
601596
# Send end event for this content part
602597
yield ResponseContentPartDoneEvent(
603598
content_index=state.text_content_index_and_output[0],
@@ -609,7 +604,6 @@ async def handle_stream(
609604
)
610605

611606
if state.refusal_content_index_and_output:
612-
function_call_starting_index += 1
613607
# Send end event for this content part
614608
yield ResponseContentPartDoneEvent(
615609
content_index=state.refusal_content_index_and_output[0],
@@ -621,7 +615,6 @@ async def handle_stream(
621615
)
622616

623617
# Send completion events for function calls
624-
fallback_emitted_count = 0
625618
for index, function_call in state.function_calls.items():
626619
if state.function_call_streaming.get(index, False):
627620
# Function call was streamed, just send the completion event
@@ -654,19 +647,7 @@ async def handle_stream(
654647
else:
655648
# Function call was not streamed (fallback to old behavior)
656649
# This handles edge cases where function name never arrived
657-
fallback_starting_index = 0
658-
if state.reasoning_content_index_and_output:
659-
fallback_starting_index += 1
660-
if state.text_content_index_and_output:
661-
fallback_starting_index += 1
662-
if state.refusal_content_index_and_output:
663-
fallback_starting_index += 1
664-
665-
# Add offset for already started function calls
666-
fallback_starting_index += sum(
667-
1 for streaming in state.function_call_streaming.values() if streaming
668-
)
669-
fallback_output_index = fallback_starting_index + fallback_emitted_count
650+
fallback_output_index = state.function_call_output_idx[index]
670651

671652
# Build function call kwargs, include provider_data if present
672653
fallback_func_call_kwargs: dict[str, Any] = {
@@ -706,7 +687,6 @@ async def handle_stream(
706687
type="response.output_item.done",
707688
sequence_number=sequence_number.get_and_increment(),
708689
)
709-
fallback_emitted_count += 1
710690

711691
# Finally, send the Response completed event
712692
outputs: list[ResponseOutputItem] = []

tests/models/test_openai_chatcompletions_stream.py

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,3 +628,211 @@ async def patched_fetch_response(self, *args, **kwargs):
628628
assert [event.output_index for event in added_events] == [0, 1]
629629
assert [event.output_index for event in delta_events] == [0, 1]
630630
assert [event.output_index for event in done_events] == [0, 1]
631+
632+
633+
@pytest.mark.allow_call_model_methods
634+
@pytest.mark.asyncio
635+
async def test_stream_response_mixed_tool_calls_use_final_output_indexes(monkeypatch) -> None:
636+
fallback_tool_call = ChoiceDeltaToolCall(
637+
index=0,
638+
function=ChoiceDeltaToolCallFunction(name="first_tool", arguments='{"a": 1}'),
639+
type="function",
640+
)
641+
streamed_tool_call = ChoiceDeltaToolCall(
642+
index=1,
643+
id="second-tool-call-id",
644+
function=ChoiceDeltaToolCallFunction(name="second_tool", arguments='{"b": 2}'),
645+
type="function",
646+
)
647+
chunk1 = ChatCompletionChunk(
648+
id="chunk-id",
649+
created=1,
650+
model="fake",
651+
object="chat.completion.chunk",
652+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[fallback_tool_call]))],
653+
)
654+
chunk2 = ChatCompletionChunk(
655+
id="chunk-id",
656+
created=1,
657+
model="fake",
658+
object="chat.completion.chunk",
659+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[streamed_tool_call]))],
660+
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
661+
)
662+
663+
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
664+
for chunk in (chunk1, chunk2):
665+
yield chunk
666+
667+
async def patched_fetch_response(self, *args, **kwargs):
668+
response = Response(
669+
id="resp-id",
670+
created_at=0,
671+
model="fake-model",
672+
object="response",
673+
output=[],
674+
tool_choice="none",
675+
tools=[],
676+
parallel_tool_calls=False,
677+
)
678+
return response, fake_stream()
679+
680+
monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
681+
model = OpenAIProvider(use_responses=False).get_model("gpt-4")
682+
output_events = []
683+
684+
async for event in model.stream_response(
685+
system_instructions=None,
686+
input="",
687+
model_settings=ModelSettings(),
688+
tools=[],
689+
output_schema=None,
690+
handoffs=[],
691+
tracing=ModelTracing.DISABLED,
692+
previous_response_id=None,
693+
conversation_id=None,
694+
prompt=None,
695+
):
696+
output_events.append(event)
697+
698+
added_events = [event for event in output_events if event.type == "response.output_item.added"]
699+
delta_events = [
700+
event for event in output_events if event.type == "response.function_call_arguments.delta"
701+
]
702+
done_events = [event for event in output_events if event.type == "response.output_item.done"]
703+
completed_event = next(event for event in output_events if event.type == "response.completed")
704+
705+
added_event_indexes = {}
706+
for event in added_events:
707+
assert isinstance(event.item, ResponseFunctionToolCall)
708+
added_event_indexes[event.item.name] = event.output_index
709+
710+
done_event_indexes = {}
711+
for event in done_events:
712+
assert isinstance(event.item, ResponseFunctionToolCall)
713+
done_event_indexes[event.item.name] = event.output_index
714+
715+
completed_output_names = []
716+
for output in completed_event.response.output:
717+
assert isinstance(output, ResponseFunctionToolCall)
718+
completed_output_names.append(output.name)
719+
720+
assert added_event_indexes == {
721+
"first_tool": 0,
722+
"second_tool": 1,
723+
}
724+
assert {event.delta: event.output_index for event in delta_events} == {
725+
'{"a": 1}': 0,
726+
'{"b": 2}': 1,
727+
}
728+
assert done_event_indexes == {
729+
"first_tool": 0,
730+
"second_tool": 1,
731+
}
732+
assert completed_output_names == ["first_tool", "second_tool"]
733+
734+
735+
@pytest.mark.allow_call_model_methods
736+
@pytest.mark.asyncio
737+
async def test_stream_response_text_before_mixed_tool_calls_offsets_tool_indexes(
738+
monkeypatch,
739+
) -> None:
740+
fallback_tool_call = ChoiceDeltaToolCall(
741+
index=0,
742+
function=ChoiceDeltaToolCallFunction(name="first_tool", arguments='{"a": 1}'),
743+
type="function",
744+
)
745+
streamed_tool_call = ChoiceDeltaToolCall(
746+
index=1,
747+
id="second-tool-call-id",
748+
function=ChoiceDeltaToolCallFunction(name="second_tool", arguments='{"b": 2}'),
749+
type="function",
750+
)
751+
chunk1 = ChatCompletionChunk(
752+
id="chunk-id",
753+
created=1,
754+
model="fake",
755+
object="chat.completion.chunk",
756+
choices=[Choice(index=0, delta=ChoiceDelta(content="Preparing tools"))],
757+
)
758+
chunk2 = ChatCompletionChunk(
759+
id="chunk-id",
760+
created=1,
761+
model="fake",
762+
object="chat.completion.chunk",
763+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[fallback_tool_call]))],
764+
)
765+
chunk3 = ChatCompletionChunk(
766+
id="chunk-id",
767+
created=1,
768+
model="fake",
769+
object="chat.completion.chunk",
770+
choices=[Choice(index=0, delta=ChoiceDelta(tool_calls=[streamed_tool_call]))],
771+
usage=CompletionUsage(completion_tokens=1, prompt_tokens=1, total_tokens=2),
772+
)
773+
774+
async def fake_stream() -> AsyncIterator[ChatCompletionChunk]:
775+
for chunk in (chunk1, chunk2, chunk3):
776+
yield chunk
777+
778+
async def patched_fetch_response(self, *args, **kwargs):
779+
response = Response(
780+
id="resp-id",
781+
created_at=0,
782+
model="fake-model",
783+
object="response",
784+
output=[],
785+
tool_choice="none",
786+
tools=[],
787+
parallel_tool_calls=False,
788+
)
789+
return response, fake_stream()
790+
791+
monkeypatch.setattr(OpenAIChatCompletionsModel, "_fetch_response", patched_fetch_response)
792+
model = OpenAIProvider(use_responses=False).get_model("gpt-4")
793+
output_events = []
794+
795+
async for event in model.stream_response(
796+
system_instructions=None,
797+
input="",
798+
model_settings=ModelSettings(),
799+
tools=[],
800+
output_schema=None,
801+
handoffs=[],
802+
tracing=ModelTracing.DISABLED,
803+
previous_response_id=None,
804+
conversation_id=None,
805+
prompt=None,
806+
):
807+
output_events.append(event)
808+
809+
added_events = [event for event in output_events if event.type == "response.output_item.added"]
810+
delta_events = [
811+
event for event in output_events if event.type == "response.function_call_arguments.delta"
812+
]
813+
done_events = [event for event in output_events if event.type == "response.output_item.done"]
814+
completed_event = next(event for event in output_events if event.type == "response.completed")
815+
816+
added_tool_indexes = {}
817+
for event in added_events:
818+
if isinstance(event.item, ResponseFunctionToolCall):
819+
added_tool_indexes[event.item.name] = event.output_index
820+
821+
done_tool_indexes = {}
822+
for event in done_events:
823+
if isinstance(event.item, ResponseFunctionToolCall):
824+
done_tool_indexes[event.item.name] = event.output_index
825+
826+
assert added_tool_indexes == {"first_tool": 1, "second_tool": 2}
827+
assert {event.delta: event.output_index for event in delta_events} == {
828+
'{"a": 1}': 1,
829+
'{"b": 2}': 2,
830+
}
831+
assert done_tool_indexes == {"first_tool": 1, "second_tool": 2}
832+
assert isinstance(completed_event.response.output[0], ResponseOutputMessage)
833+
completed_tool_outputs = completed_event.response.output[1:]
834+
completed_tool_names = []
835+
for output in completed_tool_outputs:
836+
assert isinstance(output, ResponseFunctionToolCall)
837+
completed_tool_names.append(output.name)
838+
assert completed_tool_names == ["first_tool", "second_tool"]

0 commit comments

Comments
 (0)