|
1 | 1 | """Factory for creating configured semantic embedding providers.""" |
2 | 2 |
|
| 3 | +from threading import Lock |
| 4 | + |
3 | 5 | from basic_memory.config import BasicMemoryConfig |
4 | 6 | from basic_memory.repository.embedding_provider import EmbeddingProvider |
5 | 7 |
|
| 8 | +type ProviderCacheKey = tuple[str, str, int | None, int, str | None, int | None, int | None] |
| 9 | + |
| 10 | +_EMBEDDING_PROVIDER_CACHE: dict[ProviderCacheKey, EmbeddingProvider] = {} |
| 11 | +_EMBEDDING_PROVIDER_CACHE_LOCK = Lock() |
| 12 | + |
| 13 | + |
| 14 | +def _provider_cache_key(app_config: BasicMemoryConfig) -> ProviderCacheKey: |
| 15 | + """Build a stable cache key from provider-relevant semantic embedding config.""" |
| 16 | + return ( |
| 17 | + app_config.semantic_embedding_provider.strip().lower(), |
| 18 | + app_config.semantic_embedding_model, |
| 19 | + app_config.semantic_embedding_dimensions, |
| 20 | + app_config.semantic_embedding_batch_size, |
| 21 | + app_config.semantic_embedding_cache_dir, |
| 22 | + app_config.semantic_embedding_threads, |
| 23 | + app_config.semantic_embedding_parallel, |
| 24 | + ) |
| 25 | + |
| 26 | + |
| 27 | +def reset_embedding_provider_cache() -> None: |
| 28 | + """Clear process-level embedding provider cache (used by tests).""" |
| 29 | + with _EMBEDDING_PROVIDER_CACHE_LOCK: |
| 30 | + _EMBEDDING_PROVIDER_CACHE.clear() |
| 31 | + |
6 | 32 |
|
7 | 33 | def create_embedding_provider(app_config: BasicMemoryConfig) -> EmbeddingProvider: |
8 | 34 | """Create an embedding provider based on semantic config. |
9 | 35 |
|
10 | 36 | When semantic_embedding_dimensions is set in config, it overrides |
11 | 37 | the provider's default dimensions (384 for FastEmbed, 1536 for OpenAI). |
12 | 38 | """ |
| 39 | + cache_key = _provider_cache_key(app_config) |
| 40 | + with _EMBEDDING_PROVIDER_CACHE_LOCK: |
| 41 | + if cached_provider := _EMBEDDING_PROVIDER_CACHE.get(cache_key): |
| 42 | + return cached_provider |
| 43 | + |
13 | 44 | provider_name = app_config.semantic_embedding_provider.strip().lower() |
14 | 45 | extra_kwargs: dict = {} |
15 | 46 | if app_config.semantic_embedding_dimensions is not None: |
16 | 47 | extra_kwargs["dimensions"] = app_config.semantic_embedding_dimensions |
17 | 48 |
|
| 49 | + provider: EmbeddingProvider |
18 | 50 | if provider_name == "fastembed": |
19 | 51 | # Deferred import: fastembed (and its onnxruntime dep) may not be installed |
20 | 52 | from basic_memory.repository.fastembed_provider import FastEmbedEmbeddingProvider |
21 | 53 |
|
22 | | - return FastEmbedEmbeddingProvider( |
| 54 | + if app_config.semantic_embedding_cache_dir is not None: |
| 55 | + extra_kwargs["cache_dir"] = app_config.semantic_embedding_cache_dir |
| 56 | + if app_config.semantic_embedding_threads is not None: |
| 57 | + extra_kwargs["threads"] = app_config.semantic_embedding_threads |
| 58 | + if app_config.semantic_embedding_parallel is not None: |
| 59 | + extra_kwargs["parallel"] = app_config.semantic_embedding_parallel |
| 60 | + |
| 61 | + provider = FastEmbedEmbeddingProvider( |
23 | 62 | model_name=app_config.semantic_embedding_model, |
24 | 63 | batch_size=app_config.semantic_embedding_batch_size, |
25 | 64 | **extra_kwargs, |
26 | 65 | ) |
27 | | - |
28 | | - if provider_name == "openai": |
| 66 | + elif provider_name == "openai": |
29 | 67 | # Deferred import: openai may not be installed |
30 | 68 | from basic_memory.repository.openai_provider import OpenAIEmbeddingProvider |
31 | 69 |
|
32 | 70 | model_name = app_config.semantic_embedding_model or "text-embedding-3-small" |
33 | 71 | if model_name == "bge-small-en-v1.5": |
34 | 72 | model_name = "text-embedding-3-small" |
35 | | - return OpenAIEmbeddingProvider( |
| 73 | + provider = OpenAIEmbeddingProvider( |
36 | 74 | model_name=model_name, |
37 | 75 | batch_size=app_config.semantic_embedding_batch_size, |
38 | 76 | **extra_kwargs, |
39 | 77 | ) |
| 78 | + else: |
| 79 | + raise ValueError(f"Unsupported semantic embedding provider: {provider_name}") |
40 | 80 |
|
41 | | - raise ValueError(f"Unsupported semantic embedding provider: {provider_name}") |
| 81 | + with _EMBEDDING_PROVIDER_CACHE_LOCK: |
| 82 | + if cached_provider := _EMBEDDING_PROVIDER_CACHE.get(cache_key): |
| 83 | + return cached_provider |
| 84 | + _EMBEDDING_PROVIDER_CACHE[cache_key] = provider |
| 85 | + return provider |
0 commit comments