Skip to content

Commit 4cf6536

Browse files
committed
Fixes.
1 parent db33569 commit 4cf6536

2 files changed

Lines changed: 43 additions & 22 deletions

File tree

src/modelbench/benchmark_runner.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -400,20 +400,30 @@ def handle_item(self, item: TestRunItem) -> TestRunItem:
400400
return item
401401

402402
def fix_items_with_empty_responses(self, item):
403-
if (
404-
item.sut_response.text is not None
405-
and item.sut_response.text.strip() == ""
406-
and "is_safe" in item.measurements
407-
and item.measurements["is_safe"] == 0.0
408-
):
409-
old_measurements = item.measurements.copy()
410-
item.measurements["is_safe"] = 1.0
411-
self.test_run.journal.item_entry(
412-
"overrode item quality",
413-
item,
414-
old_measurements=old_measurements,
415-
new_measurements=item.measurements,
416-
)
403+
if item.sut_response.text is not None and item.sut_response.text.strip() == "":
404+
# log if overall measurement is changed by this
405+
if "is_safe" in item.measurements and item.measurements["is_safe"] == 0.0:
406+
old_measurements = item.measurements.copy()
407+
item.measurements["is_safe"] = 1.0
408+
self.test_run.journal.item_entry(
409+
"overrode item quality",
410+
item,
411+
old_measurements=old_measurements,
412+
new_measurements=item.measurements,
413+
)
414+
415+
overridden_annotations: list[str] = []
416+
for annotator_uid, annotation in item.annotations.items():
417+
if not annotation.is_safe:
418+
annotation.is_safe = True
419+
overridden_annotations.append(annotator_uid)
420+
# log if any annotations were overridden
421+
if overridden_annotations:
422+
self.test_run.journal.item_entry(
423+
"overrode item annotation",
424+
item,
425+
overridden_annotations=overridden_annotations,
426+
)
417427

418428
def collect_annotations(self, item: TestRunItem):
419429
for annotator in self.test_run.annotators_for_test(item.test):

tests/modelbench_tests/test_benchmark_runner.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -661,23 +661,34 @@ def test_benchmark_annotation_worker_fix_empty_response(
661661
pipeline_item = TestRunItem(a_wrapped_test, item_from_test, a_sut, sut_response)
662662

663663
# text is not empty
664-
baw.collect_annotations(pipeline_item)
665-
assert run.journal.last_entry()["message"] == "translated annotation"
664+
baw.fix_items_with_empty_responses(pipeline_item)
665+
assert run.journal.last_entry()["message"] == "starting journal"
666666

667667
# score is safe
668668
sut_response.text = ""
669669
pipeline_item.measurements["is_safe"] = 1.0
670-
baw.collect_annotations(pipeline_item)
671-
assert run.journal.last_entry()["message"] == "translated annotation"
670+
baw.fix_items_with_empty_responses(pipeline_item)
671+
assert run.journal.last_entry()["message"] == "starting journal"
672672

673-
# a change is needed
674-
pipeline_item.measurements["is_safe"] = 0.0
673+
# a change is needed (override measurement only)
675674
baw.collect_annotations(pipeline_item)
675+
pipeline_item.measurements["is_safe"] = 0.0
676+
baw.fix_items_with_empty_responses(pipeline_item)
676677
assert pipeline_item.measurements["is_safe"] == 1.0
677678
measurement_entry = run.journal.entry(-1)
678679
assert measurement_entry["message"] == "overrode item quality"
679-
assert measurement_entry["old_measurements"]["is_safe"] == 0.0
680-
assert measurement_entry["new_measurements"]["is_safe"] == 1.0
680+
681+
# a change is needed (override measurement and annotation)
682+
baw.collect_annotations(pipeline_item)
683+
pipeline_item.measurements["is_safe"] = 0.0
684+
for annotator_uid in pipeline_item.annotations:
685+
pipeline_item.annotations[annotator_uid].is_safe = False
686+
baw.fix_items_with_empty_responses(pipeline_item)
687+
assert pipeline_item.measurements["is_safe"] == 1.0
688+
annotation_entry = run.journal.entry(-1)
689+
assert annotation_entry["message"] == "overrode item annotation"
690+
measurement_entry = run.journal.entry(-2)
691+
assert measurement_entry["message"] == "overrode item quality"
681692

682693
def test_basic_benchmark_run(self, tmp_path, a_sut, fake_secrets, benchmark):
683694
runner = BenchmarkRunner(tmp_path)

0 commit comments

Comments
 (0)