Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 28 additions & 107 deletions src/modelgauge/annotation_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,91 +10,23 @@
from modelgauge.annotation import Annotation
from modelgauge.annotator import Annotator
from modelgauge.annotator_set import AnnotatorSet
from modelgauge.dataset import AnnotationDataset, PromptResponseDataset
from modelgauge.data_schema import DEFAULT_PROMPT_RESPONSE_SCHEMA, PromptResponseSchema
from modelgauge.pipeline import CachingPipe, Pipe, Sink, Source
from modelgauge.prompt import TextPrompt
from modelgauge.prompt_pipeline import PromptOutput, SutInteraction
from modelgauge.single_turn_prompt_response import SUTResponseAnnotations, TestItem
from modelgauge.single_turn_prompt_response import (
AnnotatedSUTInteraction,
SUTResponseAnnotations,
SUTInteraction,
TestItem,
)
from modelgauge.sut import PromptResponseSUT, SUTResponse

logger = logging.getLogger(__name__)

ANNOTATOR_CSV_INPUT_COLUMNS = ["UID", "Prompt", "SUT", "Response"]


class AnnotatorInput(metaclass=ABCMeta):
@abstractmethod
def __iter__(self) -> Iterable[SutInteraction]:
pass

def __len__(self):
count = 0
for prompt in self:
count += 1
return count


class CsvAnnotatorInput(AnnotatorInput):
def __init__(self, path):
super().__init__()
self.path = path
self._validate_file()

def __iter__(self) -> Iterable[SutInteraction]:
with open(self.path, newline="") as f:
csvreader = csv.DictReader(f)
for row in csvreader:
prompt = TestItem(
prompt=TextPrompt(text=row["Prompt"]),
# Forward the underlying id to help make data tracking easier.
source_id=row["UID"],
# Context can be any type you want.
context=row,
)
response = SUTResponse(text=row["Response"])
yield SutInteraction(prompt, row["SUT"], response)

def _validate_file(self):
with open(self.path, newline="") as f:
csvreader = csv.reader(f)
columns = next(csvreader)
assert all(
c in columns for c in ANNOTATOR_CSV_INPUT_COLUMNS
), f"Invalid input file. Must have columns: {', '.join(ANNOTATOR_CSV_INPUT_COLUMNS)}."


class JsonlAnnotatorOutput(PromptOutput):
def __init__(self, path):
super().__init__()
assert path.suffix.lower() == ".jsonl", f"Invalid output file {path}. Must be of type JSONL."

self.path = path
self.file = None
self.writer = None

def __enter__(self):
self.file = open(self.path, "w", newline="")
self.writer = jsonlines.Writer(self.file)
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.writer.close()
self.file.close()

def write(self, item: SutInteraction, results):
if not isinstance(item.prompt.prompt, TextPrompt):
raise Exception(f"Error handling {item}. Can only handle TextPrompts.")
output_obj = {
"UID": item.prompt.source_id,
"Prompt": item.prompt.prompt.text,
"SUT": item.sut_uid,
"Response": item.response.text,
"Annotations": results,
}
self.writer.write(output_obj)


class AnnotatorSource(Source):
def __init__(self, input: AnnotatorInput):
def __init__(self, input: PromptResponseDataset):
super().__init__()
self.input = input

Expand All @@ -107,7 +39,7 @@ def __init__(self, annotators: dict[str, Annotator]):
super().__init__()
self.annotators = annotators

def handle_item(self, item: SutInteraction):
def handle_item(self, item: SUTInteraction):
for annotator_uid in self.annotators:
self.downstream_put((item, annotator_uid))

Expand Down Expand Up @@ -146,7 +78,7 @@ def handle_uncached_item(self, item):
time.sleep(self.sleep_time)
result = annotator.translate_response(request, response)
self.annotation_counts[annotator_uid] += 1
return sut_interaction, annotator_uid, result
return AnnotatedSUTInteraction(annotator_uid=annotator_uid, annotation=result, sut_interaction=sut_interaction)


