|
| 1 | +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. |
| 2 | +# SPDX-License-Identifier: AGPL-3.0 |
| 3 | +"""Local GGUF embedders powered by llama-cpp-python.""" |
| 4 | + |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +import importlib |
| 8 | +import logging |
| 9 | +import os |
| 10 | +from dataclasses import dataclass |
| 11 | +from pathlib import Path |
| 12 | +from typing import Any, Dict, List, Optional |
| 13 | + |
| 14 | +import requests |
| 15 | + |
| 16 | +from openviking.models.embedder.base import DenseEmbedderBase, EmbedResult |
| 17 | +from openviking.storage.errors import EmbeddingConfigurationError |
| 18 | + |
| 19 | +logger = logging.getLogger(__name__) |
| 20 | + |
| 21 | +DEFAULT_LOCAL_MODEL_CACHE_DIR = "~/.cache/openviking/models" |
| 22 | +DEFAULT_LOCAL_DENSE_MODEL = "bge-small-zh-v1.5-f16" |
| 23 | +DEFAULT_BGE_ZH_QUERY_INSTRUCTION = "为这个句子生成表示以用于检索相关文章:" |
| 24 | + |
| 25 | + |
| 26 | +@dataclass(frozen=True) |
| 27 | +class LocalModelSpec: |
| 28 | + model_name: str |
| 29 | + dimension: int |
| 30 | + filename: str |
| 31 | + download_url: str |
| 32 | + query_instruction: Optional[str] = None |
| 33 | + |
| 34 | + |
| 35 | +LOCAL_DENSE_MODEL_SPECS: Dict[str, LocalModelSpec] = { |
| 36 | + DEFAULT_LOCAL_DENSE_MODEL: LocalModelSpec( |
| 37 | + model_name=DEFAULT_LOCAL_DENSE_MODEL, |
| 38 | + dimension=512, |
| 39 | + filename="bge-small-zh-v1.5-f16.gguf", |
| 40 | + download_url=( |
| 41 | + "https://huggingface.co/CompendiumLabs/bge-small-zh-v1.5-gguf/resolve/main/" |
| 42 | + "bge-small-zh-v1.5-f16.gguf?download=true" |
| 43 | + ), |
| 44 | + query_instruction=DEFAULT_BGE_ZH_QUERY_INSTRUCTION, |
| 45 | + ) |
| 46 | +} |
| 47 | + |
| 48 | + |
| 49 | +def get_local_model_spec(model_name: str) -> LocalModelSpec: |
| 50 | + try: |
| 51 | + return LOCAL_DENSE_MODEL_SPECS[model_name] |
| 52 | + except KeyError as exc: |
| 53 | + raise ValueError( |
| 54 | + f"Unknown local embedding model '{model_name}'. " |
| 55 | + f"Supported models: {list(LOCAL_DENSE_MODEL_SPECS.keys())}" |
| 56 | + ) from exc |
| 57 | + |
| 58 | + |
| 59 | +def get_local_model_default_dimension(model_name: str) -> int: |
| 60 | + return get_local_model_spec(model_name).dimension |
| 61 | + |
| 62 | + |
| 63 | +def get_local_model_cache_path(model_name: str, cache_dir: Optional[str] = None) -> Path: |
| 64 | + spec = get_local_model_spec(model_name) |
| 65 | + cache_root = Path(cache_dir or DEFAULT_LOCAL_MODEL_CACHE_DIR).expanduser().resolve() |
| 66 | + return cache_root / spec.filename |
| 67 | + |
| 68 | + |
| 69 | +def get_local_model_identity(model_name: str, model_path: Optional[str] = None) -> str: |
| 70 | + if model_path: |
| 71 | + resolved = Path(model_path).expanduser().resolve() |
| 72 | + return str(resolved) |
| 73 | + return get_local_model_spec(model_name).filename |
| 74 | + |
| 75 | + |
| 76 | +class LocalDenseEmbedder(DenseEmbedderBase): |
| 77 | + """Dense embedder backed by a local GGUF model via llama-cpp-python.""" |
| 78 | + |
| 79 | + def __init__( |
| 80 | + self, |
| 81 | + model_name: str = DEFAULT_LOCAL_DENSE_MODEL, |
| 82 | + model_path: Optional[str] = None, |
| 83 | + cache_dir: Optional[str] = None, |
| 84 | + dimension: Optional[int] = None, |
| 85 | + query_instruction: Optional[str] = None, |
| 86 | + config: Optional[Dict[str, Any]] = None, |
| 87 | + ): |
| 88 | + runtime_config = dict(config or {}) |
| 89 | + runtime_config.setdefault("provider", "local") |
| 90 | + super().__init__(model_name, runtime_config) |
| 91 | + |
| 92 | + self.model_spec = get_local_model_spec(model_name) |
| 93 | + self.model_path = model_path |
| 94 | + self.cache_dir = cache_dir or DEFAULT_LOCAL_MODEL_CACHE_DIR |
| 95 | + self.query_instruction = ( |
| 96 | + query_instruction |
| 97 | + if query_instruction is not None |
| 98 | + else self.model_spec.query_instruction |
| 99 | + ) |
| 100 | + self._dimension = dimension or self.model_spec.dimension |
| 101 | + if self._dimension != self.model_spec.dimension: |
| 102 | + raise ValueError( |
| 103 | + f"Local model '{model_name}' has fixed dimension {self.model_spec.dimension}, " |
| 104 | + f"but got dimension={self._dimension}" |
| 105 | + ) |
| 106 | + |
| 107 | + self._resolved_model_path = self._resolve_model_path() |
| 108 | + self._llama = self._load_model() |
| 109 | + |
| 110 | + def _import_llama(self): |
| 111 | + try: |
| 112 | + module = importlib.import_module("llama_cpp") |
| 113 | + except ImportError as exc: |
| 114 | + raise EmbeddingConfigurationError( |
| 115 | + "Local embedding is enabled but 'llama-cpp-python' is not installed. " |
| 116 | + 'Install it with: pip install "openviking[local-embed]". ' |
| 117 | + "If you prefer a remote provider, set embedding.dense.provider explicitly in ov.conf." |
| 118 | + ) from exc |
| 119 | + |
| 120 | + llama_cls = getattr(module, "Llama", None) |
| 121 | + if llama_cls is None: |
| 122 | + raise EmbeddingConfigurationError( |
| 123 | + "llama_cpp.Llama is unavailable in the installed llama-cpp-python package." |
| 124 | + ) |
| 125 | + return llama_cls |
| 126 | + |
| 127 | + def _resolve_model_path(self) -> Path: |
| 128 | + if self.model_path: |
| 129 | + resolved = Path(self.model_path).expanduser().resolve() |
| 130 | + if not resolved.exists(): |
| 131 | + raise EmbeddingConfigurationError( |
| 132 | + f"Local embedding model file not found: {resolved}" |
| 133 | + ) |
| 134 | + return resolved |
| 135 | + |
| 136 | + cache_root = Path(self.cache_dir).expanduser().resolve() |
| 137 | + cache_root.mkdir(parents=True, exist_ok=True) |
| 138 | + target = get_local_model_cache_path(self.model_name, self.cache_dir) |
| 139 | + if target.exists(): |
| 140 | + return target |
| 141 | + |
| 142 | + self._download_model(self.model_spec.download_url, target) |
| 143 | + return target |
| 144 | + |
| 145 | + def _download_model(self, url: str, target: Path) -> None: |
| 146 | + logger.info("Downloading local embedding model %s to %s", self.model_name, target) |
| 147 | + tmp_target = target.with_suffix(target.suffix + ".part") |
| 148 | + try: |
| 149 | + with requests.get(url, stream=True, timeout=(10, 300)) as response: |
| 150 | + response.raise_for_status() |
| 151 | + with tmp_target.open("wb") as fh: |
| 152 | + for chunk in response.iter_content(chunk_size=1024 * 1024): |
| 153 | + if chunk: |
| 154 | + fh.write(chunk) |
| 155 | + os.replace(tmp_target, target) |
| 156 | + except Exception as exc: |
| 157 | + tmp_target.unlink(missing_ok=True) |
| 158 | + raise EmbeddingConfigurationError( |
| 159 | + f"Failed to download local embedding model '{self.model_name}' from {url} " |
| 160 | + f"to {target}: {exc}" |
| 161 | + ) from exc |
| 162 | + |
| 163 | + def _load_model(self): |
| 164 | + llama_cls = self._import_llama() |
| 165 | + try: |
| 166 | + return llama_cls( |
| 167 | + model_path=str(self._resolved_model_path), |
| 168 | + embedding=True, |
| 169 | + verbose=False, |
| 170 | + ) |
| 171 | + except Exception as exc: |
| 172 | + raise EmbeddingConfigurationError( |
| 173 | + f"Failed to load GGUF embedding model from {self._resolved_model_path}: {exc}" |
| 174 | + ) from exc |
| 175 | + |
| 176 | + def _format_text(self, text: str, *, is_query: bool) -> str: |
| 177 | + if is_query and self.query_instruction: |
| 178 | + return f"{self.query_instruction}{text}" |
| 179 | + return text |
| 180 | + |
| 181 | + def _supports_native_batch_embeddings(self) -> bool: |
| 182 | + context_params = getattr(self._llama, "context_params", None) |
| 183 | + n_seq_max = getattr(context_params, "n_seq_max", 1) |
| 184 | + return n_seq_max > 1 |
| 185 | + |
| 186 | + @staticmethod |
| 187 | + def _extract_embedding(payload: Any) -> List[float]: |
| 188 | + if isinstance(payload, dict): |
| 189 | + data = payload.get("data") |
| 190 | + if isinstance(data, list) and data: |
| 191 | + item = data[0] |
| 192 | + if isinstance(item, dict) and "embedding" in item: |
| 193 | + return list(item["embedding"]) |
| 194 | + if "embedding" in payload: |
| 195 | + return list(payload["embedding"]) |
| 196 | + raise RuntimeError("Unexpected llama-cpp-python embedding response format") |
| 197 | + |
| 198 | + @staticmethod |
| 199 | + def _extract_embeddings(payload: Any) -> List[List[float]]: |
| 200 | + if isinstance(payload, dict): |
| 201 | + data = payload.get("data") |
| 202 | + if isinstance(data, list): |
| 203 | + vectors: List[List[float]] = [] |
| 204 | + for item in data: |
| 205 | + if not isinstance(item, dict) or "embedding" not in item: |
| 206 | + raise RuntimeError( |
| 207 | + "Unexpected llama-cpp-python batch embedding response format" |
| 208 | + ) |
| 209 | + vectors.append(list(item["embedding"])) |
| 210 | + return vectors |
| 211 | + raise RuntimeError("Unexpected llama-cpp-python batch embedding response format") |
| 212 | + |
| 213 | + def _embed_formatted_text(self, formatted: str) -> EmbedResult: |
| 214 | + payload = self._llama.create_embedding(formatted) |
| 215 | + return EmbedResult(dense_vector=self._extract_embedding(payload)) |
| 216 | + |
| 217 | + def _embed_formatted_texts_sequential(self, formatted: List[str]) -> List[EmbedResult]: |
| 218 | + return [ |
| 219 | + self._run_with_retry( |
| 220 | + lambda formatted_text=text: self._embed_formatted_text(formatted_text), |
| 221 | + logger=logger, |
| 222 | + operation_name="local sequential batch embedding", |
| 223 | + ) |
| 224 | + for text in formatted |
| 225 | + ] |
| 226 | + |
| 227 | + def embed(self, text: str, is_query: bool = False) -> EmbedResult: |
| 228 | + formatted = self._format_text(text, is_query=is_query) |
| 229 | + |
| 230 | + try: |
| 231 | + result = self._run_with_retry( |
| 232 | + lambda: self._embed_formatted_text(formatted), |
| 233 | + logger=logger, |
| 234 | + operation_name="local embedding", |
| 235 | + ) |
| 236 | + except Exception as exc: |
| 237 | + raise RuntimeError(f"Local embedding failed: {exc}") from exc |
| 238 | + |
| 239 | + estimated_tokens = self._estimate_tokens(formatted) |
| 240 | + self.update_token_usage( |
| 241 | + model_name=self.model_name, |
| 242 | + provider="local", |
| 243 | + prompt_tokens=estimated_tokens, |
| 244 | + completion_tokens=0, |
| 245 | + ) |
| 246 | + return result |
| 247 | + |
| 248 | + def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedResult]: |
| 249 | + if not texts: |
| 250 | + return [] |
| 251 | + |
| 252 | + formatted = [self._format_text(text, is_query=is_query) for text in texts] |
| 253 | + if len(formatted) > 1 and not self._supports_native_batch_embeddings(): |
| 254 | + logger.info( |
| 255 | + "Local model %s does not support native multi-sequence batch embedding " |
| 256 | + "(n_seq_max <= 1); using sequential mode", |
| 257 | + self.model_name, |
| 258 | + ) |
| 259 | + results = self._embed_formatted_texts_sequential(formatted) |
| 260 | + estimated_tokens = sum(self._estimate_tokens(text) for text in formatted) |
| 261 | + self.update_token_usage( |
| 262 | + model_name=self.model_name, |
| 263 | + provider="local", |
| 264 | + prompt_tokens=estimated_tokens, |
| 265 | + completion_tokens=0, |
| 266 | + ) |
| 267 | + return results |
| 268 | + |
| 269 | + def _call_batch() -> List[EmbedResult]: |
| 270 | + payload = self._llama.create_embedding(formatted) |
| 271 | + return [ |
| 272 | + EmbedResult(dense_vector=vector) for vector in self._extract_embeddings(payload) |
| 273 | + ] |
| 274 | + |
| 275 | + try: |
| 276 | + results = self._run_with_retry( |
| 277 | + _call_batch, |
| 278 | + logger=logger, |
| 279 | + operation_name="local batch embedding", |
| 280 | + ) |
| 281 | + except Exception as batch_exc: |
| 282 | + logger.warning( |
| 283 | + "Local batch embedding failed for model=%s (%s); falling back to sequential embedding", |
| 284 | + self.model_name, |
| 285 | + batch_exc, |
| 286 | + ) |
| 287 | + try: |
| 288 | + results = self._embed_formatted_texts_sequential(formatted) |
| 289 | + except Exception as exc: |
| 290 | + raise RuntimeError(f"Local batch embedding failed: {exc}") from exc |
| 291 | + |
| 292 | + estimated_tokens = sum(self._estimate_tokens(text) for text in formatted) |
| 293 | + self.update_token_usage( |
| 294 | + model_name=self.model_name, |
| 295 | + provider="local", |
| 296 | + prompt_tokens=estimated_tokens, |
| 297 | + completion_tokens=0, |
| 298 | + ) |
| 299 | + return results |
| 300 | + |
| 301 | + def get_dimension(self) -> int: |
| 302 | + return self._dimension |
| 303 | + |
| 304 | + def close(self): |
| 305 | + close_fn = getattr(self._llama, "close", None) |
| 306 | + if callable(close_fn): |
| 307 | + close_fn() |
0 commit comments