Skip to content

Commit fc5c896

Browse files
committed
🐛 avoid network calls in disconnected environments whn loading fastembed models
1 parent 20b3945 commit fc5c896

1 file changed

Lines changed: 39 additions & 0 deletions

File tree

nemoguardrails/embeddings/providers/fastembed.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515

1616
import asyncio
17+
import os
18+
from pathlib import Path
1719
from typing import List
1820

1921
from .base import EmbeddingModel
@@ -49,6 +51,15 @@ def __init__(self, embedding_model: str, **kwargs):
4951
embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
5052
self.embedding_model = embedding_model
5153

54+
# If a pre-downloaded model exists in the cache, pass its path directly
55+
# to skip all download attempts. This avoids network calls in disconnected
56+
# environments where fastembed would otherwise try HuggingFace Hub and GCS.
57+
# See https://github.com/qdrant/fastembed/pull/614
58+
if "specific_model_path" not in kwargs:
59+
cached_path = self._find_cached_model(embedding_model, Embedding)
60+
if cached_path:
61+
kwargs["specific_model_path"] = str(cached_path)
62+
5263
try:
5364
self.model = Embedding(embedding_model, **kwargs)
5465
except ValueError as ex:
@@ -63,6 +74,34 @@ def __init__(self, embedding_model: str, **kwargs):
6374
# Get the embedding dimension of the model
6475
self.embedding_size = len(list(self.model.embed("test"))[0].tolist())
6576

77+
@staticmethod
78+
def _find_cached_model(embedding_model: str, embedding_cls) -> Path | None:
79+
"""Find a pre-downloaded model in the fastembed cache.
80+
81+
Looks up the HF source repo for the model in fastembed's registry,
82+
then checks if it exists in the HF-convention cache directory
83+
(models--{org}--{repo}/snapshots/{hash}/).
84+
"""
85+
cache_dir = Path(os.environ.get("FASTEMBED_CACHE_PATH", str(Path.home() / ".cache" / "fastembed")))
86+
if not cache_dir.is_dir():
87+
return None
88+
89+
for model_desc in embedding_cls.list_supported_models():
90+
if model_desc["model"] != embedding_model:
91+
continue
92+
hf_source = (model_desc.get("sources") or {}).get("hf")
93+
if not hf_source:
94+
break
95+
snapshots_dir = cache_dir / f"models--{hf_source.replace('/', '--')}" / "snapshots"
96+
if not snapshots_dir.is_dir():
97+
break
98+
snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()]
99+
if snapshot_dirs:
100+
return snapshot_dirs[0]
101+
break
102+
103+
return None
104+
66105
async def encode_async(self, documents: List[str]) -> List[List[float]]:
67106
"""Encode a list of documents into their corresponding sentence embeddings.
68107

0 commit comments

Comments
 (0)