Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 99 additions & 11 deletions src/infer_check/backends/openai_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import asyncio
import math
import time
from typing import Any

import httpx

Expand All @@ -17,6 +18,14 @@
__all__ = ["OpenAICompatBackend"]


class _ServerHTTPError(RuntimeError):
"""Internal exception carrying the HTTP status code from the server."""

def __init__(self, status_code: int, body: str) -> None:
self.status_code = status_code
super().__init__(f"Server returned HTTP {status_code}: {body}")


class OpenAICompatBackend:
"""Backend adapter for any OpenAI-compatible completion server.

Expand Down Expand Up @@ -49,6 +58,9 @@ def __init__(
timeout=120.0,
)

# Logprobs support: assume yes until a server rejects it.
self._chat_logprobs_supported: bool = True

# ------------------------------------------------------------------
# BackendAdapter protocol
# ------------------------------------------------------------------
Expand All @@ -63,15 +75,11 @@ async def generate(self, prompt: Prompt) -> InferenceResult:
return await self._generate_chat(prompt)
return await self._generate_completions(prompt)

async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
"""Use ``/v1/chat/completions`` with proper message formatting."""
payload = {
"model": self._model_id,
"messages": [{"role": "user", "content": prompt.text}],
"max_tokens": prompt.max_tokens,
"temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0,
}
async def _post_chat(self, payload: dict[str, Any]) -> tuple[float, dict[str, Any]]:
"""POST to /v1/chat/completions with consistent error handling.

Returns (elapsed_seconds, response_json).
"""
start = time.perf_counter()
try:
response = await self._client.post("/v1/chat/completions", json=payload)
Expand All @@ -86,7 +94,7 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
except httpx.HTTPStatusError as exc:
status = exc.response.status_code
body = exc.response.text[:500]
raise RuntimeError(f"Server returned HTTP {status}: {body}") from exc
raise _ServerHTTPError(status, body) from exc

elapsed_s = time.perf_counter() - start

Expand All @@ -98,12 +106,90 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
if "choices" not in data or not data["choices"]:
raise RuntimeError(f"Server returned empty or malformed response: {data}")

return elapsed_s, data

async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
"""Use ``/v1/chat/completions`` with proper message formatting.

