Skip to content

Commit 40daa58

Browse files
committed
chore: Improve Jina API embeddings
1 parent 5f26498 commit 40daa58

3 files changed

Lines changed: 28 additions & 62 deletions

File tree

server/embeddings/jina_api.py

Lines changed: 16 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
# uniform with the OpenAI/Voyage providers.
1616
_BATCH_SIZE = 128
1717
_BACKOFF_DELAYS = [10, 20, 30, 40]
18-
# Conservative character cap (~8 k tokens at ~4 chars/token) to avoid
19-
# "Failed to encode text" 400s on models with limited context windows.
20-
_MAX_TEXT_CHARS = 32_000
2118

2219
# Native output dimensions for known models. The jina-code-embeddings family
2320
# supports Matryoshka truncation via the `dimensions` API parameter —
@@ -39,7 +36,10 @@
3936
_TASK_AWARE_PREFIXES = ("jina-code-embeddings-",)
4037

4138
# jina-code-embeddings models use a different task vocabulary than the generic
42-
# "retrieval.*" tasks accepted by other Jina models.
39+
# "retrieval.*" tasks accepted by other Jina models. The family also supports
40+
# code2code.*, code2nl.*, qa.*, and code2completion.* — we map all retrieval
41+
# traffic to nl2code because our queries are natural language and our passages
42+
# are source code.
4343
_JINA_CODE_TASK_MAP = {
4444
"retrieval.passage": "nl2code.passage",
4545
"retrieval.query": "nl2code.query",
@@ -94,10 +94,12 @@ def _sanitize(self, text: str) -> str:
9494
# cause Jina's tokenizer to return 400 "Failed to encode text".
9595
cleaned = text.encode("utf-8", errors="replace").decode("utf-8")
9696
cleaned = "".join(ch for ch in cleaned if ch >= " " or ch in "\t\n\r")
97-
return cleaned[:_MAX_TEXT_CHARS].strip() or "."
97+
return cleaned.strip()
9898

9999
def _make_body(self, inputs: list[str], task: str) -> dict:
100-
body: dict = {"model": self._model, "input": inputs}
100+
# truncate=True lets Jina trim oversized inputs server-side on token
101+
# boundaries instead of returning 400 "Failed to encode text".
102+
body: dict = {"model": self._model, "input": inputs, "truncate": True}
101103
if self._supports_task:
102104
body["task"] = (
103105
_JINA_CODE_TASK_MAP.get(task, task) if self._uses_code_tasks else task
@@ -124,69 +126,21 @@ async def _post_with_retry(self, body: dict) -> dict:
124126
resp.raise_for_status()
125127
return resp.json()
126128

127-
async def _embed_batch_with_fallback(
128-
self, batch: list[str], task: str
129-
) -> list[list[float]]:
130-
"""Embed one item at a time, halving on failure, substituting '.' only as last resort."""
131-
vectors: list[list[float]] = []
132-
for idx, text in enumerate(batch):
133-
candidate = text
134-
embedded = False
135-
while candidate:
136-
try:
137-
data = await self._post_with_retry(
138-
self._make_body([candidate], task)
139-
)
140-
vectors.append(data["data"][0]["embedding"])
141-
if len(candidate) < len(text):
142-
logger.info(
143-
"Encoded truncated text at batch index %d (%d → %d chars)",
144-
idx,
145-
len(text),
146-
len(candidate),
147-
)
148-
embedded = True
149-
break
150-
except Exception:
151-
half = len(candidate) // 2
152-
if half < 64:
153-
break
154-
logger.warning(
155-
"Text at batch index %d (len=%d) failed — retrying with first %d chars",
156-
idx,
157-
len(candidate),
158-
half,
159-
)
160-
candidate = candidate[:half]
161-
if not embedded:
162-
logger.warning(
163-
"Skipping unencodable text at batch index %d (original len=%d), using placeholder.",
164-
idx,
165-
len(text),
166-
)
167-
data = await self._post_with_retry(self._make_body(["."], task))
168-
vectors.append(data["data"][0]["embedding"])
169-
return vectors
170-
171129
async def _embed(self, texts: list[str], task: str) -> list[list[float]]:
172130
if not texts:
173131
return []
174132
sanitized = [self._sanitize(t) for t in texts]
133+
empty_indices = [i for i, t in enumerate(sanitized) if not t]
134+
if empty_indices:
135+
raise ValueError(
136+
f"Jina embed: received empty/whitespace input(s) at index "
137+
f"{empty_indices[:5]} of {len(sanitized)} — callers must filter "
138+
f"empty strings before calling embed_batch/embed_query."
139+
)
175140
all_vectors: list[list[float]] = []
176141
for i in range(0, len(sanitized), _BATCH_SIZE):
177142
batch = sanitized[i : i + _BATCH_SIZE]
178-
try:
179-
data = await self._post_with_retry(self._make_body(batch, task))
180-
except Exception as exc:
181-
if "400" in str(exc):
182-
logger.warning(
183-
"Batch of %d failed with 400 — retrying one-by-one", len(batch)
184-
)
185-
all_vectors.extend(
186-
await self._embed_batch_with_fallback(batch, task)
187-
)
188-
continue
189-
raise
143+
data = await self._post_with_retry(self._make_body(batch, task))
190144
batch_vectors = [item["embedding"] for item in data.get("data", [])]
191145
if len(batch_vectors) != len(batch):
192146
raise ValueError(

server/indexer/pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ async def index_service(
137137
force: bool = False,
138138
progress_callback: Callable[[ProgressEvent], Awaitable[None]] | None = None,
139139
) -> dict[str, int]:
140+
await self._store.ensure_collection()
140141
services = settings.load_services()
141142
svc = next((s for s in services if s.name == service_name), None)
142143
if svc is None:

tests/embeddings/test_jina_api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ async def test_embed_batch_omits_task_for_v2_model(provider):
6565
body = json.loads(route.calls.last.request.read())
6666
assert body["model"] == "jina-embeddings-v2-base-code"
6767
assert body["input"] == ["a", "b"]
68+
assert body["truncate"] is True # server-side truncation on token boundary
6869
assert "task" not in body # v2 models don't accept the task parameter
6970
assert "dimensions" not in body
7071

@@ -161,6 +162,16 @@ async def test_embed_batch_empty(provider):
161162
assert await provider.embed_batch([]) == []
162163

163164

165+
async def test_embed_batch_raises_on_empty_string(provider):
166+
with pytest.raises(ValueError, match="empty/whitespace"):
167+
await provider.embed_batch([""])
168+
169+
170+
async def test_embed_batch_raises_on_whitespace_only(provider):
171+
with pytest.raises(ValueError, match="empty/whitespace"):
172+
await provider.embed_batch([" \n\t "])
173+
174+
164175
@respx.mock
165176
async def test_rate_limit_backoff_delays(provider, monkeypatch):
166177
sleep_calls: list[float] = []

0 commit comments

Comments
 (0)