Skip to content

Commit d1d0abe

Browse files
committed
fix: emit integer chat stream output indexes
1 parent 54f737b commit d1d0abe

2 files changed

Lines changed: 31 additions & 17 deletions

File tree

src/agents/models/chatcmpl_stream_handler.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,10 @@ def get_and_increment(self) -> int:
8484

8585

8686
class ChatCmplStreamHandler:
87+
@staticmethod
88+
def _assistant_message_output_index(state: StreamingState) -> int:
89+
return 1 if state.reasoning_content_index_and_output is not None else 0
90+
8791
@classmethod
8892
def _finish_reasoning_summary_part(
8993
cls,
@@ -341,16 +345,14 @@ async def handle_stream(
341345
# Notify consumers of the start of a new output message + first content part
342346
yield ResponseOutputItemAddedEvent(
343347
item=assistant_item,
344-
output_index=state.reasoning_content_index_and_output
345-
is not None, # fixed 0 -> 0 or 1
348+
output_index=cls._assistant_message_output_index(state),
346349
type="response.output_item.added",
347350
sequence_number=sequence_number.get_and_increment(),
348351
)
349352
yield ResponseContentPartAddedEvent(
350353
content_index=state.text_content_index_and_output[0],
351354
item_id=FAKE_RESPONSES_ID,
352-
output_index=state.reasoning_content_index_and_output
353-
is not None, # fixed 0 -> 0 or 1
355+
output_index=cls._assistant_message_output_index(state),
354356
part=ResponseOutputText(
355357
text="",
356358
type="output_text",
@@ -374,8 +376,7 @@ async def handle_stream(
374376
content_index=state.text_content_index_and_output[0],
375377
delta=delta.content,
376378
item_id=FAKE_RESPONSES_ID,
377-
output_index=state.reasoning_content_index_and_output
378-
is not None, # fixed 0 -> 0 or 1
379+
output_index=cls._assistant_message_output_index(state),
379380
type="response.output_text.delta",
380381
sequence_number=sequence_number.get_and_increment(),
381382
logprobs=delta_logprobs,
@@ -415,15 +416,14 @@ async def handle_stream(
415416
# Notify downstream that assistant message + first content part are starting
416417
yield ResponseOutputItemAddedEvent(
417418
item=assistant_item,
418-
output_index=state.reasoning_content_index_and_output
419-
is not None, # fixed 0 -> 0 or 1
419+
output_index=cls._assistant_message_output_index(state),
420420
type="response.output_item.added",
421421
sequence_number=sequence_number.get_and_increment(),
422422
)
423423
yield ResponseContentPartAddedEvent(
424424
content_index=state.refusal_content_index_and_output[0],
425425
item_id=FAKE_RESPONSES_ID,
426-
output_index=(1 if state.reasoning_content_index_and_output else 0),
426+
output_index=cls._assistant_message_output_index(state),
427427
part=ResponseOutputRefusal(
428428
refusal="",
429429
type="refusal",
@@ -436,8 +436,7 @@ async def handle_stream(
436436
content_index=state.refusal_content_index_and_output[0],
437437
delta=delta.refusal,
438438
item_id=FAKE_RESPONSES_ID,
439-
output_index=state.reasoning_content_index_and_output
440-
is not None, # fixed 0 -> 0 or 1
439+
output_index=cls._assistant_message_output_index(state),
441440
type="response.refusal.delta",
442441
sequence_number=sequence_number.get_and_increment(),
443442
)
@@ -603,8 +602,7 @@ async def handle_stream(
603602
yield ResponseContentPartDoneEvent(
604603
content_index=state.text_content_index_and_output[0],
605604
item_id=FAKE_RESPONSES_ID,
606-
output_index=state.reasoning_content_index_and_output
607-
is not None, # fixed 0 -> 0 or 1
605+
output_index=cls._assistant_message_output_index(state),
608606
part=state.text_content_index_and_output[1],
609607
type="response.content_part.done",
610608
sequence_number=sequence_number.get_and_increment(),
@@ -616,8 +614,7 @@ async def handle_stream(
616614
yield ResponseContentPartDoneEvent(
617615
content_index=state.refusal_content_index_and_output[0],
618616
item_id=FAKE_RESPONSES_ID,
619-
output_index=state.reasoning_content_index_and_output
620-
is not None, # fixed 0 -> 0 or 1
617+
output_index=cls._assistant_message_output_index(state),
621618
part=state.refusal_content_index_and_output[1],
622619
type="response.content_part.done",
623620
sequence_number=sequence_number.get_and_increment(),
@@ -750,8 +747,7 @@ async def handle_stream(
750747
# send a ResponseOutputItemDone for the assistant message
751748
yield ResponseOutputItemDoneEvent(
752749
item=assistant_msg,
753-
output_index=state.reasoning_content_index_and_output
754-
is not None, # fixed 0 -> 0 or 1
750+
output_index=cls._assistant_message_output_index(state),
755751
type="response.output_item.done",
756752
sequence_number=sequence_number.get_and_increment(),
757753
)

tests/models/test_reasoning_content.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,24 @@ async def patched_fetch_response(self, *args, **kwargs):
160160
assert content_delta_events[0].delta == "The answer"
161161
assert content_delta_events[1].delta == " is 42"
162162

163+
assistant_message_index_events = []
164+
for event in output_events:
165+
event_any = cast(Any, event)
166+
if event.type in {"response.output_item.added", "response.output_item.done"}:
167+
if event_any.item.type == "message":
168+
assistant_message_index_events.append(event_any)
169+
elif event.type in {
170+
"response.content_part.added",
171+
"response.output_text.delta",
172+
"response.content_part.done",
173+
}:
174+
assistant_message_index_events.append(event_any)
175+
176+
assert assistant_message_index_events
177+
for event in assistant_message_index_events:
178+
assert event.output_index == 1
179+
assert type(event.output_index) is int
180+
163181
# verify the final response contains both types of content
164182
response_event = output_events[-1]
165183
assert response_event.type == "response.completed"

0 commit comments

Comments
 (0)