Skip to content

Commit dbb62df

Browse files
feat(ollama): add dimensions parameter to OllamaDocumentEmbedder and OllamaTextEmbedder (#3322)
Co-authored-by: bogdankostic <bogdankostic@web.de>
1 parent c797f9b commit dbb62df

5 files changed

Lines changed: 204 additions & 14 deletions

File tree

integrations/ollama/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ classifiers = [
2727
"Programming Language :: Python :: Implementation :: CPython",
2828
"Programming Language :: Python :: Implementation :: PyPy",
2929
]
30-
dependencies = ["haystack-ai>=2.22.0", "ollama>=0.5.0", "pydantic>=2.12.0", "tenacity>=8.2.3"]
30+
dependencies = ["haystack-ai>=2.22.0", "ollama>=0.5.4", "pydantic>=2.12.0", "tenacity>=8.2.3"]
3131

3232
[project.urls]
3333
Documentation = "https://github.com/deepset-ai/haystack-core-integrations/tree/main/integrations/ollama#readme"

integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
meta_fields_to_embed: list[str] | None = None,
4242
embedding_separator: str = "\n",
4343
batch_size: int = 32,
44+
dimensions: int | None = None,
4445
) -> None:
4546
"""
4647
Create a new OllamaDocumentEmbedder instance.
@@ -76,6 +77,11 @@ def __init__(
7677
Separator used to concatenate the metadata fields to the document text.
7778
:param batch_size:
7879
Number of documents to process at once.
80+
:param dimensions:
81+
The desired number of dimensions in the embedding output. Only supported by models
82+
that implement Matryoshka Representation Learning (MRL), such as nomic-embed-text-v1.5,
83+
mxbai-embed-large, and qwen3-embedding. If None (default), the full vector is returned.
84+
Requires ollama-python >= 0.6.2.
7985
"""
8086
self.keep_alive = keep_alive
8187
self.timeout = timeout
@@ -88,6 +94,7 @@ def __init__(
8894
self.embedding_separator = embedding_separator
8995
self.suffix = suffix
9096
self.prefix = prefix
97+
self.dimensions = dimensions
9198

9299
self._client = Client(host=self.url, timeout=self.timeout)
93100
self._async_client = AsyncClient(host=self.url, timeout=self.timeout)
@@ -145,6 +152,7 @@ def _embed_batch(
145152
input=batch,
146153
options=generation_kwargs,
147154
keep_alive=self.keep_alive,
155+
dimensions=self.dimensions,
148156
)
149157
all_embeddings.extend(result["embeddings"])
150158

@@ -166,6 +174,7 @@ async def _embed_batch_async(
166174
input=batch,
167175
options=generation_kwargs,
168176
keep_alive=self.keep_alive,
177+
dimensions=self.dimensions,
169178
)
170179
for batch in batches
171180
]

integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,7 @@
88
@component
99
class OllamaTextEmbedder:
1010
"""
11-
Computes the embeddings of a list of Documents and stores the obtained vectors in each Document's embedding field.
12-
13-
It uses embedding models compatible with the Ollama Library.
11+
Computes the embeddings of a string using embedding models compatible with the Ollama Library.
1412
1513
Usage example:
1614
```python
@@ -29,6 +27,7 @@ def __init__(
2927
generation_kwargs: dict[str, Any] | None = None,
3028
timeout: int = 120,
3129
keep_alive: float | str | None = None,
30+
dimensions: int | None = None,
3231
) -> None:
3332
"""
3433
Create a new OllamaTextEmbedder instance.
@@ -51,12 +50,17 @@ def __init__(
5150
- a number in seconds (such as 3600)
5251
- any negative number which will keep the model loaded in memory (e.g. -1 or "-1m")
5352
- '0' which will unload the model immediately after generating a response.
53+
:param dimensions:
54+
The desired number of dimensions in the embedding output. Only supported by models
55+
that implement Matryoshka Representation Learning (MRL), such as nomic-embed-text-v1.5,
56+
mxbai-embed-large, and qwen3-embedding. If None (default), the full vector is returned.
5457
"""
5558
self.keep_alive = keep_alive
5659
self.timeout = timeout
5760
self.generation_kwargs = generation_kwargs or {}
5861
self.url = url
5962
self.model = model
63+
self.dimensions = dimensions
6064

6165
self._client = Client(host=self.url, timeout=self.timeout)
6266
self._async_client = AsyncClient(host=self.url, timeout=self.timeout)
@@ -78,15 +82,15 @@ def run(
7882
- `embedding`: The computed embeddings
7983
- `meta`: The metadata collected during the embedding process
8084
"""
81-
result = self._client.embeddings(
85+
result = self._client.embed(
8286
model=self.model,
83-
prompt=text,
87+
input=text,
8488
options=generation_kwargs,
8589
keep_alive=self.keep_alive,
86-
).model_dump()
87-
result["meta"] = {"model": self.model}
90+
dimensions=self.dimensions,
91+
)
8892

89-
return result
93+
return {"embedding": result["embeddings"][0], "meta": {"model": self.model}}
9094

9195
@component.output_types(embedding=list[float], meta=dict[str, Any])
9296
async def run_async(
@@ -105,13 +109,12 @@ async def run_async(
105109
- `embedding`: The computed embeddings
106110
- `meta`: The metadata collected during the embedding process
107111
"""
108-
response = await self._async_client.embeddings(
112+
result = await self._async_client.embed(
109113
model=self.model,
110-
prompt=text,
114+
input=text,
111115
options=generation_kwargs,
112116
keep_alive=self.keep_alive,
117+
dimensions=self.dimensions,
113118
)
114-
result = response.model_dump()
115-
result["meta"] = {"model": self.model}
116119

117-
return result
120+
return {"embedding": result["embeddings"][0], "meta": {"model": self.model}}

integrations/ollama/tests/test_document_embedder.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
from unittest.mock import AsyncMock, MagicMock
2+
13
import pytest
24
from haystack import Document
5+
from haystack.core.serialization import default_from_dict, default_to_dict
36
from ollama._types import ResponseError
47

58
from haystack_integrations.components.embedders.ollama import OllamaDocumentEmbedder
@@ -74,3 +77,92 @@ async def test_run_async(self):
7477
documents = result["documents"]
7578
assert len(documents) == 3
7679
assert all(isinstance(element, float) for document in documents for element in document.embedding)
80+
81+
def test_dimensions_default_is_none(self):
82+
embedder = OllamaDocumentEmbedder()
83+
assert embedder.dimensions is None
84+
85+
def test_dimensions_stored_on_instance(self):
86+
embedder = OllamaDocumentEmbedder(dimensions=512)
87+
assert embedder.dimensions == 512
88+
89+
def test_dimensions_passed_to_embed_client(self):
90+
embedder = OllamaDocumentEmbedder(dimensions=512)
91+
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
92+
embedder._client.embed = MagicMock(return_value=mock_response)
93+
94+
embedder._embed_batch(["hello world"], batch_size=32)
95+
96+
call_kwargs = embedder._client.embed.call_args.kwargs
97+
assert call_kwargs["dimensions"] == 512
98+
99+
def test_none_dimensions_passed_to_embed_client(self):
100+
embedder = OllamaDocumentEmbedder(dimensions=None)
101+
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
102+
embedder._client.embed = MagicMock(return_value=mock_response)
103+
104+
embedder._embed_batch(["hello"], batch_size=32)
105+
106+
call_kwargs = embedder._client.embed.call_args.kwargs
107+
assert call_kwargs["dimensions"] is None
108+
109+
@pytest.mark.asyncio
110+
async def test_dimensions_passed_to_async_embed_client(self):
111+
embedder = OllamaDocumentEmbedder(dimensions=256)
112+
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
113+
embedder._async_client.embed = AsyncMock(return_value=mock_response)
114+
115+
await embedder._embed_batch_async(["hello"], batch_size=32)
116+
117+
call_kwargs = embedder._async_client.embed.call_args.kwargs
118+
assert call_kwargs["dimensions"] == 256
119+
120+
def test_to_dict_contains_dimensions(self):
121+
embedder = OllamaDocumentEmbedder(dimensions=512)
122+
embedder_dict = default_to_dict(
123+
embedder,
124+
model=embedder.model,
125+
url=embedder.url,
126+
generation_kwargs=embedder.generation_kwargs,
127+
timeout=embedder.timeout,
128+
keep_alive=embedder.keep_alive,
129+
prefix=embedder.prefix,
130+
suffix=embedder.suffix,
131+
progress_bar=embedder.progress_bar,
132+
meta_fields_to_embed=embedder.meta_fields_to_embed,
133+
embedding_separator=embedder.embedding_separator,
134+
batch_size=embedder.batch_size,
135+
dimensions=embedder.dimensions,
136+
)
137+
assert embedder_dict["init_parameters"]["dimensions"] == 512
138+
139+
def test_to_dict_contains_dimensions_none(self):
140+
embedder = OllamaDocumentEmbedder()
141+
embedder_dict = default_to_dict(
142+
embedder,
143+
model=embedder.model,
144+
url=embedder.url,
145+
generation_kwargs=embedder.generation_kwargs,
146+
timeout=embedder.timeout,
147+
keep_alive=embedder.keep_alive,
148+
prefix=embedder.prefix,
149+
suffix=embedder.suffix,
150+
progress_bar=embedder.progress_bar,
151+
meta_fields_to_embed=embedder.meta_fields_to_embed,
152+
embedding_separator=embedder.embedding_separator,
153+
batch_size=embedder.batch_size,
154+
dimensions=embedder.dimensions,
155+
)
156+
assert embedder_dict["init_parameters"]["dimensions"] is None
157+
158+
def test_from_dict_restores_dimensions(self):
159+
embedder_dict = {
160+
"type": "haystack_integrations.components.embedders.ollama.document_embedder.OllamaDocumentEmbedder",
161+
"init_parameters": {
162+
"model": "nomic-embed-text",
163+
"url": "http://localhost:11434",
164+
"dimensions": 512,
165+
},
166+
}
167+
embedder = default_from_dict(OllamaDocumentEmbedder, embedder_dict)
168+
assert embedder.dimensions == 512

integrations/ollama/tests/test_text_embedder.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1+
import asyncio
2+
from unittest.mock import AsyncMock, MagicMock
3+
14
import pytest
5+
from haystack.core.serialization import default_from_dict, default_to_dict
26
from ollama._types import ResponseError
37

48
from haystack_integrations.components.embedders.ollama import OllamaTextEmbedder
@@ -58,3 +62,85 @@ async def test_run_async(self):
5862
assert isinstance(reply, dict)
5963
assert all(isinstance(element, float) for element in reply["embedding"])
6064
assert reply["meta"]["model"] == "all-minilm"
65+
66+
def test_dimensions_default_is_none(self):
67+
embedder = OllamaTextEmbedder()
68+
assert embedder.dimensions is None
69+
70+
def test_dimensions_stored_on_instance(self):
71+
embedder = OllamaTextEmbedder(dimensions=256)
72+
assert embedder.dimensions == 256
73+
74+
def test_dimensions_passed_to_embed_client(self):
75+
76+
embedder = OllamaTextEmbedder(dimensions=256)
77+
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
78+
embedder._client.embed = MagicMock(return_value=mock_response)
79+
80+
embedder.run(text="hello world")
81+
82+
call_kwargs = embedder._client.embed.call_args.kwargs
83+
assert call_kwargs["dimensions"] == 256
84+
85+
def test_none_dimensions_passed_to_embed_client(self):
86+
87+
embedder = OllamaTextEmbedder(dimensions=None)
88+
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
89+
embedder._client.embed = MagicMock(return_value=mock_response)
90+
91+
embedder.run(text="hello")
92+
93+
call_kwargs = embedder._client.embed.call_args.kwargs
94+
assert call_kwargs["dimensions"] is None
95+
96+
def test_dimensions_passed_to_async_embed_client(self):
97+
98+
embedder = OllamaTextEmbedder(dimensions=128)
99+
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
100+
embedder._async_client.embed = AsyncMock(return_value=mock_response)
101+
102+
asyncio.run(embedder.run_async(text="hello"))
103+
104+
call_kwargs = embedder._async_client.embed.call_args.kwargs
105+
assert call_kwargs["dimensions"] == 128
106+
107+
def test_to_dict_contains_dimensions(self):
108+
109+
embedder = OllamaTextEmbedder(dimensions=256)
110+
embedder_dict = default_to_dict(
111+
embedder,
112+
model=embedder.model,
113+
url=embedder.url,
114+
generation_kwargs=embedder.generation_kwargs,
115+
timeout=embedder.timeout,
116+
keep_alive=embedder.keep_alive,
117+
dimensions=embedder.dimensions,
118+
)
119+
assert embedder_dict["init_parameters"]["dimensions"] == 256
120+
121+
def test_to_dict_contains_dimensions_none(self):
122+
123+
embedder = OllamaTextEmbedder()
124+
embedder_dict = default_to_dict(
125+
embedder,
126+
model=embedder.model,
127+
url=embedder.url,
128+
generation_kwargs=embedder.generation_kwargs,
129+
timeout=embedder.timeout,
130+
keep_alive=embedder.keep_alive,
131+
dimensions=embedder.dimensions,
132+
)
133+
assert embedder_dict["init_parameters"]["dimensions"] is None
134+
135+
def test_from_dict_restores_dimensions(self):
136+
137+
embedder_dict = {
138+
"type": "haystack_integrations.components.embedders.ollama.text_embedder.OllamaTextEmbedder",
139+
"init_parameters": {
140+
"model": "nomic-embed-text",
141+
"url": "http://localhost:11434",
142+
"dimensions": 256,
143+
},
144+
}
145+
embedder = default_from_dict(OllamaTextEmbedder, embedder_dict)
146+
assert embedder.dimensions == 256

0 commit comments

Comments
 (0)