Skip to content

Commit 8f40152

Browse files
committed
fix(oci): fix embed embedding_types casing and handle embeddingsByType response
- embedding_types: OCI expects lowercase (float, int8) not uppercase. The .upper() was breaking all embedding_types requests. - Response: OCI returns "embeddingsByType" (not "embeddings") when embeddingTypes is specified. Handle both response keys. - Unit test updated to expect lowercase. - Integration tests added: embedding_types=["float"] and truncate modes.
1 parent 7a45ba6 commit 8f40152

2 files changed

Lines changed: 29 additions & 3 deletions

File tree

src/cohere/oci_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,8 @@ def transform_request_to_oci(
669669
oci_body["truncate"] = cohere_body["truncate"].upper()
670670

671671
if "embedding_types" in cohere_body:
672-
oci_body["embeddingTypes"] = [et.upper() for et in cohere_body["embedding_types"]]
672+
# OCI expects lowercase embedding types (float, int8, binary, etc.)
673+
oci_body["embeddingTypes"] = [et.lower() for et in cohere_body["embedding_types"]]
673674
if "max_tokens" in cohere_body:
674675
oci_body["maxTokens"] = cohere_body["max_tokens"]
675676
if "output_dimension" in cohere_body:
@@ -875,7 +876,8 @@ def transform_oci_response_to_cohere(
875876
Transformed response in Cohere format
876877
"""
877878
if endpoint == "embed":
878-
embeddings_data = oci_response.get("embeddings", {})
879+
# OCI returns "embeddings" by default, or "embeddingsByType" when embeddingTypes is specified
880+
embeddings_data = oci_response.get("embeddingsByType") or oci_response.get("embeddings", {})
879881

880882
if isinstance(embeddings_data, dict):
881883
normalized_embeddings = {str(key).lower(): value for key, value in embeddings_data.items()}

tests/test_oci_client.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,30 @@ def test_embed_search_query_input_type(self):
509509
self.assertIsNotNone(response.embeddings.float_)
510510
self.assertEqual(len(response.embeddings.float_[0]), 1024)
511511

512+
def test_embed_with_embedding_types(self):
513+
"""Test embed with explicit embedding_types parameter."""
514+
response = self.client.embed(
515+
model="embed-english-v3.0",
516+
texts=["Hello world"],
517+
input_type="search_document",
518+
embedding_types=["float"],
519+
)
520+
self.assertIsNotNone(response.embeddings.float_)
521+
self.assertEqual(len(response.embeddings.float_[0]), 1024)
522+
523+
def test_embed_with_truncate(self):
524+
"""Test embed with truncate parameter."""
525+
long_text = "hello " * 1000
526+
for mode in ["NONE", "START", "END"]:
527+
response = self.client.embed(
528+
model="embed-english-v3.0",
529+
texts=[long_text],
530+
input_type="search_document",
531+
truncate=mode,
532+
)
533+
self.assertIsNotNone(response.embeddings.float_)
534+
self.assertEqual(len(response.embeddings.float_[0]), 1024)
535+
512536
def test_command_r_plus_chat(self):
513537
"""Test command-r-plus-08-2024 via V1 client."""
514538
v1_client = cohere.OciClient(
@@ -772,7 +796,7 @@ def test_transform_embed_request(self):
772796
self.assertEqual(result["inputs"], ["hello", "world"])
773797
self.assertEqual(result["inputType"], "SEARCH_DOCUMENT")
774798
self.assertEqual(result["truncate"], "END")
775-
self.assertEqual(result["embeddingTypes"], ["FLOAT", "INT8"])
799+
self.assertEqual(result["embeddingTypes"], ["float", "int8"])
776800
self.assertEqual(result["compartmentId"], "compartment-123")
777801
self.assertEqual(result["servingMode"]["modelId"], "cohere.embed-english-v3.0")
778802

0 commit comments

Comments
 (0)