|
| 1 | +from unittest.mock import AsyncMock, MagicMock |
| 2 | + |
1 | 3 | import pytest |
2 | 4 | from haystack import Document |
| 5 | +from haystack.core.serialization import default_from_dict, default_to_dict |
3 | 6 | from ollama._types import ResponseError |
4 | 7 |
|
5 | 8 | from haystack_integrations.components.embedders.ollama import OllamaDocumentEmbedder |
@@ -74,3 +77,92 @@ async def test_run_async(self): |
74 | 77 | documents = result["documents"] |
75 | 78 | assert len(documents) == 3 |
76 | 79 | assert all(isinstance(element, float) for document in documents for element in document.embedding) |
| 80 | + |
| 81 | + def test_dimensions_default_is_none(self): |
| 82 | + embedder = OllamaDocumentEmbedder() |
| 83 | + assert embedder.dimensions is None |
| 84 | + |
| 85 | + def test_dimensions_stored_on_instance(self): |
| 86 | + embedder = OllamaDocumentEmbedder(dimensions=512) |
| 87 | + assert embedder.dimensions == 512 |
| 88 | + |
| 89 | + def test_dimensions_passed_to_embed_client(self): |
| 90 | + embedder = OllamaDocumentEmbedder(dimensions=512) |
| 91 | + mock_response = {"embeddings": [[0.1, 0.2, 0.3]]} |
| 92 | + embedder._client.embed = MagicMock(return_value=mock_response) |
| 93 | + |
| 94 | + embedder._embed_batch(["hello world"], batch_size=32) |
| 95 | + |
| 96 | + call_kwargs = embedder._client.embed.call_args.kwargs |
| 97 | + assert call_kwargs["dimensions"] == 512 |
| 98 | + |
| 99 | + def test_none_dimensions_passed_to_embed_client(self): |
| 100 | + embedder = OllamaDocumentEmbedder(dimensions=None) |
| 101 | + mock_response = {"embeddings": [[0.1, 0.2, 0.3]]} |
| 102 | + embedder._client.embed = MagicMock(return_value=mock_response) |
| 103 | + |
| 104 | + embedder._embed_batch(["hello"], batch_size=32) |
| 105 | + |
| 106 | + call_kwargs = embedder._client.embed.call_args.kwargs |
| 107 | + assert call_kwargs["dimensions"] is None |
| 108 | + |
| 109 | + @pytest.mark.asyncio |
| 110 | + async def test_dimensions_passed_to_async_embed_client(self): |
| 111 | + embedder = OllamaDocumentEmbedder(dimensions=256) |
| 112 | + mock_response = {"embeddings": [[0.1, 0.2, 0.3]]} |
| 113 | + embedder._async_client.embed = AsyncMock(return_value=mock_response) |
| 114 | + |
| 115 | + await embedder._embed_batch_async(["hello"], batch_size=32) |
| 116 | + |
| 117 | + call_kwargs = embedder._async_client.embed.call_args.kwargs |
| 118 | + assert call_kwargs["dimensions"] == 256 |
| 119 | + |
| 120 | + def test_to_dict_contains_dimensions(self): |
| 121 | + embedder = OllamaDocumentEmbedder(dimensions=512) |
| 122 | + embedder_dict = default_to_dict( |
| 123 | + embedder, |
| 124 | + model=embedder.model, |
| 125 | + url=embedder.url, |
| 126 | + generation_kwargs=embedder.generation_kwargs, |
| 127 | + timeout=embedder.timeout, |
| 128 | + keep_alive=embedder.keep_alive, |
| 129 | + prefix=embedder.prefix, |
| 130 | + suffix=embedder.suffix, |
| 131 | + progress_bar=embedder.progress_bar, |
| 132 | + meta_fields_to_embed=embedder.meta_fields_to_embed, |
| 133 | + embedding_separator=embedder.embedding_separator, |
| 134 | + batch_size=embedder.batch_size, |
| 135 | + dimensions=embedder.dimensions, |
| 136 | + ) |
| 137 | + assert embedder_dict["init_parameters"]["dimensions"] == 512 |
| 138 | + |
| 139 | + def test_to_dict_contains_dimensions_none(self): |
| 140 | + embedder = OllamaDocumentEmbedder() |
| 141 | + embedder_dict = default_to_dict( |
| 142 | + embedder, |
| 143 | + model=embedder.model, |
| 144 | + url=embedder.url, |
| 145 | + generation_kwargs=embedder.generation_kwargs, |
| 146 | + timeout=embedder.timeout, |
| 147 | + keep_alive=embedder.keep_alive, |
| 148 | + prefix=embedder.prefix, |
| 149 | + suffix=embedder.suffix, |
| 150 | + progress_bar=embedder.progress_bar, |
| 151 | + meta_fields_to_embed=embedder.meta_fields_to_embed, |
| 152 | + embedding_separator=embedder.embedding_separator, |
| 153 | + batch_size=embedder.batch_size, |
| 154 | + dimensions=embedder.dimensions, |
| 155 | + ) |
| 156 | + assert embedder_dict["init_parameters"]["dimensions"] is None |
| 157 | + |
| 158 | + def test_from_dict_restores_dimensions(self): |
| 159 | + embedder_dict = { |
| 160 | + "type": "haystack_integrations.components.embedders.ollama.document_embedder.OllamaDocumentEmbedder", |
| 161 | + "init_parameters": { |
| 162 | + "model": "nomic-embed-text", |
| 163 | + "url": "http://localhost:11434", |
| 164 | + "dimensions": 512, |
| 165 | + }, |
| 166 | + } |
| 167 | + embedder = default_from_dict(OllamaDocumentEmbedder, embedder_dict) |
| 168 | + assert embedder.dimensions == 512 |
0 commit comments