Skip to content

Commit e352097

Browse files
Mijamind719codex
andcommitted
fix: fallback local batch embedding to sequential mode
Co-authored-by: GPT-5.4 <noreply@openai.com>
1 parent dd16683 commit e352097

2 files changed

Lines changed: 59 additions & 9 deletions

File tree

openviking/models/embedder/local_embedders.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,16 @@ def _extract_embeddings(payload: Any) -> List[List[float]]:
205205
return vectors
206206
raise RuntimeError("Unexpected llama-cpp-python batch embedding response format")
207207

208+
def _embed_formatted_text(self, formatted: str) -> EmbedResult:
209+
payload = self._llama.create_embedding(formatted)
210+
return EmbedResult(dense_vector=self._extract_embedding(payload))
211+
208212
def embed(self, text: str, is_query: bool = False) -> EmbedResult:
209213
formatted = self._format_text(text, is_query=is_query)
210214

211-
def _call() -> EmbedResult:
212-
payload = self._llama.create_embedding(formatted)
213-
return EmbedResult(dense_vector=self._extract_embedding(payload))
214-
215215
try:
216216
result = self._run_with_retry(
217-
_call,
217+
lambda: self._embed_formatted_text(formatted),
218218
logger=logger,
219219
operation_name="local embedding",
220220
)
@@ -236,20 +236,35 @@ def embed_batch(self, texts: List[str], is_query: bool = False) -> List[EmbedRes
236236

237237
formatted = [self._format_text(text, is_query=is_query) for text in texts]
238238

239-
def _call() -> List[EmbedResult]:
239+
def _call_batch() -> List[EmbedResult]:
240240
payload = self._llama.create_embedding(formatted)
241241
return [
242242
EmbedResult(dense_vector=vector) for vector in self._extract_embeddings(payload)
243243
]
244244

245245
try:
246246
results = self._run_with_retry(
247-
_call,
247+
_call_batch,
248248
logger=logger,
249249
operation_name="local batch embedding",
250250
)
251-
except Exception as exc:
252-
raise RuntimeError(f"Local batch embedding failed: {exc}") from exc
251+
except Exception as batch_exc:
252+
logger.warning(
253+
"Local batch embedding failed for model=%s (%s); falling back to sequential embedding",
254+
self.model_name,
255+
batch_exc,
256+
)
257+
try:
258+
results = [
259+
self._run_with_retry(
260+
lambda formatted_text=text: self._embed_formatted_text(formatted_text),
261+
logger=logger,
262+
operation_name="local sequential batch embedding",
263+
)
264+
for text in formatted
265+
]
266+
except Exception as exc:
267+
raise RuntimeError(f"Local batch embedding failed: {exc}") from exc
253268

254269
estimated_tokens = sum(self._estimate_tokens(text) for text in formatted)
255270
self.update_token_usage(

tests/unit/test_local_embedder.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ def create_embedding(self, payload):
5252
return {"data": [{"embedding": [0.1] * 512}]}
5353

5454

55+
class _FakeLlamaFailBatch(_FakeLlama):
56+
def create_embedding(self, payload):
57+
self.__class__.inputs.append(payload)
58+
if isinstance(payload, list) and len(payload) > 1:
59+
raise RuntimeError("llama_decode returned -1")
60+
return {"data": [{"embedding": [0.2] * 512}]}
61+
62+
5563
@pytest.fixture(autouse=True)
5664
def _reset_fake_llama():
5765
_FakeLlama.init_kwargs = []
@@ -149,3 +157,30 @@ def test_local_embedder_embed_batch_preserves_count(monkeypatch, tmp_path):
149157
f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}a",
150158
f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}b",
151159
]
160+
161+
162+
def test_local_embedder_embed_batch_falls_back_to_sequential(monkeypatch, tmp_path):
163+
model_path = tmp_path / "model.gguf"
164+
model_path.write_bytes(b"gguf")
165+
166+
_FakeLlamaFailBatch.init_kwargs = []
167+
_FakeLlamaFailBatch.inputs = []
168+
169+
monkeypatch.setattr(
170+
"openviking.models.embedder.local_embedders.importlib.import_module",
171+
lambda _name: SimpleNamespace(Llama=_FakeLlamaFailBatch),
172+
)
173+
174+
embedder = LocalDenseEmbedder(model_path=str(model_path))
175+
results = embedder.embed_batch(["a", "b"], is_query=True)
176+
177+
assert len(results) == 2
178+
assert all(len(item.dense_vector) == 512 for item in results)
179+
assert _FakeLlamaFailBatch.inputs[0] == [
180+
f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}a",
181+
f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}b",
182+
]
183+
assert _FakeLlamaFailBatch.inputs[1:] == [
184+
f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}a",
185+
f"{DEFAULT_BGE_ZH_QUERY_INSTRUCTION}b",
186+
]

0 commit comments

Comments
 (0)