|
1 | | -"""Tests for LiteLLMEmbeddingProvider. |
| 1 | +"""Tests for LiteLLMEmbeddingProvider and factory litellm branch.""" |
2 | 2 |
|
3 | | -Uses AST parsing and direct SDK mocking to avoid importing the full |
4 | | -basic_memory dependency chain (logfire, alembic, etc.). |
5 | | -""" |
6 | | - |
7 | | -import ast |
| 3 | +import asyncio |
| 4 | +import builtins |
8 | 5 | import sys |
9 | | -import types |
10 | | -from pathlib import Path |
11 | | -from unittest.mock import AsyncMock, MagicMock |
| 6 | +from types import SimpleNamespace |
12 | 7 |
|
13 | 8 | import pytest |
14 | 9 |
|
15 | | -PROVIDER_PATH = ( |
16 | | - Path(__file__).resolve().parents[2] |
17 | | - / "src" |
18 | | - / "basic_memory" |
19 | | - / "repository" |
20 | | - / "litellm_provider.py" |
21 | | -) |
22 | | -FACTORY_PATH = ( |
23 | | - Path(__file__).resolve().parents[2] |
24 | | - / "src" |
25 | | - / "basic_memory" |
26 | | - / "repository" |
27 | | - / "embedding_provider_factory.py" |
| 10 | +from basic_memory.config import BasicMemoryConfig |
| 11 | +from basic_memory.repository.embedding_provider_factory import ( |
| 12 | + create_embedding_provider, |
| 13 | + reset_embedding_provider_cache, |
28 | 14 | ) |
29 | | - |
30 | | - |
31 | | -class TestLiteLLMProviderStructure: |
32 | | - """Verify the provider file has the correct structure.""" |
33 | | - |
34 | | - def _parse(self): |
35 | | - return ast.parse(PROVIDER_PATH.read_text()) |
36 | | - |
37 | | - def test_file_exists(self): |
38 | | - assert PROVIDER_PATH.exists() |
39 | | - |
40 | | - def test_has_litellm_embedding_provider_class(self): |
41 | | - tree = self._parse() |
42 | | - classes = [n.name for n in ast.walk(tree) if isinstance(n, ast.ClassDef)] |
43 | | - assert "LiteLLMEmbeddingProvider" in classes |
44 | | - |
45 | | - def test_has_embed_documents_method(self): |
46 | | - tree = self._parse() |
47 | | - for node in ast.walk(tree): |
48 | | - if isinstance(node, ast.ClassDef) and node.name == "LiteLLMEmbeddingProvider": |
49 | | - methods = [ |
50 | | - n.name |
51 | | - for n in node.body |
52 | | - if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)) |
53 | | - ] |
54 | | - assert "embed_documents" in methods |
55 | | - assert "embed_query" in methods |
56 | | - return |
57 | | - pytest.fail("LiteLLMEmbeddingProvider class not found") |
58 | | - |
59 | | - def test_embed_documents_is_async(self): |
60 | | - tree = self._parse() |
61 | | - for node in ast.walk(tree): |
62 | | - if isinstance(node, ast.ClassDef) and node.name == "LiteLLMEmbeddingProvider": |
63 | | - for item in node.body: |
64 | | - if isinstance(item, ast.AsyncFunctionDef) and item.name == "embed_documents": |
65 | | - return |
66 | | - pytest.fail("embed_documents is not async") |
67 | | - |
68 | | - def test_uses_drop_params_true(self): |
69 | | - src = PROVIDER_PATH.read_text() |
70 | | - assert "drop_params" in src |
71 | | - |
72 | | - def test_uses_litellm_aembedding(self): |
73 | | - src = PROVIDER_PATH.read_text() |
74 | | - assert "aembedding" in src |
75 | | - |
76 | | - def test_has_runtime_log_attrs(self): |
77 | | - tree = self._parse() |
78 | | - for node in ast.walk(tree): |
79 | | - if isinstance(node, ast.ClassDef) and node.name == "LiteLLMEmbeddingProvider": |
80 | | - methods = [ |
81 | | - n.name |
82 | | - for n in node.body |
83 | | - if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef)) |
84 | | - ] |
85 | | - assert "runtime_log_attrs" in methods |
86 | | - return |
87 | | - |
88 | | - def test_default_model_in_source(self): |
89 | | - src = PROVIDER_PATH.read_text() |
90 | | - assert "openai/text-embedding-3-small" in src |
91 | | - |
92 | | - |
93 | | -class TestFactoryRegistration: |
94 | | - """Verify the factory recognizes litellm as a provider.""" |
95 | | - |
96 | | - def test_litellm_branch_in_factory(self): |
97 | | - src = FACTORY_PATH.read_text() |
98 | | - assert 'provider_name == "litellm"' in src |
99 | | - |
100 | | - def test_imports_litellm_provider(self): |
101 | | - src = FACTORY_PATH.read_text() |
102 | | - assert "LiteLLMEmbeddingProvider" in src |
103 | | - |
104 | | - |
105 | | -class TestLiteLLMSDKInteraction: |
106 | | - """Test litellm SDK calls directly (no basic_memory deps needed).""" |
107 | | - |
108 | | - def test_aembedding_called_with_drop_params(self): |
109 | | - fake = types.ModuleType("litellm") |
110 | | - response = MagicMock() |
111 | | - response.data = [{"index": 0, "embedding": [0.1, 0.2]}] |
112 | | - fake.aembedding = AsyncMock(return_value=response) |
113 | | - sys.modules["litellm"] = fake |
114 | | - |
115 | | - try: |
116 | | - import asyncio |
117 | | - |
118 | | - async def run(): |
119 | | - await fake.aembedding( |
120 | | - model="openai/text-embedding-3-small", |
121 | | - input=["hello"], |
122 | | - drop_params=True, |
123 | | - ) |
124 | | - |
125 | | - asyncio.run(run()) |
126 | | - kwargs = fake.aembedding.call_args.kwargs |
127 | | - assert kwargs["drop_params"] is True |
128 | | - assert kwargs["model"] == "openai/text-embedding-3-small" |
129 | | - finally: |
130 | | - del sys.modules["litellm"] |
131 | | - |
132 | | - def test_aembedding_forwards_api_key(self): |
133 | | - fake = types.ModuleType("litellm") |
134 | | - response = MagicMock() |
135 | | - response.data = [{"index": 0, "embedding": [0.1]}] |
136 | | - fake.aembedding = AsyncMock(return_value=response) |
137 | | - sys.modules["litellm"] = fake |
138 | | - |
139 | | - try: |
140 | | - import asyncio |
141 | | - |
142 | | - async def run(): |
143 | | - await fake.aembedding( |
144 | | - model="openai/text-embedding-3-small", |
145 | | - input=["hello"], |
146 | | - api_key="sk-test", |
147 | | - drop_params=True, |
148 | | - ) |
149 | | - |
150 | | - asyncio.run(run()) |
151 | | - assert fake.aembedding.call_args.kwargs["api_key"] == "sk-test" |
152 | | - finally: |
153 | | - del sys.modules["litellm"] |
154 | | - |
155 | | - def test_aembedding_response_has_vectors(self): |
156 | | - fake = types.ModuleType("litellm") |
157 | | - response = MagicMock() |
158 | | - response.data = [ |
159 | | - {"index": 0, "embedding": [0.1, 0.2, 0.3]}, |
160 | | - {"index": 1, "embedding": [0.4, 0.5, 0.6]}, |
161 | | - ] |
162 | | - fake.aembedding = AsyncMock(return_value=response) |
163 | | - sys.modules["litellm"] = fake |
164 | | - |
165 | | - try: |
166 | | - import asyncio |
167 | | - |
168 | | - async def run(): |
169 | | - resp = await fake.aembedding( |
170 | | - model="openai/text-embedding-3-small", |
171 | | - input=["hello", "world"], |
172 | | - drop_params=True, |
173 | | - ) |
174 | | - return resp |
175 | | - |
176 | | - resp = asyncio.run(run()) |
177 | | - assert len(resp.data) == 2 |
178 | | - assert resp.data[0]["embedding"] == [0.1, 0.2, 0.3] |
179 | | - assert resp.data[1]["embedding"] == [0.4, 0.5, 0.6] |
180 | | - finally: |
181 | | - del sys.modules["litellm"] |
| 15 | +from basic_memory.repository.litellm_provider import LiteLLMEmbeddingProvider |
| 16 | +from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError |
| 17 | + |
| 18 | + |
| 19 | +def _make_embedding_response(inputs: list[str], dim: int = 3): |
| 20 | + """Build a fake litellm.aembedding response matching the real shape.""" |
| 21 | + data = [] |
| 22 | + for index, text in enumerate(inputs): |
| 23 | + base = float(len(text)) |
| 24 | + data.append({"index": index, "embedding": [base + float(d) for d in range(dim)]}) |
| 25 | + return SimpleNamespace(data=data) |
| 26 | + |
| 27 | + |
| 28 | +def _install_litellm_stub(monkeypatch, dim: int = 3): |
| 29 | + """Install a fake litellm module and return the mock aembedding callable.""" |
| 30 | + calls: list[dict] = [] |
| 31 | + |
| 32 | + async def _aembedding(**kwargs): |
| 33 | + calls.append(kwargs) |
| 34 | + return _make_embedding_response(kwargs["input"], dim) |
| 35 | + |
| 36 | + module = type(sys)("litellm") |
| 37 | + setattr(module, "aembedding", _aembedding) |
| 38 | + monkeypatch.setitem(sys.modules, "litellm", module) |
| 39 | + return calls |
| 40 | + |
| 41 | + |
| 42 | +@pytest.fixture(autouse=True) |
| 43 | +def _reset_cache(): |
| 44 | + reset_embedding_provider_cache() |
| 45 | + yield |
| 46 | + reset_embedding_provider_cache() |
| 47 | + |
| 48 | + |
| 49 | +@pytest.mark.asyncio |
| 50 | +async def test_litellm_provider_embed_query(monkeypatch): |
| 51 | + """embed_query should return a single vector through litellm.aembedding.""" |
| 52 | + _install_litellm_stub(monkeypatch) |
| 53 | + provider = LiteLLMEmbeddingProvider( |
| 54 | + model_name="openai/text-embedding-3-small", batch_size=2, dimensions=3 |
| 55 | + ) |
| 56 | + result = await provider.embed_query("hello world") |
| 57 | + assert len(result) == 3 |
| 58 | + assert all(isinstance(v, float) for v in result) |
| 59 | + |
| 60 | + |
| 61 | +@pytest.mark.asyncio |
| 62 | +async def test_litellm_provider_embed_documents(monkeypatch): |
| 63 | + """embed_documents should return vectors for each input text.""" |
| 64 | + _install_litellm_stub(monkeypatch) |
| 65 | + provider = LiteLLMEmbeddingProvider( |
| 66 | + model_name="openai/text-embedding-3-small", batch_size=2, dimensions=3 |
| 67 | + ) |
| 68 | + texts = ["first doc", "second doc", "third doc"] |
| 69 | + result = await provider.embed_documents(texts) |
| 70 | + assert len(result) == 3 |
| 71 | + assert all(len(v) == 3 for v in result) |
| 72 | + |
| 73 | + |
| 74 | +@pytest.mark.asyncio |
| 75 | +async def test_litellm_provider_empty_input(monkeypatch): |
| 76 | + """embed_documents with empty list should return empty list.""" |
| 77 | + _install_litellm_stub(monkeypatch) |
| 78 | + provider = LiteLLMEmbeddingProvider(dimensions=3) |
| 79 | + result = await provider.embed_documents([]) |
| 80 | + assert result == [] |
| 81 | + |
| 82 | + |
| 83 | +@pytest.mark.asyncio |
| 84 | +async def test_litellm_provider_batching(monkeypatch): |
| 85 | + """Provider should split inputs into batches of batch_size.""" |
| 86 | + calls = _install_litellm_stub(monkeypatch) |
| 87 | + provider = LiteLLMEmbeddingProvider( |
| 88 | + model_name="openai/text-embedding-3-small", batch_size=2, dimensions=3 |
| 89 | + ) |
| 90 | + texts = ["a", "b", "c", "d", "e"] |
| 91 | + result = await provider.embed_documents(texts) |
| 92 | + |
| 93 | + assert len(result) == 5 |
| 94 | + assert len(calls) == 3 # 2 + 2 + 1 |
| 95 | + |
| 96 | + |
| 97 | +@pytest.mark.asyncio |
| 98 | +async def test_litellm_provider_api_key_forwarded(monkeypatch): |
| 99 | + """api_key should be passed to litellm.aembedding when set.""" |
| 100 | + calls = _install_litellm_stub(monkeypatch) |
| 101 | + provider = LiteLLMEmbeddingProvider( |
| 102 | + model_name="openai/text-embedding-3-small", |
| 103 | + api_key="sk-test-key", |
| 104 | + dimensions=3, |
| 105 | + ) |
| 106 | + await provider.embed_query("test") |
| 107 | + assert calls[0]["api_key"] == "sk-test-key" |
| 108 | + |
| 109 | + |
| 110 | +@pytest.mark.asyncio |
| 111 | +async def test_litellm_provider_api_key_omitted_when_none(monkeypatch): |
| 112 | + """api_key should not appear in kwargs when not set.""" |
| 113 | + calls = _install_litellm_stub(monkeypatch) |
| 114 | + provider = LiteLLMEmbeddingProvider( |
| 115 | + model_name="openai/text-embedding-3-small", dimensions=3 |
| 116 | + ) |
| 117 | + await provider.embed_query("test") |
| 118 | + assert "api_key" not in calls[0] |
| 119 | + |
| 120 | + |
| 121 | +@pytest.mark.asyncio |
| 122 | +async def test_litellm_provider_drop_params_always_set(monkeypatch): |
| 123 | + """drop_params=True should always be in the call kwargs.""" |
| 124 | + calls = _install_litellm_stub(monkeypatch) |
| 125 | + provider = LiteLLMEmbeddingProvider(dimensions=3) |
| 126 | + await provider.embed_query("test") |
| 127 | + assert calls[0]["drop_params"] is True |
| 128 | + |
| 129 | + |
| 130 | +@pytest.mark.asyncio |
| 131 | +async def test_litellm_provider_dimension_mismatch_raises_error(monkeypatch): |
| 132 | + """Provider should fail fast when response dimensions differ from configured.""" |
| 133 | + _install_litellm_stub(monkeypatch, dim=3) |
| 134 | + provider = LiteLLMEmbeddingProvider(dimensions=5) |
| 135 | + with pytest.raises(RuntimeError, match="3-dimensional vectors"): |
| 136 | + await provider.embed_documents(["test text"]) |
| 137 | + |
| 138 | + |
| 139 | +@pytest.mark.asyncio |
| 140 | +async def test_litellm_provider_missing_dependency_raises_actionable_error(monkeypatch): |
| 141 | + """Missing litellm package should raise SemanticDependenciesMissingError.""" |
| 142 | + monkeypatch.delitem(sys.modules, "litellm", raising=False) |
| 143 | + original_import = builtins.__import__ |
| 144 | + |
| 145 | + def _raising_import(name, globals=None, locals=None, fromlist=(), level=0): |
| 146 | + if name == "litellm": |
| 147 | + raise ImportError("litellm not installed") |
| 148 | + return original_import(name, globals, locals, fromlist, level) |
| 149 | + |
| 150 | + monkeypatch.setattr(builtins, "__import__", _raising_import) |
| 151 | + |
| 152 | + provider = LiteLLMEmbeddingProvider(model_name="openai/text-embedding-3-small") |
| 153 | + with pytest.raises(SemanticDependenciesMissingError): |
| 154 | + await provider.embed_query("test") |
| 155 | + |
| 156 | + |
| 157 | +@pytest.mark.asyncio |
| 158 | +async def test_litellm_provider_output_ordering(monkeypatch): |
| 159 | + """Vectors should be returned in the same order as input texts.""" |
| 160 | + _install_litellm_stub(monkeypatch) |
| 161 | + provider = LiteLLMEmbeddingProvider(dimensions=3, batch_size=2) |
| 162 | + texts = ["short", "a longer text here"] |
| 163 | + result = await provider.embed_documents(texts) |
| 164 | + |
| 165 | + assert result[0][0] == float(len("short")) |
| 166 | + assert result[1][0] == float(len("a longer text here")) |
| 167 | + |
| 168 | + |
| 169 | +def test_factory_selects_litellm_provider(): |
| 170 | + """Factory should select LiteLLMEmbeddingProvider for litellm config.""" |
| 171 | + config = BasicMemoryConfig( |
| 172 | + env="test", |
| 173 | + projects={"test": "/tmp/basic-memory-test"}, |
| 174 | + default_project="test", |
| 175 | + semantic_search_enabled=True, |
| 176 | + semantic_embedding_provider="litellm", |
| 177 | + semantic_embedding_model="openai/text-embedding-3-small", |
| 178 | + ) |
| 179 | + provider = create_embedding_provider(config) |
| 180 | + assert isinstance(provider, LiteLLMEmbeddingProvider) |
| 181 | + assert provider.model_name == "openai/text-embedding-3-small" |
| 182 | + |
| 183 | + |
| 184 | +def test_factory_maps_default_model_for_litellm(): |
| 185 | + """Factory should remap bge-small-en-v1.5 default to openai/text-embedding-3-small.""" |
| 186 | + config = BasicMemoryConfig( |
| 187 | + env="test", |
| 188 | + projects={"test": "/tmp/basic-memory-test"}, |
| 189 | + default_project="test", |
| 190 | + semantic_search_enabled=True, |
| 191 | + semantic_embedding_provider="litellm", |
| 192 | + semantic_embedding_model="bge-small-en-v1.5", |
| 193 | + ) |
| 194 | + provider = create_embedding_provider(config) |
| 195 | + assert isinstance(provider, LiteLLMEmbeddingProvider) |
| 196 | + assert provider.model_name == "openai/text-embedding-3-small" |
| 197 | + |
| 198 | + |
| 199 | +def test_runtime_log_attrs(): |
| 200 | + """runtime_log_attrs should return batch_size and concurrency.""" |
| 201 | + provider = LiteLLMEmbeddingProvider(batch_size=32, request_concurrency=8) |
| 202 | + attrs = provider.runtime_log_attrs() |
| 203 | + assert attrs["provider_batch_size"] == 32 |
| 204 | + assert attrs["request_concurrency"] == 8 |
0 commit comments