Skip to content

Commit b644eb3

Browse files
committed
fix: remove dead chat_stream endpoint and body-based stream detection
The "stream" in endpoint check was dead code — both V1 and V2 SDK always route through endpoint "chat" (v1/chat and v2/chat paths). Streaming is reliably signalled via body["stream"], which the SDK always sets. - Drop "stream" in endpoint guard on is_stream and isStream detection - Remove "chat_stream" from action_map, transform, and response branches - Update unit tests to use "chat" endpoint (the only real one)
1 parent 636e76e commit b644eb3

2 files changed

Lines changed: 6 additions & 9 deletions

File tree

src/cohere/oci_client.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ def _event_hook(request: httpx.Request) -> None:
480480
request.stream = ByteStream(oci_body_bytes)
481481
request._content = oci_body_bytes
482482
request.extensions["endpoint"] = endpoint
483-
request.extensions["is_stream"] = "stream" in endpoint or body.get("stream", False)
483+
request.extensions["is_stream"] = body.get("stream", False)
484484
request.extensions["is_v2"] = is_v2_client
485485

486486
return _event_hook
@@ -554,7 +554,6 @@ def get_oci_url(
554554
action_map = {
555555
"embed": "embedText",
556556
"chat": "chat",
557-
"chat_stream": "chat",
558557
}
559558

560559
action = action_map.get(endpoint)
@@ -658,7 +657,7 @@ def transform_request_to_oci(
658657

659658
return oci_body
660659

661-
elif endpoint in ["chat", "chat_stream"]:
660+
elif endpoint == "chat":
662661
# Validate that the request body matches the client type
663662
has_messages = "messages" in cohere_body
664663
has_message = "message" in cohere_body
@@ -800,7 +799,7 @@ def transform_request_to_oci(
800799
chat_request["priority"] = cohere_body["priority"]
801800

802801
# Handle streaming for both versions
803-
if "stream" in endpoint or cohere_body.get("stream"):
802+
if cohere_body.get("stream"):
804803
chat_request["isStream"] = True
805804

806805
# Top level OCI request structure
@@ -861,7 +860,7 @@ def transform_oci_response_to_cohere(
861860
"meta": meta,
862861
}
863862

864-
elif endpoint == "chat" or endpoint == "chat_stream":
863+
elif endpoint == "chat":
865864
chat_response = oci_response.get("chatResponse", {})
866865

867866
if is_v2:
@@ -1096,7 +1095,7 @@ def transform_stream_event(
10961095
Returns:
10971096
V2: List of transformed events. V1: Single transformed event dict.
10981097
"""
1099-
if endpoint in ["chat_stream", "chat"]:
1098+
if endpoint == "chat":
11001099
if is_v2:
11011100
content_type = "text"
11021101
content_value = ""

tests/test_oci_client.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ def test_v1_stream_wrapper_preserves_finish_reason(self):
715715

716716
events = [
717717
json.loads(raw.decode("utf-8"))
718-
for raw in transform_oci_stream_wrapper(iter(chunks), "chat_stream", is_v2=False)
718+
for raw in transform_oci_stream_wrapper(iter(chunks), "chat", is_v2=False)
719719
]
720720

721721
self.assertEqual(events[2]["event_type"], "stream-end")
@@ -759,8 +759,6 @@ def test_get_oci_url_known_endpoints(self):
759759
url = get_oci_url("us-chicago-1", "chat")
760760
self.assertIn("/actions/chat", url)
761761

762-
url = get_oci_url("us-chicago-1", "chat_stream")
763-
self.assertIn("/actions/chat", url)
764762

765763
def test_get_oci_url_unknown_endpoint_raises(self):
766764
"""Test that unknown endpoints raise ValueError instead of producing bad URLs."""

0 commit comments

Comments
 (0)