Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions nemoguardrails/embeddings/providers/fastembed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.

import asyncio
import os
from pathlib import Path
from typing import List

from .base import EmbeddingModel
Expand Down Expand Up @@ -49,6 +51,15 @@ def __init__(self, embedding_model: str, **kwargs):
embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
self.embedding_model = embedding_model

# If a pre-downloaded model exists in the cache, pass its path directly
# to skip all download attempts. This avoids network calls in disconnected
# environments where fastembed would otherwise try HuggingFace Hub and GCS.
# See https://github.com/qdrant/fastembed/pull/614
if "specific_model_path" not in kwargs:
cached_path = self._find_cached_model(embedding_model, Embedding)
if cached_path:
kwargs["specific_model_path"] = str(cached_path)

try:
self.model = Embedding(embedding_model, **kwargs)
except ValueError as ex:
Expand All @@ -63,6 +74,34 @@ def __init__(self, embedding_model: str, **kwargs):
# Get the embedding dimension of the model
self.embedding_size = len(list(self.model.embed("test"))[0].tolist())

@staticmethod
def _find_cached_model(embedding_model: str, embedding_cls) -> Path | None:
"""Find a pre-downloaded model in the fastembed cache.

Looks up the HF source repo for the model in fastembed's registry,
then checks if it exists in the HF-convention cache directory
(models--{org}--{repo}/snapshots/{hash}/).
"""
cache_dir = Path(os.environ.get("FASTEMBED_CACHE_PATH", str(Path.home() / ".cache" / "fastembed")))
if not cache_dir.is_dir():
return None

for model_desc in embedding_cls.list_supported_models():
if model_desc["model"] != embedding_model:
continue
hf_source = (model_desc.get("sources") or {}).get("hf")
if not hf_source:
break
snapshots_dir = cache_dir / f"models--{hf_source.replace('/', '--')}" / "snapshots"
if not snapshots_dir.is_dir():
break
snapshot_dirs = [d for d in snapshots_dir.iterdir() if d.is_dir()]
if snapshot_dirs:
return snapshot_dirs[0]
break

return None

async def encode_async(self, documents: List[str]) -> List[List[float]]:
"""Encode a list of documents into their corresponding sentence embeddings.

Expand Down
Loading