Skip to content

Commit c029eb3

Browse files
committed
feat: add LiteLLM as embedding provider
1 parent df5e8d8 commit c029eb3

4 files changed

Lines changed: 308 additions & 0 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dependencies = [
4747
"fastembed>=0.7.4",
4848
"sqlite-vec>=0.1.6",
4949
"openai>=1.100.2",
50+
"litellm>=1.60.0,<2.0.0",
5051
"logfire>=4.19.0",
5152
"psutil>=5.9.0",
5253
]

src/basic_memory/repository/embedding_provider_factory.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,16 @@ def create_embedding_provider(app_config: BasicMemoryConfig) -> EmbeddingProvide
151151
request_concurrency=app_config.semantic_embedding_request_concurrency,
152152
**extra_kwargs,
153153
)
154+
elif provider_name == "litellm":
155+
from basic_memory.repository.litellm_provider import LiteLLMEmbeddingProvider
156+
157+
model_name = app_config.semantic_embedding_model or "openai/text-embedding-3-small"
158+
provider = LiteLLMEmbeddingProvider(
159+
model_name=model_name,
160+
batch_size=app_config.semantic_embedding_batch_size,
161+
request_concurrency=app_config.semantic_embedding_request_concurrency,
162+
**extra_kwargs,
163+
)
154164
else:
155165
raise ValueError(f"Unsupported semantic embedding provider: {provider_name}")
156166

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""LiteLLM-based embedding provider for semantic indexing.
2+
3+
Routes embedding requests to 100+ providers (OpenAI, Anthropic, Google, Azure,
4+
Bedrock, Cohere, etc.) via the litellm SDK. No proxy server needed.
5+
6+
Model strings use the ``provider/model`` format, e.g.
7+
``openai/text-embedding-3-small``, ``cohere/embed-english-v3.0``,
8+
``azure/my-embedding-deployment``.
9+
10+
See https://docs.litellm.ai/docs/embedding/supported_embedding for all
11+
supported embedding models.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import asyncio
17+
from typing import Any
18+
19+
from basic_memory.repository.embedding_provider import EmbeddingProvider
20+
from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError
21+
22+
23+
class LiteLLMEmbeddingProvider(EmbeddingProvider):
24+
"""Embedding provider backed by the litellm SDK."""
25+
26+
def __init__(
27+
self,
28+
model_name: str = "openai/text-embedding-3-small",
29+
*,
30+
batch_size: int = 64,
31+
request_concurrency: int = 4,
32+
dimensions: int = 1536,
33+
api_key: str | None = None,
34+
timeout: float = 30.0,
35+
) -> None:
36+
self.model_name = model_name
37+
self.dimensions = dimensions
38+
self.batch_size = batch_size
39+
self.request_concurrency = request_concurrency
40+
self._api_key = api_key
41+
self._timeout = timeout
42+
43+
def runtime_log_attrs(self) -> dict[str, int]:
44+
"""Return provider-specific runtime settings suitable for startup logs."""
45+
return {
46+
"provider_batch_size": self.batch_size,
47+
"request_concurrency": self.request_concurrency,
48+
}
49+
50+
async def embed_documents(self, texts: list[str]) -> list[list[float]]:
51+
if not texts:
52+
return []
53+
54+
try:
55+
import litellm
56+
except ImportError as exc:
57+
raise SemanticDependenciesMissingError(
58+
"litellm dependency is missing. Install with: pip install litellm"
59+
) from exc
60+
61+
batches = [
62+
texts[start : start + self.batch_size]
63+
for start in range(0, len(texts), self.batch_size)
64+
]
65+
batch_vectors: list[list[list[float]] | None] = [None] * len(batches)
66+
semaphore = asyncio.Semaphore(self.request_concurrency)
67+
68+
async def embed_batch(batch_index: int, batch: list[str]) -> None:
69+
async with semaphore:
70+
params: dict[str, Any] = {
71+
"model": self.model_name,
72+
"input": batch,
73+
"drop_params": True,
74+
"timeout": self._timeout,
75+
}
76+
if self._api_key:
77+
params["api_key"] = self._api_key
78+
79+
response = await litellm.aembedding(**params)
80+
81+
vectors_by_index: dict[int, list[float]] = {}
82+
for item in response.data:
83+
response_index = int(item["index"])
84+
vectors_by_index[response_index] = [float(v) for v in item["embedding"]]
85+
86+
ordered_vectors: list[list[float]] = []
87+
for index in range(len(batch)):
88+
vector = vectors_by_index.get(index)
89+
if vector is None:
90+
raise RuntimeError(
91+
"LiteLLM embedding response is missing expected vector index."
92+
)
93+
ordered_vectors.append(vector)
94+
95+
batch_vectors[batch_index] = ordered_vectors
96+
97+
await asyncio.gather(
98+
*(embed_batch(batch_index, batch) for batch_index, batch in enumerate(batches))
99+
)
100+
101+
all_vectors: list[list[float]] = []
102+
for vectors in batch_vectors:
103+
if vectors is None:
104+
raise RuntimeError("LiteLLM embedding batch did not produce vectors.")
105+
all_vectors.extend(vectors)
106+
107+
if all_vectors and len(all_vectors[0]) != self.dimensions:
108+
raise RuntimeError(
109+
f"Embedding model returned {len(all_vectors[0])}-dimensional vectors "
110+
f"but provider was configured for {self.dimensions} dimensions."
111+
)
112+
return all_vectors
113+
114+
async def embed_query(self, text: str) -> list[float]:
115+
vectors = await self.embed_documents([text])
116+
return vectors[0] if vectors else [0.0] * self.dimensions
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
"""Tests for LiteLLMEmbeddingProvider.
2+
3+
Uses AST parsing and direct SDK mocking to avoid importing the full
4+
basic_memory dependency chain (logfire, alembic, etc.).
5+
"""
6+
7+
import ast
8+
import sys
9+
import types
10+
from pathlib import Path
11+
from unittest.mock import AsyncMock, MagicMock
12+
13+
import pytest
14+
15+
PROVIDER_PATH = (
16+
Path(__file__).resolve().parents[2]
17+
/ "src"
18+
/ "basic_memory"
19+
/ "repository"
20+
/ "litellm_provider.py"
21+
)
22+
FACTORY_PATH = (
23+
Path(__file__).resolve().parents[2]
24+
/ "src"
25+
/ "basic_memory"
26+
/ "repository"
27+
/ "embedding_provider_factory.py"
28+
)
29+
30+
31+
class TestLiteLLMProviderStructure:
32+
"""Verify the provider file has the correct structure."""
33+
34+
def _parse(self):
35+
return ast.parse(PROVIDER_PATH.read_text())
36+
37+
def test_file_exists(self):
38+
assert PROVIDER_PATH.exists()
39+
40+
def test_has_litellm_embedding_provider_class(self):
41+
tree = self._parse()
42+
classes = [n.name for n in ast.walk(tree) if isinstance(n, ast.ClassDef)]
43+
assert "LiteLLMEmbeddingProvider" in classes
44+
45+
def test_has_embed_documents_method(self):
46+
tree = self._parse()
47+
for node in ast.walk(tree):
48+
if isinstance(node, ast.ClassDef) and node.name == "LiteLLMEmbeddingProvider":
49+
methods = [
50+
n.name
51+
for n in node.body
52+
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
53+
]
54+
assert "embed_documents" in methods
55+
assert "embed_query" in methods
56+
return
57+
pytest.fail("LiteLLMEmbeddingProvider class not found")
58+
59+
def test_embed_documents_is_async(self):
60+
tree = self._parse()
61+
for node in ast.walk(tree):
62+
if isinstance(node, ast.ClassDef) and node.name == "LiteLLMEmbeddingProvider":
63+
for item in node.body:
64+
if isinstance(item, ast.AsyncFunctionDef) and item.name == "embed_documents":
65+
return
66+
pytest.fail("embed_documents is not async")
67+
68+
def test_uses_drop_params_true(self):
69+
src = PROVIDER_PATH.read_text()
70+
assert "drop_params" in src
71+
72+
def test_uses_litellm_aembedding(self):
73+
src = PROVIDER_PATH.read_text()
74+
assert "aembedding" in src
75+
76+
def test_has_runtime_log_attrs(self):
77+
tree = self._parse()
78+
for node in ast.walk(tree):
79+
if isinstance(node, ast.ClassDef) and node.name == "LiteLLMEmbeddingProvider":
80+
methods = [
81+
n.name
82+
for n in node.body
83+
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))
84+
]
85+
assert "runtime_log_attrs" in methods
86+
return
87+
88+
def test_default_model_in_source(self):
89+
src = PROVIDER_PATH.read_text()
90+
assert "openai/text-embedding-3-small" in src
91+
92+
93+
class TestFactoryRegistration:
94+
"""Verify the factory recognizes litellm as a provider."""
95+
96+
def test_litellm_branch_in_factory(self):
97+
src = FACTORY_PATH.read_text()
98+
assert 'provider_name == "litellm"' in src
99+
100+
def test_imports_litellm_provider(self):
101+
src = FACTORY_PATH.read_text()
102+
assert "LiteLLMEmbeddingProvider" in src
103+
104+
105+
class TestLiteLLMSDKInteraction:
106+
"""Test litellm SDK calls directly (no basic_memory deps needed)."""
107+
108+
def test_aembedding_called_with_drop_params(self):
109+
fake = types.ModuleType("litellm")
110+
response = MagicMock()
111+
response.data = [{"index": 0, "embedding": [0.1, 0.2]}]
112+
fake.aembedding = AsyncMock(return_value=response)
113+
sys.modules["litellm"] = fake
114+
115+
try:
116+
import asyncio
117+
118+
async def run():
119+
await fake.aembedding(
120+
model="openai/text-embedding-3-small",
121+
input=["hello"],
122+
drop_params=True,
123+
)
124+
125+
asyncio.run(run())
126+
kwargs = fake.aembedding.call_args.kwargs
127+
assert kwargs["drop_params"] is True
128+
assert kwargs["model"] == "openai/text-embedding-3-small"
129+
finally:
130+
del sys.modules["litellm"]
131+
132+
def test_aembedding_forwards_api_key(self):
133+
fake = types.ModuleType("litellm")
134+
response = MagicMock()
135+
response.data = [{"index": 0, "embedding": [0.1]}]
136+
fake.aembedding = AsyncMock(return_value=response)
137+
sys.modules["litellm"] = fake
138+
139+
try:
140+
import asyncio
141+
142+
async def run():
143+
await fake.aembedding(
144+
model="openai/text-embedding-3-small",
145+
input=["hello"],
146+
api_key="sk-test",
147+
drop_params=True,
148+
)
149+
150+
asyncio.run(run())
151+
assert fake.aembedding.call_args.kwargs["api_key"] == "sk-test"
152+
finally:
153+
del sys.modules["litellm"]
154+
155+
def test_aembedding_response_has_vectors(self):
156+
fake = types.ModuleType("litellm")
157+
response = MagicMock()
158+
response.data = [
159+
{"index": 0, "embedding": [0.1, 0.2, 0.3]},
160+
{"index": 1, "embedding": [0.4, 0.5, 0.6]},
161+
]
162+
fake.aembedding = AsyncMock(return_value=response)
163+
sys.modules["litellm"] = fake
164+
165+
try:
166+
import asyncio
167+
168+
async def run():
169+
resp = await fake.aembedding(
170+
model="openai/text-embedding-3-small",
171+
input=["hello", "world"],
172+
drop_params=True,
173+
)
174+
return resp
175+
176+
resp = asyncio.run(run())
177+
assert len(resp.data) == 2
178+
assert resp.data[0]["embedding"] == [0.1, 0.2, 0.3]
179+
assert resp.data[1]["embedding"] == [0.4, 0.5, 0.6]
180+
finally:
181+
del sys.modules["litellm"]

0 commit comments

Comments
 (0)