Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
"fastembed>=0.7.4",
"sqlite-vec>=0.1.6",
"openai>=1.100.2",
"litellm>=1.60.0,<2.0.0",
"logfire>=4.19.0",
"psutil>=5.9.0",
]
Expand Down
12 changes: 12 additions & 0 deletions src/basic_memory/repository/embedding_provider_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ def create_embedding_provider(app_config: BasicMemoryConfig) -> EmbeddingProvide
request_concurrency=app_config.semantic_embedding_request_concurrency,
**extra_kwargs,
)
elif provider_name == "litellm":
from basic_memory.repository.litellm_provider import LiteLLMEmbeddingProvider

model_name = app_config.semantic_embedding_model or "openai/text-embedding-3-small"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Map the built-in default model for LiteLLM

When users switch only semantic_embedding_provider to litellm, BasicMemoryConfig still supplies the non-empty default model bge-small-en-v1.5, so this or never selects the LiteLLM provider default. The factory then instantiates LiteLLMEmbeddingProvider(model_name="bge-small-en-v1.5") instead of a LiteLLM-routable model such as openai/text-embedding-3-small, making the new provider fail for the documented minimal configuration; mirror the OpenAI branch's remapping of the FastEmbed default or otherwise treat it as unset.

Useful? React with 👍 / 👎.

if model_name == "bge-small-en-v1.5":
model_name = "openai/text-embedding-3-small"
provider = LiteLLMEmbeddingProvider(
model_name=model_name,
batch_size=app_config.semantic_embedding_batch_size,
request_concurrency=app_config.semantic_embedding_request_concurrency,
**extra_kwargs,
)
else:
raise ValueError(f"Unsupported semantic embedding provider: {provider_name}")

Expand Down
116 changes: 116 additions & 0 deletions src/basic_memory/repository/litellm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""LiteLLM-based embedding provider for semantic indexing.

Routes embedding requests to 100+ providers (OpenAI, Anthropic, Google, Azure,
Bedrock, Cohere, etc.) via the litellm SDK. No proxy server needed.

Model strings use the ``provider/model`` format, e.g.
``openai/text-embedding-3-small``, ``cohere/embed-english-v3.0``,
``azure/my-embedding-deployment``.

See https://docs.litellm.ai/docs/embedding/supported_embedding for all
supported embedding models.
"""

from __future__ import annotations

import asyncio
from typing import Any

from basic_memory.repository.embedding_provider import EmbeddingProvider
from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError


class LiteLLMEmbeddingProvider(EmbeddingProvider):
"""Embedding provider backed by the litellm SDK."""

def __init__(
self,
model_name: str = "openai/text-embedding-3-small",
*,
batch_size: int = 64,
request_concurrency: int = 4,
dimensions: int = 1536,
api_key: str | None = None,
timeout: float = 30.0,
) -> None:
self.model_name = model_name
self.dimensions = dimensions
self.batch_size = batch_size
self.request_concurrency = request_concurrency
self._api_key = api_key
self._timeout = timeout

def runtime_log_attrs(self) -> dict[str, int]:
"""Return provider-specific runtime settings suitable for startup logs."""
return {
"provider_batch_size": self.batch_size,
"request_concurrency": self.request_concurrency,
}

async def embed_documents(self, texts: list[str]) -> list[list[float]]:
if not texts:
return []

try:
import litellm
except ImportError as exc:
raise SemanticDependenciesMissingError(
"litellm dependency is missing. Install with: pip install litellm"
) from exc

batches = [
texts[start : start + self.batch_size]
for start in range(0, len(texts), self.batch_size)
]
batch_vectors: list[list[list[float]] | None] = [None] * len(batches)
semaphore = asyncio.Semaphore(self.request_concurrency)

async def embed_batch(batch_index: int, batch: list[str]) -> None:
async with semaphore:
params: dict[str, Any] = {
"model": self.model_name,
"input": batch,
"drop_params": True,
"timeout": self._timeout,
}
if self._api_key:
params["api_key"] = self._api_key

response = await litellm.aembedding(**params)

vectors_by_index: dict[int, list[float]] = {}
for item in response.data:
response_index = int(item["index"])
vectors_by_index[response_index] = [float(v) for v in item["embedding"]]

ordered_vectors: list[list[float]] = []
for index in range(len(batch)):
vector = vectors_by_index.get(index)
if vector is None:
raise RuntimeError(
"LiteLLM embedding response is missing expected vector index."
)
ordered_vectors.append(vector)

batch_vectors[batch_index] = ordered_vectors

await asyncio.gather(
*(embed_batch(batch_index, batch) for batch_index, batch in enumerate(batches))
)

all_vectors: list[list[float]] = []
for vectors in batch_vectors:
if vectors is None:
raise RuntimeError("LiteLLM embedding batch did not produce vectors.")
all_vectors.extend(vectors)

if all_vectors and len(all_vectors[0]) != self.dimensions:
raise RuntimeError(
f"Embedding model returned {len(all_vectors[0])}-dimensional vectors "
f"but provider was configured for {self.dimensions} dimensions."
)
return all_vectors

async def embed_query(self, text: str) -> list[float]:
vectors = await self.embed_documents([text])
return vectors[0] if vectors else [0.0] * self.dimensions
204 changes: 204 additions & 0 deletions tests/repository/test_litellm_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""Tests for LiteLLMEmbeddingProvider and factory litellm branch."""

