Skip to content

Commit 105fc5d

Browse files
authored
feat: Update FastEmbed components to auto call run warm_up and don't modify Documents in place (#2678)
* Add license headers, don't edit documents in-place, auto-call warm_up * Fix bug * Formatting * mypy
1 parent c038259 commit 105fc5d

14 files changed

Lines changed: 103 additions & 53 deletions

integrations/fastembed/LICENSE.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ APPENDIX: How to apply the Apache License to your work.
5858

5959
To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives.
6060

61-
Copyright [yyyy] [name of copyright owner]
61+
Copyright 2024 deepset GmbH
6262

6363
Licensed under the Apache License, Version 2.0 (the "License");
6464
you may not use this file except in compliance with the License.

integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
22
#
33
# SPDX-License-Identifier: Apache-2.0
4+
45
from .fastembed_document_embedder import FastembedDocumentEmbedder
56
from .fastembed_sparse_document_embedder import FastembedSparseDocumentEmbedder
67
from .fastembed_sparse_text_embedder import FastembedSparseTextEmbedder

integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from dataclasses import replace
16
from typing import Any, Optional
27

38
from haystack import Document, component, default_to_dict
49

5-
from .embedding_backend.fastembed_backend import _FastembedEmbeddingBackendFactory
10+
from .embedding_backend.fastembed_backend import _FastembedEmbeddingBackend, _FastembedEmbeddingBackendFactory
611

712

813
@component
@@ -68,7 +73,7 @@ def __init__(
6873
local_files_only: bool = False,
6974
meta_fields_to_embed: Optional[list[str]] = None,
7075
embedding_separator: str = "\n",
71-
):
76+
) -> None:
7277
"""
7378
Create an FastembedDocumentEmbedder component.
7479
@@ -102,6 +107,7 @@ def __init__(
102107
self.local_files_only = local_files_only
103108
self.meta_fields_to_embed = meta_fields_to_embed or []
104109
self.embedding_separator = embedding_separator
110+
self.embedding_backend: Optional[_FastembedEmbeddingBackend] = None
105111

106112
def to_dict(self) -> dict[str, Any]:
107113
"""
@@ -124,11 +130,11 @@ def to_dict(self) -> dict[str, Any]:
124130
embedding_separator=self.embedding_separator,
125131
)
126132

127-
def warm_up(self):
133+
def warm_up(self) -> None:
128134
"""
129135
Initializes the component.
130136
"""
131-
if not hasattr(self, "embedding_backend"):
137+
if self.embedding_backend is None:
132138
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
133139
model_name=self.model_name,
134140
cache_dir=self.cache_dir,
@@ -157,26 +163,28 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]:
157163
:param documents: List of Documents to embed.
158164
:returns: A dictionary with the following keys:
159165
- `documents`: List of Documents with each Document's `embedding` field set to the computed embeddings.
166+
:raises TypeError: If the input is not a list of Documents.
160167
"""
161168
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
162169
msg = (
163170
"FastembedDocumentEmbedder expects a list of Documents as input. "
164171
"In case you want to embed a list of strings, please use the FastembedTextEmbedder."
165172
)
166173
raise TypeError(msg)
167-
if not hasattr(self, "embedding_backend"):
168-
msg = "The embedding model has not been loaded. Please call warm_up() before running."
169-
raise RuntimeError(msg)
174+
175+
if self.embedding_backend is None:
176+
self.warm_up()
170177

171178
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
172-
embeddings = self.embedding_backend.embed(
179+
embeddings = self.embedding_backend.embed( # type: ignore[union-attr]
173180
texts_to_embed,
174181
batch_size=self.batch_size,
175182
progress_bar=self.progress_bar,
176183
parallel=self.parallel,
177184
)
178185

186+
new_documents = []
179187
for doc, emb in zip(documents, embeddings):
180-
doc.embedding = emb
188+
new_documents.append(replace(doc, embedding=emb))
181189

182-
return {"documents": documents}
190+
return {"documents": new_documents}

integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from dataclasses import replace
16
from typing import Any, Optional
27

38
from haystack import Document, component, default_to_dict
49

5-
from .embedding_backend.fastembed_backend import _FastembedSparseEmbeddingBackendFactory
10+
from .embedding_backend.fastembed_backend import (
11+
_FastembedSparseEmbeddingBackend,
12+
_FastembedSparseEmbeddingBackendFactory,
13+
)
614

715

816
@component
@@ -63,7 +71,7 @@ def __init__(
6371
meta_fields_to_embed: Optional[list[str]] = None,
6472
embedding_separator: str = "\n",
6573
model_kwargs: Optional[dict[str, Any]] = None,
66-
):
74+
) -> None:
6775
"""
6876
Create an FastembedDocumentEmbedder component.
6977
@@ -95,6 +103,7 @@ def __init__(
95103
self.meta_fields_to_embed = meta_fields_to_embed or []
96104
self.embedding_separator = embedding_separator
97105
self.model_kwargs = model_kwargs
106+
self.embedding_backend: Optional[_FastembedSparseEmbeddingBackend] = None
98107

99108
def to_dict(self) -> dict[str, Any]:
100109
"""
@@ -116,11 +125,11 @@ def to_dict(self) -> dict[str, Any]:
116125
model_kwargs=self.model_kwargs,
117126
)
118127

119-
def warm_up(self):
128+
def warm_up(self) -> None:
120129
"""
121130
Initializes the component.
122131
"""
123-
if not hasattr(self, "embedding_backend"):
132+
if self.embedding_backend is None:
124133
self.embedding_backend = _FastembedSparseEmbeddingBackendFactory.get_embedding_backend(
125134
model_name=self.model_name,
126135
cache_dir=self.cache_dir,
@@ -149,25 +158,28 @@ def run(self, documents: list[Document]) -> dict[str, list[Document]]:
149158
:returns: A dictionary with the following keys:
150159
- `documents`: List of Documents with each Document's `sparse_embedding`
151160
field set to the computed embeddings.
161+
:raises TypeError: If the input is not a list of Documents.
152162
"""
153163
if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)):
154164
msg = (
155165
"FastembedSparseDocumentEmbedder expects a list of Documents as input. "
156166
"In case you want to embed a list of strings, please use the FastembedTextEmbedder."
157167
)
158168
raise TypeError(msg)
159-
if not hasattr(self, "embedding_backend"):
160-
msg = "The embedding model has not been loaded. Please call warm_up() before running."
161-
raise RuntimeError(msg)
169+
170+
if self.embedding_backend is None:
171+
self.warm_up()
162172

