@@ -26,18 +26,21 @@ def __init__(self, provider_config: dict, provider_settings: dict) -> None:
2626 )
2727 self .truncate = provider_config .get ("nvidia_rerank_truncate" , "" )
2828
29- headers = {
30- "Authorization" : f"Bearer { self .api_key } " ,
31- "Content-Type" : "application/json" ,
32- "Accept" : "application/json" ,
33- }
34-
35- self .client = aiohttp .ClientSession (
36- headers = headers , timeout = aiohttp .ClientTimeout (total = self .timeout )
37- )
38-
29+ self .client = None
3930 self .set_model (self .model )
4031
32+ async def _get_client (self ):
33+ if self .client is None or self .client .closed :
34+ headers = {
35+ "Authorization" : f"Bearer { self .api_key } " ,
36+ "Content-Type" : "application/json" ,
37+ "Accept" : "application/json" ,
38+ }
39+ self .client = aiohttp .ClientSession (
40+ headers = headers , timeout = aiohttp .ClientTimeout (total = self .timeout )
41+ )
42+ return self .client
43+
4144 def _get_endpoint (self ) -> str :
4245 """
4346 构建完整API URL。
@@ -111,7 +114,8 @@ async def rerank(
111114 documents : list [str ],
112115 top_n : int | None = None ,
113116 ) -> list [RerankResult ]:
114- if not self .client or self .client .closed :
117+ client = await self ._get_client ()
118+ if not client or client .closed :
115119 logger .error ("[NVIDIA Rerank] Client session not initialized or closed" )
116120 return []
117121
@@ -125,7 +129,7 @@ async def rerank(
125129 payload = self ._build_payload (query , documents )
126130 request_url = self ._get_endpoint ()
127131
128- async with self . client .post (request_url , json = payload ) as response :
132+ async with client .post (request_url , json = payload ) as response :
129133 response_data = await response .json ()
130134 logger .debug (f"[NVIDIA Rerank] API Response: { response_data } " )
131135
0 commit comments