7171
7272from benchmarks .eval_accuracy import eval_accuracy
7373from benchmarks .eval_accuracy_mmlu import eval_accuracy_mmlu
74+ from benchmarks .eval_accuracy_longcontext import eval_accuracy_longcontext
7475from benchmarks .metrics import CounterMetric , EventMetric
7576import grpc
7677from jetstream .core .proto import jetstream_pb2
@@ -166,6 +167,7 @@ class InputRequest:
166167 output : str = ""
167168 output_len : int = 0
168169 sample_idx : int = - 1
170+ metric : str = ""
169171
170172
171173@dataclass
@@ -187,10 +189,12 @@ def to_dict(self):
187189 prompt = self .input_request .prompt
188190 original_output = self .input_request .output
189191 sample_idx = self .input_request .sample_idx
192+ metric = self .input_request .metric
190193 else :
191194 prompt = None
192195 original_output = None
193196 sample_idx = None
197+ metric = None
194198 return {
195199 "prompt" : prompt ,
196200 "original_output" : original_output ,
@@ -201,6 +205,7 @@ def to_dict(self):
201205 "ttst_sec" : self .ttst_sec ,
202206 "prompt_len" : self .prompt_len ,
203207 "sample_idx" : sample_idx ,
208+ "metric" : metric ,
204209 }
205210
206211
@@ -282,17 +287,19 @@ def load_openorca_dataset_pkl(
282287
283288def load_longcontext_dataset_pkl (
284289 dataset_path : str ,
285- ) -> list [tuple [Any , Any ]]:
290+ ) -> tuple [ list [tuple [Any , Any ]], list ]:
286291 assert os .path .isfile (dataset_path )
287292
288293 # read pickle file
289294 data = pandas .read_pickle (dataset_path )
290295
291296 samples = []
297+ metrics = []
292298 for _ , row in data .iterrows ():
293- samples .append ((row ["input" ], row ["ref_output" ]))
299+ samples .append ((row ["input" ], row ["gt_output" ]))
300+ metrics .append (row ["metric" ])
294301
295- return samples
302+ return samples , metrics
296303
297304
298305def load_mmlu_dataset_csv (dataset_path : str ) -> tuple [Any , dict [str , str ]]:
@@ -421,7 +428,6 @@ def filter_dataset(
421428 tokenized_dataset : list [tuple [str , Any , str , int , int , int ]],
422429 dataset_type : str ,
423430 max_output_length : int = 0 ,
424- run_mmlu_dataset : bool = False ,
425431 min_input_length : int = 4 ,
426432 max_input_length : int = 0 ,
427433 max_target_length : int = 0 ,
@@ -443,7 +449,8 @@ def filter_dataset(
443449 sample_idx ,
444450 ) in tokenized_dataset :
445451 if prompt_len < min_input_length or (
446- not (run_mmlu_dataset or dataset_type == "math500" ) and output_len < 4
452+ not (dataset_type == "mmlu" or dataset_type == "math500" )
453+ and output_len < 4
447454 ):
448455 # Prune too short sequences.
449456 # This is because TGI causes errors when the input or output length
@@ -479,11 +486,11 @@ def sample_requests(
479486 dataset_type : str ,
480487 max_output_length : int = 0 ,
481488 oversample_multiplier : float = 1.2 ,
482- run_mmlu_dataset : bool = False ,
483489 min_input_length : int = 4 ,
484490 max_input_length : int = 0 ,
485491 max_target_length : int = 0 ,
486492 max_output_multiplier : int = 0 ,
493+ metrics : Optional [list [str ]] = None ,
487494) -> list [InputRequest ]:
488495
489496 # Original dataset size
@@ -521,13 +528,16 @@ def sample_requests(
521528 tokenized_dataset ,
522529 dataset_type ,
523530 max_output_length ,
524- run_mmlu_dataset ,
525531 min_input_length ,
526532 max_input_length ,
527533 max_target_length ,
528534 max_output_multiplier ,
529535 )
530536
537+ if metrics is not None :
538+ for request in input_requests :
539+ request .metric = metrics [request .sample_idx ]
540+
531541 # Sample the requests.
532542 if len (input_requests ) > num_requests :
533543 input_requests = random .sample (input_requests , num_requests )
@@ -1068,11 +1078,6 @@ def parse_args() -> argparse.Namespace:
10681078 choices = ["HELM" , "Harness" , "" ],
10691079 help = "mmlu method/format to generate shots" ,
10701080 )
1071- parser .add_argument (
1072- "--run-mmlu-dataset" ,
1073- action = "store_true" ,
1074- help = "specify if it's for mmlu dataset" ,
1075- )
10761081 return parser .parse_args ()
10771082
10781083
@@ -1094,6 +1099,7 @@ def main(args: argparse.Namespace):
10941099 tokenizer = get_tokenizer (
10951100 model_id , tokenizer_id , use_hf_tokenizer , hf_access_token
10961101 )
1102+ metrics = None
10971103 if tokenizer == "test" or args .dataset == "test" :
10981104 input_requests = mock_requests (
10991105 args .total_mock_requests
@@ -1116,7 +1122,7 @@ def main(args: argparse.Namespace):
11161122 args .dataset_path ,
11171123 )
11181124 elif args .dataset == "longcontext" :
1119- dataset = load_longcontext_dataset_pkl (
1125+ dataset , metrics = load_longcontext_dataset_pkl (
11201126 args .dataset_path ,
11211127 )
11221128 else :
@@ -1134,11 +1140,11 @@ def main(args: argparse.Namespace):
11341140 num_requests = args .num_prompts ,
11351141 dataset_type = args .dataset ,
11361142 max_output_length = args .max_output_length ,
1137- run_mmlu_dataset = args .run_mmlu_dataset ,
11381143 min_input_length = args .min_input_length ,
11391144 max_input_length = args .max_input_length ,
11401145 max_target_length = args .max_target_length ,
11411146 max_output_multiplier = args .max_output_multiplier ,
1147+ metrics = metrics ,
11421148 )
11431149
11441150 warmup_requests = None
@@ -1184,8 +1190,10 @@ def main(args: argparse.Namespace):
11841190 # Process output
11851191 output = [output .to_dict () for output in request_outputs ]
11861192 if args .run_eval :
1187- if args .run_mmlu_dataset :
1193+ if args .dataset == "mmlu" :
11881194 eval_json = eval_accuracy_mmlu (output )
1195+ elif args .dataset == "longcontext" :
1196+ eval_json = eval_accuracy_longcontext (output )
11891197 else :
11901198 eval_json = eval_accuracy (output , args .dataset [:4 ])
11911199
0 commit comments