163173
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
164-
embeddings = self.embedding_backend.embed(
174+
embeddings = self.embedding_backend.embed( # type: ignore[union-attr]
165175
texts_to_embed,
166176
batch_size=self.batch_size,
167177
progress_bar=self.progress_bar,
168178
parallel=self.parallel,
169179
)
170180

181+
new_documents = []
171182
for doc, emb in zip(documents, embeddings):
172-
doc.sparse_embedding = emb
173-
return {"documents": documents}
183+
new_documents.append(replace(doc, sparse_embedding=emb))
184+
185+
return {"documents": new_documents}

integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,16 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
15
from typing import Any, Optional
26

37
from haystack import component, default_to_dict
48
from haystack.dataclasses.sparse_embedding import SparseEmbedding
59

6-
from .embedding_backend.fastembed_backend import _FastembedSparseEmbeddingBackendFactory
10+
from .embedding_backend.fastembed_backend import (
11+
_FastembedSparseEmbeddingBackend,
12+
_FastembedSparseEmbeddingBackendFactory,
13+
)
714

815

916
@component
@@ -36,7 +43,7 @@ def __init__(
3643
parallel: Optional[int] = None,
3744
local_files_only: bool = False,
3845
model_kwargs: Optional[dict[str, Any]] = None,
39-
):
46+
) -> None:
4047
"""
4148
Create a FastembedSparseTextEmbedder component.
4249
@@ -61,6 +68,7 @@ def __init__(
6168
self.parallel = parallel
6269
self.local_files_only = local_files_only
6370
self.model_kwargs = model_kwargs
71+
self.embedding_backend: Optional[_FastembedSparseEmbeddingBackend] = None
6472

6573
def to_dict(self) -> dict[str, Any]:
6674
"""
@@ -80,11 +88,11 @@ def to_dict(self) -> dict[str, Any]:
8088
model_kwargs=self.model_kwargs,
8189
)
8290

