Skip to content
Closed
94 changes: 39 additions & 55 deletions src/agents/models/chatcmpl_stream_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,27 @@ def get_and_increment(self) -> int:


class ChatCmplStreamHandler:
@staticmethod
def _assistant_message_output_index(state: StreamingState) -> int:
return 1 if state.reasoning_content_index_and_output is not None else 0

@staticmethod
def _function_call_output_base(state: StreamingState) -> int:
output_index = 0
if state.reasoning_content_index_and_output:
output_index += 1
if state.text_content_index_and_output or state.refusal_content_index_and_output:
output_index += 1
return output_index

@classmethod
def _function_call_output_index(cls, state: StreamingState, function_call_index: int) -> int:
for offset, index in enumerate(state.function_calls):
if index == function_call_index:
return cls._function_call_output_base(state) + offset

raise KeyError(f"Function call index {function_call_index} has not been tracked")

@classmethod
def _finish_reasoning_summary_part(
cls,
Expand Down Expand Up @@ -341,16 +362,14 @@ async def handle_stream(
# Notify consumers of the start of a new output message + first content part
yield ResponseOutputItemAddedEvent(
item=assistant_item,
output_index=state.reasoning_content_index_and_output
is not None, # fixed 0 -> 0 or 1
output_index=cls._assistant_message_output_index(state),
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)
yield ResponseContentPartAddedEvent(
content_index=state.text_content_index_and_output[0],
item_id=FAKE_RESPONSES_ID,
output_index=state.reasoning_content_index_and_output
is not None, # fixed 0 -> 0 or 1
output_index=cls._assistant_message_output_index(state),
part=ResponseOutputText(
text="",
type="output_text",
Expand All @@ -374,8 +393,7 @@ async def handle_stream(
content_index=state.text_content_index_and_output[0],
delta=delta.content,
item_id=FAKE_RESPONSES_ID,
output_index=state.reasoning_content_index_and_output
is not None, # fixed 0 -> 0 or 1
output_index=cls._assistant_message_output_index(state),
type="response.output_text.delta",
sequence_number=sequence_number.get_and_increment(),
logprobs=delta_logprobs,
Expand Down Expand Up @@ -415,15 +433,14 @@ async def handle_stream(
# Notify downstream that assistant message + first content part are starting
yield ResponseOutputItemAddedEvent(
item=assistant_item,
output_index=state.reasoning_content_index_and_output
is not None, # fixed 0 -> 0 or 1
output_index=cls._assistant_message_output_index(state),
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)
yield ResponseContentPartAddedEvent(
content_index=state.refusal_content_index_and_output[0],
item_id=FAKE_RESPONSES_ID,
output_index=(1 if state.reasoning_content_index_and_output else 0),
output_index=cls._assistant_message_output_index(state),
part=ResponseOutputRefusal(
refusal="",
type="refusal",
Expand All @@ -436,8 +453,7 @@ async def handle_stream(
content_index=state.refusal_content_index_and_output[0],
delta=delta.refusal,
item_id=FAKE_RESPONSES_ID,
output_index=state.reasoning_content_index_and_output
is not None, # fixed 0 -> 0 or 1
output_index=cls._assistant_message_output_index(state),
type="response.refusal.delta",
sequence_number=sequence_number.get_and_increment(),
)
Expand Down Expand Up @@ -527,25 +543,13 @@ async def handle_stream(
and function_call.name
and function_call.call_id
):
# Calculate the output index for this function call
function_call_starting_index = 0
if state.reasoning_content_index_and_output:
function_call_starting_index += 1
if state.text_content_index_and_output:
function_call_starting_index += 1
if state.refusal_content_index_and_output:
function_call_starting_index += 1

# Add offset for already started function calls
function_call_starting_index += sum(
1 for streaming in state.function_call_streaming.values() if streaming
)

# Mark this function call as streaming and store its output index
state.function_call_streaming[tc_delta.index] = True
state.function_call_output_idx[tc_delta.index] = (
function_call_starting_index
function_call_output_index = cls._function_call_output_index(
state,
tc_delta.index,
)
state.function_call_output_idx[tc_delta.index] = function_call_output_index
Comment thread
seratch marked this conversation as resolved.
Outdated

# Send initial function call added event
func_call_item = ResponseFunctionToolCall(
Expand All @@ -570,7 +574,7 @@ async def handle_stream(
func_call_item.provider_data = merged_provider_data # type: ignore[attr-defined]
yield ResponseOutputItemAddedEvent(
item=func_call_item,
output_index=function_call_starting_index,
output_index=function_call_output_index,
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)
Expand All @@ -593,31 +597,23 @@ async def handle_stream(
for event in cls._finish_reasoning_item(state, sequence_number):
yield event

function_call_starting_index = 0
if state.reasoning_content_index_and_output:
function_call_starting_index += 1

if state.text_content_index_and_output:
function_call_starting_index += 1
# Send end event for this content part
yield ResponseContentPartDoneEvent(
content_index=state.text_content_index_and_output[0],
item_id=FAKE_RESPONSES_ID,
output_index=state.reasoning_content_index_and_output
is not None, # fixed 0 -> 0 or 1
output_index=cls._assistant_message_output_index(state),
part=state.text_content_index_and_output[1],
type="response.content_part.done",
sequence_number=sequence_number.get_and_increment(),
)

if state.refusal_content_index_and_output:
function_call_starting_index += 1
# Send end event for this content part
yield ResponseContentPartDoneEvent(
content_index=state.refusal_content_index_and_output[0],
item_id=FAKE_RESPONSES_ID,
output_index=state.reasoning_content_index_and_output
is not None, # fixed 0 -> 0 or 1
output_index=cls._assistant_message_output_index(state),
part=state.refusal_content_index_and_output[1],
type="response.content_part.done",
sequence_number=sequence_number.get_and_increment(),
Expand Down Expand Up @@ -656,18 +652,7 @@ async def handle_stream(
else:
# Function call was not streamed (fallback to old behavior)
# This handles edge cases where function name never arrived
fallback_starting_index = 0
if state.reasoning_content_index_and_output:
fallback_starting_index += 1
if state.text_content_index_and_output:
fallback_starting_index += 1
if state.refusal_content_index_and_output:
fallback_starting_index += 1

# Add offset for already started function calls
fallback_starting_index += sum(
1 for streaming in state.function_call_streaming.values() if streaming
)
fallback_output_index = cls._function_call_output_index(state, index)

# Build function call kwargs, include provider_data if present
fallback_func_call_kwargs: dict[str, Any] = {
Expand All @@ -690,20 +675,20 @@ async def handle_stream(
# Send all events at once (backward compatibility)
yield ResponseOutputItemAddedEvent(
item=ResponseFunctionToolCall(**fallback_func_call_kwargs),
output_index=fallback_starting_index,
output_index=fallback_output_index,
type="response.output_item.added",
sequence_number=sequence_number.get_and_increment(),
)
yield ResponseFunctionCallArgumentsDeltaEvent(
delta=function_call.arguments,
item_id=FAKE_RESPONSES_ID,
output_index=fallback_starting_index,
output_index=fallback_output_index,
type="response.function_call_arguments.delta",
sequence_number=sequence_number.get_and_increment(),
)
yield ResponseOutputItemDoneEvent(
item=ResponseFunctionToolCall(**fallback_func_call_kwargs),
output_index=fallback_starting_index,
output_index=fallback_output_index,
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
)
Expand Down Expand Up @@ -747,8 +732,7 @@ async def handle_stream(
# send a ResponseOutputItemDone for the assistant message
yield ResponseOutputItemDoneEvent(
item=assistant_msg,
output_index=state.reasoning_content_index_and_output
is not None, # fixed 0 -> 0 or 1
output_index=cls._assistant_message_output_index(state),
type="response.output_item.done",
sequence_number=sequence_number.get_and_increment(),
)
Expand Down
Loading