Skip to content

Commit 77b5de2

Browse files
Apply suggestions from code review
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 6366429 commit 77b5de2

1 file changed

Lines changed: 18 additions & 10 deletions

File tree

src/infer_check/cli.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -129,9 +129,13 @@ def sweep(
129129
tag = " (baseline)" if label == baseline_label else ""
130130
console.print(f" {label}: {path}{tag}")
131131
prompt_list = load_suite(_resolve_prompts(prompts))
132-
# Apply global max_tokens
132+
# Apply global max_tokens only when a prompt did not explicitly set its own.
133133
for p in prompt_list:
134-
p.max_tokens = ctx.obj["max_tokens"]
134+
fields_set = getattr(p, "model_fields_set", None)
135+
if fields_set is None:
136+
fields_set = getattr(p, "__pydantic_fields_set__", set())
137+
if "max_tokens" not in fields_set:
138+
p.max_tokens = ctx.obj["max_tokens"]
135139

136140
# Build a separate backend for each model
137141
backend_map: dict[str, Any] = {}
@@ -320,9 +324,10 @@ def compare(
320324
)
321325

322326
prompt_list = load_suite(_resolve_prompts(prompts))
323-
# Apply global max_tokens
327+
# Apply global max_tokens only when the suite did not set one.
324328
for p in prompt_list:
325-
p.max_tokens = ctx.obj["max_tokens"]
329+
if getattr(p, "max_tokens", None) is None:
330+
p.max_tokens = ctx.obj["max_tokens"]
326331

327332
console.print(f" prompts: {len(prompt_list)} from '{prompts}'")
328333

@@ -554,9 +559,10 @@ def diff(
554559
console.print(f"[bold cyan]diff[/bold cyan] model={model} backends={backend_names} quant={quant}")
555560

556561
prompt_list = load_suite(_resolve_prompts(prompts))
557-
# Apply global max_tokens
562+
# Apply global max_tokens only when the prompt does not already specify one.
558563
for p in prompt_list:
559-
p.max_tokens = ctx.obj["max_tokens"]
564+
if getattr(p, "max_tokens", None) is None:
565+
p.max_tokens = ctx.obj["max_tokens"]
560566

561567
backend_instances = []
562568
for name, url in zip(backend_names, url_list, strict=True):
@@ -671,9 +677,10 @@ def stress(
671677
)
672678

673679
prompt_list = load_suite(_resolve_prompts(prompts))
674-
# Apply global max_tokens
680+
# Apply global max_tokens only to prompts that do not already specify one.
675681
for p in prompt_list:
676-
p.max_tokens = ctx.obj["max_tokens"]
682+
if p.max_tokens is None:
683+
p.max_tokens = ctx.obj["max_tokens"]
677684

678685
runner = TestRunner()
679686
stress_results = asyncio.run(
@@ -757,9 +764,10 @@ def determinism(
757764
console.print(f"[bold cyan]determinism[/bold cyan] model={model} backend={backend_instance.name} runs={runs}")
758765

759766
prompt_list = load_suite(_resolve_prompts(prompts))
760-
# Apply global max_tokens
767+
# Apply global max_tokens only when a prompt does not define its own value.
761768
for p in prompt_list:
762-
p.max_tokens = ctx.obj["max_tokens"]
769+
if getattr(p, "max_tokens", None) is None:
770+
p.max_tokens = ctx.obj["max_tokens"]
763771

764772
runner = TestRunner()
765773
det_results = asyncio.run(

0 commit comments

Comments
 (0)