Skip to content

Commit 4f369b9

Browse files
committed
feat:add embedding and similarity search for knowledge graph
1 parent e1ef07d commit 4f369b9

2 files changed

Lines changed: 614 additions & 95 deletions

File tree

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""Embedding utilities backed by the OpenAI client."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
import math
7+
from typing import Any, Mapping, MutableMapping, Sequence
8+
9+
LOGGER = logging.getLogger(__name__)
10+
11+
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
12+
13+
14+
class EmbeddingGenerator:
15+
"""Generate embeddings using OpenAI's embeddings API."""
16+
17+
def __init__(
18+
self,
19+
*,
20+
model: str = DEFAULT_EMBEDDING_MODEL,
21+
client_options: Mapping[str, Any] | None = None,
22+
) -> None:
23+
self._model = model
24+
self._client_options: MutableMapping[str, Any] = dict(client_options or {})
25+
self._client = None
26+
27+
def embed(self, texts: Sequence[str]) -> list[list[float]]:
28+
"""Return embeddings for the provided texts."""
29+
sanitized = [text for text in texts if text]
30+
if not sanitized:
31+
return []
32+
33+
client = self._ensure_client()
34+
response = client.embeddings.create(model=self._model, input=sanitized)
35+
data = getattr(response, "data", None)
36+
if not data:
37+
raise RuntimeError("OpenAI embedding response missing 'data' entries.")
38+
39+
embeddings: list[list[float]] = []
40+
for item in data:
41+
vector = getattr(item, "embedding", None)
42+
if vector is None and isinstance(item, Mapping):
43+
vector = item.get("embedding")
44+
if vector is None:
45+
raise RuntimeError("OpenAI embedding response missing 'embedding'.")
46+
try:
47+
embeddings.append([float(value) for value in vector])
48+
except (TypeError, ValueError) as exc:
49+
raise RuntimeError("Embedding vector contains non-numeric values.") from exc
50+
return embeddings
51+
52+
def embed_one(self, text: str) -> list[float] | None:
53+
"""Return a single embedding for convenience."""
54+
vectors = self.embed([text])
55+
return vectors[0] if vectors else None
56+
57+
def _ensure_client(self):
58+
if self._client is None:
59+
try:
60+
from openai import OpenAI # type: ignore[import-not-found]
61+
except ImportError as exc:
62+
raise RuntimeError(
63+
"The `openai` package is required for embeddings. "
64+
"Install it or supply a custom client."
65+
) from exc
66+
67+
sanitized_opts = _sanitize_options(self._client_options)
68+
LOGGER.debug(
69+
"Initializing OpenAI embeddings client",
70+
extra={
71+
"lance_graph": {
72+
"openai_model": self._model,
73+
"openai_options": sanitized_opts,
74+
}
75+
},
76+
)
77+
self._client = OpenAI(**self._client_options)
78+
return self._client
79+
80+
81+
def cosine_similarity(lhs: Sequence[float], rhs: Sequence[float]) -> float:
82+
"""Return cosine similarity between two vectors."""
83+
if len(lhs) != len(rhs):
84+
LOGGER.debug(
85+
"Unable to compute cosine similarity due to mismatched lengths: %s vs %s",
86+
len(lhs),
87+
len(rhs),
88+
)
89+
return 0.0
90+
dot = sum(x * y for x, y in zip(lhs, rhs))
91+
lhs_norm = math.sqrt(sum(x * x for x in lhs))
92+
rhs_norm = math.sqrt(sum(y * y for y in rhs))
93+
if lhs_norm == 0 or rhs_norm == 0:
94+
return 0.0
95+
return dot / (lhs_norm * rhs_norm)
96+
97+
98+
def _sanitize_options(options: Mapping[str, Any]) -> dict[str, Any]:
99+
"""Strip sensitive values for logging."""
100+
sanitized: dict[str, Any] = {}
101+
for key, value in options.items():
102+
if key.lower() in {"api_key", "api-key", "authorization"}:
103+
sanitized[key] = "***"
104+
else:
105+
sanitized[key] = value
106+
return sanitized

0 commit comments

Comments
 (0)