-
Notifications
You must be signed in to change notification settings - Fork 202
feat: add LiteLLM as embedding provider #809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
RheagalFire
wants to merge
3
commits into
basicmachines-co:main
Choose a base branch
from
RheagalFire:feat/add-litellm-provider
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| """LiteLLM-based embedding provider for semantic indexing. | ||
|
|
||
| Routes embedding requests to 100+ providers (OpenAI, Anthropic, Google, Azure, | ||
| Bedrock, Cohere, etc.) via the litellm SDK. No proxy server needed. | ||
|
|
||
| Model strings use the ``provider/model`` format, e.g. | ||
| ``openai/text-embedding-3-small``, ``cohere/embed-english-v3.0``, | ||
| ``azure/my-embedding-deployment``. | ||
|
|
||
| See https://docs.litellm.ai/docs/embedding/supported_embedding for all | ||
| supported embedding models. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| from typing import Any | ||
|
|
||
| from basic_memory.repository.embedding_provider import EmbeddingProvider | ||
| from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError | ||
|
|
||
|
|
||
| class LiteLLMEmbeddingProvider(EmbeddingProvider): | ||
| """Embedding provider backed by the litellm SDK.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| model_name: str = "openai/text-embedding-3-small", | ||
| *, | ||
| batch_size: int = 64, | ||
| request_concurrency: int = 4, | ||
| dimensions: int = 1536, | ||
| api_key: str | None = None, | ||
| timeout: float = 30.0, | ||
| ) -> None: | ||
| self.model_name = model_name | ||
| self.dimensions = dimensions | ||
| self.batch_size = batch_size | ||
| self.request_concurrency = request_concurrency | ||
| self._api_key = api_key | ||
| self._timeout = timeout | ||
|
|
||
| def runtime_log_attrs(self) -> dict[str, int]: | ||
| """Return provider-specific runtime settings suitable for startup logs.""" | ||
| return { | ||
| "provider_batch_size": self.batch_size, | ||
| "request_concurrency": self.request_concurrency, | ||
| } | ||
|
|
||
| async def embed_documents(self, texts: list[str]) -> list[list[float]]: | ||
| if not texts: | ||
| return [] | ||
|
|
||
| try: | ||
| import litellm | ||
| except ImportError as exc: | ||
| raise SemanticDependenciesMissingError( | ||
| "litellm dependency is missing. Install with: pip install litellm" | ||
| ) from exc | ||
|
|
||
| batches = [ | ||
| texts[start : start + self.batch_size] | ||
| for start in range(0, len(texts), self.batch_size) | ||
| ] | ||
| batch_vectors: list[list[list[float]] | None] = [None] * len(batches) | ||
| semaphore = asyncio.Semaphore(self.request_concurrency) | ||
|
|
||
| async def embed_batch(batch_index: int, batch: list[str]) -> None: | ||
| async with semaphore: | ||
| params: dict[str, Any] = { | ||
| "model": self.model_name, | ||
| "input": batch, | ||
| "drop_params": True, | ||
| "timeout": self._timeout, | ||
| } | ||
| if self._api_key: | ||
| params["api_key"] = self._api_key | ||
|
|
||
| response = await litellm.aembedding(**params) | ||
|
|
||
| vectors_by_index: dict[int, list[float]] = {} | ||
| for item in response.data: | ||
| response_index = int(item["index"]) | ||
| vectors_by_index[response_index] = [float(v) for v in item["embedding"]] | ||
|
|
||
| ordered_vectors: list[list[float]] = [] | ||
| for index in range(len(batch)): | ||
| vector = vectors_by_index.get(index) | ||
| if vector is None: | ||
| raise RuntimeError( | ||
| "LiteLLM embedding response is missing expected vector index." | ||
| ) | ||
| ordered_vectors.append(vector) | ||
|
|
||
| batch_vectors[batch_index] = ordered_vectors | ||
|
|
||
| await asyncio.gather( | ||
| *(embed_batch(batch_index, batch) for batch_index, batch in enumerate(batches)) | ||
| ) | ||
|
|
||
| all_vectors: list[list[float]] = [] | ||
| for vectors in batch_vectors: | ||
| if vectors is None: | ||
| raise RuntimeError("LiteLLM embedding batch did not produce vectors.") | ||
| all_vectors.extend(vectors) | ||
|
|
||
| if all_vectors and len(all_vectors[0]) != self.dimensions: | ||
| raise RuntimeError( | ||
| f"Embedding model returned {len(all_vectors[0])}-dimensional vectors " | ||
| f"but provider was configured for {self.dimensions} dimensions." | ||
| ) | ||
| return all_vectors | ||
|
|
||
| async def embed_query(self, text: str) -> list[float]: | ||
| vectors = await self.embed_documents([text]) | ||
| return vectors[0] if vectors else [0.0] * self.dimensions |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,204 @@ | ||
| """Tests for LiteLLMEmbeddingProvider and factory litellm branch.""" | ||
|
|
||
| import asyncio | ||
| import builtins | ||
| import sys | ||
| from types import SimpleNamespace | ||
|
|
||
| import pytest | ||
|
|
||
| from basic_memory.config import BasicMemoryConfig | ||
| from basic_memory.repository.embedding_provider_factory import ( | ||
| create_embedding_provider, | ||
| reset_embedding_provider_cache, | ||
| ) | ||
| from basic_memory.repository.litellm_provider import LiteLLMEmbeddingProvider | ||
| from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError | ||
|
|
||
|
|
||
| def _make_embedding_response(inputs: list[str], dim: int = 3): | ||
| """Build a fake litellm.aembedding response matching the real shape.""" | ||
| data = [] | ||
| for index, text in enumerate(inputs): | ||
| base = float(len(text)) | ||
| data.append({"index": index, "embedding": [base + float(d) for d in range(dim)]}) | ||
| return SimpleNamespace(data=data) | ||
|
|
||
|
|
||
| def _install_litellm_stub(monkeypatch, dim: int = 3): | ||
| """Install a fake litellm module and return the mock aembedding callable.""" | ||
| calls: list[dict] = [] | ||
|
|
||
| async def _aembedding(**kwargs): | ||
| calls.append(kwargs) | ||
| return _make_embedding_response(kwargs["input"], dim) | ||
|
|
||
| module = type(sys)("litellm") | ||
| setattr(module, "aembedding", _aembedding) | ||
| monkeypatch.setitem(sys.modules, "litellm", module) | ||
| return calls | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def _reset_cache(): | ||
| reset_embedding_provider_cache() | ||
| yield | ||
| reset_embedding_provider_cache() | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_litellm_provider_embed_query(monkeypatch): | ||
| """embed_query should return a single vector through litellm.aembedding.""" | ||
| _install_litellm_stub(monkeypatch) | ||
| provider = LiteLLMEmbeddingProvider( | ||
| model_name="openai/text-embedding-3-small", batch_size=2, dimensions=3 | ||
| ) | ||
| result = await provider.embed_query("hello world") | ||
| assert len(result) == 3 | ||
| assert all(isinstance(v, float) for v in result) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_litellm_provider_embed_documents(monkeypatch): | ||
| """embed_documents should return vectors for each input text.""" | ||
| _install_litellm_stub(monkeypatch) | ||
| provider = LiteLLMEmbeddingProvider( | ||
| model_name="openai/text-embedding-3-small", batch_size=2, dimensions=3 | ||
| ) | ||
| texts = ["first doc", "second doc", "third doc"] | ||
| result = await provider.embed_documents(texts) | ||
| assert len(result) == 3 | ||
| assert all(len(v) == 3 for v in result) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_litellm_provider_empty_input(monkeypatch): | ||
| """embed_documents with empty list should return empty list.""" | ||
| _install_litellm_stub(monkeypatch) | ||
| provider = LiteLLMEmbeddingProvider(dimensions=3) | ||
| result = await provider.embed_documents([]) | ||
| assert result == [] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_litellm_provider_batching(monkeypatch): | ||
| """Provider should split inputs into batches of batch_size.""" | ||
| calls = _install_litellm_stub(monkeypatch) | ||
| provider = LiteLLMEmbeddingProvider( | ||
| model_name="openai/text-embedding-3-small", batch_size=2, dimensions=3 | ||
| ) | ||
| texts = ["a", "b", "c", "d", "e"] | ||
| result = await provider.embed_documents(texts) | ||
|
|
||
| assert len(result) == 5 | ||
| assert len(calls) == 3 # 2 + 2 + 1 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_litellm_provider_api_key_forwarded(monkeypatch): | ||
| """api_key should be passed to litellm.aembedding when set.""" | ||
| calls = _install_litellm_stub(monkeypatch) | ||
| provider = LiteLLMEmbeddingProvider( | ||
| model_name="openai/text-embedding-3-small", | ||
| api_key="sk-test-key", | ||
| dimensions=3, | ||
| ) | ||
| await provider.embed_query("test") | ||
| assert calls[0]["api_key"] == "sk-test-key" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_litellm_provider_api_key_omitted_when_none(monkeypatch): | ||
| """api_key should not appear in kwargs when not set.""" | ||
| calls = _install_litellm_stub(monkeypatch) | ||
| provider = LiteLLMEmbeddingProvider( | ||
| model_name="openai/text-embedding-3-small", dimensions=3 | ||
| ) | ||
| await provider.embed_query("test") | ||
| assert "api_key" not in calls[0] | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_litellm_provider_drop_params_always_set(monkeypatch): | ||
| """drop_params=True should always be in the call kwargs.""" | ||
| calls = _install_litellm_stub(monkeypatch) | ||
| provider = LiteLLMEmbeddingProvider(dimensions=3) | ||
| await provider.embed_query("test") | ||
| assert calls[0]["drop_params"] is True | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_litellm_provider_dimension_mismatch_raises_error(monkeypatch): | ||
| """Provider should fail fast when response dimensions differ from configured.""" | ||
| _install_litellm_stub(monkeypatch, dim=3) | ||
| provider = LiteLLMEmbeddingProvider(dimensions=5) | ||
| with pytest.raises(RuntimeError, match="3-dimensional vectors"): | ||
| await provider.embed_documents(["test text"]) | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_litellm_provider_missing_dependency_raises_actionable_error(monkeypatch): | ||
| """Missing litellm package should raise SemanticDependenciesMissingError.""" | ||
| monkeypatch.delitem(sys.modules, "litellm", raising=False) | ||
| original_import = builtins.__import__ | ||
|
|
||
| def _raising_import(name, globals=None, locals=None, fromlist=(), level=0): | ||
| if name == "litellm": | ||
| raise ImportError("litellm not installed") | ||
| return original_import(name, globals, locals, fromlist, level) | ||
|
|
||
| monkeypatch.setattr(builtins, "__import__", _raising_import) | ||
|
|
||
| provider = LiteLLMEmbeddingProvider(model_name="openai/text-embedding-3-small") | ||
| with pytest.raises(SemanticDependenciesMissingError): | ||
| await provider.embed_query("test") | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_litellm_provider_output_ordering(monkeypatch): | ||
| """Vectors should be returned in the same order as input texts.""" | ||
| _install_litellm_stub(monkeypatch) | ||
| provider = LiteLLMEmbeddingProvider(dimensions=3, batch_size=2) | ||
| texts = ["short", "a longer text here"] | ||
| result = await provider.embed_documents(texts) | ||
|
|
||
| assert result[0][0] == float(len("short")) | ||
| assert result[1][0] == float(len("a longer text here")) | ||
|
|
||
|
|
||
| def test_factory_selects_litellm_provider(): | ||
| """Factory should select LiteLLMEmbeddingProvider for litellm config.""" | ||
| config = BasicMemoryConfig( | ||
| env="test", | ||
| projects={"test": "/tmp/basic-memory-test"}, | ||
| default_project="test", | ||
| semantic_search_enabled=True, | ||
| semantic_embedding_provider="litellm", | ||
| semantic_embedding_model="openai/text-embedding-3-small", | ||
| ) | ||
| provider = create_embedding_provider(config) | ||
| assert isinstance(provider, LiteLLMEmbeddingProvider) | ||
| assert provider.model_name == "openai/text-embedding-3-small" | ||
|
|
||
|
|
||
| def test_factory_maps_default_model_for_litellm(): | ||
| """Factory should remap bge-small-en-v1.5 default to openai/text-embedding-3-small.""" | ||
| config = BasicMemoryConfig( | ||
| env="test", | ||
| projects={"test": "/tmp/basic-memory-test"}, | ||
| default_project="test", | ||
| semantic_search_enabled=True, | ||
| semantic_embedding_provider="litellm", | ||
| semantic_embedding_model="bge-small-en-v1.5", | ||
| ) | ||
| provider = create_embedding_provider(config) | ||
| assert isinstance(provider, LiteLLMEmbeddingProvider) | ||
| assert provider.model_name == "openai/text-embedding-3-small" | ||
|
|
||
|
|
||
| def test_runtime_log_attrs(): | ||
| """runtime_log_attrs should return batch_size and concurrency.""" | ||
| provider = LiteLLMEmbeddingProvider(batch_size=32, request_concurrency=8) | ||
| attrs = provider.runtime_log_attrs() | ||
| assert attrs["provider_batch_size"] == 32 | ||
| assert attrs["request_concurrency"] == 8 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When users switch only
semantic_embedding_providertolitellm,BasicMemoryConfigstill supplies the non-empty default modelbge-small-en-v1.5, so thisornever selects the LiteLLM provider default. The factory then instantiatesLiteLLMEmbeddingProvider(model_name="bge-small-en-v1.5")instead of a LiteLLM-routable model such asopenai/text-embedding-3-small, making the new provider fail for the documented minimal configuration; mirror the OpenAI branch's remapping of the FastEmbed default or otherwise treat it as unset.Useful? React with 👍 / 👎.