Skip to content

Commit af17782

Browse files
refactor: centralize flip detection and add comprehensive tests for diff command
- Extract answer extraction and flip detection logic into a new `_annotate_flip_metadata` helper method in `TestRunner` to eliminate code duplication between `compare` and `diff`. - Add unit tests for the `diff` CLI command (`test_cli_diff.py`) to verify summary table rendering, backend metrics, and flip rate formatting. - Introduce `StubBackend` in `test_runner.py` and add async tests to ensure the `diff` runner method accurately detects answer flips across single and multiple test backends.
1 parent 6cad8fa commit af17782

3 files changed

Lines changed: 277 additions & 40 deletions

File tree

src/infer_check/runner.py

Lines changed: 28 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,32 @@ def __init__(self, cache_dir: str | Path = ".infer_check_cache"):
2727
self.cache_dir = Path(cache_dir)
2828
self.cache_dir.mkdir(parents=True, exist_ok=True)
2929

30+
def _annotate_flip_metadata(
31+
self,
32+
comp: ComparisonResult,
33+
text_a: str,
34+
text_b: str,
35+
category: str,
36+
) -> None:
37+
"""Extract functional answers and check for flips, updating comparison metadata."""
38+
from infer_check.analysis.answer_extract import (
39+
answers_match,
40+
extract_answer,
41+
)
42+
43+
ans_a = extract_answer(text_a, category)
44+
ans_b = extract_answer(text_b, category)
45+
flipped = not answers_match(ans_a, ans_b)
46+
47+
comp.metadata["flipped"] = flipped
48+
comp.metadata["answer_a"] = ans_a.value
49+
comp.metadata["answer_b"] = ans_b.value
50+
comp.metadata["extraction_strategy"] = ans_a.strategy
51+
comp.metadata["extraction_confidence"] = min(
52+
ans_a.confidence,
53+
ans_b.confidence,
54+
)
55+
3056
def _save_checkpoint(self, results: Any, path: Path) -> None:
3157
"""Write intermediate results as JSON for resumability."""
3258
path.parent.mkdir(parents=True, exist_ok=True)
@@ -401,30 +427,12 @@ async def compare(
401427
progress.advance(task)
402428

403429
# ── Build comparisons with answer extraction ────────────────
404-
from infer_check.analysis.answer_extract import (
405-
answers_match,
406-
extract_answer,
407-
)
408-
409430
for prompt in prompts:
410431
a = results_a.get(prompt.id)
411432
b = results_b.get(prompt.id)
412433
if a and b:
413434
comp = self._compare(a, b)
414-
415-
# Extract functional answers and check for flips.
416-
ans_a = extract_answer(a.text, prompt.category)
417-
ans_b = extract_answer(b.text, prompt.category)
418-
flipped = not answers_match(ans_a, ans_b)
419-
420-
comp.metadata["flipped"] = flipped
421-
comp.metadata["answer_a"] = ans_a.value
422-
comp.metadata["answer_b"] = ans_b.value
423-
comp.metadata["extraction_strategy"] = ans_a.strategy
424-
comp.metadata["extraction_confidence"] = min(
425-
ans_a.confidence,
426-
ans_b.confidence,
427-
)
435+
self._annotate_flip_metadata(comp, a.text, b.text, prompt.category)
428436
comparisons.append(comp)
429437

430438
# ── Aggregate metrics ────────────────────────────────────────
@@ -491,11 +499,6 @@ async def diff(
491499
prompts: list[Prompt],
492500
) -> list[ComparisonResult]:
493501
"""Compare outputs across different backends against a baseline."""
494-
from infer_check.analysis.answer_extract import (
495-
answers_match,
496-
extract_answer,
497-
)
498-
499502
baseline_results: dict[str, InferenceResult] = {}
500503
comparisons: list[ComparisonResult] = []
501504

@@ -538,21 +541,7 @@ async def diff(
538541
baseline = baseline_results.get(test_res.prompt_id)
539542
if baseline:
540543
comp = self._compare(baseline, test_res)
541-
542-
# Answer extraction and flip detection
543-
ans_a = extract_answer(baseline.text, prompt.category)
544-
ans_b = extract_answer(test_res.text, prompt.category)
545-
flipped = not answers_match(ans_a, ans_b)
546-
547-
comp.metadata["flipped"] = flipped
548-
comp.metadata["answer_a"] = ans_a.value
549-
comp.metadata["answer_b"] = ans_b.value
550-
comp.metadata["extraction_strategy"] = ans_a.strategy
551-
comp.metadata["extraction_confidence"] = min(
552-
ans_a.confidence,
553-
ans_b.confidence,
554-
)
555-
544+
self._annotate_flip_metadata(comp, baseline.text, test_res.text, prompt.category)
556545
comparisons.append(comp)
557546
progress.advance(task)
558547

tests/unit/test_cli_diff.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from pathlib import Path
2+
from typing import Any
3+
from unittest.mock import patch
4+
5+
import pytest
6+
from click.testing import CliRunner
7+
8+
from infer_check.cli import main
9+
from infer_check.types import ComparisonResult, InferenceResult
10+
11+
12+
@pytest.fixture
13+
def runner() -> CliRunner:
14+
return CliRunner()
15+
16+
17+
def test_cli_diff_summary_table(runner: CliRunner, tmp_path: Path) -> None:
18+
# Setup dummy prompts file
19+
dummy_suite = tmp_path / "dummy.jsonl"
20+
dummy_suite.write_text(
21+
'{"id":"p1", "text":"hi", "category":"general"}\n{"id":"p2", "text":"bye", "category":"general"}'
22+
)
23+
24+
# Mock InferenceResults
25+
inf_res_baseline_1 = InferenceResult(
26+
prompt_id="p1", backend_name="baseline", model_id="m", tokens=["hi"], text="hi", latency_ms=1.0
27+
)
28+
inf_res_test_1 = InferenceResult(
29+
prompt_id="p1", backend_name="test", model_id="m", tokens=["hi"], text="hi", latency_ms=1.0
30+
)
31+
32+
inf_res_baseline_2 = InferenceResult(
33+
prompt_id="p2", backend_name="baseline", model_id="m", tokens=["bye"], text="bye", latency_ms=1.0
34+
)
35+
inf_res_test_2 = InferenceResult(
36+
prompt_id="p2", backend_name="test", model_id="m", tokens=["hello"], text="hello", latency_ms=1.0
37+
)
38+
39+
# Mock ComparisonResults
40+
# comp1: not flipped
41+
comp1 = ComparisonResult(
42+
baseline=inf_res_baseline_1,
43+
test=inf_res_test_1,
44+
text_similarity=1.0,
45+
is_failure=False,
46+
metadata={"flipped": False},
47+
)
48+
# comp2: flipped
49+
comp2 = ComparisonResult(
50+
baseline=inf_res_baseline_2,
51+
test=inf_res_test_2,
52+
text_similarity=0.4,
53+
is_failure=True,
54+
metadata={"flipped": True},
55+
)
56+
57+
with (
58+
patch("infer_check.backends.base.get_backend"),
59+
patch("infer_check.runner.TestRunner.diff") as mock_diff,
60+
):
61+
# We need an async mock that returns the list of comparisons
62+
async def mock_diff_async(*args: Any, **kwargs: Any) -> list[ComparisonResult]:
63+
return [comp1, comp2]
64+
65+
mock_diff.side_effect = mock_diff_async
66+
67+
result = runner.invoke(
68+
main,
69+
[
70+
"diff",
71+
"--model",
72+
"m1",
73+
"--backends",
74+
"mlx-lm,llama-cpp",
75+
"--prompts",
76+
str(dummy_suite),
77+
"--output",
78+
str(tmp_path),
79+
],
80+
)
81+
82+
assert result.exit_code == 0
83+
84+
# Output should contain the table headers
85+
assert "test_backend" in result.output
86+
assert "failures" in result.output
87+
assert "failure_rate" in result.output
88+
assert "flip_rate" in result.output
89+
assert "mean_similarity" in result.output
90+
91+
# Check backend name and metrics
92+
assert "llama-cpp" in result.output # Backend name used in runner.diff?
93+
# Actually in cli.py it uses backend_names = [b.strip() for b in backends.split(",")]
94+
# and it pads it. In my mock, the Comparisons results have backend_name from inf_res.
95+
# But groups = defaultdict(list)
96+
# for comp in comparisons:
97+
# groups[comp.test.backend_name].append(comp)
98+
99+
assert "test" in result.output # backend_name in inf_res_test_*
100+
101+
# 2 prompts, 1 failure -> 50.00%
102+
assert "50.00%" in result.output
103+
104+
# 1 flip out of 2 -> 50.0%
105+
# The formatting in cli.py is f"[{'red' if flip_rate > 0.1 else 'green'}]{flip_rate:.1%}[/]"
106+
# Rich markup might be stripped or present depending on how CliRunner handles it.
107+
# Usually CliRunner output doesn't have the color codes unless we tell it to.
108+
assert "50.0%" in result.output
109+
110+
# mean similarity: (1.0 + 0.4) / 2 = 0.7
111+
assert "0.7000" in result.output
112+
113+
114+
def test_cli_diff_summary_no_flips(runner: CliRunner, tmp_path: Path) -> None:
115+
dummy_suite = tmp_path / "dummy_no_flips.jsonl"
116+
dummy_suite.write_text('{"id":"p1", "text":"hi", "category":"general"}')
117+
118+
inf_res_baseline = InferenceResult(
119+
prompt_id="p1", backend_name="baseline", model_id="m", tokens=["hi"], text="hi", latency_ms=1.0
120+
)
121+
inf_res_test = InferenceResult(
122+
prompt_id="p1", backend_name="test", model_id="m", tokens=["hi"], text="hi", latency_ms=1.0
123+
)
124+
125+
comp = ComparisonResult(
126+
baseline=inf_res_baseline, test=inf_res_test, text_similarity=1.0, is_failure=False, metadata={"flipped": False}
127+
)
128+
129+
with (
130+
patch("infer_check.backends.base.get_backend"),
131+
patch("infer_check.runner.TestRunner.diff") as mock_diff,
132+
):
133+
134+
async def mock_diff_async(*args: Any, **kwargs: Any) -> list[ComparisonResult]:
135+
return [comp]
136+
137+
mock_diff.side_effect = mock_diff_async
138+
139+
result = runner.invoke(
140+
main,
141+
[
142+
"diff",
143+
"--model",
144+
"m1",
145+
"--backends",
146+
"mlx-lm,llama-cpp",
147+
"--prompts",
148+
str(dummy_suite),
149+
"--output",
150+
str(tmp_path),
151+
],
152+
)
153+
154+
assert result.exit_code == 0
155+
assert "0.0%" in result.output

tests/unit/test_runner.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import asyncio
2+
13
import pytest
24

35
from infer_check.runner import TestRunner
4-
from infer_check.types import InferenceResult
6+
from infer_check.types import InferenceResult, Prompt
57

68

79
@pytest.fixture
@@ -115,3 +117,94 @@ def test_compare_threshold_edge_cases(runner: TestRunner) -> None:
115117
# Since similarity is ~0.8, it should fail.
116118
comp_strict = runner._compare(baseline, test_res, threshold=0.9)
117119
assert comp_strict.is_failure is True
120+
121+
122+
class StubBackend:
123+
def __init__(self, name: str, responses: dict[str, str]):
124+
self._name = name
125+
self._responses = responses
126+
self.cleanup_called = False
127+
128+
@property
129+
def name(self) -> str:
130+
return self._name
131+
132+
async def generate(self, prompt: Prompt) -> InferenceResult:
133+
text = self._responses.get(prompt.id, "Default response")
134+
return InferenceResult(
135+
prompt_id=prompt.id,
136+
backend_name=self._name,
137+
model_id="stub-model",
138+
tokens=text.split(),
139+
text=text,
140+
latency_ms=1.0,
141+
metadata={},
142+
)
143+
144+
async def generate_batch(self, prompts: list[Prompt]) -> list[InferenceResult]:
145+
return await asyncio.gather(*(self.generate(p) for p in prompts))
146+
147+
async def health_check(self) -> bool:
148+
return True
149+
150+
async def cleanup(self) -> None:
151+
self.cleanup_called = True
152+
153+
154+
@pytest.mark.asyncio
155+
async def test_diff_flip_detection(runner: TestRunner) -> None:
156+
# Setup prompts
157+
p1 = Prompt(id="p1", text="What is 2+2?", category="arithmetic")
158+
p2 = Prompt(id="p2", text="What is the capital of France?", category="general")
159+
prompts = [p1, p2]
160+
161+
# Setup backends
162+
# p1: baseline says 4, test says 5 (FLIP)
163+
# p2: both say Paris (NO FLIP)
164+
baseline_backend = StubBackend("baseline", {"p1": "The answer is 4", "p2": "Paris"})
165+
test_backend = StubBackend("test", {"p1": "The answer is 5", "p2": "Paris"})
166+
167+
# Run diff
168+
comparisons = await runner.diff(
169+
baseline_backend=baseline_backend,
170+
test_backends=[test_backend],
171+
prompts=prompts,
172+
)
173+
174+
assert len(comparisons) == 2
175+
176+
# Check p1 (flip)
177+
comp_p1 = next(c for c in comparisons if c.baseline.prompt_id == "p1")
178+
assert comp_p1.metadata["flipped"] is True
179+
assert comp_p1.metadata["answer_a"] == "4"
180+
assert comp_p1.metadata["answer_b"] == "5"
181+
assert "extraction_confidence" in comp_p1.metadata
182+
183+
# Check p2 (no flip)
184+
comp_p2 = next(c for c in comparisons if c.baseline.prompt_id == "p2")
185+
assert comp_p2.metadata["flipped"] is False
186+
assert comp_p2.metadata["answer_a"].lower() == "paris"
187+
assert comp_p2.metadata["answer_b"].lower() == "paris"
188+
assert "extraction_confidence" in comp_p2.metadata
189+
190+
191+
@pytest.mark.asyncio
192+
async def test_diff_multiple_test_backends(runner: TestRunner) -> None:
193+
p1 = Prompt(id="p1", text="Test", category="general")
194+
prompts = [p1]
195+
196+
baseline = StubBackend("baseline", {"p1": "A"})
197+
test1 = StubBackend("test1", {"p1": "B"})
198+
test2 = StubBackend("test2", {"p1": "A"})
199+
200+
comparisons = await runner.diff(
201+
baseline_backend=baseline,
202+
test_backends=[test1, test2],
203+
prompts=prompts,
204+
)
205+
206+
assert len(comparisons) == 2
207+
assert comparisons[0].test.backend_name == "test1"
208+
assert comparisons[0].metadata["flipped"] is True
209+
assert comparisons[1].test.backend_name == "test2"
210+
assert comparisons[1].metadata["flipped"] is False

0 commit comments

Comments
 (0)