Skip to content

Commit 50210ee

Browse files
feat: add global --max-tokens CLI flag and increase default max_tokens to 1024
1 parent 4d0ea3e commit 50210ee

4 files changed

Lines changed: 45 additions & 5 deletions

File tree

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
# 0.2.0 (2026-04-02)
2+
3+
- Added global `--max-tokens` flag (defaults to 1024) to the main CLI.
4+
- Increased default `max_tokens` for all prompts from 256 to 1024.
5+
16
# 0.1.0 (2026-03-11)
27

38
- Initial release

README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ pip install "infer-check[mlx]"
6868

6969
### Quantization sweep
7070

71-
Compare pre-quantized models against a baseline. Each model is a separate HuggingFace repo.
71+
Compare pre-quantized models against a baseline. Each model is a separate HuggingFace repo. Use `--max-tokens` to control generation length (defaults to 1024).
7272

7373
```
7474
infer-check sweep \
@@ -77,6 +77,7 @@ infer-check sweep \
7777
4bit=mlx-community/Meta-Llama-3.1-8B-Instruct-4bit" \
7878
--backend mlx-lm \
7979
--prompts reasoning \
80+
--max-tokens 512 \
8081
--output ./results/sweep/
8182
```
8283

@@ -161,7 +162,7 @@ Curated prompts targeting known quantization failure modes:
161162
| `quant-sensitive.jsonl` | 20 | Multi-digit arithmetic, long CoT, precise syntax |
162163
| `determinism.jsonl` | 50 | High-entropy continuations for determinism testing |
163164

164-
All suites ship with the package — no need to clone the repo. Custom suites are JSONL files with one object per line:
165+
All suites ship with the package — no need to clone the repo. Custom suites are JSONL files with one object per line (default `max_tokens` is 1024):
165166

166167
```json
167168
{"id": "custom-001", "text": "Your prompt here", "category": "math", "max_tokens": 512}

src/infer_check/cli.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,17 @@ def _resolve_prompts(prompts: str) -> Path:
2727

2828
@click.group()
2929
@click.version_option(package_name="infer-check")
30-
def main() -> None:
30+
@click.option(
31+
"--max-tokens",
32+
default=1024,
33+
show_default=True,
34+
help="Default max tokens for generation (applies to all prompts unless they specify their own).",
35+
)
36+
@click.pass_context
37+
def main(ctx: click.Context, max_tokens: int) -> None:
3138
"""infer-check: correctness and reliability testing for LLM inference engines."""
39+
ctx.ensure_object(dict)
40+
ctx.obj["max_tokens"] = max_tokens
3241

3342

3443
# ---------------------------------------------------------------------------
@@ -65,7 +74,9 @@ def main() -> None:
6574
help="Baseline label (defaults to first in --models).",
6675
)
6776
@click.option("--base-url", default=None, help="Base URL for HTTP backends.")
77+
@click.pass_context
6878
def sweep(
79+
ctx: click.Context,
6980
models: str,
7081
backend: str | None,
7182
prompts: str,
@@ -117,8 +128,10 @@ def sweep(
117128
for label, path in model_map.items():
118129
tag = " (baseline)" if label == baseline_label else ""
119130
console.print(f" {label}: {path}{tag}")
120-
121131
prompt_list = load_suite(_resolve_prompts(prompts))
132+
# Apply global max_tokens
133+
for p in prompt_list:
134+
p.max_tokens = ctx.obj["max_tokens"]
122135

123136
# Build a separate backend for each model
124137
backend_map: dict[str, Any] = {}
@@ -257,7 +270,9 @@ def sweep(
257270
show_default=True,
258271
help="Generate an HTML comparison report after the run.",
259272
)
273+
@click.pass_context
260274
def compare(
275+
ctx: click.Context,
261276
model_a: str,
262277
model_b: str,
263278
prompts: str,
@@ -305,6 +320,10 @@ def compare(
305320
)
306321

307322
prompt_list = load_suite(_resolve_prompts(prompts))
323+
# Apply global max_tokens
324+
for p in prompt_list:
325+
p.max_tokens = ctx.obj["max_tokens"]
326+
308327
console.print(f" prompts: {len(prompt_list)} from '{prompts}'")
309328

310329
# ── Build backends ───────────────────────────────────────────────
@@ -510,7 +529,9 @@ def compare(
510529
show_default=True,
511530
help="Use /v1/chat/completions for HTTP backends (applies chat template server-side).",
512531
)
532+
@click.pass_context
513533
def diff(
534+
ctx: click.Context,
514535
model: str,
515536
backends: str,
516537
prompts: str,
@@ -533,6 +554,9 @@ def diff(
533554
console.print(f"[bold cyan]diff[/bold cyan] model={model} backends={backend_names} quant={quant}")
534555

535556
prompt_list = load_suite(_resolve_prompts(prompts))
557+
# Apply global max_tokens
558+
for p in prompt_list:
559+
p.max_tokens = ctx.obj["max_tokens"]
536560

537561
backend_instances = []
538562
for name, url in zip(backend_names, url_list, strict=True):
@@ -619,7 +643,9 @@ def diff(
619643
help="Comma-separated concurrency levels.",
620644
)
621645
@click.option("--base-url", default=None, help="Base URL for HTTP backends.")
646+
@click.pass_context
622647
def stress(
648+
ctx: click.Context,
623649
model: str,
624650
backend: str | None,
625651
prompts: str,
@@ -645,6 +671,9 @@ def stress(
645671
)
646672

647673
prompt_list = load_suite(_resolve_prompts(prompts))
674+
# Apply global max_tokens
675+
for p in prompt_list:
676+
p.max_tokens = ctx.obj["max_tokens"]
648677

649678
runner = TestRunner()
650679
stress_results = asyncio.run(
@@ -704,7 +733,9 @@ def stress(
704733
)
705734
@click.option("--runs", default=100, show_default=True, type=int, help="Number of runs per prompt.")
706735
@click.option("--base-url", default=None, help="Base URL for HTTP backends.")
736+
@click.pass_context
707737
def determinism(
738+
ctx: click.Context,
708739
model: str,
709740
backend: str | None,
710741
prompts: str,
@@ -726,6 +757,9 @@ def determinism(
726757
console.print(f"[bold cyan]determinism[/bold cyan] model={model} backend={backend_instance.name} runs={runs}")
727758

728759
prompt_list = load_suite(_resolve_prompts(prompts))
760+
# Apply global max_tokens
761+
for p in prompt_list:
762+
p.max_tokens = ctx.obj["max_tokens"]
729763

730764
runner = TestRunner()
731765
det_results = asyncio.run(

src/infer_check/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Prompt(BaseInferModel):
4848
id: str = Field(default_factory=_generate_uuid)
4949
text: str
5050
category: str = "general"
51-
max_tokens: int = 256
51+
max_tokens: int = 1024
5252
metadata: dict[str, Any] = Field(default_factory=dict)
5353

5454

0 commit comments

Comments
 (0)