Skip to content

Commit 752bc21

Browse files
authored
feat:add embedding and similarity search for knowledge graph (#17)
* feat:add embedding and similarity search for knowledge graph * Fix lint
1 parent e1ef07d commit 752bc21

4 files changed

Lines changed: 611 additions & 96 deletions

File tree

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ requires = ["maturin>=1.4"]
4141
build-backend = "maturin"
4242

4343
[project.optional-dependencies]
44-
tests = ["pytest", "pyarrow>=14", "pandas"]
44+
tests = ["pytest", "pyarrow>=14", "pandas", "ruff"]
4545
dev = ["ruff", "pyright"]
4646
llm = ["openai>=1.52.0"]
4747
lance-storage = ["lance>=0.17.0"]
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""Embedding utilities backed by the OpenAI client."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
import math
7+
from typing import Any, Mapping, MutableMapping, Sequence
8+
9+
LOGGER = logging.getLogger(__name__)
10+
11+
DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small"
12+
13+
14+
class EmbeddingGenerator:
15+
"""Generate embeddings using OpenAI's embeddings API."""
16+
17+
def __init__(
18+
self,
19+
*,
20+
model: str = DEFAULT_EMBEDDING_MODEL,
21+
client_options: Mapping[str, Any] | None = None,
22+
) -> None:
23+
self._model = model
24+
self._client_options: MutableMapping[str, Any] = dict(client_options or {})
25+
self._client = None
26+
27+
def embed(self, texts: Sequence[str]) -> list[list[float]]:
28+
"""Return embeddings for the provided texts."""
29+
sanitized = [text for text in texts if text]
30+
if not sanitized:
31+
return []
32+
33+
client = self._ensure_client()
34+
response = client.embeddings.create(model=self._model, input=sanitized)
35+
data = getattr(response, "data", None)
36+
if not data:
37+
raise RuntimeError("OpenAI embedding response missing 'data' entries.")
38+
39+
embeddings: list[list[float]] = []
40+
for item in data:
41+
vector = getattr(item, "embedding", None)
42+
if vector is None and isinstance(item, Mapping):
43+
vector = item.get("embedding")
44+
if vector is None:
45+
raise RuntimeError("OpenAI embedding response missing 'embedding'.")
46+
try:
47+
embeddings.append([float(value) for value in vector])
48+
except (TypeError, ValueError) as exc:
49+
raise RuntimeError(
50+
"Embedding vector contains non-numeric values."
51+
) from exc
52+
return embeddings
53+
54+
def embed_one(self, text: str) -> list[float] | None:
55+
"""Return a single embedding for convenience."""
56+
vectors = self.embed([text])
57+
return vectors[0] if vectors else None
58+
59+
def _ensure_client(self):
60+
if self._client is None:
61+
try:
62+
from openai import OpenAI # type: ignore[import-not-found]
63+
except ImportError as exc:
64+
raise RuntimeError(
65+
"The `openai` package is required for embeddings. "
66+
"Install it or supply a custom client."
67+
) from exc
68+
69+
sanitized_opts = _sanitize_options(self._client_options)
70+
LOGGER.debug(
71+
"Initializing OpenAI embeddings client",
72+
extra={
73+
"lance_graph": {
74+
"openai_model": self._model,
75+
"openai_options": sanitized_opts,
76+
}
77+
},
78+
)
79+
self._client = OpenAI(**self._client_options)
80+
return self._client
81+
82+
83+
def cosine_similarity(lhs: Sequence[float], rhs: Sequence[float]) -> float:
84+
"""Return cosine similarity between two vectors."""
85+
if len(lhs) != len(rhs):
86+
LOGGER.debug(
87+
"Unable to compute cosine similarity due to mismatched lengths: %s vs %s",
88+
len(lhs),
89+
len(rhs),
90+
)
91+
return 0.0
92+
dot = sum(x * y for x, y in zip(lhs, rhs))
93+
lhs_norm = math.sqrt(sum(x * x for x in lhs))
94+
rhs_norm = math.sqrt(sum(y * y for y in rhs))
95+
if lhs_norm == 0 or rhs_norm == 0:
96+
return 0.0
97+
return dot / (lhs_norm * rhs_norm)
98+
99+
100+
def _sanitize_options(options: Mapping[str, Any]) -> dict[str, Any]:
101+
"""Strip sensitive values for logging."""
102+
sanitized: dict[str, Any] = {}
103+
for key, value in options.items():
104+
if key.lower() in {"api_key", "api-key", "authorization"}:
105+
sanitized[key] = "***"
106+
else:
107+
sanitized[key] = value
108+
return sanitized

0 commit comments

Comments
 (0)