Skip to content

Commit 846331b

Browse files
Bug/max tokens (#10)
* feat: add global --max-tokens CLI flag and increase default max_tokens to 1024 * fix: handle missing content by falling back to reasoning_content in OpenAI compatibility * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * fix: ensure prompt-level max_tokens override global CLI flag and add tests for precedence - Track explicitly set fields in prompt metadata to distinguish prompt-level max_tokens - Only apply global --max-tokens if max_tokens not set in prompt JSONL - Add tests to verify prompt-level max_tokens take precedence over global flag - Make model cleanup in mlx_lm async-safe * fix: use model_fields_set for max_tokens override detection and update tests - Replace usage of metadata["__fields_set__"] with model_fields_set to check if max_tokens is set in prompt - Remove manual tracking of __fields_set__ in loader.py - Update test_cli_max_tokens.py to use SweepResult instead of MagicMock for mock results - Change ruff target-version to py311 in pyproject.toml * fix: set correct default and type for --max-tokens CLI option, update python_version in mypy config --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 4d0ea3e commit 846331b

9 files changed

Lines changed: 235 additions & 9 deletions

File tree

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# 0.2.0 (2026-04-02)
2+
3+
- Added global `--max-tokens` flag (defaults to 1024) to the main CLI.
4+
- Increased default `max_tokens` for all prompts from 256 to 1024.
5+
16
# 0.1.0 (2026-03-11)
27

38
- Initial release

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ pip install "infer-check[mlx]"
6868

6969
### Quantization sweep
7070

71-
Compare pre-quantized models against a baseline. Each model is a separate HuggingFace repo.
71+
Compare pre-quantized models against a baseline. Each model is a separate HuggingFace repo. Use `--max-tokens` to control generation length (defaults to 1024).
7272

7373
```
7474
infer-check sweep \
@@ -77,6 +77,7 @@ infer-check sweep \
7777
4bit=mlx-community/Meta-Llama-3.1-8B-Instruct-4bit" \
7878
--backend mlx-lm \
7979
--prompts reasoning \
80+
--max-tokens 512 \
8081
--output ./results/sweep/
8182
```
8283

@@ -161,7 +162,7 @@ Curated prompts targeting known quantization failure modes:
161162
| `quant-sensitive.jsonl` | 20 | Multi-digit arithmetic, long CoT, precise syntax |
162163
| `determinism.jsonl` | 50 | High-entropy continuations for determinism testing |
163164

164-
All suites ship with the package — no need to clone the repo. Custom suites are JSONL files with one object per line:
165+
All suites ship with the package — no need to clone the repo. Custom suites are JSONL files with one object per line (default `max_tokens` is 1024):
165166

166167
```json
167168
{"id": "custom-001", "text": "Your prompt here", "category": "math", "max_tokens": 512}

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ dev = [
6565
]
6666

6767
[tool.ruff]
68-
target-version = "py313"
68+
target-version = "py311"
6969
line-length = 120
7070
src = ["src"]
7171

@@ -76,7 +76,7 @@ select = ["E", "F", "I", "N", "W", "UP", "B", "SIM"]
7676
"html.py" = ["E501"]
7777

7878
[tool.mypy]
79-
python_version = "3.13"
79+
python_version = "3.11"
8080
strict = true
8181

8282
[tool.pytest.ini_options]

src/infer_check/backends/mlx_lm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import asyncio
56
import gc
67
import time
78
from typing import Any, cast
@@ -85,7 +86,7 @@ async def cleanup(self) -> None:
8586
"""Release model references and trigger garbage collection."""
8687
self._model = None
8788
self._tokenizer = None
88-
gc.collect()
89+
await asyncio.to_thread(gc.collect)
8990

9091
# ------------------------------------------------------------------
9192
# Internal helpers

src/infer_check/backends/openai_compat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ async def _generate_chat(self, prompt: Prompt) -> InferenceResult:
101101
choice = data["choices"][0]
102102
message = choice.get("message", {})
103103
text: str = message.get("content", "")
104+
if not text:
105+
text = message.get("reasoning_content", "")
104106
tokens = text.split()
105107

106108
usage = data.get("usage", {})

src/infer_check/cli.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,34 @@
44

55
import asyncio
66
import json
7+
from collections.abc import Callable
78
from datetime import UTC, datetime
89
from pathlib import Path
9-
from typing import Any, Literal
10+
from typing import Any, Literal, TypeVar
1011

1112
import click
1213
from rich.console import Console
1314
from rich.table import Table
1415

1516
console = Console()
1617

18+
F = TypeVar("F", bound=Callable[..., Any])
19+
20+
21+
def common_options(f: F) -> F:
22+
"""Add common options to all subcommands."""
23+
options = [
24+
click.option(
25+
"--max-tokens",
26+
default=None,
27+
type=click.IntRange(min=1, clamp=True),
28+
help="Override default max tokens for generation.",
29+
),
30+
]
31+
for option in reversed(options):
32+
f = option(f)
33+
return f
34+
1735

1836
def _resolve_prompts(prompts: str) -> Path:
1937
"""Resolve a prompt suite name or path to an actual file path."""
@@ -27,8 +45,17 @@ def _resolve_prompts(prompts: str) -> Path:
2745

2846
@click.group()
2947
@click.version_option(package_name="infer-check")
30-
def main() -> None:
48+
@click.option(
49+
"--max-tokens",
50+
default=1024,
51+
show_default=True,
52+
help="Default max tokens for generation (applies to all prompts unless they specify their own).",
53+
)
54+
@click.pass_context
55+
def main(ctx: click.Context, max_tokens: int) -> None:
3156
"""infer-check: correctness and reliability testing for LLM inference engines."""
57+
ctx.ensure_object(dict)
58+
ctx.obj["max_tokens"] = max_tokens
3259

3360

3461
# ---------------------------------------------------------------------------
@@ -65,13 +92,17 @@ def main() -> None:
6592
help="Baseline label (defaults to first in --models).",
6693
)
6794
@click.option("--base-url", default=None, help="Base URL for HTTP backends.")
95+
@common_options
96+
@click.pass_context
6897
def sweep(
98+
ctx: click.Context,
6999
models: str,
70100
backend: str | None,
71101
prompts: str,
72102
output: Path,
73103
baseline: str | None,
74104
base_url: str | None,
105+
max_tokens: int | None,
75106
) -> None:
76107
"""Run a quantization sweep: compare pre-quantized models against a baseline.
77108
@@ -90,6 +121,10 @@ def sweep(
90121
from infer_check.runner import TestRunner
91122
from infer_check.suites.loader import load_suite
92123

124+
# Update max_tokens from subcommand if provided
125+
if max_tokens is not None:
126+
ctx.obj["max_tokens"] = max_tokens
127+
93128
# Parse label=model_path pairs
94129
model_map: dict[str, str] = {}
95130
for entry in models.split(","):
@@ -117,8 +152,11 @@ def sweep(
117152
for label, path in model_map.items():
118153
tag = " (baseline)" if label == baseline_label else ""
119154
console.print(f" {label}: {path}{tag}")
120-
121155
prompt_list = load_suite(_resolve_prompts(prompts))
156+
# Apply global max_tokens only if not explicitly set in the prompt JSONL
157+
for p in prompt_list:
158+
if "max_tokens" not in p.model_fields_set:
159+
p.max_tokens = ctx.obj["max_tokens"]
122160

123161
# Build a separate backend for each model
124162
backend_map: dict[str, Any] = {}
@@ -257,7 +295,10 @@ def sweep(
257295
show_default=True,
258296
help="Generate an HTML comparison report after the run.",
259297
)
298+
@common_options
299+
@click.pass_context
260300
def compare(
301+
ctx: click.Context,
261302
model_a: str,
262303
model_b: str,
263304
prompts: str,
@@ -266,6 +307,7 @@ def compare(
266307
label_a: str | None,
267308
label_b: str | None,
268309
report: bool,
310+
max_tokens: int | None,
269311
) -> None:
270312
"""Compare two quantizations of the same model.
271313
@@ -295,6 +337,10 @@ def compare(
295337
from infer_check.runner import TestRunner
296338
from infer_check.suites.loader import load_suite
297339

340+
# Update max_tokens from subcommand if provided
341+
if max_tokens is not None:
342+
ctx.obj["max_tokens"] = max_tokens
343+
298344
resolved_a = resolve_model(model_a, base_url=base_url, label=label_a)
299345
resolved_b = resolve_model(model_b, base_url=base_url, label=label_b)
300346

@@ -305,6 +351,11 @@ def compare(
305351
)
306352

307353
prompt_list = load_suite(_resolve_prompts(prompts))
354+
# Apply global max_tokens only if not explicitly set in the prompt JSONL
355+
for p in prompt_list:
356+
if "max_tokens" not in p.model_fields_set:
357+
p.max_tokens = ctx.obj["max_tokens"]
358+
308359
console.print(f" prompts: {len(prompt_list)} from '{prompts}'")
309360

310361
# ── Build backends ───────────────────────────────────────────────
@@ -510,20 +561,28 @@ def compare(
510561
show_default=True,
511562
help="Use /v1/chat/completions for HTTP backends (applies chat template server-side).",
512563
)
564+
@common_options
565+
@click.pass_context
513566
def diff(
567+
ctx: click.Context,
514568
model: str,
515569
backends: str,
516570
prompts: str,
517571
output: Path,
518572
quant: str | None,
519573
base_urls: str | None,
520574
chat: bool,
575+
max_tokens: int | None,
521576
) -> None:
522577
"""Compare outputs across different backends for the same model and prompts."""
523578
from infer_check.backends.base import BackendConfig, get_backend
524579
from infer_check.runner import TestRunner
525580
from infer_check.suites.loader import load_suite
526581

582+
# Update max_tokens from subcommand if provided
583+
if max_tokens is not None:
584+
ctx.obj["max_tokens"] = max_tokens
585+
527586
backend_names = [b.strip() for b in backends.split(",") if b.strip()]
528587
url_list: list[str | None] = [u.strip() for u in base_urls.split(",")] if base_urls else [None] * len(backend_names)
529588
# Pad url_list if shorter than backend_names
@@ -533,6 +592,10 @@ def diff(
533592
console.print(f"[bold cyan]diff[/bold cyan] model={model} backends={backend_names} quant={quant}")
534593

535594
prompt_list = load_suite(_resolve_prompts(prompts))
595+
# Apply global max_tokens only if not explicitly set in the prompt JSONL
596+
for p in prompt_list:
597+
if "max_tokens" not in p.model_fields_set:
598+
p.max_tokens = ctx.obj["max_tokens"]
536599

537600
backend_instances = []
538601
for name, url in zip(backend_names, url_list, strict=True):
@@ -619,19 +682,27 @@ def diff(
619682
help="Comma-separated concurrency levels.",
620683
)
621684
@click.option("--base-url", default=None, help="Base URL for HTTP backends.")
685+
@common_options
686+
@click.pass_context
622687
def stress(
688+
ctx: click.Context,
623689
model: str,
624690
backend: str | None,
625691
prompts: str,
626692
output: Path,
627693
concurrency: str,
628694
base_url: str | None,
695+
max_tokens: int | None,
629696
) -> None:
630697
"""Stress-test a backend with varying concurrency levels."""
631698
from infer_check.backends.base import get_backend_for_model
632699
from infer_check.runner import TestRunner
633700
from infer_check.suites.loader import load_suite
634701

702+
# Update max_tokens from subcommand if provided
703+
if max_tokens is not None:
704+
ctx.obj["max_tokens"] = max_tokens
705+
635706
concurrency_levels = [int(c.strip()) for c in concurrency.split(",") if c.strip()]
636707

637708
backend_instance = get_backend_for_model(
@@ -645,6 +716,10 @@ def stress(
645716
)
646717

647718
prompt_list = load_suite(_resolve_prompts(prompts))
719+
# Apply global max_tokens only if not explicitly set in the prompt JSONL
720+
for p in prompt_list:
721+
if "max_tokens" not in p.model_fields_set:
722+
p.max_tokens = ctx.obj["max_tokens"]
648723

649724
runner = TestRunner()
650725
stress_results = asyncio.run(
@@ -704,19 +779,27 @@ def stress(
704779
)
705780
@click.option("--runs", default=100, show_default=True, type=int, help="Number of runs per prompt.")
706781
@click.option("--base-url", default=None, help="Base URL for HTTP backends.")
782+
@common_options
783+
@click.pass_context
707784
def determinism(
785+
ctx: click.Context,
708786
model: str,
709787
backend: str | None,
710788
prompts: str,
711789
output: Path,
712790
runs: int,
713791
base_url: str | None,
792+
max_tokens: int | None,
714793
) -> None:
715794
"""Test whether a backend produces identical outputs across repeated runs at temperature=0."""
716795
from infer_check.backends.base import get_backend_for_model
717796
from infer_check.runner import TestRunner
718797
from infer_check.suites.loader import load_suite
719798

799+
# Update max_tokens from subcommand if provided
800+
if max_tokens is not None:
801+
ctx.obj["max_tokens"] = max_tokens
802+
720803
backend_instance = get_backend_for_model(
721804
model_str=model,
722805
backend_type=backend,
@@ -726,6 +809,10 @@ def determinism(
726809
console.print(f"[bold cyan]determinism[/bold cyan] model={model} backend={backend_instance.name} runs={runs}")
727810

728811
prompt_list = load_suite(_resolve_prompts(prompts))
812+
# Apply global max_tokens only if not explicitly set in the prompt JSONL
813+
for p in prompt_list:
814+
if "max_tokens" not in p.model_fields_set:
815+
p.max_tokens = ctx.obj["max_tokens"]
729816

730817
runner = TestRunner()
731818
det_results = asyncio.run(

src/infer_check/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Prompt(BaseInferModel):
4848
id: str = Field(default_factory=_generate_uuid)
4949
text: str
5050
category: str = "general"
51-
max_tokens: int = 256
51+
max_tokens: int = 1024
5252
metadata: dict[str, Any] = Field(default_factory=dict)
5353

5454

0 commit comments

Comments
 (0)