-
Notifications
You must be signed in to change notification settings - Fork 44
Expand file tree
/
Copy pathembedder.py
More file actions
307 lines (252 loc) · 11.7 KB
/
embedder.py
File metadata and controls
307 lines (252 loc) · 11.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import List, Optional, Union, cast
from langchain_community.embeddings import (
CohereEmbeddings,
FakeEmbeddings,
LocalAIEmbeddings,
OpenAIEmbeddings,
)
from airbyte_cdk.destinations.vector_db_based.config import (
AzureOpenAIEmbeddingConfigModel,
CohereEmbeddingConfigModel,
FakeEmbeddingConfigModel,
FromFieldEmbeddingConfigModel,
OpenAICompatibleEmbeddingConfigModel,
OpenAIEmbeddingConfigModel,
ProcessingConfigModel,
)
from airbyte_cdk.destinations.vector_db_based.utils import create_chunks, format_exception
from airbyte_cdk.models import AirbyteRecordMessage
from airbyte_cdk.utils.traced_exception import AirbyteTracedException, FailureType
@dataclass
class Document:
page_content: str
record: AirbyteRecordMessage
class Embedder(ABC):
"""
Embedder is an abstract class that defines the interface for embedding text.
The Indexer class uses the Embedder class to internally embed text - each indexer is responsible to pass the text of all documents to the embedder and store the resulting embeddings in the destination.
The destination connector is responsible to create an embedder instance and pass it to the writer.
The CDK defines basic embedders that should be supported in each destination. It is possible to implement custom embedders for special destinations if needed.
"""
def __init__(self) -> None:
pass
@abstractmethod
def check(self) -> Optional[str]:
pass
@abstractmethod
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
"""
Embed the text of each chunk and return the resulting embedding vectors.
If a chunk cannot be embedded or is configured to not be embedded, return None for that chunk.
"""
pass
@property
@abstractmethod
def embedding_dimensions(self) -> int:
pass
OPEN_AI_VECTOR_SIZE = 1536
OPEN_AI_TOKEN_LIMIT = 150_000 # limit of tokens per minute
class BaseOpenAIEmbedder(Embedder):
def __init__(self, embeddings: OpenAIEmbeddings, chunk_size: int):
super().__init__()
self.embeddings = embeddings
self.chunk_size = chunk_size
def check(self) -> Optional[str]:
try:
self.embeddings.embed_query("test")
except Exception as e:
return format_exception(e)
return None
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
"""
Embed the text of each chunk and return the resulting embedding vectors.
As the OpenAI API will fail if more than the per-minute limit worth of tokens is sent at once, we split the request into batches and embed each batch separately.
It's still possible to run into the rate limit between each embed call because the available token budget hasn't recovered between the calls,
but the built-in retry mechanism of the OpenAI client handles that.
"""
# Each chunk can hold at most self.chunk_size tokens, so tokens-per-minute by maximum tokens per chunk is the number of documents that can be embedded at once without exhausting the limit in a single request
embedding_batch_size = OPEN_AI_TOKEN_LIMIT // self.chunk_size
batches = create_chunks(documents, batch_size=embedding_batch_size)
embeddings: List[Optional[List[float]]] = []
for batch in batches:
embeddings.extend(
self.embeddings.embed_documents([chunk.page_content for chunk in batch])
)
return embeddings
@property
def embedding_dimensions(self) -> int:
# vector size produced by text-embedding-ada-002 model
return OPEN_AI_VECTOR_SIZE
class OpenAIEmbedder(BaseOpenAIEmbedder):
def __init__(self, config: OpenAIEmbeddingConfigModel, chunk_size: int):
super().__init__(
OpenAIEmbeddings( # type: ignore [call-arg]
openai_api_key=config.openai_key, max_retries=15, disallowed_special=()
),
chunk_size,
) # type: ignore
class AzureOpenAIEmbedder(BaseOpenAIEmbedder):
def __init__(self, config: AzureOpenAIEmbeddingConfigModel, chunk_size: int):
# Azure OpenAI API has — as of 20230927 — a limit of 16 documents per request
super().__init__(
OpenAIEmbeddings( # type: ignore [call-arg]
openai_api_key=config.openai_key,
chunk_size=16,
max_retries=15,
openai_api_type="azure",
openai_api_version="2023-05-15",
openai_api_base=config.api_base,
deployment=config.deployment,
disallowed_special=(),
),
chunk_size,
) # type: ignore
COHERE_VECTOR_SIZE = 1024
class CohereEmbedder(Embedder):
def __init__(self, config: CohereEmbeddingConfigModel):
super().__init__()
# Client is set internally
self.embeddings = CohereEmbeddings(
cohere_api_key=config.cohere_key,
model="embed-english-light-v2.0",
user_agent="airbyte-cdk",
) # type: ignore
def check(self) -> Optional[str]:
try:
self.embeddings.embed_query("test")
except Exception as e:
return format_exception(e)
return None
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
return cast(
List[Optional[List[float]]],
self.embeddings.embed_documents([document.page_content for document in documents]),
)
@property
def embedding_dimensions(self) -> int:
# vector size produced by text-embedding-ada-002 model
return COHERE_VECTOR_SIZE
class FakeEmbedder(Embedder):
def __init__(self, config: FakeEmbeddingConfigModel):
super().__init__()
self.embeddings = FakeEmbeddings(size=OPEN_AI_VECTOR_SIZE)
def check(self) -> Optional[str]:
try:
self.embeddings.embed_query("test")
except Exception as e:
return format_exception(e)
return None
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
return cast(
List[Optional[List[float]]],
self.embeddings.embed_documents([document.page_content for document in documents]),
)
@property
def embedding_dimensions(self) -> int:
# use same vector size as for OpenAI embeddings to keep it realistic
return OPEN_AI_VECTOR_SIZE
CLOUD_DEPLOYMENT_MODE = "cloud"
class OpenAICompatibleEmbedder(Embedder):
def __init__(self, config: OpenAICompatibleEmbeddingConfigModel):
super().__init__()
self.config = config
# Client is set internally
# Always set an API key even if there is none defined in the config because the validator will fail otherwise. Embedding APIs that don't require an API key don't fail if one is provided, so this is not breaking usage.
self.embeddings = LocalAIEmbeddings(
model=config.model_name,
openai_api_key=config.api_key or "dummy-api-key",
openai_api_base=config.base_url,
max_retries=15,
disallowed_special=(),
) # type: ignore
def check(self) -> Optional[str]:
deployment_mode = os.environ.get("DEPLOYMENT_MODE", "")
if (
deployment_mode.casefold() == CLOUD_DEPLOYMENT_MODE
and not self.config.base_url.startswith("https://")
):
return "Base URL must start with https://"
try:
self.embeddings.embed_query("test")
except Exception as e:
return format_exception(e)
return None
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
return cast(
List[Optional[List[float]]],
self.embeddings.embed_documents([document.page_content for document in documents]),
)
@property
def embedding_dimensions(self) -> int:
# vector size produced by the model
return self.config.dimensions
class FromFieldEmbedder(Embedder):
def __init__(self, config: FromFieldEmbeddingConfigModel):
super().__init__()
self.config = config
def check(self) -> Optional[str]:
return None
def embed_documents(self, documents: List[Document]) -> List[Optional[List[float]]]:
"""
From each chunk, pull the embedding from the field specified in the config.
Check that the field exists, is a list of numbers and is the correct size. If not, raise an AirbyteTracedException explaining the problem.
"""
embeddings: List[Optional[List[float]]] = []
for document in documents:
data = document.record.data
if self.config.field_name not in data:
raise AirbyteTracedException(
internal_message="Embedding vector field not found",
failure_type=FailureType.config_error,
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does not contain embedding vector field {self.config.field_name}. Please check your embedding configuration, the embedding vector field has to be set correctly on every record.",
)
field = data[self.config.field_name]
if not isinstance(field, list) or not all(isinstance(x, (int, float)) for x in field):
raise AirbyteTracedException(
internal_message="Embedding vector field not a list of numbers",
failure_type=FailureType.config_error,
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does contain embedding vector field {self.config.field_name}, but it is not a list of numbers. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
)
if len(field) != self.config.dimensions:
raise AirbyteTracedException(
internal_message="Embedding vector field has wrong length",
failure_type=FailureType.config_error,
message=f"Record {str(data)[:250]}... in stream {document.record.stream} does contain embedding vector field {self.config.field_name}, but it has length {len(field)} instead of the configured {self.config.dimensions}. Please check your embedding configuration, the embedding vector field has to be a list of numbers of length {self.config.dimensions} on every record.",
)
embeddings.append(field)
return embeddings
@property
def embedding_dimensions(self) -> int:
return self.config.dimensions
embedder_map = {
"openai": OpenAIEmbedder,
"cohere": CohereEmbedder,
"fake": FakeEmbedder,
"azure_openai": AzureOpenAIEmbedder,
"from_field": FromFieldEmbedder,
"openai_compatible": OpenAICompatibleEmbedder,
}
def create_from_config(
embedding_config: Union[
AzureOpenAIEmbeddingConfigModel,
CohereEmbeddingConfigModel,
FakeEmbeddingConfigModel,
FromFieldEmbeddingConfigModel,
OpenAIEmbeddingConfigModel,
OpenAICompatibleEmbeddingConfigModel,
],
processing_config: ProcessingConfigModel,
) -> Embedder:
if embedding_config.mode == "azure_openai" or embedding_config.mode == "openai":
return cast(
Embedder,
embedder_map[embedding_config.mode](embedding_config, processing_config.chunk_size),
)
else:
return cast(Embedder, embedder_map[embedding_config.mode](embedding_config))