-
Notifications
You must be signed in to change notification settings - Fork 186
Expand file tree
/
Copy pathfastembed_provider.py
More file actions
138 lines (118 loc) · 5.58 KB
/
fastembed_provider.py
File metadata and controls
138 lines (118 loc) · 5.58 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""FastEmbed-based local embedding provider."""
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from loguru import logger
from basic_memory.repository.embedding_provider import EmbeddingProvider
from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError
if TYPE_CHECKING:
from fastembed import TextEmbedding # type: ignore[import-not-found] # pragma: no cover
class FastEmbedEmbeddingProvider(EmbeddingProvider):
"""Local ONNX embedding provider backed by FastEmbed."""
_MODEL_ALIASES = {
"bge-small-en-v1.5": "BAAI/bge-small-en-v1.5",
}
def _effective_parallel(self) -> int | None:
return self.parallel if self.parallel is not None and self.parallel > 1 else None
def runtime_log_attrs(self) -> dict[str, int | str | None]:
"""Return the resolved runtime knobs that shape FastEmbed throughput."""
return {
"provider_batch_size": self.batch_size,
"threads": self.threads,
"configured_parallel": self.parallel,
"effective_parallel": self._effective_parallel(),
}
def __init__(
self,
model_name: str = "bge-small-en-v1.5",
*,
batch_size: int = 64,
dimensions: int = 384,
cache_dir: str | None = None,
threads: int | None = None,
parallel: int | None = None,
) -> None:
self.model_name = model_name
self.dimensions = dimensions
self.batch_size = batch_size
self.cache_dir = cache_dir
self.threads = threads
self.parallel = parallel
self._model: TextEmbedding | None = None
self._model_lock = asyncio.Lock()
async def _load_model(self) -> "TextEmbedding":
if self._model is not None:
return self._model
async with self._model_lock:
if self._model is not None:
return self._model
def _create_model() -> "TextEmbedding":
try:
from fastembed import TextEmbedding # type: ignore[import-not-found]
except (
ImportError
) as exc: # pragma: no cover - exercised via tests with monkeypatch
raise SemanticDependenciesMissingError(
"fastembed package is missing. "
"Install/update basic-memory to include semantic dependencies: "
"pip install -U basic-memory"
) from exc
resolved_model_name = self._MODEL_ALIASES.get(self.model_name, self.model_name)
if self.cache_dir is not None and self.threads is not None:
return TextEmbedding(
model_name=resolved_model_name,
cache_dir=self.cache_dir,
threads=self.threads,
)
if self.cache_dir is not None:
return TextEmbedding(model_name=resolved_model_name, cache_dir=self.cache_dir)
if self.threads is not None:
return TextEmbedding(model_name=resolved_model_name, threads=self.threads)
return TextEmbedding(model_name=resolved_model_name)
self._model = await asyncio.to_thread(_create_model)
logger.info(
"FastEmbed model loaded: model_name={model_name} batch_size={batch_size} "
"threads={threads} configured_parallel={configured_parallel} "
"effective_parallel={effective_parallel}",
model_name=self._MODEL_ALIASES.get(self.model_name, self.model_name),
batch_size=self.batch_size,
threads=self.threads,
configured_parallel=self.parallel,
effective_parallel=self._effective_parallel(),
)
return self._model
async def embed_documents(self, texts: list[str]) -> list[list[float]]:
if not texts:
return []
model = await self._load_model()
effective_parallel = self._effective_parallel()
logger.debug(
"FastEmbed embed_documents call: text_count={text_count} batch_size={batch_size} "
"threads={threads} configured_parallel={configured_parallel} "
"effective_parallel={effective_parallel}",
text_count=len(texts),
batch_size=self.batch_size,
threads=self.threads,
configured_parallel=self.parallel,
effective_parallel=effective_parallel,
)
def _embed_batch() -> list[list[float]]:
embed_kwargs: dict[str, int] = {"batch_size": self.batch_size}
if effective_parallel is not None:
embed_kwargs["parallel"] = effective_parallel
vectors = list(model.embed(texts, **embed_kwargs))
normalized: list[list[float]] = []
for vector in vectors:
values = vector.tolist() if hasattr(vector, "tolist") else vector
normalized.append([float(value) for value in values])
return normalized
vectors = await asyncio.to_thread(_embed_batch)
if vectors and len(vectors[0]) != self.dimensions:
raise RuntimeError(
f"Embedding model returned {len(vectors[0])}-dimensional vectors "
f"but provider was configured for {self.dimensions} dimensions."
)
return vectors
async def embed_query(self, text: str) -> list[float]:
vectors = await self.embed_documents([text])
return vectors[0] if vectors else [0.0] * self.dimensions