Requests logprobs when the server supports them. If the first
request fails with 400 or 422 (unsupported parameter), the backend
automatically retries without logprobs and disables them for
all subsequent requests.
"""
payload: dict[str, object] = {
"model": self._model_id,
"messages": [{"role": "user", "content": prompt.text}],
"max_tokens": prompt.max_tokens,
"temperature": prompt.metadata.get("temperature", 0.0) if prompt.metadata else 0.0,
}
if self._chat_logprobs_supported:
payload["logprobs"] = True
payload["top_logprobs"] = 5

try:
elapsed_s, data = await self._post_chat(payload)
except _ServerHTTPError as exc:
# Retry without logprobs only on 400/422 (unsupported parameter).
if self._chat_logprobs_supported and exc.status_code in (400, 422):
self._chat_logprobs_supported = False
payload.pop("logprobs", None)
payload.pop("top_logprobs", None)
elapsed_s, data = await self._post_chat(payload)
else:
raise
Comment thread
NullPointerDepressiveDisorder marked this conversation as resolved.
Comment thread
NullPointerDepressiveDisorder marked this conversation as resolved.

choice = data["choices"][0]
message = choice.get("message", {})
text: str = message.get("content", "")
if not text:
text = message.get("reasoning_content", "")
tokens = text.split()

# Parse logprobs (chat completions format) -------------------------
tokens: list[str] = []
logprobs_list: list[float] | None = None
distributions: list[list[float]] | None = None
distribution_metadata: list[dict[str, int | str]] | None = None

lp_data = choice.get("logprobs")
if lp_data and lp_data.get("content"):
content_logprobs = lp_data["content"]
tokens = [entry.get("token", "") for entry in content_logprobs]
logprobs_list = []
for entry in content_logprobs:
raw = entry.get("logprob")
try:
fv = float(raw) if raw is not None else -9999.0
except (TypeError, ValueError):
fv = -9999.0
if math.isnan(fv):
fv = -9999.0
logprobs_list.append(fv)

distributions = []
distribution_metadata = []
for entry in content_logprobs:
top = entry.get("top_logprobs", [])
if not top:
distributions.append([])
distribution_metadata.append({})
continue
sorted_items = sorted(top, key=lambda x: x.get("token", ""))
cleaned: list[tuple[str, float]] = []
for item in sorted_items:
try:
fv = float(item["logprob"]) if item.get("logprob") is not None else -9999.0
except (TypeError, ValueError):
fv = -9999.0
if math.isnan(fv):
fv = -9999.0
cleaned.append((item.get("token", ""), fv))
distributions.append([fv for _, fv in cleaned])
meta: dict[str, int | str] = {}
for i, (tok, _) in enumerate(cleaned):
meta[f"id_{i}"] = tok
distribution_metadata.append(meta)
else:
Comment thread
NullPointerDepressiveDisorder marked this conversation as resolved.
tokens = text.split()

usage = data.get("usage", {})
completion_tokens = usage.get("completion_tokens", len(tokens))
Expand All @@ -114,7 +200,9 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
backend_name=self.name,
model_id=self._model_id,
tokens=tokens,
logprobs=None,
logprobs=logprobs_list,
distributions=distributions,
distribution_metadata=distribution_metadata,
text=text,
latency_ms=elapsed_s * 1000,
tokens_per_second=tps,
Expand Down
6 changes: 1 addition & 5 deletions src/infer_check/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,7 @@ def _load_prompts(ctx: click.Context, prompts: str, max_tokens: int | None, num_
if num_prompts is not None:
ctx.obj["num_prompts"] = num_prompts

prompt_list = load_suite(_resolve_prompts(prompts))

# Apply num_prompts limit
if ctx.obj["num_prompts"] is not None:
prompt_list = prompt_list[: ctx.obj["num_prompts"]]
prompt_list = load_suite(_resolve_prompts(prompts), num_prompts=ctx.obj["num_prompts"])

# Apply global max_tokens only if not explicitly set in the prompt JSONL
for p in prompt_list:
Expand Down
50 changes: 42 additions & 8 deletions src/infer_check/suites/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,19 @@
console = Console()


def load_suite(path: str | Path) -> list[Prompt]:
def load_suite(path: str | Path, num_prompts: int | None = None) -> list[Prompt]:
"""
Read a JSONL file and validate each line against the Prompt model.
Logs the count and category distribution via rich.console.
Raises ValueError with the line number on invalid entries.
If num_prompts is provided, selects an approximately equal number
of prompts from each category.
"""
Comment thread
NullPointerDepressiveDisorder marked this conversation as resolved.
path_obj = Path(path)
if not path_obj.exists():
raise FileNotFoundError(f"Prompt suite not found: {path_obj}")

prompts = []
category_counts: Counter[str] = Counter()
all_prompts: list[Prompt] = []
prompts_by_category: dict[str, list[Prompt]] = {}

with path_obj.open("r", encoding="utf-8") as f:
for idx, line in enumerate(f, start=1):
Expand All @@ -34,19 +35,52 @@ def load_suite(path: str | Path) -> list[Prompt]:
try:
data = json.loads(line)
prompt = Prompt.model_validate(data)
prompts.append(prompt)
category_counts[prompt.category] += 1
all_prompts.append(prompt)
cat = prompt.category
if cat not in prompts_by_category:
prompts_by_category[cat] = []
prompts_by_category[cat].append(prompt)
Comment thread
NullPointerDepressiveDisorder marked this conversation as resolved.
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON at {path_obj}:{idx} - {e}") from e
except ValidationError as e:
raise ValueError(f"Invalid Prompt at {path_obj}:{idx} - {e}") from e

# Apply num_prompts limit with equal category distribution
if num_prompts is not None and num_prompts < len(all_prompts):
selected_prompts: list[Prompt] = []
categories = sorted(prompts_by_category.keys())
num_categories = len(categories)

if num_categories > 0:
Comment thread
NullPointerDepressiveDisorder marked this conversation as resolved.
# Simple round-robin selection to keep categories equal
# We iterate through categories and pick one prompt from each until we hit the limit
# This ensures that even if categories have different sizes, we pick as equally as possible
cat_indices = {cat: 0 for cat in categories}
while len(selected_prompts) < num_prompts:
added_in_round = False
for cat in categories:
if len(selected_prompts) >= num_prompts:
break
idx = cat_indices[cat]
if idx < len(prompts_by_category[cat]):
selected_prompts.append(prompts_by_category[cat][idx])
cat_indices[cat] += 1
added_in_round = True
if not added_in_round:
break
final_prompts = selected_prompts
else:
final_prompts = all_prompts[:num_prompts]
else:
final_prompts = all_prompts

# Log summary
console.print(f"[bold green]Loaded {len(prompts)} prompts from {path_obj.name}[/bold green]")
category_counts = Counter(p.category for p in final_prompts)
console.print(f"[bold green]Loaded {len(final_prompts)} prompts from {path_obj.name}[/bold green]")
for category, count in category_counts.most_common():
console.print(f" - {category}: {count}")

return prompts
return final_prompts


def save_suite(prompts: list[Prompt], path: str | Path) -> None:
Expand Down
86 changes: 86 additions & 0 deletions tests/unit/test_loader_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import json
from pathlib import Path

from infer_check.suites.loader import load_suite


def test_load_suite_equal_distribution(tmp_path: Path) -> None:
"""Test that load_suite distributes num_prompts equally across categories."""
prompt_file = tmp_path / "test_prompts.jsonl"

# 10 math, 5 code, 2 logic
prompts = []
for i in range(10):
prompts.append({"id": f"math-{i}", "text": f"math {i}", "category": "math"})
for i in range(5):
prompts.append({"id": f"code-{i}", "text": f"code {i}", "category": "code"})
for i in range(2):
prompts.append({"id": f"logic-{i}", "text": f"logic {i}", "category": "logic"})

prompt_file.write_text("\n".join(json.dumps(p) for p in prompts))

# Request 6 prompts.
# Round 1: math-0, code-0, logic-0 (3 total)
# Round 2: math-1, code-1, logic-1 (6 total)
# Categories: code, logic, math (sorted)
# Round 1: code-0, logic-0, math-0
# Round 2: code-1, logic-1, math-1
loaded = load_suite(prompt_file, num_prompts=6)

assert len(loaded) == 6
categories = [p.category for p in loaded]
from collections import Counter

counts = Counter(categories)

assert counts["math"] == 2
assert counts["code"] == 2
assert counts["logic"] == 2

# Request 4 prompts
# Round 1: code-0, logic-0, math-0 (3 total)
# Round 2: code-1 (4 total)
loaded_4 = load_suite(prompt_file, num_prompts=4)
assert len(loaded_4) == 4
counts_4 = Counter([p.category for p in loaded_4])
assert counts_4["code"] == 2
assert counts_4["logic"] == 1
assert counts_4["math"] == 1


def test_load_suite_uneven_categories(tmp_path: Path) -> None:
"""Test distribution when some categories are exhausted."""
prompt_file = tmp_path / "test_prompts_uneven.jsonl"

# 5 math, 1 code
prompts = []
for i in range(5):
prompts.append({"id": f"math-{i}", "text": f"math {i}", "category": "math"})
prompts.append({"id": "code-0", "text": "code 0", "category": "code"})

prompt_file.write_text("\n".join(json.dumps(p) for p in prompts))

# Request 4 prompts.
# Sorted categories: code, math
# Round 1: code-0, math-0
# Round 2: (code exhausted), math-1
# Round 3: math-2
loaded = load_suite(prompt_file, num_prompts=4)

assert len(loaded) == 4
counts = {p.category: 0 for p in loaded}
for p in loaded:
counts[p.category] += 1

assert counts["code"] == 1
assert counts["math"] == 3


def test_load_suite_no_limit(tmp_path: Path) -> None:
"""Test that load_suite returns all prompts if no limit is provided."""
prompt_file = tmp_path / "test_prompts_all.jsonl"
prompts = [{"id": "1", "text": "t1", "category": "a"}, {"id": "2", "text": "t2", "category": "b"}]
prompt_file.write_text("\n".join(json.dumps(p) for p in prompts))

loaded = load_suite(prompt_file)
assert len(loaded) == 2
Loading
Loading