@@ -669,7 +669,8 @@ def transform_request_to_oci(
669669 oci_body ["truncate" ] = cohere_body ["truncate" ].upper ()
670670
671671 if "embedding_types" in cohere_body :
672- oci_body ["embeddingTypes" ] = [et .upper () for et in cohere_body ["embedding_types" ]]
672+ # OCI expects lowercase embedding types (float, int8, binary, etc.)
673+ oci_body ["embeddingTypes" ] = [et .lower () for et in cohere_body ["embedding_types" ]]
673674 if "max_tokens" in cohere_body :
674675 oci_body ["maxTokens" ] = cohere_body ["max_tokens" ]
675676 if "output_dimension" in cohere_body :
@@ -731,7 +732,13 @@ def transform_request_to_oci(
731732 oci_msg ["content" ] = msg .get ("content" ) or []
732733
733734 if "tool_calls" in msg :
734- oci_msg ["toolCalls" ] = msg ["tool_calls" ]
735+ oci_tool_calls = []
736+ for tc in msg ["tool_calls" ]:
737+ oci_tc = {** tc }
738+ if "type" in oci_tc :
739+ oci_tc ["type" ] = oci_tc ["type" ].upper ()
740+ oci_tool_calls .append (oci_tc )
741+ oci_msg ["toolCalls" ] = oci_tool_calls
735742 if "tool_call_id" in msg :
736743 oci_msg ["toolCallId" ] = msg ["tool_call_id" ]
737744 if "tool_plan" in msg :
@@ -759,7 +766,13 @@ def transform_request_to_oci(
759766 if "stop_sequences" in cohere_body :
760767 chat_request ["stopSequences" ] = cohere_body ["stop_sequences" ]
761768 if "tools" in cohere_body :
762- chat_request ["tools" ] = cohere_body ["tools" ]
769+ oci_tools = []
770+ for tool in cohere_body ["tools" ]:
771+ oci_tool = {** tool }
772+ if "type" in oci_tool :
773+ oci_tool ["type" ] = oci_tool ["type" ].upper ()
774+ oci_tools .append (oci_tool )
775+ chat_request ["tools" ] = oci_tools
763776 if "strict_tools" in cohere_body :
764777 chat_request ["strictTools" ] = cohere_body ["strict_tools" ]
765778 if "documents" in cohere_body :
@@ -768,8 +781,8 @@ def transform_request_to_oci(
768781 chat_request ["citationOptions" ] = cohere_body ["citation_options" ]
769782 if "response_format" in cohere_body :
770783 chat_request ["responseFormat" ] = cohere_body ["response_format" ]
771- if "safety_mode" in cohere_body :
772- chat_request ["safetyMode" ] = cohere_body ["safety_mode" ]
784+ if "safety_mode" in cohere_body and cohere_body [ "safety_mode" ] is not None :
785+ chat_request ["safetyMode" ] = cohere_body ["safety_mode" ]. upper ()
773786 if "logprobs" in cohere_body :
774787 chat_request ["logprobs" ] = cohere_body ["logprobs" ]
775788 if "tool_choice" in cohere_body :
@@ -813,13 +826,19 @@ def transform_request_to_oci(
813826 if "documents" in cohere_body :
814827 chat_request ["documents" ] = cohere_body ["documents" ]
815828 if "tools" in cohere_body :
816- chat_request ["tools" ] = cohere_body ["tools" ]
829+ oci_tools = []
830+ for tool in cohere_body ["tools" ]:
831+ oci_tool = {** tool }
832+ if "type" in oci_tool :
833+ oci_tool ["type" ] = oci_tool ["type" ].upper ()
834+ oci_tools .append (oci_tool )
835+ chat_request ["tools" ] = oci_tools
817836 if "tool_results" in cohere_body :
818837 chat_request ["toolResults" ] = cohere_body ["tool_results" ]
819838 if "response_format" in cohere_body :
820839 chat_request ["responseFormat" ] = cohere_body ["response_format" ]
821- if "safety_mode" in cohere_body :
822- chat_request ["safetyMode" ] = cohere_body ["safety_mode" ]
840+ if "safety_mode" in cohere_body and cohere_body [ "safety_mode" ] is not None :
841+ chat_request ["safetyMode" ] = cohere_body ["safety_mode" ]. upper ()
823842 if "priority" in cohere_body :
824843 chat_request ["priority" ] = cohere_body ["priority" ]
825844
@@ -860,7 +879,8 @@ def transform_oci_response_to_cohere(
860879 Transformed response in Cohere format
861880 """
862881 if endpoint == "embed" :
863- embeddings_data = oci_response .get ("embeddings" , {})
882+ # OCI returns "embeddings" by default, or "embeddingsByType" when embeddingTypes is specified
883+ embeddings_data = oci_response .get ("embeddingsByType" ) or oci_response .get ("embeddings" , {})
864884
865885 if isinstance (embeddings_data , dict ):
866886 normalized_embeddings = {str (key ).lower (): value for key , value in embeddings_data .items ()}
@@ -914,7 +934,12 @@ def transform_oci_response_to_cohere(
914934 message = {** message , "content" : transformed_content }
915935
916936 if "toolCalls" in message :
917- tool_calls = message ["toolCalls" ]
937+ tool_calls = []
938+ for tc in message ["toolCalls" ]:
939+ lowered_tc = {** tc }
940+ if "type" in lowered_tc :
941+ lowered_tc ["type" ] = lowered_tc ["type" ].lower ()
942+ tool_calls .append (lowered_tc )
918943 message = {k : v for k , v in message .items () if k != "toolCalls" }
919944 message ["tool_calls" ] = tool_calls
920945 if "toolPlan" in message :
@@ -1058,36 +1083,45 @@ def _transform_v1_event(oci_event: typing.Dict[str, typing.Any]) -> typing.Itera
10581083 final_v1_finish_reason = oci_event .get ("finishReason" , final_v1_finish_reason )
10591084 yield _emit_v1_event (event )
10601085
1086+ stream_finished = False
1087+
1088+ def _emit_closing_events () -> typing .Iterator [bytes ]:
1089+ """Emit the final closing events for the stream."""
1090+ if is_v2 :
1091+ if emitted_start :
1092+ if not emitted_content_end :
1093+ yield _emit_v2_event ({"type" : "content-end" , "index" : current_content_index })
1094+ message_end_event : typing .Dict [str , typing .Any ] = {
1095+ "type" : "message-end" ,
1096+ "id" : generation_id ,
1097+ "delta" : {"finish_reason" : final_finish_reason },
1098+ }
1099+ if final_usage :
1100+ message_end_event ["delta" ]["usage" ] = final_usage
1101+ yield _emit_v2_event (message_end_event )
1102+ else :
1103+ yield _emit_v1_event (
1104+ {
1105+ "event_type" : "stream-end" ,
1106+ "finish_reason" : final_v1_finish_reason ,
1107+ "response" : {
1108+ "text" : full_v1_text ,
1109+ "generation_id" : generation_id ,
1110+ "finish_reason" : final_v1_finish_reason ,
1111+ },
1112+ }
1113+ )
1114+
10611115 def _process_line (line : str ) -> typing .Iterator [bytes ]:
1116+ nonlocal stream_finished
10621117 if not line .startswith ("data: " ):
10631118 return
10641119
10651120 data_str = line [6 :]
10661121 if data_str .strip () == "[DONE]" :
1067- if is_v2 :
1068- if emitted_start :
1069- if not emitted_content_end :
1070- yield _emit_v2_event ({"type" : "content-end" , "index" : current_content_index })
1071- message_end_event : typing .Dict [str , typing .Any ] = {
1072- "type" : "message-end" ,
1073- "id" : generation_id ,
1074- "delta" : {"finish_reason" : final_finish_reason },
1075- }
1076- if final_usage :
1077- message_end_event ["delta" ]["usage" ] = final_usage
1078- yield _emit_v2_event (message_end_event )
1079- else :
1080- yield _emit_v1_event (
1081- {
1082- "event_type" : "stream-end" ,
1083- "finish_reason" : final_v1_finish_reason ,
1084- "response" : {
1085- "text" : full_v1_text ,
1086- "generation_id" : generation_id ,
1087- "finish_reason" : final_v1_finish_reason ,
1088- },
1089- }
1090- )
1122+ for event_bytes in _emit_closing_events ():
1123+ yield event_bytes
1124+ stream_finished = True
10911125 return
10921126
10931127 try :
@@ -1105,15 +1139,23 @@ def _process_line(line: str) -> typing.Iterator[bytes]:
11051139 except Exception as exc :
11061140 raise RuntimeError (f"OCI stream event transformation failed for endpoint '{ endpoint } ': { exc } " ) from exc
11071141
1142+ # OCI may not send [DONE] — treat finishReason as stream termination
1143+ if "finishReason" in oci_event :
1144+ for event_bytes in _emit_closing_events ():
1145+ yield event_bytes
1146+ stream_finished = True
1147+
11081148 for chunk in stream :
11091149 buffer += chunk
11101150 while b"\n " in buffer :
11111151 line_bytes , buffer = buffer .split (b"\n " , 1 )
11121152 line = line_bytes .decode ("utf-8" ).strip ()
11131153 for event_bytes in _process_line (line ):
11141154 yield event_bytes
1155+ if stream_finished :
1156+ return
11151157
1116- if buffer .strip ():
1158+ if buffer .strip () and not stream_finished :
11171159 line = buffer .decode ("utf-8" ).strip ()
11181160 for event_bytes in _process_line (line ):
11191161 yield event_bytes
0 commit comments