99import asyncio
1010import math
1111import time
12+ from typing import Any
1213
1314import httpx
1415
@@ -66,24 +67,11 @@ async def generate(self, prompt: Prompt) -> InferenceResult:
6667 return await self ._generate_chat (prompt )
6768 return await self ._generate_completions (prompt )
6869
69- async def _generate_chat (self , prompt : Prompt ) -> InferenceResult :
70- """Use `` /v1/chat/completions`` with proper message formatting .
70+ async def _post_chat (self , payload : dict [ str , Any ] ) -> tuple [ float , dict [ str , Any ]] :
71+ """POST to /v1/chat/completions with consistent error handling .
7172
72- Requests logprobs when the server supports them. If the first
73- request fails with a 4xx (unsupported parameter), the backend
74- automatically retries without logprobs and disables them for
75- all subsequent requests.
73+ Returns (elapsed_seconds, response_json).
7674 """
77- payload : dict [str , object ] = {
78- "model" : self ._model_id ,
79- "messages" : [{"role" : "user" , "content" : prompt .text }],
80- "max_tokens" : prompt .max_tokens ,
81- "temperature" : prompt .metadata .get ("temperature" , 0.0 ) if prompt .metadata else 0.0 ,
82- }
83- if self ._chat_logprobs_supported :
84- payload ["logprobs" ] = True
85- payload ["top_logprobs" ] = 5
86-
8775 start = time .perf_counter ()
8876 try :
8977 response = await self ._client .post ("/v1/chat/completions" , json = payload )
@@ -97,21 +85,8 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
9785 raise RuntimeError (f"Request to { self ._base_url } /v1/chat/completions timed out after 120s." ) from exc
9886 except httpx .HTTPStatusError as exc :
9987 status = exc .response .status_code
100- # If the server rejected logprobs, retry without them.
101- if 400 <= status < 500 and self ._chat_logprobs_supported :
102- self ._chat_logprobs_supported = False
103- payload .pop ("logprobs" , None )
104- payload .pop ("top_logprobs" , None )
105- start = time .perf_counter ()
106- try :
107- response = await self ._client .post ("/v1/chat/completions" , json = payload )
108- response .raise_for_status ()
109- except httpx .HTTPStatusError as retry_exc :
110- body = retry_exc .response .text [:500 ]
111- raise RuntimeError (f"Server returned HTTP { retry_exc .response .status_code } : { body } " ) from retry_exc
112- else :
113- body = exc .response .text [:500 ]
114- raise RuntimeError (f"Server returned HTTP { status } : { body } " ) from exc
88+ body = exc .response .text [:500 ]
89+ raise RuntimeError (f"Server returned HTTP { status } : { body } " ) from exc
11590
11691 elapsed_s = time .perf_counter () - start
11792
@@ -123,6 +98,39 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
12398 if "choices" not in data or not data ["choices" ]:
12499 raise RuntimeError (f"Server returned empty or malformed response: { data } " )
125100
101+ return elapsed_s , data
102+
103+ async def _generate_chat (self , prompt : Prompt ) -> InferenceResult :
104+ """Use ``/v1/chat/completions`` with proper message formatting.
105+
106+ Requests logprobs when the server supports them. If the first
107+ request fails with 400 or 422 (unsupported parameter), the backend
108+ automatically retries without logprobs and disables them for
109+ all subsequent requests.
110+ """
111+ payload : dict [str , object ] = {
112+ "model" : self ._model_id ,
113+ "messages" : [{"role" : "user" , "content" : prompt .text }],
114+ "max_tokens" : prompt .max_tokens ,
115+ "temperature" : prompt .metadata .get ("temperature" , 0.0 ) if prompt .metadata else 0.0 ,
116+ }
117+ if self ._chat_logprobs_supported :
118+ payload ["logprobs" ] = True
119+ payload ["top_logprobs" ] = 5
120+
121+ try :
122+ elapsed_s , data = await self ._post_chat (payload )
123+ except RuntimeError as exc :
124+ # Retry without logprobs only on 400/422 (unsupported parameter).
125+ msg = str (exc )
126+ if self ._chat_logprobs_supported and ("HTTP 400" in msg or "HTTP 422" in msg ):
127+ self ._chat_logprobs_supported = False
128+ payload .pop ("logprobs" , None )
129+ payload .pop ("top_logprobs" , None )
130+ elapsed_s , data = await self ._post_chat (payload )
131+ else :
132+ raise
133+
126134 choice = data ["choices" ][0 ]
127135 message = choice .get ("message" , {})
128136 text : str = message .get ("content" , "" )
0 commit comments