Skip to content

Commit aa75120

Browse files
fix: Update OpenAI backend to parse logprobs and token distributions (#13)
* test: add coverage for `load_suite` category distribution logic - Add unit tests to validate equal and uneven category prompt distribution - Extend `load_suite` to support limiting prompts by category balance - Update OpenAI backend to parse logprobs and token distributions * fix: improve logprobs handling in chat completions * test: add logprobs handling tests for chat completions * refactor: extract chat completion POST logic and improve logprobs retry handling * fix: improve OpenAI compatibility error handling and category assignment - Introduce _ServerHTTPError for clearer HTTP error propagation in OpenAICompatBackend - Refine logprobs retry logic to use status codes instead of string matching - Adjust prompt category assignment to allow None values instead of defaulting to "default" - Safeguard token extraction in logprobs parsing
1 parent 45e6b67 commit aa75120

5 files changed

Lines changed: 311 additions & 24 deletions

File tree

src/infer_check/backends/openai_compat.py

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import asyncio
1010
import math
1111
import time
12+
from typing import Any
1213

1314
import httpx
1415

@@ -17,6 +18,14 @@
1718
__all__ = ["OpenAICompatBackend"]
1819

1920

21+
class _ServerHTTPError(RuntimeError):
22+
"""Internal exception carrying the HTTP status code from the server."""
23+
24+
def __init__(self, status_code: int, body: str) -> None:
25+
self.status_code = status_code
26+
super().__init__(f"Server returned HTTP {status_code}: {body}")
27+
28+
2029
class OpenAICompatBackend:
2130
"""Backend adapter for any OpenAI-compatible completion server.
2231
@@ -49,6 +58,9 @@ def __init__(
4958
timeout=120.0,
5059
)
5160

61+
# Logprobs support: assume yes until a server rejects it.
62+
self._chat_logprobs_supported: bool = True
63+
5264
# ------------------------------------------------------------------
5365
# BackendAdapter protocol
5466
# ------------------------------------------------------------------
@@ -63,15 +75,11 @@ async def generate(self, prompt: Prompt) -> InferenceResult:
6375
return await self._generate_chat(prompt)
6476
return await self._generate_completions(prompt)
6577

66-
async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
67-
"""Use ``/v1/chat/completions`` with proper message formatting."""
68-
payload = {
69-
"model": self._model_id,
70-
"messages": [{"role": "user", "content": prompt.text}],
71-
"max_tokens": prompt.max_tokens,
72-
"temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0,
73-
}
78+
async def _post_chat(self, payload: dict[str, Any]) -> tuple[float, dict[str, Any]]:
79+
"""POST to /v1/chat/completions with consistent error handling.
7480
81+
Returns (elapsed_seconds, response_json).
82+
"""
7583
start = time.perf_counter()
7684
try:
7785
response = await self._client.post("/v1/chat/completions", json=payload)
@@ -86,7 +94,7 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
8694
except httpx.HTTPStatusError as exc:
8795
status = exc.response.status_code
8896
body = exc.response.text[:500]
89-
raise RuntimeError(f"Server returned HTTP {status}: {body}") from exc
97+
raise _ServerHTTPError(status, body) from exc
9098

9199
elapsed_s = time.perf_counter() - start
92100

@@ -98,12 +106,90 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
98106
if "choices" not in data or not data["choices"]:
99107
raise RuntimeError(f"Server returned empty or malformed response: {data}")
100108

109+
return elapsed_s, data
110+
111+
async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
112+
"""Use ``/v1/chat/completions`` with proper message formatting.
113+
114+
Requests logprobs when the server supports them. If the first
115+
request fails with 400 or 422 (unsupported parameter), the backend
116+
automatically retries without logprobs and disables them for
117+
all subsequent requests.
118+
"""
119+
payload: dict[str, object] = {
120+
"model": self._model_id,
121+
"messages": [{"role": "user", "content": prompt.text}],
122+
"max_tokens": prompt.max_tokens,
123+
"temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0,
124+
}
125+
if self._chat_logprobs_supported:
126+
payload["logprobs"] = True
127+
payload["top_logprobs"] = 5
128+
129+
try:
130+
elapsed_s, data = await self._post_chat(payload)
131+
except _ServerHTTPError as exc:
132+
# Retry without logprobs only on 400/422 (unsupported parameter).
133+
if self._chat_logprobs_supported and exc.status_code in (400, 422):
134+
self._chat_logprobs_supported = False
135+
payload.pop("logprobs", None)
136+
payload.pop("top_logprobs", None)
137+
elapsed_s, data = await self._post_chat(payload)
138+
else:
139+
raise
140+
101141
choice = data["choices"][0]
102142
message = choice.get("message", {})
103143
text: str = message.get("content", "")
104144
if not text:
105145
text = message.get("reasoning_content", "")
106-
tokens = text.split()
146+
147+
# Parse logprobs (chat completions format) -------------------------
148+
tokens: list[str] = []
149+
logprobs_list: list[float] | None = None
150+
distributions: list[list[float]] | None = None
151+
distribution_metadata: list[dict[str, int | str]] | None = None
152+
153+
lp_data = choice.get("logprobs")
154+
if lp_data and lp_data.get("content"):
155+
content_logprobs = lp_data["content"]
156+
tokens = [entry.get("token", "") for entry in content_logprobs]
157+
logprobs_list = []
158+
for entry in content_logprobs:
159+
raw = entry.get("logprob")
160+
try:
161+
fv = float(raw) if raw is not None else -9999.0
162+
except (TypeError, ValueError):
163+
fv = -9999.0
164+
if math.isnan(fv):
165+
fv = -9999.0
166+
logprobs_list.append(fv)
167+
168+
distributions = []
169+
distribution_metadata = []
170+
for entry in content_logprobs:
171+
top = entry.get("top_logprobs", [])
172+
if not top:
173+
distributions.append([])
174+
distribution_metadata.append({})
175+
continue
176+
sorted_items = sorted(top, key=lambda x: x.get("token", ""))
177+
cleaned: list[tuple[str, float]] = []
178+
for item in sorted_items:
179+
try:
180+
fv = float(item["logprob"]) if item.get("logprob") is not None else -9999.0
181+
except (TypeError, ValueError):
182+
fv = -9999.0
183+
if math.isnan(fv):
184+
fv = -9999.0
185+
cleaned.append((item.get("token", ""), fv))
186+
distributions.append([fv for _, fv in cleaned])
187+
meta: dict[str, int | str] = {}
188+
for i, (tok, _) in enumerate(cleaned):
189+
meta[f"id_{i}"] = tok
190+
distribution_metadata.append(meta)
191+
else:
192+
tokens = text.split()
107193

108194
usage = data.get("usage", {})
109195
completion_tokens = usage.get("completion_tokens", len(tokens))
@@ -114,7 +200,9 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
114200
backend_name=self.name,
115201
model_id=self._model_id,
116202
tokens=tokens,
117-
logprobs=None,
203+
logprobs=logprobs_list,
204+
distributions=distributions,
205+
distribution_metadata=distribution_metadata,
118206
text=text,
119207
latency_ms=elapsed_s * 1000,
120208
tokens_per_second=tps,

src/infer_check/cli.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,7 @@ def _load_prompts(ctx: click.Context, prompts: str, max_tokens: int | None, num_
5959
if num_prompts is not None:
6060
ctx.obj["num_prompts"] = num_prompts
6161

62-
prompt_list = load_suite(_resolve_prompts(prompts))
63-
64-
# Apply num_prompts limit
65-
if ctx.obj["num_prompts"] is not None:
66-
prompt_list = prompt_list[: ctx.obj["num_prompts"]]
62+
prompt_list = load_suite(_resolve_prompts(prompts), num_prompts=ctx.obj["num_prompts"])
6763

6864
# Apply global max_tokens only if not explicitly set in the prompt JSONL
6965
for p in prompt_list:

src/infer_check/suites/loader.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,19 @@
1212
console = Console()
1313

1414

15-
def load_suite(path: str | Path) -> list[Prompt]:
15+
def load_suite(path: str | Path, num_prompts: int | None = None) -> list[Prompt]:
1616
"""
1717
Read a JSONL file and validate each line against the Prompt model.
1818
Logs the count and category distribution via rich.console.
19-
Raises ValueError with the line number on invalid entries.
19+
If num_prompts is provided, selects an approximately equal number
20+
of prompts from each category.
2021
"""
2122
path_obj = Path(path)
2223
if not path_obj.exists():
2324
raise FileNotFoundError(f"Prompt suite not found: {path_obj}")
2425

25-
prompts = []
26-
category_counts: Counter[str] = Counter()
26+
all_prompts: list[Prompt] = []
27+
prompts_by_category: dict[str, list[Prompt]] = {}
2728

2829
with path_obj.open("r", encoding="utf-8") as f:
2930
for idx, line in enumerate(f, start=1):
@@ -34,19 +35,52 @@ def load_suite(path: str | Path) -> list[Prompt]:
3435
try:
3536
data = json.loads(line)
3637
prompt = Prompt.model_validate(data)
37-
prompts.append(prompt)
38-
category_counts[prompt.category] += 1
38+
all_prompts.append(prompt)
39+
cat = prompt.category
40+
if cat not in prompts_by_category:
41+
prompts_by_category[cat] = []
42+
prompts_by_category[cat].append(prompt)
3943
except json.JSONDecodeError as e:
4044
raise ValueError(f"Invalid JSON at {path_obj}:{idx} - {e}") from e
4145
except ValidationError as e:
4246
raise ValueError(f"Invalid Prompt at {path_obj}:{idx} - {e}") from e
4347

48+
# Apply num_prompts limit with equal category distribution
49+
if num_prompts is not None and num_prompts < len(all_prompts):
50+
selected_prompts: list[Prompt] = []
51+
categories = sorted(prompts_by_category.keys())
52+
num_categories = len(categories)
53+
54+
if num_categories > 0:
55+
# Simple round-robin selection to keep categories equal
56+
# We iterate through categories and pick one prompt from each until we hit the limit
57+
# This ensures that even if categories have different sizes, we pick as equally as possible
58+
cat_indices = {cat: 0 for cat in categories}
59+
while len(selected_prompts) < num_prompts:
60+
added_in_round = False
61+
for cat in categories:
62+
if len(selected_prompts) >= num_prompts:
63+
break
64+
idx = cat_indices[cat]
65+
if idx < len(prompts_by_category[cat]):
66+
selected_prompts.append(prompts_by_category[cat][idx])
67+
cat_indices[cat] += 1
68+
added_in_round = True
69+
if not added_in_round:
70+
break
71+
final_prompts = selected_prompts
72+
else:
73+
final_prompts = all_prompts[:num_prompts]
74+
else:
75+
final_prompts = all_prompts
76+
4477
# Log summary
45-
console.print(f"[bold green]Loaded {len(prompts)} prompts from {path_obj.name}[/bold green]")
78+
category_counts = Counter(p.category for p in final_prompts)
79+
console.print(f"[bold green]Loaded {len(final_prompts)} prompts from {path_obj.name}[/bold green]")
4680
for category, count in category_counts.most_common():
4781
console.print(f" - {category}: {count}")
4882

49-
return prompts
83+
return final_prompts
5084

5185

5286
def save_suite(prompts: list[Prompt], path: str | Path) -> None:
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import json
2+
from pathlib import Path
3+
4+
from infer_check.suites.loader import load_suite
5+
6+
7+
def test_load_suite_equal_distribution(tmp_path: Path) -> None:
8+
"""Test that load_suite distributes num_prompts equally across categories."""
9+
prompt_file = tmp_path / "test_prompts.jsonl"
10+
11+
# 10 math, 5 code, 2 logic
12+
prompts = []
13+
for i in range(10):
14+
prompts.append({"id": f"math-{i}", "text": f"math {i}", "category": "math"})
15+
for i in range(5):
16+
prompts.append({"id": f"code-{i}", "text": f"code {i}", "category": "code"})
17+
for i in range(2):
18+
prompts.append({"id": f"logic-{i}", "text": f"logic {i}", "category": "logic"})
19+
20+
prompt_file.write_text("\n".join(json.dumps(p) for p in prompts))
21+
22+
# Request 6 prompts.
23+
# Round 1: math-0, code-0, logic-0 (3 total)
24+
# Round 2: math-1, code-1, logic-1 (6 total)
25+
# Categories: code, logic, math (sorted)
26+
# Round 1: code-0, logic-0, math-0
27+
# Round 2: code-1, logic-1, math-1
28+
loaded = load_suite(prompt_file, num_prompts=6)
29+
30+
assert len(loaded) == 6
31+
categories = [p.category for p in loaded]
32+
from collections import Counter
33+
34+
counts = Counter(categories)
35+
36+
assert counts["math"] == 2
37+
assert counts["code"] == 2
38+
assert counts["logic"] == 2
39+
40+
# Request 4 prompts
41+
# Round 1: code-0, logic-0, math-0 (3 total)
42+
# Round 2: code-1 (4 total)
43+
loaded_4 = load_suite(prompt_file, num_prompts=4)
44+
assert len(loaded_4) == 4
45+
counts_4 = Counter([p.category for p in loaded_4])
46+
assert counts_4["code"] == 2
47+
assert counts_4["logic"] == 1
48+
assert counts_4["math"] == 1
49+
50+
51+
def test_load_suite_uneven_categories(tmp_path: Path) -> None:
52+
"""Test distribution when some categories are exhausted."""
53+
prompt_file = tmp_path / "test_prompts_uneven.jsonl"
54+
55+
# 5 math, 1 code
56+
prompts = []
57+
for i in range(5):
58+
prompts.append({"id": f"math-{i}", "text": f"math {i}", "category": "math"})
59+
prompts.append({"id": "code-0", "text": "code 0", "category": "code"})
60+
61+
prompt_file.write_text("\n".join(json.dumps(p) for p in prompts))
62+
63+
# Request 4 prompts.
64+
# Sorted categories: code, math
65+
# Round 1: code-0, math-0
66+
# Round 2: (code exhausted), math-1
67+
# Round 3: math-2
68+
loaded = load_suite(prompt_file, num_prompts=4)
69+
70+
assert len(loaded) == 4
71+
counts = {p.category: 0 for p in loaded}
72+
for p in loaded:
73+
counts[p.category] += 1
74+
75+
assert counts["code"] == 1
76+
assert counts["math"] == 3
77+
78+
79+
def test_load_suite_no_limit(tmp_path: Path) -> None:
80+
"""Test that load_suite returns all prompts if no limit is provided."""
81+
prompt_file = tmp_path / "test_prompts_all.jsonl"
82+
prompts = [{"id": "1", "text": "t1", "category": "a"}, {"id": "2", "text": "t2", "category": "b"}]
83+
prompt_file.write_text("\n".join(json.dumps(p) for p in prompts))
84+
85+
loaded = load_suite(prompt_file)
86+
assert len(loaded) == 2

0 commit comments

Comments
 (0)