From dcac74406bfd029b2902b86a706724b96947914d Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Wed, 18 Jun 2025 13:43:58 -0700 Subject: [PATCH 01/15] Data schema objects --- src/modelgauge/data_schema.py | 94 ++++++++++++++++++++++ tests/modelgauge_tests/test_data_schema.py | 73 +++++++++++++++++ 2 files changed, 167 insertions(+) create mode 100644 src/modelgauge/data_schema.py create mode 100644 tests/modelgauge_tests/test_data_schema.py diff --git a/src/modelgauge/data_schema.py b/src/modelgauge/data_schema.py new file mode 100644 index 000000000..38af5a6bd --- /dev/null +++ b/src/modelgauge/data_schema.py @@ -0,0 +1,94 @@ +ANNOTATOR_CSV_INPUT_COLUMNS = ["UID", "Prompt", "SUT", "Response"] +PROMPT_CSV_INPUT_COLUMNS = { + "default": {"id": "UID", "text": "Text"}, + "prompt_set": {"id": "release_prompt_id", "text": "prompt_text"}, # official prompt set files + "db": {"id": "prompt_uid", "text": "prompt_text"}, # database dumps +} + +# The first value is the preferred name. +PROMPT_UID_COLS = ["prompt_uid", "release_prompt_id"] +PROMPT_TEXT_COLS = ["prompt_text"] +SUT_UID_COLS = ["sut_uid", "sut"] +SUT_RESPONSE_COLS = ["sut_response", "response_text", "response"] + + +class SchemaValidationError(ValueError): + """Exception raised when schema validation fails.""" + + def __init__(self, missing_columns): + """missing_columns: a list where each element is a string or a list of strings. List elements are used to indicate that the column can be one of several options.""" + self.missing_columns = missing_columns + super().__init__(str(self)) + + def __str__(self): + message = "Missing required columns:" + for column in self.missing_columns: + if isinstance(column, str): + message += f"\n\t{column}" + elif len(column) == 1: + message += f"\n\t{column[0]}" + else: + message += f"\n\tone of: {column}" + return message + + +class PromptSchema: + """A case-insensitive schema for a prompts file. + + Attributes: + prompt_uid: The column name for the prompt uid. + prompt_text: The column name for the prompt text. + """ + + def __init__(self, header: list[str]): + self.prompt_uid = self._find_column(header, PROMPT_UID_COLS) + self.prompt_text = self._find_column(header, PROMPT_TEXT_COLS) + self._validate() + + def _find_column(self, header, columns): + return next((col for col in header if col.lower() in columns), None) + + def _validate(self): + """Validates that all required columns were found in the header. + + Raises: + SchemaValidationError: If any required columns are missing. + """ + missing = [] + if self.prompt_uid is None: + missing.append(PROMPT_UID_COLS) + if self.prompt_text is None: + missing.append(PROMPT_TEXT_COLS) + + if missing: + raise SchemaValidationError(missing) + + +class PromptResponseSchema(PromptSchema): + """A schema for a prompt + response file that is used as annotation input. + Attributes: + prompt_uid: The column name for the prompt uid. (same as PromptSchema) + prompt_text: The column name for the prompt text. (same as PromptSchema) + sut_uid: The column name for the SUT uid. + sut_response: The column name for the SUT response. + """ + + def __init__(self, header: list[str]): + self.sut_uid = self._find_column(header, SUT_UID_COLS) + self.sut_response = self._find_column(header, SUT_RESPONSE_COLS) + super().__init__(header) # Iniitalize the prompt schema columns and then validate. + + def _validate(self): + missing = [] + # Validate that the prompt schema is valid + try: + super()._validate() + except SchemaValidationError as e: + missing.extend(e.missing_columns) + # Validate that the SUT uid and response columns are present + if self.sut_uid is None: + missing.append(SUT_UID_COLS) + if self.sut_response is None: + missing.append(SUT_RESPONSE_COLS) + if missing: + raise SchemaValidationError(missing) diff --git a/tests/modelgauge_tests/test_data_schema.py b/tests/modelgauge_tests/test_data_schema.py new file mode 100644 index 000000000..9054c1abf --- /dev/null +++ b/tests/modelgauge_tests/test_data_schema.py @@ -0,0 +1,73 @@ +import pytest + +from modelgauge.data_schema import ( + PROMPT_TEXT_COLS, + PROMPT_UID_COLS, + PromptResponseSchema, + PromptSchema, + SchemaValidationError, +) + + +def test_schema_validation_error(): + error = SchemaValidationError(["one", "two"]) + assert error.missing_columns == ["one", "two"] + assert str(error) == "Missing required columns:\n\tone\n\ttwo" + + +def test_schema_validation_error_multiple_options(): + error = SchemaValidationError([["one", "a"], "two"]) + assert error.missing_columns == [["one", "a"], "two"] + assert str(error) == "Missing required columns:\n\tone of: ['one', 'a']\n\ttwo" + + +@pytest.mark.parametrize( + "header", + [ + ["prompt_uid", "prompt_text"], # Preferred names. + ["Prompt_UID", "Prompt_Text"], # Case-insensitive + ["release_prompt_id", "prompt_text"], + ["release_prompt_id", "prompt_text", "random_column"], # Extra columns are allowed. + ], +) +def test_valid_prompt_schema(header): + schema = PromptSchema(header) + assert schema.prompt_uid == header[0] + assert schema.prompt_text == header[1] + + +def test_invalid_prompt_schema(): + header = ["random_column", "random_column_2"] + with pytest.raises(SchemaValidationError) as e: + schema = PromptSchema(header) + assert set(e.missing_columns) == {PROMPT_UID_COLS, PROMPT_TEXT_COLS} + + +@pytest.mark.parametrize( + "header", + [ + ["prompt_uid", "prompt_text", "sut_uid", "sut_response"], # Preferred names. + ["prompt_UID", "Prompt_Text", "SUT_UID", "SUT_Response"], # Case-insensitive + ["release_prompt_id", "prompt_text", "sut", "response"], + ], +) +def test_valid_prompt_response_schema(header): + schema = PromptResponseSchema(header) + assert schema.prompt_uid == header[0] + assert schema.prompt_text == header[1] + assert schema.sut_uid == header[2] + assert schema.sut_response == header[3] + + +def test_valid_prompt_invalid_response_schema(): + header = ["prompt_uid", "prompt_text", "random_column", "random_column_2"] + with pytest.raises(SchemaValidationError) as e: + schema = PromptResponseSchema(header) + assert set(e.missing_columns) == {SUT_UID_COLS, SUT_RESPONSE_COLS} + + +def test_invalid_prompt_valid_response_schema(): + header = ["random_column", "random_column_2", "prompt_uid", "prompt_text"] + with pytest.raises(SchemaValidationError) as e: + schema = PromptResponseSchema(header) + assert set(e.missing_columns) == {PROMPT_UID_COLS, PROMPT_TEXT_COLS} From e29a34340b8ea9b231b9bc0b6e5c55ed64dce7e2 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Wed, 18 Jun 2025 16:55:11 -0700 Subject: [PATCH 02/15] pipeline runners use data schema for input --- src/modelgauge/annotation_pipeline.py | 35 ++++---- src/modelgauge/data_schema.py | 14 +-- src/modelgauge/prompt_pipeline.py | 40 +++------ .../test_annotation_pipeline.py | 88 +++++++++++-------- tests/modelgauge_tests/test_cli.py | 50 ++++++----- tests/modelgauge_tests/test_data_schema.py | 14 +++ .../modelgauge_tests/test_pipeline_runner.py | 8 +- .../modelgauge_tests/test_prompt_pipeline.py | 45 +++------- 8 files changed, 146 insertions(+), 148 deletions(-) diff --git a/src/modelgauge/annotation_pipeline.py b/src/modelgauge/annotation_pipeline.py index 23a28ba51..8750914f1 100644 --- a/src/modelgauge/annotation_pipeline.py +++ b/src/modelgauge/annotation_pipeline.py @@ -10,6 +10,7 @@ from modelgauge.annotation import Annotation from modelgauge.annotator import Annotator from modelgauge.annotator_set import AnnotatorSet +from modelgauge.data_schema import DEFAULT_PROMPT_RESPONSE_SCHEMA, PromptResponseSchema from modelgauge.pipeline import CachingPipe, Pipe, Sink, Source from modelgauge.prompt import TextPrompt from modelgauge.prompt_pipeline import PromptOutput, SutInteraction @@ -18,8 +19,6 @@ logger = logging.getLogger(__name__) -ANNOTATOR_CSV_INPUT_COLUMNS = ["UID", "Prompt", "SUT", "Response"] - class AnnotatorInput(metaclass=ABCMeta): @abstractmethod @@ -37,29 +36,26 @@ class CsvAnnotatorInput(AnnotatorInput): def __init__(self, path): super().__init__() self.path = path - self._validate_file() + self.schema = PromptResponseSchema(self._header()) # Validate header and store the schema. + + def _header(self) -> list[str]: + with open(self.path, newline="") as f: + csvreader = csv.reader(f) + return next(csvreader) def __iter__(self) -> Iterable[SutInteraction]: with open(self.path, newline="") as f: csvreader = csv.DictReader(f) for row in csvreader: prompt = TestItem( - prompt=TextPrompt(text=row["Prompt"]), + prompt=TextPrompt(text=row[self.schema.prompt_text]), # Forward the underlying id to help make data tracking easier. - source_id=row["UID"], + source_id=row[self.schema.prompt_uid], # Context can be any type you want. context=row, ) - response = SUTResponse(text=row["Response"]) - yield SutInteraction(prompt, row["SUT"], response) - - def _validate_file(self): - with open(self.path, newline="") as f: - csvreader = csv.reader(f) - columns = next(csvreader) - assert all( - c in columns for c in ANNOTATOR_CSV_INPUT_COLUMNS - ), f"Invalid input file. Must have columns: {', '.join(ANNOTATOR_CSV_INPUT_COLUMNS)}." + response = SUTResponse(text=row[self.schema.sut_response]) + yield SutInteraction(prompt, row[self.schema.sut_uid], response) class JsonlAnnotatorOutput(PromptOutput): @@ -83,11 +79,12 @@ def __exit__(self, exc_type, exc_val, exc_tb): def write(self, item: SutInteraction, results): if not isinstance(item.prompt.prompt, TextPrompt): raise Exception(f"Error handling {item}. Can only handle TextPrompts.") + # TODO: Standardize annotation schema. output_obj = { - "UID": item.prompt.source_id, - "Prompt": item.prompt.prompt.text, - "SUT": item.sut_uid, - "Response": item.response.text, + DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid: item.prompt.source_id, + DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text: item.prompt.prompt.text, + DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid: item.sut_uid, + DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response: item.response.text, "Annotations": results, } self.writer.write(output_obj) diff --git a/src/modelgauge/data_schema.py b/src/modelgauge/data_schema.py index 38af5a6bd..c02d16e52 100644 --- a/src/modelgauge/data_schema.py +++ b/src/modelgauge/data_schema.py @@ -1,10 +1,3 @@ -ANNOTATOR_CSV_INPUT_COLUMNS = ["UID", "Prompt", "SUT", "Response"] -PROMPT_CSV_INPUT_COLUMNS = { - "default": {"id": "UID", "text": "Text"}, - "prompt_set": {"id": "release_prompt_id", "text": "prompt_text"}, # official prompt set files - "db": {"id": "prompt_uid", "text": "prompt_text"}, # database dumps -} - # The first value is the preferred name. PROMPT_UID_COLS = ["prompt_uid", "release_prompt_id"] PROMPT_TEXT_COLS = ["prompt_text"] @@ -92,3 +85,10 @@ def _validate(self): missing.append(SUT_RESPONSE_COLS) if missing: raise SchemaValidationError(missing) + + +# Schemas with preferred names. +DEFAULT_PROMPT_SCHEMA = PromptSchema([PROMPT_UID_COLS[0], PROMPT_TEXT_COLS[0]]) +DEFAULT_PROMPT_RESPONSE_SCHEMA = PromptResponseSchema( + [PROMPT_UID_COLS[0], PROMPT_TEXT_COLS[0], SUT_UID_COLS[0], SUT_RESPONSE_COLS[0]] +) diff --git a/src/modelgauge/prompt_pipeline.py b/src/modelgauge/prompt_pipeline.py index bd52216e4..3ec1bfed8 100644 --- a/src/modelgauge/prompt_pipeline.py +++ b/src/modelgauge/prompt_pipeline.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import Iterable, Optional +from modelgauge.data_schema import DEFAULT_PROMPT_SCHEMA, PromptSchema from modelgauge.pipeline import CachingPipe, Pipe, Sink, Source from modelgauge.prompt import TextPrompt from modelgauge.single_turn_prompt_response import TestItem @@ -13,12 +14,6 @@ logger = logging.getLogger(__name__) -PROMPT_CSV_INPUT_COLUMNS = { - "default": {"id": "UID", "text": "Text"}, - "prompt_set": {"id": "release_prompt_id", "text": "prompt_text"}, # official prompt set files - "db": {"id": "prompt_uid", "text": "prompt_text"}, # database dumps -} - @dataclass class SutInteraction: @@ -51,39 +46,25 @@ class CsvPromptInput(PromptInput): def __init__(self, path): super().__init__() self.path = path - self.prompt_input_type = "default" - self._identify_input() + self.schema = PromptSchema(self._header()) # Validate header and store the schema. - def _extract_field(self, row, field_name): - column_name = PROMPT_CSV_INPUT_COLUMNS[self.prompt_input_type][field_name] - return row[column_name] + def _header(self) -> list[str]: + with open(self.path, newline="") as f: + csvreader = csv.reader(f) + return next(csvreader) def __iter__(self) -> Iterable[TestItem]: with open(self.path, newline="") as f: csvreader = csv.DictReader(f) for row in csvreader: yield TestItem( - prompt=TextPrompt(text=self._extract_field(row, "text")), + prompt=TextPrompt(text=row[self.schema.prompt_text]), # Forward the underlying id to help make data tracking easier. - source_id=self._extract_field(row, "id"), + source_id=row[self.schema.prompt_uid], # Context can be any type you want. context=row, ) - def _identify_input(self): - with open(self.path, newline="") as f: - csvreader = csv.reader(f) - columns = next(csvreader) - is_valid = False - for prompt_input_type, column_names in PROMPT_CSV_INPUT_COLUMNS.items(): - if all(c in columns for c in column_names.values()): - self.prompt_input_type = prompt_input_type - is_valid = True - break - assert ( - is_valid - ), f"Unsupported input file. Required columns are one of: f{PROMPT_CSV_INPUT_COLUMNS.values()}\nActual columns are: f{columns}" - class PromptOutput(metaclass=ABCMeta): def __enter__(self): @@ -110,7 +91,10 @@ def __init__(self, path, suts): def __enter__(self): self.file = open(self.path, "w", newline="") self.writer = csv.writer(self.file, quoting=csv.QUOTE_ALL) - self.writer.writerow(list(PROMPT_CSV_INPUT_COLUMNS["default"].values()) + [s for s in self.suts.keys()]) + # TODO: Standardize SUT columns. + self.writer.writerow( + [DEFAULT_PROMPT_SCHEMA.prompt_uid, DEFAULT_PROMPT_SCHEMA.prompt_text] + [s for s in self.suts.keys()] + ) return self def __exit__(self, exc_type, exc_val, exc_tb): diff --git a/tests/modelgauge_tests/test_annotation_pipeline.py b/tests/modelgauge_tests/test_annotation_pipeline.py index 0add53405..df1477554 100644 --- a/tests/modelgauge_tests/test_annotation_pipeline.py +++ b/tests/modelgauge_tests/test_annotation_pipeline.py @@ -16,6 +16,7 @@ JsonlAnnotatorOutput, ) from modelgauge.annotator_set import AnnotatorSet +from modelgauge.data_schema import DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA from modelgauge.pipeline import Pipeline from modelgauge.prompt import TextPrompt from modelgauge.prompt_pipeline import ( @@ -45,12 +46,12 @@ def __iter__(self): for row in self.items: time.sleep(next(self.delay)) prompt = TestItem( - prompt=TextPrompt(text=row["Prompt"]), - source_id=row["UID"], + prompt=TextPrompt(text=row[PROMPT_RESPONSE_SCHEMA.prompt_text]), + source_id=row[PROMPT_RESPONSE_SCHEMA.prompt_uid], context=row, ) - response = SUTResponse(text=row["Response"]) - yield SutInteraction(prompt, row["SUT"], response) + response = SUTResponse(text=row[PROMPT_RESPONSE_SCHEMA.sut_response]) + yield SutInteraction(prompt, row[PROMPT_RESPONSE_SCHEMA.sut_uid], response) class FakeAnnotatorOutput(PromptOutput): @@ -81,7 +82,9 @@ def sut_interactions_is_equal(a, b): def test_csv_annotator_input(tmp_path): file_path = tmp_path / "input.csv" - file_path.write_text('UID,Prompt,SUT,Response\n"1","a","s","b"') + file_path.write_text( + f'{PROMPT_RESPONSE_SCHEMA.prompt_uid},{PROMPT_RESPONSE_SCHEMA.prompt_text},{PROMPT_RESPONSE_SCHEMA.sut_uid},{PROMPT_RESPONSE_SCHEMA.sut_response}\n"1","a","s","b"' + ) input = CsvAnnotatorInput(file_path) assert len(input) == 1 @@ -89,24 +92,6 @@ def test_csv_annotator_input(tmp_path): assert sut_interactions_is_equal(item, make_sut_interaction("1", "a", "s", "b")) -@pytest.mark.parametrize( - "header", - [ - "Prompt,UID,Extra,Response,Response\n", - "UID,Prompt,SUT\n", - "Extra,Response,Extra\n", - ], -) -def test_csv_annotator_input_invalid_columns(tmp_path, header): - file_path = tmp_path / "input.csv" - file_path.write_text(header) - with pytest.raises( - AssertionError, - match="Invalid input file. Must have columns: UID, Prompt, SUT, Response.", - ): - CsvAnnotatorInput(file_path) - - def test_json_annotator_output(tmp_path): file_path = tmp_path / "output.jsonl" with JsonlAnnotatorOutput(file_path) as output: @@ -117,17 +102,17 @@ def test_json_annotator_output(tmp_path): items: list[dict] = [i for i in reader] assert len(items) == 2 assert items[0] == { - "UID": "1", - "Prompt": "a", - "SUT": "sut1", - "Response": "b", + PROMPT_RESPONSE_SCHEMA.prompt_uid: "1", + PROMPT_RESPONSE_SCHEMA.prompt_text: "a", + PROMPT_RESPONSE_SCHEMA.sut_uid: "sut1", + PROMPT_RESPONSE_SCHEMA.sut_response: "b", "Annotations": {"fake": "x"}, } assert items[1] == { - "UID": "2", - "Prompt": "c", - "SUT": "sut2", - "Response": "d", + PROMPT_RESPONSE_SCHEMA.prompt_uid: "2", + PROMPT_RESPONSE_SCHEMA.prompt_text: "c", + PROMPT_RESPONSE_SCHEMA.sut_uid: "sut2", + PROMPT_RESPONSE_SCHEMA.sut_response: "d", "Annotations": {"fake": "y"}, } @@ -313,8 +298,18 @@ def test_ensemble_worker_computes_ensemble_with_all_annotators(): def test_full_run(annotators): input = FakeAnnotatorInput( [ - {"UID": "1", "Prompt": "a", "Response": "b", "SUT": "s"}, - {"UID": "2", "Prompt": "c", "Response": "d", "SUT": "s"}, + { + PROMPT_RESPONSE_SCHEMA.prompt_uid: "1", + PROMPT_RESPONSE_SCHEMA.prompt_text: "a", + PROMPT_RESPONSE_SCHEMA.sut_response: "b", + PROMPT_RESPONSE_SCHEMA.sut_uid: "s", + }, + { + PROMPT_RESPONSE_SCHEMA.prompt_uid: "2", + PROMPT_RESPONSE_SCHEMA.prompt_text: "c", + PROMPT_RESPONSE_SCHEMA.sut_response: "d", + PROMPT_RESPONSE_SCHEMA.sut_uid: "s", + }, ] ) output = FakeAnnotatorOutput() @@ -346,8 +341,18 @@ def test_full_run(annotators): def test_full_run_with_ensemble(annotators): input = FakeAnnotatorInput( [ - {"UID": "1", "Prompt": "a", "Response": "b", "SUT": "s"}, - {"UID": "2", "Prompt": "c", "Response": "d", "SUT": "s"}, + { + PROMPT_RESPONSE_SCHEMA.prompt_uid: "1", + PROMPT_RESPONSE_SCHEMA.prompt_text: "a", + PROMPT_RESPONSE_SCHEMA.sut_response: "b", + PROMPT_RESPONSE_SCHEMA.sut_uid: "s", + }, + { + PROMPT_RESPONSE_SCHEMA.prompt_uid: "2", + PROMPT_RESPONSE_SCHEMA.prompt_text: "c", + PROMPT_RESPONSE_SCHEMA.sut_response: "d", + PROMPT_RESPONSE_SCHEMA.sut_uid: "s", + }, ] ) output = FakeAnnotatorOutput() @@ -378,8 +383,8 @@ def test_full_run_with_ensemble(annotators): def test_prompt_response_annotation_pipeline(annotators, sut_worker_count, annotator_worker_count): input = FakePromptInput( [ - {"UID": "1", "Text": "a"}, - {"UID": "2", "Text": "b"}, + {PROMPT_RESPONSE_SCHEMA.prompt_uid: "1", PROMPT_RESPONSE_SCHEMA.prompt_text: "a"}, + {PROMPT_RESPONSE_SCHEMA.prompt_uid: "2", PROMPT_RESPONSE_SCHEMA.prompt_text: "b"}, ] ) output = FakeAnnotatorOutput() @@ -401,9 +406,14 @@ def test_prompt_response_annotation_pipeline(annotators, sut_worker_count, annot prompt, sut = prompt_sut assert sut_interactions_is_equal( interaction, - make_sut_interaction(prompt["UID"], prompt["Text"], sut, prompt["Text"]), + make_sut_interaction( + prompt[PROMPT_RESPONSE_SCHEMA.prompt_uid], + prompt[PROMPT_RESPONSE_SCHEMA.prompt_text], + sut, + prompt[PROMPT_RESPONSE_SCHEMA.prompt_text], + ), ) - annotation = {"sut_text": prompt["Text"]} + annotation = {"sut_text": prompt[PROMPT_RESPONSE_SCHEMA.prompt_text]} assert output.output[interaction] == { "annotator_pydantic": annotation, "annotator_dict": annotation, diff --git a/tests/modelgauge_tests/test_cli.py b/tests/modelgauge_tests/test_cli.py index 8f9e83ba7..55a8cefd4 100644 --- a/tests/modelgauge_tests/test_cli.py +++ b/tests/modelgauge_tests/test_cli.py @@ -19,6 +19,10 @@ from modelgauge.annotator_set import AnnotatorSet from modelgauge.command_line import _validate_sut_uid, check_secrets, classify_sut_ids, validate_uid from modelgauge.config import MissingSecretsFromConfig +from modelgauge.data_schema import ( + DEFAULT_PROMPT_SCHEMA as PROMPT_SCHEMA, + DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA, +) from modelgauge.secret_values import InjectSecret from modelgauge.sut import SUT, SUTNotFoundException, SUTOptions from modelgauge.sut_decorator import modelgauge_sut @@ -145,7 +149,8 @@ def prompts_file(tmp_path_factory): """Sample file with 2 prompts for testing.""" file = tmp_path_factory.mktemp("data") / "prompts.csv" with open(file, "w") as f: - f.write("UID,Text,Ignored\np1,Say yes,ignored\np2,Refuse,ignored\n") + f.write(f"{PROMPT_SCHEMA.prompt_uid},{PROMPT_SCHEMA.prompt_text}\n") + f.write("p1,Say yes,ignored\np2,Refuse,ignored\n") return file @@ -154,7 +159,10 @@ def prompt_responses_file(tmp_path_factory): """Sample file with 2 prompts + responses from 1 SUT for testing.""" file = tmp_path_factory.mktemp("data") / "prompt-responses.csv" with open(file, "w") as f: - f.write("UID,Prompt,SUT,Response\np1,Say yes,demo_yes_no,Yes\np2,Refuse,demo_yes_no,No\n") + f.write( + f"{PROMPT_RESPONSE_SCHEMA.prompt_uid},{PROMPT_RESPONSE_SCHEMA.prompt_text},{PROMPT_RESPONSE_SCHEMA.sut_uid},{PROMPT_RESPONSE_SCHEMA.sut_response}\n" + ) + f.write("p1,Say yes,demo_yes_no,Yes\np2,Refuse,demo_yes_no,No\n") return file @@ -174,10 +182,10 @@ def test_run_prompts_normal(caplog, tmp_path, prompts_file): reader = csv.DictReader(f) rows = (next(reader), next(reader)) - rows = sorted(rows, key=lambda row: row["UID"]) + rows = sorted(rows, key=lambda row: row[PROMPT_SCHEMA.prompt_uid]) expected = ( - {"UID": "p1", "Text": "Say yes", "demo_yes_no": "Yes"}, - {"UID": "p2", "Text": "Refuse", "demo_yes_no": "No"}, + {PROMPT_SCHEMA.prompt_uid: "p1", PROMPT_SCHEMA.prompt_text: "Say yes", "demo_yes_no": "Yes"}, + {PROMPT_SCHEMA.prompt_uid: "p2", PROMPT_SCHEMA.prompt_text: "Refuse", "demo_yes_no": "No"}, ) assert rows[0] == expected[0] assert rows[1] == expected[1] @@ -243,17 +251,17 @@ def test_run_prompts_with_annotators(caplog, tmp_path, prompts_file): output.append(reader.read()) output.append(reader.read()) assert { - "UID": "p1", - "Prompt": "Say yes", - "SUT": "demo_yes_no", - "Response": "Yes", + PROMPT_RESPONSE_SCHEMA.prompt_uid: "p1", + PROMPT_RESPONSE_SCHEMA.prompt_text: "Say yes", + PROMPT_RESPONSE_SCHEMA.sut_uid: "demo_yes_no", + PROMPT_RESPONSE_SCHEMA.sut_response: "Yes", "Annotations": {"demo_annotator": {"badness": 1.0}}, } in output assert { - "UID": "p2", - "Prompt": "Refuse", - "SUT": "demo_yes_no", - "Response": "No", + PROMPT_RESPONSE_SCHEMA.prompt_uid: "p2", + PROMPT_RESPONSE_SCHEMA.prompt_text: "Refuse", + PROMPT_RESPONSE_SCHEMA.sut_uid: "demo_yes_no", + PROMPT_RESPONSE_SCHEMA.sut_response: "No", "Annotations": {"demo_annotator": {"badness": 0.0}}, } in output @@ -326,17 +334,17 @@ def test_run_annotators(caplog, tmp_path, prompt_responses_file): out_path = re.findall(r"\S+\.jsonl", caplog.text)[0] with jsonlines.open(out_path) as reader: assert reader.read() == { - "UID": "p1", - "Prompt": "Say yes", - "SUT": "demo_yes_no", - "Response": "Yes", + PROMPT_RESPONSE_SCHEMA.prompt_uid: "p1", + PROMPT_RESPONSE_SCHEMA.prompt_text: "Say yes", + PROMPT_RESPONSE_SCHEMA.sut_uid: "demo_yes_no", + PROMPT_RESPONSE_SCHEMA.sut_response: "Yes", "Annotations": {"demo_annotator": {"badness": 1.0}}, } assert reader.read() == { - "UID": "p2", - "Prompt": "Refuse", - "SUT": "demo_yes_no", - "Response": "No", + PROMPT_RESPONSE_SCHEMA.prompt_uid: "p2", + PROMPT_RESPONSE_SCHEMA.prompt_text: "Refuse", + PROMPT_RESPONSE_SCHEMA.sut_uid: "demo_yes_no", + PROMPT_RESPONSE_SCHEMA.sut_response: "No", "Annotations": {"demo_annotator": {"badness": 0.0}}, } diff --git a/tests/modelgauge_tests/test_data_schema.py b/tests/modelgauge_tests/test_data_schema.py index 9054c1abf..99ac9d5fc 100644 --- a/tests/modelgauge_tests/test_data_schema.py +++ b/tests/modelgauge_tests/test_data_schema.py @@ -1,6 +1,8 @@ import pytest from modelgauge.data_schema import ( + DEFAULT_PROMPT_RESPONSE_SCHEMA, + DEFAULT_PROMPT_SCHEMA, PROMPT_TEXT_COLS, PROMPT_UID_COLS, PromptResponseSchema, @@ -43,6 +45,11 @@ def test_invalid_prompt_schema(): assert set(e.missing_columns) == {PROMPT_UID_COLS, PROMPT_TEXT_COLS} +def test_default_prompt_schema(): + assert DEFAULT_PROMPT_SCHEMA.prompt_uid == "prompt_uid" + assert DEFAULT_PROMPT_SCHEMA.prompt_text == "prompt_text" + + @pytest.mark.parametrize( "header", [ @@ -71,3 +78,10 @@ def test_invalid_prompt_valid_response_schema(): with pytest.raises(SchemaValidationError) as e: schema = PromptResponseSchema(header) assert set(e.missing_columns) == {PROMPT_UID_COLS, PROMPT_TEXT_COLS} + + +def test_default_prompt_response_schema(): + assert DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid == "prompt_uid" + assert DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text == "prompt_text" + assert DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid == "sut_uid" + assert DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response == "sut_response" diff --git a/tests/modelgauge_tests/test_pipeline_runner.py b/tests/modelgauge_tests/test_pipeline_runner.py index 41fa210ea..22aabab84 100644 --- a/tests/modelgauge_tests/test_pipeline_runner.py +++ b/tests/modelgauge_tests/test_pipeline_runner.py @@ -9,6 +9,10 @@ CsvAnnotatorInput, ) from modelgauge.annotator_set import AnnotatorSet +from modelgauge.data_schema import ( + DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA, + DEFAULT_PROMPT_SCHEMA as PROMPT_SCHEMA, +) from modelgauge.pipeline_runner import ( AnnotatorRunner, EnsembleRunner, @@ -37,7 +41,7 @@ def prompts_file(tmp_path_factory): """Sample file with 3 prompts for testing.""" file = tmp_path_factory.mktemp("data") / "prompts.csv" with open(file, "w") as f: - text = "UID,Text\n" + text = f"{PROMPT_SCHEMA.prompt_uid},{PROMPT_SCHEMA.prompt_text}\n" for i in range(NUM_PROMPTS): text += f"p{i},Prompt {i}\n" f.write(text) @@ -341,7 +345,7 @@ def prompt_responses_file(self, tmp_path_factory): """Sample file with 2 prompts + responses from 2 SUTs for testing.""" file = tmp_path_factory.mktemp("data") / "prompt-responses.csv" with open(file, "w") as f: - text = "UID,Prompt,SUT,Response\n" + text = f"{PROMPT_RESPONSE_SCHEMA.prompt_uid},{PROMPT_RESPONSE_SCHEMA.prompt_text},{PROMPT_RESPONSE_SCHEMA.sut_uid},{PROMPT_RESPONSE_SCHEMA.sut_response}\n" for i in range(NUM_PROMPTS): text += f"p{i},Prompt {i},sut1,Response {i}\n" text += f"p{i},Prompt {i},sut2,Response {i}\n" diff --git a/tests/modelgauge_tests/test_prompt_pipeline.py b/tests/modelgauge_tests/test_prompt_pipeline.py index 200f0b14a..1799111ec 100644 --- a/tests/modelgauge_tests/test_prompt_pipeline.py +++ b/tests/modelgauge_tests/test_prompt_pipeline.py @@ -7,6 +7,7 @@ import pytest +from modelgauge.data_schema import DEFAULT_PROMPT_SCHEMA as PROMPT_SCHEMA, SchemaValidationError from modelgauge.pipeline import Pipeline, PipelineSegment from modelgauge.prompt import TextPrompt from modelgauge.prompt_pipeline import ( @@ -51,8 +52,8 @@ def __iter__(self): for row in self.items: time.sleep(next(self.delay)) yield TestItem( - prompt=TextPrompt(text=row["Text"]), - source_id=row["UID"], + prompt=TextPrompt(text=row[PROMPT_SCHEMA.prompt_text]), + source_id=row[PROMPT_SCHEMA.prompt_uid], context=row, ) @@ -83,7 +84,7 @@ def suts(): def test_csv_prompt_input(tmp_path): file_path = tmp_path / "input.csv" - file_path.write_text('UID,Text\n"1","a"') + file_path.write_text(f'{PROMPT_SCHEMA.prompt_uid},{PROMPT_SCHEMA.prompt_text}\n"1","a"') input = CsvPromptInput(file_path) assert len(input) == 1 @@ -97,30 +98,10 @@ def test_csv_prompt_input(tmp_path): def test_csv_prompt_input_invalid_columns(tmp_path, header): file_path = tmp_path / "input.csv" file_path.write_text(header) - with pytest.raises(AssertionError, match="Unsupported input file. Required columns are"): + with pytest.raises(SchemaValidationError): CsvPromptInput(file_path) -@pytest.mark.parametrize( - "header", - [ - "UID,Text\n", - "release_prompt_id,prompt_text\n", - "prompt_uid,prompt_text\n", - "UID,Text,other,fields\n", - "release_prompt_id,prompt_text,other,fields\n", - "prompt_uid,prompt_text,other,fields\n", - "Text,spacer,UID\n", - "release_prompt_id,spacer,prompt_text\n", - "prompt_uid,spacer,prompt_text\n", - ], -) -def test_csv_prompt_input_accepts_multiple_formats(tmp_path, header): - file_path = tmp_path / "input.csv" - file_path.write_text(header) - _ = CsvPromptInput(file_path) - - def test_csv_prompt_output(tmp_path, suts): file_path = tmp_path / "output.csv" @@ -134,8 +115,8 @@ def test_csv_prompt_output(tmp_path, suts): # noinspection PyTypeChecker items: list[dict] = [i for i in (DictReader(f))] assert len(items) == 1 - assert items[0]["UID"] == "1" - assert items[0]["Text"] == "a" + assert items[0][PROMPT_SCHEMA.prompt_uid] == "1" + assert items[0][PROMPT_SCHEMA.prompt_text] == "a" assert items[0]["fake1"] == "a1" assert items[0]["fake2"] == "a2" @@ -207,8 +188,8 @@ def test_prompt_sut_worker_retries_until_success(suts): def test_full_run(suts): input = FakePromptInput( [ - {"UID": "1", "Text": "a"}, - {"UID": "2", "Text": "b"}, + {PROMPT_SCHEMA.prompt_uid: "1", PROMPT_SCHEMA.prompt_text: "a"}, + {PROMPT_SCHEMA.prompt_uid: "2", PROMPT_SCHEMA.prompt_text: "b"}, ] ) output = FakePromptOutput() @@ -224,7 +205,7 @@ def test_full_run(suts): p.run() assert len(output.output) == len(input.items) - assert sorted([r["item"].source_id for r in output.output]) == [i["UID"] for i in input.items] + assert sorted([r["item"].source_id for r in output.output]) == [i[PROMPT_SCHEMA.prompt_uid] for i in input.items] row1 = output.output[0] assert "fake1" in row1["results"] assert "fake2" in row1["results"] @@ -245,7 +226,7 @@ def test_concurrency_with_delays(suts, worker_count): "fake2": FakeSUTWithDelay("fake2", delay=sut_delays), } input = FakePromptInput( - [{"UID": str(i), "Text": "text" + str(i)} for i in range(prompt_count)], + [{PROMPT_SCHEMA.prompt_uid: str(i), PROMPT_SCHEMA.prompt_text: "text" + str(i)} for i in range(prompt_count)], delay=prompt_delays, ) output = FakePromptOutput() @@ -268,8 +249,8 @@ def test_concurrency_with_delays(suts, worker_count): def test_progress(suts): input = FakePromptInput( [ - {"UID": "1", "Text": "a"}, - {"UID": "2", "Text": "b"}, + {PROMPT_SCHEMA.prompt_uid: "1", PROMPT_SCHEMA.prompt_text: "a"}, + {PROMPT_SCHEMA.prompt_uid: "2", PROMPT_SCHEMA.prompt_text: "b"}, ] ) output = FakePromptOutput() From f4f2dcc3b056d84f12c7b3e83e87bbfe662b0ec9 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Wed, 18 Jun 2025 17:18:02 -0700 Subject: [PATCH 03/15] Prompt runner outputs each sut response in different row --- src/modelgauge/prompt_pipeline.py | 20 +++++++++---------- tests/modelgauge_tests/test_cli.py | 16 ++++++++++++--- .../modelgauge_tests/test_prompt_pipeline.py | 20 +++++++++++++------ 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/src/modelgauge/prompt_pipeline.py b/src/modelgauge/prompt_pipeline.py index 3ec1bfed8..2c1a191ff 100644 --- a/src/modelgauge/prompt_pipeline.py +++ b/src/modelgauge/prompt_pipeline.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from typing import Iterable, Optional -from modelgauge.data_schema import DEFAULT_PROMPT_SCHEMA, PromptSchema +from modelgauge.data_schema import DEFAULT_PROMPT_RESPONSE_SCHEMA, DEFAULT_PROMPT_SCHEMA, PromptSchema from modelgauge.pipeline import CachingPipe, Pipe, Sink, Source from modelgauge.prompt import TextPrompt from modelgauge.single_turn_prompt_response import TestItem @@ -79,6 +79,10 @@ def write(self, item, results): class CsvPromptOutput(PromptOutput): + """Outputs a CSV file where each row represents one SUT's response to a prompt.""" + + schema = DEFAULT_PROMPT_RESPONSE_SCHEMA + def __init__(self, path, suts): super().__init__() assert path.suffix.lower() == ".csv", f"Invalid output file {path}. Must be of type CSV." @@ -91,9 +95,8 @@ def __init__(self, path, suts): def __enter__(self): self.file = open(self.path, "w", newline="") self.writer = csv.writer(self.file, quoting=csv.QUOTE_ALL) - # TODO: Standardize SUT columns. self.writer.writerow( - [DEFAULT_PROMPT_SCHEMA.prompt_uid, DEFAULT_PROMPT_SCHEMA.prompt_text] + [s for s in self.suts.keys()] + [self.schema.prompt_uid, self.schema.prompt_text, self.schema.sut_uid, self.schema.sut_response] ) return self @@ -101,13 +104,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.file.close() def write(self, item: TestItem, results): - row = [item.source_id, item.prompt.text] # type: ignore - for k in self.suts: - if k in results: - row.append(results[k]) - else: - row.append("") - self.writer.writerow(row) + base_row = [item.source_id, item.prompt.text] # type: ignore + for sut in self.suts: + if sut in results: + self.writer.writerow(base_row + [sut, results[sut]]) def launder_the_type_problem(self, item) -> str: return item.prompt.text diff --git a/tests/modelgauge_tests/test_cli.py b/tests/modelgauge_tests/test_cli.py index 55a8cefd4..3e1b560bd 100644 --- a/tests/modelgauge_tests/test_cli.py +++ b/tests/modelgauge_tests/test_cli.py @@ -182,10 +182,20 @@ def test_run_prompts_normal(caplog, tmp_path, prompts_file): reader = csv.DictReader(f) rows = (next(reader), next(reader)) - rows = sorted(rows, key=lambda row: row[PROMPT_SCHEMA.prompt_uid]) + rows = sorted(rows, key=lambda row: row[PROMPT_RESPONSE_SCHEMA.prompt_uid]) expected = ( - {PROMPT_SCHEMA.prompt_uid: "p1", PROMPT_SCHEMA.prompt_text: "Say yes", "demo_yes_no": "Yes"}, - {PROMPT_SCHEMA.prompt_uid: "p2", PROMPT_SCHEMA.prompt_text: "Refuse", "demo_yes_no": "No"}, + { + PROMPT_RESPONSE_SCHEMA.prompt_uid: "p1", + PROMPT_RESPONSE_SCHEMA.prompt_text: "Say yes", + PROMPT_RESPONSE_SCHEMA.sut_uid: "demo_yes_no", + PROMPT_RESPONSE_SCHEMA.sut_response: "Yes", + }, + { + PROMPT_RESPONSE_SCHEMA.prompt_uid: "p2", + PROMPT_RESPONSE_SCHEMA.prompt_text: "Refuse", + PROMPT_RESPONSE_SCHEMA.sut_uid: "demo_yes_no", + PROMPT_RESPONSE_SCHEMA.sut_response: "No", + }, ) assert rows[0] == expected[0] assert rows[1] == expected[1] diff --git a/tests/modelgauge_tests/test_prompt_pipeline.py b/tests/modelgauge_tests/test_prompt_pipeline.py index 1799111ec..5b1c7bdc5 100644 --- a/tests/modelgauge_tests/test_prompt_pipeline.py +++ b/tests/modelgauge_tests/test_prompt_pipeline.py @@ -7,7 +7,11 @@ import pytest -from modelgauge.data_schema import DEFAULT_PROMPT_SCHEMA as PROMPT_SCHEMA, SchemaValidationError +from modelgauge.data_schema import ( + DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA, + DEFAULT_PROMPT_SCHEMA as PROMPT_SCHEMA, + SchemaValidationError, +) from modelgauge.pipeline import Pipeline, PipelineSegment from modelgauge.prompt import TextPrompt from modelgauge.prompt_pipeline import ( @@ -114,11 +118,15 @@ def test_csv_prompt_output(tmp_path, suts): with open(file_path, "r", newline="") as f: # noinspection PyTypeChecker items: list[dict] = [i for i in (DictReader(f))] - assert len(items) == 1 - assert items[0][PROMPT_SCHEMA.prompt_uid] == "1" - assert items[0][PROMPT_SCHEMA.prompt_text] == "a" - assert items[0]["fake1"] == "a1" - assert items[0]["fake2"] == "a2" + assert len(items) == 2 + assert items[0][PROMPT_RESPONSE_SCHEMA.prompt_uid] == "1" + assert items[0][PROMPT_RESPONSE_SCHEMA.prompt_text] == "a" + assert items[0][PROMPT_RESPONSE_SCHEMA.sut_uid] == "fake1" + assert items[0][PROMPT_RESPONSE_SCHEMA.sut_response] == "a1" + assert items[1][PROMPT_RESPONSE_SCHEMA.prompt_uid] == "1" + assert items[1][PROMPT_RESPONSE_SCHEMA.prompt_text] == "a" + assert items[1][PROMPT_RESPONSE_SCHEMA.sut_uid] == "fake2" + assert items[1][PROMPT_RESPONSE_SCHEMA.sut_response] == "a2" @pytest.mark.parametrize("output_fname", ["output.jsonl", "output"]) From a0a5bfa9c6e82551a96d1fc7f66a0de895629442 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Mon, 30 Jun 2025 13:05:47 -0700 Subject: [PATCH 04/15] annotation data schema --- src/modelgauge/data_schema.py | 49 ++++++++++++++++++- tests/modelgauge_tests/test_data_schema.py | 55 ++++++++++++++++++++++ 2 files changed, 102 insertions(+), 2 deletions(-) diff --git a/src/modelgauge/data_schema.py b/src/modelgauge/data_schema.py index c02d16e52..33c1085d4 100644 --- a/src/modelgauge/data_schema.py +++ b/src/modelgauge/data_schema.py @@ -3,6 +3,8 @@ PROMPT_TEXT_COLS = ["prompt_text"] SUT_UID_COLS = ["sut_uid", "sut"] SUT_RESPONSE_COLS = ["sut_response", "response_text", "response"] +ANNOTATOR_UID_COLS = ["annotator_uid"] +ANNOTATION_COLS = ["is_safe"] class SchemaValidationError(ValueError): @@ -26,7 +28,7 @@ def __str__(self): class PromptSchema: - """A case-insensitive schema for a prompts file. + """A case-insensitive schema for a prompts file that is used as input to get SUT responses. Attributes: prompt_uid: The column name for the prompt uid. @@ -34,6 +36,7 @@ class PromptSchema: """ def __init__(self, header: list[str]): + self.header = header self.prompt_uid = self._find_column(header, PROMPT_UID_COLS) self.prompt_text = self._find_column(header, PROMPT_TEXT_COLS) self._validate() @@ -58,7 +61,7 @@ def _validate(self): class PromptResponseSchema(PromptSchema): - """A schema for a prompt + response file that is used as annotation input. + """A schema for a prompt + response file that is used as prompt-response output or annotation input. Attributes: prompt_uid: The column name for the prompt uid. (same as PromptSchema) prompt_text: The column name for the prompt text. (same as PromptSchema) @@ -87,8 +90,50 @@ def _validate(self): raise SchemaValidationError(missing) +class AnnotationSchema(PromptResponseSchema): + """A schema for a prompt + response + annotation file that is used as annotation output. + Attributes: + prompt_uid: The column name for the prompt uid. (same as PromptSchema) + prompt_text: The column name for the prompt text. (same as PromptSchema) + sut_uid: The column name for the SUT uid. (same as PromptResponseSchema) + sut_response: The column name for the SUT response. (same as PromptResponseSchema) + annotator_uid: The column name for the annotator uid. + annotation: The column name for the text annotation. + """ + + def __init__(self, header: list[str]): + self.annotator_uid = self._find_column(header, ANNOTATOR_UID_COLS) + self.annotation = self._find_column(header, ANNOTATION_COLS) + super().__init__(header) # Iniitalize the prompt schema columns and then validate. + + def _validate(self): + missing = [] + # Validate that the prompt schema is valid + try: + super()._validate() + except SchemaValidationError as e: + missing.extend(e.missing_columns) + # Validate that the SUT uid and response columns are present + if self.annotator_uid is None: + missing.append(ANNOTATOR_UID_COLS) + if self.annotation is None: + missing.append(ANNOTATION_COLS) + if missing: + raise SchemaValidationError(missing) + + # Schemas with preferred names. DEFAULT_PROMPT_SCHEMA = PromptSchema([PROMPT_UID_COLS[0], PROMPT_TEXT_COLS[0]]) DEFAULT_PROMPT_RESPONSE_SCHEMA = PromptResponseSchema( [PROMPT_UID_COLS[0], PROMPT_TEXT_COLS[0], SUT_UID_COLS[0], SUT_RESPONSE_COLS[0]] ) +DEFAULT_ANNOTATION_SCHEMA = AnnotationSchema( + [ + PROMPT_UID_COLS[0], + PROMPT_TEXT_COLS[0], + SUT_UID_COLS[0], + SUT_RESPONSE_COLS[0], + ANNOTATOR_UID_COLS[0], + ANNOTATION_COLS[0], + ] +) diff --git a/tests/modelgauge_tests/test_data_schema.py b/tests/modelgauge_tests/test_data_schema.py index 99ac9d5fc..7b3d4eac3 100644 --- a/tests/modelgauge_tests/test_data_schema.py +++ b/tests/modelgauge_tests/test_data_schema.py @@ -1,10 +1,16 @@ import pytest from modelgauge.data_schema import ( + ANNOTATOR_UID_COLS, + ANNOTATION_COLS, + DEFAULT_ANNOTATION_SCHEMA, DEFAULT_PROMPT_RESPONSE_SCHEMA, DEFAULT_PROMPT_SCHEMA, PROMPT_TEXT_COLS, PROMPT_UID_COLS, + SUT_UID_COLS, + SUT_RESPONSE_COLS, + AnnotationSchema, PromptResponseSchema, PromptSchema, SchemaValidationError, @@ -85,3 +91,52 @@ def test_default_prompt_response_schema(): assert DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text == "prompt_text" assert DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid == "sut_uid" assert DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response == "sut_response" + + +@pytest.mark.parametrize( + "header", + [ + # Preferred names + ["prompt_uid", "prompt_text", "sut_uid", "sut_response", "annotator_uid", "is_safe"], + # Case-insensitive + ["prompt_UID", "Prompt_Text", "SUT_UID", "SUT_Response", "Annotator_UID", "Is_Safe"], + # Extra columns are allowed + ["prompt_uid", "prompt_text", "sut_uid", "sut_response", "annotator_uid", "is_safe", "extra_col"], + ], +) +def test_valid_annotation_schema(header): + schema = AnnotationSchema(header) + assert schema.prompt_uid == header[0] + assert schema.prompt_text == header[1] + assert schema.sut_uid == header[2] + assert schema.sut_response == header[3] + assert schema.annotator_uid == header[4] + assert schema.annotation == header[5] + + +def test_valid_prompt_response_invalid_annotation_schema(): + header = ["prompt_uid", "prompt_text", "sut_uid", "sut_response", "random_column", "random_column_2"] + with pytest.raises(SchemaValidationError) as e: + schema = AnnotationSchema(header) + assert set(e.missing_columns) == {ANNOTATOR_UID_COLS, ANNOTATION_COLS} + + +def test_invalid_prompt_response_valid_annotation_schema(): + header = ["random_1", "random_2", "random_3", "random_4", "annotator_uid", "is_safe"] + with pytest.raises(SchemaValidationError) as e: + schema = AnnotationSchema(header) + assert set(e.missing_columns) == { + PROMPT_UID_COLS, + PROMPT_TEXT_COLS, + SUT_UID_COLS, + SUT_RESPONSE_COLS, + } + + +def test_default_annotation_schema(): + assert DEFAULT_ANNOTATION_SCHEMA.prompt_uid == "prompt_uid" + assert DEFAULT_ANNOTATION_SCHEMA.prompt_text == "prompt_text" + assert DEFAULT_ANNOTATION_SCHEMA.sut_uid == "sut_uid" + assert DEFAULT_ANNOTATION_SCHEMA.sut_response == "sut_response" + assert DEFAULT_ANNOTATION_SCHEMA.annotator_uid == "annotator_uid" + assert DEFAULT_ANNOTATION_SCHEMA.annotation == "is_safe" From faf6fa6c77d65cc71eb3373e76e887a6bf0b9036 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 1 Jul 2025 12:44:35 -0700 Subject: [PATCH 05/15] New dataset objects --- src/modelgauge/annotation_pipeline.py | 16 +- src/modelgauge/dataset.py | 212 ++++++++++ src/modelgauge/pipeline_runner.py | 1 + src/modelgauge/prompt_pipeline.py | 24 +- src/modelgauge/single_turn_prompt_response.py | 11 + .../test_annotation_pipeline.py | 9 +- tests/modelgauge_tests/test_dataset.py | 392 ++++++++++++++++++ .../modelgauge_tests/test_prompt_pipeline.py | 11 +- 8 files changed, 639 insertions(+), 37 deletions(-) create mode 100644 src/modelgauge/dataset.py create mode 100644 tests/modelgauge_tests/test_dataset.py diff --git a/src/modelgauge/annotation_pipeline.py b/src/modelgauge/annotation_pipeline.py index 8750914f1..bc3bbd166 100644 --- a/src/modelgauge/annotation_pipeline.py +++ b/src/modelgauge/annotation_pipeline.py @@ -13,8 +13,8 @@ from modelgauge.data_schema import DEFAULT_PROMPT_RESPONSE_SCHEMA, PromptResponseSchema from modelgauge.pipeline import CachingPipe, Pipe, Sink, Source from modelgauge.prompt import TextPrompt -from modelgauge.prompt_pipeline import PromptOutput, SutInteraction -from modelgauge.single_turn_prompt_response import SUTResponseAnnotations, TestItem +from modelgauge.prompt_pipeline import PromptOutput +from modelgauge.single_turn_prompt_response import SUTResponseAnnotations, SUTInteraction, TestItem from modelgauge.sut import PromptResponseSUT, SUTResponse logger = logging.getLogger(__name__) @@ -22,7 +22,7 @@ class AnnotatorInput(metaclass=ABCMeta): @abstractmethod - def __iter__(self) -> Iterable[SutInteraction]: + def __iter__(self) -> Iterable[SUTInteraction]: pass def __len__(self): @@ -43,7 +43,7 @@ def _header(self) -> list[str]: csvreader = csv.reader(f) return next(csvreader) - def __iter__(self) -> Iterable[SutInteraction]: + def __iter__(self) -> Iterable[SUTInteraction]: with open(self.path, newline="") as f: csvreader = csv.DictReader(f) for row in csvreader: @@ -55,7 +55,7 @@ def __iter__(self) -> Iterable[SutInteraction]: context=row, ) response = SUTResponse(text=row[self.schema.sut_response]) - yield SutInteraction(prompt, row[self.schema.sut_uid], response) + yield SUTInteraction(prompt, row[self.schema.sut_uid], response) class JsonlAnnotatorOutput(PromptOutput): @@ -76,7 +76,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.writer.close() self.file.close() - def write(self, item: SutInteraction, results): + def write(self, item: SUTInteraction, results): if not isinstance(item.prompt.prompt, TextPrompt): raise Exception(f"Error handling {item}. Can only handle TextPrompts.") # TODO: Standardize annotation schema. @@ -104,7 +104,7 @@ def __init__(self, annotators: dict[str, Annotator]): super().__init__() self.annotators = annotators - def handle_item(self, item: SutInteraction): + def handle_item(self, item: SUTInteraction): for annotator_uid in self.annotators: self.downstream_put((item, annotator_uid)) @@ -172,7 +172,7 @@ def handle_item(self, item): class AnnotatorSink(Sink): - unfinished: defaultdict[SutInteraction, dict[str, str]] + unfinished: defaultdict[SUTInteraction, dict[str, str]] def __init__(self, annotators: dict[str, Annotator], writer: JsonlAnnotatorOutput, ensemble: bool = False): super().__init__() diff --git a/src/modelgauge/dataset.py b/src/modelgauge/dataset.py new file mode 100644 index 000000000..02ad4416a --- /dev/null +++ b/src/modelgauge/dataset.py @@ -0,0 +1,212 @@ +import csv +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, Iterable, Optional, Union, Any, Sequence + +from modelgauge.data_schema import ( + DEFAULT_ANNOTATION_SCHEMA, + DEFAULT_PROMPT_RESPONSE_SCHEMA, + DEFAULT_PROMPT_SCHEMA, + AnnotationSchema, + PromptResponseSchema, + PromptSchema, +) +from modelgauge.prompt import TextPrompt +from modelgauge.single_turn_prompt_response import SutInteraction, TestItem +from modelgauge.sut import SUTResponse + + +class BaseDataset(ABC): + """This class provides common functionality for CSV file handling and context management.""" + + def __init__(self, path: Union[str, Path], mode: str): + """Args: + path: Path to the dataset file + mode: Mode to open the file in ('r' for read, 'w' for write) + """ + self.path = Path(path) + self.mode = mode + assert mode in ["r", "w"], f"Invalid dataset mode {mode}. Must be 'r' or 'w'." + if self.mode == "r" and not self.path.exists(): + raise FileNotFoundError(f"File {self.path} does not exist.") + if self.mode == "w" and self.path.exists(): + raise FileExistsError(f"File {self.path} already exists.") + + self.file = None + self.writer = None + self.reader = None + self.schema = None + self._init_schema() # Initialized by subclass. + + def __enter__(self): + """Context manager entry. Opens the file and sets the reader or writer.""" + if self.file is not None: + raise RuntimeError("Cannot enter context manager twice before exiting.") + if self.mode == "w" and not self.path.exists(): + # New file, need to write header. + self.file = open(self.path, mode=self.mode, newline="") + self.writer = csv.writer(self.file, quoting=csv.QUOTE_MINIMAL) + self.writer.writerow(self.header_columns()) + elif self.mode == "w": + # Append to existing file. + self.file = open(self.path, mode="a", newline="") + self.writer = csv.writer(self.file, quoting=csv.QUOTE_MINIMAL) + elif self.mode == "r": + self.file = open(self.path, mode=self.mode, newline="") + self.reader = csv.DictReader(self.file) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit. Closes the file and unsets the reader and writer.""" + if self.file: + self.file.close() + self.file = None + self.reader = None + self.writer = None + + def __iter__(self): + """Base iterator implementation that ensures proper context management. + Will enter the context if not already open. + """ + if self.mode != "r": + raise RuntimeError("Can only iterate over dataset in read mode.") + + # If we're not already in a context, create one for this iteration + if self.file is None: + with self: + for row in self.reader: + yield self.row_to_item(row) + else: + # We're already in a context, just yield + for row in self.reader: + yield self.row_to_item(row) + + def __len__(self) -> int: + if self.mode != "r": + raise NotImplementedError("Length not supported in write mode") + count = 0 + with open(self.path, newline="") as f: + csvreader = csv.reader(f) + next(csvreader) # Skip header row + for row in csvreader: + count += 1 + return count + + @abstractmethod + def _init_schema(self): + """Initialize dataset schema `self.schema`. To be implemented by subclasses.""" + pass + + def _read_header(self) -> list[str]: + """Read the header row from a CSV file.""" + if self.mode != "r": + raise RuntimeError("Can only read header in read mode.") + if self.file is None: + with self: + header = self.reader.fieldnames + else: + header = self.reader.fieldnames + return header + + def header_columns(self) -> Sequence[str]: + return self.schema.header + + def write(self, item: Any): + """Write an item to the csv file.""" + if self.mode != "w": + raise RuntimeError("Cannot write to dataset in read mode") + if not self.writer: + raise RuntimeError("Must be in a context to write.") + self.writer.writerow(self.item_to_row(item)) + + def row_to_item(self, row: dict): + """Transform a single dict-row from the csv file into a dataset object.""" + raise NotImplementedError("Subclasses that enable reading must implement this method.") + + def item_to_row(self, item: Any) -> list[str]: + """Transform a dataset object into a list of strings that can be written to a csv file.""" + raise NotImplementedError("Subclasses that enable writing must implement this method.") + + +class PromptDataset(BaseDataset): + """Dataset for reading prompts as TestItems from a CSV file. Read only.""" + + def __init__(self, path: Union[str, Path]): + super().__init__(path, "r") + + def _init_schema(self): + self.schema = PromptSchema(self._read_header()) + + def row_to_item(self, row: dict) -> TestItem: + """Convert a single prompt row to a TestItem.""" + return TestItem( + prompt=TextPrompt(text=row[self.schema.prompt_text]), + source_id=row[self.schema.prompt_uid], + context=row, + ) + + +class PromptResponseDataset(BaseDataset): + """Dataset for prompt-response CSV data. Read or write.""" + + def _init_schema(self): + if self.mode == "r": + self.schema = PromptResponseSchema(self._read_header()) + else: + self.schema = DEFAULT_PROMPT_RESPONSE_SCHEMA + + def row_to_item(self, row: dict) -> SutInteraction: + prompt = TestItem( + prompt=TextPrompt(text=row[self.schema.prompt_text]), + source_id=row[self.schema.prompt_uid], + context=row, + ) + response = SUTResponse(text=row[self.schema.sut_response]) + return SutInteraction(prompt, row[self.schema.sut_uid], response) + + def item_to_row(self, item: SutInteraction) -> list[str]: + if not isinstance(item.prompt.prompt, TextPrompt): + raise ValueError(f"Error handling {item}. Can only handle TextPrompts.") + + return [ + item.prompt.source_id, + item.prompt.prompt.text, + item.sut_uid, + item.response.text, + ] + + +class AnnotationDataset(BaseDataset): + """Dataset for annotated prompt-response CSV data. Read or write.""" + + def _init_schema(self): + if self.mode == "r": + self.schema = AnnotationSchema(self._read_header()) + else: + self.schema = DEFAULT_ANNOTATION_SCHEMA + + # TODO: New annotation object + def row_to_item(self, row: dict) -> tuple[SutInteraction, Optional[Dict[str, Any]]]: + prompt = TestItem( + prompt=TextPrompt(text=row[self.schema.prompt_text]), + source_id=row[self.schema.prompt_uid], + context=row, + ) + response = SUTResponse(text=row[self.schema.sut_response]) + interaction = SutInteraction(prompt, row[self.schema.sut_uid], response) + + # Extract annotations if present + annotations = row.get(self.schema.annotation) + return interaction, annotations + + def item_to_row(self, item: SutInteraction, annotations: Optional[Dict[str, Any]] = None) -> list[str]: + if not isinstance(item.prompt.prompt, TextPrompt): + raise ValueError(f"Error handling {item}. Can only handle TextPrompts.") + return [ + item.prompt.source_id, + item.prompt.prompt.text, + item.sut_uid, + item.response.text, + annotations.is_safe, + ] diff --git a/src/modelgauge/pipeline_runner.py b/src/modelgauge/pipeline_runner.py index 01e76e409..93a707482 100644 --- a/src/modelgauge/pipeline_runner.py +++ b/src/modelgauge/pipeline_runner.py @@ -12,6 +12,7 @@ EnsembleVoter, JsonlAnnotatorOutput, ) +from modelgauge.dataset import PromptDataset from modelgauge.pipeline import Pipeline from modelgauge.prompt_pipeline import ( PromptSource, diff --git a/src/modelgauge/prompt_pipeline.py b/src/modelgauge/prompt_pipeline.py index 2c1a191ff..c64b4ff0f 100644 --- a/src/modelgauge/prompt_pipeline.py +++ b/src/modelgauge/prompt_pipeline.py @@ -3,28 +3,18 @@ import time from abc import ABCMeta, abstractmethod from collections import defaultdict -from dataclasses import dataclass from typing import Iterable, Optional +from modelgauge.dataset import PromptDataset, PromptResponseDataset from modelgauge.data_schema import DEFAULT_PROMPT_RESPONSE_SCHEMA, DEFAULT_PROMPT_SCHEMA, PromptSchema from modelgauge.pipeline import CachingPipe, Pipe, Sink, Source from modelgauge.prompt import TextPrompt -from modelgauge.single_turn_prompt_response import TestItem +from modelgauge.single_turn_prompt_response import SUTInteraction, TestItem from modelgauge.sut import PromptResponseSUT, SUT, SUTOptions, SUTResponse logger = logging.getLogger(__name__) -@dataclass -class SutInteraction: - prompt: TestItem - sut_uid: str - response: SUTResponse - - def __hash__(self): - return hash(self.prompt.source_id + self.sut_uid) - - class PromptInput(metaclass=ABCMeta): """ Your subclass should implement __iter__ such that it yields TestItem objects. @@ -42,6 +32,7 @@ def __len__(self): return count +# TODO: Delete: replace with PromptDataset. class CsvPromptInput(PromptInput): def __init__(self, path): super().__init__() @@ -109,12 +100,9 @@ def write(self, item: TestItem, results): if sut in results: self.writer.writerow(base_row + [sut, results[sut]]) - def launder_the_type_problem(self, item) -> str: - return item.prompt.text - class PromptSource(Source): - def __init__(self, input: PromptInput): + def __init__(self, input: PromptDataset): super().__init__() self.input = input @@ -151,7 +139,7 @@ def handle_uncached_item(self, item): prompt_item: TestItem prompt_item, sut_uid = item response = self.call_sut(prompt_item.prompt, self.suts[sut_uid]) - return SutInteraction(prompt_item, sut_uid, response) + return SUTInteraction(prompt_item, sut_uid, response) def call_sut(self, prompt_text: TextPrompt, sut: PromptResponseSUT) -> SUTResponse: request = sut.translate_text_prompt(prompt_text, self.sut_options) @@ -182,7 +170,7 @@ def run(self): with self.writer: super().run() - def handle_item(self, item: SutInteraction): + def handle_item(self, item: SUTInteraction): self.unfinished[item.prompt][item.sut_uid] = item.response.text if len(self.unfinished[item.prompt]) == len(self.suts): self.writer.write(item.prompt, self.unfinished[item.prompt]) diff --git a/src/modelgauge/single_turn_prompt_response.py b/src/modelgauge/single_turn_prompt_response.py index 6554618e7..e679fe2ff 100644 --- a/src/modelgauge/single_turn_prompt_response.py +++ b/src/modelgauge/single_turn_prompt_response.py @@ -1,4 +1,5 @@ from typing import Dict, Mapping, Optional, Type, TypeVar +from dataclasses import dataclass from pydantic import BaseModel, Field @@ -72,3 +73,13 @@ class MeasuredTestItem(BaseModel): test_item: TestItem measurements: Dict[str, float] + + +@dataclass +class SUTInteraction: + prompt: TestItem + sut_uid: str + response: SUTResponse + + def __hash__(self): + return hash(self.prompt.source_id + self.sut_uid) diff --git a/tests/modelgauge_tests/test_annotation_pipeline.py b/tests/modelgauge_tests/test_annotation_pipeline.py index df1477554..ebf2226fc 100644 --- a/tests/modelgauge_tests/test_annotation_pipeline.py +++ b/tests/modelgauge_tests/test_annotation_pipeline.py @@ -5,7 +5,6 @@ from unittest.mock import MagicMock from modelgauge.annotation_pipeline import ( - SutInteraction, AnnotatorInput, AnnotatorSource, AnnotatorAssigner, @@ -25,7 +24,7 @@ PromptSutAssigner, PromptSutWorkers, ) -from modelgauge.single_turn_prompt_response import TestItem +from modelgauge.single_turn_prompt_response import SUTInteraction, TestItem from modelgauge.sut import SUTResponse from modelgauge_tests.fake_annotator import ( FakeAnnotation, @@ -51,7 +50,7 @@ def __iter__(self): context=row, ) response = SUTResponse(text=row[PROMPT_RESPONSE_SCHEMA.sut_response]) - yield SutInteraction(prompt, row[PROMPT_RESPONSE_SCHEMA.sut_uid], response) + yield SUTInteraction(prompt, row[PROMPT_RESPONSE_SCHEMA.sut_uid], response) class FakeAnnotatorOutput(PromptOutput): @@ -63,7 +62,7 @@ def write(self, item, annotations): def make_sut_interaction(source_id, prompt, sut_uid, response): - return SutInteraction( + return SUTInteraction( TestItem(source_id=source_id, prompt=TextPrompt(text=prompt)), sut_uid, SUTResponse(text=response), @@ -88,7 +87,7 @@ def test_csv_annotator_input(tmp_path): input = CsvAnnotatorInput(file_path) assert len(input) == 1 - item: SutInteraction = next(iter(input)) + item: SUTInteraction = next(iter(input)) assert sut_interactions_is_equal(item, make_sut_interaction("1", "a", "s", "b")) diff --git a/tests/modelgauge_tests/test_dataset.py b/tests/modelgauge_tests/test_dataset.py new file mode 100644 index 000000000..bfcd615ea --- /dev/null +++ b/tests/modelgauge_tests/test_dataset.py @@ -0,0 +1,392 @@ +import pytest +from pathlib import Path +from typing import Iterable + +from modelgauge.data_schema import ( + DEFAULT_PROMPT_RESPONSE_SCHEMA, + DEFAULT_PROMPT_SCHEMA, + SchemaValidationError, +) +from modelgauge.dataset import ( + AnnotationDataset, + BaseDataset, + PromptDataset, + PromptResponseDataset, + SutInteraction, +) +from modelgauge.prompt import TextPrompt +from modelgauge.single_turn_prompt_response import TestItem +from modelgauge.sut import SUTResponse + + +class TestBaseDataset: + """Tests for the base dataset functionality.""" + + class DummySchema: + def __init__(self): + self.header = ["col_a"] + self.col_a = "col_a" + + class DummyDataset(BaseDataset): + def __init__(self, path: Path, mode: str): + super().__init__(path, mode) + self.row_to_item_called = False + self.write_called = False + self.write_item = None + + def _init_schema(self): + self.schema = TestBaseDataset.DummySchema() + + def row_to_item(self, row: dict) -> str: + """Convert a row to a dummy item.""" + self.row_to_item_called = True + return row[self.schema.col_a] + + def item_to_row(self, item: str) -> list[str]: + """Write a dummy item.""" + self.write_called = True + self.write_item = item + return [item] + + @pytest.fixture + def dummy_csv(self, tmp_path): + """Create a dummy CSV file.""" + file_path = tmp_path / "dummy.csv" + file_path.write_text("col_a\ndata1") + return file_path + + @pytest.fixture + def dummy_read_dataset(self, dummy_csv): + return self.DummyDataset(dummy_csv, mode="r") + + @pytest.fixture + def dummy_write_dataset(self, tmp_path): + return self.DummyDataset(tmp_path / "dummy.csv", mode="w") + + def test_invalid_mode(self, tmp_path): + with pytest.raises(AssertionError, match="Invalid dataset mode"): + self.DummyDataset(tmp_path / "data.csv", mode="x") + + def test_file_not_found(self): + with pytest.raises(FileNotFoundError): + self.DummyDataset("nonexistent.csv", mode="r") + + def test_cannot_use_existing_file_for_write(self, dummy_csv): + with pytest.raises(FileExistsError): + self.DummyDataset(dummy_csv, mode="w") + + def test_len(self, dummy_read_dataset): + assert len(dummy_read_dataset) == 1 + + def test_len_in_write_mode(self, dummy_write_dataset): + with pytest.raises(NotImplementedError, match="Length not supported in write mode"): + len(dummy_write_dataset) + + def test_schema_is_set_in_initialization(self, dummy_read_dataset): + assert isinstance(dummy_read_dataset.schema, self.DummySchema) + + def test_header_columns(self, dummy_read_dataset): + assert dummy_read_dataset.header_columns() == ["col_a"] + + def test_context_manager_read(self, dummy_read_dataset): + """Test that context manager properly opens and closes files in read mode.""" + assert dummy_read_dataset.file is None + + with dummy_read_dataset: + assert dummy_read_dataset.file is not None + assert not dummy_read_dataset.file.closed + assert dummy_read_dataset.reader is not None + + assert dummy_read_dataset.file is None + assert dummy_read_dataset.reader is None + + def test_context_manager_write(self, dummy_write_dataset): + """Test that context manager properly opens and closes files in write mode.""" + assert dummy_write_dataset.file is None + + with dummy_write_dataset: + assert dummy_write_dataset.file is not None + assert not dummy_write_dataset.file.closed + assert dummy_write_dataset.writer is not None + + assert dummy_write_dataset.file is None + assert dummy_write_dataset.writer is None + + def test_read_in_write_mode(self, dummy_write_dataset): + """Test that reading in write mode raises an error.""" + with pytest.raises(RuntimeError, match="Can only iterate over dataset in read mode."): + for _ in dummy_write_dataset: + break # Iteration forces read. + + def test_write_in_read_mode(self, dummy_read_dataset): + """Test that writing in read mode raises an error.""" + with pytest.raises(RuntimeError, match="Cannot write to dataset in read mode"): + with dummy_read_dataset: + dummy_read_dataset.write("test") + + def test_row_to_item_called(self, dummy_read_dataset): + """Test that row_to_item is called when iterating.""" + assert dummy_read_dataset.row_to_item_called is False + with dummy_read_dataset: + for obj in dummy_read_dataset: + assert dummy_read_dataset.row_to_item_called is True + assert obj == "data1" + + def test_write_operation(self, dummy_write_dataset): + """Test that write operation works in write mode.""" + with dummy_write_dataset: + dummy_write_dataset.write("test_data") + assert dummy_write_dataset.write_called is True + assert dummy_write_dataset.write_item == "test_data" + + def test_header_gets_written_to_new_file(self, dummy_write_dataset): + """Test that header is written to a new file.""" + with dummy_write_dataset: + pass + assert dummy_write_dataset.path.read_text() == "col_a\n" + + def test_append_to_existing_file(self, dummy_write_dataset): + """Test that header does not get re-written if context manager is entered twice.""" + with dummy_write_dataset: + pass + with dummy_write_dataset: + pass + assert dummy_write_dataset.path.read_text() == "col_a\n" + + def test_iteration_enters_exits_context(self, dummy_read_dataset): + """Test that iteration enters and exits the context.""" + assert dummy_read_dataset.file is None + for obj in dummy_read_dataset: + assert dummy_read_dataset.file is not None + assert not dummy_read_dataset.file.closed + assert obj == "data1" + break + assert dummy_read_dataset.file is None + + +class TestPromptDataset: + @pytest.fixture + def sample_prompts_csv(self, tmp_path): + """Create a sample CSV file with prompts only.""" + file_path = tmp_path / "prompts.csv" + content = ( + f"{DEFAULT_PROMPT_SCHEMA.prompt_uid},{DEFAULT_PROMPT_SCHEMA.prompt_text}\n" + "p1,Say hello\n" + "p2,Say goodbye" + ) + file_path.write_text(content) + return file_path + + @pytest.fixture + def prompts_dataset(self, sample_prompts_csv): + return PromptDataset(sample_prompts_csv) + + def test_header_columns(self, prompts_dataset): + assert prompts_dataset.header_columns() == DEFAULT_PROMPT_SCHEMA.header + + def test_iterate_explicit_context(self, prompts_dataset): + with prompts_dataset as dataset: + items = [] + for item in dataset: + items.append(item) + assert len(items) == 2 + assert all(isinstance(item, TestItem) for item in items) + + assert items[0].source_id == "p1" + assert items[0].prompt.text == "Say hello" + + assert items[1].source_id == "p2" + assert items[1].prompt.text == "Say goodbye" + + def test_iterate_implicit_context(self, prompts_dataset): + items = list(prompts_dataset) + assert len(items) == 2 + assert all(isinstance(item, TestItem) for item in items) + + def test_invalid_schema(self, tmp_path): + """Test that reading a CSV file with invalid schema raises an error.""" + file_path = tmp_path / "invalid.csv" + file_path.write_text("column1,column2\na,b\n") + + with pytest.raises(SchemaValidationError): + PromptDataset(file_path) + + +class TestPromptResponseDataset: + @pytest.fixture + def sample_responses_csv(self, tmp_path): + """Create a sample CSV file with prompt-response data.""" + file_path = tmp_path / "responses.csv" + content = ( + f"{DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid},{DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text}," + f"{DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid},{DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response}\n" + "p1,Say hello,sut1,Hello world\n" + "p2,Say goodbye,sut1,Goodbye world" + ) + file_path.write_text(content) + return file_path + + def test_schema_read(self, sample_responses_csv): + dataset = PromptResponseDataset(sample_responses_csv, mode="r") + assert dataset.schema.header == DEFAULT_PROMPT_RESPONSE_SCHEMA.header + + def test_schema_write(self, tmp_path): + dataset = PromptResponseDataset(tmp_path / "responses.csv", mode="w") + assert dataset.schema == DEFAULT_PROMPT_RESPONSE_SCHEMA + + def test_read_csv(self, sample_responses_csv): + with PromptResponseDataset(sample_responses_csv, mode="r") as dataset: + interactions = list(dataset) + assert len(interactions) == 2 + assert all(isinstance(interaction, SutInteraction) for interaction in interactions) + + # Check first interaction + assert interactions[0].prompt.source_id == "p1" + assert interactions[0].prompt.prompt.text == "Say hello" + assert interactions[0].sut_uid == "sut1" + assert interactions[0].response.text == "Hello world" + + # Check second interaction + assert interactions[1].prompt.source_id == "p2" + assert interactions[1].prompt.prompt.text == "Say goodbye" + assert interactions[1].sut_uid == "sut1" + assert interactions[1].response.text == "Goodbye world" + + def test_write_csv(self, tmp_path): + output_file = tmp_path / "output.csv" + + # Create test data + interaction = SutInteraction( + prompt=TestItem(prompt=TextPrompt(text="Test prompt"), source_id="test1", context={}), + sut_uid="sut1", + response=SUTResponse(text="Test response"), + ) + + # Write data + with PromptResponseDataset(output_file, mode="w") as dataset: + dataset.write(interaction) + + # Verify written data + assert output_file.exists() + content = output_file.read_text() + expected_header = ( + f"{DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid},{DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text}," + f"{DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid},{DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response}\n" + ) + expected_data = "test1,Test prompt,sut1,Test response\n" + assert content == expected_header + expected_data + + +# class TestAnnotationDataset: +# @pytest.fixture +# def sample_annotations_jsonl(tmp_path): +# """Create a sample JSONL file with annotated prompt-response data.""" +# file_path = tmp_path / "annotations.jsonl" +# content = [ +# { +# DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid: "p1", +# DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text: "Say hello", +# DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid: "sut1", +# DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response: "Hello world", +# "Annotations": {"toxicity": 0.1} +# }, +# { +# DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid: "p2", +# DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text: "Say goodbye", +# DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid: "sut1", +# DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response: "Goodbye world", +# "Annotations": {"toxicity": 0.0} +# } +# ] +# import jsonlines +# with jsonlines.open(file_path, mode='w') as writer: +# writer.write_all(content) +# return file_path +# +# def test_read_jsonl(self, sample_annotations_jsonl): +# """Test reading annotated prompt-response pairs from a JSONL file.""" +# with AnnotationDataset(sample_annotations_jsonl, mode='r') as dataset: +# interactions = list(dataset) +# assert len(interactions) == 2 + +# # Check first interaction +# interaction, annotations = interactions[0] +# assert interaction.prompt.source_id == "p1" +# assert interaction.prompt.prompt.text == "Say hello" +# assert interaction.sut_uid == "sut1" +# assert interaction.response.text == "Hello world" +# assert annotations == {"toxicity": 0.1} + +# # Check second interaction +# interaction, annotations = interactions[1] +# assert interaction.prompt.source_id == "p2" +# assert annotations == {"toxicity": 0.0} + +# def test_read_csv(self, sample_responses_csv): +# """Test reading from a CSV file (should have no annotations).""" +# with AnnotationDataset(sample_responses_csv, mode='r') as dataset: +# interactions = list(dataset) +# assert len(interactions) == 2 + +# # Check that annotations are None +# for interaction, annotations in interactions: +# assert annotations is None + +# def test_write_jsonl(self, tmp_path): +# """Test writing annotated prompt-response pairs to a JSONL file.""" +# output_file = tmp_path / "output.jsonl" + +# # Create test data +# interaction = SutInteraction( +# prompt=TestItem( +# prompt=TextPrompt(text="Test prompt"), +# source_id="test1", +# context={} +# ), +# sut_uid="sut1", +# response=SUTResponse(text="Test response") +# ) +# annotations = {"toxicity": 0.5} + +# # Write data +# with AnnotationDataset(output_file, mode='w') as dataset: +# dataset.write(interaction, annotations) + +# # Verify written data +# import jsonlines +# with jsonlines.open(output_file) as reader: +# data = list(reader) +# assert len(data) == 1 +# assert data[0][DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid] == "test1" +# assert data[0][DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text] == "Test prompt" +# assert data[0][DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid] == "sut1" +# assert data[0][DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response] == "Test response" +# assert data[0]["Annotations"] == {"toxicity": 0.5} + +# def test_write_csv(self, tmp_path): +# """Test writing to a CSV file (should ignore annotations).""" +# output_file = tmp_path / "output.csv" + +# # Create test data +# interaction = SutInteraction( +# prompt=TestItem( +# prompt=TextPrompt(text="Test prompt"), +# source_id="test1", +# context={} +# ), +# sut_uid="sut1", +# response=SUTResponse(text="Test response") +# ) +# annotations = {"toxicity": 0.5} + +# # Write data +# with AnnotationDataset(output_file, mode='w') as dataset: +# dataset.write(interaction, annotations) + +# # Verify written data - should not include annotations +# assert output_file.exists() +# content = output_file.read_text() +# expected_header = f"{DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid},{DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text}," \ +# f"{DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid},{DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response}\n" +# expected_data = '"test1","Test prompt","sut1","Test response"\n' +# assert content == expected_header + expected_data diff --git a/tests/modelgauge_tests/test_prompt_pipeline.py b/tests/modelgauge_tests/test_prompt_pipeline.py index 5b1c7bdc5..4a4e3f99c 100644 --- a/tests/modelgauge_tests/test_prompt_pipeline.py +++ b/tests/modelgauge_tests/test_prompt_pipeline.py @@ -23,9 +23,8 @@ PromptSource, PromptSutAssigner, PromptSutWorkers, - SutInteraction, ) -from modelgauge.single_turn_prompt_response import TestItem +from modelgauge.single_turn_prompt_response import SUTInteraction,TestItem from modelgauge.sut import SUTOptions, SUTResponse from modelgauge_tests.fake_sut import FakeSUT, FakeSUTRequest, FakeSUTResponse @@ -145,7 +144,7 @@ def test_prompt_sut_worker_normal(suts): w = PromptSutWorkers(suts) result = w.handle_item((prompt_with_context, "fake1")) - assert result == SutInteraction(prompt_with_context, "fake1", SUTResponse(text="a response")) + assert result == SUTInteraction(prompt_with_context, "fake1", SUTResponse(text="a response")) def test_prompt_sut_worker_sends_prompt_options(suts): @@ -170,11 +169,11 @@ def test_prompt_sut_worker_cache(suts, tmp_path): w = PromptSutWorkers(suts, cache_path=tmp_path) result = w.handle_item((prompt_with_context, "fake1")) - assert result == SutInteraction(prompt_with_context, "fake1", SUTResponse(text="a response")) + assert result == SUTInteraction(prompt_with_context, "fake1", SUTResponse(text="a response")) assert mock.call_count == 1 result = w.handle_item((prompt_with_context, "fake1")) - assert result == SutInteraction(prompt_with_context, "fake1", SUTResponse(text="a response")) + assert result == SUTInteraction(prompt_with_context, "fake1", SUTResponse(text="a response")) assert mock.call_count == 1 @@ -189,7 +188,7 @@ def test_prompt_sut_worker_retries_until_success(suts): w = PromptSutWorkers(suts) w.sleep_time = 0.01 result = w.handle_item((prompt_with_context, "fake1")) - assert result == SutInteraction(prompt_with_context, "fake1", SUTResponse(text="a response")) + assert result == SUTInteraction(prompt_with_context, "fake1", SUTResponse(text="a response")) assert mock.call_count == num_exceptions + 1 From ea68e93384d6ac50b6cc0e958a61f86ac9ac9b22 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 1 Jul 2025 12:54:52 -0700 Subject: [PATCH 06/15] Use PromptDataset as input to PromptRunner. Delete CsvPromptInput --- src/modelgauge/dataset.py | 18 ++++---- src/modelgauge/pipeline_runner.py | 3 +- src/modelgauge/prompt_pipeline.py | 42 ------------------- tests/modelgauge_tests/test_dataset.py | 2 +- .../modelgauge_tests/test_pipeline_runner.py | 6 +-- .../modelgauge_tests/test_prompt_pipeline.py | 11 +++-- 6 files changed, 19 insertions(+), 63 deletions(-) diff --git a/src/modelgauge/dataset.py b/src/modelgauge/dataset.py index 02ad4416a..8802b9ec4 100644 --- a/src/modelgauge/dataset.py +++ b/src/modelgauge/dataset.py @@ -12,7 +12,7 @@ PromptSchema, ) from modelgauge.prompt import TextPrompt -from modelgauge.single_turn_prompt_response import SutInteraction, TestItem +from modelgauge.single_turn_prompt_response import SUTInteraction, TestItem from modelgauge.sut import SUTResponse @@ -21,8 +21,8 @@ class BaseDataset(ABC): def __init__(self, path: Union[str, Path], mode: str): """Args: - path: Path to the dataset file - mode: Mode to open the file in ('r' for read, 'w' for write) + path: Path to the dataset file + mode: Mode to open the file in ('r' for read, 'w' for write) """ self.path = Path(path) self.mode = mode @@ -36,7 +36,7 @@ def __init__(self, path: Union[str, Path], mode: str): self.writer = None self.reader = None self.schema = None - self._init_schema() # Initialized by subclass. + self._init_schema() # Initialized by subclass. def __enter__(self): """Context manager entry. Opens the file and sets the reader or writer.""" @@ -156,16 +156,16 @@ def _init_schema(self): else: self.schema = DEFAULT_PROMPT_RESPONSE_SCHEMA - def row_to_item(self, row: dict) -> SutInteraction: + def row_to_item(self, row: dict) -> SUTInteraction: prompt = TestItem( prompt=TextPrompt(text=row[self.schema.prompt_text]), source_id=row[self.schema.prompt_uid], context=row, ) response = SUTResponse(text=row[self.schema.sut_response]) - return SutInteraction(prompt, row[self.schema.sut_uid], response) + return SUTInteraction(prompt, row[self.schema.sut_uid], response) - def item_to_row(self, item: SutInteraction) -> list[str]: + def item_to_row(self, item: SUTInteraction) -> list[str]: if not isinstance(item.prompt.prompt, TextPrompt): raise ValueError(f"Error handling {item}. Can only handle TextPrompts.") @@ -187,7 +187,7 @@ def _init_schema(self): self.schema = DEFAULT_ANNOTATION_SCHEMA # TODO: New annotation object - def row_to_item(self, row: dict) -> tuple[SutInteraction, Optional[Dict[str, Any]]]: + def row_to_item(self, row: dict) -> tuple[SUTInteraction, Optional[Dict[str, Any]]]: prompt = TestItem( prompt=TextPrompt(text=row[self.schema.prompt_text]), source_id=row[self.schema.prompt_uid], @@ -200,7 +200,7 @@ def row_to_item(self, row: dict) -> tuple[SutInteraction, Optional[Dict[str, Any annotations = row.get(self.schema.annotation) return interaction, annotations - def item_to_row(self, item: SutInteraction, annotations: Optional[Dict[str, Any]] = None) -> list[str]: + def item_to_row(self, item: SUTInteraction, annotations: Optional[Dict[str, Any]] = None) -> list[str]: if not isinstance(item.prompt.prompt, TextPrompt): raise ValueError(f"Error handling {item}. Can only handle TextPrompts.") return [ diff --git a/src/modelgauge/pipeline_runner.py b/src/modelgauge/pipeline_runner.py index 93a707482..b2168d2ff 100644 --- a/src/modelgauge/pipeline_runner.py +++ b/src/modelgauge/pipeline_runner.py @@ -19,7 +19,6 @@ PromptSutAssigner, PromptSutWorkers, PromptSink, - CsvPromptInput, CsvPromptOutput, ) from modelgauge.sut import SUTOptions @@ -137,7 +136,7 @@ def metadata(self): return {**super().metadata(), **self._sut_metadata()} def _add_prompt_segments(self, include_sink=True): - input = CsvPromptInput(self.input_path) + input = PromptDataset(self.input_path) self.pipeline_segments.append(PromptSource(input)) self.pipeline_segments.append(PromptSutAssigner(self.suts)) self.sut_worker = PromptSutWorkers( diff --git a/src/modelgauge/prompt_pipeline.py b/src/modelgauge/prompt_pipeline.py index c64b4ff0f..a1495835a 100644 --- a/src/modelgauge/prompt_pipeline.py +++ b/src/modelgauge/prompt_pipeline.py @@ -15,48 +15,6 @@ logger = logging.getLogger(__name__) -class PromptInput(metaclass=ABCMeta): - """ - Your subclass should implement __iter__ such that it yields TestItem objects. - Note that the source_id field must be set. - """ - - @abstractmethod - def __iter__(self) -> Iterable[TestItem]: - pass - - def __len__(self): - count = 0 - for prompt in self: - count += 1 - return count - - -# TODO: Delete: replace with PromptDataset. -class CsvPromptInput(PromptInput): - def __init__(self, path): - super().__init__() - self.path = path - self.schema = PromptSchema(self._header()) # Validate header and store the schema. - - def _header(self) -> list[str]: - with open(self.path, newline="") as f: - csvreader = csv.reader(f) - return next(csvreader) - - def __iter__(self) -> Iterable[TestItem]: - with open(self.path, newline="") as f: - csvreader = csv.DictReader(f) - for row in csvreader: - yield TestItem( - prompt=TextPrompt(text=row[self.schema.prompt_text]), - # Forward the underlying id to help make data tracking easier. - source_id=row[self.schema.prompt_uid], - # Context can be any type you want. - context=row, - ) - - class PromptOutput(metaclass=ABCMeta): def __enter__(self): return self diff --git a/tests/modelgauge_tests/test_dataset.py b/tests/modelgauge_tests/test_dataset.py index bfcd615ea..402d0825a 100644 --- a/tests/modelgauge_tests/test_dataset.py +++ b/tests/modelgauge_tests/test_dataset.py @@ -77,7 +77,7 @@ def test_cannot_use_existing_file_for_write(self, dummy_csv): def test_len(self, dummy_read_dataset): assert len(dummy_read_dataset) == 1 - + def test_len_in_write_mode(self, dummy_write_dataset): with pytest.raises(NotImplementedError, match="Length not supported in write mode"): len(dummy_write_dataset) diff --git a/tests/modelgauge_tests/test_pipeline_runner.py b/tests/modelgauge_tests/test_pipeline_runner.py index 22aabab84..203f74317 100644 --- a/tests/modelgauge_tests/test_pipeline_runner.py +++ b/tests/modelgauge_tests/test_pipeline_runner.py @@ -9,6 +9,7 @@ CsvAnnotatorInput, ) from modelgauge.annotator_set import AnnotatorSet +from modelgauge.dataset import PromptDataset from modelgauge.data_schema import ( DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA, DEFAULT_PROMPT_SCHEMA as PROMPT_SCHEMA, @@ -25,7 +26,6 @@ PromptSutAssigner, PromptSutWorkers, PromptSink, - CsvPromptInput, CsvPromptOutput, ) from modelgauge.sut import SUTOptions @@ -142,7 +142,7 @@ def test_pipeline_segments(self, tmp_path, prompts_file, suts): source, sut_assigner, sut_workers, sink = runner.pipeline_segments assert isinstance(source, PromptSource) - assert isinstance(source.input, CsvPromptInput) + assert isinstance(source.input, PromptDataset) assert source.input.path == prompts_file assert isinstance(sut_assigner, PromptSutAssigner) @@ -250,7 +250,7 @@ def test_pipeline_segments(self, tmp_path, prompts_file, suts, annotators): source, sut_assigner, sut_workers, annotator_assigner, annotator_workers, sink = runner.pipeline_segments assert isinstance(source, PromptSource) - assert isinstance(source.input, CsvPromptInput) + assert isinstance(source.input, PromptDataset) assert source.input.path == prompts_file assert isinstance(sut_assigner, PromptSutAssigner) diff --git a/tests/modelgauge_tests/test_prompt_pipeline.py b/tests/modelgauge_tests/test_prompt_pipeline.py index 4a4e3f99c..3ecdaeacd 100644 --- a/tests/modelgauge_tests/test_prompt_pipeline.py +++ b/tests/modelgauge_tests/test_prompt_pipeline.py @@ -7,6 +7,7 @@ import pytest +from modelgauge.dataset import PromptDataset from modelgauge.data_schema import ( DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA, DEFAULT_PROMPT_SCHEMA as PROMPT_SCHEMA, @@ -15,16 +16,14 @@ from modelgauge.pipeline import Pipeline, PipelineSegment from modelgauge.prompt import TextPrompt from modelgauge.prompt_pipeline import ( - CsvPromptInput, CsvPromptOutput, - PromptInput, PromptOutput, PromptSink, PromptSource, PromptSutAssigner, PromptSutWorkers, ) -from modelgauge.single_turn_prompt_response import SUTInteraction,TestItem +from modelgauge.single_turn_prompt_response import SUTInteraction, TestItem from modelgauge.sut import SUTOptions, SUTResponse from modelgauge_tests.fake_sut import FakeSUT, FakeSUTRequest, FakeSUTResponse @@ -45,7 +44,7 @@ def __exit__(self, type, value, traceback): signal.alarm(0) -class FakePromptInput(PromptInput): +class FakePromptInput: def __init__(self, items: list[dict], delay=None): super().__init__() self.items = items @@ -88,7 +87,7 @@ def suts(): def test_csv_prompt_input(tmp_path): file_path = tmp_path / "input.csv" file_path.write_text(f'{PROMPT_SCHEMA.prompt_uid},{PROMPT_SCHEMA.prompt_text}\n"1","a"') - input = CsvPromptInput(file_path) + input = PromptDataset(file_path) assert len(input) == 1 items: List[TestItem] = [i for i in input] @@ -102,7 +101,7 @@ def test_csv_prompt_input_invalid_columns(tmp_path, header): file_path = tmp_path / "input.csv" file_path.write_text(header) with pytest.raises(SchemaValidationError): - CsvPromptInput(file_path) + PromptDataset(file_path) def test_csv_prompt_output(tmp_path, suts): From 5f8bc470e7a89501a1c1752789d4c52dbd99a76c Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 1 Jul 2025 14:01:24 -0700 Subject: [PATCH 07/15] Use PromptResponseDataset instead of CSVPromptOutput in prompt runner --- src/modelgauge/pipeline_runner.py | 7 +- src/modelgauge/prompt_pipeline.py | 45 +------------ tests/modelgauge_tests/test_dataset.py | 7 +- .../modelgauge_tests/test_pipeline_runner.py | 7 +- .../modelgauge_tests/test_prompt_pipeline.py | 67 +++++++------------ 5 files changed, 36 insertions(+), 97 deletions(-) diff --git a/src/modelgauge/pipeline_runner.py b/src/modelgauge/pipeline_runner.py index b2168d2ff..f40d23ffd 100644 --- a/src/modelgauge/pipeline_runner.py +++ b/src/modelgauge/pipeline_runner.py @@ -12,14 +12,13 @@ EnsembleVoter, JsonlAnnotatorOutput, ) -from modelgauge.dataset import PromptDataset +from modelgauge.dataset import PromptDataset, PromptResponseDataset from modelgauge.pipeline import Pipeline from modelgauge.prompt_pipeline import ( PromptSource, PromptSutAssigner, PromptSutWorkers, PromptSink, - CsvPromptOutput, ) from modelgauge.sut import SUTOptions @@ -144,8 +143,8 @@ def _add_prompt_segments(self, include_sink=True): ) self.pipeline_segments.append(self.sut_worker) if include_sink: - output = CsvPromptOutput(self.output_dir() / self.output_file_name, self.suts) - self.pipeline_segments.append(PromptSink(self.suts, output)) + output = PromptResponseDataset(self.output_dir() / self.output_file_name, "w") + self.pipeline_segments.append(PromptSink(output)) def _sut_metadata(self): counts = self.sut_worker.sut_response_counts diff --git a/src/modelgauge/prompt_pipeline.py b/src/modelgauge/prompt_pipeline.py index a1495835a..cf6cf44db 100644 --- a/src/modelgauge/prompt_pipeline.py +++ b/src/modelgauge/prompt_pipeline.py @@ -15,6 +15,7 @@ logger = logging.getLogger(__name__) +# TODO: Delete. class PromptOutput(metaclass=ABCMeta): def __enter__(self): return self @@ -27,38 +28,6 @@ def write(self, item, results): pass -class CsvPromptOutput(PromptOutput): - """Outputs a CSV file where each row represents one SUT's response to a prompt.""" - - schema = DEFAULT_PROMPT_RESPONSE_SCHEMA - - def __init__(self, path, suts): - super().__init__() - assert path.suffix.lower() == ".csv", f"Invalid output file {path}. Must be of type CSV." - - self.path = path - self.suts = suts - self.file = None - self.writer = None - - def __enter__(self): - self.file = open(self.path, "w", newline="") - self.writer = csv.writer(self.file, quoting=csv.QUOTE_ALL) - self.writer.writerow( - [self.schema.prompt_uid, self.schema.prompt_text, self.schema.sut_uid, self.schema.sut_response] - ) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.file.close() - - def write(self, item: TestItem, results): - base_row = [item.source_id, item.prompt.text] # type: ignore - for sut in self.suts: - if sut in results: - self.writer.writerow(base_row + [sut, results[sut]]) - - class PromptSource(Source): def __init__(self, input: PromptDataset): super().__init__() @@ -116,21 +85,13 @@ def call_sut(self, prompt_text: TextPrompt, sut: PromptResponseSUT) -> SUTRespon class PromptSink(Sink): - unfinished: defaultdict[TestItem, dict[str, str]] - - def __init__(self, suts: dict[str, SUT], writer: PromptOutput): + def __init__(self, writer: PromptResponseDataset): super().__init__() - self.suts = suts self.writer = writer - self.unfinished = defaultdict(lambda: dict()) def run(self): with self.writer: super().run() def handle_item(self, item: SUTInteraction): - self.unfinished[item.prompt][item.sut_uid] = item.response.text - if len(self.unfinished[item.prompt]) == len(self.suts): - self.writer.write(item.prompt, self.unfinished[item.prompt]) - self._debug(f"wrote {item.prompt}") - del self.unfinished[item.prompt] + self.writer.write(item) diff --git a/tests/modelgauge_tests/test_dataset.py b/tests/modelgauge_tests/test_dataset.py index 402d0825a..54e1e602a 100644 --- a/tests/modelgauge_tests/test_dataset.py +++ b/tests/modelgauge_tests/test_dataset.py @@ -12,10 +12,9 @@ BaseDataset, PromptDataset, PromptResponseDataset, - SutInteraction, ) from modelgauge.prompt import TextPrompt -from modelgauge.single_turn_prompt_response import TestItem +from modelgauge.single_turn_prompt_response import SUTInteraction, TestItem from modelgauge.sut import SUTResponse @@ -238,7 +237,7 @@ def test_read_csv(self, sample_responses_csv): with PromptResponseDataset(sample_responses_csv, mode="r") as dataset: interactions = list(dataset) assert len(interactions) == 2 - assert all(isinstance(interaction, SutInteraction) for interaction in interactions) + assert all(isinstance(interaction, SUTInteraction) for interaction in interactions) # Check first interaction assert interactions[0].prompt.source_id == "p1" @@ -256,7 +255,7 @@ def test_write_csv(self, tmp_path): output_file = tmp_path / "output.csv" # Create test data - interaction = SutInteraction( + interaction = SUTInteraction( prompt=TestItem(prompt=TextPrompt(text="Test prompt"), source_id="test1", context={}), sut_uid="sut1", response=SUTResponse(text="Test response"), diff --git a/tests/modelgauge_tests/test_pipeline_runner.py b/tests/modelgauge_tests/test_pipeline_runner.py index 203f74317..a3db8f2f3 100644 --- a/tests/modelgauge_tests/test_pipeline_runner.py +++ b/tests/modelgauge_tests/test_pipeline_runner.py @@ -9,7 +9,7 @@ CsvAnnotatorInput, ) from modelgauge.annotator_set import AnnotatorSet -from modelgauge.dataset import PromptDataset +from modelgauge.dataset import PromptDataset, PromptResponseDataset from modelgauge.data_schema import ( DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA, DEFAULT_PROMPT_SCHEMA as PROMPT_SCHEMA, @@ -26,7 +26,6 @@ PromptSutAssigner, PromptSutWorkers, PromptSink, - CsvPromptOutput, ) from modelgauge.sut import SUTOptions from modelgauge_tests.fake_annotator import FakeAnnotator @@ -154,9 +153,7 @@ def test_pipeline_segments(self, tmp_path, prompts_file, suts): assert sut_workers.thread_count == 20 assert isinstance(sink, PromptSink) - assert sink.suts == suts - assert isinstance(sink.writer, CsvPromptOutput) - assert sink.writer.suts == suts + assert isinstance(sink.writer, PromptResponseDataset) def test_prompt_runner_num_input_items(self, runner_basic): assert runner_basic.num_input_items == NUM_PROMPTS diff --git a/tests/modelgauge_tests/test_prompt_pipeline.py b/tests/modelgauge_tests/test_prompt_pipeline.py index 3ecdaeacd..6d3aa4679 100644 --- a/tests/modelgauge_tests/test_prompt_pipeline.py +++ b/tests/modelgauge_tests/test_prompt_pipeline.py @@ -7,7 +7,7 @@ import pytest -from modelgauge.dataset import PromptDataset +from modelgauge.dataset import PromptDataset, PromptResponseDataset from modelgauge.data_schema import ( DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA, DEFAULT_PROMPT_SCHEMA as PROMPT_SCHEMA, @@ -16,8 +16,6 @@ from modelgauge.pipeline import Pipeline, PipelineSegment from modelgauge.prompt import TextPrompt from modelgauge.prompt_pipeline import ( - CsvPromptOutput, - PromptOutput, PromptSink, PromptSource, PromptSutAssigner, @@ -60,12 +58,13 @@ def __iter__(self): ) -class FakePromptOutput(PromptOutput): - def __init__(self): +class FakePromptOutput(PromptResponseDataset): + def __init__(self, path: str): self.output = [] + super().__init__(path, "w") - def write(self, item, results): - self.output.append({"item": item, "results": results}) + def write(self, item): + self.output.append(item) class FakeSUTWithDelay(FakeSUT): @@ -107,31 +106,19 @@ def test_csv_prompt_input_invalid_columns(tmp_path, header): def test_csv_prompt_output(tmp_path, suts): file_path = tmp_path / "output.csv" - with CsvPromptOutput(file_path, suts) as output: - output.write( - TestItem(source_id="1", prompt=TextPrompt(text="a")), - {"fake1": "a1", "fake2": "a2"}, - ) + with PromptResponseDataset(file_path, "w") as output: + test_item = TestItem(source_id="1", prompt=TextPrompt(text="a")) + sut_interaction = SUTInteraction(test_item, "fake1", SUTResponse(text="a1")) + output.write(sut_interaction) with open(file_path, "r", newline="") as f: # noinspection PyTypeChecker items: list[dict] = [i for i in (DictReader(f))] - assert len(items) == 2 + assert len(items) == 1 assert items[0][PROMPT_RESPONSE_SCHEMA.prompt_uid] == "1" assert items[0][PROMPT_RESPONSE_SCHEMA.prompt_text] == "a" assert items[0][PROMPT_RESPONSE_SCHEMA.sut_uid] == "fake1" assert items[0][PROMPT_RESPONSE_SCHEMA.sut_response] == "a1" - assert items[1][PROMPT_RESPONSE_SCHEMA.prompt_uid] == "1" - assert items[1][PROMPT_RESPONSE_SCHEMA.prompt_text] == "a" - assert items[1][PROMPT_RESPONSE_SCHEMA.sut_uid] == "fake2" - assert items[1][PROMPT_RESPONSE_SCHEMA.sut_response] == "a2" - - -@pytest.mark.parametrize("output_fname", ["output.jsonl", "output"]) -def test_csv_prompt_output_invalid(tmp_path, suts, output_fname): - file_path = tmp_path / output_fname - with pytest.raises(AssertionError, match=f"Invalid output file {file_path}. Must be of type CSV."): - CsvPromptOutput(file_path, suts) def test_prompt_sut_worker_normal(suts): @@ -191,37 +178,33 @@ def test_prompt_sut_worker_retries_until_success(suts): assert mock.call_count == num_exceptions + 1 -def test_full_run(suts): +def test_full_run(suts, tmp_path): input = FakePromptInput( [ {PROMPT_SCHEMA.prompt_uid: "1", PROMPT_SCHEMA.prompt_text: "a"}, {PROMPT_SCHEMA.prompt_uid: "2", PROMPT_SCHEMA.prompt_text: "b"}, ] ) - output = FakePromptOutput() + output = FakePromptOutput(tmp_path / "output.csv") p = Pipeline( PromptSource(input), PromptSutAssigner(suts), PromptSutWorkers(suts, workers=1), - PromptSink(suts, output), + PromptSink(output), debug=True, ) p.run() - assert len(output.output) == len(input.items) - assert sorted([r["item"].source_id for r in output.output]) == [i[PROMPT_SCHEMA.prompt_uid] for i in input.items] - row1 = output.output[0] - assert "fake1" in row1["results"] - assert "fake2" in row1["results"] - row2 = output.output[1] - assert "fake1" in row2["results"] - assert "fake2" in row2["results"] + assert len(output.output) == len(input.items) * len(suts) # One row per prompt per SUT + # Every sut uid and prompt uid should be present + assert set(row.sut_uid for row in output.output) == set(suts.keys()) + assert set(row.prompt.source_id for row in output.output) == {"1", "2"} @pytest.mark.parametrize("worker_count", [1, 2, 4, 8]) -def test_concurrency_with_delays(suts, worker_count): +def test_concurrency_with_delays(suts, worker_count, tmp_path): PipelineSegment.default_timeout = 0.001 # burn some CPU to make the tests run faster prompt_count = worker_count * 4 @@ -235,13 +218,13 @@ def test_concurrency_with_delays(suts, worker_count): [{PROMPT_SCHEMA.prompt_uid: str(i), PROMPT_SCHEMA.prompt_text: "text" + str(i)} for i in range(prompt_count)], delay=prompt_delays, ) - output = FakePromptOutput() + output = FakePromptOutput(tmp_path / "output.csv") p = Pipeline( PromptSource(input), PromptSutAssigner(suts), PromptSutWorkers(suts, workers=worker_count), - PromptSink(suts, output), + PromptSink(output), ) average_delay_per_prompt = sum(sut_delays) / len(sut_delays) + sum(prompt_delays) / len(sut_delays) @@ -249,17 +232,17 @@ def test_concurrency_with_delays(suts, worker_count): with timeout(5 + int(prompt_count * average_delay_per_prompt / worker_count)): p.run() - assert len(output.output) == len(input.items) + assert len(output.output) == len(input.items) * len(suts) -def test_progress(suts): +def test_progress(suts, tmp_path): input = FakePromptInput( [ {PROMPT_SCHEMA.prompt_uid: "1", PROMPT_SCHEMA.prompt_text: "a"}, {PROMPT_SCHEMA.prompt_uid: "2", PROMPT_SCHEMA.prompt_text: "b"}, ] ) - output = FakePromptOutput() + output = FakePromptOutput(tmp_path / "output.csv") def track_progress(data): progress_items.append(data.copy()) @@ -268,7 +251,7 @@ def track_progress(data): PromptSource(input), PromptSutAssigner(suts), PromptSutWorkers(suts, workers=2), - PromptSink(suts, output), + PromptSink(output), progress_callback=track_progress, ) progress_items = [] From afe604a5bc9576b214f2b60135b41a6726826141 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 1 Jul 2025 14:09:52 -0700 Subject: [PATCH 08/15] Quote all + only accept csv files in datasets --- src/modelgauge/dataset.py | 9 +++++++-- tests/modelgauge_tests/test_dataset.py | 14 +++++++++----- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/src/modelgauge/dataset.py b/src/modelgauge/dataset.py index 8802b9ec4..631086bab 100644 --- a/src/modelgauge/dataset.py +++ b/src/modelgauge/dataset.py @@ -19,12 +19,17 @@ class BaseDataset(ABC): """This class provides common functionality for CSV file handling and context management.""" + quoting = csv.QUOTE_ALL + def __init__(self, path: Union[str, Path], mode: str): """Args: path: Path to the dataset file mode: Mode to open the file in ('r' for read, 'w' for write) """ self.path = Path(path) + if self.path.suffix.lower() != ".csv": + raise ValueError(f"Invalid dataset file {path}. Must be a CSV file.") + self.mode = mode assert mode in ["r", "w"], f"Invalid dataset mode {mode}. Must be 'r' or 'w'." if self.mode == "r" and not self.path.exists(): @@ -45,12 +50,12 @@ def __enter__(self): if self.mode == "w" and not self.path.exists(): # New file, need to write header. self.file = open(self.path, mode=self.mode, newline="") - self.writer = csv.writer(self.file, quoting=csv.QUOTE_MINIMAL) + self.writer = csv.writer(self.file, quoting=self.quoting) self.writer.writerow(self.header_columns()) elif self.mode == "w": # Append to existing file. self.file = open(self.path, mode="a", newline="") - self.writer = csv.writer(self.file, quoting=csv.QUOTE_MINIMAL) + self.writer = csv.writer(self.file, quoting=self.quoting) elif self.mode == "r": self.file = open(self.path, mode=self.mode, newline="") self.reader = csv.DictReader(self.file) diff --git a/tests/modelgauge_tests/test_dataset.py b/tests/modelgauge_tests/test_dataset.py index 54e1e602a..267f66b08 100644 --- a/tests/modelgauge_tests/test_dataset.py +++ b/tests/modelgauge_tests/test_dataset.py @@ -51,7 +51,7 @@ def item_to_row(self, item: str) -> list[str]: def dummy_csv(self, tmp_path): """Create a dummy CSV file.""" file_path = tmp_path / "dummy.csv" - file_path.write_text("col_a\ndata1") + file_path.write_text('"col_a"\n"data1"') return file_path @pytest.fixture @@ -66,6 +66,10 @@ def test_invalid_mode(self, tmp_path): with pytest.raises(AssertionError, match="Invalid dataset mode"): self.DummyDataset(tmp_path / "data.csv", mode="x") + def test_invalid_file_extension(self, tmp_path): + with pytest.raises(ValueError, match="Invalid dataset file"): + self.DummyDataset(tmp_path / "data.txt", mode="r") + def test_file_not_found(self): with pytest.raises(FileNotFoundError): self.DummyDataset("nonexistent.csv", mode="r") @@ -142,7 +146,7 @@ def test_header_gets_written_to_new_file(self, dummy_write_dataset): """Test that header is written to a new file.""" with dummy_write_dataset: pass - assert dummy_write_dataset.path.read_text() == "col_a\n" + assert dummy_write_dataset.path.read_text() == '"col_a"\n' def test_append_to_existing_file(self, dummy_write_dataset): """Test that header does not get re-written if context manager is entered twice.""" @@ -150,7 +154,7 @@ def test_append_to_existing_file(self, dummy_write_dataset): pass with dummy_write_dataset: pass - assert dummy_write_dataset.path.read_text() == "col_a\n" + assert dummy_write_dataset.path.read_text() == '"col_a"\n' def test_iteration_enters_exits_context(self, dummy_read_dataset): """Test that iteration enters and exits the context.""" @@ -205,7 +209,7 @@ def test_iterate_implicit_context(self, prompts_dataset): def test_invalid_schema(self, tmp_path): """Test that reading a CSV file with invalid schema raises an error.""" file_path = tmp_path / "invalid.csv" - file_path.write_text("column1,column2\na,b\n") + file_path.write_text('"column1","column2"\n"a","b"\n') with pytest.raises(SchemaValidationError): PromptDataset(file_path) @@ -273,7 +277,7 @@ def test_write_csv(self, tmp_path): f"{DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid},{DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response}\n" ) expected_data = "test1,Test prompt,sut1,Test response\n" - assert content == expected_header + expected_data + assert content.replace('"', "") == expected_header + expected_data # class TestAnnotationDataset: From 679efea2c8b050131f661aa8243184c28d7e11f1 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 1 Jul 2025 14:23:58 -0700 Subject: [PATCH 09/15] Replace annotator input objects with PromptResponseDataset --- src/modelgauge/annotation_pipeline.py | 41 +------------------ src/modelgauge/pipeline_runner.py | 3 +- .../test_annotation_pipeline.py | 7 ++-- .../modelgauge_tests/test_pipeline_runner.py | 3 +- 4 files changed, 7 insertions(+), 47 deletions(-) diff --git a/src/modelgauge/annotation_pipeline.py b/src/modelgauge/annotation_pipeline.py index bc3bbd166..6e2e1feed 100644 --- a/src/modelgauge/annotation_pipeline.py +++ b/src/modelgauge/annotation_pipeline.py @@ -10,6 +10,7 @@ from modelgauge.annotation import Annotation from modelgauge.annotator import Annotator from modelgauge.annotator_set import AnnotatorSet +from modelgauge.dataset import PromptResponseDataset from modelgauge.data_schema import DEFAULT_PROMPT_RESPONSE_SCHEMA, PromptResponseSchema from modelgauge.pipeline import CachingPipe, Pipe, Sink, Source from modelgauge.prompt import TextPrompt @@ -20,44 +21,6 @@ logger = logging.getLogger(__name__) -class AnnotatorInput(metaclass=ABCMeta): - @abstractmethod - def __iter__(self) -> Iterable[SUTInteraction]: - pass - - def __len__(self): - count = 0 - for prompt in self: - count += 1 - return count - - -class CsvAnnotatorInput(AnnotatorInput): - def __init__(self, path): - super().__init__() - self.path = path - self.schema = PromptResponseSchema(self._header()) # Validate header and store the schema. - - def _header(self) -> list[str]: - with open(self.path, newline="") as f: - csvreader = csv.reader(f) - return next(csvreader) - - def __iter__(self) -> Iterable[SUTInteraction]: - with open(self.path, newline="") as f: - csvreader = csv.DictReader(f) - for row in csvreader: - prompt = TestItem( - prompt=TextPrompt(text=row[self.schema.prompt_text]), - # Forward the underlying id to help make data tracking easier. - source_id=row[self.schema.prompt_uid], - # Context can be any type you want. - context=row, - ) - response = SUTResponse(text=row[self.schema.sut_response]) - yield SUTInteraction(prompt, row[self.schema.sut_uid], response) - - class JsonlAnnotatorOutput(PromptOutput): def __init__(self, path): super().__init__() @@ -91,7 +54,7 @@ def write(self, item: SUTInteraction, results): class AnnotatorSource(Source): - def __init__(self, input: AnnotatorInput): + def __init__(self, input: PromptResponseDataset): super().__init__() self.input = input diff --git a/src/modelgauge/pipeline_runner.py b/src/modelgauge/pipeline_runner.py index f40d23ffd..d036d65b9 100644 --- a/src/modelgauge/pipeline_runner.py +++ b/src/modelgauge/pipeline_runner.py @@ -8,7 +8,6 @@ AnnotatorSink, AnnotatorSource, AnnotatorWorkers, - CsvAnnotatorInput, EnsembleVoter, JsonlAnnotatorOutput, ) @@ -192,7 +191,7 @@ def metadata(self): def _add_annotator_segments(self, include_source=True, include_sink=True): if include_source: - input = CsvAnnotatorInput(self.input_path) + input = PromptResponseDataset(self.input_path, mode="r") self.pipeline_segments.append(AnnotatorSource(input)) self.pipeline_segments.append(AnnotatorAssigner(self.annotators)) self.annotator_workers = AnnotatorWorkers(self.annotators, self.num_workers) diff --git a/tests/modelgauge_tests/test_annotation_pipeline.py b/tests/modelgauge_tests/test_annotation_pipeline.py index ebf2226fc..2446ae685 100644 --- a/tests/modelgauge_tests/test_annotation_pipeline.py +++ b/tests/modelgauge_tests/test_annotation_pipeline.py @@ -5,16 +5,15 @@ from unittest.mock import MagicMock from modelgauge.annotation_pipeline import ( - AnnotatorInput, AnnotatorSource, AnnotatorAssigner, AnnotatorWorkers, AnnotatorSink, - CsvAnnotatorInput, EnsembleVoter, JsonlAnnotatorOutput, ) from modelgauge.annotator_set import AnnotatorSet +from modelgauge.dataset import PromptResponseDataset from modelgauge.data_schema import DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA from modelgauge.pipeline import Pipeline from modelgauge.prompt import TextPrompt @@ -35,7 +34,7 @@ from modelgauge_tests.test_prompt_pipeline import FakePromptInput -class FakeAnnotatorInput(AnnotatorInput): +class FakeAnnotatorInput: def __init__(self, items: list[dict], delay=None): super().__init__() self.items = items @@ -84,7 +83,7 @@ def test_csv_annotator_input(tmp_path): file_path.write_text( f'{PROMPT_RESPONSE_SCHEMA.prompt_uid},{PROMPT_RESPONSE_SCHEMA.prompt_text},{PROMPT_RESPONSE_SCHEMA.sut_uid},{PROMPT_RESPONSE_SCHEMA.sut_response}\n"1","a","s","b"' ) - input = CsvAnnotatorInput(file_path) + input = PromptResponseDataset(file_path, mode="r") assert len(input) == 1 item: SUTInteraction = next(iter(input)) diff --git a/tests/modelgauge_tests/test_pipeline_runner.py b/tests/modelgauge_tests/test_pipeline_runner.py index a3db8f2f3..474364a8e 100644 --- a/tests/modelgauge_tests/test_pipeline_runner.py +++ b/tests/modelgauge_tests/test_pipeline_runner.py @@ -6,7 +6,6 @@ AnnotatorSink, AnnotatorSource, AnnotatorWorkers, - CsvAnnotatorInput, ) from modelgauge.annotator_set import AnnotatorSet from modelgauge.dataset import PromptDataset, PromptResponseDataset @@ -409,7 +408,7 @@ def test_pipeline_segments(self, tmp_path, prompt_responses_file, annotators): source, annotator_assigner, annotator_workers, sink = runner.pipeline_segments assert isinstance(source, AnnotatorSource) - assert isinstance(source.input, CsvAnnotatorInput) + assert isinstance(source.input, PromptResponseDataset) assert source.input.path == prompt_responses_file assert isinstance(annotator_assigner, AnnotatorAssigner) From ee4f4c92416795a49670b483e47bcec823c8538f Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 1 Jul 2025 16:30:52 -0700 Subject: [PATCH 10/15] New AnnotatedSUTInteraction object. --- src/modelgauge/annotation_pipeline.py | 49 ++++++++++++------- src/modelgauge/single_turn_prompt_response.py | 10 ++++ .../test_annotation_pipeline.py | 46 ++++++++++------- 3 files changed, 68 insertions(+), 37 deletions(-) diff --git a/src/modelgauge/annotation_pipeline.py b/src/modelgauge/annotation_pipeline.py index 6e2e1feed..adb3c299a 100644 --- a/src/modelgauge/annotation_pipeline.py +++ b/src/modelgauge/annotation_pipeline.py @@ -15,7 +15,12 @@ from modelgauge.pipeline import CachingPipe, Pipe, Sink, Source from modelgauge.prompt import TextPrompt from modelgauge.prompt_pipeline import PromptOutput -from modelgauge.single_turn_prompt_response import SUTResponseAnnotations, SUTInteraction, TestItem +from modelgauge.single_turn_prompt_response import ( + AnnotatedSUTInteraction, + SUTResponseAnnotations, + SUTInteraction, + TestItem, +) from modelgauge.sut import PromptResponseSUT, SUTResponse logger = logging.getLogger(__name__) @@ -106,7 +111,7 @@ def handle_uncached_item(self, item): time.sleep(self.sleep_time) result = annotator.translate_response(request, response) self.annotation_counts[annotator_uid] += 1 - return sut_interaction, annotator_uid, result + return AnnotatedSUTInteraction(annotator_uid=annotator_uid, annotation=result, sut_interaction=sut_interaction) class EnsembleVoter(Pipe): @@ -119,18 +124,25 @@ def __init__(self, ensemble: AnnotatorSet): def handle_item(self, item): # Always pass the original item through self.downstream_put(item) - sut_interaction, annotator_uid, annotation = item - if annotator_uid in self.ensemble.annotators: - self.annotations[sut_interaction][annotator_uid] = annotation - if len(self.annotations[sut_interaction]) == len(self.ensemble.annotators): + if item.annotator_uid in self.ensemble.annotators: + self.annotations[item.sut_interaction][item.annotator_uid] = item.annotation + if len(self.annotations[item.sut_interaction]) == len(self.ensemble.annotators): # All annotators have responded, so we can compute the ensemble response. - annotations = {k: Annotation.from_instance(v) for k, v in self.annotations[sut_interaction].items()} + annotations = { + k: Annotation.from_instance(v) for k, v in self.annotations[item.sut_interaction].items() + } result = self.ensemble.evaluate( SUTResponseAnnotations( - test_item=sut_interaction.prompt, sut_response=sut_interaction.response, annotations=annotations + test_item=item.sut_interaction.prompt, + sut_response=item.sut_interaction.response, + annotations=annotations, + ) + ) + self.downstream_put( + AnnotatedSUTInteraction( + annotator_uid="ensemble", annotation=result, sut_interaction=item.sut_interaction ) ) - self.downstream_put((sut_interaction, "ensemble", result)) self.num_ensemble_votes += 1 @@ -148,18 +160,17 @@ def run(self): with self.writer: super().run() - def interaction_is_complete(self, sut_interaction) -> bool: + def interaction_is_complete(self, sut_interaction: SUTInteraction) -> bool: num_expected_annotations = len(self.annotators) if self.ensemble: num_expected_annotations += 1 return len(self.unfinished[sut_interaction]) == num_expected_annotations - def handle_item(self, item): - sut_interaction, annotator_uid, annotation = item - if isinstance(annotation, BaseModel): - annotation = annotation.model_dump() - self.unfinished[sut_interaction][annotator_uid] = annotation - if self.interaction_is_complete(sut_interaction): - self.writer.write(sut_interaction, self.unfinished[sut_interaction]) - self._debug(f"wrote {sut_interaction}") - del self.unfinished[sut_interaction] + def handle_item(self, item: AnnotatedSUTInteraction): + # Convert Pydantic model to dict if needed + annotation = item.annotation.model_dump() if isinstance(item.annotation, BaseModel) else item.annotation + self.unfinished[item.sut_interaction][item.annotator_uid] = annotation + if self.interaction_is_complete(item.sut_interaction): + self.writer.write(item.sut_interaction, self.unfinished[item.sut_interaction]) + self._debug(f"wrote {item.sut_interaction}") + del self.unfinished[item.sut_interaction] diff --git a/src/modelgauge/single_turn_prompt_response.py b/src/modelgauge/single_turn_prompt_response.py index e679fe2ff..e7f3f6366 100644 --- a/src/modelgauge/single_turn_prompt_response.py +++ b/src/modelgauge/single_turn_prompt_response.py @@ -83,3 +83,13 @@ class SUTInteraction: def __hash__(self): return hash(self.prompt.source_id + self.sut_uid) + + +@dataclass +class AnnotatedSUTInteraction: + annotator_uid: str + annotation: Annotation + sut_interaction: SUTInteraction + + def __hash__(self): + return hash(self.prompt.source_id + self.sut_uid + self.annotator_uid) diff --git a/tests/modelgauge_tests/test_annotation_pipeline.py b/tests/modelgauge_tests/test_annotation_pipeline.py index 2446ae685..570024a07 100644 --- a/tests/modelgauge_tests/test_annotation_pipeline.py +++ b/tests/modelgauge_tests/test_annotation_pipeline.py @@ -23,7 +23,7 @@ PromptSutAssigner, PromptSutWorkers, ) -from modelgauge.single_turn_prompt_response import SUTInteraction, TestItem +from modelgauge.single_turn_prompt_response import AnnotatedSUTInteraction, SUTInteraction, TestItem from modelgauge.sut import SUTResponse from modelgauge_tests.fake_annotator import ( FakeAnnotation, @@ -161,10 +161,10 @@ def test_annotator_worker_normal(annotators, annotator_uid, annotation): sut_interaction = make_sut_interaction("1", "prompt", "sut", "response") w = AnnotatorWorkers(annotators) result = w.handle_item((sut_interaction, annotator_uid)) - - assert result[0] == sut_interaction - assert result[1] == annotator_uid - assert result[2] == annotation + assert isinstance(result, AnnotatedSUTInteraction) + assert result.sut_interaction == sut_interaction + assert result.annotator_uid == annotator_uid + assert result.annotation == annotation def test_annotator_worker_cache_simple(annotators, tmp_path): @@ -174,8 +174,8 @@ def test_annotator_worker_cache_simple(annotators, tmp_path): # Tests that first call invokes the annotator and the second call uses the cache. assert annotators["annotator_pydantic"].annotate_calls == 0 for _ in range(2): - _, _, annotation = w.handle_item((sut_interaction, "annotator_pydantic")) - assert annotation == FakeAnnotation(sut_text="response") + result = w.handle_item((sut_interaction, "annotator_pydantic")) + assert result.annotation == FakeAnnotation(sut_text="response") assert annotators["annotator_pydantic"].annotate_calls == 1 @@ -241,7 +241,9 @@ def test_annotator_worker_retries_until_success(): result = w.handle_item((sut_interaction, "fake-annotator")) assert mock.call_count == num_exceptions + 1 - assert (sut_interaction, "fake-annotator", FakeAnnotation(sut_text="response")) == result + assert result.sut_interaction == sut_interaction + assert result.annotator_uid == "fake-annotator" + assert result.annotation == FakeAnnotation(sut_text="response") class FakeEnsemble(AnnotatorSet): @@ -262,14 +264,16 @@ def test_ensemble_worker_puts_all_items(annotator_uid): sut_interaction = make_sut_interaction("1", "prompt", "sut", "response") annotation = FakeAnnotation(sut_text="response") - w.handle_item((sut_interaction, annotator_uid, annotation)) + w.handle_item( + AnnotatedSUTInteraction(sut_interaction=sut_interaction, annotator_uid=annotator_uid, annotation=annotation) + ) assert w._queue.qsize() > 0 item = w._queue.get() - assert item[0] == sut_interaction - assert item[1] == annotator_uid - assert item[2] == annotation + assert item == AnnotatedSUTInteraction( + annotator_uid=annotator_uid, annotation=annotation, sut_interaction=sut_interaction + ) def test_ensemble_worker_computes_ensemble_with_all_annotators(): @@ -279,18 +283,25 @@ def test_ensemble_worker_computes_ensemble_with_all_annotators(): sut_interaction = make_sut_interaction("1", "prompt", "sut", "response") annotation = FakeAnnotation(sut_text="response") - w.handle_item((sut_interaction, "annotator_pydantic", annotation)) + w.handle_item( + AnnotatedSUTInteraction( + sut_interaction=sut_interaction, annotator_uid="annotator_pydantic", annotation=annotation + ) + ) assert w._queue.qsize() == 1 # Should just pass the first annotation through - w.handle_item((sut_interaction, "dummy", annotation)) + w.handle_item( + AnnotatedSUTInteraction(sut_interaction=sut_interaction, annotator_uid="dummy", annotation=annotation) + ) assert w._queue.qsize() == 3 # Should pass second annotation + final ensemble annotation w._queue.get() w._queue.get() item = w._queue.get() - assert item[0] == sut_interaction - assert item[1] == "ensemble" - assert item[2] == {"ensemble_vote": 1.0} + expected = AnnotatedSUTInteraction( + annotator_uid="ensemble", annotation={"ensemble_vote": 1.0}, sut_interaction=sut_interaction + ) + assert item == expected def test_full_run(annotators): @@ -316,7 +327,6 @@ def test_full_run(annotators): AnnotatorAssigner(annotators), AnnotatorWorkers(annotators, workers=1), AnnotatorSink(annotators, output), - debug=True, ) p.run() From 8a36bde2766bca3db97a0dece1d969d6c007f171 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Tue, 1 Jul 2025 16:53:51 -0700 Subject: [PATCH 11/15] Annotation column is a json dict + annotation dataset tests --- src/modelgauge/data_schema.py | 2 +- src/modelgauge/dataset.py | 33 ++-- tests/modelgauge_tests/test_data_schema.py | 10 +- tests/modelgauge_tests/test_dataset.py | 220 ++++++++++----------- 4 files changed, 130 insertions(+), 135 deletions(-) diff --git a/src/modelgauge/data_schema.py b/src/modelgauge/data_schema.py index 33c1085d4..960f00bd6 100644 --- a/src/modelgauge/data_schema.py +++ b/src/modelgauge/data_schema.py @@ -4,7 +4,7 @@ SUT_UID_COLS = ["sut_uid", "sut"] SUT_RESPONSE_COLS = ["sut_response", "response_text", "response"] ANNOTATOR_UID_COLS = ["annotator_uid"] -ANNOTATION_COLS = ["is_safe"] +ANNOTATION_COLS = ["annotation_json"] class SchemaValidationError(ValueError): diff --git a/src/modelgauge/dataset.py b/src/modelgauge/dataset.py index 631086bab..0c13611cd 100644 --- a/src/modelgauge/dataset.py +++ b/src/modelgauge/dataset.py @@ -3,6 +3,8 @@ from pathlib import Path from typing import Dict, Iterable, Optional, Union, Any, Sequence +from pydantic import BaseModel + from modelgauge.data_schema import ( DEFAULT_ANNOTATION_SCHEMA, DEFAULT_PROMPT_RESPONSE_SCHEMA, @@ -12,7 +14,7 @@ PromptSchema, ) from modelgauge.prompt import TextPrompt -from modelgauge.single_turn_prompt_response import SUTInteraction, TestItem +from modelgauge.single_turn_prompt_response import AnnotatedSUTInteraction, SUTInteraction, TestItem from modelgauge.sut import SUTResponse @@ -191,27 +193,28 @@ def _init_schema(self): else: self.schema = DEFAULT_ANNOTATION_SCHEMA - # TODO: New annotation object - def row_to_item(self, row: dict) -> tuple[SUTInteraction, Optional[Dict[str, Any]]]: + def row_to_item(self, row: dict) -> AnnotatedSUTInteraction: prompt = TestItem( prompt=TextPrompt(text=row[self.schema.prompt_text]), source_id=row[self.schema.prompt_uid], context=row, ) response = SUTResponse(text=row[self.schema.sut_response]) - interaction = SutInteraction(prompt, row[self.schema.sut_uid], response) - - # Extract annotations if present - annotations = row.get(self.schema.annotation) - return interaction, annotations + interaction = SUTInteraction(prompt, row[self.schema.sut_uid], response) + annotation = row.get(self.schema.annotation) + return AnnotatedSUTInteraction( + sut_interaction=interaction, annotator_uid=row[self.schema.annotator_uid], annotation=annotation + ) - def item_to_row(self, item: SUTInteraction, annotations: Optional[Dict[str, Any]] = None) -> list[str]: - if not isinstance(item.prompt.prompt, TextPrompt): + def item_to_row(self, item: AnnotatedSUTInteraction) -> list[str]: + if not isinstance(item.sut_interaction.prompt.prompt, TextPrompt): raise ValueError(f"Error handling {item}. Can only handle TextPrompts.") + annotation = item.annotation.model_dump() if isinstance(item.annotation, BaseModel) else item.annotation return [ - item.prompt.source_id, - item.prompt.prompt.text, - item.sut_uid, - item.response.text, - annotations.is_safe, + item.sut_interaction.prompt.source_id, + item.sut_interaction.prompt.prompt.text, + item.sut_interaction.sut_uid, + item.sut_interaction.response.text, + item.annotator_uid, + annotation, ] diff --git a/tests/modelgauge_tests/test_data_schema.py b/tests/modelgauge_tests/test_data_schema.py index 7b3d4eac3..9d4cfea0a 100644 --- a/tests/modelgauge_tests/test_data_schema.py +++ b/tests/modelgauge_tests/test_data_schema.py @@ -97,11 +97,11 @@ def test_default_prompt_response_schema(): "header", [ # Preferred names - ["prompt_uid", "prompt_text", "sut_uid", "sut_response", "annotator_uid", "is_safe"], + ["prompt_uid", "prompt_text", "sut_uid", "sut_response", "annotator_uid", "annotation_json"], # Case-insensitive - ["prompt_UID", "Prompt_Text", "SUT_UID", "SUT_Response", "Annotator_UID", "Is_Safe"], + ["prompt_UID", "Prompt_Text", "SUT_UID", "SUT_Response", "Annotator_UID", "annotation_JSON"], # Extra columns are allowed - ["prompt_uid", "prompt_text", "sut_uid", "sut_response", "annotator_uid", "is_safe", "extra_col"], + ["prompt_uid", "prompt_text", "sut_uid", "sut_response", "annotator_uid", "annotation_json", "extra_col"], ], ) def test_valid_annotation_schema(header): @@ -122,7 +122,7 @@ def test_valid_prompt_response_invalid_annotation_schema(): def test_invalid_prompt_response_valid_annotation_schema(): - header = ["random_1", "random_2", "random_3", "random_4", "annotator_uid", "is_safe"] + header = ["random_1", "random_2", "random_3", "random_4", "annotator_uid", "annotation_json"] with pytest.raises(SchemaValidationError) as e: schema = AnnotationSchema(header) assert set(e.missing_columns) == { @@ -139,4 +139,4 @@ def test_default_annotation_schema(): assert DEFAULT_ANNOTATION_SCHEMA.sut_uid == "sut_uid" assert DEFAULT_ANNOTATION_SCHEMA.sut_response == "sut_response" assert DEFAULT_ANNOTATION_SCHEMA.annotator_uid == "annotator_uid" - assert DEFAULT_ANNOTATION_SCHEMA.annotation == "is_safe" + assert DEFAULT_ANNOTATION_SCHEMA.annotation == "annotation_json" diff --git a/tests/modelgauge_tests/test_dataset.py b/tests/modelgauge_tests/test_dataset.py index 267f66b08..0405b5cdb 100644 --- a/tests/modelgauge_tests/test_dataset.py +++ b/tests/modelgauge_tests/test_dataset.py @@ -2,7 +2,10 @@ from pathlib import Path from typing import Iterable +from pydantic import BaseModel + from modelgauge.data_schema import ( + DEFAULT_ANNOTATION_SCHEMA, DEFAULT_PROMPT_RESPONSE_SCHEMA, DEFAULT_PROMPT_SCHEMA, SchemaValidationError, @@ -14,7 +17,7 @@ PromptResponseDataset, ) from modelgauge.prompt import TextPrompt -from modelgauge.single_turn_prompt_response import SUTInteraction, TestItem +from modelgauge.single_turn_prompt_response import AnnotatedSUTInteraction, SUTInteraction, TestItem from modelgauge.sut import SUTResponse @@ -280,116 +283,105 @@ def test_write_csv(self, tmp_path): assert content.replace('"', "") == expected_header + expected_data -# class TestAnnotationDataset: -# @pytest.fixture -# def sample_annotations_jsonl(tmp_path): -# """Create a sample JSONL file with annotated prompt-response data.""" -# file_path = tmp_path / "annotations.jsonl" -# content = [ -# { -# DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid: "p1", -# DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text: "Say hello", -# DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid: "sut1", -# DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response: "Hello world", -# "Annotations": {"toxicity": 0.1} -# }, -# { -# DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid: "p2", -# DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text: "Say goodbye", -# DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid: "sut1", -# DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response: "Goodbye world", -# "Annotations": {"toxicity": 0.0} -# } -# ] -# import jsonlines -# with jsonlines.open(file_path, mode='w') as writer: -# writer.write_all(content) -# return file_path -# -# def test_read_jsonl(self, sample_annotations_jsonl): -# """Test reading annotated prompt-response pairs from a JSONL file.""" -# with AnnotationDataset(sample_annotations_jsonl, mode='r') as dataset: -# interactions = list(dataset) -# assert len(interactions) == 2 - -# # Check first interaction -# interaction, annotations = interactions[0] -# assert interaction.prompt.source_id == "p1" -# assert interaction.prompt.prompt.text == "Say hello" -# assert interaction.sut_uid == "sut1" -# assert interaction.response.text == "Hello world" -# assert annotations == {"toxicity": 0.1} - -# # Check second interaction -# interaction, annotations = interactions[1] -# assert interaction.prompt.source_id == "p2" -# assert annotations == {"toxicity": 0.0} - -# def test_read_csv(self, sample_responses_csv): -# """Test reading from a CSV file (should have no annotations).""" -# with AnnotationDataset(sample_responses_csv, mode='r') as dataset: -# interactions = list(dataset) -# assert len(interactions) == 2 - -# # Check that annotations are None -# for interaction, annotations in interactions: -# assert annotations is None - -# def test_write_jsonl(self, tmp_path): -# """Test writing annotated prompt-response pairs to a JSONL file.""" -# output_file = tmp_path / "output.jsonl" - -# # Create test data -# interaction = SutInteraction( -# prompt=TestItem( -# prompt=TextPrompt(text="Test prompt"), -# source_id="test1", -# context={} -# ), -# sut_uid="sut1", -# response=SUTResponse(text="Test response") -# ) -# annotations = {"toxicity": 0.5} - -# # Write data -# with AnnotationDataset(output_file, mode='w') as dataset: -# dataset.write(interaction, annotations) - -# # Verify written data -# import jsonlines -# with jsonlines.open(output_file) as reader: -# data = list(reader) -# assert len(data) == 1 -# assert data[0][DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid] == "test1" -# assert data[0][DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text] == "Test prompt" -# assert data[0][DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid] == "sut1" -# assert data[0][DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response] == "Test response" -# assert data[0]["Annotations"] == {"toxicity": 0.5} - -# def test_write_csv(self, tmp_path): -# """Test writing to a CSV file (should ignore annotations).""" -# output_file = tmp_path / "output.csv" - -# # Create test data -# interaction = SutInteraction( -# prompt=TestItem( -# prompt=TextPrompt(text="Test prompt"), -# source_id="test1", -# context={} -# ), -# sut_uid="sut1", -# response=SUTResponse(text="Test response") -# ) -# annotations = {"toxicity": 0.5} - -# # Write data -# with AnnotationDataset(output_file, mode='w') as dataset: -# dataset.write(interaction, annotations) - -# # Verify written data - should not include annotations -# assert output_file.exists() -# content = output_file.read_text() -# expected_header = f"{DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid},{DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text}," \ -# f"{DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid},{DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response}\n" -# expected_data = '"test1","Test prompt","sut1","Test response"\n' -# assert content == expected_header + expected_data +class TestAnnotationDataset: + @pytest.fixture + def sample_annotations_csv(self, tmp_path): + file_path = tmp_path / "annotations.csv" + content = ( + f"{DEFAULT_ANNOTATION_SCHEMA.prompt_uid},{DEFAULT_ANNOTATION_SCHEMA.prompt_text}," + f"{DEFAULT_ANNOTATION_SCHEMA.sut_uid},{DEFAULT_ANNOTATION_SCHEMA.sut_response}," + f"{DEFAULT_ANNOTATION_SCHEMA.annotator_uid},{DEFAULT_ANNOTATION_SCHEMA.annotation}\n" + "p1,Say hello,sut1,Hello world,annotator1,1\n" + "p2,Say goodbye,sut1,Goodbye world,annotator1,0" + ) + file_path.write_text(content) + return file_path + + def test_schema_read(self, sample_annotations_csv): + dataset = AnnotationDataset(sample_annotations_csv, mode="r") + assert dataset.schema.header == DEFAULT_ANNOTATION_SCHEMA.header + + def test_schema_write(self, tmp_path): + dataset = AnnotationDataset(tmp_path / "annotations.csv", mode="w") + assert dataset.schema == DEFAULT_ANNOTATION_SCHEMA + + def test_read_csv(self, sample_annotations_csv): + with AnnotationDataset(sample_annotations_csv, mode="r") as dataset: + annotations = list(dataset) + assert len(annotations) == 2 + assert all(isinstance(annotation, AnnotatedSUTInteraction) for annotation in annotations) + + # Check first annotation + assert annotations[0].sut_interaction.prompt.source_id == "p1" + assert annotations[0].sut_interaction.prompt.prompt.text == "Say hello" + assert annotations[0].sut_interaction.sut_uid == "sut1" + assert annotations[0].sut_interaction.response.text == "Hello world" + assert annotations[0].annotator_uid == "annotator1" + assert annotations[0].annotation == "1" + + # Check second annotation + assert annotations[1].sut_interaction.prompt.source_id == "p2" + assert annotations[1].sut_interaction.prompt.prompt.text == "Say goodbye" + assert annotations[1].sut_interaction.sut_uid == "sut1" + assert annotations[1].sut_interaction.response.text == "Goodbye world" + assert annotations[1].annotator_uid == "annotator1" + assert annotations[1].annotation == "0" + + def test_write_csv_dict_annotation(self, tmp_path): + output_file = tmp_path / "output.csv" + + # Create test data + sut_interaction = SUTInteraction( + prompt=TestItem(prompt=TextPrompt(text="Test prompt"), source_id="test1", context={}), + sut_uid="sut1", + response=SUTResponse(text="Test response"), + ) + annotated_interaction = AnnotatedSUTInteraction( + sut_interaction=sut_interaction, annotator_uid="annotator1", annotation={"is_safe": True} + ) + + # Write data + with AnnotationDataset(output_file, mode="w") as dataset: + dataset.write(annotated_interaction) + + # Verify written data + assert output_file.exists() + content = output_file.read_text() + expected_header = ( + f"{DEFAULT_ANNOTATION_SCHEMA.prompt_uid},{DEFAULT_ANNOTATION_SCHEMA.prompt_text}," + f"{DEFAULT_ANNOTATION_SCHEMA.sut_uid},{DEFAULT_ANNOTATION_SCHEMA.sut_response}," + f"{DEFAULT_ANNOTATION_SCHEMA.annotator_uid},{DEFAULT_ANNOTATION_SCHEMA.annotation}\n" + ) + expected_data = "test1,Test prompt,sut1,Test response,annotator1,{'is_safe': True}\n" + assert content.replace('"', "") == expected_header + expected_data + + def test_write_csv_pydantic_annotation(self, tmp_path): + output_file = tmp_path / "output.csv" + + # Create test data + sut_interaction = SUTInteraction( + prompt=TestItem(prompt=TextPrompt(text="Test prompt"), source_id="test1", context={}), + sut_uid="sut1", + response=SUTResponse(text="Test response"), + ) + + class MyAnnotation(BaseModel): + is_safe: bool + + annotated_interaction = AnnotatedSUTInteraction( + sut_interaction=sut_interaction, annotator_uid="annotator1", annotation=MyAnnotation(is_safe=True) + ) + + with AnnotationDataset(output_file, mode="w") as dataset: + dataset.write(annotated_interaction) + + # Verify written data + assert output_file.exists() + content = output_file.read_text() + expected_header = ( + f"{DEFAULT_ANNOTATION_SCHEMA.prompt_uid},{DEFAULT_ANNOTATION_SCHEMA.prompt_text}," + f"{DEFAULT_ANNOTATION_SCHEMA.sut_uid},{DEFAULT_ANNOTATION_SCHEMA.sut_response}," + f"{DEFAULT_ANNOTATION_SCHEMA.annotator_uid},{DEFAULT_ANNOTATION_SCHEMA.annotation}\n" + ) + expected_data = "test1,Test prompt,sut1,Test response,annotator1,{'is_safe': True}\n" + assert content.replace('"', "") == expected_header + expected_data From 160253ba9935116762d869745a4c298b33647080 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Wed, 2 Jul 2025 13:11:18 -0700 Subject: [PATCH 12/15] Annotation dataset dumps annotations as json strings and reads them back as dictionaries --- src/modelgauge/dataset.py | 11 ++++- tests/modelgauge_tests/test_dataset.py | 56 ++++++++++++++------------ 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/src/modelgauge/dataset.py b/src/modelgauge/dataset.py index 0c13611cd..86072d7c1 100644 --- a/src/modelgauge/dataset.py +++ b/src/modelgauge/dataset.py @@ -1,4 +1,5 @@ import csv +import json from abc import ABC, abstractmethod from pathlib import Path from typing import Dict, Iterable, Optional, Union, Any, Sequence @@ -201,15 +202,21 @@ def row_to_item(self, row: dict) -> AnnotatedSUTInteraction: ) response = SUTResponse(text=row[self.schema.sut_response]) interaction = SUTInteraction(prompt, row[self.schema.sut_uid], response) - annotation = row.get(self.schema.annotation) + print(row.get(self.schema.annotation)) + print(type(row.get(self.schema.annotation))) + annotation = json.loads(row.get(self.schema.annotation)) return AnnotatedSUTInteraction( sut_interaction=interaction, annotator_uid=row[self.schema.annotator_uid], annotation=annotation ) def item_to_row(self, item: AnnotatedSUTInteraction) -> list[str]: + """Write an AnnotatedSUTInteraction to a csv row. The last column is the annotation, which is a json string that can be deserialized with json.loads.""" if not isinstance(item.sut_interaction.prompt.prompt, TextPrompt): raise ValueError(f"Error handling {item}. Can only handle TextPrompts.") - annotation = item.annotation.model_dump() if isinstance(item.annotation, BaseModel) else item.annotation + annotation = item.annotation + if isinstance(annotation, BaseModel): + annotation = annotation.model_dump() + annotation = json.dumps(annotation) return [ item.sut_interaction.prompt.source_id, item.sut_interaction.prompt.prompt.text, diff --git a/tests/modelgauge_tests/test_dataset.py b/tests/modelgauge_tests/test_dataset.py index 0405b5cdb..bf56cc5fc 100644 --- a/tests/modelgauge_tests/test_dataset.py +++ b/tests/modelgauge_tests/test_dataset.py @@ -1,3 +1,4 @@ +import json import pytest from pathlib import Path from typing import Iterable @@ -284,15 +285,26 @@ def test_write_csv(self, tmp_path): class TestAnnotationDataset: + @pytest.fixture + def expected_content(self): + expected_header = ( + f'"{DEFAULT_ANNOTATION_SCHEMA.prompt_uid}","{DEFAULT_ANNOTATION_SCHEMA.prompt_text}",' + f'"{DEFAULT_ANNOTATION_SCHEMA.sut_uid}","{DEFAULT_ANNOTATION_SCHEMA.sut_response}",' + f'"{DEFAULT_ANNOTATION_SCHEMA.annotator_uid}","{DEFAULT_ANNOTATION_SCHEMA.annotation}"\n' + ) + expected_data = '"test1","Test prompt","sut1","Test response","annotator1","{""is_safe"": true}"\n' + return expected_header + expected_data + @pytest.fixture def sample_annotations_csv(self, tmp_path): + """Sample CSV with QUOTE_ALL quoting.""" file_path = tmp_path / "annotations.csv" content = ( - f"{DEFAULT_ANNOTATION_SCHEMA.prompt_uid},{DEFAULT_ANNOTATION_SCHEMA.prompt_text}," - f"{DEFAULT_ANNOTATION_SCHEMA.sut_uid},{DEFAULT_ANNOTATION_SCHEMA.sut_response}," - f"{DEFAULT_ANNOTATION_SCHEMA.annotator_uid},{DEFAULT_ANNOTATION_SCHEMA.annotation}\n" - "p1,Say hello,sut1,Hello world,annotator1,1\n" - "p2,Say goodbye,sut1,Goodbye world,annotator1,0" + f'"{DEFAULT_ANNOTATION_SCHEMA.prompt_uid}","{DEFAULT_ANNOTATION_SCHEMA.prompt_text}",' + f'"{DEFAULT_ANNOTATION_SCHEMA.sut_uid}","{DEFAULT_ANNOTATION_SCHEMA.sut_response}",' + f'"{DEFAULT_ANNOTATION_SCHEMA.annotator_uid}","{DEFAULT_ANNOTATION_SCHEMA.annotation}"\n' + '"p1","Say hello","sut1","Hello world","annotator1","{""is_safe"": true}"\n' + '"p2","Say goodbye","sut1","Goodbye world","annotator1","{""is_safe"": false}"' ) file_path.write_text(content) return file_path @@ -317,7 +329,7 @@ def test_read_csv(self, sample_annotations_csv): assert annotations[0].sut_interaction.sut_uid == "sut1" assert annotations[0].sut_interaction.response.text == "Hello world" assert annotations[0].annotator_uid == "annotator1" - assert annotations[0].annotation == "1" + assert annotations[0].annotation == {"is_safe": True} # Annotation jsonstrings are deserialized to dicts # Check second annotation assert annotations[1].sut_interaction.prompt.source_id == "p2" @@ -325,9 +337,9 @@ def test_read_csv(self, sample_annotations_csv): assert annotations[1].sut_interaction.sut_uid == "sut1" assert annotations[1].sut_interaction.response.text == "Goodbye world" assert annotations[1].annotator_uid == "annotator1" - assert annotations[1].annotation == "0" + assert annotations[1].annotation == {"is_safe": False} - def test_write_csv_dict_annotation(self, tmp_path): + def test_write_csv_dict_annotation(self, tmp_path, expected_content): output_file = tmp_path / "output.csv" # Create test data @@ -346,16 +358,11 @@ def test_write_csv_dict_annotation(self, tmp_path): # Verify written data assert output_file.exists() - content = output_file.read_text() - expected_header = ( - f"{DEFAULT_ANNOTATION_SCHEMA.prompt_uid},{DEFAULT_ANNOTATION_SCHEMA.prompt_text}," - f"{DEFAULT_ANNOTATION_SCHEMA.sut_uid},{DEFAULT_ANNOTATION_SCHEMA.sut_response}," - f"{DEFAULT_ANNOTATION_SCHEMA.annotator_uid},{DEFAULT_ANNOTATION_SCHEMA.annotation}\n" - ) - expected_data = "test1,Test prompt,sut1,Test response,annotator1,{'is_safe': True}\n" - assert content.replace('"', "") == expected_header + expected_data + read_content = output_file.read_text() - def test_write_csv_pydantic_annotation(self, tmp_path): + assert read_content == expected_content + + def test_write_csv_pydantic_annotation(self, expected_content, tmp_path): output_file = tmp_path / "output.csv" # Create test data @@ -377,11 +384,10 @@ class MyAnnotation(BaseModel): # Verify written data assert output_file.exists() - content = output_file.read_text() - expected_header = ( - f"{DEFAULT_ANNOTATION_SCHEMA.prompt_uid},{DEFAULT_ANNOTATION_SCHEMA.prompt_text}," - f"{DEFAULT_ANNOTATION_SCHEMA.sut_uid},{DEFAULT_ANNOTATION_SCHEMA.sut_response}," - f"{DEFAULT_ANNOTATION_SCHEMA.annotator_uid},{DEFAULT_ANNOTATION_SCHEMA.annotation}\n" - ) - expected_data = "test1,Test prompt,sut1,Test response,annotator1,{'is_safe': True}\n" - assert content.replace('"', "") == expected_header + expected_data + read_content = output_file.read_text() + assert read_content == expected_content + + def test_annotation_is_deserialized_on_read(self, sample_annotations_csv): + with AnnotationDataset(sample_annotations_csv, mode="r") as dataset: + row_1 = list(dataset)[0] + row_1.annotation == {"is_safe": True} From a20d5ec0cdb6845fc19180573d7721021f19739c Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Wed, 2 Jul 2025 13:34:39 -0700 Subject: [PATCH 13/15] Use AnnotationDataset object in annotation runner --- src/modelgauge/annotation_pipeline.py | 56 +----- src/modelgauge/pipeline_runner.py | 17 +- src/modelgauge/prompt_pipeline.py | 14 -- .../test_annotation_pipeline.py | 189 ++++++++---------- tests/modelgauge_tests/test_cli.py | 78 +------- .../modelgauge_tests/test_pipeline_runner.py | 8 - 6 files changed, 104 insertions(+), 258 deletions(-) diff --git a/src/modelgauge/annotation_pipeline.py b/src/modelgauge/annotation_pipeline.py index adb3c299a..8abb102ec 100644 --- a/src/modelgauge/annotation_pipeline.py +++ b/src/modelgauge/annotation_pipeline.py @@ -10,11 +10,10 @@ from modelgauge.annotation import Annotation from modelgauge.annotator import Annotator from modelgauge.annotator_set import AnnotatorSet -from modelgauge.dataset import PromptResponseDataset +from modelgauge.dataset import AnnotationDataset, PromptResponseDataset from modelgauge.data_schema import DEFAULT_PROMPT_RESPONSE_SCHEMA, PromptResponseSchema from modelgauge.pipeline import CachingPipe, Pipe, Sink, Source from modelgauge.prompt import TextPrompt -from modelgauge.prompt_pipeline import PromptOutput from modelgauge.single_turn_prompt_response import ( AnnotatedSUTInteraction, SUTResponseAnnotations, @@ -26,38 +25,6 @@ logger = logging.getLogger(__name__) -class JsonlAnnotatorOutput(PromptOutput): - def __init__(self, path): - super().__init__() - assert path.suffix.lower() == ".jsonl", f"Invalid output file {path}. Must be of type JSONL." - - self.path = path - self.file = None - self.writer = None - - def __enter__(self): - self.file = open(self.path, "w", newline="") - self.writer = jsonlines.Writer(self.file) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.writer.close() - self.file.close() - - def write(self, item: SUTInteraction, results): - if not isinstance(item.prompt.prompt, TextPrompt): - raise Exception(f"Error handling {item}. Can only handle TextPrompts.") - # TODO: Standardize annotation schema. - output_obj = { - DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_uid: item.prompt.source_id, - DEFAULT_PROMPT_RESPONSE_SCHEMA.prompt_text: item.prompt.prompt.text, - DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_uid: item.sut_uid, - DEFAULT_PROMPT_RESPONSE_SCHEMA.sut_response: item.response.text, - "Annotations": results, - } - self.writer.write(output_obj) - - class AnnotatorSource(Source): def __init__(self, input: PromptResponseDataset): super().__init__() @@ -147,30 +114,13 @@ def handle_item(self, item): class AnnotatorSink(Sink): - unfinished: defaultdict[SUTInteraction, dict[str, str]] - - def __init__(self, annotators: dict[str, Annotator], writer: JsonlAnnotatorOutput, ensemble: bool = False): + def __init__(self, writer: AnnotationDataset): super().__init__() - self.annotators = annotators - self.ensemble = ensemble self.writer = writer - self.unfinished = defaultdict(lambda: dict()) def run(self): with self.writer: super().run() - def interaction_is_complete(self, sut_interaction: SUTInteraction) -> bool: - num_expected_annotations = len(self.annotators) - if self.ensemble: - num_expected_annotations += 1 - return len(self.unfinished[sut_interaction]) == num_expected_annotations - def handle_item(self, item: AnnotatedSUTInteraction): - # Convert Pydantic model to dict if needed - annotation = item.annotation.model_dump() if isinstance(item.annotation, BaseModel) else item.annotation - self.unfinished[item.sut_interaction][item.annotator_uid] = annotation - if self.interaction_is_complete(item.sut_interaction): - self.writer.write(item.sut_interaction, self.unfinished[item.sut_interaction]) - self._debug(f"wrote {item.sut_interaction}") - del self.unfinished[item.sut_interaction] + self.writer.write(item) diff --git a/src/modelgauge/pipeline_runner.py b/src/modelgauge/pipeline_runner.py index d036d65b9..3f59d6907 100644 --- a/src/modelgauge/pipeline_runner.py +++ b/src/modelgauge/pipeline_runner.py @@ -9,9 +9,8 @@ AnnotatorSource, AnnotatorWorkers, EnsembleVoter, - JsonlAnnotatorOutput, ) -from modelgauge.dataset import PromptDataset, PromptResponseDataset +from modelgauge.dataset import AnnotationDataset, PromptDataset, PromptResponseDataset from modelgauge.pipeline import Pipeline from modelgauge.prompt_pipeline import ( PromptSource, @@ -178,7 +177,7 @@ def num_total_items(self): @property def output_file_name(self): - return "annotations.jsonl" + return "annotations.csv" @property def run_id(self): @@ -197,8 +196,8 @@ def _add_annotator_segments(self, include_source=True, include_sink=True): self.annotator_workers = AnnotatorWorkers(self.annotators, self.num_workers) self.pipeline_segments.append(self.annotator_workers) if include_sink: - output = JsonlAnnotatorOutput(self.output_dir() / self.output_file_name) - self.pipeline_segments.append(AnnotatorSink(self.annotators, output, ensemble=False)) + output = AnnotationDataset(self.output_dir() / self.output_file_name, "w") + self.pipeline_segments.append(AnnotatorSink(output)) def _annotator_metadata(self): counts = self.annotator_workers.annotation_counts @@ -255,8 +254,8 @@ def _add_ensemble_segments(self): """Adds ensemble worker plus annotator sink.""" self.ensemble_voter = EnsembleVoter(self.ensemble) self.pipeline_segments.append(self.ensemble_voter) - output = JsonlAnnotatorOutput(self.output_dir() / self.output_file_name) - self.pipeline_segments.append(AnnotatorSink(self.annotators, output, ensemble=True)) + output = AnnotationDataset(self.output_dir() / self.output_file_name, "w") + self.pipeline_segments.append(AnnotatorSink(output)) def _initialize_segments(self): # Add regular annotator segments @@ -274,7 +273,7 @@ def num_total_items(self): @property def output_file_name(self): - return "prompt-responses-annotated.jsonl" + return "prompt-responses-annotated.csv" @property def run_id(self): @@ -301,7 +300,7 @@ def num_total_items(self): @property def output_file_name(self): - return "prompt-responses-annotated.jsonl" + return "prompt-responses-annotated.csv" @property def run_id(self): diff --git a/src/modelgauge/prompt_pipeline.py b/src/modelgauge/prompt_pipeline.py index cf6cf44db..be7d51348 100644 --- a/src/modelgauge/prompt_pipeline.py +++ b/src/modelgauge/prompt_pipeline.py @@ -1,7 +1,6 @@ import csv import logging import time -from abc import ABCMeta, abstractmethod from collections import defaultdict from typing import Iterable, Optional @@ -15,19 +14,6 @@ logger = logging.getLogger(__name__) -# TODO: Delete. -class PromptOutput(metaclass=ABCMeta): - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - pass - - @abstractmethod - def write(self, item, results): - pass - - class PromptSource(Source): def __init__(self, input: PromptDataset): super().__init__() diff --git a/tests/modelgauge_tests/test_annotation_pipeline.py b/tests/modelgauge_tests/test_annotation_pipeline.py index 570024a07..c7811fa9c 100644 --- a/tests/modelgauge_tests/test_annotation_pipeline.py +++ b/tests/modelgauge_tests/test_annotation_pipeline.py @@ -1,4 +1,5 @@ import itertools +import json import jsonlines import pytest import time @@ -10,15 +11,13 @@ AnnotatorWorkers, AnnotatorSink, EnsembleVoter, - JsonlAnnotatorOutput, ) from modelgauge.annotator_set import AnnotatorSet -from modelgauge.dataset import PromptResponseDataset +from modelgauge.dataset import AnnotationDataset, PromptResponseDataset from modelgauge.data_schema import DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA from modelgauge.pipeline import Pipeline from modelgauge.prompt import TextPrompt from modelgauge.prompt_pipeline import ( - PromptOutput, PromptSource, PromptSutAssigner, PromptSutWorkers, @@ -52,12 +51,13 @@ def __iter__(self): yield SUTInteraction(prompt, row[PROMPT_RESPONSE_SCHEMA.sut_uid], response) -class FakeAnnotatorOutput(PromptOutput): - def __init__(self): - self.output = {} +class FakeAnnotatorOutput(AnnotationDataset): + def __init__(self, path: str): + self.output = [] + super().__init__(path, "w") - def write(self, item, annotations): - self.output[item] = annotations + def write(self, item): + self.output.append(self.item_to_row(item)) def make_sut_interaction(source_id, prompt, sut_uid, response): @@ -90,52 +90,6 @@ def test_csv_annotator_input(tmp_path): assert sut_interactions_is_equal(item, make_sut_interaction("1", "a", "s", "b")) -def test_json_annotator_output(tmp_path): - file_path = tmp_path / "output.jsonl" - with JsonlAnnotatorOutput(file_path) as output: - output.write(make_sut_interaction("1", "a", "sut1", "b"), {"fake": "x"}) - output.write(make_sut_interaction("2", "c", "sut2", "d"), {"fake": "y"}) - - with jsonlines.open(file_path) as reader: - items: list[dict] = [i for i in reader] - assert len(items) == 2 - assert items[0] == { - PROMPT_RESPONSE_SCHEMA.prompt_uid: "1", - PROMPT_RESPONSE_SCHEMA.prompt_text: "a", - PROMPT_RESPONSE_SCHEMA.sut_uid: "sut1", - PROMPT_RESPONSE_SCHEMA.sut_response: "b", - "Annotations": {"fake": "x"}, - } - assert items[1] == { - PROMPT_RESPONSE_SCHEMA.prompt_uid: "2", - PROMPT_RESPONSE_SCHEMA.prompt_text: "c", - PROMPT_RESPONSE_SCHEMA.sut_uid: "sut2", - PROMPT_RESPONSE_SCHEMA.sut_response: "d", - "Annotations": {"fake": "y"}, - } - - -def test_json_annotator_output_different_annotation_types(tmp_path): - file_path = tmp_path / "output.jsonl" - annotations = { - "fake1": {"sut_text": "a"}, - "fake2": {"sut_text": "b", "num": 0}, - "fake3": "c", - } - with JsonlAnnotatorOutput(file_path) as output: - output.write(make_sut_interaction("1", "a", "s", "b"), annotations) - - with jsonlines.open(file_path) as reader: - assert reader.read()["Annotations"] == annotations - - -@pytest.mark.parametrize("output_fname", ["output.csv", "output.json"]) -def test_csv_annotator_output_invalid(tmp_path, output_fname): - file_path = tmp_path / output_fname - with pytest.raises(AssertionError, match=f"Invalid output file {file_path}. Must be of type JSONL."): - JsonlAnnotatorOutput(file_path) - - @pytest.fixture def annotators(): annotator_pydantic = FakeAnnotator("annotator_pydantic") @@ -304,7 +258,7 @@ def test_ensemble_worker_computes_ensemble_with_all_annotators(): assert item == expected -def test_full_run(annotators): +def test_full_run(annotators, tmp_path): input = FakeAnnotatorInput( [ { @@ -321,32 +275,42 @@ def test_full_run(annotators): }, ] ) - output = FakeAnnotatorOutput() + output = FakeAnnotatorOutput(tmp_path / "output.csv") p = Pipeline( AnnotatorSource(input), AnnotatorAssigner(annotators), AnnotatorWorkers(annotators, workers=1), - AnnotatorSink(annotators, output), + AnnotatorSink(output), ) p.run() - assert len(output.output) == len(input.items) - interactions = sorted(list(output.output.keys()), key=lambda o: o.prompt.source_id) - assert sut_interactions_is_equal(interactions[0], make_sut_interaction("1", "a", "s", "b")) - assert output.output[interactions[0]] == { - "annotator_pydantic": {"sut_text": "b"}, - "annotator_dict": {"sut_text": "b"}, - "dummy": "d", - } - assert sut_interactions_is_equal(interactions[1], make_sut_interaction("2", "c", "s", "d")) - assert output.output[interactions[1]] == { - "annotator_pydantic": {"sut_text": "d"}, - "annotator_dict": {"sut_text": "d"}, - "dummy": "d", - } - - -def test_full_run_with_ensemble(annotators): + assert len(output.output) == len(input.items) * len(annotators) + items = sorted(output.output, key=lambda o: (o[0], o[4])) # Sort by prompt_uid, annotator_uid + + # First 3 items are same sut interaction + assert sut_interactions_is_equal(make_sut_interaction(*items[0][:4]), make_sut_interaction("1", "a", "s", "b")) + assert sut_interactions_is_equal(make_sut_interaction(*items[1][:4]), make_sut_interaction("1", "a", "s", "b")) + assert sut_interactions_is_equal(make_sut_interaction(*items[2][:4]), make_sut_interaction("1", "a", "s", "b")) + assert items[0][4] == "annotator_dict" + assert items[1][4] == "annotator_pydantic" + assert items[2][4] == "dummy" + assert items[0][5] == '{"sut_text": "b"}' + assert items[1][5] == '{"sut_text": "b"}' + assert items[2][5] == '"d"' + + # Second 3 items are same sut interaction + assert sut_interactions_is_equal(make_sut_interaction(*items[3][:4]), make_sut_interaction("2", "c", "s", "d")) + assert sut_interactions_is_equal(make_sut_interaction(*items[4][:4]), make_sut_interaction("2", "c", "s", "d")) + assert sut_interactions_is_equal(make_sut_interaction(*items[5][:4]), make_sut_interaction("2", "c", "s", "d")) + assert items[3][4] == "annotator_dict" + assert items[4][4] == "annotator_pydantic" + assert items[5][4] == "dummy" + assert items[3][5] == '{"sut_text": "d"}' + assert items[4][5] == '{"sut_text": "d"}' + assert items[5][5] == '"d"' + + +def test_full_run_with_ensemble(annotators, tmp_path): input = FakeAnnotatorInput( [ { @@ -363,39 +327,42 @@ def test_full_run_with_ensemble(annotators): }, ] ) - output = FakeAnnotatorOutput() + output = FakeAnnotatorOutput(tmp_path / "output.csv") p = Pipeline( AnnotatorSource(input), AnnotatorAssigner(annotators), AnnotatorWorkers(annotators, workers=1), EnsembleVoter(FakeEnsemble(["annotator_pydantic", "annotator_dict"])), - AnnotatorSink(annotators, output, ensemble=True), + AnnotatorSink(output), debug=False, ) p.run() - assert len(output.output) == len(input.items) - interactions = sorted(list(output.output.keys()), key=lambda o: o.prompt.source_id) - assert output.output[interactions[0]] == { - "annotator_pydantic": {"sut_text": "b"}, - "annotator_dict": {"sut_text": "b"}, - "dummy": "d", - "ensemble": {"ensemble_vote": 1.0}, - } + assert len(output.output) == len(input.items) * (len(annotators) + 1) # +1 for ensemble + items = sorted(output.output, key=lambda o: (o[4], o[0])) # Sort by annotator_uid, prompt_uid + + assert items[0][4] == "annotator_dict" + assert items[0][5] == '{"sut_text": "b"}' + assert items[2][4] == "annotator_pydantic" + assert items[2][5] == '{"sut_text": "b"}' + assert items[4][4] == "dummy" + assert items[4][5] == '"d"' + assert items[6][4] == "ensemble" + assert items[6][5] == '{"ensemble_vote": 1.0}' @pytest.mark.parametrize( "sut_worker_count,annotator_worker_count", - [(1, 1), (2, 2), (8, 8), (1, 5), (5, 1), (3, 9), (9, 3)], + [(1, 1), (2, 2), (8, 8), (1, 5), (5, 1)], ) -def test_prompt_response_annotation_pipeline(annotators, sut_worker_count, annotator_worker_count): +def test_prompt_response_annotation_pipeline(annotators, sut_worker_count, annotator_worker_count, tmp_path): input = FakePromptInput( [ {PROMPT_RESPONSE_SCHEMA.prompt_uid: "1", PROMPT_RESPONSE_SCHEMA.prompt_text: "a"}, {PROMPT_RESPONSE_SCHEMA.prompt_uid: "2", PROMPT_RESPONSE_SCHEMA.prompt_text: "b"}, ] ) - output = FakeAnnotatorOutput() + output = FakeAnnotatorOutput(tmp_path / "output.csv") suts = {"sut1": FakeSUT("sut1"), "sut2": FakeSUT("sut2")} p = Pipeline( @@ -404,26 +371,34 @@ def test_prompt_response_annotation_pipeline(annotators, sut_worker_count, annot PromptSutWorkers(suts, workers=sut_worker_count), AnnotatorAssigner(annotators), AnnotatorWorkers(annotators, workers=annotator_worker_count), - AnnotatorSink(annotators, output), + AnnotatorSink(output), ) p.run() - assert len(output.output) == len(input.items) * len(suts) - interactions = sorted(list(output.output.keys()), key=lambda o: (o.prompt.source_id, o.sut_uid)) - for interaction, prompt_sut in zip(interactions, itertools.product(input.items, suts)): - prompt, sut = prompt_sut - assert sut_interactions_is_equal( - interaction, - make_sut_interaction( - prompt[PROMPT_RESPONSE_SCHEMA.prompt_uid], - prompt[PROMPT_RESPONSE_SCHEMA.prompt_text], - sut, - prompt[PROMPT_RESPONSE_SCHEMA.prompt_text], - ), - ) - annotation = {"sut_text": prompt[PROMPT_RESPONSE_SCHEMA.prompt_text]} - assert output.output[interaction] == { - "annotator_pydantic": annotation, - "annotator_dict": annotation, - "dummy": "d", - } + assert len(output.output) == len(input.items) * len(suts) * len(annotators) + + rows = sorted(output.output, key=lambda row: (row[0], row[2], row[4])) # Sort by prompt_uid, sut_uid, annotator_uid + + # Group rows by prompt and sut + current_idx = 0 + for prompt in input.items: + for sut in suts: + # For each prompt-sut combination, we should have one row per annotator + for annotator_name in ["annotator_dict", "annotator_pydantic", "dummy"]: + row = rows[current_idx] + # Check prompt fields + assert row[0] == prompt[PROMPT_RESPONSE_SCHEMA.prompt_uid] # prompt_uid + assert row[1] == prompt[PROMPT_RESPONSE_SCHEMA.prompt_text] # prompt_text + # Check SUT fields + assert row[2] == sut # sut_uid + assert row[3] == prompt[PROMPT_RESPONSE_SCHEMA.prompt_text] # sut_response (FakeSUT echoes prompt) + # Check annotator fields + assert row[4] == annotator_name # annotator_uid + # Check annotation content + if annotator_name == "dummy": + assert row[5] == '"d"' # dummy annotator returns "d" + else: + # Both dict and pydantic annotators return the same structure + expected_annotation = {"sut_text": prompt[PROMPT_RESPONSE_SCHEMA.prompt_text]} + assert row[5] == json.dumps(expected_annotation) + current_idx += 1 diff --git a/tests/modelgauge_tests/test_cli.py b/tests/modelgauge_tests/test_cli.py index 3e1b560bd..d1834795a 100644 --- a/tests/modelgauge_tests/test_cli.py +++ b/tests/modelgauge_tests/test_cli.py @@ -1,5 +1,3 @@ -import csv -import json import logging import re import sys @@ -8,9 +6,6 @@ from unittest.mock import patch import click - -import jsonlines - import pytest from click.testing import CliRunner, Result @@ -178,27 +173,7 @@ def test_run_prompts_normal(caplog, tmp_path, prompts_file): assert result.exit_code == 0 out_path = re.findall(r"\S+\.csv", caplog.text)[0] - with open(out_path, "r") as f: - reader = csv.DictReader(f) - - rows = (next(reader), next(reader)) - rows = sorted(rows, key=lambda row: row[PROMPT_RESPONSE_SCHEMA.prompt_uid]) - expected = ( - { - PROMPT_RESPONSE_SCHEMA.prompt_uid: "p1", - PROMPT_RESPONSE_SCHEMA.prompt_text: "Say yes", - PROMPT_RESPONSE_SCHEMA.sut_uid: "demo_yes_no", - PROMPT_RESPONSE_SCHEMA.sut_response: "Yes", - }, - { - PROMPT_RESPONSE_SCHEMA.prompt_uid: "p2", - PROMPT_RESPONSE_SCHEMA.prompt_text: "Refuse", - PROMPT_RESPONSE_SCHEMA.sut_uid: "demo_yes_no", - PROMPT_RESPONSE_SCHEMA.sut_response: "No", - }, - ) - assert rows[0] == expected[0] - assert rows[1] == expected[1] + assert Path(out_path).exists() @pytest.mark.parametrize("arg_name", ["--sut", "-s"]) @@ -255,25 +230,8 @@ def test_run_prompts_with_annotators(caplog, tmp_path, prompts_file): ) assert result.exit_code == 0 - out_path = re.findall(r"\S+\.jsonl", caplog.text)[0] - output = [] - with jsonlines.open(out_path) as reader: - output.append(reader.read()) - output.append(reader.read()) - assert { - PROMPT_RESPONSE_SCHEMA.prompt_uid: "p1", - PROMPT_RESPONSE_SCHEMA.prompt_text: "Say yes", - PROMPT_RESPONSE_SCHEMA.sut_uid: "demo_yes_no", - PROMPT_RESPONSE_SCHEMA.sut_response: "Yes", - "Annotations": {"demo_annotator": {"badness": 1.0}}, - } in output - assert { - PROMPT_RESPONSE_SCHEMA.prompt_uid: "p2", - PROMPT_RESPONSE_SCHEMA.prompt_text: "Refuse", - PROMPT_RESPONSE_SCHEMA.sut_uid: "demo_yes_no", - PROMPT_RESPONSE_SCHEMA.sut_response: "No", - "Annotations": {"demo_annotator": {"badness": 0.0}}, - } in output + out_path = re.findall(r"\S+\.csv", caplog.text)[0] + assert Path(out_path).exists() @patch("modelgauge.suts.demo_01_yes_no_sut.DemoYesNoSUT.translate_text_prompt") @@ -341,22 +299,8 @@ def test_run_annotators(caplog, tmp_path, prompt_responses_file): ) assert result.exit_code == 0 - out_path = re.findall(r"\S+\.jsonl", caplog.text)[0] - with jsonlines.open(out_path) as reader: - assert reader.read() == { - PROMPT_RESPONSE_SCHEMA.prompt_uid: "p1", - PROMPT_RESPONSE_SCHEMA.prompt_text: "Say yes", - PROMPT_RESPONSE_SCHEMA.sut_uid: "demo_yes_no", - PROMPT_RESPONSE_SCHEMA.sut_response: "Yes", - "Annotations": {"demo_annotator": {"badness": 1.0}}, - } - assert reader.read() == { - PROMPT_RESPONSE_SCHEMA.prompt_uid: "p2", - PROMPT_RESPONSE_SCHEMA.prompt_text: "Refuse", - PROMPT_RESPONSE_SCHEMA.sut_uid: "demo_yes_no", - PROMPT_RESPONSE_SCHEMA.sut_response: "No", - "Annotations": {"demo_annotator": {"badness": 0.0}}, - } + out_path = re.findall(r"\S+\.csv", caplog.text)[0] + assert Path(out_path).exists() @pytest.mark.parametrize( @@ -460,10 +404,10 @@ def test_run_job_sut_and_annotator_output_name(caplog, tmp_path, prompts_file): assert result.exit_code == 0 - out_path = Path(re.findall(r"\S+\.jsonl", caplog.text)[0]) + out_path = Path(re.findall(r"\S+\.csv", caplog.text)[0]) assert out_path.exists() - assert out_path.name == "prompt-responses-annotated.jsonl" # File name + assert out_path.name == "prompt-responses-annotated.csv" # File name assert re.match(r"\d{8}-\d{6}-demo_yes_no-demo_annotator", out_path.parent.name) # Subdir name assert out_path.parent.parent == tmp_path # Parent dir @@ -482,10 +426,10 @@ def test_run_job_annotators_only_output_name(caplog, tmp_path, prompt_responses_ assert result.exit_code == 0 - out_path = Path(re.findall(r"\S+\.jsonl", caplog.text)[0]) + out_path = Path(re.findall(r"\S+\.csv", caplog.text)[0]) assert out_path.exists() - assert out_path.name == "annotations.jsonl" # File name + assert out_path.name == "annotations.csv" # File name assert re.match(r"\d{8}-\d{6}-demo_annotator", out_path.parent.name) # Subdir name assert out_path.parent.parent == tmp_path # Parent dir @@ -523,10 +467,10 @@ def evaluate(self, item): assert result.exit_code == 0 - out_path = Path(re.findall(r"\S+\.jsonl", caplog.text)[0]) + out_path = Path(re.findall(r"\S+\.csv", caplog.text)[0]) assert out_path.exists() - assert out_path.name == "annotations.jsonl" # File name + assert out_path.name == "annotations.csv" # File name assert re.match(r"\d{8}-\d{6}-ensemble", out_path.parent.name) # Subdir name assert out_path.parent.parent == tmp_path # Parent dir diff --git a/tests/modelgauge_tests/test_pipeline_runner.py b/tests/modelgauge_tests/test_pipeline_runner.py index 474364a8e..b3d42aa6c 100644 --- a/tests/modelgauge_tests/test_pipeline_runner.py +++ b/tests/modelgauge_tests/test_pipeline_runner.py @@ -265,8 +265,6 @@ def test_pipeline_segments(self, tmp_path, prompts_file, suts, annotators): assert annotator_workers.thread_count == 20 assert isinstance(sink, AnnotatorSink) - assert sink.annotators == annotators - assert sink.ensemble == False def test_pipeline_segments_ensemble(self, runner_ensemble, annotators, ensemble): source, sut_assigner, sut_workers, annotator_assigner, annotator_workers, ensemble_worker, sink = ( @@ -279,8 +277,6 @@ def test_pipeline_segments_ensemble(self, runner_ensemble, annotators, ensemble) assert ensemble_worker.ensemble == ensemble assert isinstance(sink, AnnotatorSink) - assert sink.annotators == annotators - assert sink.ensemble == True def test_runner_num_input_items(self, runner_basic): assert runner_basic.num_input_items == NUM_PROMPTS @@ -419,8 +415,6 @@ def test_pipeline_segments(self, tmp_path, prompt_responses_file, annotators): assert annotator_workers.thread_count == 20 assert isinstance(sink, AnnotatorSink) - assert sink.annotators == annotators - assert sink.ensemble == False def test_pipeline_segments_ensemble(self, runner_ensemble, annotators, ensemble): source, annotator_assigner, annotator_workers, ensemble_worker, sink = runner_ensemble.pipeline_segments @@ -431,8 +425,6 @@ def test_pipeline_segments_ensemble(self, runner_ensemble, annotators, ensemble) assert ensemble_worker.ensemble == ensemble assert isinstance(sink, AnnotatorSink) - assert sink.annotators == annotators - assert sink.ensemble is True def test_missing_ensemble_annotators_raises_error(self, tmp_path, prompt_responses_file, ensemble): incomplete_annotators = {"annotator1": FakeAnnotator("annotator1"), "annotator2": FakeAnnotator("annotator2")} From d8e9a8db9781de28ccd9f472c80e139bfd0d6c84 Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Wed, 2 Jul 2025 13:54:17 -0700 Subject: [PATCH 14/15] mypy --- src/modelgauge/dataset.py | 42 ++++++++++++++------------ tests/modelgauge_tests/test_dataset.py | 4 +-- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/src/modelgauge/dataset.py b/src/modelgauge/dataset.py index 86072d7c1..586716e00 100644 --- a/src/modelgauge/dataset.py +++ b/src/modelgauge/dataset.py @@ -43,8 +43,7 @@ def __init__(self, path: Union[str, Path], mode: str): self.file = None self.writer = None self.reader = None - self.schema = None - self._init_schema() # Initialized by subclass. + self.schema = self._get_schema() # Implemented by subclass. def __enter__(self): """Context manager entry. Opens the file and sets the reader or writer.""" @@ -102,8 +101,8 @@ def __len__(self) -> int: return count @abstractmethod - def _init_schema(self): - """Initialize dataset schema `self.schema`. To be implemented by subclasses.""" + def _get_schema(self): + """Return dataset schema. To be implemented by subclasses.""" pass def _read_header(self) -> list[str]: @@ -112,12 +111,13 @@ def _read_header(self) -> list[str]: raise RuntimeError("Can only read header in read mode.") if self.file is None: with self: - header = self.reader.fieldnames + header = self.reader.fieldnames # type: ignore else: - header = self.reader.fieldnames + header = self.reader.fieldnames # type: ignore return header def header_columns(self) -> Sequence[str]: + assert self.schema is not None, "Sub-classes must initialized schema." return self.schema.header def write(self, item: Any): @@ -143,8 +143,8 @@ class PromptDataset(BaseDataset): def __init__(self, path: Union[str, Path]): super().__init__(path, "r") - def _init_schema(self): - self.schema = PromptSchema(self._read_header()) + def _get_schema(self): + return PromptSchema(self._read_header()) def row_to_item(self, row: dict) -> TestItem: """Convert a single prompt row to a TestItem.""" @@ -158,11 +158,11 @@ def row_to_item(self, row: dict) -> TestItem: class PromptResponseDataset(BaseDataset): """Dataset for prompt-response CSV data. Read or write.""" - def _init_schema(self): + def _get_schema(self): if self.mode == "r": - self.schema = PromptResponseSchema(self._read_header()) + return PromptResponseSchema(self._read_header()) else: - self.schema = DEFAULT_PROMPT_RESPONSE_SCHEMA + return DEFAULT_PROMPT_RESPONSE_SCHEMA def row_to_item(self, row: dict) -> SUTInteraction: prompt = TestItem( @@ -176,6 +176,7 @@ def row_to_item(self, row: dict) -> SUTInteraction: def item_to_row(self, item: SUTInteraction) -> list[str]: if not isinstance(item.prompt.prompt, TextPrompt): raise ValueError(f"Error handling {item}. Can only handle TextPrompts.") + assert item.prompt.source_id is not None, "Prompt source_id is required." return [ item.prompt.source_id, @@ -188,11 +189,11 @@ def item_to_row(self, item: SUTInteraction) -> list[str]: class AnnotationDataset(BaseDataset): """Dataset for annotated prompt-response CSV data. Read or write.""" - def _init_schema(self): + def _get_schema(self): if self.mode == "r": - self.schema = AnnotationSchema(self._read_header()) + return AnnotationSchema(self._read_header()) else: - self.schema = DEFAULT_ANNOTATION_SCHEMA + return DEFAULT_ANNOTATION_SCHEMA def row_to_item(self, row: dict) -> AnnotatedSUTInteraction: prompt = TestItem( @@ -202,9 +203,9 @@ def row_to_item(self, row: dict) -> AnnotatedSUTInteraction: ) response = SUTResponse(text=row[self.schema.sut_response]) interaction = SUTInteraction(prompt, row[self.schema.sut_uid], response) - print(row.get(self.schema.annotation)) - print(type(row.get(self.schema.annotation))) - annotation = json.loads(row.get(self.schema.annotation)) + print(row[self.schema.annotation]) + print(type(row[self.schema.annotation])) + annotation = json.loads(row[self.schema.annotation]) return AnnotatedSUTInteraction( sut_interaction=interaction, annotator_uid=row[self.schema.annotator_uid], annotation=annotation ) @@ -213,15 +214,16 @@ def item_to_row(self, item: AnnotatedSUTInteraction) -> list[str]: """Write an AnnotatedSUTInteraction to a csv row. The last column is the annotation, which is a json string that can be deserialized with json.loads.""" if not isinstance(item.sut_interaction.prompt.prompt, TextPrompt): raise ValueError(f"Error handling {item}. Can only handle TextPrompts.") + assert item.sut_interaction.prompt.source_id is not None, "Prompt source_id is required." annotation = item.annotation if isinstance(annotation, BaseModel): - annotation = annotation.model_dump() - annotation = json.dumps(annotation) + annotation = annotation.model_dump() # type: ignore + annotation_json_str = json.dumps(annotation) return [ item.sut_interaction.prompt.source_id, item.sut_interaction.prompt.prompt.text, item.sut_interaction.sut_uid, item.sut_interaction.response.text, item.annotator_uid, - annotation, + annotation_json_str, ] diff --git a/tests/modelgauge_tests/test_dataset.py b/tests/modelgauge_tests/test_dataset.py index bf56cc5fc..4d8944152 100644 --- a/tests/modelgauge_tests/test_dataset.py +++ b/tests/modelgauge_tests/test_dataset.py @@ -37,8 +37,8 @@ def __init__(self, path: Path, mode: str): self.write_called = False self.write_item = None - def _init_schema(self): - self.schema = TestBaseDataset.DummySchema() + def _get_schema(self): + return TestBaseDataset.DummySchema() def row_to_item(self, row: dict) -> str: """Convert a row to a dummy item.""" From 9194c8be6be8bfc7727b7441eede036c0d0b76fe Mon Sep 17 00:00:00 2001 From: Barbara Korycki Date: Thu, 3 Jul 2025 10:56:09 -0700 Subject: [PATCH 15/15] remove prints --- src/modelgauge/dataset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/modelgauge/dataset.py b/src/modelgauge/dataset.py index 586716e00..5c394f2e6 100644 --- a/src/modelgauge/dataset.py +++ b/src/modelgauge/dataset.py @@ -203,8 +203,6 @@ def row_to_item(self, row: dict) -> AnnotatedSUTInteraction: ) response = SUTResponse(text=row[self.schema.sut_response]) interaction = SUTInteraction(prompt, row[self.schema.sut_uid], response) - print(row[self.schema.annotation]) - print(type(row[self.schema.annotation])) annotation = json.loads(row[self.schema.annotation]) return AnnotatedSUTInteraction( sut_interaction=interaction, annotator_uid=row[self.schema.annotator_uid], annotation=annotation