88 AnnotatorWorkers ,
99)
1010from modelgauge .annotator_set import AnnotatorSet
11- from modelgauge .dataset import PromptDataset , PromptResponseDataset
11+ from modelgauge .dataset import AnnotationDataset , PromptDataset , PromptResponseDataset
1212from modelgauge .data_schema import (
1313 DEFAULT_PROMPT_RESPONSE_SCHEMA as PROMPT_RESPONSE_SCHEMA ,
1414 DEFAULT_PROMPT_SCHEMA as PROMPT_SCHEMA ,
@@ -54,7 +54,7 @@ def prompts_dataset(prompts_file):
5454
5555@pytest .fixture (scope = "session" )
5656def prompt_responses_file (tmp_path_factory ):
57- """Sample file with 2 prompts + responses from 2 SUTs for testing."""
57+ """Sample file with 3 prompts + responses from 2 SUTs for testing."""
5858 file = tmp_path_factory .mktemp ("data" ) / "prompt-responses.csv"
5959 with open (file , "w" ) as f :
6060 text = f"{ PROMPT_RESPONSE_SCHEMA .prompt_uid } ,{ PROMPT_RESPONSE_SCHEMA .prompt_text } ,{ PROMPT_RESPONSE_SCHEMA .sut_uid } ,{ PROMPT_RESPONSE_SCHEMA .sut_response } \n "
@@ -65,6 +65,21 @@ def prompt_responses_file(tmp_path_factory):
6565 return file
6666
6767
68+ @pytest .fixture (scope = "session" )
69+ def prompt_responses_file_with_duplicates (tmp_path_factory ):
70+ """Sample file with 3 prompts + responses from 2 SUTs for testing. Also include duplicate prompt/response, with unique prompt id."""
71+ file = tmp_path_factory .mktemp ("data" ) / "prompt-responses.csv"
72+ with open (file , "w" ) as f :
73+ text = f"{ PROMPT_RESPONSE_SCHEMA .prompt_uid } ,{ PROMPT_RESPONSE_SCHEMA .prompt_text } ,{ PROMPT_RESPONSE_SCHEMA .sut_uid } ,{ PROMPT_RESPONSE_SCHEMA .sut_response } \n "
74+ for i in range (NUM_PROMPTS ):
75+ text += f"p{ i } ,Prompt { i } ,sut1,Response { i } \n "
76+ text += f"p{ i } ,Prompt { i } ,sut2,Response { i } \n "
77+ # add a duplicate with unique prompt ids
78+ text += f"q0,Prompt 0,sut1,Response 0\n "
79+ f .write (text )
80+ return file
81+
82+
6883@pytest .fixture
6984def annotators ():
7085 return {
@@ -493,6 +508,23 @@ def test_metadata_ensemble(self, runner_ensemble):
493508 assert metadata ["ensemble" ]["annotators" ] == ["annotator1" , "annotator2" , "annotator3" ]
494509 assert metadata ["ensemble" ]["num_votes" ] == NUM_PROMPTS * self .NUM_SUTS
495510
511+ def test_cache_responses (self , prompt_responses_file_with_duplicates , annotators , tmp_path ):
512+ runner = AnnotatorRunner (
513+ annotators = annotators ,
514+ num_workers = 1 ,
515+ input_dataset = PromptResponseDataset (prompt_responses_file_with_duplicates , mode = "r" ),
516+ output_dir = tmp_path ,
517+ cache_dir = tmp_path / "cache" ,
518+ )
519+ runner .run (progress_callback = lambda _ : _ , debug = False )
520+ prompt_ids = set ()
521+ with PromptResponseDataset (prompt_responses_file_with_duplicates , "r" ) as prompts :
522+ prompt_ids .update (item .prompt .source_id for item in prompts )
523+ annotated_prompt_ids = set ()
524+ with AnnotationDataset (runner .output_dir () / runner .output_file_name , "r" ) as annotations :
525+ annotated_prompt_ids .update (item .sut_interaction .prompt .source_id for item in annotations )
526+ assert prompt_ids == annotated_prompt_ids
527+
496528
497529class TestBuildRunner :
498530 def test_build_prompt_runner (self , prompts_file , suts , tmp_path ):
0 commit comments