Skip to content

Commit bac195c

Browse files
committed
fix: don't trigger content-type transition on finish-only stream events
_current_content_type now returns None for events with no message content (e.g. {"finishReason": "COMPLETE"}). The transition branch in _transform_v2_event is skipped when event_content_type is None, so a finish-only event after a thinking block no longer opens a spurious empty text block before emitting content-end.
1 parent b644eb3 commit bac195c

2 files changed

Lines changed: 33 additions & 7 deletions

File tree

src/cohere/oci_client.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -953,21 +953,21 @@ def _emit_v2_event(event: typing.Dict[str, typing.Any]) -> bytes:
953953
def _emit_v1_event(event: typing.Dict[str, typing.Any]) -> bytes:
954954
return json.dumps(event).encode("utf-8") + b"\n"
955955

956-
def _current_content_type(oci_event: typing.Dict[str, typing.Any]) -> str:
956+
def _current_content_type(oci_event: typing.Dict[str, typing.Any]) -> typing.Optional[str]:
957957
message = oci_event.get("message")
958958
if isinstance(message, dict):
959959
content_list = message.get("content")
960960
if content_list and isinstance(content_list, list) and len(content_list) > 0:
961961
oci_type = content_list[0].get("type", "TEXT").upper()
962-
if oci_type == "THINKING":
963-
return "thinking"
964-
return "text"
962+
return "thinking" if oci_type == "THINKING" else "text"
963+
return None # finish-only or non-content event — don't trigger a type transition
965964

966965
def _transform_v2_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Iterator[bytes]:
967966
nonlocal emitted_start, emitted_content_end, current_content_type, current_content_index
968967
nonlocal final_finish_reason, final_usage
969968

970969
event_content_type = _current_content_type(oci_event)
970+
open_type = event_content_type or "text"
971971

972972
if not emitted_start:
973973
yield _emit_v2_event(
@@ -981,12 +981,12 @@ def _transform_v2_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Itera
981981
{
982982
"type": "content-start",
983983
"index": current_content_index,
984-
"delta": {"message": {"content": {"type": event_content_type}}},
984+
"delta": {"message": {"content": {"type": open_type}}},
985985
}
986986
)
987987
emitted_start = True
988-
current_content_type = event_content_type
989-
elif current_content_type != event_content_type:
988+
current_content_type = open_type
989+
elif event_content_type is not None and current_content_type != event_content_type:
990990
yield _emit_v2_event({"type": "content-end", "index": current_content_index})
991991
current_content_index += 1
992992
yield _emit_v2_event(

tests/test_oci_client.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,6 +954,32 @@ def test_stream_wrapper_emits_new_content_block_on_thinking_transition(self):
954954
self.assertEqual(events[5]["type"], "content-delta")
955955
self.assertEqual(events[5]["index"], 1)
956956

957+
def test_stream_wrapper_no_spurious_block_on_finish_only_event(self):
958+
"""Finish-only event after thinking block must not open a spurious empty text block."""
959+
import json
960+
from cohere.oci_client import transform_oci_stream_wrapper
961+
962+
chunks = [
963+
b'data: {"message": {"content": [{"type": "THINKING", "thinking": "Reasoning..."}]}}\n',
964+
b'data: {"finishReason": "COMPLETE"}\n',
965+
b"data: [DONE]\n",
966+
]
967+
968+
events = []
969+
for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=True):
970+
line = raw.decode("utf-8").strip()
971+
if line.startswith("data: "):
972+
events.append(json.loads(line[6:]))
973+
974+
types = [e["type"] for e in events]
975+
# Must not contain two content-start events
976+
self.assertEqual(types.count("content-start"), 1)
977+
# The single content block must be thinking
978+
cs = next(e for e in events if e["type"] == "content-start")
979+
self.assertEqual(cs["delta"]["message"]["content"]["type"], "thinking")
980+
# Must end cleanly
981+
self.assertEqual(events[-1]["type"], "message-end")
982+
957983
def test_stream_wrapper_skips_malformed_json_with_warning(self):
958984
"""Test that malformed JSON in SSE stream is skipped."""
959985
from cohere.oci_client import transform_oci_stream_wrapper

0 commit comments

Comments
 (0)