class EnsembleVoter(Pipe):
Expand All @@ -159,47 +91,36 @@ def __init__(self, ensemble: AnnotatorSet):
def handle_item(self, item):
# Always pass the original item through
self.downstream_put(item)
sut_interaction, annotator_uid, annotation = item
if annotator_uid in self.ensemble.annotators:
self.annotations[sut_interaction][annotator_uid] = annotation
if len(self.annotations[sut_interaction]) == len(self.ensemble.annotators):
if item.annotator_uid in self.ensemble.annotators:
self.annotations[item.sut_interaction][item.annotator_uid] = item.annotation
if len(self.annotations[item.sut_interaction]) == len(self.ensemble.annotators):
# All annotators have responded, so we can compute the ensemble response.
annotations = {k: Annotation.from_instance(v) for k, v in self.annotations[sut_interaction].items()}
annotations = {
k: Annotation.from_instance(v) for k, v in self.annotations[item.sut_interaction].items()
}
result = self.ensemble.evaluate(
SUTResponseAnnotations(
test_item=sut_interaction.prompt, sut_response=sut_interaction.response, annotations=annotations
test_item=item.sut_interaction.prompt,
sut_response=item.sut_interaction.response,
annotations=annotations,
)
)
self.downstream_put(
AnnotatedSUTInteraction(
annotator_uid="ensemble", annotation=result, sut_interaction=item.sut_interaction
)
)
self.downstream_put((sut_interaction, "ensemble", result))
self.num_ensemble_votes += 1


class AnnotatorSink(Sink):
unfinished: defaultdict[SutInteraction, dict[str, str]]

def __init__(self, annotators: dict[str, Annotator], writer: JsonlAnnotatorOutput, ensemble: bool = False):
def __init__(self, writer: AnnotationDataset):
super().__init__()
self.annotators = annotators
self.ensemble = ensemble
self.writer = writer
self.unfinished = defaultdict(lambda: dict())

def run(self):
with self.writer:
super().run()

def interaction_is_complete(self, sut_interaction) -> bool:
num_expected_annotations = len(self.annotators)
if self.ensemble:
num_expected_annotations += 1
return len(self.unfinished[sut_interaction]) == num_expected_annotations

def handle_item(self, item):
sut_interaction, annotator_uid, annotation = item
if isinstance(annotation, BaseModel):
annotation = annotation.model_dump()
self.unfinished[sut_interaction][annotator_uid] = annotation
if self.interaction_is_complete(sut_interaction):
self.writer.write(sut_interaction, self.unfinished[sut_interaction])
self._debug(f"wrote {sut_interaction}")
del self.unfinished[sut_interaction]
def handle_item(self, item: AnnotatedSUTInteraction):
self.writer.write(item)
139 changes: 139 additions & 0 deletions src/modelgauge/data_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# The first value is the preferred name.
PROMPT_UID_COLS = ["prompt_uid", "release_prompt_id"]
PROMPT_TEXT_COLS = ["prompt_text"]
SUT_UID_COLS = ["sut_uid", "sut"]
SUT_RESPONSE_COLS = ["sut_response", "response_text", "response"]
ANNOTATOR_UID_COLS = ["annotator_uid"]
ANNOTATION_COLS = ["annotation_json"]


class SchemaValidationError(ValueError):
"""Exception raised when schema validation fails."""

def __init__(self, missing_columns):
"""missing_columns: a list where each element is a string or a list of strings. List elements are used to indicate that the column can be one of several options."""
self.missing_columns = missing_columns
super().__init__(str(self))

def __str__(self):
message = "Missing required columns:"
for column in self.missing_columns:
if isinstance(column, str):
message += f"\n\t{column}"
elif len(column) == 1:
message += f"\n\t{column[0]}"
else:
message += f"\n\tone of: {column}"
return message


