diff --git a/integrations/fastembed/examples/ranker_example.py b/integrations/fastembed/examples/ranker_example.py index 593334e905..235fe4f814 100644 --- a/integrations/fastembed/examples/ranker_example.py +++ b/integrations/fastembed/examples/ranker_example.py @@ -11,7 +11,6 @@ ] ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2") -ranker.warm_up() reranked_documents = ranker.run(query=query, documents=documents)["documents"] diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py index c10b05d8c3..b2c28918dc 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_document_embedder.py @@ -29,8 +29,6 @@ class FastembedDocumentEmbedder: batch_size=256, ) - doc_embedder.warm_up() - # Text taken from PubMed QA Dataset (https://huggingface.co/datasets/pubmed_qa) document_list = [ Document( diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py index f20ce84d20..5fda8ffbbe 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_document_embedder.py @@ -28,8 +28,6 @@ class FastembedSparseDocumentEmbedder: batch_size=32, ) - sparse_doc_embedder.warm_up() - # Text taken from PubMed QA Dataset (https://huggingface.co/datasets/pubmed_qa) document_list = [ Document( diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py index 8824458a05..437ad3c7f3 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_sparse_text_embedder.py @@ -28,7 +28,6 @@ class FastembedSparseTextEmbedder: sparse_text_embedder = FastembedSparseTextEmbedder( model="prithivida/Splade_PP_en_v1" ) - sparse_text_embedder.warm_up() sparse_embedding = sparse_text_embedder.run(text)["sparse_embedding"] ``` diff --git a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py index b6e24500f3..d593f2b9b3 100644 --- a/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py +++ b/integrations/fastembed/src/haystack_integrations/components/embedders/fastembed/fastembed_text_embedder.py @@ -24,7 +24,6 @@ class FastembedTextEmbedder: text_embedder = FastembedTextEmbedder( model="BAAI/bge-small-en-v1.5" ) - text_embedder.warm_up() embedding = text_embedder.run(text)["embedding"] ``` diff --git a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py index 6120bac4a3..26878387f4 100644 --- a/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py +++ b/integrations/fastembed/src/haystack_integrations/components/rankers/fastembed/ranker.py @@ -177,13 +177,12 @@ def run(self, query: str, documents: list[Document], top_k: int | None = None) - raise ValueError(msg) if self._model is None: - msg = "The ranker model has not been loaded. Please call warm_up() before running." - raise RuntimeError(msg) + self.warm_up() fastembed_input_docs = self._prepare_fastembed_input_docs(documents) scores = list( - self._model.rerank( + self._model.rerank( # type: ignore[union-attr] query=query, documents=fastembed_input_docs, batch_size=self.batch_size, diff --git a/integrations/fastembed/tests/test_fastembed_document_embedder.py b/integrations/fastembed/tests/test_fastembed_document_embedder.py index 37a7a0c581..61be6f4409 100644 --- a/integrations/fastembed/tests/test_fastembed_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_document_embedder.py @@ -281,12 +281,25 @@ def test_embed_metadata(self): parallel=None, ) + def test_run_calls_warm_up(self): + embedder = FastembedDocumentEmbedder() + assert embedder.embedding_backend is None + + mock_backend = MagicMock() + mock_backend.embed.return_value = [[0.1, 0.2, 0.3]] + + with patch.object( + embedder, "warm_up", side_effect=lambda: setattr(embedder, "embedding_backend", mock_backend) + ) as mock_warm_up: + embedder.run(documents=[Document(content="test document")]) + + mock_warm_up.assert_called_once() + @pytest.mark.integration def test_run(self): embedder = FastembedDocumentEmbedder( model="BAAI/bge-small-en-v1.5", ) - embedder.warm_up() doc = Document(content="Parton energy loss in QCD matter") diff --git a/integrations/fastembed/tests/test_fastembed_ranker.py b/integrations/fastembed/tests/test_fastembed_ranker.py index 0985762613..3d78991def 100644 --- a/integrations/fastembed/tests/test_fastembed_ranker.py +++ b/integrations/fastembed/tests/test_fastembed_ranker.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import numpy as np import pytest @@ -257,10 +257,23 @@ def test_embed_metadata(self): parallel=None, ) + def test_run_calls_warm_up(self): + """ + Unit test to check that warm_up is called when run is called for the first time. + """ + ranker = FastembedRanker() + + mock_model = MagicMock() + mock_model.rerank.return_value = [0.5] + + with patch.object(ranker, "warm_up", side_effect=lambda: setattr(ranker, "_model", mock_model)) as mock_warm_up: + ranker.run(query="test query", documents=[Document(content="test document")]) + + mock_warm_up.assert_called_once() + @pytest.mark.integration def test_run(self): ranker = FastembedRanker(model_name="Xenova/ms-marco-MiniLM-L-6-v2", top_k=2) - ranker.warm_up() query = "Who is maintaining Qdrant?" documents = [ diff --git a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py index 0b58d818aa..8b80b884f1 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_document_embedder.py @@ -304,6 +304,19 @@ def test_init_with_model_kwargs_parameters(self): assert embedder.model_kwargs == bm25_config + def test_run_calls_warm_up(self): + embedder = FastembedSparseDocumentEmbedder() + + mock_backend = MagicMock() + mock_backend.embed.return_value = [{"indices": [0], "values": [0.5]}] + + with patch.object( + embedder, "warm_up", side_effect=lambda: setattr(embedder, "embedding_backend", mock_backend) + ) as mock_warm_up: + embedder.run(documents=[Document(content="test document")]) + + mock_warm_up.assert_called_once() + @pytest.mark.integration def test_run_with_model_kwargs(self): """ @@ -317,7 +330,6 @@ def test_run_with_model_kwargs(self): model="Qdrant/bm42-all-minilm-l6-v2-attentions", model_kwargs=bm42_config, ) - embedder.warm_up() doc = Document(content="Example content using BM42") @@ -336,7 +348,6 @@ def test_run(self): embedder = FastembedSparseDocumentEmbedder( model="prithivida/Splade_PP_en_v1", ) - embedder.warm_up() doc = Document(content="Parton energy loss in QCD matter") diff --git a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py index c9e3f77130..0a7936dad6 100644 --- a/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_sparse_text_embedder.py @@ -224,6 +224,20 @@ def test_init_with_model_kwargs_parameters(self): assert embedder.model_kwargs == bm25_config + def test_run_calls_warm_up(self): + embedder = FastembedSparseTextEmbedder() + assert embedder.embedding_backend is None + + mock_backend = MagicMock() + mock_backend.embed.return_value = [[0.1, 0.2, 0.3]] + + with patch.object( + embedder, "warm_up", side_effect=lambda: setattr(embedder, "embedding_backend", mock_backend) + ) as mock_warm_up: + embedder.run(text="test text") + + mock_warm_up.assert_called_once() + @pytest.mark.integration def test_run_with_model_kwargs(self): """ @@ -258,7 +272,6 @@ def test_run(self): embedder = FastembedSparseTextEmbedder( model="prithivida/Splade_PP_en_v1", ) - embedder.warm_up() text = "Parton energy loss in QCD matter" diff --git a/integrations/fastembed/tests/test_fastembed_text_embedder.py b/integrations/fastembed/tests/test_fastembed_text_embedder.py index da969dffa8..008e905cab 100644 --- a/integrations/fastembed/tests/test_fastembed_text_embedder.py +++ b/integrations/fastembed/tests/test_fastembed_text_embedder.py @@ -202,12 +202,25 @@ def test_run_wrong_incorrect_format(self): with pytest.raises(TypeError, match="FastembedTextEmbedder expects a string as input"): embedder.run(text=list_integers_input) + def test_run_calls_warm_up(self): + embedder = FastembedTextEmbedder() + assert embedder.embedding_backend is None + + mock_backend = MagicMock() + mock_backend.embed.return_value = [[0.1, 0.2, 0.3]] + + with patch.object( + embedder, "warm_up", side_effect=lambda: setattr(embedder, "embedding_backend", mock_backend) + ) as mock_warm_up: + embedder.run(text="test text") + + mock_warm_up.assert_called_once() + @pytest.mark.integration def test_run(self): embedder = FastembedTextEmbedder( model="BAAI/bge-small-en-v1.5", ) - embedder.warm_up() text = "Parton energy loss in QCD matter"