@@ -279,6 +279,21 @@ def load_openorca_dataset_pkl(
279279 return [(prompt , output ) for prompt , output in zip (prompts , outputs )]
280280
281281
282+ def load_longcontext_dataset_pkl (
283+ dataset_path : str ,
284+ ) -> list [tuple [Any , Any ]]:
285+ assert os .path .isfile (dataset_path )
286+
287+ # read pickle file
288+ data = pandas .read_pickle (dataset_path )
289+
290+ samples = []
291+ for _ , row in data .iterrows ():
292+ samples .append ((row ["input" ], row ["ref_output" ]))
293+
294+ return samples
295+
296+
282297def load_mmlu_dataset_csv (dataset_path : str ) -> tuple [Any , dict [str , str ]]:
283298 assert dataset_path != ""
284299 dataset = []
@@ -837,7 +852,14 @@ def parse_args() -> argparse.Namespace:
837852 "--dataset" ,
838853 type = str ,
839854 default = "test" ,
840- choices = ["test" , "sharegpt" , "openorca" , "mmlu" , "math500" ],
855+ choices = [
856+ "test" ,
857+ "sharegpt" ,
858+ "openorca" ,
859+ "mmlu" ,
860+ "math500" ,
861+ "longcontext" ,
862+ ],
841863 help = "The dataset name." ,
842864 )
843865 parser .add_argument ("--dataset-path" , type = str , help = "Path to the dataset." )
@@ -1057,6 +1079,10 @@ def main(args: argparse.Namespace):
10571079 dataset = load_math500_dataset (
10581080 args .dataset_path ,
10591081 )
1082+ elif args .dataset == "longcontext" :
1083+ dataset = load_longcontext_dataset_pkl (
1084+ args .dataset_path ,
1085+ )
10601086 else :
10611087 raise ValueError (
10621088 f"Fatal Error: Uncognized input parameters: { args .dataset } "
0 commit comments