@@ -27,8 +27,17 @@ def _resolve_prompts(prompts: str) -> Path:
2727
2828@click .group ()
2929@click .version_option (package_name = "infer-check" )
30- def main () -> None :
30+ @click .option (
31+ "--max-tokens" ,
32+ default = 1024 ,
33+ show_default = True ,
34+ help = "Default max tokens for generation (applies to all prompts unless they specify their own)." ,
35+ )
36+ @click .pass_context
37+ def main (ctx : click .Context , max_tokens : int ) -> None :
3138 """infer-check: correctness and reliability testing for LLM inference engines."""
39+ ctx .ensure_object (dict )
40+ ctx .obj ["max_tokens" ] = max_tokens
3241
3342
3443# ---------------------------------------------------------------------------
@@ -65,7 +74,9 @@ def main() -> None:
6574 help = "Baseline label (defaults to first in --models)." ,
6675)
6776@click .option ("--base-url" , default = None , help = "Base URL for HTTP backends." )
77+ @click .pass_context
6878def sweep (
79+ ctx : click .Context ,
6980 models : str ,
7081 backend : str | None ,
7182 prompts : str ,
@@ -117,8 +128,10 @@ def sweep(
117128 for label , path in model_map .items ():
118129 tag = " (baseline)" if label == baseline_label else ""
119130 console .print (f" { label } : { path } { tag } " )
120-
121131 prompt_list = load_suite (_resolve_prompts (prompts ))
132+ # Apply global max_tokens
133+ for p in prompt_list :
134+ p .max_tokens = ctx .obj ["max_tokens" ]
122135
123136 # Build a separate backend for each model
124137 backend_map : dict [str , Any ] = {}
@@ -257,7 +270,9 @@ def sweep(
257270 show_default = True ,
258271 help = "Generate an HTML comparison report after the run." ,
259272)
273+ @click .pass_context
260274def compare (
275+ ctx : click .Context ,
261276 model_a : str ,
262277 model_b : str ,
263278 prompts : str ,
@@ -305,6 +320,10 @@ def compare(
305320 )
306321
307322 prompt_list = load_suite (_resolve_prompts (prompts ))
323+ # Apply global max_tokens
324+ for p in prompt_list :
325+ p .max_tokens = ctx .obj ["max_tokens" ]
326+
308327 console .print (f" prompts: { len (prompt_list )} from '{ prompts } '" )
309328
310329 # ── Build backends ───────────────────────────────────────────────
@@ -510,7 +529,9 @@ def compare(
510529 show_default = True ,
511530 help = "Use /v1/chat/completions for HTTP backends (applies chat template server-side)." ,
512531)
532+ @click .pass_context
513533def diff (
534+ ctx : click .Context ,
514535 model : str ,
515536 backends : str ,
516537 prompts : str ,
@@ -533,6 +554,9 @@ def diff(
533554 console .print (f"[bold cyan]diff[/bold cyan] model={ model } backends={ backend_names } quant={ quant } " )
534555
535556 prompt_list = load_suite (_resolve_prompts (prompts ))
557+ # Apply global max_tokens
558+ for p in prompt_list :
559+ p .max_tokens = ctx .obj ["max_tokens" ]
536560
537561 backend_instances = []
538562 for name , url in zip (backend_names , url_list , strict = True ):
@@ -619,7 +643,9 @@ def diff(
619643 help = "Comma-separated concurrency levels." ,
620644)
621645@click .option ("--base-url" , default = None , help = "Base URL for HTTP backends." )
646+ @click .pass_context
622647def stress (
648+ ctx : click .Context ,
623649 model : str ,
624650 backend : str | None ,
625651 prompts : str ,
@@ -645,6 +671,9 @@ def stress(
645671 )
646672
647673 prompt_list = load_suite (_resolve_prompts (prompts ))
674+ # Apply global max_tokens
675+ for p in prompt_list :
676+ p .max_tokens = ctx .obj ["max_tokens" ]
648677
649678 runner = TestRunner ()
650679 stress_results = asyncio .run (
@@ -704,7 +733,9 @@ def stress(
704733)
705734@click .option ("--runs" , default = 100 , show_default = True , type = int , help = "Number of runs per prompt." )
706735@click .option ("--base-url" , default = None , help = "Base URL for HTTP backends." )
736+ @click .pass_context
707737def determinism (
738+ ctx : click .Context ,
708739 model : str ,
709740 backend : str | None ,
710741 prompts : str ,
@@ -726,6 +757,9 @@ def determinism(
726757 console .print (f"[bold cyan]determinism[/bold cyan] model={ model } backend={ backend_instance .name } runs={ runs } " )
727758
728759 prompt_list = load_suite (_resolve_prompts (prompts ))
760+ # Apply global max_tokens
761+ for p in prompt_list :
762+ p .max_tokens = ctx .obj ["max_tokens" ]
729763
730764 runner = TestRunner ()
731765 det_results = asyncio .run (
0 commit comments