@@ -129,9 +129,13 @@ def sweep(
129129 tag = " (baseline)" if label == baseline_label else ""
130130 console .print (f" { label } : { path } { tag } " )
131131 prompt_list = load_suite (_resolve_prompts (prompts ))
132- # Apply global max_tokens
132+ # Apply global max_tokens only when a prompt did not explicitly set its own.
133133 for p in prompt_list :
134- p .max_tokens = ctx .obj ["max_tokens" ]
134+ fields_set = getattr (p , "model_fields_set" , None )
135+ if fields_set is None :
136+ fields_set = getattr (p , "__pydantic_fields_set__" , set ())
137+ if "max_tokens" not in fields_set :
138+ p .max_tokens = ctx .obj ["max_tokens" ]
135139
136140 # Build a separate backend for each model
137141 backend_map : dict [str , Any ] = {}
@@ -320,9 +324,10 @@ def compare(
320324 )
321325
322326 prompt_list = load_suite (_resolve_prompts (prompts ))
323- # Apply global max_tokens
327+ # Apply global max_tokens only when the suite did not set one.
324328 for p in prompt_list :
325- p .max_tokens = ctx .obj ["max_tokens" ]
329+ if getattr (p , "max_tokens" , None ) is None :
330+ p .max_tokens = ctx .obj ["max_tokens" ]
326331
327332 console .print (f" prompts: { len (prompt_list )} from '{ prompts } '" )
328333
@@ -554,9 +559,10 @@ def diff(
554559 console .print (f"[bold cyan]diff[/bold cyan] model={ model } backends={ backend_names } quant={ quant } " )
555560
556561 prompt_list = load_suite (_resolve_prompts (prompts ))
557- # Apply global max_tokens
562+ # Apply global max_tokens only when the prompt does not already specify one.
558563 for p in prompt_list :
559- p .max_tokens = ctx .obj ["max_tokens" ]
564+ if getattr (p , "max_tokens" , None ) is None :
565+ p .max_tokens = ctx .obj ["max_tokens" ]
560566
561567 backend_instances = []
562568 for name , url in zip (backend_names , url_list , strict = True ):
@@ -671,9 +677,10 @@ def stress(
671677 )
672678
673679 prompt_list = load_suite (_resolve_prompts (prompts ))
674- # Apply global max_tokens
680+ # Apply global max_tokens only to prompts that do not already specify one.
675681 for p in prompt_list :
676- p .max_tokens = ctx .obj ["max_tokens" ]
682+ if p .max_tokens is None :
683+ p .max_tokens = ctx .obj ["max_tokens" ]
677684
678685 runner = TestRunner ()
679686 stress_results = asyncio .run (
@@ -757,9 +764,10 @@ def determinism(
757764 console .print (f"[bold cyan]determinism[/bold cyan] model={ model } backend={ backend_instance .name } runs={ runs } " )
758765
759766 prompt_list = load_suite (_resolve_prompts (prompts ))
760- # Apply global max_tokens
767+ # Apply global max_tokens only when a prompt does not define its own value.
761768 for p in prompt_list :
762- p .max_tokens = ctx .obj ["max_tokens" ]
769+ if getattr (p , "max_tokens" , None ) is None :
770+ p .max_tokens = ctx .obj ["max_tokens" ]
763771
764772 runner = TestRunner ()
765773 det_results = asyncio .run (
0 commit comments