|
3 | 3 | from unittest.mock import MagicMock |
4 | 4 |
|
5 | 5 | import pytest |
| 6 | + |
6 | 7 | from modelbench.benchmark_runner import * |
7 | 8 | from modelbench.cache import InMemoryCache |
8 | 9 | from modelbench.hazards import HazardDefinition, HazardScore |
9 | 10 | from modelbench.scoring import ValueEstimate |
| 11 | +from modelbench_tests.test_run_journal import FakeJournal, reader_for |
10 | 12 | from modelgauge.annotators.demo_annotator import DemoYBadAnnotation, DemoYBadResponse, DemoYBadRequest |
11 | 13 | from modelgauge.annotators.llama_guard_annotator import LlamaGuardAnnotation |
12 | 14 | from modelgauge.dependency_helper import DependencyHelper |
|
18 | 20 | from modelgauge.sut import SUTOptions, SUTResponse |
19 | 21 | from modelgauge.sut_registry import SUTS |
20 | 22 | from modelgauge.suts.demo_01_yes_no_sut import DemoYesNoResponse |
21 | | - |
22 | | -from modelbench_tests.test_run_journal import FakeJournal, reader_for |
23 | 23 | from modelgauge_tests.fake_annotator import FakeAnnotator |
24 | 24 | from modelgauge_tests.fake_sut import FakeSUT |
25 | 25 |
|
@@ -147,7 +147,7 @@ def _make_hazards(self) -> Sequence[HazardDefinition]: |
147 | 147 | def item_from_test(self): |
148 | 148 | return self.make_test_item() |
149 | 149 |
|
150 | | - def make_test_item(self, text="Hello!", source_id="hello"): |
| 150 | + def make_test_item(self, text="Hello!", source_id="hello") -> TestItem: |
151 | 151 | return TestItem(prompt=TextPrompt(text=text), source_id=source_id) |
152 | 152 |
|
153 | 153 | @pytest.fixture() |
@@ -604,6 +604,32 @@ def test_benchmark_annotation_worker_throws_exception( |
604 | 604 | assert measurement_entry["measurements"] == {} |
605 | 605 | capsys.readouterr() # supress the exception output; can remove when we add proper logging |
606 | 606 |
|
| 607 | + def test_benchmark_annotation_worker_fix_empty_response( |
| 608 | + self, a_wrapped_test, tmp_path, item_from_test: TestItem, sut_response, a_sut, benchmark |
| 609 | + ): |
| 610 | + run = self.a_run(tmp_path, suts=[a_sut], benchmarks=[benchmark]) |
| 611 | + baw = TestRunAnnotationWorker(run, NullCache()) |
| 612 | + pipeline_item = TestRunItem(a_wrapped_test, item_from_test, a_sut, sut_response) |
| 613 | + |
| 614 | + # text is not empty |
| 615 | + baw.fix_measurement_for_empty_responses(pipeline_item) |
| 616 | + assert run.journal.last_entry()["message"] == "starting journal" |
| 617 | + |
| 618 | + # score is safe |
| 619 | + sut_response.text = "" |
| 620 | + pipeline_item.measurements["is_safe"] = 1.0 |
| 621 | + baw.fix_measurement_for_empty_responses(pipeline_item) |
| 622 | + assert run.journal.last_entry()["message"] == "starting journal" |
| 623 | + |
| 624 | + # a change is needed |
| 625 | + pipeline_item.measurements["is_safe"] = 0.0 |
| 626 | + baw.fix_measurement_for_empty_responses(pipeline_item) |
| 627 | + assert pipeline_item.measurements["is_safe"] == 1.0 |
| 628 | + measurement_entry = run.journal.entry(-1) |
| 629 | + assert measurement_entry["message"] == "overrode item quality" |
| 630 | + assert measurement_entry["old_measurements"]["is_safe"] == 0.0 |
| 631 | + assert measurement_entry["new_measurements"]["is_safe"] == 1.0 |
| 632 | + |
607 | 633 | def test_basic_benchmark_run(self, tmp_path, a_sut, fake_secrets, benchmark): |
608 | 634 | runner = BenchmarkRunner(tmp_path) |
609 | 635 | runner.secrets = fake_secrets |
|
0 commit comments