-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_tokenizer.py
More file actions
121 lines (100 loc) · 3.8 KB
/
train_tokenizer.py
File metadata and controls
121 lines (100 loc) · 3.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import json
import copy
import argparse
from datasets import load_dataset
from transformers import AutoTokenizer
from tokenizers.models import BPE
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Train Tokenizer for Multi-lingual Text"
)
parser.add_argument(
"--source_tokenizer",
type=str,
default="LingoIITGN/Ganga-2-1B",
help="Source tokenizer",
)
parser.add_argument(
"--file_path",
type=str,
default="./Data/training_data_parquet/all_languages_merged.parquet",
help="File path for training.",
)
parser.add_argument(
"--output_dir", type=str, default="Multilingual_Ganga", help="Saving Dir"
)
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.source_tokenizer)
vocab = tokenizer.get_vocab()
tokenizer_json = json.loads(tokenizer._tokenizer.to_str())
merges = tokenizer_json["model"]["merges"]
dataset = load_dataset("parquet", data_files=args.file_path)
aux_tokenizer = tokenizer.train_new_from_iterator(
dataset["train"]["text"],
64000,
)
aux_tokenizer_json = json.loads(aux_tokenizer._tokenizer.to_str())
aux_merges = aux_tokenizer_json["model"]["merges"]
# merge the tokenizers
# merge the tokenizers
num_new_token = 0
max_new_token = 32000
ret_vocab = copy.copy(vocab)
ret_merges = []
old_merges = copy.copy(merges)
for merge in aux_merges:
# vocab
[token_1, token_2] = merge
if (len(token_1) > 15) or (len(token_2) > 15):
continue
token = token_1 + token_2
if num_new_token < max_new_token:
if token_1 not in ret_vocab and token_2 not in ret_vocab: # both are new
ret_vocab[token_1] = len(vocab) + num_new_token
num_new_token += 1
if token_1 != token_2:
ret_vocab[token_2] = len(vocab) + num_new_token
num_new_token += 1
elif token_1 not in ret_vocab and token_2 in ret_vocab: # new + old
ret_vocab[token_1] = len(vocab) + num_new_token
num_new_token += 1
elif token_1 in ret_vocab and token_2 not in ret_vocab: # old + new
ret_vocab[token_2] = len(vocab) + num_new_token
num_new_token += 1
else: # both are old
pass
if token not in ret_vocab:
ret_vocab[token] = len(vocab) + num_new_token
num_new_token += 1
# merge
if merge in merges:
old_merges.remove(merge)
ret_merges.append(merge)
elif token in ret_vocab and token_1 in ret_vocab and token_2 in ret_vocab:
ret_merges.append(merge)
# Combine merges and convert to tuples
final_merges = ret_merges + old_merges
# Convert merge lists to tuples if they aren't already
final_merges_tuples = []
for merge in final_merges:
if isinstance(merge, list):
final_merges_tuples.append(tuple(merge))
elif isinstance(merge, tuple):
final_merges_tuples.append(merge)
else:
# Handle string format like "token1 token2"
if isinstance(merge, str):
parts = merge.split()
if len(parts) == 2:
final_merges_tuples.append(tuple(parts))
print(f"Total vocabulary size: {len(ret_vocab)}")
print(f"Total merges: {len(final_merges_tuples)}")
print(f"Added {num_new_token} new tokens")
# retrain tokenizer
tokenizer.backend_tokenizer.model = BPE(
vocab=ret_vocab,
merges=final_merges_tuples, # Use tuples instead of lists
fuse_unk=False,
byte_fallback=True,
)
tokenizer.save_pretrained(args.output_dir)