Skip to content

Commit f8d5e89

Browse files
feat: add dimensions parameter to OllamaDocumentEmbedder and OllamaTextEmbedder
1 parent 63e66be commit f8d5e89

4 files changed

Lines changed: 203 additions & 10 deletions

File tree

integrations/ollama/src/haystack_integrations/components/embedders/ollama/document_embedder.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def __init__(
4141
meta_fields_to_embed: list[str] | None = None,
4242
embedding_separator: str = "\n",
4343
batch_size: int = 32,
44+
dimensions: int | None = None,
4445
) -> None:
4546
"""
4647
Create a new OllamaDocumentEmbedder instance.
@@ -76,6 +77,11 @@ def __init__(
7677
Separator used to concatenate the metadata fields to the document text.
7778
:param batch_size:
7879
Number of documents to process at once.
80+
:param dimensions:
81+
The desired number of dimensions in the embedding output. Only supported by models
82+
that implement Matryoshka Representation Learning (MRL), such as nomic-embed-text-v1.5,
83+
mxbai-embed-large, and qwen3-embedding. If None (default), the full vector is returned.
84+
Requires ollama-python >= 0.6.2.
7985
"""
8086
self.keep_alive = keep_alive
8187
self.timeout = timeout
@@ -88,6 +94,7 @@ def __init__(
8894
self.embedding_separator = embedding_separator
8995
self.suffix = suffix
9096
self.prefix = prefix
97+
self.dimensions = dimensions
9198

9299
self._client = Client(host=self.url, timeout=self.timeout)
93100
self._async_client = AsyncClient(host=self.url, timeout=self.timeout)
@@ -145,6 +152,7 @@ def _embed_batch(
145152
input=batch,
146153
options=generation_kwargs,
147154
keep_alive=self.keep_alive,
155+
dimensions=self.dimensions,
148156
)
149157
all_embeddings.extend(result["embeddings"])
150158

@@ -166,6 +174,7 @@ async def _embed_batch_async(
166174
input=batch,
167175
options=generation_kwargs,
168176
keep_alive=self.keep_alive,
177+
dimensions=self.dimensions,
169178
)
170179
for batch in batches
171180
]

integrations/ollama/src/haystack_integrations/components/embedders/ollama/text_embedder.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
generation_kwargs: dict[str, Any] | None = None,
3030
timeout: int = 120,
3131
keep_alive: float | str | None = None,
32+
dimensions: int | None = None,
3233
) -> None:
3334
"""
3435
Create a new OllamaTextEmbedder instance.
@@ -51,12 +52,18 @@ def __init__(
5152
- a number in seconds (such as 3600)
5253
- any negative number which will keep the model loaded in memory (e.g. -1 or "-1m")
5354
- '0' which will unload the model immediately after generating a response.
55+
:param dimensions:
56+
The desired number of dimensions in the embedding output. Only supported by models
57+
that implement Matryoshka Representation Learning (MRL), such as nomic-embed-text-v1.5,
58+
mxbai-embed-large, and qwen3-embedding. If None (default), the full vector is returned.
59+
Requires ollama-python >= 0.6.2.
5460
"""
5561
self.keep_alive = keep_alive
5662
self.timeout = timeout
5763
self.generation_kwargs = generation_kwargs or {}
5864
self.url = url
5965
self.model = model
66+
self.dimensions = dimensions
6067

