Skip to content

Commit afe1cbf

Browse files
authored
Standardize columns + Data refactor (#1113)
* Data schema objects * pipeline runners use data schema for input * Prompt runner outputs each sut response in different row * annotation data schema * New dataset objects * Use PromptDataset as input to PromptRunner. Delete CsvPromptInput * Use PromptResponseDataset instead of CSVPromptOutput in prompt runner * Quote all + only accept csv files in datasets * Replace annotator input objects with PromptResponseDataset * New AnnotatedSUTInteraction object. * Annotation column is a json dict + annotation dataset tests * Annotation dataset dumps annotations as json strings and reads them back as dictionaries * Use AnnotationDataset object in annotation runner * mypy * remove prints
1 parent dd735c8 commit afe1cbf

12 files changed

Lines changed: 1200 additions & 566 deletions

src/modelgauge/annotation_pipeline.py

Lines changed: 28 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -10,91 +10,23 @@
1010
from modelgauge.annotation import Annotation
1111
from modelgauge.annotator import Annotator
1212
from modelgauge.annotator_set import AnnotatorSet
13+
from modelgauge.dataset import AnnotationDataset, PromptResponseDataset
14+
from modelgauge.data_schema import DEFAULT_PROMPT_RESPONSE_SCHEMA, PromptResponseSchema
1315
from modelgauge.pipeline import CachingPipe, Pipe, Sink, Source
1416
from modelgauge.prompt import TextPrompt
15-
from modelgauge.prompt_pipeline import PromptOutput, SutInteraction
16-
from modelgauge.single_turn_prompt_response import SUTResponseAnnotations, TestItem
17+
from modelgauge.single_turn_prompt_response import (
18+
AnnotatedSUTInteraction,
19+
SUTResponseAnnotations,
20+
SUTInteraction,
21+
TestItem,
22+
)
1723
from modelgauge.sut import PromptResponseSUT, SUTResponse
1824

1925
logger = logging.getLogger(__name__)
2026

21-
ANNOTATOR_CSV_INPUT_COLUMNS = ["UID", "Prompt", "SUT", "Response"]
22-
23-
24-
class AnnotatorInput(metaclass=ABCMeta):
25-
@abstractmethod
26-
def __iter__(self) -> Iterable[SutInteraction]:
27-
pass
28-
29-
def __len__(self):
30-
count = 0
31-
for prompt in self:
32-
count += 1
33-
return count
34-
35-
36-
class CsvAnnotatorInput(AnnotatorInput):
37-
def __init__(self, path):
38-
super().__init__()
39-
self.path = path
40-
self._validate_file()
41-
42-
def __iter__(self) -> Iterable[SutInteraction]:
43-
with open(self.path, newline="") as f:
44-
csvreader = csv.DictReader(f)
45-
for row in csvreader:
46-
prompt = TestItem(
47-
prompt=TextPrompt(text=row["Prompt"]),
48-
# Forward the underlying id to help make data tracking easier.
49-
source_id=row["UID"],
50-
# Context can be any type you want.
51-
context=row,
52-
)
53-
response = SUTResponse(text=row["Response"])
54-
yield SutInteraction(prompt, row["SUT"], response)
55-
56-
def _validate_file(self):
57-
with open(self.path, newline="") as f:
58-
csvreader = csv.reader(f)
59-
columns = next(csvreader)
60-
assert all(
61-
c in columns for c in ANNOTATOR_CSV_INPUT_COLUMNS
62-
), f"Invalid input file. Must have columns: {', '.join(ANNOTATOR_CSV_INPUT_COLUMNS)}."
63-
64-
65-
class JsonlAnnotatorOutput(PromptOutput):
66-
def __init__(self, path):
67-
super().__init__()
68-
assert path.suffix.lower() == ".jsonl", f"Invalid output file {path}. Must be of type JSONL."
69-
70-
self.path = path
71-
self.file = None
72-
self.writer = None
73-
74-
def __enter__(self):
75-
self.file = open(self.path, "w", newline="")
76-
self.writer = jsonlines.Writer(self.file)
77-
return self
78-
79-
def __exit__(self, exc_type, exc_val, exc_tb):
80-
self.writer.close()
81-
self.file.close()
82-
83-
def write(self, item: SutInteraction, results):
84-
if not isinstance(item.prompt.prompt, TextPrompt):
85-
raise Exception(f"Error handling {item}. Can only handle TextPrompts.")
86-
output_obj = {
87-
"UID": item.prompt.source_id,
88-
"Prompt": item.prompt.prompt.text,
89-
"SUT": item.sut_uid,
90-
"Response": item.response.text,
91-
"Annotations": results,
92-
}
93-
self.writer.write(output_obj)
94-
9527

9628
class AnnotatorSource(Source):
97-
def __init__(self, input: AnnotatorInput):
29+
def __init__(self, input: PromptResponseDataset):
9830
super().__init__()
9931
self.input = input
10032

@@ -107,7 +39,7 @@ def __init__(self, annotators: dict[str, Annotator]):
10739
super().__init__()
10840
self.annotators = annotators
10941

110-
def handle_item(self, item: SutInteraction):
42+
def handle_item(self, item: SUTInteraction):
11143
for annotator_uid in self.annotators:
11244
self.downstream_put((item, annotator_uid))
11345

@@ -146,7 +78,7 @@ def handle_uncached_item(self, item):
14678
time.sleep(self.sleep_time)
14779
result = annotator.translate_response(request, response)
14880
self.annotation_counts[annotator_uid] += 1
149-
return sut_interaction, annotator_uid, result
81+
return AnnotatedSUTInteraction(annotator_uid=annotator_uid, annotation=result, sut_interaction=sut_interaction)
15082

15183

15284
class EnsembleVoter(Pipe):
@@ -159,47 +91,36 @@ def __init__(self, ensemble: AnnotatorSet):
15991
def handle_item(self, item):
16092
# Always pass the original item through
16193
self.downstream_put(item)
162-
sut_interaction, annotator_uid, annotation = item
163-
if annotator_uid in self.ensemble.annotators:
164-
self.annotations[sut_interaction][annotator_uid] = annotation
165-
if len(self.annotations[sut_interaction]) == len(self.ensemble.annotators):
94+
if item.annotator_uid in self.ensemble.annotators:
95+
self.annotations[item.sut_interaction][item.annotator_uid] = item.annotation
96+
if len(self.annotations[item.sut_interaction]) == len(self.ensemble.annotators):
16697
# All annotators have responded, so we can compute the ensemble response.
167-
annotations = {k: Annotation.from_instance(v) for k, v in self.annotations[sut_interaction].items()}
98+
annotations = {
99+
k: Annotation.from_instance(v) for k, v in self.annotations[item.sut_interaction].items()
100+
}
168101
result = self.ensemble.evaluate(
169102
SUTResponseAnnotations(
170-
test_item=sut_interaction.prompt, sut_response=sut_interaction.response, annotations=annotations
103+
test_item=item.sut_interaction.prompt,
104+
sut_response=item.sut_interaction.response,
105+
annotations=annotations,
106+
)
107+
)
108+
self.downstream_put(
109+
AnnotatedSUTInteraction(
110+
annotator_uid="ensemble", annotation=result, sut_interaction=item.sut_interaction
171111
)
172112
)
173-
self.downstream_put((sut_interaction, "ensemble", result))
174113
self.num_ensemble_votes += 1
175114

176115

177116
class AnnotatorSink(Sink):
178-
unfinished: defaultdict[SutInteraction, dict[str, str]]
179-
180-
def __init__(self, annotators: dict[str, Annotator], writer: JsonlAnnotatorOutput, ensemble: bool = False):
117+
def __init__(self, writer: AnnotationDataset):
181118
super().__init__()
182-
self.annotators = annotators
183-
self.ensemble = ensemble
184119
self.writer = writer
185-
self.unfinished = defaultdict(lambda: dict())
186120

187121
def run(self):
188122
with self.writer:
189123
super().run()
190124

191-
def interaction_is_complete(self, sut_interaction) -> bool:
192-
num_expected_annotations = len(self.annotators)
193-
if self.ensemble:
194-
num_expected_annotations += 1
195-
return len(self.unfinished[sut_interaction]) == num_expected_annotations
196-
197-
def handle_item(self, item):
198-
sut_interaction, annotator_uid, annotation = item
199-
if isinstance(annotation, BaseModel):
200-
annotation = annotation.model_dump()
201-
self.unfinished[sut_interaction][annotator_uid] = annotation
202-
if self.interaction_is_complete(sut_interaction):
203-
self.writer.write(sut_interaction, self.unfinished[sut_interaction])
204-
self._debug(f"wrote {sut_interaction}")
205-
del self.unfinished[sut_interaction]
125+
def handle_item(self, item: AnnotatedSUTInteraction):
126+
self.writer.write(item)

src/modelgauge/data_schema.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
# The first value is the preferred name.
2+
PROMPT_UID_COLS = ["prompt_uid", "release_prompt_id"]
3+
PROMPT_TEXT_COLS = ["prompt_text"]
4+
SUT_UID_COLS = ["sut_uid", "sut"]
5+
SUT_RESPONSE_COLS = ["sut_response", "response_text", "response"]
6+
ANNOTATOR_UID_COLS = ["annotator_uid"]
7+
ANNOTATION_COLS = ["annotation_json"]
8+
9+
10+
class SchemaValidationError(ValueError):
11+
"""Exception raised when schema validation fails."""
12+
13+
def __init__(self, missing_columns):
14+
"""missing_columns: a list where each element is a string or a list of strings. List elements are used to indicate that the column can be one of several options."""
15+
self.missing_columns = missing_columns
16+
super().__init__(str(self))
17+
18+
def __str__(self):
19+
message = "Missing required columns:"
20+
for column in self.missing_columns:
21+
if isinstance(column, str):
22+
message += f"\n\t{column}"
23+
elif len(column) == 1:
24+
message += f"\n\t{column[0]}"
25+
else:
26+
message += f"\n\tone of: {column}"
27+
return message
28+
29+
30+
class PromptSchema:
31+
"""A case-insensitive schema for a prompts file that is used as input to get SUT responses.
32+
33+
Attributes:
34+
prompt_uid: The column name for the prompt uid.
35+
prompt_text: The column name for the prompt text.
36+
"""
37+
38+
def __init__(self, header: list[str]):
39+
self.header = header
40+
self.prompt_uid = self._find_column(header, PROMPT_UID_COLS)
41+
self.prompt_text = self._find_column(header, PROMPT_TEXT_COLS)
42+
self._validate()
43+
44+
def _find_column(self, header, columns):
45+
return next((col for col in header if col.lower() in columns), None)
46+
47+
def _validate(self):
48+
"""Validates that all required columns were found in the header.
49+
50+
Raises:
51+
SchemaValidationError: If any required columns are missing.
52+
"""
53+
missing = []
54+
if self.prompt_uid is None:
55+
missing.append(PROMPT_UID_COLS)
56+
if self.prompt_text is None:
57+
missing.append(PROMPT_TEXT_COLS)
58+
59+
if missing:
60+
raise SchemaValidationError(missing)
61+
62+
63+
class PromptResponseSchema(PromptSchema):
64+
"""A schema for a prompt + response file that is used as prompt-response output or annotation input.
65+
Attributes:
66+
prompt_uid: The column name for the prompt uid. (same as PromptSchema)
67+
prompt_text: The column name for the prompt text. (same as PromptSchema)
68+
sut_uid: The column name for the SUT uid.
69+
sut_response: The column name for the SUT response.
70+
"""
71+
72+
def __init__(self, header: list[str]):
73+
self.sut_uid = self._find_column(header, SUT_UID_COLS)
74+
self.sut_response = self._find_column(header, SUT_RESPONSE_COLS)
75+
super().__init__(header) # Iniitalize the prompt schema columns and then validate.
76+
77+
def _validate(self):
78+
missing = []
79+
# Validate that the prompt schema is valid
80+
try:
81+
super()._validate()
82+
except SchemaValidationError as e:
83+
missing.extend(e.missing_columns)
84+
# Validate that the SUT uid and response columns are present
85+
if self.sut_uid is None:
86+
missing.append(SUT_UID_COLS)
87+
if self.sut_response is None:
88+
missing.append(SUT_RESPONSE_COLS)
89+
if missing:
90+
raise SchemaValidationError(missing)
91+
92+
93+
class AnnotationSchema(PromptResponseSchema):
94+
"""A schema for a prompt + response + annotation file that is used as annotation output.
95+
Attributes:
96+
prompt_uid: The column name for the prompt uid. (same as PromptSchema)
97+
prompt_text: The column name for the prompt text. (same as PromptSchema)
98+
sut_uid: The column name for the SUT uid. (same as PromptResponseSchema)
99+
sut_response: The column name for the SUT response. (same as PromptResponseSchema)
100+
annotator_uid: The column name for the annotator uid.
101+
annotation: The column name for the text annotation.
102+
"""
103+
104+
def __init__(self, header: list[str]):
105+
self.annotator_uid = self._find_column(header, ANNOTATOR_UID_COLS)
106+
self.annotation = self._find_column(header, ANNOTATION_COLS)
107+
super().__init__(header) # Iniitalize the prompt schema columns and then validate.
108+
109+
def _validate(self):
110+
missing = []
111+
# Validate that the prompt schema is valid
112+
try:
113+
super()._validate()
114+
except SchemaValidationError as e:
115+
missing.extend(e.missing_columns)
116+
# Validate that the SUT uid and response columns are present
117+
if self.annotator_uid is None:
118+
missing.append(ANNOTATOR_UID_COLS)
119+
if self.annotation is None:
120+
missing.append(ANNOTATION_COLS)
121+
if missing:
122+
raise SchemaValidationError(missing)
123+
124+
125+
# Schemas with preferred names.
126+
DEFAULT_PROMPT_SCHEMA = PromptSchema([PROMPT_UID_COLS[0], PROMPT_TEXT_COLS[0]])
127+
DEFAULT_PROMPT_RESPONSE_SCHEMA = PromptResponseSchema(
128+
[PROMPT_UID_COLS[0], PROMPT_TEXT_COLS[0], SUT_UID_COLS[0], SUT_RESPONSE_COLS[0]]
129+
)
130+
DEFAULT_ANNOTATION_SCHEMA = AnnotationSchema(
131+
[
132+
PROMPT_UID_COLS[0],
133+
PROMPT_TEXT_COLS[0],
134+
SUT_UID_COLS[0],
135+
SUT_RESPONSE_COLS[0],
136+
ANNOTATOR_UID_COLS[0],
137+
ANNOTATION_COLS[0],
138+
]
139+
)

0 commit comments

Comments
 (0)