diff --git a/asmtransformers/scripts/tokenize_dataset.py b/asmtransformers/scripts/tokenize_dataset.py index 53d652e..b07ff5b 100644 --- a/asmtransformers/scripts/tokenize_dataset.py +++ b/asmtransformers/scripts/tokenize_dataset.py @@ -1,4 +1,4 @@ -import sys +import argparse import datasets @@ -17,19 +17,41 @@ def do_tokenize(function): return dataset.map(do_tokenize, **map_kwargs) -if __name__ == '__main__': - # expect 3 arguments from cli - tokenizer, data_in, data_out = sys.argv[1:] - +def main(tokenizer, input_data, output_folder, split): tokenizer = ASMTokenizer.from_pretrained(tokenizer) - dataset = datasets.load_from_disk(data_in) + print('loading dataset') + dataset = datasets.load_from_disk(input_data) + if split: + dataset = dataset.train_test_split(test_size=split) # let the tokenizer preprocess data from data_in, write the result to data_out - + print('tokenizing dataset') if isinstance(dataset, datasets.Dataset): # datasets.load_from_disk either a Dataset or DatasetDict type dataset = tokenize(tokenizer, dataset) else: for subset in dataset: dataset[subset] = tokenize(tokenizer, dataset[subset]) - dataset.save_to_disk(data_out) + dataset.save_to_disk(output_folder) + + +def get_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument('--tokenizer', type=str, required=True, help='folder with tokenizer') + parser.add_argument('--input-data', type=str, required=True, help='data to be used for training') + parser.add_argument('--output-folder', type=str, required=True, help='folder to leave the tokenized data') + parser.add_argument( + '--split', + type=float, + required=False, + help='split between train and test; define percentage of test data as number between 0 and 1', + ) + + return parser + + +if __name__ == '__main__': + parser = get_parser() + args = parser.parse_args() + main(**vars(args))