diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index 4c021b1473..515e022421 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -31,6 +31,9 @@ env: FORCE_COLOR: "1" VLLM_MODEL: "Qwen/Qwen3-0.6B" VLLM_EMBEDDING_MODEL: "sentence-transformers/all-MiniLM-L6-v2" + VLLM_RANKER_MODEL: "BAAI/bge-reranker-base" + VLLM_TARGET_DEVICE: "cpu" + VLLM_CPU_KVCACHE_SPACE: "4" # we only test on Ubuntu to keep vLLM server running simple TEST_MATRIX_OS: '["ubuntu-latest"]' # vLLM is not compatible with Python 3.14. https://github.com/vllm-project/vllm/issues/34096 @@ -90,9 +93,6 @@ jobs: --torch-backend cpu - name: Start vLLM chat server - env: - VLLM_TARGET_DEVICE: "cpu" - VLLM_CPU_KVCACHE_SPACE: "4" run: | nohup hatch run -- vllm serve ${{ env.VLLM_MODEL }} \ --port 8000 \ @@ -120,9 +120,6 @@ jobs: echo "vLLM chat server started successfully." - name: Start vLLM embedding server - env: - VLLM_TARGET_DEVICE: "cpu" - VLLM_CPU_KVCACHE_SPACE: "4" run: | nohup hatch run -- vllm serve ${{ env.VLLM_EMBEDDING_MODEL }} \ --port 8001 \ @@ -144,6 +141,27 @@ jobs: echo "vLLM embedding server started successfully." + - name: Start vLLM ranker server + run: | + nohup hatch run -- vllm serve ${{ env.VLLM_RANKER_MODEL }} \ + --port 8002 \ + --enforce-eager \ + --max-num-seqs 1 & + + # Wait for the vLLM ranker server to be ready with a timeout of 300 seconds + timeout=300 + while [ $timeout -gt 0 ] && ! curl -sSf http://localhost:8002/health > /dev/null 2>&1; do + echo "Waiting for vLLM ranker server to start..." + sleep 10 + ((timeout-=10)) + done + + if [ $timeout -eq 0 ]; then + echo "Timed out waiting for vLLM ranker server to start." + exit 1 + fi + + echo "vLLM ranker server started successfully." - name: Lint if: matrix.python-version == '3.10' && runner.os == 'Linux' run: hatch run fmt-check && hatch run test:types diff --git a/README.md b/README.md index 6e1ace16a3..1c213b7a70 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [togetherai-haystack](integrations/togetherai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/togetherai-haystack.svg)](https://pypi.org/project/togetherai-haystack) | [![Test / togetherai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/togetherai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/togetherai.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-togetherai/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-togetherai/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-togetherai-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-togetherai-combined/htmlcov/index.html) | | [unstructured-fileconverter-haystack](integrations/unstructured/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-unstructured/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-unstructured/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-unstructured-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-unstructured-combined/htmlcov/index.html) | | [valkey-haystack](integrations/valkey/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/valkey-haystack.svg)](https://pypi.org/project/valkey-haystack) | [![Test / valkey](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/valkey.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/valkey.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-valkey/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-valkey/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-valkey-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-valkey-combined/htmlcov/index.html) | -| [vllm-haystack](integrations/vllm/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/vllm-haystack.svg)](https://pypi.org/project/vllm-haystack) | [![Test / vllm](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/vllm.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/vllm.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-vllm/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-vllm/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-vllm-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-vllm-combined/htmlcov/index.html) | +| [vllm-haystack](integrations/vllm/) | Embedder, Generator, Ranker | [![PyPI - Version](https://img.shields.io/pypi/v/vllm-haystack.svg)](https://pypi.org/project/vllm-haystack) | [![Test / vllm](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/vllm.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/vllm.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-vllm/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-vllm/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-vllm-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-vllm-combined/htmlcov/index.html) | | [watsonx-haystack](integrations/watsonx/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/watsonx-haystack.svg?color=orange)](https://pypi.org/project/watsonx-haystack) | [![Test / watsonx](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/watsonx.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/watsonx.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-watsonx/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-watsonx/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-watsonx-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-watsonx-combined/htmlcov/index.html) | | [weave-haystack](integrations/weave/) | Tracer | [![PyPI - Version](https://img.shields.io/pypi/v/weave-haystack.svg)](https://pypi.org/project/weave-haystack) | [![Test / weave](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weave.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weave.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-weave/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-weave/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-weave-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-weave-combined/htmlcov/index.html) | | [weaviate-haystack](integrations/weaviate/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/weaviate-haystack.svg)](https://pypi.org/project/weaviate-haystack) | [![Test / weaviate](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-weaviate/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-weaviate/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-weaviate-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-weaviate-combined/htmlcov/index.html) | diff --git a/integrations/vllm/README.md b/integrations/vllm/README.md index a0a498cf8e..543651de99 100644 --- a/integrations/vllm/README.md +++ b/integrations/vllm/README.md @@ -26,4 +26,11 @@ vLLM-metal does not support embedding models. On macOS, you can run the embeddin # embedders server (port 8001) docker run --rm -p 8001:8000 -e VLLM_CPU_OMP_THREADS_BIND=0-3 vllm/vllm-openai-cpu:latest \ --model sentence-transformers/all-MiniLM-L6-v2 --enforce-eager +``` + +To run the ranker server, use CPU Docker image: +```bash +# ranker server (port 8002) +docker run --rm -p 8002:8000 -e VLLM_CPU_OMP_THREADS_BIND=0-3 vllm/vllm-openai-cpu:latest \ + --model BAAI/bge-reranker-base --enforce-eager ``` \ No newline at end of file diff --git a/integrations/vllm/pydoc/config_docusaurus.yml b/integrations/vllm/pydoc/config_docusaurus.yml index daf8d8b75c..932b048ac5 100644 --- a/integrations/vllm/pydoc/config_docusaurus.yml +++ b/integrations/vllm/pydoc/config_docusaurus.yml @@ -3,6 +3,7 @@ loaders: - haystack_integrations.components.generators.vllm.chat.chat_generator - haystack_integrations.components.embedders.vllm.text_embedder - haystack_integrations.components.embedders.vllm.document_embedder + - haystack_integrations.components.rankers.vllm.ranker search_path: [../src] processors: - type: filter diff --git a/integrations/vllm/pyproject.toml b/integrations/vllm/pyproject.toml index 91d1a9aaa7..72dafad7a9 100644 --- a/integrations/vllm/pyproject.toml +++ b/integrations/vllm/pyproject.toml @@ -57,7 +57,6 @@ dependencies = [ "pytest-rerunfailures", "mypy", "pip", - "Pillow", ] [tool.hatch.envs.test.scripts] @@ -66,7 +65,7 @@ integration = 'pytest -m "integration" {args:tests}' all = 'pytest {args:tests}' unit-cov-retry = 'pytest --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x -m "not integration" {args:tests}' integration-cov-append-retry = 'pytest --cov=haystack_integrations --cov-append --reruns 3 --reruns-delay 30 -x -m "integration" {args:tests}' -types = "mypy -p haystack_integrations.components.generators.vllm -p haystack_integrations.components.embedders.vllm -p haystack_integrations.common.vllm {args}" +types = "mypy -p haystack_integrations.components.generators.vllm -p haystack_integrations.components.embedders.vllm -p haystack_integrations.components.rankers.vllm -p haystack_integrations.common.vllm {args}" [tool.mypy] install_types = true diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py index 16d5f69206..ff70c64c78 100644 --- a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py @@ -241,7 +241,7 @@ def _validate_documents(documents: list[Document]) -> None: raise TypeError(msg) @component.output_types(documents=list[Document], meta=dict[str, Any]) - def run(self, documents: list[Document]) -> dict[str, Any]: + def run(self, documents: list[Document]) -> dict[str, list[Document] | dict[str, Any]]: """ Embed a list of Documents. @@ -267,7 +267,7 @@ def run(self, documents: list[Document]) -> dict[str, Any]: return {"documents": new_documents, "meta": meta} @component.output_types(documents=list[Document], meta=dict[str, Any]) - async def run_async(self, documents: list[Document]) -> dict[str, Any]: + async def run_async(self, documents: list[Document]) -> dict[str, list[Document] | dict[str, Any]]: """ Asynchronously embed a list of Documents. diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py index 2749ea393e..d332fcf513 100644 --- a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py @@ -138,14 +138,14 @@ def _prepare_input(self, text: str) -> dict[str, Any]: return kwargs @staticmethod - def _prepare_output(response: CreateEmbeddingResponse) -> dict[str, Any]: + def _prepare_output(response: CreateEmbeddingResponse) -> dict[str, list[float] | dict[str, Any]]: return { "embedding": response.data[0].embedding, "meta": {"model": response.model, "usage": dict(response.usage)}, } @component.output_types(embedding=list[float], meta=dict[str, Any]) - def run(self, text: str) -> dict[str, Any]: + def run(self, text: str) -> dict[str, list[float] | dict[str, Any]]: """ Embed a single string. @@ -162,7 +162,7 @@ def run(self, text: str) -> dict[str, Any]: return self._prepare_output(response) @component.output_types(embedding=list[float], meta=dict[str, Any]) - async def run_async(self, text: str) -> dict[str, Any]: + async def run_async(self, text: str) -> dict[str, list[float] | dict[str, Any]]: """ Asynchronously embed a single string. diff --git a/integrations/vllm/src/haystack_integrations/components/rankers/py.typed b/integrations/vllm/src/haystack_integrations/components/rankers/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integrations/vllm/src/haystack_integrations/components/rankers/vllm/__init__.py b/integrations/vllm/src/haystack_integrations/components/rankers/vllm/__init__.py new file mode 100644 index 0000000000..63bca08d5f --- /dev/null +++ b/integrations/vllm/src/haystack_integrations/components/rankers/vllm/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .ranker import VLLMRanker + +__all__ = ["VLLMRanker"] diff --git a/integrations/vllm/src/haystack_integrations/components/rankers/vllm/ranker.py b/integrations/vllm/src/haystack_integrations/components/rankers/vllm/ranker.py new file mode 100644 index 0000000000..29c98bfe7a --- /dev/null +++ b/integrations/vllm/src/haystack_integrations/components/rankers/vllm/ranker.py @@ -0,0 +1,256 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import replace +from typing import Any + +import httpx +from haystack import Document, component +from haystack.utils import Secret +from haystack.utils.http_client import init_http_client + + +@component +class VLLMRanker: + """ + Ranks Documents based on their similarity to a query using models served with [vLLM](https://docs.vllm.ai/). + + It expects a vLLM server to be running and accessible at the `api_base_url` parameter and uses the + `/rerank` endpoint exposed by vLLM. + + ### Starting the vLLM server + + Before using this component, start a vLLM server with a reranker model: + + ```bash + vllm serve BAAI/bge-reranker-base + ``` + + For details on server options, see the [vLLM CLI docs](https://docs.vllm.ai/en/stable/cli/serve/). + + ### Usage example + + ```python + from haystack import Document + from haystack_integrations.components.rankers.vllm import VLLMRanker + + ranker = VLLMRanker(model="BAAI/bge-reranker-base") + docs = [ + Document(content="The capital of Brazil is Brasilia."), + Document(content="The capital of France is Paris."), + ] + result = ranker.run(query="What is the capital of France?", documents=docs) + print(result["documents"][0].content) + ``` + + ### Usage example with vLLM-specific parameters + + Pass vLLM-specific parameters via the `extra_parameters` dictionary. They are merged into the + request body sent to the `/rerank` endpoint. + + ```python + ranker = VLLMRanker( + model="BAAI/bge-reranker-base", + extra_parameters={"truncate_prompt_tokens": 256}, + ) + ``` + """ + + def __init__( + self, + *, + model: str, + api_key: Secret | None = Secret.from_env_var("VLLM_API_KEY", strict=False), + api_base_url: str = "http://localhost:8000/v1", + top_k: int | None = None, + score_threshold: float | None = None, + meta_fields_to_embed: list[str] | None = None, + meta_data_separator: str = "\n", + http_client_kwargs: dict[str, Any] | None = None, + extra_parameters: dict[str, Any] | None = None, + ) -> None: + """ + Creates an instance of VLLMRanker. + + :param model: The name of the reranker model served by vLLM. Check + [vLLM documentation](https://docs.vllm.ai/en/stable/models/pooling_models/scoring/#supported-models) for + information on supported models. + :param api_key: The vLLM API key. Defaults to the `VLLM_API_KEY` environment variable. + Only required if the vLLM server was started with `--api-key`. + :param api_base_url: The base URL of the vLLM server. + :param top_k: The maximum number of Documents to return. If `None`, all documents are returned. + :param score_threshold: If set, documents with a relevance score below this value are dropped. + Applied after `top_k`, so the output may contain fewer than `top_k` documents. + :param meta_fields_to_embed: List of meta fields that should be concatenated with the document + content before reranking. + :param meta_data_separator: Separator used to concatenate the meta fields to the document content. + :param http_client_kwargs: A dictionary of keyword arguments to configure a custom `httpx.Client` or + `httpx.AsyncClient`. For more information, see the + [HTTPX documentation](https://www.python-httpx.org/api/#client). + :param extra_parameters: Additional parameters merged into the request body sent to the vLLM + `/rerank` endpoint. Use this to pass parameters not part of the standard rerank API, such as + `truncate_prompt_tokens`. See the + [vLLM docs](https://docs.vllm.ai/en/stable/models/pooling_models/scoring/#rerank-api) for more information. + + :raises ValueError: If `top_k` is not > 0. + """ + if top_k is not None and top_k <= 0: + msg = f"top_k must be > 0, but got {top_k}" + raise ValueError(msg) + + self.model = model + self.api_key = api_key + self.api_base_url = api_base_url + self.top_k = top_k + self.score_threshold = score_threshold + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.meta_data_separator = meta_data_separator + self.http_client_kwargs = http_client_kwargs + self.extra_parameters = extra_parameters + + self._headers = {"Content-Type": "application/json"} + if self.api_key is not None and (resolved_key := self.api_key.resolve_value()): + self._headers["Authorization"] = f"Bearer {resolved_key}" + + self._client: httpx.Client | None = None + self._async_client: httpx.AsyncClient | None = None + self._is_warmed_up = False + + def warm_up(self) -> None: + """Create the httpx clients.""" + if self._is_warmed_up: + return + + client = init_http_client(self.http_client_kwargs, async_client=False) + async_client = init_http_client(self.http_client_kwargs, async_client=True) + self._client = client if client is not None else httpx.Client() + self._async_client = async_client if async_client is not None else httpx.AsyncClient() + self._is_warmed_up = True + + def _prepare_texts(self, documents: list[Document]) -> list[str]: + """Concatenate each Document's text with the selected meta fields.""" + texts = [] + for doc in documents: + meta_values = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None + ] + texts.append(self.meta_data_separator.join([*meta_values, doc.content or ""])) + return texts + + def _prepare_request( + self, + query: str, + documents: list[Document], + top_k: int | None, + ) -> dict[str, Any]: + body: dict[str, Any] = { + "model": self.model, + "query": query, + "documents": self._prepare_texts(documents), + } + if top_k is not None: + body["top_n"] = top_k + if self.extra_parameters: + body.update(self.extra_parameters) + return body + + @staticmethod + def _parse_response( + resp: dict[str, Any], + documents: list[Document], + score_threshold: float | None, + ) -> dict[str, list[Document] | dict[str, Any]]: + if "results" not in resp: + msg = resp.get("detail") or f"Unexpected response from vLLM rerank endpoint: {resp}" + raise RuntimeError(msg) + + ranked_docs: list[Document] = [] + for result in resp["results"]: + score = result["relevance_score"] + if score_threshold is not None and score < score_threshold: + continue + ranked_docs.append(replace(documents[result["index"]], score=score)) + + meta = {"model": resp.get("model"), "usage": resp.get("usage", {})} + return {"documents": ranked_docs, "meta": meta} + + def _resolve_run_params(self, top_k: int | None, score_threshold: float | None) -> tuple[int | None, float | None]: + if top_k is not None and top_k <= 0: + msg = f"top_k must be > 0, but got {top_k}" + raise ValueError(msg) + resolved_top_k = top_k if top_k is not None else self.top_k + resolved_score_threshold = score_threshold if score_threshold is not None else self.score_threshold + return resolved_top_k, resolved_score_threshold + + @component.output_types(documents=list[Document], meta=dict[str, Any]) + def run( + self, + query: str, + documents: list[Document], + top_k: int | None = None, + score_threshold: float | None = None, + ) -> dict[str, list[Document] | dict[str, Any]]: + """ + Returns a list of Documents ranked by their similarity to the given query. + + :param query: Query string. + :param documents: List of Documents to rank. + :param top_k: The maximum number of Documents to return. Overrides the value set at initialization. + :param score_threshold: Minimum relevance score required for a document to be returned. Overrides + the value set at initialization. + :returns: A dictionary with: + - `documents`: Documents sorted from most to least relevant. + - `meta`: Information about the model and usage. + + :raises ValueError: If `top_k` is not > 0. + """ + if not documents: + return {"documents": [], "meta": {}} + + top_k, score_threshold = self._resolve_run_params(top_k, score_threshold) + + if not self._is_warmed_up: + self.warm_up() + assert self._client is not None # noqa: S101 + + body = self._prepare_request(query, documents, top_k) + url = f"{self.api_base_url.rstrip('/')}/rerank" + response = self._client.post(url, json=body, headers=self._headers) + return self._parse_response(response.json(), documents, score_threshold) + + @component.output_types(documents=list[Document], meta=dict[str, Any]) + async def run_async( + self, + query: str, + documents: list[Document], + top_k: int | None = None, + score_threshold: float | None = None, + ) -> dict[str, list[Document] | dict[str, Any]]: + """ + Asynchronously returns a list of Documents ranked by their similarity to the given query. + + :param query: Query string. + :param documents: List of Documents to rank. + :param top_k: The maximum number of Documents to return. Overrides the value set at initialization. + :param score_threshold: Minimum relevance score required for a document to be returned. Overrides + the value set at initialization. + :returns: A dictionary with: + - `documents`: Documents sorted from most to least relevant. + - `meta`: Information about the model and usage. + + :raises ValueError: If `top_k` is not > 0. + """ + if not documents: + return {"documents": [], "meta": {}} + + top_k, score_threshold = self._resolve_run_params(top_k, score_threshold) + + if not self._is_warmed_up: + self.warm_up() + assert self._async_client is not None # noqa: S101 + + body = self._prepare_request(query, documents, top_k) + url = f"{self.api_base_url.rstrip('/')}/rerank" + response = await self._async_client.post(url, json=body, headers=self._headers) + return self._parse_response(response.json(), documents, score_threshold) diff --git a/integrations/vllm/tests/test_ranker.py b/integrations/vllm/tests/test_ranker.py new file mode 100644 index 0000000000..59a035f513 --- /dev/null +++ b/integrations/vllm/tests/test_ranker.py @@ -0,0 +1,248 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import AsyncMock, MagicMock + +import pytest +from haystack import Document +from haystack.core.serialization import component_from_dict, component_to_dict +from haystack.utils import Secret + +from haystack_integrations.components.rankers.vllm import VLLMRanker + +MODEL = "BAAI/bge-reranker-base" +API_BASE_URL = "http://localhost:8002/v1" + + +def _fake_response(results: list[dict], model: str = "fake-model", total_tokens: int = 10): + response = MagicMock() + response.json.return_value = { + "id": "rerank-fake", + "model": model, + "usage": {"total_tokens": total_tokens}, + "results": results, + } + return response + + +class TestVLLMRanker: + def test_init_default(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + ranker = VLLMRanker(model=MODEL) + assert ranker.model == MODEL + assert ranker.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) + assert ranker.api_base_url == "http://localhost:8000/v1" + assert ranker.top_k is None + assert ranker.score_threshold is None + assert ranker.meta_fields_to_embed == [] + assert ranker.meta_data_separator == "\n" + assert ranker.http_client_kwargs is None + assert ranker.extra_parameters is None + assert ranker._client is None + assert ranker._async_client is None + assert ranker._is_warmed_up is False + assert "Authorization" not in ranker._headers + + def test_init_with_parameters(self): + ranker = VLLMRanker( + model=MODEL, + api_key=Secret.from_token("test-api-key"), + api_base_url="http://my-vllm-server:8000/v1", + top_k=5, + score_threshold=0.5, + meta_fields_to_embed=["topic"], + meta_data_separator=" | ", + http_client_kwargs={"verify": False}, + extra_parameters={"truncate_prompt_tokens": 256}, + ) + assert ranker.api_key == Secret.from_token("test-api-key") + assert ranker.api_base_url == "http://my-vllm-server:8000/v1" + assert ranker.top_k == 5 + assert ranker.score_threshold == 0.5 + assert ranker.meta_fields_to_embed == ["topic"] + assert ranker.meta_data_separator == " | " + assert ranker.http_client_kwargs == {"verify": False} + assert ranker.extra_parameters == {"truncate_prompt_tokens": 256} + assert ranker._headers["Authorization"] == "Bearer test-api-key" + + def test_init_invalid_top_k(self): + with pytest.raises(ValueError, match="top_k must be > 0"): + VLLMRanker(model=MODEL, top_k=0) + + def test_warm_up(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + ranker = VLLMRanker(model=MODEL) + ranker.warm_up() + + assert ranker._is_warmed_up is True + assert ranker._client is not None + assert ranker._async_client is not None + + client_before = ranker._client + ranker.warm_up() + assert ranker._client is client_before + + def test_to_dict(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + component_dict = component_to_dict(VLLMRanker(model=MODEL), "ranker") + assert component_dict == { + "type": "haystack_integrations.components.rankers.vllm.ranker.VLLMRanker", + "init_parameters": { + "model": MODEL, + "api_key": {"env_vars": ["VLLM_API_KEY"], "strict": False, "type": "env_var"}, + "api_base_url": "http://localhost:8000/v1", + "top_k": None, + "score_threshold": None, + "meta_fields_to_embed": [], + "meta_data_separator": "\n", + "http_client_kwargs": None, + "extra_parameters": None, + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + data = { + "type": "haystack_integrations.components.rankers.vllm.ranker.VLLMRanker", + "init_parameters": { + "model": MODEL, + "api_key": {"env_vars": ["VLLM_API_KEY"], "strict": False, "type": "env_var"}, + "api_base_url": "http://localhost:8000/v1", + "top_k": 3, + "score_threshold": 0.5, + "meta_fields_to_embed": ["topic"], + "meta_data_separator": " | ", + "http_client_kwargs": None, + "extra_parameters": {"truncate_prompt_tokens": 256}, + }, + } + ranker = component_from_dict(VLLMRanker, data, "ranker") + assert ranker.model == MODEL + assert ranker.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) + assert ranker.top_k == 3 + assert ranker.score_threshold == 0.5 + assert ranker.meta_fields_to_embed == ["topic"] + assert ranker.meta_data_separator == " | " + assert ranker.extra_parameters == {"truncate_prompt_tokens": 256} + + def test_prepare_texts_with_meta(self): + ranker = VLLMRanker(model=MODEL, meta_fields_to_embed=["topic"], meta_data_separator=" | ") + docs = [Document(content="hello", meta={"topic": "ML"}), Document(content="world", meta={})] + assert ranker._prepare_texts(docs) == ["ML | hello", "world"] + + def test_prepare_request_with_top_k_and_extras(self): + ranker = VLLMRanker(model=MODEL, extra_parameters={"truncate_prompt_tokens": 256}) + docs = [Document(content="a"), Document(content="b")] + body = ranker._prepare_request(query="q", documents=docs, top_k=2) + assert body == { + "model": MODEL, + "query": "q", + "documents": ["a", "b"], + "top_n": 2, + "truncate_prompt_tokens": 256, + } + + def test_prepare_request_without_top_k(self): + ranker = VLLMRanker(model=MODEL) + body = ranker._prepare_request(query="q", documents=[Document(content="a")], top_k=None) + assert "top_n" not in body + + def test_parse_response_applies_score_threshold(self): + docs = [Document(content="a"), Document(content="b")] + resp = { + "model": "fake-model", + "usage": {"total_tokens": 5}, + "results": [ + {"index": 1, "relevance_score": 0.9}, + {"index": 0, "relevance_score": 0.1}, + ], + } + out = VLLMRanker._parse_response(resp, docs, score_threshold=0.5) + assert [d.content for d in out["documents"]] == ["b"] + assert out["documents"][0].score == 0.9 + assert out["meta"] == {"model": "fake-model", "usage": {"total_tokens": 5}} + + def test_parse_response_raises_on_error(self): + with pytest.raises(RuntimeError, match="boom"): + VLLMRanker._parse_response({"detail": "boom"}, [], score_threshold=None) + + def test_run_invalid_top_k(self): + ranker = VLLMRanker(model=MODEL) + with pytest.raises(ValueError, match="top_k must be > 0"): + ranker.run(query="q", documents=[Document(content="a")], top_k=0) + + def test_run_empty_documents(self): + ranker = VLLMRanker(model=MODEL) + assert ranker.run(query="q", documents=[]) == {"documents": [], "meta": {}} + + def test_run(self): + ranker = VLLMRanker(model=MODEL, top_k=2) + ranker._client = MagicMock() + ranker._client.post.return_value = _fake_response( + results=[ + {"index": 1, "relevance_score": 0.99}, + {"index": 0, "relevance_score": 0.01}, + ] + ) + ranker._is_warmed_up = True + + docs = [ + Document(content="The capital of Brazil is Brasilia."), + Document(content="The capital of France is Paris."), + ] + out = ranker.run(query="What is the capital of France?", documents=docs) + + ranker._client.post.assert_called_once() + call_kwargs = ranker._client.post.call_args + assert call_kwargs.args[0] == "http://localhost:8000/v1/rerank" + assert call_kwargs.kwargs["json"] == { + "model": MODEL, + "query": "What is the capital of France?", + "documents": [d.content for d in docs], + "top_n": 2, + } + assert [d.content for d in out["documents"]] == [ + "The capital of France is Paris.", + "The capital of Brazil is Brasilia.", + ] + assert out["documents"][0].score == 0.99 + assert out["meta"]["model"] == "fake-model" + + @pytest.mark.asyncio + async def test_run_async(self): + ranker = VLLMRanker(model=MODEL) + ranker._async_client = MagicMock() + ranker._async_client.post = AsyncMock( + return_value=_fake_response(results=[{"index": 0, "relevance_score": 0.42}]) + ) + ranker._is_warmed_up = True + + out = await ranker.run_async(query="q", documents=[Document(content="a")]) + + assert out["documents"][0].score == 0.42 + + @pytest.mark.integration + def test_live_run(self): + ranker = VLLMRanker(model=MODEL, api_base_url=API_BASE_URL) + docs = [ + Document(content="The capital of Brazil is Brasilia."), + Document(content="The capital of France is Paris."), + ] + out = ranker.run(query="What is the capital of France?", documents=docs) + assert out["documents"][0].content == "The capital of France is Paris." + assert out["documents"][0].score is not None + assert out["meta"]["model"] == MODEL + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_async(self): + ranker = VLLMRanker(model=MODEL, api_base_url=API_BASE_URL) + docs = [ + Document(content="The capital of Brazil is Brasilia."), + Document(content="The capital of France is Paris."), + ] + out = await ranker.run_async(query="What is the capital of France?", documents=docs) + assert out["documents"][0].content == "The capital of France is Paris."