-
Notifications
You must be signed in to change notification settings - Fork 190
Expand file tree
/
Copy pathopenai_provider.py
More file actions
132 lines (110 loc) · 4.85 KB
/
openai_provider.py
File metadata and controls
132 lines (110 loc) · 4.85 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""OpenAI-based embedding provider for cloud or API-backed semantic indexing."""
from __future__ import annotations
import asyncio
import os
from typing import Any
from basic_memory.repository.embedding_provider import EmbeddingProvider
from basic_memory.repository.semantic_errors import SemanticDependenciesMissingError
class OpenAIEmbeddingProvider(EmbeddingProvider):
"""Embedding provider backed by OpenAI's embeddings API."""
def __init__(
self,
model_name: str = "text-embedding-3-small",
*,
batch_size: int = 64,
request_concurrency: int = 4,
dimensions: int = 1536,
api_key: str | None = None,
base_url: str | None = None,
timeout: float = 30.0,
) -> None:
self.model_name = model_name
self.dimensions = dimensions
self.batch_size = batch_size
self.request_concurrency = request_concurrency
self._api_key = api_key
self._base_url = base_url
self._timeout = timeout
self._client: Any | None = None
self._client_lock = asyncio.Lock()
def runtime_log_attrs(self) -> dict[str, int]:
"""Return the request fan-out knobs that shape API embedding batches."""
return {
"provider_batch_size": self.batch_size,
"request_concurrency": self.request_concurrency,
}
async def _get_client(self) -> Any:
if self._client is not None:
return self._client
async with self._client_lock:
if self._client is not None:
return self._client
try:
from openai import AsyncOpenAI # type: ignore[import-not-found]
except ImportError as exc: # pragma: no cover - covered via monkeypatch tests
raise SemanticDependenciesMissingError(
"OpenAI dependency is missing. "
"Install/update basic-memory to include semantic dependencies: "
"pip install -U basic-memory"
) from exc
api_key = self._api_key or os.getenv("OPENAI_API_KEY")
if not api_key:
raise SemanticDependenciesMissingError(
"OpenAI embedding provider requires OPENAI_API_KEY."
)
self._client = AsyncOpenAI(
api_key=api_key,
base_url=self._base_url,
timeout=self._timeout,
)
return self._client
async def embed_documents(self, texts: list[str]) -> list[list[float]]:
if not texts:
return []
client = await self._get_client()
batches = [
texts[start : start + self.batch_size]
for start in range(0, len(texts), self.batch_size)
]
batch_vectors: list[list[list[float]] | None] = [None] * len(batches)
semaphore = asyncio.Semaphore(self.request_concurrency)
async def embed_batch(batch_index: int, batch: list[str]) -> None:
async with semaphore:
response = await client.embeddings.create(
model=self.model_name,
input=batch,
)
vectors_by_index: dict[int, list[float]] = {}
for item in response.data:
response_index = int(item.index)
if response_index in vectors_by_index:
raise RuntimeError(
"OpenAI embedding response returned duplicate vector indexes."
)
vectors_by_index[response_index] = [float(value) for value in item.embedding]
ordered_vectors: list[list[float]] = []
for index in range(len(batch)):
vector = vectors_by_index.get(index)
if vector is None:
raise RuntimeError(
"OpenAI embedding response is missing expected vector index."
)
ordered_vectors.append(vector)
batch_vectors[batch_index] = ordered_vectors
await asyncio.gather(
*(embed_batch(batch_index, batch) for batch_index, batch in enumerate(batches))
)
all_vectors: list[list[float]] = []
for vectors in batch_vectors:
if vectors is None:
raise RuntimeError("OpenAI embedding batch did not produce vectors.")
all_vectors.extend(vectors)
if all_vectors and len(all_vectors[0]) != self.dimensions:
raise RuntimeError(
f"Embedding model returned {len(all_vectors[0])}-dimensional vectors "
f"but provider was configured for {self.dimensions} dimensions."
)
return all_vectors
async def embed_query(self, text: str) -> list[float]:
vectors = await self.embed_documents([text])
return vectors[0] if vectors else [0.0] * self.dimensions