Skip to content

Commit 537b1eb

Browse files
committed
merge: resolve conflicts with upstream/main for vision + tool use tests
Keep both the vision test (from this branch) and the tool use, type lowercasing, multi-turn, and safety_mode tests (from upstream).
2 parents 1e2da69 + fc167c1 commit 537b1eb

2 files changed

Lines changed: 247 additions & 43 deletions

File tree

src/cohere/oci_client.py

Lines changed: 77 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,8 @@ def transform_request_to_oci(
669669
oci_body["truncate"] = cohere_body["truncate"].upper()
670670

671671
if "embedding_types" in cohere_body:
672-
oci_body["embeddingTypes"] = [et.upper() for et in cohere_body["embedding_types"]]
672+
# OCI expects lowercase embedding types (float, int8, binary, etc.)
673+
oci_body["embeddingTypes"] = [et.lower() for et in cohere_body["embedding_types"]]
673674
if "max_tokens" in cohere_body:
674675
oci_body["maxTokens"] = cohere_body["max_tokens"]
675676
if "output_dimension" in cohere_body:
@@ -731,7 +732,13 @@ def transform_request_to_oci(
731732
oci_msg["content"] = msg.get("content") or []
732733

733734
if "tool_calls" in msg:
734-
oci_msg["toolCalls"] = msg["tool_calls"]
735+
oci_tool_calls = []
736+
for tc in msg["tool_calls"]:
737+
oci_tc = {**tc}
738+
if "type" in oci_tc:
739+
oci_tc["type"] = oci_tc["type"].upper()
740+
oci_tool_calls.append(oci_tc)
741+
oci_msg["toolCalls"] = oci_tool_calls
735742
if "tool_call_id" in msg:
736743
oci_msg["toolCallId"] = msg["tool_call_id"]
737744
if "tool_plan" in msg:
@@ -759,7 +766,13 @@ def transform_request_to_oci(
759766
if "stop_sequences" in cohere_body:
760767
chat_request["stopSequences"] = cohere_body["stop_sequences"]
761768
if "tools" in cohere_body:
762-
chat_request["tools"] = cohere_body["tools"]
769+
oci_tools = []
770+
for tool in cohere_body["tools"]:
771+
oci_tool = {**tool}
772+
if "type" in oci_tool:
773+
oci_tool["type"] = oci_tool["type"].upper()
774+
oci_tools.append(oci_tool)
775+
chat_request["tools"] = oci_tools
763776
if "strict_tools" in cohere_body:
764777
chat_request["strictTools"] = cohere_body["strict_tools"]
765778
if "documents" in cohere_body:
@@ -768,8 +781,8 @@ def transform_request_to_oci(
768781
chat_request["citationOptions"] = cohere_body["citation_options"]
769782
if "response_format" in cohere_body:
770783
chat_request["responseFormat"] = cohere_body["response_format"]
771-
if "safety_mode" in cohere_body:
772-
chat_request["safetyMode"] = cohere_body["safety_mode"]
784+
if "safety_mode" in cohere_body and cohere_body["safety_mode"] is not None:
785+
chat_request["safetyMode"] = cohere_body["safety_mode"].upper()
773786
if "logprobs" in cohere_body:
774787
chat_request["logprobs"] = cohere_body["logprobs"]
775788
if "tool_choice" in cohere_body:
@@ -813,13 +826,19 @@ def transform_request_to_oci(
813826
if "documents" in cohere_body:
814827
chat_request["documents"] = cohere_body["documents"]
815828
if "tools" in cohere_body:
816-
chat_request["tools"] = cohere_body["tools"]
829+
oci_tools = []
830+
for tool in cohere_body["tools"]:
831+
oci_tool = {**tool}
832+
if "type" in oci_tool:
833+
oci_tool["type"] = oci_tool["type"].upper()
834+
oci_tools.append(oci_tool)
835+
chat_request["tools"] = oci_tools
817836
if "tool_results" in cohere_body:
818837
chat_request["toolResults"] = cohere_body["tool_results"]
819838
if "response_format" in cohere_body:
820839
chat_request["responseFormat"] = cohere_body["response_format"]
821-
if "safety_mode" in cohere_body:
822-
chat_request["safetyMode"] = cohere_body["safety_mode"]
840+
if "safety_mode" in cohere_body and cohere_body["safety_mode"] is not None:
841+
chat_request["safetyMode"] = cohere_body["safety_mode"].upper()
823842
if "priority" in cohere_body:
824843
chat_request["priority"] = cohere_body["priority"]
825844

@@ -860,7 +879,8 @@ def transform_oci_response_to_cohere(
860879
Transformed response in Cohere format
861880
"""
862881
if endpoint == "embed":
863-
embeddings_data = oci_response.get("embeddings", {})
882+
# OCI returns "embeddings" by default, or "embeddingsByType" when embeddingTypes is specified
883+
embeddings_data = oci_response.get("embeddingsByType") or oci_response.get("embeddings", {})
864884

865885
if isinstance(embeddings_data, dict):
866886
normalized_embeddings = {str(key).lower(): value for key, value in embeddings_data.items()}
@@ -914,7 +934,12 @@ def transform_oci_response_to_cohere(
914934
message = {**message, "content": transformed_content}
915935

916936
if "toolCalls" in message:
917-
tool_calls = message["toolCalls"]
937+
tool_calls = []
938+
for tc in message["toolCalls"]:
939+
lowered_tc = {**tc}
940+
if "type" in lowered_tc:
941+
lowered_tc["type"] = lowered_tc["type"].lower()
942+
tool_calls.append(lowered_tc)
918943
message = {k: v for k, v in message.items() if k != "toolCalls"}
919944
message["tool_calls"] = tool_calls
920945
if "toolPlan" in message:
@@ -1058,36 +1083,45 @@ def _transform_v1_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Itera
10581083
final_v1_finish_reason = oci_event.get("finishReason", final_v1_finish_reason)
10591084
yield _emit_v1_event(event)
10601085

1086+
stream_finished = False
1087+
1088+
def _emit_closing_events() -> typing.Iterator[bytes]:
1089+
"""Emit the final closing events for the stream."""
1090+
if is_v2:
1091+
if emitted_start:
1092+
if not emitted_content_end:
1093+
yield _emit_v2_event({"type": "content-end", "index": current_content_index})
1094+
message_end_event: typing.Dict[str, typing.Any] = {
1095+
"type": "message-end",
1096+
"id": generation_id,
1097+
"delta": {"finish_reason": final_finish_reason},
1098+
}
1099+
if final_usage:
1100+
message_end_event["delta"]["usage"] = final_usage
1101+
yield _emit_v2_event(message_end_event)
1102+
else:
1103+
yield _emit_v1_event(
1104+
{
1105+
"event_type": "stream-end",
1106+
"finish_reason": final_v1_finish_reason,
1107+
"response": {
1108+
"text": full_v1_text,
1109+
"generation_id": generation_id,
1110+
"finish_reason": final_v1_finish_reason,
1111+
},
1112+
}
1113+
)
1114+
10611115
def _process_line(line: str) -> typing.Iterator[bytes]:
1116+
nonlocal stream_finished
10621117
if not line.startswith("data: "):
10631118
return
10641119

10651120
data_str = line[6:]
10661121
if data_str.strip() == "[DONE]":
1067-
if is_v2:
1068-
if emitted_start:
1069-
if not emitted_content_end:
1070-
yield _emit_v2_event({"type": "content-end", "index": current_content_index})
1071-
message_end_event: typing.Dict[str, typing.Any] = {
1072-
"type": "message-end",
1073-
"id": generation_id,
1074-
"delta": {"finish_reason": final_finish_reason},
1075-
}
1076-
if final_usage:
1077-
message_end_event["delta"]["usage"] = final_usage
1078-
yield _emit_v2_event(message_end_event)
1079-
else:
1080-
yield _emit_v1_event(
1081-
{
1082-
"event_type": "stream-end",
1083-
"finish_reason": final_v1_finish_reason,
1084-
"response": {
1085-
"text": full_v1_text,
1086-
"generation_id": generation_id,
1087-
"finish_reason": final_v1_finish_reason,
1088-
},
1089-
}
1090-
)
1122+
for event_bytes in _emit_closing_events():
1123+
yield event_bytes
1124+
stream_finished = True
10911125
return
10921126

10931127
try:
@@ -1105,15 +1139,23 @@ def _process_line(line: str) -> typing.Iterator[bytes]:
11051139
except Exception as exc:
11061140
raise RuntimeError(f"OCI stream event transformation failed for endpoint '{endpoint}': {exc}") from exc
11071141

1142+
# OCI may not send [DONE] — treat finishReason as stream termination
1143+
if "finishReason" in oci_event:
1144+
for event_bytes in _emit_closing_events():
1145+
yield event_bytes
1146+
stream_finished = True
1147+
11081148
for chunk in stream:
11091149
buffer += chunk
11101150
while b"\n" in buffer:
11111151
line_bytes, buffer = buffer.split(b"\n", 1)
11121152
line = line_bytes.decode("utf-8").strip()
11131153
for event_bytes in _process_line(line):
11141154
yield event_bytes
1155+
if stream_finished:
1156+
return
11151157

1116-
if buffer.strip():
1158+
if buffer.strip() and not stream_finished:
11171159
line = buffer.decode("utf-8").strip()
11181160
for event_bytes in _process_line(line):
11191161
yield event_bytes

0 commit comments

Comments
 (0)