From 9d0edd784a135ef6955692119798945a0c4551ac Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 14 Apr 2026 16:23:03 +0200 Subject: [PATCH 1/7] feat: add vLLM embedders --- README.md | 2 +- integrations/vllm/pydoc/config_docusaurus.yml | 2 + integrations/vllm/pyproject.toml | 2 +- .../src/haystack_integrations/common/py.typed | 0 .../common/vllm/__init__.py | 3 + .../common/vllm/utils.py | 39 +++ .../components/embedders/py.typed | 0 .../components/embedders/vllm/__init__.py | 8 + .../embedders/vllm/document_embedder.py | 314 ++++++++++++++++++ .../embedders/vllm/text_embedder.py | 196 +++++++++++ .../generators/vllm/chat/chat_generator.py | 27 +- .../vllm/tests/test_document_embedder.py | 240 +++++++++++++ integrations/vllm/tests/test_text_embedder.py | 172 ++++++++++ integrations/vllm/tests/test_utils.py | 31 ++ 14 files changed, 1015 insertions(+), 21 deletions(-) create mode 100644 integrations/vllm/src/haystack_integrations/common/py.typed create mode 100644 integrations/vllm/src/haystack_integrations/common/vllm/__init__.py create mode 100644 integrations/vllm/src/haystack_integrations/common/vllm/utils.py create mode 100644 integrations/vllm/src/haystack_integrations/components/embedders/py.typed create mode 100644 integrations/vllm/src/haystack_integrations/components/embedders/vllm/__init__.py create mode 100644 integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py create mode 100644 integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py create mode 100644 integrations/vllm/tests/test_document_embedder.py create mode 100644 integrations/vllm/tests/test_text_embedder.py create mode 100644 integrations/vllm/tests/test_utils.py diff --git a/README.md b/README.md index 42891b8360..00bffbd95c 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,7 @@ Please check out our [Contribution Guidelines](CONTRIBUTING.md) for all the deta | [togetherai-haystack](integrations/togetherai/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/togetherai-haystack.svg)](https://pypi.org/project/togetherai-haystack) | [![Test / togetherai](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/togetherai.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/togetherai.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-togetherai/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-togetherai/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-togetherai-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-togetherai-combined/htmlcov/index.html) | | [unstructured-fileconverter-haystack](integrations/unstructured/) | File converter | [![PyPI - Version](https://img.shields.io/pypi/v/unstructured-fileconverter-haystack.svg)](https://pypi.org/project/unstructured-fileconverter-haystack) | [![Test / unstructured](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/unstructured.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-unstructured/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-unstructured/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-unstructured-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-unstructured-combined/htmlcov/index.html) | | [valkey-haystack](integrations/valkey/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/valkey-haystack.svg)](https://pypi.org/project/valkey-haystack) | [![Test / valkey](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/valkey.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/valkey.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-valkey/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-valkey/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-valkey-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-valkey-combined/htmlcov/index.html) | -| [vllm-haystack](integrations/vllm/) | Generator | [![PyPI - Version](https://img.shields.io/pypi/v/vllm-haystack.svg)](https://pypi.org/project/vllm-haystack) | [![Test / vllm](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/vllm.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/vllm.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-vllm/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-vllm/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-vllm-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-vllm-combined/htmlcov/index.html) | +| [vllm-haystack](integrations/vllm/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/vllm-haystack.svg)](https://pypi.org/project/vllm-haystack) | [![Test / vllm](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/vllm.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/vllm.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-vllm/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-vllm/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-vllm-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-vllm-combined/htmlcov/index.html) | | [watsonx-haystack](integrations/watsonx/) | Embedder, Generator | [![PyPI - Version](https://img.shields.io/pypi/v/watsonx-haystack.svg?color=orange)](https://pypi.org/project/watsonx-haystack) | [![Test / watsonx](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/watsonx.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/watsonx.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-watsonx/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-watsonx/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-watsonx-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-watsonx-combined/htmlcov/index.html) | | [weave-haystack](integrations/weave/) | Tracer | [![PyPI - Version](https://img.shields.io/pypi/v/weave-haystack.svg)](https://pypi.org/project/weave-haystack) | [![Test / weave](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weave.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weave.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-weave/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-weave/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-weave-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-weave-combined/htmlcov/index.html) | | [weaviate-haystack](integrations/weaviate/) | Document Store | [![PyPI - Version](https://img.shields.io/pypi/v/weaviate-haystack.svg)](https://pypi.org/project/weaviate-haystack) | [![Test / weaviate](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml/badge.svg)](https://github.com/deepset-ai/haystack-core-integrations/actions/workflows/weaviate.yml) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-weaviate/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-weaviate/htmlcov/index.html) | [![Coverage badge](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/deepset-ai/haystack-core-integrations/python-coverage-comment-action-data-weaviate-combined/endpoint.json&label=)](https://htmlpreview.github.io/?https://github.com/deepset-ai/haystack-core-integrations/blob/python-coverage-comment-action-data-weaviate-combined/htmlcov/index.html) | diff --git a/integrations/vllm/pydoc/config_docusaurus.yml b/integrations/vllm/pydoc/config_docusaurus.yml index 13c26dd968..daf8d8b75c 100644 --- a/integrations/vllm/pydoc/config_docusaurus.yml +++ b/integrations/vllm/pydoc/config_docusaurus.yml @@ -1,6 +1,8 @@ loaders: - modules: - haystack_integrations.components.generators.vllm.chat.chat_generator + - haystack_integrations.components.embedders.vllm.text_embedder + - haystack_integrations.components.embedders.vllm.document_embedder search_path: [../src] processors: - type: filter diff --git a/integrations/vllm/pyproject.toml b/integrations/vllm/pyproject.toml index 53cec11e33..d3c5e75a1a 100644 --- a/integrations/vllm/pyproject.toml +++ b/integrations/vllm/pyproject.toml @@ -66,7 +66,7 @@ integration = 'pytest -m "integration" {args:tests}' all = 'pytest {args:tests}' unit-cov-retry = 'pytest --cov=haystack_integrations --reruns 3 --reruns-delay 30 -x -m "not integration" {args:tests}' integration-cov-append-retry = 'pytest --cov=haystack_integrations --cov-append --reruns 3 --reruns-delay 30 -x -m "integration" {args:tests}' -types = "mypy -p haystack_integrations.components.generators.vllm {args}" +types = "mypy -p haystack_integrations.components.generators.vllm -p haystack_integrations.components.embedders.vllm -p haystack_integrations.common.vllm {args}" [tool.mypy] install_types = true diff --git a/integrations/vllm/src/haystack_integrations/common/py.typed b/integrations/vllm/src/haystack_integrations/common/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integrations/vllm/src/haystack_integrations/common/vllm/__init__.py b/integrations/vllm/src/haystack_integrations/common/vllm/__init__.py new file mode 100644 index 0000000000..c1764a6e03 --- /dev/null +++ b/integrations/vllm/src/haystack_integrations/common/vllm/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 diff --git a/integrations/vllm/src/haystack_integrations/common/vllm/utils.py b/integrations/vllm/src/haystack_integrations/common/vllm/utils.py new file mode 100644 index 0000000000..92615f5906 --- /dev/null +++ b/integrations/vllm/src/haystack_integrations/common/vllm/utils.py @@ -0,0 +1,39 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +from haystack.utils import Secret +from haystack.utils.http_client import init_http_client +from openai import AsyncOpenAI, OpenAI + + +def _create_openai_clients( + api_key: Secret | None, + api_base_url: str, + timeout: float | None, + max_retries: int | None, + http_client_kwargs: dict[str, Any] | None, +) -> tuple[OpenAI, AsyncOpenAI]: + """ + Build sync and async OpenAI clients pointing at a vLLM server. + + A placeholder api key is used when the user did not supply one and no `VLLM_API_KEY` env var is + set, because the OpenAI client requires a non-empty value. `timeout` and `max_retries` are only + forwarded when provided: when None, the OpenAI client's own defaults apply and no `OPENAI_*` + env vars are read. + """ + resolved_api_key = "placeholder-api-key" + if api_key is not None and (value := api_key.resolve_value()): + resolved_api_key = value + + client_kwargs: dict[str, Any] = {"api_key": resolved_api_key, "base_url": api_base_url} + if timeout is not None: + client_kwargs["timeout"] = timeout + if max_retries is not None: + client_kwargs["max_retries"] = max_retries + + sync_client = OpenAI(http_client=init_http_client(http_client_kwargs, async_client=False), **client_kwargs) + async_client = AsyncOpenAI(http_client=init_http_client(http_client_kwargs, async_client=True), **client_kwargs) + return sync_client, async_client diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/py.typed b/integrations/vllm/src/haystack_integrations/components/embedders/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/__init__.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/__init__.py new file mode 100644 index 0000000000..1ffc2e2931 --- /dev/null +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/__init__.py @@ -0,0 +1,8 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from .document_embedder import VLLMDocumentEmbedder +from .text_embedder import VLLMTextEmbedder + +__all__ = ["VLLMDocumentEmbedder", "VLLMTextEmbedder"] diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py new file mode 100644 index 0000000000..882a8cd597 --- /dev/null +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py @@ -0,0 +1,314 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from dataclasses import replace +from typing import Any + +from haystack import Document, component, default_from_dict, default_to_dict, logging +from haystack.utils import Secret +from more_itertools import batched +from openai import APIError, AsyncOpenAI, OpenAI +from tqdm import tqdm +from tqdm.asyncio import tqdm as async_tqdm + +from haystack_integrations.common.vllm.utils import _create_openai_clients + +logger = logging.getLogger(__name__) + + +@component +class VLLMDocumentEmbedder: + """ + A component for computing Document embeddings using models served with [vLLM](https://docs.vllm.ai/). + + The embedding of each Document is stored in the `embedding` field of the Document. + It expects a vLLM server to be running and accessible at the `api_base_url` parameter and uses the + OpenAI-compatible Embeddings API exposed by vLLM. + + ### Starting the vLLM server + + Before using this component, start a vLLM server with an embedding model: + + ```bash + vllm serve intfloat/e5-mistral-7b-instruct + ``` + + For details on server options, see the [vLLM CLI docs](https://docs.vllm.ai/en/stable/cli/serve/). + + ### Usage example + + ```python + from haystack import Document + from haystack_integrations.components.embedders.vllm import VLLMDocumentEmbedder + + doc = Document(content="I love pizza!") + + document_embedder = VLLMDocumentEmbedder(model="intfloat/e5-mistral-7b-instruct") + + result = document_embedder.run([doc]) + print(result["documents"][0].embedding) + ``` + + ### Usage example with vLLM-specific parameters + + Pass vLLM-specific parameters via the `extra_parameters` dictionary. They are forwarded as `extra_body` + to the OpenAI-compatible endpoint. + + ```python + document_embedder = VLLMDocumentEmbedder( + model="jinaai/jina-embeddings-v3", + extra_parameters={"dimensions": 32, "truncate_prompt_tokens": 256}, + ) + ``` + """ + + def __init__( + self, + *, + model: str, + api_key: Secret | None = Secret.from_env_var("VLLM_API_KEY", strict=False), + api_base_url: str = "http://localhost:8000/v1", + prefix: str = "", + suffix: str = "", + batch_size: int = 32, + progress_bar: bool = True, + meta_fields_to_embed: list[str] | None = None, + embedding_separator: str = "\n", + timeout: float | None = None, + max_retries: int | None = None, + http_client_kwargs: dict[str, Any] | None = None, + raise_on_failure: bool = False, + extra_parameters: dict[str, Any] | None = None, + ) -> None: + """ + Creates an instance of VLLMDocumentEmbedder. + + :param model: The name of the model served by vLLM (e.g., "intfloat/e5-mistral-7b-instruct"). + :param api_key: The vLLM API key. Defaults to the `VLLM_API_KEY` environment variable. + Only required if the vLLM server was started with `--api-key`. + :param api_base_url: The base URL of the vLLM server. + :param prefix: A string to add at the beginning of each text. + :param suffix: A string to add at the end of each text. + :param batch_size: Number of Documents to encode at once. + :param progress_bar: Whether to show a progress bar. Disable in production to keep logs clean. + :param meta_fields_to_embed: List of meta fields to embed along with the Document text. + :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + :param timeout: Timeout in seconds for vLLM client calls. If not set, the OpenAI client default applies. + :param max_retries: Maximum number of retries for failed requests. If not set, the OpenAI client + default applies. + :param http_client_kwargs: A dictionary of keyword arguments to configure a custom `httpx.Client` or + `httpx.AsyncClient`. For more information, see the + [HTTPX documentation](https://www.python-httpx.org/api/#client). + :param raise_on_failure: Whether to raise an exception if the embedding request fails. If `False`, + the component logs the error and continues processing the remaining documents. + :param extra_parameters: Additional parameters forwarded as `extra_body` to the vLLM embeddings + endpoint. Use this to pass parameters not part of the standard OpenAI Embeddings API, such as + `dimensions` (for Matryoshka models), `truncate_prompt_tokens`, `truncation_side`, + `additional_data`, `use_activation`, etc. See the + [vLLM Embeddings API docs](https://docs.vllm.ai/en/stable/models/pooling_models.html#openai-compatible-embeddings-api). + """ + self.model = model + self.api_key = api_key + self.api_base_url = api_base_url + self.prefix = prefix + self.suffix = suffix + self.batch_size = batch_size + self.progress_bar = progress_bar + self.meta_fields_to_embed = meta_fields_to_embed or [] + self.embedding_separator = embedding_separator + self.timeout = timeout + self.max_retries = max_retries + self.http_client_kwargs = http_client_kwargs + self.raise_on_failure = raise_on_failure + self.extra_parameters = extra_parameters + + self._client: OpenAI | None = None + self._async_client: AsyncOpenAI | None = None + self._is_warmed_up = False + + def warm_up(self) -> None: + """Create the OpenAI clients.""" + if self._is_warmed_up: + return + self._client, self._async_client = _create_openai_clients( + api_key=self.api_key, + api_base_url=self.api_base_url, + timeout=self.timeout, + max_retries=self.max_retries, + http_client_kwargs=self.http_client_kwargs, + ) + self._is_warmed_up = True + + def to_dict(self) -> dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: The serialized component as a dictionary. + """ + return default_to_dict( + self, + model=self.model, + api_key=self.api_key.to_dict() if self.api_key else None, + api_base_url=self.api_base_url, + prefix=self.prefix, + suffix=self.suffix, + batch_size=self.batch_size, + progress_bar=self.progress_bar, + meta_fields_to_embed=self.meta_fields_to_embed, + embedding_separator=self.embedding_separator, + timeout=self.timeout, + max_retries=self.max_retries, + http_client_kwargs=self.http_client_kwargs, + raise_on_failure=self.raise_on_failure, + extra_parameters=self.extra_parameters, + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "VLLMDocumentEmbedder": + """Deserialize this component from a dictionary.""" + return default_from_dict(cls, data) + + def _prepare_texts_to_embed(self, documents: list[Document]) -> dict[str, str]: + """Concatenate each Document's text with the selected meta fields.""" + texts_to_embed = {} + for doc in documents: + meta_values_to_embed = [ + str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None + ] + texts_to_embed[doc.id] = ( + self.prefix + self.embedding_separator.join([*meta_values_to_embed, doc.content or ""]) + self.suffix + ) + return texts_to_embed + + def _prepare_input(self, inputs: list[str]) -> dict[str, Any]: + kwargs: dict[str, Any] = {"model": self.model, "input": inputs, "encoding_format": "float"} + if self.extra_parameters: + kwargs["extra_body"] = self.extra_parameters + return kwargs + + @staticmethod + def _update_meta(meta: dict[str, Any], response: Any) -> None: + if "model" not in meta: + meta["model"] = response.model + if "usage" not in meta: + meta["usage"] = dict(response.usage) + else: + meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens + meta["usage"]["total_tokens"] += response.usage.total_tokens + + def _embed_batch( + self, texts_to_embed: dict[str, str], batch_size: int + ) -> tuple[dict[str, list[float]], dict[str, Any]]: + assert self._client is not None # noqa: S101 + doc_ids_to_embeddings: dict[str, list[float]] = {} + meta: dict[str, Any] = {} + + for batch in tqdm( + batched(texts_to_embed.items(), batch_size), + disable=not self.progress_bar, + desc="Calculating embeddings", + ): + kwargs = self._prepare_input([b[1] for b in batch]) + try: + response = self._client.embeddings.create(**kwargs) + except APIError as exc: + ids = ", ".join(b[0] for b in batch) + logger.exception("Failed embedding of documents {ids} caused by {exc}", ids=ids, exc=exc) + if self.raise_on_failure: + raise + continue + + embeddings = [el.embedding for el in response.data] + doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings, strict=True))) + self._update_meta(meta, response) + + return doc_ids_to_embeddings, meta + + async def _embed_batch_async( + self, texts_to_embed: dict[str, str], batch_size: int + ) -> tuple[dict[str, list[float]], dict[str, Any]]: + assert self._async_client is not None # noqa: S101 + doc_ids_to_embeddings: dict[str, list[float]] = {} + meta: dict[str, Any] = {} + + batches = list(batched(texts_to_embed.items(), batch_size)) + iterator = async_tqdm(batches, desc="Calculating embeddings") if self.progress_bar else batches + + for batch in iterator: + kwargs = self._prepare_input([b[1] for b in batch]) + try: + response = await self._async_client.embeddings.create(**kwargs) + except APIError as exc: + ids = ", ".join(b[0] for b in batch) + logger.exception("Failed embedding of documents {ids} caused by {exc}", ids=ids, exc=exc) + if self.raise_on_failure: + raise + continue + + embeddings = [el.embedding for el in response.data] + doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings, strict=True))) + self._update_meta(meta, response) + + return doc_ids_to_embeddings, meta + + @staticmethod + def _validate_documents(documents: list[Document]) -> None: + if not isinstance(documents, list) or (documents and not isinstance(documents[0], Document)): + msg = ( + "VLLMDocumentEmbedder expects a list of Documents as input. " + "In case you want to embed a string, please use the VLLMTextEmbedder." + ) + raise TypeError(msg) + + @component.output_types(documents=list[Document], meta=dict[str, Any]) + def run(self, documents: list[Document]) -> dict[str, Any]: + """ + Embed a list of Documents. + + :param documents: Documents to embed. + :returns: A dictionary with: + - `documents`: The input documents with their `embedding` field populated. + - `meta`: Information about the usage of the model. + """ + self._validate_documents(documents) + if not documents: + return {"documents": [], "meta": {}} + + if not self._is_warmed_up: + self.warm_up() + + texts_to_embed = self._prepare_texts_to_embed(documents) + doc_ids_to_embeddings, meta = self._embed_batch(texts_to_embed, self.batch_size) + + new_documents = [ + replace(doc, embedding=doc_ids_to_embeddings[doc.id]) if doc.id in doc_ids_to_embeddings else replace(doc) + for doc in documents + ] + return {"documents": new_documents, "meta": meta} + + @component.output_types(documents=list[Document], meta=dict[str, Any]) + async def run_async(self, documents: list[Document]) -> dict[str, Any]: + """ + Asynchronously embed a list of Documents. + + :param documents: Documents to embed. + :returns: A dictionary with: + - `documents`: The input documents with their `embedding` field populated. + - `meta`: Information about the usage of the model. + """ + self._validate_documents(documents) + if not documents: + return {"documents": [], "meta": {}} + + if not self._is_warmed_up: + self.warm_up() + + texts_to_embed = self._prepare_texts_to_embed(documents) + doc_ids_to_embeddings, meta = await self._embed_batch_async(texts_to_embed, self.batch_size) + + new_documents = [ + replace(doc, embedding=doc_ids_to_embeddings[doc.id]) if doc.id in doc_ids_to_embeddings else replace(doc) + for doc in documents + ] + return {"documents": new_documents, "meta": meta} diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py new file mode 100644 index 0000000000..090e1708cb --- /dev/null +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py @@ -0,0 +1,196 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + +from haystack import component, default_from_dict, default_to_dict +from haystack.utils import Secret +from openai import AsyncOpenAI, OpenAI +from openai.types import CreateEmbeddingResponse + +from haystack_integrations.common.vllm.utils import _create_openai_clients + + +@component +class VLLMTextEmbedder: + """ + A component for embedding strings using models served with [vLLM](https://docs.vllm.ai/). + + It expects a vLLM server to be running and accessible at the `api_base_url` parameter and uses the + OpenAI-compatible Embeddings API exposed by vLLM. + + ### Starting the vLLM server + + Before using this component, start a vLLM server with an embedding model: + + ```bash + vllm serve intfloat/e5-mistral-7b-instruct + ``` + + For details on server options, see the [vLLM CLI docs](https://docs.vllm.ai/en/stable/cli/serve/). + + ### Usage example + + ```python + from haystack_integrations.components.embedders.vllm import VLLMTextEmbedder + + text_embedder = VLLMTextEmbedder(model="intfloat/e5-mistral-7b-instruct") + print(text_embedder.run("I love pizza!")) + ``` + + ### Usage example with vLLM-specific parameters + + Pass vLLM-specific parameters via the `extra_parameters` dictionary. They are forwarded as `extra_body` + to the OpenAI-compatible endpoint. + + ```python + text_embedder = VLLMTextEmbedder( + model="jinaai/jina-embeddings-v3", + extra_parameters={"dimensions": 32, "truncate_prompt_tokens": 256}, + ) + ``` + """ + + def __init__( + self, + *, + model: str, + api_key: Secret | None = Secret.from_env_var("VLLM_API_KEY", strict=False), + api_base_url: str = "http://localhost:8000/v1", + prefix: str = "", + suffix: str = "", + timeout: float | None = None, + max_retries: int | None = None, + http_client_kwargs: dict[str, Any] | None = None, + extra_parameters: dict[str, Any] | None = None, + ) -> None: + """ + Creates an instance of VLLMTextEmbedder. + + :param model: The name of the model served by vLLM (e.g., "intfloat/e5-mistral-7b-instruct"). + :param api_key: The vLLM API key. Defaults to the `VLLM_API_KEY` environment variable. + Only required if the vLLM server was started with `--api-key`. + :param api_base_url: The base URL of the vLLM server. + :param prefix: A string to add at the beginning of each text to embed. + :param suffix: A string to add at the end of each text to embed. + :param timeout: Timeout in seconds for vLLM client calls. If not set, the OpenAI client default applies. + :param max_retries: Maximum number of retries for failed requests. If not set, the OpenAI client + default applies. + :param http_client_kwargs: A dictionary of keyword arguments to configure a custom `httpx.Client` or + `httpx.AsyncClient`. For more information, see the + [HTTPX documentation](https://www.python-httpx.org/api/#client). + :param extra_parameters: Additional parameters forwarded as `extra_body` to the vLLM embeddings + endpoint. Use this to pass parameters not part of the standard OpenAI Embeddings API, such as + `dimensions` (for Matryoshka models), `truncate_prompt_tokens`, `truncation_side`, + `additional_data`, `use_activation`, etc. See the + [vLLM Embeddings API docs](https://docs.vllm.ai/en/stable/models/pooling_models.html#openai-compatible-embeddings-api). + """ + self.model = model + self.api_key = api_key + self.api_base_url = api_base_url + self.prefix = prefix + self.suffix = suffix + self.timeout = timeout + self.max_retries = max_retries + self.http_client_kwargs = http_client_kwargs + self.extra_parameters = extra_parameters + + self._client: OpenAI | None = None + self._async_client: AsyncOpenAI | None = None + self._is_warmed_up = False + + def warm_up(self) -> None: + """Create the OpenAI clients.""" + if self._is_warmed_up: + return + self._client, self._async_client = _create_openai_clients( + api_key=self.api_key, + api_base_url=self.api_base_url, + timeout=self.timeout, + max_retries=self.max_retries, + http_client_kwargs=self.http_client_kwargs, + ) + self._is_warmed_up = True + + def to_dict(self) -> dict[str, Any]: + """ + Serialize this component to a dictionary. + + :returns: The serialized component as a dictionary. + """ + return default_to_dict( + self, + model=self.model, + api_key=self.api_key.to_dict() if self.api_key else None, + api_base_url=self.api_base_url, + prefix=self.prefix, + suffix=self.suffix, + timeout=self.timeout, + max_retries=self.max_retries, + http_client_kwargs=self.http_client_kwargs, + extra_parameters=self.extra_parameters, + ) + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "VLLMTextEmbedder": + """Deserialize this component from a dictionary.""" + return default_from_dict(cls, data) + + def _prepare_input(self, text: str) -> dict[str, Any]: + if not isinstance(text, str): + msg = ( + "VLLMTextEmbedder expects a string as an input. " + "In case you want to embed a list of Documents, please use the VLLMDocumentEmbedder." + ) + raise TypeError(msg) + + kwargs: dict[str, Any] = { + "model": self.model, + "input": self.prefix + text + self.suffix, + "encoding_format": "float", + } + if self.extra_parameters: + kwargs["extra_body"] = self.extra_parameters + return kwargs + + @staticmethod + def _prepare_output(response: CreateEmbeddingResponse) -> dict[str, Any]: + return { + "embedding": response.data[0].embedding, + "meta": {"model": response.model, "usage": dict(response.usage)}, + } + + @component.output_types(embedding=list[float], meta=dict[str, Any]) + def run(self, text: str) -> dict[str, Any]: + """ + Embed a single string. + + :param text: Text to embed. + :returns: A dictionary with: + - `embedding`: The embedding of the input text. + - `meta`: Information about the usage of the model. + """ + kwargs = self._prepare_input(text) + if not self._is_warmed_up: + self.warm_up() + assert self._client is not None # noqa: S101 + response = self._client.embeddings.create(**kwargs) + return self._prepare_output(response) + + @component.output_types(embedding=list[float], meta=dict[str, Any]) + async def run_async(self, text: str) -> dict[str, Any]: + """ + Asynchronously embed a single string. + + :param text: Text to embed. + :returns: A dictionary with: + - `embedding`: The embedding of the input text. + - `meta`: Information about the usage of the model. + """ + kwargs = self._prepare_input(text) + if not self._is_warmed_up: + self.warm_up() + assert self._async_client is not None # noqa: S101 + response = await self._async_client.embeddings.create(**kwargs) + return self._prepare_output(response) diff --git a/integrations/vllm/src/haystack_integrations/components/generators/vllm/chat/chat_generator.py b/integrations/vllm/src/haystack_integrations/components/generators/vllm/chat/chat_generator.py index 5d21c3abaa..356d9da3f0 100644 --- a/integrations/vllm/src/haystack_integrations/components/generators/vllm/chat/chat_generator.py +++ b/integrations/vllm/src/haystack_integrations/components/generators/vllm/chat/chat_generator.py @@ -29,11 +29,12 @@ warm_up_tools, ) from haystack.utils import Secret, deserialize_callable, serialize_callable -from haystack.utils.http_client import init_http_client from openai import AsyncOpenAI, AsyncStream, OpenAI, Stream from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat.chat_completion import Choice +from haystack_integrations.common.vllm.utils import _create_openai_clients + logger = logging.getLogger(__name__) @@ -257,24 +258,12 @@ def warm_up(self) -> None: if self._is_warmed_up: return - api_key = "placeholder-api-key" - if self.api_key and (resolved_value := self.api_key.resolve_value()): - api_key = resolved_value - - client_kwargs: dict[str, Any] = { - "api_key": api_key, - "base_url": self.api_base_url, - } - if self.timeout is not None: - client_kwargs["timeout"] = self.timeout - if self.max_retries is not None: - client_kwargs["max_retries"] = self.max_retries - - self._client = OpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs - ) - self._async_client = AsyncOpenAI( - http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs + self._client, self._async_client = _create_openai_clients( + api_key=self.api_key, + api_base_url=self.api_base_url, + timeout=self.timeout, + max_retries=self.max_retries, + http_client_kwargs=self.http_client_kwargs, ) warm_up_tools(self.tools) self._is_warmed_up = True diff --git a/integrations/vllm/tests/test_document_embedder.py b/integrations/vllm/tests/test_document_embedder.py new file mode 100644 index 0000000000..d8196d31e7 --- /dev/null +++ b/integrations/vllm/tests/test_document_embedder.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from unittest.mock import AsyncMock, MagicMock + +import pytest +from haystack import Document +from haystack.utils import Secret +from openai import APIError +from openai.types import CreateEmbeddingResponse, Embedding +from openai.types.create_embedding_response import Usage + +from haystack_integrations.components.embedders.vllm import VLLMDocumentEmbedder + +MODEL = "intfloat/e5-mistral-7b-instruct" + + +def _fake_response(embeddings: list[list[float]], prompt_tokens: int = 1, total_tokens: int = 1): + return CreateEmbeddingResponse( + object="list", + model="fake-model", + data=[Embedding(object="embedding", index=i, embedding=e) for i, e in enumerate(embeddings)], + usage=Usage(prompt_tokens=prompt_tokens, total_tokens=total_tokens), + ) + + +def _api_error() -> APIError: + return APIError(message="boom", request=MagicMock(), body=None) + + +class TestVLLMDocumentEmbedder: + def test_init_default(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + embedder = VLLMDocumentEmbedder(model=MODEL) + assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) + assert embedder.model == MODEL + assert embedder.api_base_url == "http://localhost:8000/v1" + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.meta_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + assert embedder.raise_on_failure is False + assert embedder.extra_parameters is None + assert embedder._client is None + assert embedder._async_client is None + assert embedder._is_warmed_up is False + + def test_init_with_parameters(self): + embedder = VLLMDocumentEmbedder( + model=MODEL, + api_key=Secret.from_token("test-api-key"), + api_base_url="http://my-vllm-server:8000/v1", + prefix="START", + suffix="END", + batch_size=64, + progress_bar=False, + meta_fields_to_embed=["test_field"], + embedding_separator="-", + raise_on_failure=True, + extra_parameters={"dimensions": 32, "truncate_prompt_tokens": 256}, + ) + assert embedder.api_key == Secret.from_token("test-api-key") + assert embedder.api_base_url == "http://my-vllm-server:8000/v1" + assert embedder.prefix == "START" + assert embedder.suffix == "END" + assert embedder.batch_size == 64 + assert embedder.progress_bar is False + assert embedder.meta_fields_to_embed == ["test_field"] + assert embedder.embedding_separator == "-" + assert embedder.raise_on_failure is True + assert embedder.extra_parameters == {"dimensions": 32, "truncate_prompt_tokens": 256} + + def test_warm_up(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + embedder = VLLMDocumentEmbedder(model=MODEL) + embedder.warm_up() + + assert embedder._is_warmed_up is True + assert embedder._client is not None + assert embedder._async_client is not None + + # idempotent + client_before = embedder._client + embedder.warm_up() + assert embedder._client is client_before + + def test_to_dict(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + component_dict = VLLMDocumentEmbedder(model=MODEL).to_dict() + assert component_dict == { + "type": "haystack_integrations.components.embedders.vllm.document_embedder.VLLMDocumentEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["VLLM_API_KEY"], "strict": False, "type": "env_var"}, + "model": MODEL, + "api_base_url": "http://localhost:8000/v1", + "prefix": "", + "suffix": "", + "batch_size": 32, + "progress_bar": True, + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "timeout": None, + "max_retries": None, + "http_client_kwargs": None, + "raise_on_failure": False, + "extra_parameters": None, + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + data = { + "type": "haystack_integrations.components.embedders.vllm.document_embedder.VLLMDocumentEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["VLLM_API_KEY"], "strict": False, "type": "env_var"}, + "model": MODEL, + "api_base_url": "http://localhost:8000/v1", + "prefix": "", + "suffix": "", + "batch_size": 32, + "progress_bar": True, + "meta_fields_to_embed": [], + "embedding_separator": "\n", + "timeout": None, + "max_retries": None, + "http_client_kwargs": None, + "raise_on_failure": False, + "extra_parameters": {"dimensions": 32}, + }, + } + embedder = VLLMDocumentEmbedder.from_dict(data) + assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) + assert embedder.model == MODEL + assert embedder.api_base_url == "http://localhost:8000/v1" + assert embedder.batch_size == 32 + assert embedder.extra_parameters == {"dimensions": 32} + + def test_prepare_texts_to_embed(self): + embedder = VLLMDocumentEmbedder( + model=MODEL, prefix="[", suffix="]", meta_fields_to_embed=["topic"], embedding_separator=" | " + ) + doc = Document(content="hello", meta={"topic": "ML"}) + texts = embedder._prepare_texts_to_embed([doc]) + assert texts == {doc.id: "[ML | hello]"} + + def test_prepare_input_adds_extra_body(self): + embedder = VLLMDocumentEmbedder(model=MODEL, extra_parameters={"dimensions": 32}) + kwargs = embedder._prepare_input(["a", "b"]) + assert kwargs == { + "model": MODEL, + "input": ["a", "b"], + "encoding_format": "float", + "extra_body": {"dimensions": 32}, + } + + def test_run_wrong_input_format(self): + embedder = VLLMDocumentEmbedder(model=MODEL) + + with pytest.raises(TypeError, match=r"VLLMDocumentEmbedder expects a list of Documents as input\."): + embedder.run(documents="text") + with pytest.raises(TypeError, match=r"VLLMDocumentEmbedder expects a list of Documents as input\."): + embedder.run(documents=[1, 2, 3]) + + assert embedder.run(documents=[]) == {"documents": [], "meta": {}} + + def test_run_batches_and_aggregates_meta(self): + """Multi-batch run: embeddings stitched back to the right docs, usage meta accumulates.""" + embedder = VLLMDocumentEmbedder(model=MODEL, batch_size=2, progress_bar=False) + embedder._client = MagicMock() + embedder._client.embeddings.create.side_effect = [ + _fake_response([[0.1], [0.2]], prompt_tokens=2, total_tokens=2), + _fake_response([[0.3]], prompt_tokens=1, total_tokens=1), + ] + embedder._is_warmed_up = True + + docs = [Document(content=f"doc-{i}") for i in range(3)] + result = embedder.run(docs) + + assert [d.embedding for d in result["documents"]] == [[0.1], [0.2], [0.3]] + assert result["meta"] == {"model": "fake-model", "usage": {"prompt_tokens": 3, "total_tokens": 3}} + + def test_run_continues_on_api_error(self): + """raise_on_failure=False: failed batches are skipped, surviving docs keep their embedding.""" + embedder = VLLMDocumentEmbedder(model=MODEL, batch_size=1, progress_bar=False) + embedder._client = MagicMock() + embedder._client.embeddings.create.side_effect = [_fake_response([[0.1]]), _api_error()] + embedder._is_warmed_up = True + + result = embedder.run([Document(content="a"), Document(content="b")]) + + assert result["documents"][0].embedding == [0.1] + assert result["documents"][1].embedding is None + + def test_run_raise_on_failure(self): + embedder = VLLMDocumentEmbedder(model=MODEL, raise_on_failure=True, progress_bar=False) + embedder._client = MagicMock() + embedder._client.embeddings.create.side_effect = _api_error() + embedder._is_warmed_up = True + + with pytest.raises(APIError): + embedder.run([Document(content="a")]) + + @pytest.mark.asyncio + async def test_run_async(self): + embedder = VLLMDocumentEmbedder(model=MODEL, progress_bar=True) + embedder._async_client = MagicMock() + embedder._async_client.embeddings.create = AsyncMock(return_value=_fake_response([[0.5], [0.6]])) + embedder._is_warmed_up = True + + docs = [Document(content="a"), Document(content="b")] + result = await embedder.run_async(docs) + + assert [d.embedding for d in result["documents"]] == [[0.5], [0.6]] + + @pytest.mark.skipif( + not os.environ.get("VLLM_API_BASE_URL", None), + reason="Export VLLM_API_BASE_URL pointing to a running vLLM embedding server to run this test.", + ) + @pytest.mark.integration + def test_run(self): + embedder = VLLMDocumentEmbedder(model=MODEL, api_base_url=os.environ["VLLM_API_BASE_URL"]) + + docs = [ + Document(content="I love cheese", meta={"topic": "Cuisine"}), + Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + ] + + result = embedder.run(docs) + docs_with_embeddings = result["documents"] + + assert len(docs_with_embeddings) == len(docs) + for doc in docs_with_embeddings: + assert isinstance(doc.embedding, list) + assert isinstance(doc.embedding[0], float) diff --git a/integrations/vllm/tests/test_text_embedder.py b/integrations/vllm/tests/test_text_embedder.py new file mode 100644 index 0000000000..9ec515e572 --- /dev/null +++ b/integrations/vllm/tests/test_text_embedder.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 +import os +from unittest.mock import AsyncMock, MagicMock + +import pytest +from haystack.utils import Secret +from openai.types import CreateEmbeddingResponse, Embedding +from openai.types.create_embedding_response import Usage + +from haystack_integrations.components.embedders.vllm import VLLMTextEmbedder + +MODEL = "intfloat/e5-mistral-7b-instruct" + + +def _fake_response(embeddings: list[list[float]], prompt_tokens: int = 5, total_tokens: int = 5): + return CreateEmbeddingResponse( + object="list", + model="fake-model", + data=[Embedding(object="embedding", index=i, embedding=e) for i, e in enumerate(embeddings)], + usage=Usage(prompt_tokens=prompt_tokens, total_tokens=total_tokens), + ) + + +class TestVLLMTextEmbedder: + def test_init_default(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + embedder = VLLMTextEmbedder(model=MODEL) + assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) + assert embedder.api_base_url == "http://localhost:8000/v1" + assert embedder.model == MODEL + assert embedder.prefix == "" + assert embedder.suffix == "" + assert embedder.timeout is None + assert embedder.max_retries is None + assert embedder.http_client_kwargs is None + assert embedder.extra_parameters is None + assert embedder._client is None + assert embedder._async_client is None + assert embedder._is_warmed_up is False + + def test_init_with_parameters(self): + embedder = VLLMTextEmbedder( + model=MODEL, + api_key=Secret.from_token("test-api-key"), + api_base_url="http://my-vllm-server:8000/v1", + prefix="START", + suffix="END", + timeout=10.0, + max_retries=2, + http_client_kwargs={"proxy": "https://proxy.example.com"}, + extra_parameters={"dimensions": 32, "truncate_prompt_tokens": 256}, + ) + assert embedder.api_key == Secret.from_token("test-api-key") + assert embedder.api_base_url == "http://my-vllm-server:8000/v1" + assert embedder.model == MODEL + assert embedder.prefix == "START" + assert embedder.suffix == "END" + assert embedder.timeout == 10.0 + assert embedder.max_retries == 2 + assert embedder.http_client_kwargs == {"proxy": "https://proxy.example.com"} + assert embedder.extra_parameters == {"dimensions": 32, "truncate_prompt_tokens": 256} + + def test_warm_up(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + embedder = VLLMTextEmbedder(model=MODEL) + embedder.warm_up() + + assert embedder._is_warmed_up is True + assert embedder._client is not None + assert embedder._async_client is not None + + # idempotent: calling again does not recreate clients + client_before = embedder._client + embedder.warm_up() + assert embedder._client is client_before + + def test_to_dict(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + + component_dict = VLLMTextEmbedder(model=MODEL).to_dict() + assert component_dict == { + "type": "haystack_integrations.components.embedders.vllm.text_embedder.VLLMTextEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["VLLM_API_KEY"], "strict": False, "type": "env_var"}, + "model": MODEL, + "api_base_url": "http://localhost:8000/v1", + "prefix": "", + "suffix": "", + "timeout": None, + "max_retries": None, + "http_client_kwargs": None, + "extra_parameters": None, + }, + } + + def test_from_dict(self, monkeypatch): + monkeypatch.delenv("VLLM_API_KEY", raising=False) + data = { + "type": "haystack_integrations.components.embedders.vllm.text_embedder.VLLMTextEmbedder", + "init_parameters": { + "api_key": {"env_vars": ["VLLM_API_KEY"], "strict": False, "type": "env_var"}, + "model": MODEL, + "api_base_url": "http://localhost:8000/v1", + "prefix": "", + "suffix": "", + "timeout": None, + "max_retries": None, + "http_client_kwargs": None, + "extra_parameters": {"dimensions": 32}, + }, + } + embedder = VLLMTextEmbedder.from_dict(data) + assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) + assert embedder.model == MODEL + assert embedder.api_base_url == "http://localhost:8000/v1" + assert embedder.extra_parameters == {"dimensions": 32} + + def test_prepare_input_adds_extra_body(self): + embedder = VLLMTextEmbedder(model=MODEL, prefix="[", suffix="]", extra_parameters={"dimensions": 32}) + kwargs = embedder._prepare_input("hello") + assert kwargs == { + "model": MODEL, + "input": "[hello]", + "encoding_format": "float", + "extra_body": {"dimensions": 32}, + } + + def test_run_wrong_input_format(self): + embedder = VLLMTextEmbedder(model=MODEL) + with pytest.raises(TypeError, match=r"VLLMTextEmbedder expects a string as an input\."): + embedder.run(text=["text_1", "text_2"]) + + def test_run_with_mock(self): + embedder = VLLMTextEmbedder(model=MODEL, prefix="[", suffix="]", extra_parameters={"dimensions": 2}) + embedder._client = MagicMock() + embedder._client.embeddings.create.return_value = _fake_response([[0.1, 0.2]]) + embedder._is_warmed_up = True + + result = embedder.run("hello") + + call_kwargs = embedder._client.embeddings.create.call_args.kwargs + assert call_kwargs["input"] == "[hello]" + assert call_kwargs["extra_body"] == {"dimensions": 2} + assert result == { + "embedding": [0.1, 0.2], + "meta": {"model": "fake-model", "usage": {"prompt_tokens": 5, "total_tokens": 5}}, + } + + @pytest.mark.asyncio + async def test_run_async(self): + embedder = VLLMTextEmbedder(model=MODEL) + embedder._async_client = MagicMock() + embedder._async_client.embeddings.create = AsyncMock(return_value=_fake_response([[0.3, 0.4]])) + embedder._is_warmed_up = True + + result = await embedder.run_async("world") + assert result["embedding"] == [0.3, 0.4] + + @pytest.mark.skipif( + not os.environ.get("VLLM_API_BASE_URL", None), + reason="Export VLLM_API_BASE_URL pointing to a running vLLM embedding server to run this test.", + ) + @pytest.mark.integration + def test_run(self): + embedder = VLLMTextEmbedder(model=MODEL, api_base_url=os.environ["VLLM_API_BASE_URL"]) + result = embedder.run("The food was delicious") + assert isinstance(result["embedding"], list) + assert all(isinstance(x, float) for x in result["embedding"]) diff --git a/integrations/vllm/tests/test_utils.py b/integrations/vllm/tests/test_utils.py new file mode 100644 index 0000000000..62efd9634c --- /dev/null +++ b/integrations/vllm/tests/test_utils.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: 2022-present deepset GmbH +# +# SPDX-License-Identifier: Apache-2.0 + +from haystack.utils import Secret + +from haystack_integrations.common.vllm.utils import _create_openai_clients + + +def test_create_openai_clients_placeholder_when_no_key(): + """When api_key is None or unresolved, a placeholder is used.""" + sync_client, async_client = _create_openai_clients( + api_key=None, api_base_url="http://localhost:8000/v1", timeout=None, max_retries=None, http_client_kwargs=None + ) + assert sync_client.api_key == "placeholder-api-key" + assert async_client.api_key == "placeholder-api-key" + assert str(sync_client.base_url) == "http://localhost:8000/v1/" + + +def test_create_openai_clients_uses_resolved_key_and_forwards_options(): + """When api_key resolves, it's used; timeout/max_retries forwarded only when set.""" + sync_client, _ = _create_openai_clients( + api_key=Secret.from_token("real-key"), + api_base_url="http://vllm:8000/v1", + timeout=12.5, + max_retries=7, + http_client_kwargs=None, + ) + assert sync_client.api_key == "real-key" + assert sync_client.timeout == 12.5 + assert sync_client.max_retries == 7 From bb20fc7c6f96a755fe1e64097d18186200a18201 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 14 Apr 2026 16:52:45 +0200 Subject: [PATCH 2/7] improvements --- integrations/vllm/pyproject.toml | 2 +- .../common/vllm/utils.py | 7 ++-- .../embedders/vllm/document_embedder.py | 35 ++++++++++++------- .../embedders/vllm/text_embedder.py | 22 ++++++++---- .../vllm/tests/test_document_embedder.py | 16 ++++++--- integrations/vllm/tests/test_text_embedder.py | 22 ++++++++---- 6 files changed, 67 insertions(+), 37 deletions(-) diff --git a/integrations/vllm/pyproject.toml b/integrations/vllm/pyproject.toml index d3c5e75a1a..93f7e41c67 100644 --- a/integrations/vllm/pyproject.toml +++ b/integrations/vllm/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.23.0", "openai"] +dependencies = ["haystack-ai>=2.23.0", "openai", "more_itertools", "tqdm"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/vllm#readme" diff --git a/integrations/vllm/src/haystack_integrations/common/vllm/utils.py b/integrations/vllm/src/haystack_integrations/common/vllm/utils.py index 92615f5906..becf3cf870 100644 --- a/integrations/vllm/src/haystack_integrations/common/vllm/utils.py +++ b/integrations/vllm/src/haystack_integrations/common/vllm/utils.py @@ -19,10 +19,9 @@ def _create_openai_clients( """ Build sync and async OpenAI clients pointing at a vLLM server. - A placeholder api key is used when the user did not supply one and no `VLLM_API_KEY` env var is - set, because the OpenAI client requires a non-empty value. `timeout` and `max_retries` are only - forwarded when provided: when None, the OpenAI client's own defaults apply and no `OPENAI_*` - env vars are read. + A placeholder api key is used when the user did not supply one and no `VLLM_API_KEY` env var is set, because the + OpenAI client requires a non-empty value. + `timeout` and `max_retries` are only forwarded when provided: when None, the OpenAI client's own defaults apply. """ resolved_api_key = "placeholder-api-key" if api_key is not None and (value := api_key.resolve_value()): diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py index 882a8cd597..a1483ed3b8 100644 --- a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py @@ -31,7 +31,7 @@ class VLLMDocumentEmbedder: Before using this component, start a vLLM server with an embedding model: ```bash - vllm serve intfloat/e5-mistral-7b-instruct + vllm serve google/embeddinggemma-300m ``` For details on server options, see the [vLLM CLI docs](https://docs.vllm.ai/en/stable/cli/serve/). @@ -44,7 +44,7 @@ class VLLMDocumentEmbedder: doc = Document(content="I love pizza!") - document_embedder = VLLMDocumentEmbedder(model="intfloat/e5-mistral-7b-instruct") + document_embedder = VLLMDocumentEmbedder(model="google/embeddinggemma-300m") result = document_embedder.run([doc]) print(result["documents"][0].embedding) @@ -57,8 +57,8 @@ class VLLMDocumentEmbedder: ```python document_embedder = VLLMDocumentEmbedder( - model="jinaai/jina-embeddings-v3", - extra_parameters={"dimensions": 32, "truncate_prompt_tokens": 256}, + model="google/embeddinggemma-300m", + extra_parameters={"truncate_prompt_tokens": 256, "truncation_side": "right"}, ) ``` """ @@ -71,6 +71,7 @@ def __init__( api_base_url: str = "http://localhost:8000/v1", prefix: str = "", suffix: str = "", + dimensions: int | None = None, batch_size: int = 32, progress_bar: bool = True, meta_fields_to_embed: list[str] | None = None, @@ -84,16 +85,21 @@ def __init__( """ Creates an instance of VLLMDocumentEmbedder. - :param model: The name of the model served by vLLM (e.g., "intfloat/e5-mistral-7b-instruct"). + :param model: The name of the model served by vLLM. Check + [vLLM's documentation](https://docs.vllm.ai/en/stable/models/pooling_models) for more information. :param api_key: The vLLM API key. Defaults to the `VLLM_API_KEY` environment variable. Only required if the vLLM server was started with `--api-key`. :param api_base_url: The base URL of the vLLM server. :param prefix: A string to add at the beginning of each text. :param suffix: A string to add at the end of each text. - :param batch_size: Number of Documents to encode at once. - :param progress_bar: Whether to show a progress bar. Disable in production to keep logs clean. - :param meta_fields_to_embed: List of meta fields to embed along with the Document text. - :param embedding_separator: Separator used to concatenate the meta fields to the Document text. + :param dimensions: The number of dimensions of the resulting embedding. Only models trained with + Matryoshka Representation Learning support this parameter. See + [vLLMs documentation](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#matryoshka-embeddings) + for more information. + :param batch_size: Number of documents to encode at once. + :param progress_bar: Whether to show a progress bar. + :param meta_fields_to_embed: List of meta fields to embed along with the document text. + :param embedding_separator: Separator used to concatenate the meta fields to the document text. :param timeout: Timeout in seconds for vLLM client calls. If not set, the OpenAI client default applies. :param max_retries: Maximum number of retries for failed requests. If not set, the OpenAI client default applies. @@ -104,15 +110,15 @@ def __init__( the component logs the error and continues processing the remaining documents. :param extra_parameters: Additional parameters forwarded as `extra_body` to the vLLM embeddings endpoint. Use this to pass parameters not part of the standard OpenAI Embeddings API, such as - `dimensions` (for Matryoshka models), `truncate_prompt_tokens`, `truncation_side`, - `additional_data`, `use_activation`, etc. See the - [vLLM Embeddings API docs](https://docs.vllm.ai/en/stable/models/pooling_models.html#openai-compatible-embeddings-api). + `truncate_prompt_tokens`, `truncation_side`, etc. See the + [vLLM Embeddings API docs](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#openai-compatible-embeddings-api). """ self.model = model self.api_key = api_key self.api_base_url = api_base_url self.prefix = prefix self.suffix = suffix + self.dimensions = dimensions self.batch_size = batch_size self.progress_bar = progress_bar self.meta_fields_to_embed = meta_fields_to_embed or [] @@ -149,10 +155,11 @@ def to_dict(self) -> dict[str, Any]: return default_to_dict( self, model=self.model, - api_key=self.api_key.to_dict() if self.api_key else None, + api_key=self.api_key, api_base_url=self.api_base_url, prefix=self.prefix, suffix=self.suffix, + dimensions=self.dimensions, batch_size=self.batch_size, progress_bar=self.progress_bar, meta_fields_to_embed=self.meta_fields_to_embed, @@ -183,6 +190,8 @@ def _prepare_texts_to_embed(self, documents: list[Document]) -> dict[str, str]: def _prepare_input(self, inputs: list[str]) -> dict[str, Any]: kwargs: dict[str, Any] = {"model": self.model, "input": inputs, "encoding_format": "float"} + if self.dimensions is not None: + kwargs["dimensions"] = self.dimensions if self.extra_parameters: kwargs["extra_body"] = self.extra_parameters return kwargs diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py index 090e1708cb..c624967ede 100644 --- a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py @@ -25,7 +25,7 @@ class VLLMTextEmbedder: Before using this component, start a vLLM server with an embedding model: ```bash - vllm serve intfloat/e5-mistral-7b-instruct + vllm serve google/embeddinggemma-300m ``` For details on server options, see the [vLLM CLI docs](https://docs.vllm.ai/en/stable/cli/serve/). @@ -35,7 +35,7 @@ class VLLMTextEmbedder: ```python from haystack_integrations.components.embedders.vllm import VLLMTextEmbedder - text_embedder = VLLMTextEmbedder(model="intfloat/e5-mistral-7b-instruct") + text_embedder = VLLMTextEmbedder(model="google/embeddinggemma-300m") print(text_embedder.run("I love pizza!")) ``` @@ -46,8 +46,8 @@ class VLLMTextEmbedder: ```python text_embedder = VLLMTextEmbedder( - model="jinaai/jina-embeddings-v3", - extra_parameters={"dimensions": 32, "truncate_prompt_tokens": 256}, + model="google/embeddinggemma-300m", + extra_parameters={"truncate_prompt_tokens": 256, "truncation_side": "right"}, ) ``` """ @@ -60,6 +60,7 @@ def __init__( api_base_url: str = "http://localhost:8000/v1", prefix: str = "", suffix: str = "", + dimensions: int | None = None, timeout: float | None = None, max_retries: int | None = None, http_client_kwargs: dict[str, Any] | None = None, @@ -74,6 +75,10 @@ def __init__( :param api_base_url: The base URL of the vLLM server. :param prefix: A string to add at the beginning of each text to embed. :param suffix: A string to add at the end of each text to embed. + :param dimensions: The number of dimensions of the resulting embedding. Only models trained with + Matryoshka Representation Learning support this parameter. See + [vLLMs documentation](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#matryoshka-embeddings) + for more information. :param timeout: Timeout in seconds for vLLM client calls. If not set, the OpenAI client default applies. :param max_retries: Maximum number of retries for failed requests. If not set, the OpenAI client default applies. @@ -82,15 +87,15 @@ def __init__( [HTTPX documentation](https://www.python-httpx.org/api/#client). :param extra_parameters: Additional parameters forwarded as `extra_body` to the vLLM embeddings endpoint. Use this to pass parameters not part of the standard OpenAI Embeddings API, such as - `dimensions` (for Matryoshka models), `truncate_prompt_tokens`, `truncation_side`, - `additional_data`, `use_activation`, etc. See the - [vLLM Embeddings API docs](https://docs.vllm.ai/en/stable/models/pooling_models.html#openai-compatible-embeddings-api). + `truncate_prompt_tokens`, `truncation_side`, `additional_data`, `use_activation`, etc. See the + [vLLM Embeddings API docs](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#openai-compatible-embeddings-api). """ self.model = model self.api_key = api_key self.api_base_url = api_base_url self.prefix = prefix self.suffix = suffix + self.dimensions = dimensions self.timeout = timeout self.max_retries = max_retries self.http_client_kwargs = http_client_kwargs @@ -126,6 +131,7 @@ def to_dict(self) -> dict[str, Any]: api_base_url=self.api_base_url, prefix=self.prefix, suffix=self.suffix, + dimensions=self.dimensions, timeout=self.timeout, max_retries=self.max_retries, http_client_kwargs=self.http_client_kwargs, @@ -150,6 +156,8 @@ def _prepare_input(self, text: str) -> dict[str, Any]: "input": self.prefix + text + self.suffix, "encoding_format": "float", } + if self.dimensions is not None: + kwargs["dimensions"] = self.dimensions if self.extra_parameters: kwargs["extra_body"] = self.extra_parameters return kwargs diff --git a/integrations/vllm/tests/test_document_embedder.py b/integrations/vllm/tests/test_document_embedder.py index d8196d31e7..c41fd512e4 100644 --- a/integrations/vllm/tests/test_document_embedder.py +++ b/integrations/vllm/tests/test_document_embedder.py @@ -39,6 +39,7 @@ def test_init_default(self, monkeypatch): assert embedder.api_base_url == "http://localhost:8000/v1" assert embedder.prefix == "" assert embedder.suffix == "" + assert embedder.dimensions is None assert embedder.batch_size == 32 assert embedder.progress_bar is True assert embedder.meta_fields_to_embed == [] @@ -56,6 +57,7 @@ def test_init_with_parameters(self): api_base_url="http://my-vllm-server:8000/v1", prefix="START", suffix="END", + dimensions=64, batch_size=64, progress_bar=False, meta_fields_to_embed=["test_field"], @@ -67,6 +69,7 @@ def test_init_with_parameters(self): assert embedder.api_base_url == "http://my-vllm-server:8000/v1" assert embedder.prefix == "START" assert embedder.suffix == "END" + assert embedder.dimensions == 64 assert embedder.batch_size == 64 assert embedder.progress_bar is False assert embedder.meta_fields_to_embed == ["test_field"] @@ -101,6 +104,7 @@ def test_to_dict(self, monkeypatch): "api_base_url": "http://localhost:8000/v1", "prefix": "", "suffix": "", + "dimensions": None, "batch_size": 32, "progress_bar": True, "meta_fields_to_embed": [], @@ -123,6 +127,7 @@ def test_from_dict(self, monkeypatch): "api_base_url": "http://localhost:8000/v1", "prefix": "", "suffix": "", + "dimensions": 32, "batch_size": 32, "progress_bar": True, "meta_fields_to_embed": [], @@ -131,7 +136,7 @@ def test_from_dict(self, monkeypatch): "max_retries": None, "http_client_kwargs": None, "raise_on_failure": False, - "extra_parameters": {"dimensions": 32}, + "extra_parameters": None, }, } embedder = VLLMDocumentEmbedder.from_dict(data) @@ -139,7 +144,7 @@ def test_from_dict(self, monkeypatch): assert embedder.model == MODEL assert embedder.api_base_url == "http://localhost:8000/v1" assert embedder.batch_size == 32 - assert embedder.extra_parameters == {"dimensions": 32} + assert embedder.dimensions == 32 def test_prepare_texts_to_embed(self): embedder = VLLMDocumentEmbedder( @@ -149,14 +154,15 @@ def test_prepare_texts_to_embed(self): texts = embedder._prepare_texts_to_embed([doc]) assert texts == {doc.id: "[ML | hello]"} - def test_prepare_input_adds_extra_body(self): - embedder = VLLMDocumentEmbedder(model=MODEL, extra_parameters={"dimensions": 32}) + def test_prepare_input_adds_dimensions_and_extra_body(self): + embedder = VLLMDocumentEmbedder(model=MODEL, dimensions=32, extra_parameters={"truncate_prompt_tokens": 256}) kwargs = embedder._prepare_input(["a", "b"]) assert kwargs == { "model": MODEL, "input": ["a", "b"], "encoding_format": "float", - "extra_body": {"dimensions": 32}, + "dimensions": 32, + "extra_body": {"truncate_prompt_tokens": 256}, } def test_run_wrong_input_format(self): diff --git a/integrations/vllm/tests/test_text_embedder.py b/integrations/vllm/tests/test_text_embedder.py index 9ec515e572..05f83f4b79 100644 --- a/integrations/vllm/tests/test_text_embedder.py +++ b/integrations/vllm/tests/test_text_embedder.py @@ -33,6 +33,7 @@ def test_init_default(self, monkeypatch): assert embedder.model == MODEL assert embedder.prefix == "" assert embedder.suffix == "" + assert embedder.dimensions is None assert embedder.timeout is None assert embedder.max_retries is None assert embedder.http_client_kwargs is None @@ -48,6 +49,7 @@ def test_init_with_parameters(self): api_base_url="http://my-vllm-server:8000/v1", prefix="START", suffix="END", + dimensions=64, timeout=10.0, max_retries=2, http_client_kwargs={"proxy": "https://proxy.example.com"}, @@ -58,6 +60,7 @@ def test_init_with_parameters(self): assert embedder.model == MODEL assert embedder.prefix == "START" assert embedder.suffix == "END" + assert embedder.dimensions == 64 assert embedder.timeout == 10.0 assert embedder.max_retries == 2 assert embedder.http_client_kwargs == {"proxy": "https://proxy.example.com"} @@ -90,6 +93,7 @@ def test_to_dict(self, monkeypatch): "api_base_url": "http://localhost:8000/v1", "prefix": "", "suffix": "", + "dimensions": None, "timeout": None, "max_retries": None, "http_client_kwargs": None, @@ -107,26 +111,30 @@ def test_from_dict(self, monkeypatch): "api_base_url": "http://localhost:8000/v1", "prefix": "", "suffix": "", + "dimensions": 32, "timeout": None, "max_retries": None, "http_client_kwargs": None, - "extra_parameters": {"dimensions": 32}, + "extra_parameters": None, }, } embedder = VLLMTextEmbedder.from_dict(data) assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) assert embedder.model == MODEL assert embedder.api_base_url == "http://localhost:8000/v1" - assert embedder.extra_parameters == {"dimensions": 32} + assert embedder.dimensions == 32 - def test_prepare_input_adds_extra_body(self): - embedder = VLLMTextEmbedder(model=MODEL, prefix="[", suffix="]", extra_parameters={"dimensions": 32}) + def test_prepare_input_adds_dimensions_and_extra_body(self): + embedder = VLLMTextEmbedder( + model=MODEL, prefix="[", suffix="]", dimensions=32, extra_parameters={"truncate_prompt_tokens": 256} + ) kwargs = embedder._prepare_input("hello") assert kwargs == { "model": MODEL, "input": "[hello]", "encoding_format": "float", - "extra_body": {"dimensions": 32}, + "dimensions": 32, + "extra_body": {"truncate_prompt_tokens": 256}, } def test_run_wrong_input_format(self): @@ -135,7 +143,7 @@ def test_run_wrong_input_format(self): embedder.run(text=["text_1", "text_2"]) def test_run_with_mock(self): - embedder = VLLMTextEmbedder(model=MODEL, prefix="[", suffix="]", extra_parameters={"dimensions": 2}) + embedder = VLLMTextEmbedder(model=MODEL, prefix="[", suffix="]", dimensions=2) embedder._client = MagicMock() embedder._client.embeddings.create.return_value = _fake_response([[0.1, 0.2]]) embedder._is_warmed_up = True @@ -144,7 +152,7 @@ def test_run_with_mock(self): call_kwargs = embedder._client.embeddings.create.call_args.kwargs assert call_kwargs["input"] == "[hello]" - assert call_kwargs["extra_body"] == {"dimensions": 2} + assert call_kwargs["dimensions"] == 2 assert result == { "embedding": [0.1, 0.2], "meta": {"model": "fake-model", "usage": {"prompt_tokens": 5, "total_tokens": 5}}, From 5207261c87aa9d180b89cfa2a2aa17aaa9b914a4 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Tue, 14 Apr 2026 19:16:45 +0200 Subject: [PATCH 3/7] integration tests on the ci --- integrations/vllm/README.md | 10 ++++++--- .../embedders/vllm/document_embedder.py | 4 ++-- .../embedders/vllm/text_embedder.py | 2 +- .../vllm/tests/test_document_embedder.py | 22 ++++++++++++------- integrations/vllm/tests/test_text_embedder.py | 16 ++++++++------ integrations/vllm/tests/test_utils.py | 2 -- 6 files changed, 33 insertions(+), 23 deletions(-) diff --git a/integrations/vllm/README.md b/integrations/vllm/README.md index 3d987bbefb..5f86dc33a1 100644 --- a/integrations/vllm/README.md +++ b/integrations/vllm/README.md @@ -11,10 +11,14 @@ Refer to the general [Contribution Guidelines](https://github.com/deepset-ai/haystack-core-integrations/blob/main/CONTRIBUTING.md). -To run integration tests locally, you need to have a running vLLM server. Refer to the [workflow file](https://github.com/deepset-ai/haystack-core-integrations/blob/main/.github/workflows/vllm.yml) for more details. +To run integration tests locally, you need two vLLM servers running in parallel: one for the chat generator on port `8000` and one for the embedders on port `8001`. Refer to the [workflow file](https://github.com/deepset-ai/haystack-core-integrations/blob/main/.github/workflows/vllm.yml) for more details. -For example, on macOs, you can install [vLLM-metal](https://github.com/vllm-project/vllm-metal) and run the server with: +For example, on macOs, you can install [vLLM-metal](https://github.com/vllm-project/vllm-metal) and start both servers with: ```bash -source ~/.venv-vllm-metal/bin/activate && vllm serve Qwen/Qwen3-0.6B --reasoning-parser qwen3 --max-model-len 1024 --enforce-eager --enable-auto-tool-choice --tool-call-parser hermes +# chat generator server (port 8000) +source ~/.venv-vllm-metal/bin/activate && vllm serve Qwen/Qwen3-0.6B --reasoning-parser qwen3 --max-model-len 1024 --enforce-eager --enable-auto-tool-choice --tool-call-parser hermes + +# embedders server (port 8001) +source ~/.venv-vllm-metal/bin/activate && vllm serve sergeyzh/rubert-tiny-turbo --port 8001 --enforce-eager --max-num-seqs 1 ``` \ No newline at end of file diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py index a1483ed3b8..686eecb0a6 100644 --- a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py @@ -86,7 +86,7 @@ def __init__( Creates an instance of VLLMDocumentEmbedder. :param model: The name of the model served by vLLM. Check - [vLLM's documentation](https://docs.vllm.ai/en/stable/models/pooling_models) for more information. + [vLLM documentation](https://docs.vllm.ai/en/stable/models/pooling_models) for more information. :param api_key: The vLLM API key. Defaults to the `VLLM_API_KEY` environment variable. Only required if the vLLM server was started with `--api-key`. :param api_base_url: The base URL of the vLLM server. @@ -94,7 +94,7 @@ def __init__( :param suffix: A string to add at the end of each text. :param dimensions: The number of dimensions of the resulting embedding. Only models trained with Matryoshka Representation Learning support this parameter. See - [vLLMs documentation](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#matryoshka-embeddings) + [vLLM documentation](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#matryoshka-embeddings) for more information. :param batch_size: Number of documents to encode at once. :param progress_bar: Whether to show a progress bar. diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py index c624967ede..dca1c9a4fd 100644 --- a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py @@ -77,7 +77,7 @@ def __init__( :param suffix: A string to add at the end of each text to embed. :param dimensions: The number of dimensions of the resulting embedding. Only models trained with Matryoshka Representation Learning support this parameter. See - [vLLMs documentation](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#matryoshka-embeddings) + [vLLM documentation](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#matryoshka-embeddings) for more information. :param timeout: Timeout in seconds for vLLM client calls. If not set, the OpenAI client default applies. :param max_retries: Maximum number of retries for failed requests. If not set, the OpenAI client diff --git a/integrations/vllm/tests/test_document_embedder.py b/integrations/vllm/tests/test_document_embedder.py index c41fd512e4..e9f60c11e1 100644 --- a/integrations/vllm/tests/test_document_embedder.py +++ b/integrations/vllm/tests/test_document_embedder.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -import os from unittest.mock import AsyncMock, MagicMock import pytest @@ -13,7 +12,8 @@ from haystack_integrations.components.embedders.vllm import VLLMDocumentEmbedder -MODEL = "intfloat/e5-mistral-7b-instruct" +MODEL = "sergeyzh/rubert-tiny-turbo" +API_BASE_URL = "http://localhost:8001/v1" def _fake_response(embeddings: list[list[float]], prompt_tokens: int = 1, total_tokens: int = 1): @@ -143,8 +143,18 @@ def test_from_dict(self, monkeypatch): assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) assert embedder.model == MODEL assert embedder.api_base_url == "http://localhost:8000/v1" - assert embedder.batch_size == 32 + assert embedder.prefix == "" + assert embedder.suffix == "" assert embedder.dimensions == 32 + assert embedder.batch_size == 32 + assert embedder.progress_bar is True + assert embedder.meta_fields_to_embed == [] + assert embedder.embedding_separator == "\n" + assert embedder.timeout is None + assert embedder.max_retries is None + assert embedder.http_client_kwargs is None + assert embedder.raise_on_failure is False + assert embedder.extra_parameters is None def test_prepare_texts_to_embed(self): embedder = VLLMDocumentEmbedder( @@ -224,13 +234,9 @@ async def test_run_async(self): assert [d.embedding for d in result["documents"]] == [[0.5], [0.6]] - @pytest.mark.skipif( - not os.environ.get("VLLM_API_BASE_URL", None), - reason="Export VLLM_API_BASE_URL pointing to a running vLLM embedding server to run this test.", - ) @pytest.mark.integration def test_run(self): - embedder = VLLMDocumentEmbedder(model=MODEL, api_base_url=os.environ["VLLM_API_BASE_URL"]) + embedder = VLLMDocumentEmbedder(model=MODEL, api_base_url=API_BASE_URL) docs = [ Document(content="I love cheese", meta={"topic": "Cuisine"}), diff --git a/integrations/vllm/tests/test_text_embedder.py b/integrations/vllm/tests/test_text_embedder.py index 05f83f4b79..1713d9c47f 100644 --- a/integrations/vllm/tests/test_text_embedder.py +++ b/integrations/vllm/tests/test_text_embedder.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: 2022-present deepset GmbH # # SPDX-License-Identifier: Apache-2.0 -import os from unittest.mock import AsyncMock, MagicMock import pytest @@ -11,7 +10,8 @@ from haystack_integrations.components.embedders.vllm import VLLMTextEmbedder -MODEL = "intfloat/e5-mistral-7b-instruct" +MODEL = "sergeyzh/rubert-tiny-turbo" +API_BASE_URL = "http://localhost:8001/v1" def _fake_response(embeddings: list[list[float]], prompt_tokens: int = 5, total_tokens: int = 5): @@ -122,7 +122,13 @@ def test_from_dict(self, monkeypatch): assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) assert embedder.model == MODEL assert embedder.api_base_url == "http://localhost:8000/v1" + assert embedder.prefix == "" + assert embedder.suffix == "" assert embedder.dimensions == 32 + assert embedder.timeout is None + assert embedder.max_retries is None + assert embedder.http_client_kwargs is None + assert embedder.extra_parameters is None def test_prepare_input_adds_dimensions_and_extra_body(self): embedder = VLLMTextEmbedder( @@ -168,13 +174,9 @@ async def test_run_async(self): result = await embedder.run_async("world") assert result["embedding"] == [0.3, 0.4] - @pytest.mark.skipif( - not os.environ.get("VLLM_API_BASE_URL", None), - reason="Export VLLM_API_BASE_URL pointing to a running vLLM embedding server to run this test.", - ) @pytest.mark.integration def test_run(self): - embedder = VLLMTextEmbedder(model=MODEL, api_base_url=os.environ["VLLM_API_BASE_URL"]) + embedder = VLLMTextEmbedder(model=MODEL, api_base_url=API_BASE_URL) result = embedder.run("The food was delicious") assert isinstance(result["embedding"], list) assert all(isinstance(x, float) for x in result["embedding"]) diff --git a/integrations/vllm/tests/test_utils.py b/integrations/vllm/tests/test_utils.py index 62efd9634c..b017762955 100644 --- a/integrations/vllm/tests/test_utils.py +++ b/integrations/vllm/tests/test_utils.py @@ -8,7 +8,6 @@ def test_create_openai_clients_placeholder_when_no_key(): - """When api_key is None or unresolved, a placeholder is used.""" sync_client, async_client = _create_openai_clients( api_key=None, api_base_url="http://localhost:8000/v1", timeout=None, max_retries=None, http_client_kwargs=None ) @@ -18,7 +17,6 @@ def test_create_openai_clients_placeholder_when_no_key(): def test_create_openai_clients_uses_resolved_key_and_forwards_options(): - """When api_key resolves, it's used; timeout/max_retries forwarded only when set.""" sync_client, _ = _create_openai_clients( api_key=Secret.from_token("real-key"), api_base_url="http://vllm:8000/v1", From 5d62dd66996b8485719ce5fafaad60b1c7fbb69b Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 15 Apr 2026 09:02:57 +0200 Subject: [PATCH 4/7] fixes --- .github/workflows/vllm.yml | 37 ++++++++++++++++--- integrations/vllm/README.md | 9 ++++- .../vllm/tests/test_document_embedder.py | 36 ++++++++++++++++-- integrations/vllm/tests/test_text_embedder.py | 12 +++++- 4 files changed, 81 insertions(+), 13 deletions(-) diff --git a/.github/workflows/vllm.yml b/.github/workflows/vllm.yml index 3dad3d8363..4c021b1473 100644 --- a/.github/workflows/vllm.yml +++ b/.github/workflows/vllm.yml @@ -30,6 +30,7 @@ env: PYTHONUNBUFFERED: "1" FORCE_COLOR: "1" VLLM_MODEL: "Qwen/Qwen3-0.6B" + VLLM_EMBEDDING_MODEL: "sentence-transformers/all-MiniLM-L6-v2" # we only test on Ubuntu to keep vLLM server running simple TEST_MATRIX_OS: '["ubuntu-latest"]' # vLLM is not compatible with Python 3.14. https://github.com/vllm-project/vllm/issues/34096 @@ -88,12 +89,13 @@ jobs: "https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cpu-cp38-abi3-manylinux_2_35_x86_64.whl" \ --torch-backend cpu - - name: Start vLLM server + - name: Start vLLM chat server env: VLLM_TARGET_DEVICE: "cpu" VLLM_CPU_KVCACHE_SPACE: "4" run: | nohup hatch run -- vllm serve ${{ env.VLLM_MODEL }} \ + --port 8000 \ --reasoning-parser qwen3 \ --max-model-len 1024 \ --enforce-eager \ @@ -102,20 +104,45 @@ jobs: --tool-call-parser hermes \ --max-num-seqs 1 & - # Wait for the vLLM server to be ready with a timeout of 300 seconds + # Wait for the vLLM chat server to be ready with a timeout of 300 seconds timeout=300 while [ $timeout -gt 0 ] && ! curl -sSf http://localhost:8000/health > /dev/null 2>&1; do - echo "Waiting for vLLM server to start..." + echo "Waiting for vLLM chat server to start..." sleep 10 ((timeout-=10)) done if [ $timeout -eq 0 ]; then - echo "Timed out waiting for vLLM server to start." + echo "Timed out waiting for vLLM chat server to start." exit 1 fi - echo "vLLM server started successfully." + echo "vLLM chat server started successfully." + + - name: Start vLLM embedding server + env: + VLLM_TARGET_DEVICE: "cpu" + VLLM_CPU_KVCACHE_SPACE: "4" + run: | + nohup hatch run -- vllm serve ${{ env.VLLM_EMBEDDING_MODEL }} \ + --port 8001 \ + --enforce-eager \ + --max-num-seqs 1 & + + # Wait for the vLLM embedding server to be ready with a timeout of 300 seconds + timeout=300 + while [ $timeout -gt 0 ] && ! curl -sSf http://localhost:8001/health > /dev/null 2>&1; do + echo "Waiting for vLLM embedding server to start..." + sleep 10 + ((timeout-=10)) + done + + if [ $timeout -eq 0 ]; then + echo "Timed out waiting for vLLM embedding server to start." + exit 1 + fi + + echo "vLLM embedding server started successfully." - name: Lint if: matrix.python-version == '3.10' && runner.os == 'Linux' diff --git a/integrations/vllm/README.md b/integrations/vllm/README.md index 5f86dc33a1..a0a498cf8e 100644 --- a/integrations/vllm/README.md +++ b/integrations/vllm/README.md @@ -13,12 +13,17 @@ Refer to the general [Contribution Guidelines](https://github.com/deepset-ai/hay To run integration tests locally, you need two vLLM servers running in parallel: one for the chat generator on port `8000` and one for the embedders on port `8001`. Refer to the [workflow file](https://github.com/deepset-ai/haystack-core-integrations/blob/main/.github/workflows/vllm.yml) for more details. -For example, on macOs, you can install [vLLM-metal](https://github.com/vllm-project/vllm-metal) and start both servers with: +For example, on macOs, you can install [vLLM-metal](https://github.com/vllm-project/vllm-metal) and start the chat generator server with: ```bash # chat generator server (port 8000) source ~/.venv-vllm-metal/bin/activate && vllm serve Qwen/Qwen3-0.6B --reasoning-parser qwen3 --max-model-len 1024 --enforce-eager --enable-auto-tool-choice --tool-call-parser hermes +``` +vLLM-metal does not support embedding models. On macOS, you can run the embedding server via CPU Docker image: + +```bash # embedders server (port 8001) -source ~/.venv-vllm-metal/bin/activate && vllm serve sergeyzh/rubert-tiny-turbo --port 8001 --enforce-eager --max-num-seqs 1 +docker run --rm -p 8001:8000 -e VLLM_CPU_OMP_THREADS_BIND=0-3 vllm/vllm-openai-cpu:latest \ + --model sentence-transformers/all-MiniLM-L6-v2 --enforce-eager ``` \ No newline at end of file diff --git a/integrations/vllm/tests/test_document_embedder.py b/integrations/vllm/tests/test_document_embedder.py index e9f60c11e1..6f7bcc1bdb 100644 --- a/integrations/vllm/tests/test_document_embedder.py +++ b/integrations/vllm/tests/test_document_embedder.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 from unittest.mock import AsyncMock, MagicMock +import numpy as np import pytest from haystack import Document from haystack.utils import Secret @@ -12,7 +13,7 @@ from haystack_integrations.components.embedders.vllm import VLLMDocumentEmbedder -MODEL = "sergeyzh/rubert-tiny-turbo" +MODEL = "sentence-transformers/all-MiniLM-L6-v2" API_BASE_URL = "http://localhost:8001/v1" @@ -235,12 +236,13 @@ async def test_run_async(self): assert [d.embedding for d in result["documents"]] == [[0.5], [0.6]] @pytest.mark.integration - def test_run(self): + def test_live_run(self): embedder = VLLMDocumentEmbedder(model=MODEL, api_base_url=API_BASE_URL) docs = [ - Document(content="I love cheese", meta={"topic": "Cuisine"}), - Document(content="A transformer is a deep learning architecture", meta={"topic": "ML"}), + Document(content="I love cheese"), + Document(content="Cheddar is my favorite food"), + Document(content="A transformer is a deep learning architecture"), ] result = embedder.run(docs) @@ -250,3 +252,29 @@ def test_run(self): for doc in docs_with_embeddings: assert isinstance(doc.embedding, list) assert isinstance(doc.embedding[0], float) + + embeddings = [np.array(d.embedding) for d in docs_with_embeddings] + + def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: + return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) + + assert cosine_similarity(embeddings[0], embeddings[1]) > cosine_similarity(embeddings[0], embeddings[2]) + + @pytest.mark.integration + @pytest.mark.asyncio + async def test_live_run_async(self): + embedder = VLLMDocumentEmbedder(model=MODEL, api_base_url=API_BASE_URL) + + docs = [ + Document(content="I love cheese"), + Document(content="Cheddar is my favorite food"), + Document(content="A transformer is a deep learning architecture"), + ] + + result = await embedder.run_async(docs) + docs_with_embeddings = result["documents"] + + assert len(docs_with_embeddings) == len(docs) + for doc in docs_with_embeddings: + assert isinstance(doc.embedding, list) + assert isinstance(doc.embedding[0], float) diff --git a/integrations/vllm/tests/test_text_embedder.py b/integrations/vllm/tests/test_text_embedder.py index 1713d9c47f..af302083e4 100644 --- a/integrations/vllm/tests/test_text_embedder.py +++ b/integrations/vllm/tests/test_text_embedder.py @@ -10,7 +10,7 @@ from haystack_integrations.components.embedders.vllm import VLLMTextEmbedder -MODEL = "sergeyzh/rubert-tiny-turbo" +MODEL = "sentence-transformers/all-MiniLM-L6-v2" API_BASE_URL = "http://localhost:8001/v1" @@ -175,8 +175,16 @@ async def test_run_async(self): assert result["embedding"] == [0.3, 0.4] @pytest.mark.integration - def test_run(self): + def test_live_run(self): embedder = VLLMTextEmbedder(model=MODEL, api_base_url=API_BASE_URL) result = embedder.run("The food was delicious") assert isinstance(result["embedding"], list) assert all(isinstance(x, float) for x in result["embedding"]) + + @pytest.mark.asyncio + @pytest.mark.integration + async def test_live_run_async(self): + embedder = VLLMTextEmbedder(model=MODEL, api_base_url=API_BASE_URL) + result = await embedder.run_async("The food was delicious") + assert isinstance(result["embedding"], list) + assert all(isinstance(x, float) for x in result["embedding"]) From 78be48fdd094a7c211fde68c2e1328f9ae623621 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 15 Apr 2026 09:12:23 +0200 Subject: [PATCH 5/7] lower bound pin for more-itertools --- integrations/vllm/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/vllm/pyproject.toml b/integrations/vllm/pyproject.toml index 93f7e41c67..746841c4a8 100644 --- a/integrations/vllm/pyproject.toml +++ b/integrations/vllm/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.23.0", "openai", "more_itertools", "tqdm"] +dependencies = ["haystack-ai>=2.23.0", "openai", "more_itertools>=9.0.0", "tqdm"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/vllm#readme" From edc9fe1ace4350c9eb7d72bbffdcf17c2ea4071e Mon Sep 17 00:00:00 2001 From: anakin87 Date: Wed, 15 Apr 2026 09:20:31 +0200 Subject: [PATCH 6/7] more pins --- integrations/vllm/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/integrations/vllm/pyproject.toml b/integrations/vllm/pyproject.toml index 746841c4a8..91d1a9aaa7 100644 --- a/integrations/vllm/pyproject.toml +++ b/integrations/vllm/pyproject.toml @@ -22,7 +22,7 @@ classifiers = [ "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] -dependencies = ["haystack-ai>=2.23.0", "openai", "more_itertools>=9.0.0", "tqdm"] +dependencies = ["haystack-ai>=2.23.0", "openai", "more_itertools>=9.0.0", "tqdm>=4.48.0"] [project.urls] Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/vllm#readme" From 700f44d30e3588cc2e3c7588aac8751e7399fa9c Mon Sep 17 00:00:00 2001 From: anakin87 Date: Thu, 16 Apr 2026 12:16:54 +0200 Subject: [PATCH 7/7] rm serde methods --- .../embedders/vllm/document_embedder.py | 32 +------------------ .../embedders/vllm/text_embedder.py | 27 +--------------- .../vllm/tests/test_document_embedder.py | 5 +-- integrations/vllm/tests/test_text_embedder.py | 5 +-- 4 files changed, 8 insertions(+), 61 deletions(-) diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py index 686eecb0a6..16d5f69206 100644 --- a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py @@ -5,7 +5,7 @@ from dataclasses import replace from typing import Any -from haystack import Document, component, default_from_dict, default_to_dict, logging +from haystack import Document, component, logging from haystack.utils import Secret from more_itertools import batched from openai import APIError, AsyncOpenAI, OpenAI @@ -146,36 +146,6 @@ def warm_up(self) -> None: ) self._is_warmed_up = True - def to_dict(self) -> dict[str, Any]: - """ - Serialize this component to a dictionary. - - :returns: The serialized component as a dictionary. - """ - return default_to_dict( - self, - model=self.model, - api_key=self.api_key, - api_base_url=self.api_base_url, - prefix=self.prefix, - suffix=self.suffix, - dimensions=self.dimensions, - batch_size=self.batch_size, - progress_bar=self.progress_bar, - meta_fields_to_embed=self.meta_fields_to_embed, - embedding_separator=self.embedding_separator, - timeout=self.timeout, - max_retries=self.max_retries, - http_client_kwargs=self.http_client_kwargs, - raise_on_failure=self.raise_on_failure, - extra_parameters=self.extra_parameters, - ) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "VLLMDocumentEmbedder": - """Deserialize this component from a dictionary.""" - return default_from_dict(cls, data) - def _prepare_texts_to_embed(self, documents: list[Document]) -> dict[str, str]: """Concatenate each Document's text with the selected meta fields.""" texts_to_embed = {} diff --git a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py index dca1c9a4fd..2749ea393e 100644 --- a/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py +++ b/integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py @@ -4,7 +4,7 @@ from typing import Any -from haystack import component, default_from_dict, default_to_dict +from haystack import component from haystack.utils import Secret from openai import AsyncOpenAI, OpenAI from openai.types import CreateEmbeddingResponse @@ -118,31 +118,6 @@ def warm_up(self) -> None: ) self._is_warmed_up = True - def to_dict(self) -> dict[str, Any]: - """ - Serialize this component to a dictionary. - - :returns: The serialized component as a dictionary. - """ - return default_to_dict( - self, - model=self.model, - api_key=self.api_key.to_dict() if self.api_key else None, - api_base_url=self.api_base_url, - prefix=self.prefix, - suffix=self.suffix, - dimensions=self.dimensions, - timeout=self.timeout, - max_retries=self.max_retries, - http_client_kwargs=self.http_client_kwargs, - extra_parameters=self.extra_parameters, - ) - - @classmethod - def from_dict(cls, data: dict[str, Any]) -> "VLLMTextEmbedder": - """Deserialize this component from a dictionary.""" - return default_from_dict(cls, data) - def _prepare_input(self, text: str) -> dict[str, Any]: if not isinstance(text, str): msg = ( diff --git a/integrations/vllm/tests/test_document_embedder.py b/integrations/vllm/tests/test_document_embedder.py index 6f7bcc1bdb..a2d2ea6e35 100644 --- a/integrations/vllm/tests/test_document_embedder.py +++ b/integrations/vllm/tests/test_document_embedder.py @@ -6,6 +6,7 @@ import numpy as np import pytest from haystack import Document +from haystack.core.serialization import component_from_dict, component_to_dict from haystack.utils import Secret from openai import APIError from openai.types import CreateEmbeddingResponse, Embedding @@ -96,7 +97,7 @@ def test_warm_up(self, monkeypatch): def test_to_dict(self, monkeypatch): monkeypatch.delenv("VLLM_API_KEY", raising=False) - component_dict = VLLMDocumentEmbedder(model=MODEL).to_dict() + component_dict = component_to_dict(VLLMDocumentEmbedder(model=MODEL), "embedder") assert component_dict == { "type": "haystack_integrations.components.embedders.vllm.document_embedder.VLLMDocumentEmbedder", "init_parameters": { @@ -140,7 +141,7 @@ def test_from_dict(self, monkeypatch): "extra_parameters": None, }, } - embedder = VLLMDocumentEmbedder.from_dict(data) + embedder = component_from_dict(VLLMDocumentEmbedder, data, "embedder") assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) assert embedder.model == MODEL assert embedder.api_base_url == "http://localhost:8000/v1" diff --git a/integrations/vllm/tests/test_text_embedder.py b/integrations/vllm/tests/test_text_embedder.py index af302083e4..7ee2b42051 100644 --- a/integrations/vllm/tests/test_text_embedder.py +++ b/integrations/vllm/tests/test_text_embedder.py @@ -4,6 +4,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from haystack.core.serialization import component_from_dict, component_to_dict from haystack.utils import Secret from openai.types import CreateEmbeddingResponse, Embedding from openai.types.create_embedding_response import Usage @@ -84,7 +85,7 @@ def test_warm_up(self, monkeypatch): def test_to_dict(self, monkeypatch): monkeypatch.delenv("VLLM_API_KEY", raising=False) - component_dict = VLLMTextEmbedder(model=MODEL).to_dict() + component_dict = component_to_dict(VLLMTextEmbedder(model=MODEL), "embedder") assert component_dict == { "type": "haystack_integrations.components.embedders.vllm.text_embedder.VLLMTextEmbedder", "init_parameters": { @@ -118,7 +119,7 @@ def test_from_dict(self, monkeypatch): "extra_parameters": None, }, } - embedder = VLLMTextEmbedder.from_dict(data) + embedder = component_from_dict(VLLMTextEmbedder, data, "embedder") assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False) assert embedder.model == MODEL assert embedder.api_base_url == "http://localhost:8000/v1"