44from dataclasses import replace
55from typing import Any
66
7- import requests
7+ import httpx
88from haystack import Document , component , default_from_dict , default_to_dict
99from haystack .utils import Secret , deserialize_secrets_inplace
1010from tqdm import tqdm
@@ -89,14 +89,11 @@ def __init__(
8989 self .progress_bar = progress_bar
9090 self .meta_fields_to_embed = meta_fields_to_embed or []
9191 self .embedding_separator = embedding_separator
92- self ._session = requests .Session ()
93- self ._session .headers .update (
94- {
95- "Authorization" : f"Bearer { resolved_api_key } " ,
96- "Accept-Encoding" : "identity" ,
97- "Content-type" : "application/json" ,
98- }
99- )
92+ self ._headers = {
93+ "Authorization" : f"Bearer { resolved_api_key } " ,
94+ "Accept-Encoding" : "identity" ,
95+ "Content-type" : "application/json" ,
96+ }
10097 self .task = task
10198 self .dimensions = dimensions
10299 self .late_chunking = late_chunking
@@ -164,40 +161,96 @@ def _prepare_texts_to_embed(self, documents: list[Document]) -> list[str]:
164161 texts_to_embed .append (text_to_embed )
165162 return texts_to_embed
166163
164+ def _validate_input (self , documents : list [Document ]) -> None :
165+ if not isinstance (documents , list ) or (documents and not isinstance (documents [0 ], Document )):
166+ msg = (
167+ "JinaDocumentEmbedder expects a list of Documents as input."
168+ "In case you want to embed a string, please use the JinaTextEmbedder."
169+ )
170+ raise TypeError (msg )
171+
172+ def _prepare_parameters (self ) -> dict [str , Any ]:
173+ parameters : dict [str , Any ] = {}
174+ if self .task is not None :
175+ parameters ["task" ] = self .task
176+ if self .dimensions is not None :
177+ parameters ["dimensions" ] = self .dimensions
178+ if self .late_chunking is not None :
179+ parameters ["late_chunking" ] = self .late_chunking
180+ return parameters
181+
182+ @staticmethod
183+ def _process_batch_response (
184+ response : dict [str , Any ], all_embeddings : list [list [float ]], metadata : dict [str , Any ]
185+ ) -> None :
186+ if "data" not in response :
187+ raise RuntimeError (response ["detail" ])
188+
189+ # Sort resulting embeddings by index
190+ sorted_embeddings = sorted (response ["data" ], key = lambda e : e ["index" ])
191+ embeddings = [result ["embedding" ] for result in sorted_embeddings ]
192+ all_embeddings .extend (embeddings )
193+ if "model" not in metadata :
194+ metadata ["model" ] = response ["model" ]
195+ if "usage" not in metadata :
196+ metadata ["usage" ] = dict (response ["usage" ].items ())
197+ else :
198+ metadata ["usage" ]["prompt_tokens" ] += response ["usage" ]["prompt_tokens" ]
199+ metadata ["usage" ]["total_tokens" ] += response ["usage" ]["total_tokens" ]
200+
167201 def _embed_batch (
168202 self , texts_to_embed : list [str ], batch_size : int , parameters : dict | None = None
169203 ) -> tuple [list [list [float ]], dict [str , Any ]]:
170- """
171- Embed a list of texts in batches.
172- """
204+ """Embed a list of texts in batches."""
205+ all_embeddings : list [list [float ]] = []
206+ metadata : dict [str , Any ] = {}
207+ with httpx .Client () as client :
208+ for i in tqdm (
209+ range (0 , len (texts_to_embed ), batch_size ),
210+ disable = not self .progress_bar ,
211+ desc = "Calculating embeddings" ,
212+ ):
213+ batch = texts_to_embed [i : i + batch_size ]
214+ response = client .post (
215+ self .base_url ,
216+ json = {"input" : batch , "model" : self .model_name , ** (parameters or {})},
217+ headers = self ._headers ,
218+ ).json ()
219+ self ._process_batch_response (response , all_embeddings , metadata )
173220
174- all_embeddings = []
175- metadata = {}
176- for i in tqdm (
177- range (0 , len (texts_to_embed ), batch_size ), disable = not self .progress_bar , desc = "Calculating embeddings"
178- ):
179- batch = texts_to_embed [i : i + batch_size ]
180- response = self ._session .post (
181- self .base_url ,
182- json = {"input" : batch , "model" : self .model_name , ** (parameters or {})},
183- ).json ()
184- if "data" not in response :
185- raise RuntimeError (response ["detail" ])
186-
187- # Sort resulting embeddings by index
188- sorted_embeddings = sorted (response ["data" ], key = lambda e : e ["index" ])
189- embeddings = [result ["embedding" ] for result in sorted_embeddings ]
190- all_embeddings .extend (embeddings )
191- if "model" not in metadata :
192- metadata ["model" ] = response ["model" ]
193- if "usage" not in metadata :
194- metadata ["usage" ] = dict (response ["usage" ].items ())
195- else :
196- metadata ["usage" ]["prompt_tokens" ] += response ["usage" ]["prompt_tokens" ]
197- metadata ["usage" ]["total_tokens" ] += response ["usage" ]["total_tokens" ]
221+ return all_embeddings , metadata
222+
223+ async def _embed_batch_async (
224+ self , texts_to_embed : list [str ], batch_size : int , parameters : dict | None = None
225+ ) -> tuple [list [list [float ]], dict [str , Any ]]:
226+ """Asynchronously embed a list of texts in batches."""
227+ all_embeddings : list [list [float ]] = []
228+ metadata : dict [str , Any ] = {}
229+ async with httpx .AsyncClient () as client :
230+ for i in tqdm (
231+ range (0 , len (texts_to_embed ), batch_size ),
232+ disable = not self .progress_bar ,
233+ desc = "Calculating embeddings" ,
234+ ):
235+ batch = texts_to_embed [i : i + batch_size ]
236+ response = await client .post (
237+ self .base_url ,
238+ json = {"input" : batch , "model" : self .model_name , ** (parameters or {})},
239+ headers = self ._headers ,
240+ )
241+ self ._process_batch_response (response .json (), all_embeddings , metadata )
198242
199243 return all_embeddings , metadata
200244
245+ @staticmethod
246+ def _build_result (
247+ documents : list [Document ], embeddings : list [list [float ]], metadata : dict [str , Any ]
248+ ) -> dict [str , Any ]:
249+ new_documents : list [Document ] = []
250+ for doc , emb in zip (documents , embeddings , strict = True ):
251+ new_documents .append (replace (doc , embedding = emb ))
252+ return {"documents" : new_documents , "meta" : metadata }
253+
201254 @component .output_types (documents = list [Document ], meta = dict [str , Any ])
202255 def run (self , documents : list [Document ]) -> dict [str , Any ]:
203256 """
@@ -209,27 +262,36 @@ def run(self, documents: list[Document]) -> dict[str, Any]:
209262 - `meta`: A dictionary with metadata including the model name and usage statistics.
210263 :raises TypeError: If the input is not a list of Documents.
211264 """
212- if not isinstance (documents , list ) or (documents and not isinstance (documents [0 ], Document )):
213- msg = (
214- "JinaDocumentEmbedder expects a list of Documents as input."
215- "In case you want to embed a string, please use the JinaTextEmbedder."
216- )
217- raise TypeError (msg )
265+ self ._validate_input (documents )
218266
219267 texts_to_embed = self ._prepare_texts_to_embed (documents = documents )
220- parameters : dict [str , Any ] = {}
221- if self .task is not None :
222- parameters ["task" ] = self .task
223- if self .dimensions is not None :
224- parameters ["dimensions" ] = self .dimensions
225- if self .late_chunking is not None :
226- parameters ["late_chunking" ] = self .late_chunking
268+ parameters = self ._prepare_parameters ()
227269 embeddings , metadata = self ._embed_batch (
228270 texts_to_embed = texts_to_embed , batch_size = self .batch_size , parameters = parameters
229271 )
230272
231- new_documents : list [Document ] = []
232- for doc , emb in zip (documents , embeddings , strict = True ):
233- new_documents .append (replace (doc , embedding = emb ))
273+ return self ._build_result (documents , embeddings , metadata )
234274
235- return {"documents" : new_documents , "meta" : metadata }
275+ @component .output_types (documents = list [Document ], meta = dict [str , Any ])
276+ async def run_async (self , documents : list [Document ]) -> dict [str , Any ]:
277+ """
278+ Asynchronously compute the embeddings for a list of Documents.
279+
280+ This is the asynchronous version of the `run` method. It has the same parameters and return values
281+ but can be used with `await` in async code.
282+
283+ :param documents: A list of Documents to embed.
284+ :returns: A dictionary with following keys:
285+ - `documents`: List of Documents, each with an `embedding` field containing the computed embedding.
286+ - `meta`: A dictionary with metadata including the model name and usage statistics.
287+ :raises TypeError: If the input is not a list of Documents.
288+ """
289+ self ._validate_input (documents )
290+
291+ texts_to_embed = self ._prepare_texts_to_embed (documents = documents )
292+ parameters = self ._prepare_parameters ()
293+ embeddings , metadata = await self ._embed_batch_async (
294+ texts_to_embed = texts_to_embed , batch_size = self .batch_size , parameters = parameters
295+ )
296+
297+ return self ._build_result (documents , embeddings , metadata )
0 commit comments