-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Expand file tree
/
Copy pathopenai_document_embedder.py
More file actions
353 lines (297 loc) · 13.6 KB
/
openai_document_embedder.py
File metadata and controls
353 lines (297 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
#
# SPDX-License-Identifier: Apache-2.0
import os
from dataclasses import replace
from typing import Any
from more_itertools import batched
from openai import APIError, AsyncOpenAI, OpenAI
from tqdm import tqdm
from tqdm.asyncio import tqdm as async_tqdm
from haystack import Document, component, default_from_dict, default_to_dict, logging
from haystack.utils import Secret, deserialize_secrets_inplace, get_progress_bar_setting
from haystack.utils.http_client import init_http_client
logger = logging.getLogger(__name__)
@component
class OpenAIDocumentEmbedder:
"""
Computes document embeddings using OpenAI models.
### Usage example
```python
from haystack import Document
from haystack.components.embedders import OpenAIDocumentEmbedder
doc = Document(content="I love pizza!")
document_embedder = OpenAIDocumentEmbedder()
result = document_embedder.run([doc])
print(result['documents'][0].embedding)
# [0.017020374536514282, -0.023255806416273117, ...]
```
"""
def __init__( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-positional-arguments
self,
api_key: Secret = Secret.from_env_var("OPENAI_API_KEY"),
model: str = "text-embedding-ada-002",
dimensions: int | None = None,
api_base_url: str | None = None,
organization: str | None = None,
prefix: str = "",
suffix: str = "",
batch_size: int = 32,
progress_bar: bool = True,
meta_fields_to_embed: list[str] | None = None,
embedding_separator: str = "\n",
timeout: float | None = None,
max_retries: int | None = None,
http_client_kwargs: dict[str, Any] | None = None,
*,
raise_on_failure: bool = False,
):
"""
Creates an OpenAIDocumentEmbedder component.
Before initializing the component, you can set the 'OPENAI_TIMEOUT' and 'OPENAI_MAX_RETRIES'
environment variables to override the `timeout` and `max_retries` parameters respectively
in the OpenAI client.
:param api_key:
The OpenAI API key.
You can set it with an environment variable `OPENAI_API_KEY`, or pass with this parameter
during initialization.
:param model:
The name of the model to use for calculating embeddings.
The default model is `text-embedding-ada-002`.
:param dimensions:
The number of dimensions of the resulting embeddings. Only `text-embedding-3` and
later models support this parameter.
:param api_base_url:
Overrides the default base URL for all HTTP requests.
:param organization:
Your OpenAI organization ID. See OpenAI's
[Setting Up Your Organization](https://platform.openai.com/docs/guides/production-best-practices/setting-up-your-organization)
for more information.
:param prefix:
A string to add at the beginning of each text.
:param suffix:
A string to add at the end of each text.
:param batch_size:
Number of documents to embed at once.
:param progress_bar:
If `True`, shows a progress bar when running.
:param meta_fields_to_embed:
List of metadata fields to embed along with the document text.
:param embedding_separator:
Separator used to concatenate the metadata fields to the document text.
:param timeout:
Timeout for OpenAI client calls. If not set, it defaults to either the
`OPENAI_TIMEOUT` environment variable, or 30 seconds.
:param max_retries:
Maximum number of retries to contact OpenAI after an internal error.
If not set, it defaults to either the `OPENAI_MAX_RETRIES` environment variable, or 5 retries.
:param http_client_kwargs:
A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
:param raise_on_failure:
Whether to raise an exception if the embedding request fails. If `False`, the component will log the error
and continue processing the remaining documents. If `True`, it will raise an exception on failure.
"""
self.api_key = api_key
self.model = model
self.dimensions = dimensions
self.api_base_url = api_base_url
self.organization = organization
self.prefix = prefix
self.suffix = suffix
self.batch_size = batch_size
self._progress_bar_param = progress_bar
self.progress_bar = get_progress_bar_setting(progress_bar)
self.meta_fields_to_embed = meta_fields_to_embed or []
self.embedding_separator = embedding_separator
self.timeout = timeout
self.max_retries = max_retries
self.http_client_kwargs = http_client_kwargs
self.raise_on_failure = raise_on_failure
if timeout is None:
timeout = float(os.environ.get("OPENAI_TIMEOUT", "30.0"))
if max_retries is None:
max_retries = int(os.environ.get("OPENAI_MAX_RETRIES", "5"))
client_kwargs: dict[str, Any] = {
"api_key": api_key.resolve_value(),
"organization": organization,
"base_url": api_base_url,
"timeout": timeout,
"max_retries": max_retries,
}
self.client = OpenAI(http_client=init_http_client(self.http_client_kwargs, async_client=False), **client_kwargs)
self.async_client = AsyncOpenAI(
http_client=init_http_client(self.http_client_kwargs, async_client=True), **client_kwargs
)
def _get_telemetry_data(self) -> dict[str, Any]:
"""
Data that is sent to Posthog for usage analytics.
"""
return {"model": self.model}
def to_dict(self) -> dict[str, Any]:
"""
Serializes the component to a dictionary.
:returns:
Dictionary with serialized data.
"""
return default_to_dict(
self,
api_key=self.api_key.to_dict(),
model=self.model,
dimensions=self.dimensions,
api_base_url=self.api_base_url,
organization=self.organization,
prefix=self.prefix,
suffix=self.suffix,
batch_size=self.batch_size,
progress_bar=self._progress_bar_param,
meta_fields_to_embed=self.meta_fields_to_embed,
embedding_separator=self.embedding_separator,
timeout=self.timeout,
max_retries=self.max_retries,
http_client_kwargs=self.http_client_kwargs,
raise_on_failure=self.raise_on_failure,
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "OpenAIDocumentEmbedder":
"""
Deserializes the component from a dictionary.
:param data:
Dictionary to deserialize from.
:returns:
Deserialized component.
"""
deserialize_secrets_inplace(data["init_parameters"], keys=["api_key"])
return default_from_dict(cls, data)
def _prepare_texts_to_embed(self, documents: list[Document]) -> dict[str, str]:
"""
Prepare the texts to embed by concatenating the Document text with the metadata fields to embed.
"""
texts_to_embed = {}
for doc in documents:
meta_values_to_embed = [
str(doc.meta[key]) for key in self.meta_fields_to_embed if key in doc.meta and doc.meta[key] is not None
]
texts_to_embed[doc.id] = (
self.prefix + self.embedding_separator.join(meta_values_to_embed + [doc.content or ""]) + self.suffix
)
return texts_to_embed
def _embed_batch(
self, texts_to_embed: dict[str, str], batch_size: int
) -> tuple[dict[str, list[float]], dict[str, Any]]:
"""
Embed a list of texts in batches.
"""
doc_ids_to_embeddings: dict[str, list[float]] = {}
meta: dict[str, Any] = {}
for batch in tqdm(
batched(texts_to_embed.items(), batch_size), disable=not self.progress_bar, desc="Calculating embeddings"
):
args: dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch], "encoding_format": "float"}
if self.dimensions is not None:
args["dimensions"] = self.dimensions
try:
response = self.client.embeddings.create(**args)
except APIError as exc:
ids = ", ".join(b[0] for b in batch)
msg = "Failed embedding of documents {ids} caused by {exc}"
logger.exception(msg, ids=ids, exc=exc)
if self.raise_on_failure:
raise exc
continue
embeddings = [el.embedding for el in response.data]
doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings)))
if "model" not in meta:
meta["model"] = response.model
if "usage" not in meta:
meta["usage"] = dict(response.usage)
else:
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
meta["usage"]["total_tokens"] += response.usage.total_tokens
return doc_ids_to_embeddings, meta
async def _embed_batch_async(
self, texts_to_embed: dict[str, str], batch_size: int
) -> tuple[dict[str, list[float]], dict[str, Any]]:
"""
Embed a list of texts in batches asynchronously.
"""
doc_ids_to_embeddings: dict[str, list[float]] = {}
meta: dict[str, Any] = {}
batches = list(batched(texts_to_embed.items(), batch_size))
if self.progress_bar:
batches = async_tqdm(batches, desc="Calculating embeddings")
for batch in batches:
args: dict[str, Any] = {"model": self.model, "input": [b[1] for b in batch]}
if self.dimensions is not None:
args["dimensions"] = self.dimensions
try:
response = await self.async_client.embeddings.create(**args)
except APIError as exc:
ids = ", ".join(b[0] for b in batch)
msg = "Failed embedding of documents {ids} caused by {exc}"
logger.exception(msg, ids=ids, exc=exc)
if self.raise_on_failure:
raise exc
continue
embeddings = [el.embedding for el in response.data]
doc_ids_to_embeddings.update(dict(zip((b[0] for b in batch), embeddings)))
if "model" not in meta:
meta["model"] = response.model
if "usage" not in meta:
meta["usage"] = dict(response.usage)
else:
meta["usage"]["prompt_tokens"] += response.usage.prompt_tokens
meta["usage"]["total_tokens"] += response.usage.total_tokens
return doc_ids_to_embeddings, meta
@component.output_types(documents=list[Document], meta=dict[str, Any])
def run(self, documents: list[Document]):
"""
Embeds a list of documents.
:param documents:
A list of documents to embed.
:returns:
A dictionary with the following keys:
- `documents`: A list of documents with embeddings.
- `meta`: Information about the usage of the model.
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
raise TypeError(
"OpenAIDocumentEmbedder expects a list of Documents as input."
"In case you want to embed a string, please use the OpenAITextEmbedder."
)
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
doc_ids_to_embeddings, meta = self._embed_batch(texts_to_embed=texts_to_embed, batch_size=self.batch_size)
new_documents = []
for doc in documents:
if doc.id in doc_ids_to_embeddings:
new_documents.append(replace(doc, embedding=doc_ids_to_embeddings[doc.id]))
else:
new_documents.append(replace(doc))
return {"documents": new_documents, "meta": meta}
@component.output_types(documents=list[Document], meta=dict[str, Any])
async def run_async(self, documents: list[Document]):
"""
Embeds a list of documents asynchronously.
:param documents:
A list of documents to embed.
:returns:
A dictionary with the following keys:
- `documents`: A list of documents with embeddings.
- `meta`: Information about the usage of the model.
"""
if not isinstance(documents, list) or documents and not isinstance(documents[0], Document):
raise TypeError(
"OpenAIDocumentEmbedder expects a list of Documents as input. "
"In case you want to embed a string, please use the OpenAITextEmbedder."
)
texts_to_embed = self._prepare_texts_to_embed(documents=documents)
doc_ids_to_embeddings, meta = await self._embed_batch_async(
texts_to_embed=texts_to_embed, batch_size=self.batch_size
)
new_documents = []
for doc in documents:
if doc.id in doc_ids_to_embeddings:
new_documents.append(replace(doc, embedding=doc_ids_to_embeddings[doc.id]))
else:
new_documents.append(replace(doc))
return {"documents": new_documents, "meta": meta}