|
2 | 2 | # |
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
| 5 | +from dataclasses import replace |
5 | 6 | from typing import Any, Optional, Union |
6 | 7 |
|
7 | 8 | from haystack import Document, component, default_from_dict, default_to_dict |
@@ -56,7 +57,7 @@ def __init__( |
56 | 57 | progress_bar: bool = True, |
57 | 58 | meta_fields_to_embed: Optional[list[str]] = None, |
58 | 59 | embedding_separator: str = "\n", |
59 | | - ): |
| 60 | + ) -> None: |
60 | 61 | """ |
61 | 62 | Create a OptimumDocumentEmbedder component. |
62 | 63 |
|
@@ -140,7 +141,7 @@ def __init__( |
140 | 141 | self._backend = _EmbedderBackend(params) |
141 | 142 | self._initialized = False |
142 | 143 |
|
143 | | - def warm_up(self): |
| 144 | + def warm_up(self) -> None: |
144 | 145 | """ |
145 | 146 | Initializes the component. |
146 | 147 | """ |
@@ -223,7 +224,9 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]: |
223 | 224 |
|
224 | 225 | texts_to_embed = self._prepare_texts_to_embed(documents=documents) |
225 | 226 | embeddings = self._backend.embed_texts(texts_to_embed) |
| 227 | + |
| 228 | + new_documents = [] |
226 | 229 | for doc, emb in zip(documents, embeddings): |
227 | | - doc.embedding = emb |
| 230 | + new_documents.append(replace(doc, embedding=emb)) |
228 | 231 |
|
229 | | - return {"documents": documents} |
| 232 | + return {"documents": new_documents} |
0 commit comments