Skip to content

Commit bd0173f

Browse files
refactor: extract chat completion POST logic and improve logprobs retry handling
1 parent 7c44989 commit bd0173f

1 file changed

Lines changed: 39 additions & 31 deletions

File tree

src/infer_check/backends/openai_compat.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import asyncio
1010
import math
1111
import time
12+
from typing import Any
1213

1314
import httpx
1415

@@ -66,24 +67,11 @@ async def generate(self, prompt: Prompt) -> InferenceResult:
6667
return await self._generate_chat(prompt)
6768
return await self._generate_completions(prompt)
6869

69-
async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
70-
"""Use ``/v1/chat/completions`` with proper message formatting.
70+
async def _post_chat(self, payload: dict[str, Any]) -> tuple[float, dict[str, Any]]:
71+
"""POST to /v1/chat/completions with consistent error handling.
7172
72-
Requests logprobs when the server supports them. If the first
73-
request fails with a 4xx (unsupported parameter), the backend
74-
automatically retries without logprobs and disables them for
75-
all subsequent requests.
73+
Returns (elapsed_seconds, response_json).
7674
"""
77-
payload: dict[str, object] = {
78-
"model": self._model_id,
79-
"messages": [{"role": "user", "content": prompt.text}],
80-
"max_tokens": prompt.max_tokens,
81-
"temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0,
82-
}
83-
if self._chat_logprobs_supported:
84-
payload["logprobs"] = True
85-
payload["top_logprobs"] = 5
86-
8775
start = time.perf_counter()
8876
try:
8977
response = await self._client.post("/v1/chat/completions", json=payload)
@@ -97,21 +85,8 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
9785
raise RuntimeError(f"Request to {self._base_url}/v1/chat/completions timed out after 120s.") from exc
9886
except httpx.HTTPStatusError as exc:
9987
status = exc.response.status_code
100-
# If the server rejected logprobs, retry without them.
101-
if 400 <= status < 500 and self._chat_logprobs_supported:
102-
self._chat_logprobs_supported = False
103-
payload.pop("logprobs", None)
104-
payload.pop("top_logprobs", None)
105-
start = time.perf_counter()
106-
try:
107-
response = await self._client.post("/v1/chat/completions", json=payload)
108-
response.raise_for_status()
109-
except httpx.HTTPStatusError as retry_exc:
110-
body = retry_exc.response.text[:500]
111-
raise RuntimeError(f"Server returned HTTP {retry_exc.response.status_code}: {body}") from retry_exc
112-
else:
113-
body = exc.response.text[:500]
114-
raise RuntimeError(f"Server returned HTTP {status}: {body}") from exc
88+
body = exc.response.text[:500]
89+
raise RuntimeError(f"Server returned HTTP {status}: {body}") from exc
11590

11691
elapsed_s = time.perf_counter() - start
11792

@@ -123,6 +98,39 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
12398
if "choices" not in data or not data["choices"]:
12499
raise RuntimeError(f"Server returned empty or malformed response: {data}")
125100

101+
return elapsed_s, data
102+
103+
async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
104+
"""Use ``/v1/chat/completions`` with proper message formatting.
105+
106+
Requests logprobs when the server supports them. If the first
107+
request fails with 400 or 422 (unsupported parameter), the backend
108+
automatically retries without logprobs and disables them for
109+
all subsequent requests.
110+
"""
111+
payload: dict[str, object] = {
112+
"model": self._model_id,
113+
"messages": [{"role": "user", "content": prompt.text}],
114+
"max_tokens": prompt.max_tokens,
115+
"temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0,
116+
}
117+
if self._chat_logprobs_supported:
118+
payload["logprobs"] = True
119+
payload["top_logprobs"] = 5
120+
121+
try:
122+
elapsed_s, data = await self._post_chat(payload)
123+
except RuntimeError as exc:
124+
# Retry without logprobs only on 400/422 (unsupported parameter).
125+
msg = str(exc)
126+
if self._chat_logprobs_supported and ("HTTP 400" in msg or "HTTP 422" in msg):
127+
self._chat_logprobs_supported = False
128+
payload.pop("logprobs", None)
129+
payload.pop("top_logprobs", None)
130+
elapsed_s, data = await self._post_chat(payload)
131+
else:
132+
raise
133+
126134
choice = data["choices"][0]
127135
message = choice.get("message", {})
128136
text: str = message.get("content", "")

0 commit comments

Comments
 (0)