import asyncio
import builtins
import sys
from types import SimpleNamespace

import pytest

from basic_memory.config import BasicMemoryConfig
from basic_memory.repository.embedding_provider_factory import (
create_embedding_provider,
reset_embedding_provider_cache,
)
from basic_memory.repository.litellm_provider import LiteLLMEmbeddingProvider
from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError


def _make_embedding_response(inputs: list[str], dim: int = 3):
"""Build a fake litellm.aembedding response matching the real shape."""
data = []
for index, text in enumerate(inputs):
base = float(len(text))
data.append({"index": index, "embedding": [base + float(d) for d in range(dim)]})
return SimpleNamespace(data=data)


def _install_litellm_stub(monkeypatch, dim: int = 3):
"""Install a fake litellm module and return the mock aembedding callable."""
calls: list[dict] = []

async def _aembedding(**kwargs):
calls.append(kwargs)
return _make_embedding_response(kwargs["input"], dim)

module = type(sys)("litellm")
setattr(module, "aembedding", _aembedding)
monkeypatch.setitem(sys.modules, "litellm", module)
return calls


@pytest.fixture(autouse=True)
def _reset_cache():
reset_embedding_provider_cache()
yield
reset_embedding_provider_cache()


@pytest.mark.asyncio
async def test_litellm_provider_embed_query(monkeypatch):
"""embed_query should return a single vector through litellm.aembedding."""
_install_litellm_stub(monkeypatch)
provider = LiteLLMEmbeddingProvider(
model_name="openai/text-embedding-3-small", batch_size=2, dimensions=3
)
result = await provider.embed_query("hello world")
assert len(result) == 3
assert all(isinstance(v, float) for v in result)


@pytest.mark.asyncio
async def test_litellm_provider_embed_documents(monkeypatch):
"""embed_documents should return vectors for each input text."""
_install_litellm_stub(monkeypatch)
provider = LiteLLMEmbeddingProvider(
model_name="openai/text-embedding-3-small", batch_size=2, dimensions=3
)
texts = ["first doc", "second doc", "third doc"]
result = await provider.embed_documents(texts)
assert len(result) == 3
assert all(len(v) == 3 for v in result)


@pytest.mark.asyncio
async def test_litellm_provider_empty_input(monkeypatch):
"""embed_documents with empty list should return empty list."""
_install_litellm_stub(monkeypatch)
provider = LiteLLMEmbeddingProvider(dimensions=3)
result = await provider.embed_documents([])
assert result == []


@pytest.mark.asyncio
async def test_litellm_provider_batching(monkeypatch):
"""Provider should split inputs into batches of batch_size."""
calls = _install_litellm_stub(monkeypatch)
provider = LiteLLMEmbeddingProvider(
model_name="openai/text-embedding-3-small", batch_size=2, dimensions=3
)
texts = ["a", "b", "c", "d", "e"]
result = await provider.embed_documents(texts)

assert len(result) == 5
assert len(calls) == 3 # 2 + 2 + 1


