Skip to content

Commit 6f2332a

Browse files
authored
feat: Update FastembedRanker to auto call warm_up on first run (#2834)
* Automatically call warm_up() if model is not loaded in FastembedRanker * Remove explicit call to warm_up() in ranker_example.py * Add unit tests * Fix typing issue * Fix tying * Update docstrings * Remove explicit warm_up calls from integration tests
1 parent 266ac3f commit 6f2332a

11 files changed

Lines changed: 72 additions & 17 deletions

integrations/fastembed/examples/ranker_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
]
1212

1313
ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2")
14-
ranker.warm_up()
1514
reranked_documents = ranker.run(query=query, documents=documents)["documents"]
1615

1716

integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ class FastembedDocumentEmbedder:
2929
batch_size=256,
3030
)
3131
32-
doc_embedder.warm_up()
33-
3432
# Text taken from PubMed QA Dataset (https://huggingface.co/datasets/pubmed_qa)
3533
document_list = [
3634
Document(

integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ class FastembedSparseDocumentEmbedder:
2828
batch_size=32,
2929
)
3030
31-
sparse_doc_embedder.warm_up()
32-
3331
# Text taken from PubMed QA Dataset (https://huggingface.co/datasets/pubmed_qa)
3432
document_list = [
3533
Document(

integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class FastembedSparseTextEmbedder:
2828
sparse_text_embedder = FastembedSparseTextEmbedder(
2929
model="prithivida/Splade_PP_en_v1"
3030
)
31-
sparse_text_embedder.warm_up()
3231
3332
sparse_embedding = sparse_text_embedder.run(text)["sparse_embedding"]
3433
```

integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@ class FastembedTextEmbedder:
2424
text_embedder = FastembedTextEmbedder(
2525
model="BAAI/bge-small-en-v1.5"
2626
)
27-
text_embedder.warm_up()
2827
2928
embedding = text_embedder.run(text)["embedding"]
3029
```

integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,13 +177,12 @@ def run(self, query: str, documents: list[Document], top_k: int | None = None) -
177177
raise ValueError(msg)
178178

179179
if self._model is None:
180-
msg = "The ranker model has not been loaded. Please call warm_up() before running."
181-
raise RuntimeError(msg)
180+
self.warm_up()
182181

183182
fastembed_input_docs = self._prepare_fastembed_input_docs(documents)
184183

185184
scores = list(
186-
self._model.rerank(
185+
self._model.rerank( # type: ignore[union-attr]
187186
query=query,
188187
documents=fastembed_input_docs,
189188
batch_size=self.batch_size,

integrations/fastembed/tests/test_fastembed_document_embedder.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,12 +281,25 @@ def test_embed_metadata(self):
281281
parallel=None,
282282
)
283283

284+
def test_run_calls_warm_up(self):
285+
embedder = FastembedDocumentEmbedder()
286+
assert embedder.embedding_backend is None
287+
288+
mock_backend = MagicMock()
289+
mock_backend.embed.return_value = [[0.1, 0.2, 0.3]]
290+
291+
with patch.object(
292+
embedder, "warm_up", side_effect=lambda: setattr(embedder, "embedding_backend", mock_backend)
293+
) as mock_warm_up:
294+
embedder.run(documents=[Document(content="test document")])
295+
296+
mock_warm_up.assert_called_once()
297+
284298
@pytest.mark.integration
285299
def test_run(self):
286300
embedder = FastembedDocumentEmbedder(
287301
model="BAAI/bge-small-en-v1.5",
288302
)
289-
embedder.warm_up()
290303

291304
doc = Document(content="Parton energy loss in QCD matter")
292305

integrations/fastembed/tests/test_fastembed_ranker.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from unittest.mock import MagicMock
5+
from unittest.mock import MagicMock, patch
66

77
import numpy as np
88
import pytest
@@ -257,10 +257,23 @@ def test_embed_metadata(self):
257257
parallel=None,
258258
)
259259

260+
def test_run_calls_warm_up(self):
261+
"""
262+
Unit test to check that warm_up is called when run is called for the first time.
263+
"""
264+
ranker = FastembedRanker()
265+
266+
mock_model = MagicMock()
267+
mock_model.rerank.return_value = [0.5]
268+
269+
with patch.object(ranker, "warm_up", side_effect=lambda: setattr(ranker, "_model", mock_model)) as mock_warm_up:
270+
ranker.run(query="test query", documents=[Document(content="test document")])
271+
272+
mock_warm_up.assert_called_once()
273+
260274
@pytest.mark.integration
261275
def test_run(self):
262276
ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2", top_k=2)
263-
ranker.warm_up()
264277

265278
query = "Who is maintaining Qdrant?"
266279
documents = [

integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,19 @@ def test_init_with_model_kwargs_parameters(self):
304304

305305
assert embedder.model_kwargs == bm25_config
306306

307+
def test_run_calls_warm_up(self):
308+
embedder = FastembedSparseDocumentEmbedder()
309+
310+
mock_backend = MagicMock()
311+
mock_backend.embed.return_value = [{"indices": [0], "values": [0.5]}]
312+
313+
with patch.object(
314+
embedder, "warm_up", side_effect=lambda: setattr(embedder, "embedding_backend", mock_backend)
315+
) as mock_warm_up:
316+
embedder.run(documents=[Document(content="test document")])
317+
318+
mock_warm_up.assert_called_once()
319+
307320
@pytest.mark.integration
308321
def test_run_with_model_kwargs(self):
309322
"""
@@ -317,7 +330,6 @@ def test_run_with_model_kwargs(self):
317330
model="Qdrant/bm42-all-minilm-l6-v2-attentions",
318331
model_kwargs=bm42_config,
319332
)
320-
embedder.warm_up()
321333

322334
doc = Document(content="Example content using BM42")
323335

@@ -336,7 +348,6 @@ def test_run(self):
336348
embedder = FastembedSparseDocumentEmbedder(
337349
model="prithivida/Splade_PP_en_v1",
338350
)
339-
embedder.warm_up()
340351

341352
doc = Document(content="Parton energy loss in QCD matter")
342353

integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,20 @@ def test_init_with_model_kwargs_parameters(self):
224224

225225
assert embedder.model_kwargs == bm25_config
226226

227+
def test_run_calls_warm_up(self):
228+
embedder = FastembedSparseTextEmbedder()
229+
assert embedder.embedding_backend is None
230+
231+
mock_backend = MagicMock()
232+
mock_backend.embed.return_value = [[0.1, 0.2, 0.3]]
233+
234+
with patch.object(
235+
embedder, "warm_up", side_effect=lambda: setattr(embedder, "embedding_backend", mock_backend)
236+
) as mock_warm_up:
237+
embedder.run(text="test text")
238+
239+
mock_warm_up.assert_called_once()
240+
227241
@pytest.mark.integration
228242
def test_run_with_model_kwargs(self):
229243
"""
@@ -258,7 +272,6 @@ def test_run(self):
258272
embedder = FastembedSparseTextEmbedder(
259273
model="prithivida/Splade_PP_en_v1",
260274
)
261-
embedder.warm_up()
262275

263276
text = "Parton energy loss in QCD matter"
264277

0 commit comments

Comments
 (0)