Skip to content

Commit e2bb1e0

Browse files
fix: improve OpenAI compatibility error handling and category assignment
- Introduce _ServerHTTPError for clearer HTTP error propagation in OpenAICompatBackend - Refine logprobs retry logic to use status codes instead of string matching - Adjust prompt category assignment to allow None values instead of defaulting to "default" - Safeguard token extraction in logprobs parsing
1 parent bd0173f commit e2bb1e0

2 files changed

Lines changed: 14 additions & 7 deletions

File tree

src/infer_check/backends/openai_compat.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,14 @@
1818
__all__ = ["OpenAICompatBackend"]
1919

2020

21+
class _ServerHTTPError(RuntimeError):
22+
"""Internal exception carrying the HTTP status code from the server."""
23+
24+
def __init__(self, status_code: int, body: str) -> None:
25+
self.status_code = status_code
26+
super().__init__(f"Server returned HTTP {status_code}: {body}")
27+
28+
2129
class OpenAICompatBackend:
2230
"""Backend adapter for any OpenAI-compatible completion server.
2331
@@ -86,7 +94,7 @@ async def _post_chat(self, payload: dict[str, Any]) -> tuple[float, dict[str, An
8694
except httpx.HTTPStatusError as exc:
8795
status = exc.response.status_code
8896
body = exc.response.text[:500]
89-
raise RuntimeError(f"Server returned HTTP {status}: {body}") from exc
97+
raise _ServerHTTPError(status, body) from exc
9098

9199
elapsed_s = time.perf_counter() - start
92100

@@ -120,10 +128,9 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
120128

121129
try:
122130
elapsed_s, data = await self._post_chat(payload)
123-
except RuntimeError as exc:
131+
except _ServerHTTPError as exc:
124132
# Retry without logprobs only on 400/422 (unsupported parameter).
125-
msg = str(exc)
126-
if self._chat_logprobs_supported and ("HTTP 400" in msg or "HTTP 422" in msg):
133+
if self._chat_logprobs_supported and exc.status_code in (400, 422):
127134
self._chat_logprobs_supported = False
128135
payload.pop("logprobs", None)
129136
payload.pop("top_logprobs", None)
@@ -146,7 +153,7 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
146153
lp_data = choice.get("logprobs")
147154
if lp_data and lp_data.get("content"):
148155
content_logprobs = lp_data["content"]
149-
tokens = [entry["token"] for entry in content_logprobs]
156+
tokens = [entry.get("token", "") for entry in content_logprobs]
150157
logprobs_list = []
151158
for entry in content_logprobs:
152159
raw = entry.get("logprob")

src/infer_check/suites/loader.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def load_suite(path: str | Path, num_prompts: int | None = None) -> list[Prompt]
3636
data = json.loads(line)
3737
prompt = Prompt.model_validate(data)
3838
all_prompts.append(prompt)
39-
cat = prompt.category or "default"
39+
cat = prompt.category
4040
if cat not in prompts_by_category:
4141
prompts_by_category[cat] = []
4242
prompts_by_category[cat].append(prompt)
@@ -75,7 +75,7 @@ def load_suite(path: str | Path, num_prompts: int | None = None) -> list[Prompt]
7575
final_prompts = all_prompts
7676

7777
# Log summary
78-
category_counts = Counter(p.category or "default" for p in final_prompts)
78+
category_counts = Counter(p.category for p in final_prompts)
7979
console.print(f"[bold green]Loaded {len(final_prompts)} prompts from {path_obj.name}[/bold green]")
8080
for category, count in category_counts.most_common():
8181
console.print(f" - {category}: {count}")

0 commit comments

Comments
 (0)