@pytest.mark.asyncio
async def test_litellm_provider_api_key_forwarded(monkeypatch):
"""api_key should be passed to litellm.aembedding when set."""
calls = _install_litellm_stub(monkeypatch)
provider = LiteLLMEmbeddingProvider(
model_name="openai/text-embedding-3-small",
api_key="sk-test-key",
dimensions=3,
)
await provider.embed_query("test")
assert calls[0]["api_key"] == "sk-test-key"


@pytest.mark.asyncio
async def test_litellm_provider_api_key_omitted_when_none(monkeypatch):
"""api_key should not appear in kwargs when not set."""
calls = _install_litellm_stub(monkeypatch)
provider = LiteLLMEmbeddingProvider(
model_name="openai/text-embedding-3-small", dimensions=3
)
await provider.embed_query("test")
assert "api_key" not in calls[0]


@pytest.mark.asyncio
async def test_litellm_provider_drop_params_always_set(monkeypatch):
"""drop_params=True should always be in the call kwargs."""
calls = _install_litellm_stub(monkeypatch)
provider = LiteLLMEmbeddingProvider(dimensions=3)
await provider.embed_query("test")
assert calls[0]["drop_params"] is True


@pytest.mark.asyncio
async def test_litellm_provider_dimension_mismatch_raises_error(monkeypatch):
"""Provider should fail fast when response dimensions differ from configured."""
_install_litellm_stub(monkeypatch, dim=3)
provider = LiteLLMEmbeddingProvider(dimensions=5)
with pytest.raises(RuntimeError, match="3-dimensional vectors"):
await provider.embed_documents(["test text"])


@pytest.mark.asyncio
async def test_litellm_provider_missing_dependency_raises_actionable_error(monkeypatch):
"""Missing litellm package should raise SemanticDependenciesMissingError."""
monkeypatch.delitem(sys.modules, "litellm", raising=False)
original_import = builtins.__import__

def _raising_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == "litellm":
raise ImportError("litellm not installed")
return original_import(name, globals, locals, fromlist, level)

monkeypatch.setattr(builtins, "__import__", _raising_import)

provider = LiteLLMEmbeddingProvider(model_name="openai/text-embedding-3-small")
with pytest.raises(SemanticDependenciesMissingError):
await provider.embed_query("test")


@pytest.mark.asyncio
async def test_litellm_provider_output_ordering(monkeypatch):
"""Vectors should be returned in the same order as input texts."""
_install_litellm_stub(monkeypatch)
provider = LiteLLMEmbeddingProvider(dimensions=3, batch_size=2)
texts = ["short", "a longer text here"]
result = await provider.embed_documents(texts)

assert result[0][0] == float(len("short"))
assert result[1][0] == float(len("a longer text here"))


def test_factory_selects_litellm_provider():
"""Factory should select LiteLLMEmbeddingProvider for litellm config."""
config = BasicMemoryConfig(
env="test",
projects={"test": "/tmp/basic-memory-test"},
default_project="test",
semantic_search_enabled=True,
semantic_embedding_provider="litellm",
semantic_embedding_model="openai/text-embedding-3-small",
)
provider = create_embedding_provider(config)
assert isinstance(provider, LiteLLMEmbeddingProvider)
assert provider.model_name == "openai/text-embedding-3-small"


def test_factory_maps_default_model_for_litellm():
"""Factory should remap bge-small-en-v1.5 default to openai/text-embedding-3-small."""
config = BasicMemoryConfig(
env="test",
projects={"test": "/tmp/basic-memory-test"},
default_project="test",
semantic_search_enabled=True,
semantic_embedding_provider="litellm",
semantic_embedding_model="bge-small-en-v1.5",
)
provider = create_embedding_provider(config)
assert isinstance(provider, LiteLLMEmbeddingProvider)
assert provider.model_name == "openai/text-embedding-3-small"


def test_runtime_log_attrs():
"""runtime_log_attrs should return batch_size and concurrency."""
provider = LiteLLMEmbeddingProvider(batch_size=32, request_concurrency=8)
attrs = provider.runtime_log_attrs()
assert attrs["provider_batch_size"] == 32
assert attrs["request_concurrency"] == 8
Loading