Skip to content

Commit bb20fc7

Browse files
committed
improvements
1 parent 9d0edd7 commit bb20fc7

6 files changed

Lines changed: 67 additions & 37 deletions

File tree

integrations/vllm/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ classifiers = [
2222
"Programming Language :: Python :: Implementation :: CPython",
2323
"Programming Language :: Python :: Implementation :: PyPy",
2424
]
25-
dependencies = ["haystack-ai>=2.23.0", "openai"]
25+
dependencies = ["haystack-ai>=2.23.0", "openai", "more_itertools", "tqdm"]
2626

2727
[project.urls]
2828
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/vllm#readme"

integrations/vllm/src/haystack_integrations/common/vllm/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,9 @@ def _create_openai_clients(
1919
"""
2020
Build sync and async OpenAI clients pointing at a vLLM server.
2121
22-
A placeholder api key is used when the user did not supply one and no `VLLM_API_KEY` env var is
23-
set, because the OpenAI client requires a non-empty value. `timeout` and `max_retries` are only
24-
forwarded when provided: when None, the OpenAI client's own defaults apply and no `OPENAI_*`
25-
env vars are read.
22+
A placeholder api key is used when the user did not supply one and no `VLLM_API_KEY` env var is set, because the
23+
OpenAI client requires a non-empty value.
24+
`timeout` and `max_retries` are only forwarded when provided: when None, the OpenAI client's own defaults apply.
2625
"""
2726
resolved_api_key = "placeholder-api-key"
2827
if api_key is not None and (value := api_key.resolve_value()):

integrations/vllm/src/haystack_integrations/components/embedders/vllm/document_embedder.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class VLLMDocumentEmbedder:
3131
Before using this component, start a vLLM server with an embedding model:
3232
3333
```bash
34-
vllm serve intfloat/e5-mistral-7b-instruct
34+
vllm serve google/embeddinggemma-300m
3535
```
3636
3737
For details on server options, see the [vLLM CLI docs](https://docs.vllm.ai/en/stable/cli/serve/).
@@ -44,7 +44,7 @@ class VLLMDocumentEmbedder:
4444
4545
doc = Document(content="I love pizza!")
4646
47-
document_embedder = VLLMDocumentEmbedder(model="intfloat/e5-mistral-7b-instruct")
47+
document_embedder = VLLMDocumentEmbedder(model="google/embeddinggemma-300m")
4848
4949
result = document_embedder.run([doc])
5050
print(result["documents"][0].embedding)
@@ -57,8 +57,8 @@ class VLLMDocumentEmbedder:
5757
5858
```python
5959
document_embedder = VLLMDocumentEmbedder(
60-
model="jinaai/jina-embeddings-v3",
61-
extra_parameters={"dimensions": 32, "truncate_prompt_tokens": 256},
60+
model="google/embeddinggemma-300m",
61+
extra_parameters={"truncate_prompt_tokens": 256, "truncation_side": "right"},
6262
)
6363
```
6464
"""
@@ -71,6 +71,7 @@ def __init__(
7171
api_base_url: str = "http://localhost:8000/v1",
7272
prefix: str = "",
7373
suffix: str = "",
74+
dimensions: int | None = None,
7475
batch_size: int = 32,
7576
progress_bar: bool = True,
7677
meta_fields_to_embed: list[str] | None = None,
@@ -84,16 +85,21 @@ def __init__(
8485
"""
8586
Creates an instance of VLLMDocumentEmbedder.
8687
87-
:param model: The name of the model served by vLLM (e.g., "intfloat/e5-mistral-7b-instruct").
88+
:param model: The name of the model served by vLLM. Check
89+
[vLLM's documentation](https://docs.vllm.ai/en/stable/models/pooling_models) for more information.
8890
:param api_key: The vLLM API key. Defaults to the `VLLM_API_KEY` environment variable.
8991
Only required if the vLLM server was started with `--api-key`.
9092
:param api_base_url: The base URL of the vLLM server.
9193
:param prefix: A string to add at the beginning of each text.
9294
:param suffix: A string to add at the end of each text.
93-
:param batch_size: Number of Documents to encode at once.
94-
:param progress_bar: Whether to show a progress bar. Disable in production to keep logs clean.
95-
:param meta_fields_to_embed: List of meta fields to embed along with the Document text.
96-
:param embedding_separator: Separator used to concatenate the meta fields to the Document text.
95+
:param dimensions: The number of dimensions of the resulting embedding. Only models trained with
96+
Matryoshka Representation Learning support this parameter. See
97+
[vLLMs documentation](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#matryoshka-embeddings)
98+
for more information.
99+
:param batch_size: Number of documents to encode at once.
100+
:param progress_bar: Whether to show a progress bar.
101+
:param meta_fields_to_embed: List of meta fields to embed along with the document text.
102+
:param embedding_separator: Separator used to concatenate the meta fields to the document text.
97103
:param timeout: Timeout in seconds for vLLM client calls. If not set, the OpenAI client default applies.
98104
:param max_retries: Maximum number of retries for failed requests. If not set, the OpenAI client
99105
default applies.
@@ -104,15 +110,15 @@ def __init__(
104110
the component logs the error and continues processing the remaining documents.
105111
:param extra_parameters: Additional parameters forwarded as `extra_body` to the vLLM embeddings
106112
endpoint. Use this to pass parameters not part of the standard OpenAI Embeddings API, such as
107-
`dimensions` (for Matryoshka models), `truncate_prompt_tokens`, `truncation_side`,
108-
`additional_data`, `use_activation`, etc. See the
109-
[vLLM Embeddings API docs](https://docs.vllm.ai/en/stable/models/pooling_models.html#openai-compatible-embeddings-api).
113+
`truncate_prompt_tokens`, `truncation_side`, etc. See the
114+
[vLLM Embeddings API docs](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#openai-compatible-embeddings-api).
110115
"""
111116
self.model = model
112117
self.api_key = api_key
113118
self.api_base_url = api_base_url
114119
self.prefix = prefix
115120
self.suffix = suffix
121+
self.dimensions = dimensions
116122
self.batch_size = batch_size
117123
self.progress_bar = progress_bar
118124
self.meta_fields_to_embed = meta_fields_to_embed or []
@@ -149,10 +155,11 @@ def to_dict(self) -> dict[str, Any]:
149155
return default_to_dict(
150156
self,
151157
model=self.model,
152-
api_key=self.api_key.to_dict() if self.api_key else None,
158+
api_key=self.api_key,
153159
api_base_url=self.api_base_url,
154160
prefix=self.prefix,
155161
suffix=self.suffix,
162+
dimensions=self.dimensions,
156163
batch_size=self.batch_size,
157164
progress_bar=self.progress_bar,
158165
meta_fields_to_embed=self.meta_fields_to_embed,
@@ -183,6 +190,8 @@ def _prepare_texts_to_embed(self, documents: list[Document]) -> dict[str, str]:
183190

184191
def _prepare_input(self, inputs: list[str]) -> dict[str, Any]:
185192
kwargs: dict[str, Any] = {"model": self.model, "input": inputs, "encoding_format": "float"}
193+
if self.dimensions is not None:
194+
kwargs["dimensions"] = self.dimensions
186195
if self.extra_parameters:
187196
kwargs["extra_body"] = self.extra_parameters
188197
return kwargs

integrations/vllm/src/haystack_integrations/components/embedders/vllm/text_embedder.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class VLLMTextEmbedder:
2525
Before using this component, start a vLLM server with an embedding model:
2626
2727
```bash
28-
vllm serve intfloat/e5-mistral-7b-instruct
28+
vllm serve google/embeddinggemma-300m
2929
```
3030
3131
For details on server options, see the [vLLM CLI docs](https://docs.vllm.ai/en/stable/cli/serve/).
@@ -35,7 +35,7 @@ class VLLMTextEmbedder:
3535
```python
3636
from haystack_integrations.components.embedders.vllm import VLLMTextEmbedder
3737
38-
text_embedder = VLLMTextEmbedder(model="intfloat/e5-mistral-7b-instruct")
38+
text_embedder = VLLMTextEmbedder(model="google/embeddinggemma-300m")
3939
print(text_embedder.run("I love pizza!"))
4040
```
4141
@@ -46,8 +46,8 @@ class VLLMTextEmbedder:
4646
4747
```python
4848
text_embedder = VLLMTextEmbedder(
49-
model="jinaai/jina-embeddings-v3",
50-
extra_parameters={"dimensions": 32, "truncate_prompt_tokens": 256},
49+
model="google/embeddinggemma-300m",
50+
extra_parameters={"truncate_prompt_tokens": 256, "truncation_side": "right"},
5151
)
5252
```
5353
"""
@@ -60,6 +60,7 @@ def __init__(
6060
api_base_url: str = "http://localhost:8000/v1",
6161
prefix: str = "",
6262
suffix: str = "",
63+
dimensions: int | None = None,
6364
timeout: float | None = None,
6465
max_retries: int | None = None,
6566
http_client_kwargs: dict[str, Any] | None = None,
@@ -74,6 +75,10 @@ def __init__(
7475
:param api_base_url: The base URL of the vLLM server.
7576
:param prefix: A string to add at the beginning of each text to embed.
7677
:param suffix: A string to add at the end of each text to embed.
78+
:param dimensions: The number of dimensions of the resulting embedding. Only models trained with
79+
Matryoshka Representation Learning support this parameter. See
80+
[vLLMs documentation](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#matryoshka-embeddings)
81+
for more information.
7782
:param timeout: Timeout in seconds for vLLM client calls. If not set, the OpenAI client default applies.
7883
:param max_retries: Maximum number of retries for failed requests. If not set, the OpenAI client
7984
default applies.
@@ -82,15 +87,15 @@ def __init__(
8287
[HTTPX documentation](https://www.python-httpx.org/api/#client).
8388
:param extra_parameters: Additional parameters forwarded as `extra_body` to the vLLM embeddings
8489
endpoint. Use this to pass parameters not part of the standard OpenAI Embeddings API, such as
85-
`dimensions` (for Matryoshka models), `truncate_prompt_tokens`, `truncation_side`,
86-
`additional_data`, `use_activation`, etc. See the
87-
[vLLM Embeddings API docs](https://docs.vllm.ai/en/stable/models/pooling_models.html#openai-compatible-embeddings-api).
90+
`truncate_prompt_tokens`, `truncation_side`, `additional_data`, `use_activation`, etc. See the
91+
[vLLM Embeddings API docs](https://docs.vllm.ai/en/stable/models/pooling_models/embed/#openai-compatible-embeddings-api).
8892
"""
8993
self.model = model
9094
self.api_key = api_key
9195
self.api_base_url = api_base_url
9296
self.prefix = prefix
9397
self.suffix = suffix
98+
self.dimensions = dimensions
9499
self.timeout = timeout
95100
self.max_retries = max_retries
96101
self.http_client_kwargs = http_client_kwargs
@@ -126,6 +131,7 @@ def to_dict(self) -> dict[str, Any]:
126131
api_base_url=self.api_base_url,
127132
prefix=self.prefix,
128133
suffix=self.suffix,
134+
dimensions=self.dimensions,
129135
timeout=self.timeout,
130136
max_retries=self.max_retries,
131137
http_client_kwargs=self.http_client_kwargs,
@@ -150,6 +156,8 @@ def _prepare_input(self, text: str) -> dict[str, Any]:
150156
"input": self.prefix + text + self.suffix,
151157
"encoding_format": "float",
152158
}
159+
if self.dimensions is not None:
160+
kwargs["dimensions"] = self.dimensions
153161
if self.extra_parameters:
154162
kwargs["extra_body"] = self.extra_parameters
155163
return kwargs

integrations/vllm/tests/test_document_embedder.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def test_init_default(self, monkeypatch):
3939
assert embedder.api_base_url == "http://localhost:8000/v1"
4040
assert embedder.prefix == ""
4141
assert embedder.suffix == ""
42+
assert embedder.dimensions is None
4243
assert embedder.batch_size == 32
4344
assert embedder.progress_bar is True
4445
assert embedder.meta_fields_to_embed == []
@@ -56,6 +57,7 @@ def test_init_with_parameters(self):
5657
api_base_url="http://my-vllm-server:8000/v1",
5758
prefix="START",
5859
suffix="END",
60+
dimensions=64,
5961
batch_size=64,
6062
progress_bar=False,
6163
meta_fields_to_embed=["test_field"],
@@ -67,6 +69,7 @@ def test_init_with_parameters(self):
6769
assert embedder.api_base_url == "http://my-vllm-server:8000/v1"
6870
assert embedder.prefix == "START"
6971
assert embedder.suffix == "END"
72+
assert embedder.dimensions == 64
7073
assert embedder.batch_size == 64
7174
assert embedder.progress_bar is False
7275
assert embedder.meta_fields_to_embed == ["test_field"]
@@ -101,6 +104,7 @@ def test_to_dict(self, monkeypatch):
101104
"api_base_url": "http://localhost:8000/v1",
102105
"prefix": "",
103106
"suffix": "",
107+
"dimensions": None,
104108
"batch_size": 32,
105109
"progress_bar": True,
106110
"meta_fields_to_embed": [],
@@ -123,6 +127,7 @@ def test_from_dict(self, monkeypatch):
123127
"api_base_url": "http://localhost:8000/v1",
124128
"prefix": "",
125129
"suffix": "",
130+
"dimensions": 32,
126131
"batch_size": 32,
127132
"progress_bar": True,
128133
"meta_fields_to_embed": [],
@@ -131,15 +136,15 @@ def test_from_dict(self, monkeypatch):
131136
"max_retries": None,
132137
"http_client_kwargs": None,
133138
"raise_on_failure": False,
134-
"extra_parameters": {"dimensions": 32},
139+
"extra_parameters": None,
135140
},
136141
}
137142
embedder = VLLMDocumentEmbedder.from_dict(data)
138143
assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False)
139144
assert embedder.model == MODEL
140145
assert embedder.api_base_url == "http://localhost:8000/v1"
141146
assert embedder.batch_size == 32
142-
assert embedder.extra_parameters == {"dimensions": 32}
147+
assert embedder.dimensions == 32
143148

144149
def test_prepare_texts_to_embed(self):
145150
embedder = VLLMDocumentEmbedder(
@@ -149,14 +154,15 @@ def test_prepare_texts_to_embed(self):
149154
texts = embedder._prepare_texts_to_embed([doc])
150155
assert texts == {doc.id: "[ML | hello]"}
151156

152-
def test_prepare_input_adds_extra_body(self):
153-
embedder = VLLMDocumentEmbedder(model=MODEL, extra_parameters={"dimensions": 32})
157+
def test_prepare_input_adds_dimensions_and_extra_body(self):
158+
embedder = VLLMDocumentEmbedder(model=MODEL, dimensions=32, extra_parameters={"truncate_prompt_tokens": 256})
154159
kwargs = embedder._prepare_input(["a", "b"])
155160
assert kwargs == {
156161
"model": MODEL,
157162
"input": ["a", "b"],
158163
"encoding_format": "float",
159-
"extra_body": {"dimensions": 32},
164+
"dimensions": 32,
165+
"extra_body": {"truncate_prompt_tokens": 256},
160166
}
161167

162168
def test_run_wrong_input_format(self):

integrations/vllm/tests/test_text_embedder.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def test_init_default(self, monkeypatch):
3333
assert embedder.model == MODEL
3434
assert embedder.prefix == ""
3535
assert embedder.suffix == ""
36+
assert embedder.dimensions is None
3637
assert embedder.timeout is None
3738
assert embedder.max_retries is None
3839
assert embedder.http_client_kwargs is None
@@ -48,6 +49,7 @@ def test_init_with_parameters(self):
4849
api_base_url="http://my-vllm-server:8000/v1",
4950
prefix="START",
5051
suffix="END",
52+
dimensions=64,
5153
timeout=10.0,
5254
max_retries=2,
5355
http_client_kwargs={"proxy": "https://proxy.example.com"},
@@ -58,6 +60,7 @@ def test_init_with_parameters(self):
5860
assert embedder.model == MODEL
5961
assert embedder.prefix == "START"
6062
assert embedder.suffix == "END"
63+
assert embedder.dimensions == 64
6164
assert embedder.timeout == 10.0
6265
assert embedder.max_retries == 2
6366
assert embedder.http_client_kwargs == {"proxy": "https://proxy.example.com"}
@@ -90,6 +93,7 @@ def test_to_dict(self, monkeypatch):
9093
"api_base_url": "http://localhost:8000/v1",
9194
"prefix": "",
9295
"suffix": "",
96+
"dimensions": None,
9397
"timeout": None,
9498
"max_retries": None,
9599
"http_client_kwargs": None,
@@ -107,26 +111,30 @@ def test_from_dict(self, monkeypatch):
107111
"api_base_url": "http://localhost:8000/v1",
108112
"prefix": "",
109113
"suffix": "",
114+
"dimensions": 32,
110115
"timeout": None,
111116
"max_retries": None,
112117
"http_client_kwargs": None,
113-
"extra_parameters": {"dimensions": 32},
118+
"extra_parameters": None,
114119
},
115120
}
116121
embedder = VLLMTextEmbedder.from_dict(data)
117122
assert embedder.api_key == Secret.from_env_var("VLLM_API_KEY", strict=False)
118123
assert embedder.model == MODEL
119124
assert embedder.api_base_url == "http://localhost:8000/v1"
120-
assert embedder.extra_parameters == {"dimensions": 32}
125+
assert embedder.dimensions == 32
121126

122-
def test_prepare_input_adds_extra_body(self):
123-
embedder = VLLMTextEmbedder(model=MODEL, prefix="[", suffix="]", extra_parameters={"dimensions": 32})
127+
def test_prepare_input_adds_dimensions_and_extra_body(self):
128+
embedder = VLLMTextEmbedder(
129+
model=MODEL, prefix="[", suffix="]", dimensions=32, extra_parameters={"truncate_prompt_tokens": 256}
130+
)
124131
kwargs = embedder._prepare_input("hello")
125132
assert kwargs == {
126133
"model": MODEL,
127134
"input": "[hello]",
128135
"encoding_format": "float",
129-
"extra_body": {"dimensions": 32},
136+
"dimensions": 32,
137+
"extra_body": {"truncate_prompt_tokens": 256},
130138
}
131139

132140
def test_run_wrong_input_format(self):
@@ -135,7 +143,7 @@ def test_run_wrong_input_format(self):
135143
embedder.run(text=["text_1", "text_2"])
136144

137145
def test_run_with_mock(self):
138-
embedder = VLLMTextEmbedder(model=MODEL, prefix="[", suffix="]", extra_parameters={"dimensions": 2})
146+
embedder = VLLMTextEmbedder(model=MODEL, prefix="[", suffix="]", dimensions=2)
139147
embedder._client = MagicMock()
140148
embedder._client.embeddings.create.return_value = _fake_response([[0.1, 0.2]])
141149
embedder._is_warmed_up = True
@@ -144,7 +152,7 @@ def test_run_with_mock(self):
144152

145153
call_kwargs = embedder._client.embeddings.create.call_args.kwargs
146154
assert call_kwargs["input"] == "[hello]"
147-
assert call_kwargs["extra_body"] == {"dimensions": 2}
155+
assert call_kwargs["dimensions"] == 2
148156
assert result == {
149157
"embedding": [0.1, 0.2],
150158
"meta": {"model": "fake-model", "usage": {"prompt_tokens": 5, "total_tokens": 5}},

0 commit comments

Comments
 (0)