83-
def warm_up(self):
91+
def warm_up(self) -> None:
8492
"""
8593
Initializes the component.
8694
"""
87-
if not hasattr(self, "embedding_backend"):
95+
if self.embedding_backend is None:
8896
self.embedding_backend = _FastembedSparseEmbeddingBackendFactory.get_embedding_backend(
8997
model_name=self.model_name,
9098
cache_dir=self.cache_dir,
@@ -102,19 +110,18 @@ def run(self, text: str) -> dict[str, SparseEmbedding]:
102110
:returns: A dictionary with the following keys:
103111
- `embedding`: A list of floats representing the embedding of the input text.
104112
:raises TypeError: If the input is not a string.
105-
:raises RuntimeError: If the embedding model has not been loaded.
106113
"""
107114
if not isinstance(text, str):
108115
msg = (
109116
"FastembedSparseTextEmbedder expects a string as input. "
110117
"In case you want to embed a list of Documents, please use the FastembedDocumentEmbedder."
111118
)
112119
raise TypeError(msg)
113-
if not hasattr(self, "embedding_backend"):
114-
msg = "The embedding model has not been loaded. Please call warm_up() before running."
115-
raise RuntimeError(msg)
116120

117-
embedding = self.embedding_backend.embed(
121+
if self.embedding_backend is None:
122+
self.warm_up()
123+
124+
embedding = self.embedding_backend.embed( # type: ignore[union-attr]
118125
[text],
119126
progress_bar=self.progress_bar,
120127
parallel=self.parallel,

integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
15
from typing import Any, Optional
26

37
from haystack import component, default_to_dict
48

5-
from .embedding_backend.fastembed_backend import _FastembedEmbeddingBackendFactory
9+
from .embedding_backend.fastembed_backend import _FastembedEmbeddingBackend, _FastembedEmbeddingBackendFactory
610

711

812
@component
@@ -36,7 +40,7 @@ def __init__(
3640
progress_bar: bool = True,
3741
parallel: Optional[int] = None,
3842
local_files_only: bool = False,
39-
):
43+
) -> None:
4044
"""
4145
Create a FastembedTextEmbedder component.
4246
@@ -63,6 +67,7 @@ def __init__(
6367
self.progress_bar = progress_bar
6468
self.parallel = parallel
6569
self.local_files_only = local_files_only
70+
self.embedding_backend: Optional[_FastembedEmbeddingBackend] = None
6671

6772
def to_dict(self) -> dict[str, Any]:
6873
"""
@@ -83,11 +88,11 @@ def to_dict(self) -> dict[str, Any]:
8388
local_files_only=self.local_files_only,
8489
)
8590

86-
def warm_up(self):
91+
def warm_up(self) -> None:
8792
"""
8893
Initializes the component.
8994
"""
90-
if not hasattr(self, "embedding_backend"):
95+
if self.embedding_backend is None:
9196
self.embedding_backend = _FastembedEmbeddingBackendFactory.get_embedding_backend(
9297
model_name=self.model_name,
9398
cache_dir=self.cache_dir,
@@ -104,21 +109,20 @@ def run(self, text: str) -> dict[str, list[float]]:
104109
:returns: A dictionary with the following keys:
105110
- `embedding`: A list of floats representing the embedding of the input text.
106111
:raises TypeError: If the input is not a string.
107-
:raises RuntimeError: If the embedding model has not been loaded.
108112
"""
109113
if not isinstance(text, str):
110114
msg = (
111115
"FastembedTextEmbedder expects a string as input. "
112116
"In case you want to embed a list of Documents, please use the FastembedDocumentEmbedder."
113117
)
114118
raise TypeError(msg)
115-
if not hasattr(self, "embedding_backend"):
116-
msg = "The embedding model has not been loaded. Please call warm_up() before running."
117-
raise RuntimeError(msg)
119+
120+
if self.embedding_backend is None:
121+
self.warm_up()
118122

119123
text_to_embed = [self.prefix + text + self.suffix]
120124
embedding = list(
121-
self.embedding_backend.embed(
125+
self.embedding_backend.embed( # type: ignore[union-attr]
122126
text_to_embed,
123127
progress_bar=self.progress_bar,
124128
parallel=self.parallel,
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
15
from .ranker import FastembedRanker
26

37
__all__ = ["FastembedRanker"]

integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
15
from typing import Any, Optional
26

37
from haystack import Document, component, default_from_dict, default_to_dict, logging

integrations/fastembed/tests/test_fastembed_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
15
from unittest.mock import patch
26

37
from haystack_integrations.components.embedders.fastembed.embedding_backend.fastembed_backend import (

integrations/fastembed/tests/test_fastembed_document_embedder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
# SPDX-FileCopyrightText: 2024-present deepset GmbH <info@deepset.ai>
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
15
from unittest.mock import MagicMock, patch
26

37
import numpy as np

0 commit comments

Comments
 (0)