@@ -30,7 +30,6 @@ def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[Token]) -> list[
3030 :param tokens: The tokens to pre-tokenize.
3131 :return: The pre-tokenized tokens.
3232 """
33- current_tokenizer_vocab = set (tokenizer .get_vocab ())
3433 pre_tokenized_tokens = []
3534
3635 if tokenizer .pre_tokenizer is not None :
@@ -50,7 +49,7 @@ def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[Token]) -> list[
5049def _remap_added_tokens (
5150 special_tokens : list [dict [str , Any ]],
5251 vocabulary : list [str ],
53- ) -> list [dict [str , int ]]:
52+ ) -> list [dict [str , Any ]]:
5453 """
5554 Remap special tokens in the tokenizer.
5655
@@ -119,25 +118,33 @@ def replace_vocabulary(
119118 pre_tokenized_tokens = _pre_tokenize_vocabulary (tokenizer , new_vocabulary )
120119
121120 model_type = tokenizer_json ["model" ]["type" ]
122- special_tokens = {unk_token , pad_token }
121+ added_tokens : list [dict [str , Any ]] = tokenizer_json ["added_tokens" ]
122+ original_added_tokens = {x ["content" ] for x in added_tokens } - {"[UNK]" , "[PAD]" }
123+ added_tokens = _rename_added_token (unk_token , "[UNK]" , added_tokens , pre_tokenized_tokens )
124+ added_tokens = _rename_added_token (pad_token , "[PAD]" , added_tokens , pre_tokenized_tokens )
125+
126+ # Remove old special tokens
127+ added_tokens = [x for x in added_tokens if x ["content" ] not in original_added_tokens ]
128+ # Remove other special tokens from the vocabulary
129+ pre_tokenized_tokens = [x for x in pre_tokenized_tokens if x not in original_added_tokens ]
130+ tokenizer_json ["added_tokens" ] = added_tokens
131+ all_added_tokens = {x ["content" ] for x in added_tokens } | original_added_tokens
123132
124133 if model_type in {"WordPiece" , "BPE" }:
125134 # Easiest, just add the new vocab
126135 unk_token = unk_token or tokenizer_json ["model" ]["unk_token" ]
127- tokenizer_json ["model" ]["unk_token" ] = unk_token
128- tokenizer_json ["added_tokens" ] = [x for x in tokenizer_json ["added_tokens" ] if x ["content" ] in special_tokens ]
136+ tokenizer_json ["model" ]["unk_token" ] = "[UNK]"
129137 tokenizer_json ["model" ]["vocab" ] = {token : idx for idx , token in enumerate (pre_tokenized_tokens )}
130138
131139 if model_type == "BPE" :
132140 # Bit more difficult, we need to take into account merges.
133141 merges = tokenizer_json ["model" ]["merges" ]
134- merges = _make_new_merges_from_vocab (merges , pre_tokenized_tokens , special_tokens )
142+ merges = _make_new_merges_from_vocab (merges , pre_tokenized_tokens , all_added_tokens )
135143 tokenizer_json ["model" ]["merges" ] = merges
136144
137145 elif model_type == "Unigram" :
138146 # Bit more difficult, we need to take into account probas.
139147 unk_id = tokenizer_json ["model" ]["unk_id" ]
140- tokenizer_json ["added_tokens" ] = [x for x in tokenizer_json ["added_tokens" ] if x ["content" ] in special_tokens ]
141148 vocab = tokenizer_json ["model" ]["vocab" ]
142149 unk_token = vocab [unk_id ][0 ] if unk_id is not None else None
143150 current_probas = dict (tokenizer_json ["model" ]["vocab" ])
@@ -152,8 +159,27 @@ def replace_vocabulary(
152159 raise ValueError (f"Unknown model type { model_type } " )
153160
154161 # Remap special tokens
155- added_tokens = tokenizer_json ["added_tokens" ]
156- tokenizer_json ["added_tokens" ] = _remap_added_tokens (added_tokens , pre_tokenized_tokens )
162+ tokenizer_json ["added_tokens" ] = _remap_added_tokens (
163+ special_tokens = tokenizer_json ["added_tokens" ],
164+ vocabulary = pre_tokenized_tokens ,
165+ )
157166 tokenizer_json ["post_processor" ] = _DEFAULT_POST_PROCESSOR_TEMPLATE
158167
159168 return Tokenizer .from_str (json .dumps (tokenizer_json ))
169+
170+
171+ def _rename_added_token (
172+ form : str | None , new_form : str , added_tokens : list [dict [str , Any ]], vocabulary : list [str ]
173+ ) -> list [dict [str , Any ]]:
174+ """Rename special tokens in the tokenizer."""
175+ if form is None :
176+ return added_tokens
177+
178+ idx = vocabulary .index (form )
179+ added_token = [x for x in added_tokens if x ["content" ] == form ]
180+ if added_token :
181+ added_token [0 ]["id" ] = idx
182+ added_token [0 ]["content" ] = new_form
183+ vocabulary [idx ] = new_form
184+
185+ return added_tokens
0 commit comments