-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathembeddings.py
More file actions
155 lines (132 loc) · 5.71 KB
/
Copy pathembeddings.py
File metadata and controls
155 lines (132 loc) · 5.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from __future__ import annotations
import os
from typing import Any
import requests
from observability.logger import get_runtime_logger
logger = get_runtime_logger(__name__)
OLLAMA_HOST = os.getenv("OLLAMA_HOST", "http://localhost:11434")
OLLAMA_EMBED_MODEL = os.getenv("OLLAMA_EMBED_MODEL", "nomic-embed-text")
EMBED_DIM = int(os.getenv("OLLAMA_EMBED_DIM", "768"))
OLLAMA_EMBED_MAX_INPUT_TOKENS = int(os.getenv("OLLAMA_EMBED_MAX_INPUT_TOKENS", "2048"))
OLLAMA_EMBED_MAX_INPUT_BYTES = int(os.getenv("OLLAMA_EMBED_MAX_INPUT_BYTES", "3000"))
ADAPTIVE_MIN_INPUT_BYTES = 800
_adaptive_max_input_bytes = max(ADAPTIVE_MIN_INPUT_BYTES, OLLAMA_EMBED_MAX_INPUT_BYTES)
def _truncate_utf8_bytes(text: str, max_bytes: int) -> str:
if max_bytes <= 0:
return ""
encoded = text.encode("utf-8")
if len(encoded) <= max_bytes:
return text
return encoded[:max_bytes].decode("utf-8", errors="ignore")
def _is_context_length_error(message: str) -> bool:
normalized = (message or "").lower()
return (
"input length exceeds the context length" in normalized
or "exceeds the context length" in normalized
or ("context length" in normalized
and "exceeds" in normalized)
)
def _extract_embedding(data: dict[str, Any]) -> list[float] | None:
embeddings = data.get("embeddings")
if isinstance(embeddings, list) and embeddings:
first = embeddings[0]
if isinstance(first, list):
return [float(value) for value in first]
if isinstance(first, (int, float)):
return [float(value) for value in embeddings]
legacy_embedding = data.get("embedding")
if isinstance(legacy_embedding, list) and legacy_embedding:
return [float(value) for value in legacy_embedding]
return None
def _embed_text_with_legacy_endpoint(text: str) -> list[float]:
fallback_text = _truncate_utf8_bytes(text, OLLAMA_EMBED_MAX_INPUT_BYTES)
if not fallback_text:
return [0.0] * EMBED_DIM
response = requests.post(
f"{OLLAMA_HOST}/api/embeddings",
json={"model": OLLAMA_EMBED_MODEL, "prompt": fallback_text},
timeout=30,
)
response.raise_for_status()
data = response.json()
embedding = _extract_embedding(data)
if embedding:
return embedding
raise RuntimeError(f"Ollama embeddings response missing data: {data}")
def embed_text(text: str) -> list[float]:
global _adaptive_max_input_bytes
text = (text or "").strip()
if not text:
return [0.0] * EMBED_DIM
active_byte_cap = max(ADAPTIVE_MIN_INPUT_BYTES, _adaptive_max_input_bytes)
input_text = _truncate_utf8_bytes(text, active_byte_cap)
input_chars = len(input_text)
input_bytes = len(input_text.encode("utf-8"))
def _post_embed(payload_text: str) -> requests.Response:
return requests.post(
f"{OLLAMA_HOST}/api/embed",
json={
"model": OLLAMA_EMBED_MODEL,
"input": payload_text,
"truncate": True,
"options": {"num_ctx": OLLAMA_EMBED_MAX_INPUT_TOKENS},
},
timeout=30,
)
response = _post_embed(input_text)
if response.status_code == 404:
logger.warning("[embeddings] /api/embed not available; using legacy /api/embeddings")
return _embed_text_with_legacy_endpoint(input_text)
if response.status_code >= 400:
body = (response.text or "").strip()
logger.warning(
"[embeddings] embed request failed status=%s chars=%s bytes=%s body=%s",
response.status_code,
input_chars,
input_bytes,
body[:300],
)
if _is_context_length_error(body):
retry_caps = [2500, 2000, 1600, 1200, 1000, 800]
for cap in retry_caps:
if input_bytes <= cap:
continue
reduced = _truncate_utf8_bytes(input_text, cap)
if not reduced:
continue
reduced_chars = len(reduced)
reduced_bytes = len(reduced.encode("utf-8"))
logger.info(
"[embeddings] retrying after context error chars=%s bytes=%s",
reduced_chars,
reduced_bytes,
)
retry_response = _post_embed(reduced)
if retry_response.status_code == 404:
_adaptive_max_input_bytes = min(_adaptive_max_input_bytes, reduced_bytes)
return _embed_text_with_legacy_endpoint(reduced)
if retry_response.status_code >= 400:
retry_body = (retry_response.text or "").strip()
logger.warning(
"[embeddings] retry failed status=%s body=%s",
retry_response.status_code,
retry_body[:300],
)
continue
retry_data = retry_response.json()
retry_embedding = _extract_embedding(retry_data)
if retry_embedding:
if reduced_bytes < _adaptive_max_input_bytes:
_adaptive_max_input_bytes = max(ADAPTIVE_MIN_INPUT_BYTES, reduced_bytes)
logger.info(
"[embeddings] adapting byte cap after successful retry new_byte_cap=%s",
_adaptive_max_input_bytes,
)
return retry_embedding
return _embed_text_with_legacy_endpoint(input_text)
response.raise_for_status()
data = response.json()
embedding = _extract_embedding(data)
if embedding:
return embedding
raise RuntimeError(f"Ollama embed response missing data: {data}")