33# SPDX-License-Identifier: AGPL-3.0-or-later
44#
55import logging
6+ from time import sleep
67from typing import Literal , TypedDict
78
89import httpx
@@ -36,7 +37,7 @@ class CreateEmbeddingResponse(TypedDict):
3637class NetworkEmbeddings (Embeddings , BaseModel ):
3738 app_config : TConfig
3839
39- def _get_embedding (self , input_ : str | list [str ]) -> list [float ] | list [list [float ]]:
40+ def _get_embedding (self , input_ : str | list [str ], try_ : int = 3 ) -> list [float ] | list [list [float ]]:
4041 emconf = self .app_config .embedding
4142
4243 lengths = [len (text ) for text in (input_ if isinstance (input_ , list ) else [input_ ])]
@@ -52,14 +53,21 @@ def _get_embedding(self, input_: str | list[str]) -> list[float] | list[list[flo
5253 json = {'input' : input_ },
5354 timeout = emconf .request_timeout ,
5455 )
55- except Exception as e :
56+ if response .status_code != 200 :
57+ raise EmbeddingException (response .text )
58+ except (
59+ EmbeddingException ,
60+ httpx .RemoteProtocolError ,
61+ httpx .ReadError ,
62+ httpx .LocalProtocolError ,
63+ httpx .PoolTimeout ,
64+ ) as e :
65+ if try_ > 0 :
66+ logger .debug ('Retrying embedding request in 5 secs' , extra = {'try' : try_ })
67+ sleep (5 )
68+ return self ._get_embedding (input_ , try_ - 1 )
5669 raise EmbeddingException ('Error: request to get embeddings failed' ) from e
5770
58- try :
59- response .raise_for_status ()
60- except Exception as e :
61- raise EmbeddingException (f'Error: failed to get embeddings: { response .text } ' ) from e
62-
6371 # converts TypedDict to a pydantic model
6472 resp = CreateEmbeddingResponse (** response .json ())
6573 if isinstance (input_ , str ):
0 commit comments