Skip to content

Commit 9b199c6

Browse files
phernandezclaude
andcommitted
fix: add FastEmbed runtime tuning knobs and provider caching
Add configurable cache_dir, threads, and parallel settings for FastEmbed to support cloud deployments where defaults fail. Cache embedding providers at the process level to avoid re-creating heavy ONNX model instances. - Add semantic_embedding_cache_dir, semantic_embedding_threads, and semantic_embedding_parallel config fields - Thread-safe provider cache with double-checked locking in factory - Forward runtime knobs through to TextEmbedding and embed() calls - Fix if/elif chain in factory for correct error handling Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: phernandez <paul@basicmachines.co>
1 parent fe4a7b1 commit 9b199c6

9 files changed

Lines changed: 232 additions & 26 deletions

File tree

src/basic_memory/cli/commands/project.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -227,9 +227,7 @@ async def _list_projects(ws: str | None = None):
227227

228228
console.print(table)
229229
if cloud_error is not None:
230-
console.print(
231-
f"[yellow]Cloud project discovery failed: {cloud_error}[/yellow]"
232-
)
230+
console.print(f"[yellow]Cloud project discovery failed: {cloud_error}[/yellow]")
233231
console.print(
234232
"[dim]Showing local projects only. "
235233
"Run 'bm cloud login' or 'bm cloud api-key save <key>' if this is a credentials issue.[/dim]"

src/basic_memory/config.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,20 @@ class BasicMemoryConfig(BaseSettings):
173173
description="Batch size for embedding generation.",
174174
gt=0,
175175
)
176+
semantic_embedding_cache_dir: str | None = Field(
177+
default=None,
178+
description="Optional cache directory for FastEmbed model artifacts.",
179+
)
180+
semantic_embedding_threads: int | None = Field(
181+
default=None,
182+
description="Optional FastEmbed runtime thread count override.",
183+
gt=0,
184+
)
185+
semantic_embedding_parallel: int | None = Field(
186+
default=None,
187+
description="Optional FastEmbed embed() parallelism override.",
188+
gt=0,
189+
)
176190
semantic_vector_k: int = Field(
177191
default=100,
178192
description="Vector candidate count for vector and hybrid retrieval.",
@@ -709,9 +723,7 @@ def load_config(self) -> BasicMemoryConfig:
709723
# Create backup before overwriting so users can revert if needed
710724
backup_path = self.config_file.with_suffix(".json.bak")
711725
shutil.copy2(self.config_file, backup_path)
712-
logger.info(
713-
f"Migrating config to current format (backup: {backup_path})"
714-
)
726+
logger.info(f"Migrating config to current format (backup: {backup_path})")
715727
save_basic_memory_config(self.config_file, _CONFIG_CACHE)
716728

717729
return _CONFIG_CACHE
Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,85 @@
11
"""Factory for creating configured semantic embedding providers."""
22

3+
from threading import Lock
4+
35
from basic_memory.config import BasicMemoryConfig
46
from basic_memory.repository.embedding_provider import EmbeddingProvider
57

8+
type ProviderCacheKey = tuple[str, str, int | None, int, str | None, int | None, int | None]
9+
10+
_EMBEDDING_PROVIDER_CACHE: dict[ProviderCacheKey, EmbeddingProvider] = {}
11+
_EMBEDDING_PROVIDER_CACHE_LOCK = Lock()
12+
13+
14+
def _provider_cache_key(app_config: BasicMemoryConfig) -> ProviderCacheKey:
15+
"""Build a stable cache key from provider-relevant semantic embedding config."""
16+
return (
17+
app_config.semantic_embedding_provider.strip().lower(),
18+
app_config.semantic_embedding_model,
19+
app_config.semantic_embedding_dimensions,
20+
app_config.semantic_embedding_batch_size,
21+
app_config.semantic_embedding_cache_dir,
22+
app_config.semantic_embedding_threads,
23+
app_config.semantic_embedding_parallel,
24+
)
25+
26+
27+
def reset_embedding_provider_cache() -> None:
28+
"""Clear process-level embedding provider cache (used by tests)."""
29+
with _EMBEDDING_PROVIDER_CACHE_LOCK:
30+
_EMBEDDING_PROVIDER_CACHE.clear()
31+
632

733
def create_embedding_provider(app_config: BasicMemoryConfig) -> EmbeddingProvider:
834
"""Create an embedding provider based on semantic config.
935
1036
When semantic_embedding_dimensions is set in config, it overrides
1137
the provider's default dimensions (384 for FastEmbed, 1536 for OpenAI).
1238
"""
39+
cache_key = _provider_cache_key(app_config)
40+
with _EMBEDDING_PROVIDER_CACHE_LOCK:
41+
if cached_provider := _EMBEDDING_PROVIDER_CACHE.get(cache_key):
42+
return cached_provider
43+
1344
provider_name = app_config.semantic_embedding_provider.strip().lower()
1445
extra_kwargs: dict = {}
1546
if app_config.semantic_embedding_dimensions is not None:
1647
extra_kwargs["dimensions"] = app_config.semantic_embedding_dimensions
1748

49+
provider: EmbeddingProvider
1850
if provider_name == "fastembed":
1951
# Deferred import: fastembed (and its onnxruntime dep) may not be installed
2052
from basic_memory.repository.fastembed_provider import FastEmbedEmbeddingProvider
2153

22-
return FastEmbedEmbeddingProvider(
54+
if app_config.semantic_embedding_cache_dir is not None:
55+
extra_kwargs["cache_dir"] = app_config.semantic_embedding_cache_dir
56+
if app_config.semantic_embedding_threads is not None:
57+
extra_kwargs["threads"] = app_config.semantic_embedding_threads
58+
if app_config.semantic_embedding_parallel is not None:
59+
extra_kwargs["parallel"] = app_config.semantic_embedding_parallel
60+
61+
provider = FastEmbedEmbeddingProvider(
2362
model_name=app_config.semantic_embedding_model,
2463
batch_size=app_config.semantic_embedding_batch_size,
2564
**extra_kwargs,
2665
)
27-
28-
if provider_name == "openai":
66+
elif provider_name == "openai":
2967
# Deferred import: openai may not be installed
3068
from basic_memory.repository.openai_provider import OpenAIEmbeddingProvider
3169

3270
model_name = app_config.semantic_embedding_model or "text-embedding-3-small"
3371
if model_name == "bge-small-en-v1.5":
3472
model_name = "text-embedding-3-small"
35-
return OpenAIEmbeddingProvider(
73+
provider = OpenAIEmbeddingProvider(
3674
model_name=model_name,
3775
batch_size=app_config.semantic_embedding_batch_size,
3876
**extra_kwargs,
3977
)
78+
else:
79+
raise ValueError(f"Unsupported semantic embedding provider: {provider_name}")
4080

41-
raise ValueError(f"Unsupported semantic embedding provider: {provider_name}")
81+
with _EMBEDDING_PROVIDER_CACHE_LOCK:
82+
if cached_provider := _EMBEDDING_PROVIDER_CACHE.get(cache_key):
83+
return cached_provider
84+
_EMBEDDING_PROVIDER_CACHE[cache_key] = provider
85+
return provider

src/basic_memory/repository/fastembed_provider.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,16 @@ def __init__(
2525
*,
2626
batch_size: int = 64,
2727
dimensions: int = 384,
28+
cache_dir: str | None = None,
29+
threads: int | None = None,
30+
parallel: int | None = None,
2831
) -> None:
2932
self.model_name = model_name
3033
self.dimensions = dimensions
3134
self.batch_size = batch_size
35+
self.cache_dir = cache_dir
36+
self.threads = threads
37+
self.parallel = parallel
3238
self._model: TextEmbedding | None = None
3339
self._model_lock = asyncio.Lock()
3440

@@ -52,6 +58,16 @@ def _create_model() -> "TextEmbedding":
5258
"pip install -U basic-memory"
5359
) from exc
5460
resolved_model_name = self._MODEL_ALIASES.get(self.model_name, self.model_name)
61+
if self.cache_dir is not None and self.threads is not None:
62+
return TextEmbedding(
63+
model_name=resolved_model_name,
64+
cache_dir=self.cache_dir,
65+
threads=self.threads,
66+
)
67+
if self.cache_dir is not None:
68+
return TextEmbedding(model_name=resolved_model_name, cache_dir=self.cache_dir)
69+
if self.threads is not None:
70+
return TextEmbedding(model_name=resolved_model_name, threads=self.threads)
5571
return TextEmbedding(model_name=resolved_model_name)
5672

5773
self._model = await asyncio.to_thread(_create_model)
@@ -64,7 +80,10 @@ async def embed_documents(self, texts: list[str]) -> list[list[float]]:
6480
model = await self._load_model()
6581

6682
def _embed_batch() -> list[list[float]]:
67-
vectors = list(model.embed(texts, batch_size=self.batch_size))
83+
embed_kwargs: dict[str, int] = {"batch_size": self.batch_size}
84+
if self.parallel is not None:
85+
embed_kwargs["parallel"] = self.parallel
86+
vectors = list(model.embed(texts, **embed_kwargs))
6887
normalized: list[list[float]] = []
6988
for vector in vectors:
7089
values = vector.tolist() if hasattr(vector, "tolist") else vector

tests/api/v2/test_schema_router.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,8 @@ async def test_validate_reads_schema_from_file_not_database(
668668

669669
# Overwrite the file on disk with validation=strict
670670
file_path = Path(file_service.base_path) / schema_entity.file_path
671-
file_path.write_text(dedent("""\
671+
file_path.write_text(
672+
dedent("""\
672673
---
673674
title: Editable Schema
674675
permalink: schemas/editable-schema
@@ -685,7 +686,8 @@ async def test_validate_reads_schema_from_file_not_database(
685686
686687
## Observations
687688
- [note] Schema that will be edited on disk
688-
"""))
689+
""")
690+
)
689691

690692
# Create a note missing "role" — strict mode should produce errors, not warnings
691693
note_entity, _ = await entity_service.create_or_update_entity(
@@ -749,7 +751,8 @@ async def test_validate_falls_back_to_db_on_incomplete_frontmatter(
749751

750752
# Overwrite file with frontmatter missing the 'schema' key
751753
file_path = Path(file_service.base_path) / schema_entity.file_path
752-
file_path.write_text(dedent("""\
754+
file_path.write_text(
755+
dedent("""\
753756
---
754757
title: Incomplete Schema
755758
permalink: schemas/incomplete-schema
@@ -761,7 +764,8 @@ async def test_validate_falls_back_to_db_on_incomplete_frontmatter(
761764
762765
## Observations
763766
- [note] Mid-edit state
764-
"""))
767+
""")
768+
)
765769

766770
# Create a note to validate against this schema
767771
note_entity, _ = await entity_service.create_or_update_entity(

tests/cli/test_cloud_status.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ def is_token_valid(self, t):
6363
monkeypatch.setattr(
6464
"basic_memory.cli.commands.cloud.core_commands.ConfigManager", FakeConfigManager
6565
)
66-
monkeypatch.setattr(
67-
"basic_memory.cli.commands.cloud.core_commands.CLIAuth", FakeAuth
68-
)
66+
monkeypatch.setattr("basic_memory.cli.commands.cloud.core_commands.CLIAuth", FakeAuth)
6967
monkeypatch.setattr(
7068
"basic_memory.cli.commands.cloud.core_commands.get_cloud_config",
7169
lambda: ("cid", "domain", "https://cloud.example.com"),

tests/repository/test_fastembed_provider.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,25 @@ def tolist(self):
1919

2020
class _StubTextEmbedding:
2121
init_count = 0
22+
last_init_kwargs: dict = {}
23+
last_embed_kwargs: dict = {}
2224

23-
def __init__(self, model_name: str):
25+
def __init__(self, model_name: str, cache_dir: str | None = None, threads: int | None = None):
2426
self.model_name = model_name
2527
self.embed_calls = 0
28+
_StubTextEmbedding.last_init_kwargs = {
29+
"model_name": model_name,
30+
"cache_dir": cache_dir,
31+
"threads": threads,
32+
}
2633
_StubTextEmbedding.init_count += 1
2734

28-
def embed(self, texts: list[str], batch_size: int = 64):
35+
def embed(self, texts: list[str], batch_size: int = 64, parallel: int | None = None):
2936
self.embed_calls += 1
37+
_StubTextEmbedding.last_embed_kwargs = {
38+
"batch_size": batch_size,
39+
"parallel": parallel,
40+
}
3041
for text in texts:
3142
if "wide" in text:
3243
yield _StubVector([1.0, 0.0, 0.0, 0.0, 0.5])
@@ -85,3 +96,30 @@ def _raising_import(name, globals=None, locals=None, fromlist=(), level=0):
8596
await provider.embed_query("test")
8697

8798
assert "pip install -U basic-memory" in str(error.value)
99+
100+
101+
@pytest.mark.asyncio
102+
async def test_fastembed_provider_passes_runtime_knobs_to_fastembed(monkeypatch):
103+
"""Provider should pass optional runtime tuning knobs through to FastEmbed."""
104+
module = type(sys)("fastembed")
105+
module.TextEmbedding = _StubTextEmbedding
106+
monkeypatch.setitem(sys.modules, "fastembed", module)
107+
_StubTextEmbedding.last_init_kwargs = {}
108+
_StubTextEmbedding.last_embed_kwargs = {}
109+
110+
provider = FastEmbedEmbeddingProvider(
111+
model_name="stub-model",
112+
dimensions=4,
113+
batch_size=8,
114+
cache_dir="/tmp/fastembed-cache",
115+
threads=3,
116+
parallel=2,
117+
)
118+
await provider.embed_documents(["runtime knobs"])
119+
120+
assert _StubTextEmbedding.last_init_kwargs == {
121+
"model_name": "stub-model",
122+
"cache_dir": "/tmp/fastembed-cache",
123+
"threads": 3,
124+
}
125+
assert _StubTextEmbedding.last_embed_kwargs == {"batch_size": 8, "parallel": 2}

0 commit comments

Comments
 (0)