class PromptSchema:
"""A case-insensitive schema for a prompts file that is used as input to get SUT responses.

Attributes:
prompt_uid: The column name for the prompt uid.
prompt_text: The column name for the prompt text.
"""

def __init__(self, header: list[str]):
self.header = header
self.prompt_uid = self._find_column(header, PROMPT_UID_COLS)
self.prompt_text = self._find_column(header, PROMPT_TEXT_COLS)
self._validate()

def _find_column(self, header, columns):
return next((col for col in header if col.lower() in columns), None)

def _validate(self):
"""Validates that all required columns were found in the header.

Raises:
SchemaValidationError: If any required columns are missing.
"""
missing = []
if self.prompt_uid is None:
missing.append(PROMPT_UID_COLS)
if self.prompt_text is None:
missing.append(PROMPT_TEXT_COLS)

if missing:
raise SchemaValidationError(missing)


class PromptResponseSchema(PromptSchema):
"""A schema for a prompt + response file that is used as prompt-response output or annotation input.
Attributes:
prompt_uid: The column name for the prompt uid. (same as PromptSchema)
prompt_text: The column name for the prompt text. (same as PromptSchema)
sut_uid: The column name for the SUT uid.
sut_response: The column name for the SUT response.
"""

def __init__(self, header: list[str]):
self.sut_uid = self._find_column(header, SUT_UID_COLS)
self.sut_response = self._find_column(header, SUT_RESPONSE_COLS)
super().__init__(header) # Iniitalize the prompt schema columns and then validate.

def _validate(self):
missing = []
# Validate that the prompt schema is valid
try:
super()._validate()
except SchemaValidationError as e:
missing.extend(e.missing_columns)
# Validate that the SUT uid and response columns are present
if self.sut_uid is None:
missing.append(SUT_UID_COLS)
if self.sut_response is None:
missing.append(SUT_RESPONSE_COLS)
if missing:
raise SchemaValidationError(missing)


class AnnotationSchema(PromptResponseSchema):
"""A schema for a prompt + response + annotation file that is used as annotation output.
Attributes:
prompt_uid: The column name for the prompt uid. (same as PromptSchema)
prompt_text: The column name for the prompt text. (same as PromptSchema)
sut_uid: The column name for the SUT uid. (same as PromptResponseSchema)
sut_response: The column name for the SUT response. (same as PromptResponseSchema)
annotator_uid: The column name for the annotator uid.
annotation: The column name for the text annotation.
"""

def __init__(self, header: list[str]):
self.annotator_uid = self._find_column(header, ANNOTATOR_UID_COLS)
self.annotation = self._find_column(header, ANNOTATION_COLS)
super().__init__(header) # Iniitalize the prompt schema columns and then validate.

def _validate(self):
missing = []
# Validate that the prompt schema is valid
try:
super()._validate()
except SchemaValidationError as e:
missing.extend(e.missing_columns)
# Validate that the SUT uid and response columns are present
if self.annotator_uid is None:
missing.append(ANNOTATOR_UID_COLS)
if self.annotation is None:
missing.append(ANNOTATION_COLS)
if missing:
raise SchemaValidationError(missing)


# Schemas with preferred names.
DEFAULT_PROMPT_SCHEMA = PromptSchema([PROMPT_UID_COLS[0], PROMPT_TEXT_COLS[0]])
DEFAULT_PROMPT_RESPONSE_SCHEMA = PromptResponseSchema(
[PROMPT_UID_COLS[0], PROMPT_TEXT_COLS[0], SUT_UID_COLS[0], SUT_RESPONSE_COLS[0]]
)
DEFAULT_ANNOTATION_SCHEMA = AnnotationSchema(
[
PROMPT_UID_COLS[0],
PROMPT_TEXT_COLS[0],
SUT_UID_COLS[0],
SUT_RESPONSE_COLS[0],
ANNOTATOR_UID_COLS[0],
ANNOTATION_COLS[0],
]
)
Loading