|
| 1 | +"""Embedding utilities backed by the OpenAI client.""" |
| 2 | + |
| 3 | +from __future__ import annotations |
| 4 | + |
| 5 | +import logging |
| 6 | +import math |
| 7 | +from typing import Any, Mapping, MutableMapping, Sequence |
| 8 | + |
| 9 | +LOGGER = logging.getLogger(__name__) |
| 10 | + |
| 11 | +DEFAULT_EMBEDDING_MODEL = "text-embedding-3-small" |
| 12 | + |
| 13 | + |
| 14 | +class EmbeddingGenerator: |
| 15 | + """Generate embeddings using OpenAI's embeddings API.""" |
| 16 | + |
| 17 | + def __init__( |
| 18 | + self, |
| 19 | + *, |
| 20 | + model: str = DEFAULT_EMBEDDING_MODEL, |
| 21 | + client_options: Mapping[str, Any] | None = None, |
| 22 | + ) -> None: |
| 23 | + self._model = model |
| 24 | + self._client_options: MutableMapping[str, Any] = dict(client_options or {}) |
| 25 | + self._client = None |
| 26 | + |
| 27 | + def embed(self, texts: Sequence[str]) -> list[list[float]]: |
| 28 | + """Return embeddings for the provided texts.""" |
| 29 | + sanitized = [text for text in texts if text] |
| 30 | + if not sanitized: |
| 31 | + return [] |
| 32 | + |
| 33 | + client = self._ensure_client() |
| 34 | + response = client.embeddings.create(model=self._model, input=sanitized) |
| 35 | + data = getattr(response, "data", None) |
| 36 | + if not data: |
| 37 | + raise RuntimeError("OpenAI embedding response missing 'data' entries.") |
| 38 | + |
| 39 | + embeddings: list[list[float]] = [] |
| 40 | + for item in data: |
| 41 | + vector = getattr(item, "embedding", None) |
| 42 | + if vector is None and isinstance(item, Mapping): |
| 43 | + vector = item.get("embedding") |
| 44 | + if vector is None: |
| 45 | + raise RuntimeError("OpenAI embedding response missing 'embedding'.") |
| 46 | + try: |
| 47 | + embeddings.append([float(value) for value in vector]) |
| 48 | + except (TypeError, ValueError) as exc: |
| 49 | + raise RuntimeError( |
| 50 | + "Embedding vector contains non-numeric values." |
| 51 | + ) from exc |
| 52 | + return embeddings |
| 53 | + |
| 54 | + def embed_one(self, text: str) -> list[float] | None: |
| 55 | + """Return a single embedding for convenience.""" |
| 56 | + vectors = self.embed([text]) |
| 57 | + return vectors[0] if vectors else None |
| 58 | + |
| 59 | + def _ensure_client(self): |
| 60 | + if self._client is None: |
| 61 | + try: |
| 62 | + from openai import OpenAI # type: ignore[import-not-found] |
| 63 | + except ImportError as exc: |
| 64 | + raise RuntimeError( |
| 65 | + "The `openai` package is required for embeddings. " |
| 66 | + "Install it or supply a custom client." |
| 67 | + ) from exc |
| 68 | + |
| 69 | + sanitized_opts = _sanitize_options(self._client_options) |
| 70 | + LOGGER.debug( |
| 71 | + "Initializing OpenAI embeddings client", |
| 72 | + extra={ |
| 73 | + "lance_graph": { |
| 74 | + "openai_model": self._model, |
| 75 | + "openai_options": sanitized_opts, |
| 76 | + } |
| 77 | + }, |
| 78 | + ) |
| 79 | + self._client = OpenAI(**self._client_options) |
| 80 | + return self._client |
| 81 | + |
| 82 | + |
| 83 | +def cosine_similarity(lhs: Sequence[float], rhs: Sequence[float]) -> float: |
| 84 | + """Return cosine similarity between two vectors.""" |
| 85 | + if len(lhs) != len(rhs): |
| 86 | + LOGGER.debug( |
| 87 | + "Unable to compute cosine similarity due to mismatched lengths: %s vs %s", |
| 88 | + len(lhs), |
| 89 | + len(rhs), |
| 90 | + ) |
| 91 | + return 0.0 |
| 92 | + dot = sum(x * y for x, y in zip(lhs, rhs)) |
| 93 | + lhs_norm = math.sqrt(sum(x * x for x in lhs)) |
| 94 | + rhs_norm = math.sqrt(sum(y * y for y in rhs)) |
| 95 | + if lhs_norm == 0 or rhs_norm == 0: |
| 96 | + return 0.0 |
| 97 | + return dot / (lhs_norm * rhs_norm) |
| 98 | + |
| 99 | + |
| 100 | +def _sanitize_options(options: Mapping[str, Any]) -> dict[str, Any]: |
| 101 | + """Strip sensitive values for logging.""" |
| 102 | + sanitized: dict[str, Any] = {} |
| 103 | + for key, value in options.items(): |
| 104 | + if key.lower() in {"api_key", "api-key", "authorization"}: |
| 105 | + sanitized[key] = "***" |
| 106 | + else: |
| 107 | + sanitized[key] = value |
| 108 | + return sanitized |
0 commit comments