Skip to content

Commit b1ed6d5

Browse files
committed
fix(oci): address Bugbot review — deduplicate Streamer and add V1 stream-start
1. Remove duplicate Streamer class (manually_maintained/streaming.py) and import from aws_client.py instead. Both were identical SyncByteStream wrappers. 2. Emit stream-start event with generation_id at the beginning of V1 streams, matching the standard Cohere V1 streaming chat format. Consumers relying on stream-start for state initialization will now receive it before text-generation events. Updated test_v1_stream_wrapper_preserves_finish_reason to verify stream-start is emitted first.
1 parent 49f4c34 commit b1ed6d5

3 files changed

Lines changed: 20 additions & 24 deletions

File tree

src/cohere/manually_maintained/streaming.py

Lines changed: 0 additions & 15 deletions
This file was deleted.

src/cohere/oci_client.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
import requests
1212
from .client import Client, ClientEnvironment
1313
from .client_v2 import ClientV2
14+
from .aws_client import Streamer
1415
from .manually_maintained.lazy_oci_deps import lazy_oci
15-
from .manually_maintained.streaming import Streamer
1616
from httpx import URL, ByteStream
1717

1818

@@ -1038,16 +1038,22 @@ def _transform_v2_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Itera
10381038
final_usage = _usage_from_oci(oci_event.get("usage"))
10391039
yield _emit_v2_event(cohere_event)
10401040

1041-
def _transform_v1_event(oci_event: typing.Dict[str, typing.Any]) -> bytes:
1042-
nonlocal full_v1_text, final_v1_finish_reason
1041+
def _transform_v1_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Iterator[bytes]:
1042+
nonlocal emitted_start, full_v1_text, final_v1_finish_reason
1043+
if not emitted_start:
1044+
yield _emit_v1_event({
1045+
"event_type": "stream-start",
1046+
"generation_id": generation_id,
1047+
"is_finished": False,
1048+
})
1049+
emitted_start = True
10431050
event = transform_stream_event(endpoint, oci_event, is_v2=False)
10441051
if isinstance(event, dict):
10451052
if event.get("event_type") == "text-generation" and event.get("text"):
10461053
full_v1_text += typing.cast(str, event["text"])
10471054
if "finishReason" in oci_event:
10481055
final_v1_finish_reason = oci_event.get("finishReason", final_v1_finish_reason)
1049-
return _emit_v1_event(event)
1050-
return b""
1056+
yield _emit_v1_event(event)
10511057

10521058
def _process_line(line: str) -> typing.Iterator[bytes]:
10531059
if not line.startswith("data: "):
@@ -1091,7 +1097,8 @@ def _process_line(line: str) -> typing.Iterator[bytes]:
10911097
for event_bytes in _transform_v2_event(oci_event):
10921098
yield event_bytes
10931099
else:
1094-
yield _transform_v1_event(oci_event)
1100+
for event_bytes in _transform_v1_event(oci_event):
1101+
yield event_bytes
10951102
except Exception as exc:
10961103
raise RuntimeError(f"OCI stream event transformation failed for endpoint '{endpoint}': {exc}") from exc
10971104

tests/test_oci_client.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -800,9 +800,13 @@ def test_v1_stream_wrapper_preserves_finish_reason(self):
800800
for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=False)
801801
]
802802

803-
self.assertEqual(events[2]["event_type"], "stream-end")
804-
self.assertEqual(events[2]["finish_reason"], "MAX_TOKENS")
805-
self.assertEqual(events[2]["response"]["text"], "Hello world")
803+
# First event should be stream-start with generation_id
804+
self.assertEqual(events[0]["event_type"], "stream-start")
805+
self.assertIn("generation_id", events[0])
806+
807+
self.assertEqual(events[3]["event_type"], "stream-end")
808+
self.assertEqual(events[3]["finish_reason"], "MAX_TOKENS")
809+
self.assertEqual(events[3]["response"]["text"], "Hello world")
806810

807811
def test_transform_chat_request_tool_message_fields(self):
808812
"""Test tool message fields are converted to OCI names."""

0 commit comments

Comments
 (0)