Skip to content

Commit 2a519a9

Browse files
committed
fix(oci): terminate stream on finishReason instead of waiting for [DONE]
OCI Generative AI does not send a `data: [DONE]` SSE marker to signal end-of-stream. It sends a final event with `finishReason` and keeps the connection open, causing chat_stream() to hang indefinitely. Emit closing events (message-end / stream-end) and return from the generator when `finishReason` is detected. The [DONE] path is kept as a fallback for forward compatibility. Fixes cohere-ai#756
1 parent 2598c9a commit 2a519a9

1 file changed

Lines changed: 42 additions & 25 deletions

File tree

src/cohere/oci_client.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,36 +1055,45 @@ def _transform_v1_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Itera
10551055
final_v1_finish_reason = oci_event.get("finishReason", final_v1_finish_reason)
10561056
yield _emit_v1_event(event)
10571057

1058+
stream_finished = False
1059+
1060+
def _emit_closing_events() -> typing.Iterator[bytes]:
1061+
"""Emit the final closing events for the stream."""
1062+
if is_v2:
1063+
if emitted_start:
1064+
if not emitted_content_end:
1065+
yield _emit_v2_event({"type": "content-end", "index": current_content_index})
1066+
message_end_event: typing.Dict[str, typing.Any] = {
1067+
"type": "message-end",
1068+
"id": generation_id,
1069+
"delta": {"finish_reason": final_finish_reason},
1070+
}
1071+
if final_usage:
1072+
message_end_event["delta"]["usage"] = final_usage
1073+
yield _emit_v2_event(message_end_event)
1074+
else:
1075+
yield _emit_v1_event(
1076+
{
1077+
"event_type": "stream-end",
1078+
"finish_reason": final_v1_finish_reason,
1079+
"response": {
1080+
"text": full_v1_text,
1081+
"generation_id": generation_id,
1082+
"finish_reason": final_v1_finish_reason,
1083+
},
1084+
}
1085+
)
1086+
10581087
def _process_line(line: str) -> typing.Iterator[bytes]:
1088+
nonlocal stream_finished
10591089
if not line.startswith("data: "):
10601090
return
10611091

10621092
data_str = line[6:]
10631093
if data_str.strip() == "[DONE]":
1064-
if is_v2:
1065-
if emitted_start:
1066-
if not emitted_content_end:
1067-
yield _emit_v2_event({"type": "content-end", "index": current_content_index})
1068-
message_end_event: typing.Dict[str, typing.Any] = {
1069-
"type": "message-end",
1070-
"id": generation_id,
1071-
"delta": {"finish_reason": final_finish_reason},
1072-
}
1073-
if final_usage:
1074-
message_end_event["delta"]["usage"] = final_usage
1075-
yield _emit_v2_event(message_end_event)
1076-
else:
1077-
yield _emit_v1_event(
1078-
{
1079-
"event_type": "stream-end",
1080-
"finish_reason": final_v1_finish_reason,
1081-
"response": {
1082-
"text": full_v1_text,
1083-
"generation_id": generation_id,
1084-
"finish_reason": final_v1_finish_reason,
1085-
},
1086-
}
1087-
)
1094+
for event_bytes in _emit_closing_events():
1095+
yield event_bytes
1096+
stream_finished = True
10881097
return
10891098

10901099
try:
@@ -1102,15 +1111,23 @@ def _process_line(line: str) -> typing.Iterator[bytes]:
11021111
except Exception as exc:
11031112
raise RuntimeError(f"OCI stream event transformation failed for endpoint '{endpoint}': {exc}") from exc
11041113

1114+
# OCI may not send [DONE] — treat finishReason as stream termination
1115+
if "finishReason" in oci_event:
1116+
for event_bytes in _emit_closing_events():
1117+
yield event_bytes
1118+
stream_finished = True
1119+
11051120
for chunk in stream:
11061121
buffer += chunk
11071122
while b"\n" in buffer:
11081123
line_bytes, buffer = buffer.split(b"\n", 1)
11091124
line = line_bytes.decode("utf-8").strip()
11101125
for event_bytes in _process_line(line):
11111126
yield event_bytes
1127+
if stream_finished:
1128+
return
11121129

1113-
if buffer.strip():
1130+
if buffer.strip() and not stream_finished:
11141131
line = buffer.decode("utf-8").strip()
11151132
for event_bytes in _process_line(line):
11161133
yield event_bytes

0 commit comments

Comments
 (0)