Skip to content

Commit 9829426

Browse files
test: add coverage for load_suite category distribution logic
- Add unit tests to validate equal and uneven category prompt distribution - Extend `load_suite` to support limiting prompts by category balance - Update OpenAI backend to parse logprobs and token distributions
1 parent 45e6b67 commit 9829426

4 files changed

Lines changed: 175 additions & 16 deletions

File tree

src/infer_check/backends/openai_compat.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,13 @@ async def generate(self, prompt: Prompt) -> InferenceResult:
6565

6666
async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
6767
"""Use ``/v1/chat/completions`` with proper message formatting."""
68-
payload = {
68+
payload: dict[str, object] = {
6969
"model": self._model_id,
7070
"messages": [{"role": "user", "content": prompt.text}],
7171
"max_tokens": prompt.max_tokens,
7272
"temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0,
73+
"logprobs": True,
74+
"top_logprobs": 5,
7375
}
7476

7577
start = time.perf_counter()
@@ -103,7 +105,46 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
103105
text: str = message.get("content", "")
104106
if not text:
105107
text = message.get("reasoning_content", "")
106-
tokens = text.split()
108+
109+
# Parse logprobs (chat completions format) -------------------------
110+
tokens: list[str] = []
111+
logprobs_list: list[float] | None = None
112+
distributions: list[list[float]] | None = None
113+
distribution_metadata: list[dict[str, int | str]] | None = None
114+
115+
lp_data = choice.get("logprobs")
116+
if lp_data and lp_data.get("content"):
117+
content_logprobs = lp_data["content"]
118+
tokens = [entry["token"] for entry in content_logprobs]
119+
logprobs_list = [
120+
float(entry["logprob"]) if entry.get("logprob") is not None else -9999.0 for entry in content_logprobs
121+
]
122+
123+
distributions = []
124+
distribution_metadata = []
125+
for entry in content_logprobs:
126+
top = entry.get("top_logprobs", [])
127+
if not top:
128+
distributions.append([])
129+
distribution_metadata.append({})
130+
continue
131+
sorted_items = sorted(top, key=lambda x: x.get("token", ""))
132+
cleaned: list[tuple[str, float]] = []
133+
for item in sorted_items:
134+
try:
135+
fv = float(item["logprob"]) if item.get("logprob") is not None else -9999.0
136+
except (TypeError, ValueError):
137+
fv = -9999.0
138+
if math.isnan(fv):
139+
fv = -9999.0
140+
cleaned.append((item.get("token", ""), fv))
141+
distributions.append([fv for _, fv in cleaned])
142+
meta: dict[str, int | str] = {}
143+
for i, (tok, _) in enumerate(cleaned):
144+
meta[f"id_{i}"] = tok
145+
distribution_metadata.append(meta)
146+
else:
147+
tokens = text.split()
107148

108149
usage = data.get("usage", {})
109150
completion_tokens = usage.get("completion_tokens", len(tokens))
@@ -114,7 +155,9 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
114155
backend_name=self.name,
115156
model_id=self._model_id,
116157
tokens=tokens,
117-
logprobs=None,
158+
logprobs=logprobs_list,
159+
distributions=distributions,
160+
distribution_metadata=distribution_metadata,
118161
text=text,
119162
latency_ms=elapsed_s * 1000,
120163
tokens_per_second=tps,

src/infer_check/cli.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,7 @@ def _load_prompts(ctx: click.Context, prompts: str, max_tokens: int | None, num_
5959
if num_prompts is not None:
6060
ctx.obj["num_prompts"] = num_prompts
6161

62-
prompt_list = load_suite(_resolve_prompts(prompts))
63-
64-
# Apply num_prompts limit
65-
if ctx.obj["num_prompts"] is not None:
66-
prompt_list = prompt_list[: ctx.obj["num_prompts"]]
62+
prompt_list = load_suite(_resolve_prompts(prompts), num_prompts=ctx.obj["num_prompts"])
6763

6864
# Apply global max_tokens only if not explicitly set in the prompt JSONL
6965
for p in prompt_list:

src/infer_check/suites/loader.py

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,18 +12,19 @@
1212
console = Console()
1313

1414

15-
def load_suite(path: str | Path) -> list[Prompt]:
15+
def load_suite(path: str | Path, num_prompts: int | None = None) -> list[Prompt]:
1616
"""
1717
Read a JSONL file and validate each line against the Prompt model.
1818
Logs the count and category distribution via rich.console.
19-
Raises ValueError with the line number on invalid entries.
19+
If num_prompts is provided, selects an approximately equal number
20+
of prompts from each category.
2021
"""
2122
path_obj = Path(path)
2223
if not path_obj.exists():
2324
raise FileNotFoundError(f"Prompt suite not found: {path_obj}")
2425

25-
prompts = []
26-
category_counts: Counter[str] = Counter()
26+
all_prompts: list[Prompt] = []
27+
prompts_by_category: dict[str, list[Prompt]] = {}
2728

2829
with path_obj.open("r", encoding="utf-8") as f:
2930
for idx, line in enumerate(f, start=1):
@@ -34,19 +35,52 @@ def load_suite(path: str | Path) -> list[Prompt]:
3435
try:
3536
data = json.loads(line)
3637
prompt = Prompt.model_validate(data)
37-
prompts.append(prompt)
38-
category_counts[prompt.category] += 1
38+
all_prompts.append(prompt)
39+
cat = prompt.category or "default"
40+
if cat not in prompts_by_category:
41+
prompts_by_category[cat] = []
42+
prompts_by_category[cat].append(prompt)
3943
except json.JSONDecodeError as e:
4044
raise ValueError(f"Invalid JSON at {path_obj}:{idx} - {e}") from e
4145
except ValidationError as e:
4246
raise ValueError(f"Invalid Prompt at {path_obj}:{idx} - {e}") from e
4347

48+
# Apply num_prompts limit with equal category distribution
49+
if num_prompts is not None and num_prompts < len(all_prompts):
50+
selected_prompts: list[Prompt] = []
51+
categories = sorted(prompts_by_category.keys())
52+
num_categories = len(categories)
53+
54+
if num_categories > 0:
55+
# Simple round-robin selection to keep categories equal
56+
# We iterate through categories and pick one prompt from each until we hit the limit
57+
# This ensures that even if categories have different sizes, we pick as equally as possible
58+
cat_indices = {cat: 0 for cat in categories}
59+
while len(selected_prompts) < num_prompts:
60+
added_in_round = False
61+
for cat in categories:
62+
if len(selected_prompts) >= num_prompts:
63+
break
64+
idx = cat_indices[cat]
65+
if idx < len(prompts_by_category[cat]):
66+
selected_prompts.append(prompts_by_category[cat][idx])
67+
cat_indices[cat] += 1
68+
added_in_round = True
69+
if not added_in_round:
70+
break
71+
final_prompts = selected_prompts
72+
else:
73+
final_prompts = all_prompts[:num_prompts]
74+
else:
75+
final_prompts = all_prompts
76+
4477
# Log summary
45-
console.print(f"[bold green]Loaded {len(prompts)} prompts from {path_obj.name}[/bold green]")
78+
category_counts = Counter(p.category or "default" for p in final_prompts)
79+
console.print(f"[bold green]Loaded {len(final_prompts)} prompts from {path_obj.name}[/bold green]")
4680
for category, count in category_counts.most_common():
4781
console.print(f" - {category}: {count}")
4882

49-
return prompts
83+
return final_prompts
5084

5185

5286
def save_suite(prompts: list[Prompt], path: str | Path) -> None:
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import json
2+
from pathlib import Path
3+
4+
from infer_check.suites.loader import load_suite
5+
6+
7+
def test_load_suite_equal_distribution(tmp_path: Path) -> None:
8+
"""Test that load_suite distributes num_prompts equally across categories."""
9+
prompt_file = tmp_path / "test_prompts.jsonl"
10+
11+
# 10 math, 5 code, 2 logic
12+
prompts = []
13+
for i in range(10):
14+
prompts.append({"id": f"math-{i}", "text": f"math {i}", "category": "math"})
15+
for i in range(5):
16+
prompts.append({"id": f"code-{i}", "text": f"code {i}", "category": "code"})
17+
for i in range(2):
18+
prompts.append({"id": f"logic-{i}", "text": f"logic {i}", "category": "logic"})
19+
20+
prompt_file.write_text("\n".join(json.dumps(p) for p in prompts))
21+
22+
# Request 6 prompts.
23+
# Round 1: math-0, code-0, logic-0 (3 total)
24+
# Round 2: math-1, code-1, logic-1 (6 total)
25+
# Categories: code, logic, math (sorted)
26+
# Round 1: code-0, logic-0, math-0
27+
# Round 2: code-1, logic-1, math-1
28+
loaded = load_suite(prompt_file, num_prompts=6)
29+
30+
assert len(loaded) == 6
31+
categories = [p.category for p in loaded]
32+
from collections import Counter
33+
34+
counts = Counter(categories)
35+
36+
assert counts["math"] == 2
37+
assert counts["code"] == 2
38+
assert counts["logic"] == 2
39+
40+
# Request 4 prompts
41+
# Round 1: code-0, logic-0, math-0 (3 total)
42+
# Round 2: code-1 (4 total)
43+
loaded_4 = load_suite(prompt_file, num_prompts=4)
44+
assert len(loaded_4) == 4
45+
counts_4 = Counter([p.category for p in loaded_4])
46+
assert counts_4["code"] == 2
47+
assert counts_4["logic"] == 1
48+
assert counts_4["math"] == 1
49+
50+
51+
def test_load_suite_uneven_categories(tmp_path: Path) -> None:
52+
"""Test distribution when some categories are exhausted."""
53+
prompt_file = tmp_path / "test_prompts_uneven.jsonl"
54+
55+
# 5 math, 1 code
56+
prompts = []
57+
for i in range(5):
58+
prompts.append({"id": f"math-{i}", "text": f"math {i}", "category": "math"})
59+
prompts.append({"id": "code-0", "text": "code 0", "category": "code"})
60+
61+
prompt_file.write_text("\n".join(json.dumps(p) for p in prompts))
62+
63+
# Request 4 prompts.
64+
# Sorted categories: code, math
65+
# Round 1: code-0, math-0
66+
# Round 2: (code exhausted), math-1
67+
# Round 3: math-2
68+
loaded = load_suite(prompt_file, num_prompts=4)
69+
70+
assert len(loaded) == 4
71+
counts = {p.category: 0 for p in loaded}
72+
for p in loaded:
73+
counts[p.category] += 1
74+
75+
assert counts["code"] == 1
76+
assert counts["math"] == 3
77+
78+
79+
def test_load_suite_no_limit(tmp_path: Path) -> None:
80+
"""Test that load_suite returns all prompts if no limit is provided."""
81+
prompt_file = tmp_path / "test_prompts_all.jsonl"
82+
prompts = [{"id": "1", "text": "t1", "category": "a"}, {"id": "2", "text": "t2", "category": "b"}]
83+
prompt_file.write_text("\n".join(json.dumps(p) for p in prompts))
84+
85+
loaded = load_suite(prompt_file)
86+
assert len(loaded) == 2

0 commit comments

Comments
 (0)