44
55import asyncio
66import json
7+ from collections .abc import Callable
78from datetime import UTC , datetime
89from pathlib import Path
9- from typing import Any , Literal
10+ from typing import Any , Literal , TypeVar
1011
1112import click
1213from rich .console import Console
1314from rich .table import Table
1415
1516console = Console ()
1617
18+ F = TypeVar ("F" , bound = Callable [..., Any ])
19+
20+
21+ def common_options (f : F ) -> F :
22+ """Add common options to all subcommands."""
23+ options = [
24+ click .option (
25+ "--max-tokens" ,
26+ default = None ,
27+ type = click .IntRange (min = 1 , clamp = True ),
28+ help = "Override default max tokens for generation." ,
29+ ),
30+ ]
31+ for option in reversed (options ):
32+ f = option (f )
33+ return f
34+
1735
1836def _resolve_prompts (prompts : str ) -> Path :
1937 """Resolve a prompt suite name or path to an actual file path."""
@@ -27,8 +45,17 @@ def _resolve_prompts(prompts: str) -> Path:
2745
2846@click .group ()
2947@click .version_option (package_name = "infer-check" )
30- def main () -> None :
48+ @click .option (
49+ "--max-tokens" ,
50+ default = 1024 ,
51+ show_default = True ,
52+ help = "Default max tokens for generation (applies to all prompts unless they specify their own)." ,
53+ )
54+ @click .pass_context
55+ def main (ctx : click .Context , max_tokens : int ) -> None :
3156 """infer-check: correctness and reliability testing for LLM inference engines."""
57+ ctx .ensure_object (dict )
58+ ctx .obj ["max_tokens" ] = max_tokens
3259
3360
3461# ---------------------------------------------------------------------------
@@ -65,13 +92,17 @@ def main() -> None:
6592 help = "Baseline label (defaults to first in --models)." ,
6693)
6794@click .option ("--base-url" , default = None , help = "Base URL for HTTP backends." )
95+ @common_options
96+ @click .pass_context
6897def sweep (
98+ ctx : click .Context ,
6999 models : str ,
70100 backend : str | None ,
71101 prompts : str ,
72102 output : Path ,
73103 baseline : str | None ,
74104 base_url : str | None ,
105+ max_tokens : int | None ,
75106) -> None :
76107 """Run a quantization sweep: compare pre-quantized models against a baseline.
77108
@@ -90,6 +121,10 @@ def sweep(
90121 from infer_check .runner import TestRunner
91122 from infer_check .suites .loader import load_suite
92123
124+ # Update max_tokens from subcommand if provided
125+ if max_tokens is not None :
126+ ctx .obj ["max_tokens" ] = max_tokens
127+
93128 # Parse label=model_path pairs
94129 model_map : dict [str , str ] = {}
95130 for entry in models .split ("," ):
@@ -117,8 +152,11 @@ def sweep(
117152 for label , path in model_map .items ():
118153 tag = " (baseline)" if label == baseline_label else ""
119154 console .print (f" { label } : { path } { tag } " )
120-
121155 prompt_list = load_suite (_resolve_prompts (prompts ))
156+ # Apply global max_tokens only if not explicitly set in the prompt JSONL
157+ for p in prompt_list :
158+ if "max_tokens" not in p .model_fields_set :
159+ p .max_tokens = ctx .obj ["max_tokens" ]
122160
123161 # Build a separate backend for each model
124162 backend_map : dict [str , Any ] = {}
@@ -257,7 +295,10 @@ def sweep(
257295 show_default = True ,
258296 help = "Generate an HTML comparison report after the run." ,
259297)
298+ @common_options
299+ @click .pass_context
260300def compare (
301+ ctx : click .Context ,
261302 model_a : str ,
262303 model_b : str ,
263304 prompts : str ,
@@ -266,6 +307,7 @@ def compare(
266307 label_a : str | None ,
267308 label_b : str | None ,
268309 report : bool ,
310+ max_tokens : int | None ,
269311) -> None :
270312 """Compare two quantizations of the same model.
271313
@@ -295,6 +337,10 @@ def compare(
295337 from infer_check .runner import TestRunner
296338 from infer_check .suites .loader import load_suite
297339
340+ # Update max_tokens from subcommand if provided
341+ if max_tokens is not None :
342+ ctx .obj ["max_tokens" ] = max_tokens
343+
298344 resolved_a = resolve_model (model_a , base_url = base_url , label = label_a )
299345 resolved_b = resolve_model (model_b , base_url = base_url , label = label_b )
300346
@@ -305,6 +351,11 @@ def compare(
305351 )
306352
307353 prompt_list = load_suite (_resolve_prompts (prompts ))
354+ # Apply global max_tokens only if not explicitly set in the prompt JSONL
355+ for p in prompt_list :
356+ if "max_tokens" not in p .model_fields_set :
357+ p .max_tokens = ctx .obj ["max_tokens" ]
358+
308359 console .print (f" prompts: { len (prompt_list )} from '{ prompts } '" )
309360
310361 # ── Build backends ───────────────────────────────────────────────
@@ -510,20 +561,28 @@ def compare(
510561 show_default = True ,
511562 help = "Use /v1/chat/completions for HTTP backends (applies chat template server-side)." ,
512563)
564+ @common_options
565+ @click .pass_context
513566def diff (
567+ ctx : click .Context ,
514568 model : str ,
515569 backends : str ,
516570 prompts : str ,
517571 output : Path ,
518572 quant : str | None ,
519573 base_urls : str | None ,
520574 chat : bool ,
575+ max_tokens : int | None ,
521576) -> None :
522577 """Compare outputs across different backends for the same model and prompts."""
523578 from infer_check .backends .base import BackendConfig , get_backend
524579 from infer_check .runner import TestRunner
525580 from infer_check .suites .loader import load_suite
526581
582+ # Update max_tokens from subcommand if provided
583+ if max_tokens is not None :
584+ ctx .obj ["max_tokens" ] = max_tokens
585+
527586 backend_names = [b .strip () for b in backends .split ("," ) if b .strip ()]
528587 url_list : list [str | None ] = [u .strip () for u in base_urls .split ("," )] if base_urls else [None ] * len (backend_names )
529588 # Pad url_list if shorter than backend_names
@@ -533,6 +592,10 @@ def diff(
533592 console .print (f"[bold cyan]diff[/bold cyan] model={ model } backends={ backend_names } quant={ quant } " )
534593
535594 prompt_list = load_suite (_resolve_prompts (prompts ))
595+ # Apply global max_tokens only if not explicitly set in the prompt JSONL
596+ for p in prompt_list :
597+ if "max_tokens" not in p .model_fields_set :
598+ p .max_tokens = ctx .obj ["max_tokens" ]
536599
537600 backend_instances = []
538601 for name , url in zip (backend_names , url_list , strict = True ):
@@ -619,19 +682,27 @@ def diff(
619682 help = "Comma-separated concurrency levels." ,
620683)
621684@click .option ("--base-url" , default = None , help = "Base URL for HTTP backends." )
685+ @common_options
686+ @click .pass_context
622687def stress (
688+ ctx : click .Context ,
623689 model : str ,
624690 backend : str | None ,
625691 prompts : str ,
626692 output : Path ,
627693 concurrency : str ,
628694 base_url : str | None ,
695+ max_tokens : int | None ,
629696) -> None :
630697 """Stress-test a backend with varying concurrency levels."""
631698 from infer_check .backends .base import get_backend_for_model
632699 from infer_check .runner import TestRunner
633700 from infer_check .suites .loader import load_suite
634701
702+ # Update max_tokens from subcommand if provided
703+ if max_tokens is not None :
704+ ctx .obj ["max_tokens" ] = max_tokens
705+
635706 concurrency_levels = [int (c .strip ()) for c in concurrency .split ("," ) if c .strip ()]
636707
637708 backend_instance = get_backend_for_model (
@@ -645,6 +716,10 @@ def stress(
645716 )
646717
647718 prompt_list = load_suite (_resolve_prompts (prompts ))
719+ # Apply global max_tokens only if not explicitly set in the prompt JSONL
720+ for p in prompt_list :
721+ if "max_tokens" not in p .model_fields_set :
722+ p .max_tokens = ctx .obj ["max_tokens" ]
648723
649724 runner = TestRunner ()
650725 stress_results = asyncio .run (
@@ -704,19 +779,27 @@ def stress(
704779)
705780@click .option ("--runs" , default = 100 , show_default = True , type = int , help = "Number of runs per prompt." )
706781@click .option ("--base-url" , default = None , help = "Base URL for HTTP backends." )
782+ @common_options
783+ @click .pass_context
707784def determinism (
785+ ctx : click .Context ,
708786 model : str ,
709787 backend : str | None ,
710788 prompts : str ,
711789 output : Path ,
712790 runs : int ,
713791 base_url : str | None ,
792+ max_tokens : int | None ,
714793) -> None :
715794 """Test whether a backend produces identical outputs across repeated runs at temperature=0."""
716795 from infer_check .backends .base import get_backend_for_model
717796 from infer_check .runner import TestRunner
718797 from infer_check .suites .loader import load_suite
719798
799+ # Update max_tokens from subcommand if provided
800+ if max_tokens is not None :
801+ ctx .obj ["max_tokens" ] = max_tokens
802+
720803 backend_instance = get_backend_for_model (
721804 model_str = model ,
722805 backend_type = backend ,
@@ -726,6 +809,10 @@ def determinism(
726809 console .print (f"[bold cyan]determinism[/bold cyan] model={ model } backend={ backend_instance .name } runs={ runs } " )
727810
728811 prompt_list = load_suite (_resolve_prompts (prompts ))
812+ # Apply global max_tokens only if not explicitly set in the prompt JSONL
813+ for p in prompt_list :
814+ if "max_tokens" not in p .model_fields_set :
815+ p .max_tokens = ctx .obj ["max_tokens" ]
729816
730817 runner = TestRunner ()
731818 det_results = asyncio .run (
0 commit comments