Skip to content

Commit 6dbba50

Browse files
authored
Merge pull request #76 from NetherlandsForensicInstitute/feature/split-when-tokenize
traintestsplit + added argparse
2 parents ae6d3bb + 511e055 commit 6dbba50

1 file changed

Lines changed: 30 additions & 8 deletions

File tree

asmtransformers/scripts/tokenize_dataset.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import sys
1+
import argparse
22

33
import datasets
44

@@ -17,19 +17,41 @@ def do_tokenize(function):
1717
return dataset.map(do_tokenize, **map_kwargs)
1818

1919

20-
if __name__ == '__main__':
21-
# expect 3 arguments from cli
22-
tokenizer, data_in, data_out = sys.argv[1:]
23-
20+
def main(tokenizer, input_data, output_folder, split):
2421
tokenizer = ASMTokenizer.from_pretrained(tokenizer)
25-
dataset = datasets.load_from_disk(data_in)
22+
print('loading dataset')
23+
dataset = datasets.load_from_disk(input_data)
24+
if split:
25+
dataset = dataset.train_test_split(test_size=split)
2626

2727
# let the tokenizer preprocess data from data_in, write the result to data_out
28-
28+
print('tokenizing dataset')
2929
if isinstance(dataset, datasets.Dataset): # datasets.load_from_disk either a Dataset or DatasetDict type
3030
dataset = tokenize(tokenizer, dataset)
3131
else:
3232
for subset in dataset:
3333
dataset[subset] = tokenize(tokenizer, dataset[subset])
3434

35-
dataset.save_to_disk(data_out)
35+
dataset.save_to_disk(output_folder)
36+
37+
38+
def get_parser():
39+
parser = argparse.ArgumentParser()
40+
41+
parser.add_argument('--tokenizer', type=str, required=True, help='folder with tokenizer')
42+
parser.add_argument('--input-data', type=str, required=True, help='data to be used for training')
43+
parser.add_argument('--output-folder', type=str, required=True, help='folder to leave the tokenized data')
44+
parser.add_argument(
45+
'--split',
46+
type=float,
47+
required=False,
48+
help='split between train and test; define percentage of test data as number between 0 and 1',
49+
)
50+
51+
return parser
52+
53+
54+
if __name__ == '__main__':
55+
parser = get_parser()
56+
args = parser.parse_args()
57+
main(**vars(args))

0 commit comments

Comments
 (0)