Skip to content

Commit b106193

Browse files
committed
fix: rewrite tests to exercise provider, fix default model, update lockfile
1 parent 849f9f5 commit b106193

3 files changed

Lines changed: 809 additions & 178 deletions

File tree

src/basic_memory/repository/embedding_provider_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ def create_embedding_provider(app_config: BasicMemoryConfig) -> EmbeddingProvide
155155
from basic_memory.repository.litellm_provider import LiteLLMEmbeddingProvider
156156

157157
model_name = app_config.semantic_embedding_model or "openai/text-embedding-3-small"
158+
if model_name == "bge-small-en-v1.5":
159+
model_name = "openai/text-embedding-3-small"
158160
provider = LiteLLMEmbeddingProvider(
159161
model_name=model_name,
160162
batch_size=app_config.semantic_embedding_batch_size,
Lines changed: 198 additions & 175 deletions
Original file line numberDiff line numberDiff line change
@@ -1,181 +1,204 @@
1-
"""Tests for LiteLLMEmbeddingProvider.
1+
"""Tests for LiteLLMEmbeddingProvider and factory litellm branch."""
22

3-
Uses AST parsing and direct SDK mocking to avoid importing the full
4-
basic_memory dependency chain (logfire, alembic, etc.).
5-
"""
6-
7-
import ast
3+
import asyncio
4+
import builtins
85
import sys
9-
import types
10-
from pathlib import Path
11-
from unittest.mock import AsyncMock, MagicMock
6+
from types import SimpleNamespace
127

138
import pytest
149

15-
PROVIDER_PATH = (
16-
Path(__file__).resolve().parents[2]
17-
/ "src"
18-
/ "basic_memory"
19-
/ "repository"
20-
/ "litellm_provider.py"
21-
)
22-
FACTORY_PATH = (
23-
Path(__file__).resolve().parents[2]
24-
/ "src"
25-
/ "basic_memory"
26-
/ "repository"
27-
/ "embedding_provider_factory.py"
10+
from basic_memory.config import BasicMemoryConfig
11+
from basic_memory.repository.embedding_provider_factory import (
12+
create_embedding_provider,
13+
reset_embedding_provider_cache,
2814
)
29-
30-
31-
class TestLiteLLMProviderStructure:
32-
"""Verify the provider file has the correct structure."""
33-
34-
def _parse(self):
35-
return ast.parse(PROVIDER_PATH.read_text())
36-
37-
def test_file_exists(self):
38-
assert PROVIDER_PATH.exists()
39-
40-
def test_has_litellm_embedding_provider_class(self):
41-
tree = self._parse()
42-
classes = [n.name for n in ast.walk(tree) if isinstance(n, ast.ClassDef)]
43-
assert "LiteLLMEmbeddingProvider" in classes
44-
45-
def test_has_embed_documents_method(self):
46-
tree = self._parse()
47-
for node in ast.walk(tree):
48-
if isinstance(node, ast.ClassDef) and node.name == "LiteLLMEmbeddingProvider":
49-
methods = [
50-
n.name
51-
for n in node.body
52-
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
53-
]
54-
assert "embed_documents" in methods
55-
assert "embed_query" in methods
56-
return
57-
pytest.fail("LiteLLMEmbeddingProvider class not found")
58-
59-
def test_embed_documents_is_async(self):
60-
tree = self._parse()
61-
for node in ast.walk(tree):
62-
if isinstance(node, ast.ClassDef) and node.name == "LiteLLMEmbeddingProvider":
63-
for item in node.body:
64-
if isinstance(item, ast.AsyncFunctionDef) and item.name == "embed_documents":
65-
return
66-
pytest.fail("embed_documents is not async")
67-
68-
def test_uses_drop_params_true(self):
69-
src = PROVIDER_PATH.read_text()
70-
assert "drop_params" in src
71-
72-
def test_uses_litellm_aembedding(self):
73-
src = PROVIDER_PATH.read_text()
74-
assert "aembedding" in src
75-
76-
def test_has_runtime_log_attrs(self):
77-
tree = self._parse()
78-
for node in ast.walk(tree):
79-
if isinstance(node, ast.ClassDef) and node.name == "LiteLLMEmbeddingProvider":
80-
methods = [
81-
n.name
82-
for n in node.body
83-
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
84-
]
85-
assert "runtime_log_attrs" in methods
86-
return
87-
88-
def test_default_model_in_source(self):
89-
src = PROVIDER_PATH.read_text()
90-
assert "openai/text-embedding-3-small" in src
91-
92-
93-
class TestFactoryRegistration:
94-
"""Verify the factory recognizes litellm as a provider."""
95-
96-
def test_litellm_branch_in_factory(self):
97-
src = FACTORY_PATH.read_text()
98-
assert 'provider_name == "litellm"' in src
99-
100-
def test_imports_litellm_provider(self):
101-
src = FACTORY_PATH.read_text()
102-
assert "LiteLLMEmbeddingProvider" in src
103-
104-
105-
class TestLiteLLMSDKInteraction:
106-
"""Test litellm SDK calls directly (no basic_memory deps needed)."""
107-
108-
def test_aembedding_called_with_drop_params(self):
109-
fake = types.ModuleType("litellm")
110-
response = MagicMock()
111-
response.data = [{"index": 0, "embedding": [0.1, 0.2]}]
112-
fake.aembedding = AsyncMock(return_value=response)
113-
sys.modules["litellm"] = fake
114-
115-
try:
116-
import asyncio
117-
118-
async def run():
119-
await fake.aembedding(
120-
model="openai/text-embedding-3-small",
121-
input=["hello"],
122-
drop_params=True,
123-
)
124-
125-
asyncio.run(run())
126-
kwargs = fake.aembedding.call_args.kwargs
127-
assert kwargs["drop_params"] is True
128-
assert kwargs["model"] == "openai/text-embedding-3-small"
129-
finally:
130-
del sys.modules["litellm"]
131-
132-
def test_aembedding_forwards_api_key(self):
133-
fake = types.ModuleType("litellm")
134-
response = MagicMock()
135-
response.data = [{"index": 0, "embedding": [0.1]}]
136-
fake.aembedding = AsyncMock(return_value=response)
137-
sys.modules["litellm"] = fake
138-
139-
try:
140-
import asyncio
141-
142-
async def run():
143-
await fake.aembedding(
144-
model="openai/text-embedding-3-small",
145-
input=["hello"],
146-
api_key="sk-test",
147-
drop_params=True,
148-
)
149-
150-
asyncio.run(run())
151-
assert fake.aembedding.call_args.kwargs["api_key"] == "sk-test"
152-
finally:
153-
del sys.modules["litellm"]
154-
155-
def test_aembedding_response_has_vectors(self):
156-
fake = types.ModuleType("litellm")
157-
response = MagicMock()
158-
response.data = [
159-
{"index": 0, "embedding": [0.1, 0.2, 0.3]},
160-
{"index": 1, "embedding": [0.4, 0.5, 0.6]},
161-
]
162-
fake.aembedding = AsyncMock(return_value=response)
163-
sys.modules["litellm"] = fake
164-
165-
try:
166-
import asyncio
167-
168-
async def run():
169-
resp = await fake.aembedding(
170-
model="openai/text-embedding-3-small",
171-
input=["hello", "world"],
172-
drop_params=True,
173-
)
174-
return resp
175-
176-
resp = asyncio.run(run())
177-
assert len(resp.data) == 2
178-
assert resp.data[0]["embedding"] == [0.1, 0.2, 0.3]
179-
assert resp.data[1]["embedding"] == [0.4, 0.5, 0.6]
180-
finally:
181-
del sys.modules["litellm"]
15+
from basic_memory.repository.litellm_provider import LiteLLMEmbeddingProvider
16+
from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError
17+
18+
19+
def _make_embedding_response(inputs: list[str], dim: int = 3):
20+
"""Build a fake litellm.aembedding response matching the real shape."""
21+
data = []
22+
for index, text in enumerate(inputs):
23+
base = float(len(text))
24+
data.append({"index": index, "embedding": [base + float(d) for d in range(dim)]})
25+
return SimpleNamespace(data=data)
26+
27+
28+
def _install_litellm_stub(monkeypatch, dim: int = 3):
29+
"""Install a fake litellm module and return the mock aembedding callable."""
30+
calls: list[dict] = []
31+
32+
async def _aembedding(**kwargs):
33+
calls.append(kwargs)
34+
return _make_embedding_response(kwargs["input"], dim)
35+
36+
module = type(sys)("litellm")
37+
setattr(module, "aembedding", _aembedding)
38+
monkeypatch.setitem(sys.modules, "litellm", module)
39+
return calls
40+
41+
42+
@pytest.fixture(autouse=True)
43+
def _reset_cache():
44+
reset_embedding_provider_cache()
45+
yield
46+
reset_embedding_provider_cache()
47+
48+
49+
@pytest.mark.asyncio
50+
async def test_litellm_provider_embed_query(monkeypatch):
51+
"""embed_query should return a single vector through litellm.aembedding."""
52+
_install_litellm_stub(monkeypatch)
53+
provider = LiteLLMEmbeddingProvider(
54+
model_name="openai/text-embedding-3-small", batch_size=2, dimensions=3
55+
)
56+
result = await provider.embed_query("hello world")
57+
assert len(result) == 3
58+
assert all(isinstance(v, float) for v in result)
59+
60+
61+
@pytest.mark.asyncio
62+
async def test_litellm_provider_embed_documents(monkeypatch):
63+
"""embed_documents should return vectors for each input text."""
64+
_install_litellm_stub(monkeypatch)
65+
provider = LiteLLMEmbeddingProvider(
66+
model_name="openai/text-embedding-3-small", batch_size=2, dimensions=3
67+
)
68+
texts = ["first doc", "second doc", "third doc"]
69+
result = await provider.embed_documents(texts)
70+
assert len(result) == 3
71+
assert all(len(v) == 3 for v in result)
72+
73+
74+
@pytest.mark.asyncio
75+
async def test_litellm_provider_empty_input(monkeypatch):
76+
"""embed_documents with empty list should return empty list."""
77+
_install_litellm_stub(monkeypatch)
78+
provider = LiteLLMEmbeddingProvider(dimensions=3)
79+
result = await provider.embed_documents([])
80+
assert result == []
81+
82+
83+
@pytest.mark.asyncio
84+
async def test_litellm_provider_batching(monkeypatch):
85+
"""Provider should split inputs into batches of batch_size."""
86+
calls = _install_litellm_stub(monkeypatch)
87+
provider = LiteLLMEmbeddingProvider(
88+
model_name="openai/text-embedding-3-small", batch_size=2, dimensions=3
89+
)
90+
texts = ["a", "b", "c", "d", "e"]
91+
result = await provider.embed_documents(texts)
92+
93+
assert len(result) == 5
94+
assert len(calls) == 3 # 2 + 2 + 1
95+
96+
97+
@pytest.mark.asyncio
98+
async def test_litellm_provider_api_key_forwarded(monkeypatch):
99+
"""api_key should be passed to litellm.aembedding when set."""
100+
calls = _install_litellm_stub(monkeypatch)
101+
provider = LiteLLMEmbeddingProvider(
102+
model_name="openai/text-embedding-3-small",
103+
api_key="sk-test-key",
104+
dimensions=3,
105+
)
106+
await provider.embed_query("test")
107+
assert calls[0]["api_key"] == "sk-test-key"
108+
109+
110+
@pytest.mark.asyncio
111+
async def test_litellm_provider_api_key_omitted_when_none(monkeypatch):
112+
"""api_key should not appear in kwargs when not set."""
113+
calls = _install_litellm_stub(monkeypatch)
114+
provider = LiteLLMEmbeddingProvider(
115+
model_name="openai/text-embedding-3-small", dimensions=3
116+
)
117+
await provider.embed_query("test")
118+
assert "api_key" not in calls[0]
119+
120+
121+
@pytest.mark.asyncio
122+
async def test_litellm_provider_drop_params_always_set(monkeypatch):
123+
"""drop_params=True should always be in the call kwargs."""
124+
calls = _install_litellm_stub(monkeypatch)
125+
provider = LiteLLMEmbeddingProvider(dimensions=3)
126+
await provider.embed_query("test")
127+
assert calls[0]["drop_params"] is True
128+
129+
130+
@pytest.mark.asyncio
131+
async def test_litellm_provider_dimension_mismatch_raises_error(monkeypatch):
132+
"""Provider should fail fast when response dimensions differ from configured."""
133+
_install_litellm_stub(monkeypatch, dim=3)
134+
provider = LiteLLMEmbeddingProvider(dimensions=5)
135+
with pytest.raises(RuntimeError, match="3-dimensional vectors"):
136+
await provider.embed_documents(["test text"])
137+
138+
139+
@pytest.mark.asyncio
140+
async def test_litellm_provider_missing_dependency_raises_actionable_error(monkeypatch):
141+
"""Missing litellm package should raise SemanticDependenciesMissingError."""
142+
monkeypatch.delitem(sys.modules, "litellm", raising=False)
143+
original_import = builtins.__import__
144+
145+
def _raising_import(name, globals=None, locals=None, fromlist=(), level=0):
146+
if name == "litellm":
147+
raise ImportError("litellm not installed")
148+
return original_import(name, globals, locals, fromlist, level)
149+
150+
monkeypatch.setattr(builtins, "__import__", _raising_import)
151+
152+
provider = LiteLLMEmbeddingProvider(model_name="openai/text-embedding-3-small")
153+
with pytest.raises(SemanticDependenciesMissingError):
154+
await provider.embed_query("test")
155+
156+
157+
@pytest.mark.asyncio
158+
async def test_litellm_provider_output_ordering(monkeypatch):
159+
"""Vectors should be returned in the same order as input texts."""
160+
_install_litellm_stub(monkeypatch)
161+
provider = LiteLLMEmbeddingProvider(dimensions=3, batch_size=2)
162+
texts = ["short", "a longer text here"]
163+
result = await provider.embed_documents(texts)
164+
165+
assert result[0][0] == float(len("short"))
166+
assert result[1][0] == float(len("a longer text here"))
167+
168+
169+
def test_factory_selects_litellm_provider():
170+
"""Factory should select LiteLLMEmbeddingProvider for litellm config."""
171+
config = BasicMemoryConfig(
172+
env="test",
173+
projects={"test": "/tmp/basic-memory-test"},
174+
default_project="test",
175+
semantic_search_enabled=True,
176+
semantic_embedding_provider="litellm",
177+
semantic_embedding_model="openai/text-embedding-3-small",
178+
)
179+
provider = create_embedding_provider(config)
180+
assert isinstance(provider, LiteLLMEmbeddingProvider)
181+
assert provider.model_name == "openai/text-embedding-3-small"
182+
183+
184+
def test_factory_maps_default_model_for_litellm():
185+
"""Factory should remap bge-small-en-v1.5 default to openai/text-embedding-3-small."""
186+
config = BasicMemoryConfig(
187+
env="test",
188+
projects={"test": "/tmp/basic-memory-test"},
189+
default_project="test",
190+
semantic_search_enabled=True,
191+
semantic_embedding_provider="litellm",
192+
semantic_embedding_model="bge-small-en-v1.5",
193+
)
194+
provider = create_embedding_provider(config)
195+
assert isinstance(provider, LiteLLMEmbeddingProvider)
196+
assert provider.model_name == "openai/text-embedding-3-small"
197+
198+
199+
def test_runtime_log_attrs():
200+
"""runtime_log_attrs should return batch_size and concurrency."""
201+
provider = LiteLLMEmbeddingProvider(batch_size=32, request_concurrency=8)
202+
attrs = provider.runtime_log_attrs()
203+
assert attrs["provider_batch_size"] == 32
204+
assert attrs["request_concurrency"] == 8

0 commit comments

Comments
 (0)