diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index 3dad3d8363..4c021b1473 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -30,6 +30,7 @@ env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" VLLM_MODEL: "Qwen/Qwen3-0.6B" + VLLM_EMBEDDING_MODEL: "sentence-transformers/all-MiniLM-L6-v2" # we only test on Ubuntu to keep vLLM server running simple TEST_MATRIX_OS: '["ubuntu-latest"]' # vLLM is not compatible with Python 3.14. https://github.com/vllm-project/vllm/issues/34096 @@ -88,12 +89,13 @@ jobs: "https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cpu-cp38-abi3-manylinux_2_35_x86_64.whl" \ --torch-backend cpu - - name: Start vLLM server + - name: Start vLLM chat server env: VLLM_TARGET_DEVICE: "cpu" VLLM_CPU_KVCACHE_SPACE: "4" run: | nohup hatch run -- vllm serve ${{ env.VLLM_MODEL }} \ + --port 8000 \ --reasoning-parser qwen3 \ --max-model-len 1024 \ --enforce-eager \ @@ -102,20 +104,45 @@ jobs: --tool-call-parser hermes \ --max-num-seqs 1 & - # Wait for the vLLM server to be ready with a timeout of 300 seconds + # Wait for the vLLM chat server to be ready with a timeout of 300 seconds timeout=300 while [ $timeout -gt 0 ] && ! curl -sSf http://localhost:8000/health > /dev/null 2>&1; do - echo "Waiting for vLLM server to start..." + echo "Waiting for vLLM chat server to start..." sleep 10 ((timeout-=10)) done if [ $timeout -eq 0 ]; then - echo "Timed out waiting for vLLM server to start." + echo "Timed out waiting for vLLM chat server to start." exit 1 fi - echo "vLLM server started successfully." + echo "vLLM chat server started successfully." + + - name: Start vLLM embedding server + env: + VLLM_TARGET_DEVICE: "cpu" + VLLM_CPU_KVCACHE_SPACE: "4" + run: | + nohup hatch run -- vllm serve ${{ env.VLLM_EMBEDDING_MODEL }} \ + --port 8001 \ + --enforce-eager \ + --max-num-seqs 1 & + + # Wait for the vLLM embedding server to be ready with a timeout of 300 seconds + timeout=300 + while [ $timeout -gt 0 ] && ! curl -sSf http://localhost:8001/health > /dev/null 2>&1; do + echo "Waiting for vLLM embedding server to start..." + sleep 10 + ((timeout-=10)) + done + + if [ $timeout -eq 0 ]; then + echo "Timed out waiting for vLLM embedding server to start." + exit 1 + fi + + echo "vLLM embedding server started successfully." - name: Lint if: matrix.python-version == '3.10' && runner.os == 'Linux' diff --git a/README.md b/README.md index 938a02ed0a..6e1ace16a3 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,7 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [togetherai-haystack](integrations/togetherai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/togetherai-haystack.svg)](https://pypi.org/project/togetherai-haystack) | [![Test / togetherai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/togetherai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/togetherai.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-togetherai/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-togetherai/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-togetherai-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-togetherai-combined/htmlcov/index.html) | | [unstructured-fileconverter-haystack](integrations/unstructured/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-unstructured/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-unstructured/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-unstructured-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-unstructured-combined/htmlcov/index.html) | | [valkey-haystack](integrations/valkey/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/valkey-haystack.svg)](https://pypi.org/project/valkey-haystack) | [![Test / valkey](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/valkey.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/valkey.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-valkey/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-valkey/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-valkey-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-valkey-combined/htmlcov/index.html) | -| [vllm-haystack](integrations/vllm/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/vllm-haystack.svg)](https://pypi.org/project/vllm-haystack) | [![Test / vllm](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/vllm.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/vllm.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-vllm/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-vllm/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-vllm-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-vllm-combined/htmlcov/index.html) | +| [vllm-haystack](integrations/vllm/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/vllm-haystack.svg)](https://pypi.org/project/vllm-haystack) | [![Test / vllm](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/vllm.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/vllm.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-vllm/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-vllm/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-vllm-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-vllm-combined/htmlcov/index.html) | | [watsonx-haystack](integrations/watsonx/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/watsonx-haystack.svg?color=orange)](https://pypi.org/project/watsonx-haystack) | [![Test / watsonx](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/watsonx.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/watsonx.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-watsonx/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-watsonx/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-watsonx-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-watsonx-combined/htmlcov/index.html) | | [weave-haystack](integrations/weave/) | Tracer | [![PyPI - Version](https://img.shields.io/pypi/v/weave-haystack.svg)](https://pypi.org/project/weave-haystack) | [![Test / weave](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weave.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weave.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-weave/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-weave/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-weave-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-weave-combined/htmlcov/index.html) | | [weaviate-haystack](integrations/weaviate/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/weaviate-haystack.svg)](https://pypi.org/project/weaviate-haystack) | [![Test / weaviate](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-weaviate/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-weaviate/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-weaviate-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-weaviate-combined/htmlcov/index.html) | diff --git a/integrations/vllm/README.md b/integrations/vllm/README.md index 3d987bbefb..a0a498cf8e 100644 --- a/integrations/vllm/README.md +++ b/integrations/vllm/README.md @@ -11,10 +11,19 @@ Refer to the general [Contribution Guidelines](https://github.com/deepset-ai/haystack-core-integrations/blob/main/CONTRIBUTING.md). -To run integration tests locally, you need to have a running vLLM server. Refer to the [workflow file](https://github.com/deepset-ai/haystack-core-integrations/blob/main/.github/workflows/vllm.yml) for more details. +To run integration tests locally, you need two vLLM servers running in parallel: one for the chat generator on port `8000` and one for the embedders on port `8001`. Refer to the [workflow file](https://github.com/deepset-ai/haystack-core-integrations/blob/main/.github/workflows/vllm.yml) for more details. -For example, on macOs, you can install [vLLM-metal](https://github.com/vllm-project/vllm-metal) and run the server with: +For example, on macOs, you can install [vLLM-metal](https://github.com/vllm-project/vllm-metal) and start the chat generator server with: ```bash -source ~/.venv-vllm-metal/bin/activate && vllm serve Qwen/Qwen3-0.6B --reasoning-parser qwen3 --max-model-len 1024 --enforce-eager --enable-auto-tool-choice --tool-call-parser hermes +# chat generator server (port 8000) +source ~/.venv-vllm-metal/bin/activate && vllm serve Qwen/Qwen3-0.6B --reasoning-parser qwen3 --max-model-len 1024 --enforce-eager --enable-auto-tool-choice --tool-call-parser hermes +``` + +vLLM-metal does not support embedding models. On macOS, you can run the embedding server via CPU Docker image: + +```bash +# embedders server (port 8001) +docker run --rm -p 8001:8000 -e VLLM_CPU_OMP_THREADS_BIND=0-3 vllm/vllm-openai-cpu:latest \ + --model sentence-transformers/all-MiniLM-L6-v2 --enforce-eager ``` \ No newline at end of file diff --git a/integrations/vllm/pydoc/config_docusaurus.yml b/integrations/vllm/pydoc/config_docusaurus.yml index 13c26dd968..daf8d8b75c 100644 --- a/integrations/vllm/pydoc/config_docusaurus.yml +++ b/integrations/vllm/pydoc/config_docusaurus.yml @@ -1,6 +1,8 @@ loaders: - modules: - haystack_integrations.components.generators.vllm.chat.chat_generator + - haystack_integrations.components.embedders.vllm.text_embedder + - haystack_integrations.components.embedders.vllm.document_embedder search_path: [../src] processors: - type: filter diff --git a/integrations/vllm/pyproject.toml b/integrations/vllm/pyproject.toml index 53cec11e33..91d1a9aaa7 100644 --- a/integrations/vllm/pyproject.toml +++ b/integrations/vllm/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.23.0", "openai"] +dependencies = ["haystack-ai>=2.23.0", "openai", "more_itertools>=9.0.0", "tqdm>=4.48.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/vllm#readme" @@ -66,7 +66,7 @@ integration = 'pytest -m "integration" {args:tests}' all = 'pytest {args:tests}' unit-cov-retry = 'pytest --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x -m "not integration" {args:tests}' integration-cov-append-retry = 'pytest --cov=haystack_integrations --cov-append --reruns 3 --reruns-delay 30 -x -m "integration" {args:tests}' -types = "mypy -p haystack_integrations.components.generators.vllm {args}" +types = "mypy -p haystack_integrations.components.generators.vllm -p haystack_integrations.components.embedders.vllm -p haystack_integrations.common.vllm {args}" [tool.mypy] install_types = true diff --git a/integrations/vllm/src/haystack_integrations/common/py.typed b/integrations/vllm/src/haystack_integrations/common/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integrations/vllm/src/haystack_integrations/common/vllm/__init__.py b/integrations/vllm/src/haystack_integrations/common/vllm/__init__.py new file mode 100644 index 0000000000..c1764a6e03 --- /dev/null +++ b/integrations/vllm/src/haystack_integrations/common/vllm/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/vllm/src/haystack_integrations/common/vllm/utils.py b/integrations/vllm/src/haystack_integrations/common/vllm/utils.py new file mode 100644 index 0000000000..becf3cf870 --- /dev/null +++ b/integrations/vllm/src/haystack_integrations/common/vllm/utils.py @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +from haystack.utils import Secret +from haystack.utils.http_client import init_http_client +from openai import AsyncOpenAI, OpenAI + + +def _create_openai_clients( + api_key: Secret | None, + api_base_url: str, + timeout: float | None, + max_retries: int | None, + http_client_kwargs: dict[str, Any] | None, +) -> tuple[OpenAI, AsyncOpenAI]: + """ + Build sync and async OpenAI clients pointing at a vLLM server. + + A placeholder api key is used when the user did not supply one and no `VLLM_API_KEY` env var is set, because the + OpenAI client requires a non-empty value. + `timeout` and `max_retries` are only forwarded when provided: when None, the OpenAI client's own defaults apply. + """ + resolved_api_key = "placeholder-api-key" + if api_key is not None and (value := api_key.resolve_value()): + resolved_api_key = value + + client_kwargs: dict[str, Any] = {"api_key": resolved_api_key, "base_url": api_base_url} + if timeout is not None: + client_kwargs["timeout"] = timeout + if max_retries is not None: + client_kwargs["max_retries"] = max_retries + + sync_client = OpenAI(http_client=init_http_client(http_client_kwargs, async_client=False), **client_kwargs) + async_client = AsyncOpenAI(http_client=init_http_client(http_client_kwargs, async_client=True), **client_kwargs) + return sync_client, async_client diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/py.typed b/integrations/vllm/src/haystack_integrations/components/embedders/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/__init__.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/__init__.py new file mode 100644 index 0000000000..1ffc2e2931 --- /dev/null +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .document_embedder import VLLMDocumentEmbedder +from .text_embedder import VLLMTextEmbedder + +__all__ = ["VLLMDocumentEmbedder", "VLLMTextEmbedder"] diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py new file mode 100644 index 0000000000..16d5f69206 --- /dev/null +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py @@ -0,0 +1,293 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import replace +from typing import Any + +from haystack import Document, component, logging +from haystack.utils import Secret +from more_itertools import batched +from openai import APIError, AsyncOpenAI, OpenAI +from tqdm import tqdm +from tqdm.asyncio import tqdm as async_tqdm + +from haystack_integrations.common.vllm.utils import _create_openai_clients + +logger = logging.getLogger(__name__) + + +@component +class VLLMDocumentEmbedder: + """ + A component for computing Document embeddings using models served with [vLLM](https://docs.vllm.ai/). + + The embedding of each Document is stored in the `embedding` field of the Document. + It expects a vLLM server to be running and accessible at the `api_base_url` parameter and uses the + OpenAI-compatible Embeddings API exposed by vLLM. + + ### Starting the vLLM server + + Before using this component, start a vLLM server with an embedding model: + + ```bash + vllm serve google/embeddinggemma-300m + ``` + + For details on server options, see the [vLLM CLI docs](https://docs.vllm.ai/en/stable/cli/serve/). + + ### Usage example + + ```python + from haystack import Document + from haystack_integrations.components.embedders.vllm import VLLMDocumentEmbedder + + doc = Document(content="I love pizza!") + + document_embedder = VLLMDocumentEmbedder(model="google/embeddinggemma-300m") + + result = document_embedder.run([doc]) + print(result["documents"][0].embedding) + ``` + + ### Usage example with vLLM-specific parameters + + Pass vLLM-specific parameters via the `extra_parameters` dictionary. They are forwarded as `extra_body` + to the OpenAI-compatible endpoint. + + ```python + document_embedder = VLLMDocumentEmbedder( + model="google/embeddinggemma-300m", + extra_parameters={"truncate_prompt_tokens": 256, "truncation_side": "right"}, + ) + ``` + """ + + def __init__( + self, + *, + model: str, + api_key: Secret | None = Secret.from_env_var("VLLM_API_KEY", strict=False), + api_base_url: str = "http://localhost:8000/v1", + prefix: str = "", + suffix: str = "", + dimensions: int | None = None, + batch_size: int = 32, + progress_bar: bool = True, + meta_fields_to_embed: list[str] | None = None, + embedding_separator: str = "\n", + timeout: float | None = None, + max_retries: int | None = None, + http_client_kwargs: dict[str, Any] | None = None, + raise_on_failure: bool = False, + extra_parameters: dict[str, Any] | None = None, + ) -> None: + """ + Creates an instance of VLLMDocumentEmbedder. + + :param model: The name of the model served by vLLM. Check + [vLLM documentation](https://docs.vllm.ai/en/stable/models/pooling_models) for more information. + :param api_key: The vLLM API key. Defaults to the `VLLM_API_KEY` environment variable. + Only required if the vLLM server was started with `--api-key`. + :param api_base_url: The base URL of the vLLM server. + :param prefix: A string to add at the beginning of each text. + :param suffix: A string to add at the end of each text. + :param dimensions: The number of dimensions of the resulting embedding. Only models trained with + Matryoshka Representation Learning support this parameter. See + [vLLM documentation](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#matryoshka-embeddings) + for more information. + :param batch_size: Number of documents to encode at once. + :param progress_bar: Whether to show a progress bar. + :param meta_fields_to_embed: List of meta fields to embed along with the document text. + :param embedding_separator: Separator used to concatenate the meta fields to the document text. + :param timeout: Timeout in seconds for vLLM client calls. If not set, the OpenAI client default applies. + :param max_retries: Maximum number of retries for failed requests. If not set, the OpenAI client + default applies. + :param http_client_kwargs: A dictionary of keyword arguments to configure a custom `httpx.Client` or + `httpx.AsyncClient`. For more information, see the + [HTTPX documentation](https://www.python-httpx.org/api/#client). + :param raise_on_failure: Whether to raise an exception if the embedding request fails. If `False`, + the component logs the error and continues processing the remaining documents. + :param extra_parameters: Additional parameters forwarded as `extra_body` to the vLLM embeddings + endpoint. Use this to pass parameters not part of the standard OpenAI Embeddings API, such as + `truncate_prompt_tokens`, `truncation_side`, etc. See the + [vLLM Embeddings API docs](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#openai-compatible-embeddings-api). + """ + self.model = model + self.api_key = api_key + self.api_base_url = api_base_url + self.prefix = prefix + self.suffix = suffix + self.dimensions = dimensions + self.batch_size = batch_size + self.progress_bar = progress_bar + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.embedding_separator = embedding_separator + self.timeout = timeout + self.max_retries = max_retries + self.http_client_kwargs = http_client_kwargs + self.raise_on_failure = raise_on_failure + self.extra_parameters = extra_parameters + + self._client: OpenAI | None = None + self._async_client: AsyncOpenAI | None = None + self._is_warmed_up = False + + def warm_up(self) -> None: + """Create the OpenAI clients.""" + if self._is_warmed_up: + return + self._client, self._async_client = _create_openai_clients( + api_key=self.api_key, + api_base_url=self.api_base_url, + timeout=self.timeout, + max_retries=self.max_retries, + http_client_kwargs=self.http_client_kwargs, + ) + self._is_warmed_up = True + + def _prepare_texts_to_embed(self, documents: list[Document]) -> dict[str, str]: + """Concatenate each Document's text with the selected meta fields.""" + texts_to_embed = {} + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None + ] + texts_to_embed[doc.id] = ( + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix + ) + return texts_to_embed + + def _prepare_input(self, inputs: list[str]) -> dict[str, Any]: + kwargs: dict[str, Any] = {"model": self.model, "input": inputs, "encoding_format": "float"} + if self.dimensions is not None: + kwargs["dimensions"] = self.dimensions + if self.extra_parameters: + kwargs["extra_body"] = self.extra_parameters + return kwargs + + @staticmethod + def _update_meta(meta: dict[str, Any], response: Any) -> None: + if "model" not in meta: + meta["model"] = response.model + if "usage" not in meta: + meta["usage"] = dict(response.usage) + else: + meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens + meta["usage"]["total_tokens"] += response.usage.total_tokens + + def _embed_batch( + self, texts_to_embed: dict[str, str], batch_size: int + ) -> tuple[dict[str, list[float]], dict[str, Any]]: + assert self._client is not None # noqa: S101 + doc_ids_to_embeddings: dict[str, list[float]] = {} + meta: dict[str, Any] = {} + + for batch in tqdm( + batched(texts_to_embed.items(), batch_size), + disable=not self.progress_bar, + desc="Calculating embeddings", + ): + kwargs = self._prepare_input([b[1] for b in batch]) + try: + response = self._client.embeddings.create(**kwargs) + except APIError as exc: + ids = ", ".join(b[0] for b in batch) + logger.exception("Failed embedding of documents {ids} caused by {exc}", ids=ids, exc=exc) + if self.raise_on_failure: + raise + continue + + embeddings = [el.embedding for el in response.data] + doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings, strict=True))) + self._update_meta(meta, response) + + return doc_ids_to_embeddings, meta + + async def _embed_batch_async( + self, texts_to_embed: dict[str, str], batch_size: int + ) -> tuple[dict[str, list[float]], dict[str, Any]]: + assert self._async_client is not None # noqa: S101 + doc_ids_to_embeddings: dict[str, list[float]] = {} + meta: dict[str, Any] = {} + + batches = list(batched(texts_to_embed.items(), batch_size)) + iterator = async_tqdm(batches, desc="Calculating embeddings") if self.progress_bar else batches + + for batch in iterator: + kwargs = self._prepare_input([b[1] for b in batch]) + try: + response = await self._async_client.embeddings.create(**kwargs) + except APIError as exc: + ids = ", ".join(b[0] for b in batch) + logger.exception("Failed embedding of documents {ids} caused by {exc}", ids=ids, exc=exc) + if self.raise_on_failure: + raise + continue + + embeddings = [el.embedding for el in response.data] + doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings, strict=True))) + self._update_meta(meta, response) + + return doc_ids_to_embeddings, meta + + @staticmethod + def _validate_documents(documents: list[Document]) -> None: + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): + msg = ( + "VLLMDocumentEmbedder expects a list of Documents as input. " + "In case you want to embed a string, please use the VLLMTextEmbedder." + ) + raise TypeError(msg) + + @component.output_types(documents=list[Document], meta=dict[str, Any]) + def run(self, documents: list[Document]) -> dict[str, Any]: + """ + Embed a list of Documents. + + :param documents: Documents to embed. + :returns: A dictionary with: + - `documents`: The input documents with their `embedding` field populated. + - `meta`: Information about the usage of the model. + """ + self._validate_documents(documents) + if not documents: + return {"documents": [], "meta": {}} + + if not self._is_warmed_up: + self.warm_up() + + texts_to_embed = self._prepare_texts_to_embed(documents) + doc_ids_to_embeddings, meta = self._embed_batch(texts_to_embed, self.batch_size) + + new_documents = [ + replace(doc, embedding=doc_ids_to_embeddings[doc.id]) if doc.id in doc_ids_to_embeddings else replace(doc) + for doc in documents + ] + return {"documents": new_documents, "meta": meta} + + @component.output_types(documents=list[Document], meta=dict[str, Any]) + async def run_async(self, documents: list[Document]) -> dict[str, Any]: + """ + Asynchronously embed a list of Documents. + + :param documents: Documents to embed. + :returns: A dictionary with: + - `documents`: The input documents with their `embedding` field populated. + - `meta`: Information about the usage of the model. + """ + self._validate_documents(documents) + if not documents: + return {"documents": [], "meta": {}} + + if not self._is_warmed_up: + self.warm_up() + + texts_to_embed = self._prepare_texts_to_embed(documents) + doc_ids_to_embeddings, meta = await self._embed_batch_async(texts_to_embed, self.batch_size) + + new_documents = [ + replace(doc, embedding=doc_ids_to_embeddings[doc.id]) if doc.id in doc_ids_to_embeddings else replace(doc) + for doc in documents + ] + return {"documents": new_documents, "meta": meta} diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py new file mode 100644 index 0000000000..2749ea393e --- /dev/null +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py @@ -0,0 +1,179 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +from haystack import component +from haystack.utils import Secret +from openai import AsyncOpenAI, OpenAI +from openai.types import CreateEmbeddingResponse + +from haystack_integrations.common.vllm.utils import _create_openai_clients + + +@component +class VLLMTextEmbedder: + """ + A component for embedding strings using models served with [vLLM](https://docs.vllm.ai/). + + It expects a vLLM server to be running and accessible at the `api_base_url` parameter and uses the + OpenAI-compatible Embeddings API exposed by vLLM. + + ### Starting the vLLM server + + Before using this component, start a vLLM server with an embedding model: + + ```bash + vllm serve google/embeddinggemma-300m + ``` + + For details on server options, see the [vLLM CLI docs](https://docs.vllm.ai/en/stable/cli/serve/). + + ### Usage example + + ```python + from haystack_integrations.components.embedders.vllm import VLLMTextEmbedder + + text_embedder = VLLMTextEmbedder(model="google/embeddinggemma-300m") + print(text_embedder.run("I love pizza!")) + ``` + + ### Usage example with vLLM-specific parameters + + Pass vLLM-specific parameters via the `extra_parameters` dictionary. They are forwarded as `extra_body` + to the OpenAI-compatible endpoint. + + ```python + text_embedder = VLLMTextEmbedder( + model="google/embeddinggemma-300m", + extra_parameters={"truncate_prompt_tokens": 256, "truncation_side": "right"}, + ) + ``` + """ + + def __init__( + self, + *, + model: str, + api_key: Secret | None = Secret.from_env_var("VLLM_API_KEY", strict=False), + api_base_url: str = "http://localhost:8000/v1", + prefix: str = "", + suffix: str = "", + dimensions: int | None = None, + timeout: float | None = None, + max_retries: int | None = None, + http_client_kwargs: dict[str, Any] | None = None, + extra_parameters: dict[str, Any] | None = None, + ) -> None: + """ + Creates an instance of VLLMTextEmbedder. + + :param model: The name of the model served by vLLM (e.g., "intfloat/e5-mistral-7b-instruct"). + :param api_key: The vLLM API key. Defaults to the `VLLM_API_KEY` environment variable. + Only required if the vLLM server was started with `--api-key`. + :param api_base_url: The base URL of the vLLM server. + :param prefix: A string to add at the beginning of each text to embed. + :param suffix: A string to add at the end of each text to embed. + :param dimensions: The number of dimensions of the resulting embedding. Only models trained with + Matryoshka Representation Learning support this parameter. See + [vLLM documentation](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#matryoshka-embeddings) + for more information. + :param timeout: Timeout in seconds for vLLM client calls. If not set, the OpenAI client default applies. + :param max_retries: Maximum number of retries for failed requests. If not set, the OpenAI client + default applies. + :param http_client_kwargs: A dictionary of keyword arguments to configure a custom `httpx.Client` or + `httpx.AsyncClient`. For more information, see the + [HTTPX documentation](https://www.python-httpx.org/api/#client). + :param extra_parameters: Additional parameters forwarded as `extra_body` to the vLLM embeddings + endpoint. Use this to pass parameters not part of the standard OpenAI Embeddings API, such as + `truncate_prompt_tokens`, `truncation_side`, `additional_data`, `use_activation`, etc. See the + [vLLM Embeddings API docs](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#openai-compatible-embeddings-api). + """ + self.model = model + self.api_key = api_key + self.api_base_url = api_base_url + self.prefix = prefix + self.suffix = suffix + self.dimensions = dimensions + self.timeout = timeout + self.max_retries = max_retries + self.http_client_kwargs = http_client_kwargs + self.extra_parameters = extra_parameters + + self._client: OpenAI | None = None + self._async_client: AsyncOpenAI | None = None + self._is_warmed_up = False + + def warm_up(self) -> None: + """Create the OpenAI clients.""" + if self._is_warmed_up: + return + self._client, self._async_client = _create_openai_clients( + api_key=self.api_key, + api_base_url=self.api_base_url, + timeout=self.timeout, + max_retries=self.max_retries, + http_client_kwargs=self.http_client_kwargs, + ) + self._is_warmed_up = True + + def _prepare_input(self, text: str) -> dict[str, Any]: + if not isinstance(text, str): + msg = ( + "VLLMTextEmbedder expects a string as an input. " + "In case you want to embed a list of Documents, please use the VLLMDocumentEmbedder." + ) + raise TypeError(msg) + + kwargs: dict[str, Any] = { + "model": self.model, + "input": self.prefix + text + self.suffix, + "encoding_format": "float", + } + if self.dimensions is not None: + kwargs["dimensions"] = self.dimensions + if self.extra_parameters: + kwargs["extra_body"] = self.extra_parameters + return kwargs + + @staticmethod + def _prepare_output(response: CreateEmbeddingResponse) -> dict[str, Any]: + return { + "embedding": response.data[0].embedding, + "meta": {"model": response.model, "usage": dict(response.usage)}, + } + + @component.output_types(embedding=list[float], meta=dict[str, Any]) + def run(self, text: str) -> dict[str, Any]: + """ + Embed a single string. + + :param text: Text to embed. + :returns: A dictionary with: + - `embedding`: The embedding of the input text. + - `meta`: Information about the usage of the model. + """ + kwargs = self._prepare_input(text) + if not self._is_warmed_up: + self.warm_up() + assert self._client is not None # noqa: S101 + response = self._client.embeddings.create(**kwargs) + return self._prepare_output(response) + + @component.output_types(embedding=list[float], meta=dict[str, Any]) + async def run_async(self, text: str) -> dict[str, Any]: + """ + Asynchronously embed a single string. + + :param text: Text to embed. + :returns: A dictionary with: + - `embedding`: The embedding of the input text. + - `meta`: Information about the usage of the model. + """ + kwargs = self._prepare_input(text) + if not self._is_warmed_up: + self.warm_up() + assert self._async_client is not None # noqa: S101 + response = await self._async_client.embeddings.create(**kwargs) + return self._prepare_output(response) diff --git a/integrations/vllm/src/haystack_integrations/components/generators/vllm/chat/chat_generator.py b/integrations/vllm/src/haystack_integrations/components/generators/vllm/chat/chat_generator.py index 5d21c3abaa..356d9da3f0 100644 --- a/integrations/vllm/src/haystack_integrations/components/generators/vllm/chat/chat_generator.py +++ b/integrations/vllm/src/haystack_integrations/components/generators/vllm/chat/chat_generator.py @@ -29,11 +29,12 @@ warm_up_tools, ) from haystack.utils import Secret, deserialize_callable, serialize_callable -from haystack.utils.http_client import init_http_client from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion import Choice +from haystack_integrations.common.vllm.utils import _create_openai_clients + logger = logging.getLogger(__name__) @@ -257,24 +258,12 @@ def warm_up(self) -> None: if self._is_warmed_up: return - api_key = "placeholder-api-key" - if self.api_key and (resolved_value := self.api_key.resolve_value()): - api_key = resolved_value - - client_kwargs: dict[str, Any] = { - "api_key": api_key, - "base_url": self.api_base_url, - } - if self.timeout is not None: - client_kwargs["timeout"] = self.timeout - if self.max_retries is not None: - client_kwargs["max_retries"] = self.max_retries - - self._client = OpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs - ) - self._async_client = AsyncOpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs + self._client, self._async_client = _create_openai_clients( + api_key=self.api_key, + api_base_url=self.api_base_url, + timeout=self.timeout, + max_retries=self.max_retries, + http_client_kwargs=self.http_client_kwargs, ) warm_up_tools(self.tools) self._is_warmed_up = True diff --git a/integrations/vllm/tests/test_document_embedder.py b/integrations/vllm/tests/test_document_embedder.py new file mode 100644 index 0000000000..a2d2ea6e35 --- /dev/null +++ b/integrations/vllm/tests/test_document_embedder.py @@ -0,0 +1,281 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import AsyncMock, MagicMock + +import numpy as np +import pytest +from haystack import Document +from haystack.core.serialization import component_from_dict, component_to_dict +from haystack.utils import Secret +from openai import APIError +from openai.types import CreateEmbeddingResponse, Embedding +from openai.types.create_embedding_response import Usage + +from haystack_integrations.components.embedders.vllm import VLLMDocumentEmbedder + +MODEL = "sentence-transformers/all-MiniLM-L6-v2" +API_BASE_URL = "http://localhost:8001/v1" + + +def _fake_response(embeddings: list[list[float]], prompt_tokens: int = 1, total_tokens: int = 1): + return CreateEmbeddingResponse( + object="list", + model="fake-model", + data=[Embedding(object="embedding", index=i, embedding=e) for i, e in enumerate(embeddings)], + usage=Usage(prompt_tokens=prompt_tokens, total_tokens=total_tokens), + ) + + +def _api_error() -> APIError: + return APIError(message="boom", request=MagicMock(), body=None) + + +class TestVLLMDocumentEmbedder: + def test_init_default(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + embedder = VLLMDocumentEmbedder(model=MODEL) + assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) + assert embedder.model == MODEL + assert embedder.api_base_url == "http://localhost:8000/v1" + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.dimensions is None + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.meta_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + assert embedder.raise_on_failure is False + assert embedder.extra_parameters is None + assert embedder._client is None + assert embedder._async_client is None + assert embedder._is_warmed_up is False + + def test_init_with_parameters(self): + embedder = VLLMDocumentEmbedder( + model=MODEL, + api_key=Secret.from_token("test-api-key"), + api_base_url="http://my-vllm-server:8000/v1", + prefix="START", + suffix="END", + dimensions=64, + batch_size=64, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator="-", + raise_on_failure=True, + extra_parameters={"dimensions": 32, "truncate_prompt_tokens": 256}, + ) + assert embedder.api_key == Secret.from_token("test-api-key") + assert embedder.api_base_url == "http://my-vllm-server:8000/v1" + assert embedder.prefix == "START" + assert embedder.suffix == "END" + assert embedder.dimensions == 64 + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.meta_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == "-" + assert embedder.raise_on_failure is True + assert embedder.extra_parameters == {"dimensions": 32, "truncate_prompt_tokens": 256} + + def test_warm_up(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + embedder = VLLMDocumentEmbedder(model=MODEL) + embedder.warm_up() + + assert embedder._is_warmed_up is True + assert embedder._client is not None + assert embedder._async_client is not None + + # idempotent + client_before = embedder._client + embedder.warm_up() + assert embedder._client is client_before + + def test_to_dict(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + component_dict = component_to_dict(VLLMDocumentEmbedder(model=MODEL), "embedder") + assert component_dict == { + "type": "haystack_integrations.components.embedders.vllm.document_embedder.VLLMDocumentEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["VLLM_API_KEY"], "strict": False, "type": "env_var"}, + "model": MODEL, + "api_base_url": "http://localhost:8000/v1", + "prefix": "", + "suffix": "", + "dimensions": None, + "batch_size": 32, + "progress_bar": True, + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "timeout": None, + "max_retries": None, + "http_client_kwargs": None, + "raise_on_failure": False, + "extra_parameters": None, + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + data = { + "type": "haystack_integrations.components.embedders.vllm.document_embedder.VLLMDocumentEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["VLLM_API_KEY"], "strict": False, "type": "env_var"}, + "model": MODEL, + "api_base_url": "http://localhost:8000/v1", + "prefix": "", + "suffix": "", + "dimensions": 32, + "batch_size": 32, + "progress_bar": True, + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "timeout": None, + "max_retries": None, + "http_client_kwargs": None, + "raise_on_failure": False, + "extra_parameters": None, + }, + } + embedder = component_from_dict(VLLMDocumentEmbedder, data, "embedder") + assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) + assert embedder.model == MODEL + assert embedder.api_base_url == "http://localhost:8000/v1" + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.dimensions == 32 + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.meta_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + assert embedder.timeout is None + assert embedder.max_retries is None + assert embedder.http_client_kwargs is None + assert embedder.raise_on_failure is False + assert embedder.extra_parameters is None + + def test_prepare_texts_to_embed(self): + embedder = VLLMDocumentEmbedder( + model=MODEL, prefix="[", suffix="]", meta_fields_to_embed=["topic"], embedding_separator=" | " + ) + doc = Document(content="hello", meta={"topic": "ML"}) + texts = embedder._prepare_texts_to_embed([doc]) + assert texts == {doc.id: "[ML | hello]"} + + def test_prepare_input_adds_dimensions_and_extra_body(self): + embedder = VLLMDocumentEmbedder(model=MODEL, dimensions=32, extra_parameters={"truncate_prompt_tokens": 256}) + kwargs = embedder._prepare_input(["a", "b"]) + assert kwargs == { + "model": MODEL, + "input": ["a", "b"], + "encoding_format": "float", + "dimensions": 32, + "extra_body": {"truncate_prompt_tokens": 256}, + } + + def test_run_wrong_input_format(self): + embedder = VLLMDocumentEmbedder(model=MODEL) + + with pytest.raises(TypeError, match=r"VLLMDocumentEmbedder expects a list of Documents as input\."): + embedder.run(documents="text") + with pytest.raises(TypeError, match=r"VLLMDocumentEmbedder expects a list of Documents as input\."): + embedder.run(documents=[1, 2, 3]) + + assert embedder.run(documents=[]) == {"documents": [], "meta": {}} + + def test_run_batches_and_aggregates_meta(self): + """Multi-batch run: embeddings stitched back to the right docs, usage meta accumulates.""" + embedder = VLLMDocumentEmbedder(model=MODEL, batch_size=2, progress_bar=False) + embedder._client = MagicMock() + embedder._client.embeddings.create.side_effect = [ + _fake_response([[0.1], [0.2]], prompt_tokens=2, total_tokens=2), + _fake_response([[0.3]], prompt_tokens=1, total_tokens=1), + ] + embedder._is_warmed_up = True + + docs = [Document(content=f"doc-{i}") for i in range(3)] + result = embedder.run(docs) + + assert [d.embedding for d in result["documents"]] == [[0.1], [0.2], [0.3]] + assert result["meta"] == {"model": "fake-model", "usage": {"prompt_tokens": 3, "total_tokens": 3}} + + def test_run_continues_on_api_error(self): + """raise_on_failure=False: failed batches are skipped, surviving docs keep their embedding.""" + embedder = VLLMDocumentEmbedder(model=MODEL, batch_size=1, progress_bar=False) + embedder._client = MagicMock() + embedder._client.embeddings.create.side_effect = [_fake_response([[0.1]]), _api_error()] + embedder._is_warmed_up = True + + result = embedder.run([Document(content="a"), Document(content="b")]) + + assert result["documents"][0].embedding == [0.1] + assert result["documents"][1].embedding is None + + def test_run_raise_on_failure(self): + embedder = VLLMDocumentEmbedder(model=MODEL, raise_on_failure=True, progress_bar=False) + embedder._client = MagicMock() + embedder._client.embeddings.create.side_effect = _api_error() + embedder._is_warmed_up = True + + with pytest.raises(APIError): + embedder.run([Document(content="a")]) + + @pytest.mark.asyncio + async def test_run_async(self): + embedder = VLLMDocumentEmbedder(model=MODEL, progress_bar=True) + embedder._async_client = MagicMock() + embedder._async_client.embeddings.create = AsyncMock(return_value=_fake_response([[0.5], [0.6]])) + embedder._is_warmed_up = True + + docs = [Document(content="a"), Document(content="b")] + result = await embedder.run_async(docs) + + assert [d.embedding for d in result["documents"]] == [[0.5], [0.6]] + + @pytest.mark.integration + def test_live_run(self): + embedder = VLLMDocumentEmbedder(model=MODEL, api_base_url=API_BASE_URL) + + docs = [ + Document(content="I love cheese"), + Document(content="Cheddar is my favorite food"), + Document(content="A transformer is a deep learning architecture"), + ] + + result = embedder.run(docs) + docs_with_embeddings = result["documents"] + + assert len(docs_with_embeddings) == len(docs) + for doc in docs_with_embeddings: + assert isinstance(doc.embedding, list) + assert isinstance(doc.embedding[0], float) + + embeddings = [np.array(d.embedding) for d in docs_with_embeddings] + + def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: + return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) + + assert cosine_similarity(embeddings[0], embeddings[1]) > cosine_similarity(embeddings[0], embeddings[2]) + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_async(self): + embedder = VLLMDocumentEmbedder(model=MODEL, api_base_url=API_BASE_URL) + + docs = [ + Document(content="I love cheese"), + Document(content="Cheddar is my favorite food"), + Document(content="A transformer is a deep learning architecture"), + ] + + result = await embedder.run_async(docs) + docs_with_embeddings = result["documents"] + + assert len(docs_with_embeddings) == len(docs) + for doc in docs_with_embeddings: + assert isinstance(doc.embedding, list) + assert isinstance(doc.embedding[0], float) diff --git a/integrations/vllm/tests/test_text_embedder.py b/integrations/vllm/tests/test_text_embedder.py new file mode 100644 index 0000000000..7ee2b42051 --- /dev/null +++ b/integrations/vllm/tests/test_text_embedder.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +from unittest.mock import AsyncMock, MagicMock + +import pytest +from haystack.core.serialization import component_from_dict, component_to_dict +from haystack.utils import Secret +from openai.types import CreateEmbeddingResponse, Embedding +from openai.types.create_embedding_response import Usage + +from haystack_integrations.components.embedders.vllm import VLLMTextEmbedder + +MODEL = "sentence-transformers/all-MiniLM-L6-v2" +API_BASE_URL = "http://localhost:8001/v1" + + +def _fake_response(embeddings: list[list[float]], prompt_tokens: int = 5, total_tokens: int = 5): + return CreateEmbeddingResponse( + object="list", + model="fake-model", + data=[Embedding(object="embedding", index=i, embedding=e) for i, e in enumerate(embeddings)], + usage=Usage(prompt_tokens=prompt_tokens, total_tokens=total_tokens), + ) + + +class TestVLLMTextEmbedder: + def test_init_default(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + embedder = VLLMTextEmbedder(model=MODEL) + assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) + assert embedder.api_base_url == "http://localhost:8000/v1" + assert embedder.model == MODEL + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.dimensions is None + assert embedder.timeout is None + assert embedder.max_retries is None + assert embedder.http_client_kwargs is None + assert embedder.extra_parameters is None + assert embedder._client is None + assert embedder._async_client is None + assert embedder._is_warmed_up is False + + def test_init_with_parameters(self): + embedder = VLLMTextEmbedder( + model=MODEL, + api_key=Secret.from_token("test-api-key"), + api_base_url="http://my-vllm-server:8000/v1", + prefix="START", + suffix="END", + dimensions=64, + timeout=10.0, + max_retries=2, + http_client_kwargs={"proxy": "https://proxy.example.com"}, + extra_parameters={"dimensions": 32, "truncate_prompt_tokens": 256}, + ) + assert embedder.api_key == Secret.from_token("test-api-key") + assert embedder.api_base_url == "http://my-vllm-server:8000/v1" + assert embedder.model == MODEL + assert embedder.prefix == "START" + assert embedder.suffix == "END" + assert embedder.dimensions == 64 + assert embedder.timeout == 10.0 + assert embedder.max_retries == 2 + assert embedder.http_client_kwargs == {"proxy": "https://proxy.example.com"} + assert embedder.extra_parameters == {"dimensions": 32, "truncate_prompt_tokens": 256} + + def test_warm_up(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + embedder = VLLMTextEmbedder(model=MODEL) + embedder.warm_up() + + assert embedder._is_warmed_up is True + assert embedder._client is not None + assert embedder._async_client is not None + + # idempotent: calling again does not recreate clients + client_before = embedder._client + embedder.warm_up() + assert embedder._client is client_before + + def test_to_dict(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + component_dict = component_to_dict(VLLMTextEmbedder(model=MODEL), "embedder") + assert component_dict == { + "type": "haystack_integrations.components.embedders.vllm.text_embedder.VLLMTextEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["VLLM_API_KEY"], "strict": False, "type": "env_var"}, + "model": MODEL, + "api_base_url": "http://localhost:8000/v1", + "prefix": "", + "suffix": "", + "dimensions": None, + "timeout": None, + "max_retries": None, + "http_client_kwargs": None, + "extra_parameters": None, + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + data = { + "type": "haystack_integrations.components.embedders.vllm.text_embedder.VLLMTextEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["VLLM_API_KEY"], "strict": False, "type": "env_var"}, + "model": MODEL, + "api_base_url": "http://localhost:8000/v1", + "prefix": "", + "suffix": "", + "dimensions": 32, + "timeout": None, + "max_retries": None, + "http_client_kwargs": None, + "extra_parameters": None, + }, + } + embedder = component_from_dict(VLLMTextEmbedder, data, "embedder") + assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) + assert embedder.model == MODEL + assert embedder.api_base_url == "http://localhost:8000/v1" + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.dimensions == 32 + assert embedder.timeout is None + assert embedder.max_retries is None + assert embedder.http_client_kwargs is None + assert embedder.extra_parameters is None + + def test_prepare_input_adds_dimensions_and_extra_body(self): + embedder = VLLMTextEmbedder( + model=MODEL, prefix="[", suffix="]", dimensions=32, extra_parameters={"truncate_prompt_tokens": 256} + ) + kwargs = embedder._prepare_input("hello") + assert kwargs == { + "model": MODEL, + "input": "[hello]", + "encoding_format": "float", + "dimensions": 32, + "extra_body": {"truncate_prompt_tokens": 256}, + } + + def test_run_wrong_input_format(self): + embedder = VLLMTextEmbedder(model=MODEL) + with pytest.raises(TypeError, match=r"VLLMTextEmbedder expects a string as an input\."): + embedder.run(text=["text_1", "text_2"]) + + def test_run_with_mock(self): + embedder = VLLMTextEmbedder(model=MODEL, prefix="[", suffix="]", dimensions=2) + embedder._client = MagicMock() + embedder._client.embeddings.create.return_value = _fake_response([[0.1, 0.2]]) + embedder._is_warmed_up = True + + result = embedder.run("hello") + + call_kwargs = embedder._client.embeddings.create.call_args.kwargs + assert call_kwargs["input"] == "[hello]" + assert call_kwargs["dimensions"] == 2 + assert result == { + "embedding": [0.1, 0.2], + "meta": {"model": "fake-model", "usage": {"prompt_tokens": 5, "total_tokens": 5}}, + } + + @pytest.mark.asyncio + async def test_run_async(self): + embedder = VLLMTextEmbedder(model=MODEL) + embedder._async_client = MagicMock() + embedder._async_client.embeddings.create = AsyncMock(return_value=_fake_response([[0.3, 0.4]])) + embedder._is_warmed_up = True + + result = await embedder.run_async("world") + assert result["embedding"] == [0.3, 0.4] + + @pytest.mark.integration + def test_live_run(self): + embedder = VLLMTextEmbedder(model=MODEL, api_base_url=API_BASE_URL) + result = embedder.run("The food was delicious") + assert isinstance(result["embedding"], list) + assert all(isinstance(x, float) for x in result["embedding"]) + + @pytest.mark.asyncio + @pytest.mark.integration + async def test_live_run_async(self): + embedder = VLLMTextEmbedder(model=MODEL, api_base_url=API_BASE_URL) + result = await embedder.run_async("The food was delicious") + assert isinstance(result["embedding"], list) + assert all(isinstance(x, float) for x in result["embedding"]) diff --git a/integrations/vllm/tests/test_utils.py b/integrations/vllm/tests/test_utils.py new file mode 100644 index 0000000000..b017762955 --- /dev/null +++ b/integrations/vllm/tests/test_utils.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack.utils import Secret + +from haystack_integrations.common.vllm.utils import _create_openai_clients + + +def test_create_openai_clients_placeholder_when_no_key(): + sync_client, async_client = _create_openai_clients( + api_key=None, api_base_url="http://localhost:8000/v1", timeout=None, max_retries=None, http_client_kwargs=None + ) + assert sync_client.api_key == "placeholder-api-key" + assert async_client.api_key == "placeholder-api-key" + assert str(sync_client.base_url) == "http://localhost:8000/v1/" + + +def test_create_openai_clients_uses_resolved_key_and_forwards_options(): + sync_client, _ = _create_openai_clients( + api_key=Secret.from_token("real-key"), + api_base_url="http://vllm:8000/v1", + timeout=12.5, + max_retries=7, + http_client_kwargs=None, + ) + assert sync_client.api_key == "real-key" + assert sync_client.timeout == 12.5 + assert sync_client.max_retries == 7