Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from dataclasses import replace
from typing import Any

from botocore.config import Config
Expand Down Expand Up @@ -186,10 +187,11 @@ def _embed_cohere(self, documents: list[Document]) -> list[Document]:
)
all_embeddings.extend(embeddings_list)

new_documents = []
for doc, emb in zip(documents, all_embeddings, strict=True):
doc.embedding = emb
new_documents.append(replace(doc, embedding=emb))

return documents
return new_documents

def _embed_titan(self, documents: list[Document]) -> list[Document]:
"""
Expand All @@ -214,10 +216,11 @@ def _embed_titan(self, documents: list[Document]) -> list[Document]:
embedding = response_body["embedding"]
all_embeddings.append(embedding)

new_documents = []
for doc, emb in zip(documents, all_embeddings, strict=True):
doc.embedding = emb
new_documents.append(replace(doc, embedding=emb))

return documents
return new_documents

@component.output_types(documents=list[Document])
def run(self, documents: list[Document]) -> dict[str, list[Document]]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from dataclasses import replace
from typing import Any

from botocore.exceptions import ClientError
Expand Down Expand Up @@ -251,8 +252,7 @@ def resolve_secret(secret: Secret | None) -> str | None:
idx = result["index"]
score = result["relevanceScore"]
doc = documents[idx]
doc.score = score
sorted_docs.append(doc)
sorted_docs.append(replace(doc, score=score))

return {"documents": sorted_docs}
except ClientError as client_error:
Expand Down
68 changes: 68 additions & 0 deletions integrations/amazon_bedrock/tests/test_document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,74 @@ def mock_invoke_model(*args, **kwargs):
assert doc.content == docs[i].content
assert doc.embedding == [0.1, 0.2, 0.3]

def test_run_cohere_does_not_modify_original_documents(self, mock_boto3_session):
embedder = AmazonBedrockDocumentEmbedder(model="cohere.embed-english-v3")

original_docs = [
Document(content="test 1", id="doc1"),
Document(content="test 2", id="doc2"),
]

# Store original IDs to verify they're the same objects
original_doc_ids = [id(doc) for doc in original_docs]
original_embeddings = [doc.embedding for doc in original_docs]

with patch.object(embedder, "_client") as mock_client:
mock_client.invoke_model.return_value = {
"body": io.StringIO('{"embeddings": [[0.1, 0.2], [0.3, 0.4]]}'),
}

result = embedder.run(documents=original_docs)

# Verify originals are unchanged
assert all(doc.embedding is None for doc in original_docs)
assert original_embeddings == [None, None]

# Verify returned documents are NEW instances
returned_doc_ids = [id(doc) for doc in result["documents"]]
assert original_doc_ids != returned_doc_ids

# Verify returned documents have embeddings
assert result["documents"][0].embedding == [0.1, 0.2]
assert result["documents"][1].embedding == [0.3, 0.4]
assert result["documents"][0].content == "test 1"
assert result["documents"][1].content == "test 2"

def test_run_titan_does_not_modify_original_documents(self, mock_boto3_session):
embedder = AmazonBedrockDocumentEmbedder(model="amazon.titan-embed-text-v1")

original_docs = [
Document(content="test 1", id="doc1"),
Document(content="test 2", id="doc2"),
]

# Store original IDs to verify they're the same objects
original_doc_ids = [id(doc) for doc in original_docs]
original_embeddings = [doc.embedding for doc in original_docs]

with patch.object(embedder, "_client") as mock_client:
# Titan returns one embedding at a time
mock_client.invoke_model.side_effect = [
{"body": io.StringIO('{"embedding": [0.1, 0.2]}')},
{"body": io.StringIO('{"embedding": [0.3, 0.4]}')},
]

result = embedder.run(documents=original_docs)

# Verify originals are unchanged
assert all(doc.embedding is None for doc in original_docs)
assert original_embeddings == [None, None]

# Verify returned documents are NEW instances
returned_doc_ids = [id(doc) for doc in result["documents"]]
assert original_doc_ids != returned_doc_ids

# Verify returned documents have embeddings
assert result["documents"][0].embedding == [0.1, 0.2]
assert result["documents"][1].embedding == [0.3, 0.4]
assert result["documents"][0].content == "test 1"
assert result["documents"][1].content == "test 2"

@pytest.mark.integration
@pytest.mark.skipif(
not os.getenv("AWS_ACCESS_KEY_ID")
Expand Down