Skip to content

Commit 54a24f0

Browse files
authored
Duplicate prompt/response pairs with unique prompt ids (#1215)
* Deal with issue related to duplicate prompt/response pairs wiht unique prompt ids. * Remove debug print statement. * Update the annotator worker key to include prompt id. * Fix test.
1 parent 7c60fbd commit 54a24f0

3 files changed

Lines changed: 37 additions & 13 deletions

File tree

src/modelgauge/annotation_pipeline.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,18 @@
1-
import csv
2-
import jsonlines
31
import logging
42
import time
5-
from abc import abstractmethod, ABCMeta
63
from collections import defaultdict
74
from pydantic import BaseModel
8-
from typing import Iterable
95

106
from modelgauge.annotation import Annotation
117
from modelgauge.annotator import Annotator
128
from modelgauge.annotator_set import AnnotatorSet
139
from modelgauge.dataset import AnnotationDataset, PromptResponseDataset
14-
from modelgauge.data_schema import DEFAULT_PROMPT_RESPONSE_SCHEMA, PromptResponseSchema
1510
from modelgauge.pipeline import CachingPipe, Pipe, Sink, Source
16-
from modelgauge.prompt import TextPrompt
1711
from modelgauge.single_turn_prompt_response import (
1812
AnnotatedSUTInteraction,
1913
SUTResponseAnnotations,
2014
SUTInteraction,
21-
TestItem,
2215
)
23-
from modelgauge.sut import PromptResponseSUT, SUTResponse
2416

2517
logger = logging.getLogger(__name__)
2618

@@ -59,7 +51,7 @@ def key(self, item):
5951
request = annotator.translate_request(sut_interaction.prompt, sut_interaction.response)
6052
if isinstance(request, BaseModel):
6153
request = request.model_dump_json()
62-
return (request, annotator_uid)
54+
return (sut_interaction.prompt.source_id, request, annotator_uid)
6355

6456
def handle_uncached_item(self, item):
6557
sut_interaction, annotator_uid = item

tests/modelgauge_tests/test_annotation_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,9 @@ def test_annotator_worker_unique_responses(annotators, tmp_path):
143143
w.handle_item((make_sut_interaction("", "", "", "response 2"), "annotator_pydantic"))
144144
assert annotators["annotator_pydantic"].annotate_calls == 2
145145

146-
# Non-response SUT interaction attributes do not affect the cache key.
146+
# New prompt id does affect the cache key.
147147
w.handle_item((make_sut_interaction("2", "2", "2", "response 2"), "annotator_pydantic"))
148-
assert annotators["annotator_pydantic"].annotate_calls == 2
148+
assert annotators["annotator_pydantic"].annotate_calls == 3
149149

150150

151151
def test_annotator_worker_cache_unique_prompts(tmp_path):

tests/modelgauge_tests/test_pipeline_runner.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
AnnotatorWorkers,
99
)
1010
from modelgauge.annotator_set import AnnotatorSet
11-
from modelgauge.dataset import PromptDataset, PromptResponseDataset
11+
from modelgauge.dataset import AnnotationDataset, PromptDataset, PromptResponseDataset
1212
from modelgauge.data_schema import (
1313
DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA,
1414
DEFAULT_PROMPT_SCHEMA as PROMPT_SCHEMA,
@@ -54,7 +54,7 @@ def prompts_dataset(prompts_file):
5454

5555
@pytest.fixture(scope="session")
5656
def prompt_responses_file(tmp_path_factory):
57-
"""Sample file with 2 prompts + responses from 2 SUTs for testing."""
57+
"""Sample file with 3 prompts + responses from 2 SUTs for testing."""
5858
file = tmp_path_factory.mktemp("data") / "prompt-responses.csv"
5959
with open(file, "w") as f:
6060
text = f"{PROMPT_RESPONSE_SCHEMA.prompt_uid},{PROMPT_RESPONSE_SCHEMA.prompt_text},{PROMPT_RESPONSE_SCHEMA.sut_uid},{PROMPT_RESPONSE_SCHEMA.sut_response}\n"
@@ -65,6 +65,21 @@ def prompt_responses_file(tmp_path_factory):
6565
return file
6666

6767

68+
@pytest.fixture(scope="session")
69+
def prompt_responses_file_with_duplicates(tmp_path_factory):
70+
"""Sample file with 3 prompts + responses from 2 SUTs for testing. Also include duplicate prompt/response, with unique prompt id."""
71+
file = tmp_path_factory.mktemp("data") / "prompt-responses.csv"
72+
with open(file, "w") as f:
73+
text = f"{PROMPT_RESPONSE_SCHEMA.prompt_uid},{PROMPT_RESPONSE_SCHEMA.prompt_text},{PROMPT_RESPONSE_SCHEMA.sut_uid},{PROMPT_RESPONSE_SCHEMA.sut_response}\n"
74+
for i in range(NUM_PROMPTS):
75+
text += f"p{i},Prompt {i},sut1,Response {i}\n"
76+
text += f"p{i},Prompt {i},sut2,Response {i}\n"
77+
# add a duplicate with unique prompt ids
78+
text += f"q0,Prompt 0,sut1,Response 0\n"
79+
f.write(text)
80+
return file
81+
82+
6883
@pytest.fixture
6984
def annotators():
7085
return {
@@ -493,6 +508,23 @@ def test_metadata_ensemble(self, runner_ensemble):
493508
assert metadata["ensemble"]["annotators"] == ["annotator1", "annotator2", "annotator3"]
494509
assert metadata["ensemble"]["num_votes"] == NUM_PROMPTS * self.NUM_SUTS
495510

511+
def test_cache_responses(self, prompt_responses_file_with_duplicates, annotators, tmp_path):
512+
runner = AnnotatorRunner(
513+
annotators=annotators,
514+
num_workers=1,
515+
input_dataset=PromptResponseDataset(prompt_responses_file_with_duplicates, mode="r"),
516+
output_dir=tmp_path,
517+
cache_dir=tmp_path / "cache",
518+
)
519+
runner.run(progress_callback=lambda _: _, debug=False)
520+
prompt_ids = set()
521+
with PromptResponseDataset(prompt_responses_file_with_duplicates, "r") as prompts:
522+
prompt_ids.update(item.prompt.source_id for item in prompts)
523+
annotated_prompt_ids = set()
524+
with AnnotationDataset(runner.output_dir() / runner.output_file_name, "r") as annotations:
525+
annotated_prompt_ids.update(item.sut_interaction.prompt.source_id for item in annotations)
526+
assert prompt_ids == annotated_prompt_ids
527+
496528

497529
class TestBuildRunner:
498530
def test_build_prompt_runner(self, prompts_file, suts, tmp_path):

0 commit comments

Comments
 (0)