6168
self._client = Client(host=self.url, timeout=self.timeout)
6269
self._async_client = AsyncClient(host=self.url, timeout=self.timeout)
@@ -78,15 +85,15 @@ def run(
7885
- `embedding`: The computed embeddings
7986
- `meta`: The metadata collected during the embedding process
8087
"""
81-
result = self._client.embeddings(
88+
result = self._client.embed(
8289
model=self.model,
83-
prompt=text,
90+
input=text,
8491
options=generation_kwargs,
8592
keep_alive=self.keep_alive,
86-
).model_dump()
87-
result["meta"] = {"model": self.model}
93+
dimensions=self.dimensions,
94+
)
8895

89-
return result
96+
return {"embedding": result["embeddings"][0], "meta": {"model": self.model}}
9097

9198
@component.output_types(embedding=list[float], meta=dict[str, Any])
9299
async def run_async(
@@ -105,13 +112,12 @@ async def run_async(
105112
- `embedding`: The computed embeddings
106113
- `meta`: The metadata collected during the embedding process
107114
"""
108-
response = await self._async_client.embeddings(
115+
result = await self._async_client.embed(
109116
model=self.model,
110-
prompt=text,
117+
input=text,
111118
options=generation_kwargs,
112119
keep_alive=self.keep_alive,
120+
dimensions=self.dimensions,
113121
)
114-
result = response.model_dump()
115-
result["meta"] = {"model": self.model}
116122

117-
return result
123+
return {"embedding": result["embeddings"][0], "meta": {"model": self.model}}

integrations/ollama/tests/test_document_embedder.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,3 +74,98 @@ async def test_run_async(self):
7474
documents = result["documents"]
7575
assert len(documents) == 3
7676
assert all(isinstance(element, float) for document in documents for element in document.embedding)
77+
78+
def test_dimensions_default_is_none(self):
79+
embedder = OllamaDocumentEmbedder()
80+
assert embedder.dimensions is None
81+
82+
def test_dimensions_stored_on_instance(self):
83+
embedder = OllamaDocumentEmbedder(dimensions=512)
84+
assert embedder.dimensions == 512
85+
86+
def test_dimensions_passed_to_embed_client(self):
87+
from unittest.mock import MagicMock
88+
embedder = OllamaDocumentEmbedder(dimensions=512)
89+
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
90+
embedder._client.embed = MagicMock(return_value=mock_response)
91+
92+
embedder._embed_batch(["hello world"], batch_size=32)
93+
94+
call_kwargs = embedder._client.embed.call_args.kwargs
95+
assert call_kwargs["dimensions"] == 512
96+
97+
def test_none_dimensions_passed_to_embed_client(self):
98+
from unittest.mock import MagicMock
99+
embedder = OllamaDocumentEmbedder(dimensions=None)
100+
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
101+
embedder._client.embed = MagicMock(return_value=mock_response)
102+
103+
embedder._embed_batch(["hello"], batch_size=32)
104+
105+
call_kwargs = embedder._client.embed.call_args.kwargs
106+
assert call_kwargs["dimensions"] is None
107+
108+
def test_dimensions_passed_to_async_embed_client(self):
109+
import asyncio
110+
from unittest.mock import AsyncMock
111+
embedder = OllamaDocumentEmbedder(dimensions=256)
112+
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
113+
embedder._async_client.embed = AsyncMock(return_value=mock_response)
114+
115+
asyncio.run(embedder._embed_batch_async(["hello"], batch_size=32))
116+
117+
call_kwargs = embedder._async_client.embed.call_args.kwargs
118+
assert call_kwargs["dimensions"] == 256
119+
120+
def test_to_dict_contains_dimensions(self):
121+
from haystack.core.serialization import default_to_dict
122+
embedder = OllamaDocumentEmbedder(dimensions=512)
123+
embedder_dict = default_to_dict(
124+
embedder,
125+
model=embedder.model,
126+
url=embedder.url,
127+
generation_kwargs=embedder.generation_kwargs,
128+
timeout=embedder.timeout,
129+
keep_alive=embedder.keep_alive,
130+
prefix=embedder.prefix,
131+
suffix=embedder.suffix,
132+
progress_bar=embedder.progress_bar,
133+
meta_fields_to_embed=embedder.meta_fields_to_embed,
134+
embedding_separator=embedder.embedding_separator,
135+
batch_size=embedder.batch_size,
136+
dimensions=embedder.dimensions,
137+
)
138+
assert embedder_dict["init_parameters"]["dimensions"] == 512
139+
140+
def test_to_dict_contains_dimensions_none(self):
141+
from haystack.core.serialization import default_to_dict
142+
embedder = OllamaDocumentEmbedder()
143+
embedder_dict = default_to_dict(
144+
embedder,
145+
model=embedder.model,
146+
url=embedder.url,
147+
generation_kwargs=embedder.generation_kwargs,
148+
timeout=embedder.timeout,
149+
keep_alive=embedder.keep_alive,
150+
prefix=embedder.prefix,
151+
suffix=embedder.suffix,
152+
progress_bar=embedder.progress_bar,
153+
meta_fields_to_embed=embedder.meta_fields_to_embed,
154+
embedding_separator=embedder.embedding_separator,
155+
batch_size=embedder.batch_size,
156+
dimensions=embedder.dimensions,
157+
)
158+
assert embedder_dict["init_parameters"]["dimensions"] is None
159+
160+
def test_from_dict_restores_dimensions(self):
161+
from haystack.core.serialization import default_from_dict
162+
embedder_dict = {
163+
"type": "haystack_integrations.components.embedders.ollama.document_embedder.OllamaDocumentEmbedder",
164+
"init_parameters": {
165+
"model": "nomic-embed-text",
166+
"url": "http://localhost:11434",
167+
"dimensions": 512,
168+
},
169+
}
170+
embedder = default_from_dict(OllamaDocumentEmbedder, embedder_dict)
171+
assert embedder.dimensions == 512

integrations/ollama/tests/test_text_embedder.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,86 @@ async def test_run_async(self):
5858
assert isinstance(reply, dict)
5959
assert all(isinstance(element, float) for element in reply["embedding"])
6060
assert reply["meta"]["model"] == "all-minilm"
61+
62+
def test_dimensions_default_is_none(self):
63+
embedder = OllamaTextEmbedder()
64+
assert embedder.dimensions is None
65+
66+
def test_dimensions_stored_on_instance(self):
67+
embedder = OllamaTextEmbedder(dimensions=256)
68+
assert embedder.dimensions == 256
69+
70+
def test_dimensions_passed_to_embed_client(self):
71+
from unittest.mock import MagicMock
72+
embedder = OllamaTextEmbedder(dimensions=256)
73+
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
74+
embedder._client.embed = MagicMock(return_value=mock_response)
75+
76+
embedder.run(text="hello world")
77+
78+
call_kwargs = embedder._client.embed.call_args.kwargs
79+
assert call_kwargs["dimensions"] == 256
80+
81+
def test_none_dimensions_passed_to_embed_client(self):
82+
from unittest.mock import MagicMock
83+
embedder = OllamaTextEmbedder(dimensions=None)
84+
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
85+
embedder._client.embed = MagicMock(return_value=mock_response)
86+
87+
embedder.run(text="hello")
88+
89+
call_kwargs = embedder._client.embed.call_args.kwargs
90+
assert call_kwargs["dimensions"] is None
91+
92+
def test_dimensions_passed_to_async_embed_client(self):
93+
import asyncio
94+
from unittest.mock import AsyncMock
95+
embedder = OllamaTextEmbedder(dimensions=128)
96+
mock_response = {"embeddings": [[0.1, 0.2, 0.3]]}
97+
embedder._async_client.embed = AsyncMock(return_value=mock_response)
98+
99+
asyncio.run(embedder.run_async(text="hello"))
100+
101+
call_kwargs = embedder._async_client.embed.call_args.kwargs
102+
assert call_kwargs["dimensions"] == 128
103+
104+
def test_to_dict_contains_dimensions(self):
105+
from haystack.core.serialization import default_to_dict
106+
embedder = OllamaTextEmbedder(dimensions=256)
107+
embedder_dict = default_to_dict(
108+
embedder,
109+
model=embedder.model,
110+
url=embedder.url,
111+
generation_kwargs=embedder.generation_kwargs,
112+
timeout=embedder.timeout,
113+
keep_alive=embedder.keep_alive,
114+
dimensions=embedder.dimensions,
115+
)
116+
assert embedder_dict["init_parameters"]["dimensions"] == 256
117+
118+
def test_to_dict_contains_dimensions_none(self):
119+
from haystack.core.serialization import default_to_dict
120+
embedder = OllamaTextEmbedder()
121+
embedder_dict = default_to_dict(
122+
embedder,
123+
model=embedder.model,
124+
url=embedder.url,
125+
generation_kwargs=embedder.generation_kwargs,
126+
timeout=embedder.timeout,
127+
keep_alive=embedder.keep_alive,
128+
dimensions=embedder.dimensions,
129+
)
130+
assert embedder_dict["init_parameters"]["dimensions"] is None
131+
132+
def test_from_dict_restores_dimensions(self):
133+
from haystack.core.serialization import default_from_dict
134+
embedder_dict = {
135+
"type": "haystack_integrations.components.embedders.ollama.text_embedder.OllamaTextEmbedder",
136+
"init_parameters": {
137+
"model": "nomic-embed-text",
138+
"url": "http://localhost:11434",
139+
"dimensions": 256,
140+
},
141+
}
142+
embedder = default_from_dict(OllamaTextEmbedder, embedder_dict)
143+
assert embedder.dimensions == 256

0 commit comments

Comments
 (0)