33# SPDX-License-Identifier: AGPL-3.0-or-later
44#
55import logging
6- from collections .abc import Generator
76from time import sleep
87from typing import Literal , TypedDict
98
10- import httpx
9+ import niquests
1110from langchain_core .embeddings import Embeddings
1211from pydantic import BaseModel
1312
@@ -42,15 +41,6 @@ class CreateEmbeddingResponse(TypedDict):
4241 usage : EmbeddingUsage
4342
4443
45- class ApiKeyAuth (httpx .Auth ):
46- def __init__ (self , apikey : str | bytes ) -> None :
47- self ._apikey = apikey
48-
49- def auth_flow (self , request : httpx .Request ) -> Generator [httpx .Request , httpx .Response , None ]:
50- request .headers ['Authorization' ] = f'Bearer { self ._apikey } '
51- yield request
52-
53-
5444class NetworkEmbeddings (Embeddings , BaseModel ):
5545 app_config : TConfig
5646
@@ -66,43 +56,46 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float]
6656 try :
6757 match emconf .auth :
6858 case None :
69- auth = httpx . USE_CLIENT_DEFAULT
59+ auth = None
7060 case TEmbeddingAuthApiKey (apikey = apikey ):
71- auth = ApiKeyAuth ( apikey = apikey )
61+ auth = niquests . auth . BearerTokenAuth ( token = apikey ) # pyright: ignore[reportAttributeAccessIssue]
7262 case TEmbeddingAuthBasic (username = username , password = password ):
73- auth = httpx . BasicAuth (username = username , password = password )
63+ auth = niquests . auth . HTTPBasicAuth (username = username , password = password ) # pyright: ignore[reportAttributeAccessIssue]
7464
7565 data = {'input' : input_ }
7666 if emconf .model_name :
7767 data ['model' ] = emconf .model_name
7868
79- with httpx .Client (verify = self .app_config .httpx_verify_ssl ) as client :
80- response = client .post (
81- f'{ emconf .base_url .removesuffix ("/" )} /embeddings' ,
82- json = data ,
83- timeout = emconf .request_timeout ,
84- auth = auth ,
85- )
86- if response .status_code // 100 == 4 :
87- raise FatalEmbeddingException (response .text )
88- if response .status_code // 100 != 2 :
89- raise EmbeddingException (response .text )
69+ response = niquests .post (
70+ f'{ emconf .base_url .removesuffix ("/" )} /embeddings' ,
71+ json = data ,
72+ timeout = emconf .request_timeout ,
73+ auth = auth ,
74+ verify = self .app_config .verify_ssl ,
75+ )
76+ if response .status_code is None :
77+ raise EmbeddingException ('Error: no response from embedding service' )
78+ if response .status_code // 100 == 4 :
79+ raise FatalEmbeddingException (response .text )
80+ if response .status_code // 100 != 2 :
81+ raise EmbeddingException (response .text )
9082 except FatalEmbeddingException as e :
9183 logger .error ('Fatal error while getting embeddings: %s' , str (e ), exc_info = e )
9284 raise e
93- except (
94- EmbeddingException ,
95- httpx .RemoteProtocolError ,
96- httpx .ReadError ,
97- httpx .LocalProtocolError ,
98- httpx .PoolTimeout ,
99- ) as e :
85+ except EmbeddingException as e :
10086 if try_ > 0 :
10187 logger .debug ('Retrying embedding request in 5 secs' , extra = {'try' : try_ })
10288 sleep (5 )
10389 return self ._get_embedding (input_ , try_ - 1 )
10490 raise RetryableEmbeddingException ('Error: request to get embeddings failed' ) from e
105- except httpx .ConnectError as e :
91+ except niquests .exceptions .Timeout as e :
92+ if try_ > 0 :
93+ logger .debug ('Timeout while getting embeddings, retrying in 5 secs' , extra = {'try' : try_ })
94+ sleep (5 )
95+ return self ._get_embedding (input_ , try_ - 1 )
96+ logger .error ('Timeout while getting embeddings' , exc_info = e )
97+ raise EmbeddingException ('Error: timeout while getting embeddings' ) from e
98+ except niquests .exceptions .ConnectionError as e :
10699 if self .app_config .embedding .workers > 0 :
107100 logger .error (
108101 'Error connecting to the embedding server, check if it is running and the logs' ,
@@ -111,13 +104,6 @@ def _get_embedding(self, input_: str | list[str], try_: int = 3) -> list[float]
111104 raise EmbeddingException ('Error: failed to connect to the embedding service' ) from e
112105 logger .error ('Error connecting to the remote embedding service' , exc_info = e )
113106 raise EmbeddingException ('Error: failed to connect to the remote embedding service' ) from e
114- except httpx .NetworkError as e :
115- if try_ > 0 :
116- logger .debug ('Network error while getting embeddings, retrying in 5 secs' , extra = {'try' : try_ })
117- sleep (5 )
118- return self ._get_embedding (input_ , try_ - 1 )
119- logger .error ('Network error while getting embeddings' , exc_info = e )
120- raise EmbeddingException ('Error: network error while getting embeddings' ) from e
121107 except Exception as e :
122108 logger .error ('Unexpected error while getting embeddings' , exc_info = e )
123109 raise EmbeddingException ('Error: unexpected error while getting embeddings' ) from e
0 commit comments