@@ -39,7 +39,7 @@ class OpenAIDocumentEmbedder:
3939 ```
4040 """
4141
42- def __init__ ( # pylint: disable=too-many-positional-arguments
42+ def __init__ ( # noqa: PLR0913 (too-many-arguments) # pylint: disable=too-many-positional-arguments
4343 self ,
4444 api_key : Secret = Secret .from_env_var ("OPENAI_API_KEY" ),
4545 model : str = "text-embedding-ada-002" ,
@@ -55,6 +55,8 @@ def __init__( # pylint: disable=too-many-positional-arguments
5555 timeout : Optional [float ] = None ,
5656 max_retries : Optional [int ] = None ,
5757 http_client_kwargs : Optional [Dict [str , Any ]] = None ,
58+ * ,
59+ raise_on_failure : bool = False ,
5860 ):
5961 """
6062 Creates an OpenAIDocumentEmbedder component.
@@ -100,6 +102,9 @@ def __init__( # pylint: disable=too-many-positional-arguments
100102 :param http_client_kwargs:
101103 A dictionary of keyword arguments to configure a custom `httpx.Client`or `httpx.AsyncClient`.
102104 For more information, see the [HTTPX documentation](https://www.python-httpx.org/api/#client).
105+ :param raise_on_failure:
106+ Whether to raise an exception if the embedding request fails. If `False`, the component will log the error
107+ and continue processing the remaining documents. If `True`, it will raise an exception on failure.
103108 """
104109 self .api_key = api_key
105110 self .model = model
@@ -115,6 +120,7 @@ def __init__( # pylint: disable=too-many-positional-arguments
115120 self .timeout = timeout
116121 self .max_retries = max_retries
117122 self .http_client_kwargs = http_client_kwargs
123+ self .raise_on_failure = raise_on_failure
118124
119125 if timeout is None :
120126 timeout = float (os .environ .get ("OPENAI_TIMEOUT" , "30.0" ))
@@ -163,6 +169,7 @@ def to_dict(self) -> Dict[str, Any]:
163169 timeout = self .timeout ,
164170 max_retries = self .max_retries ,
165171 http_client_kwargs = self .http_client_kwargs ,
172+ raise_on_failure = self .raise_on_failure ,
166173 )
167174
168175 @classmethod
@@ -194,12 +201,14 @@ def _prepare_texts_to_embed(self, documents: List[Document]) -> Dict[str, str]:
194201
195202 return texts_to_embed
196203
197- def _embed_batch (self , texts_to_embed : Dict [str , str ], batch_size : int ) -> Tuple [List [List [float ]], Dict [str , Any ]]:
204+ def _embed_batch (
205+ self , texts_to_embed : Dict [str , str ], batch_size : int
206+ ) -> Tuple [Dict [str , List [float ]], Dict [str , Any ]]:
198207 """
199208 Embed a list of texts in batches.
200209 """
201210
202- all_embeddings = []
211+ doc_ids_to_embeddings : Dict [ str , List [ float ]] = {}
203212 meta : Dict [str , Any ] = {}
204213 for batch in tqdm (
205214 batched (texts_to_embed .items (), batch_size ), disable = not self .progress_bar , desc = "Calculating embeddings"
@@ -215,10 +224,12 @@ def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple
215224 ids = ", " .join (b [0 ] for b in batch )
216225 msg = "Failed embedding of documents {ids} caused by {exc}"
217226 logger .exception (msg , ids = ids , exc = exc )
227+ if self .raise_on_failure :
228+ raise exc
218229 continue
219230
220231 embeddings = [el .embedding for el in response .data ]
221- all_embeddings . extend ( embeddings )
232+ doc_ids_to_embeddings . update ( dict ( zip (( b [ 0 ] for b in batch ), embeddings )) )
222233
223234 if "model" not in meta :
224235 meta ["model" ] = response .model
@@ -228,16 +239,16 @@ def _embed_batch(self, texts_to_embed: Dict[str, str], batch_size: int) -> Tuple
228239 meta ["usage" ]["prompt_tokens" ] += response .usage .prompt_tokens
229240 meta ["usage" ]["total_tokens" ] += response .usage .total_tokens
230241
231- return all_embeddings , meta
242+ return doc_ids_to_embeddings , meta
232243
233244 async def _embed_batch_async (
234245 self , texts_to_embed : Dict [str , str ], batch_size : int
235- ) -> Tuple [List [ List [float ]], Dict [str , Any ]]:
246+ ) -> Tuple [Dict [ str , List [float ]], Dict [str , Any ]]:
236247 """
237248 Embed a list of texts in batches asynchronously.
238249 """
239250
240- all_embeddings = []
251+ doc_ids_to_embeddings : Dict [ str , List [ float ]] = {}
241252 meta : Dict [str , Any ] = {}
242253
243254 batches = list (batched (texts_to_embed .items (), batch_size ))
@@ -256,10 +267,12 @@ async def _embed_batch_async(
256267 ids = ", " .join (b [0 ] for b in batch )
257268 msg = "Failed embedding of documents {ids} caused by {exc}"
258269 logger .exception (msg , ids = ids , exc = exc )
270+ if self .raise_on_failure :
271+ raise exc
259272 continue
260273
261274 embeddings = [el .embedding for el in response .data ]
262- all_embeddings . extend ( embeddings )
275+ doc_ids_to_embeddings . update ( dict ( zip (( b [ 0 ] for b in batch ), embeddings )) )
263276
264277 if "model" not in meta :
265278 meta ["model" ] = response .model
@@ -269,7 +282,7 @@ async def _embed_batch_async(
269282 meta ["usage" ]["prompt_tokens" ] += response .usage .prompt_tokens
270283 meta ["usage" ]["total_tokens" ] += response .usage .total_tokens
271284
272- return all_embeddings , meta
285+ return doc_ids_to_embeddings , meta
273286
274287 @component .output_types (documents = List [Document ], meta = Dict [str , Any ])
275288 def run (self , documents : List [Document ]):
@@ -292,12 +305,13 @@ def run(self, documents: List[Document]):
292305
293306 texts_to_embed = self ._prepare_texts_to_embed (documents = documents )
294307
295- embeddings , meta = self ._embed_batch (texts_to_embed = texts_to_embed , batch_size = self .batch_size )
308+ doc_ids_to_embeddings , meta = self ._embed_batch (texts_to_embed = texts_to_embed , batch_size = self .batch_size )
296309
297- for doc , emb in zip (documents , embeddings ):
298- doc .embedding = emb
310+ doc_id_to_document = {doc .id : doc for doc in documents }
311+ for doc_id , emb in doc_ids_to_embeddings .items ():
312+ doc_id_to_document [doc_id ].embedding = emb
299313
300- return {"documents" : documents , "meta" : meta }
314+ return {"documents" : list ( doc_id_to_document . values ()) , "meta" : meta }
301315
302316 @component .output_types (documents = List [Document ], meta = Dict [str , Any ])
303317 async def run_async (self , documents : List [Document ]):
@@ -320,9 +334,12 @@ async def run_async(self, documents: List[Document]):
320334
321335 texts_to_embed = self ._prepare_texts_to_embed (documents = documents )
322336
323- embeddings , meta = await self ._embed_batch_async (texts_to_embed = texts_to_embed , batch_size = self .batch_size )
337+ doc_ids_to_embeddings , meta = await self ._embed_batch_async (
338+ texts_to_embed = texts_to_embed , batch_size = self .batch_size
339+ )
324340
325- for doc , emb in zip (documents , embeddings ):
326- doc .embedding = emb
341+ doc_id_to_document = {doc .id : doc for doc in documents }
342+ for doc_id , emb in doc_ids_to_embeddings .items ():
343+ doc_id_to_document [doc_id ].embedding = emb
327344
328- return {"documents" : documents , "meta" : meta }
345+ return {"documents" : list ( doc_id_to_document . values ()) , "meta" : meta }
0 commit comments