@@ -107,6 +107,45 @@ def _make_new_merges_from_vocab(
107107 return new_merges
108108
109109
110+ def _process_wordpiece (
111+ tokenizer_json : dict [str , Any ], pre_tokenized_tokens : list [str ], unk_token : str | None
112+ ) -> dict [str , Any ]:
113+ """Process the WordPiece tokenizer JSON."""
114+ unk_token = unk_token or tokenizer_json ["model" ]["unk_token" ]
115+ tokenizer_json ["model" ]["unk_token" ] = "[UNK]" if unk_token else None
116+ tokenizer_json ["model" ]["vocab" ] = {token : idx for idx , token in enumerate (pre_tokenized_tokens )}
117+
118+ return tokenizer_json
119+
120+
121+ def _process_bpe (tokenizer_json : dict [str , Any ], pre_tokenized_tokens : list [str ]) -> dict [str , Any ]:
122+ """Process the BPE tokenizer JSON."""
123+ tokenizer_json = _process_wordpiece (tokenizer_json , pre_tokenized_tokens , None )
124+ merges = tokenizer_json ["model" ]["merges" ]
125+ merges = _make_new_merges_from_vocab (merges , pre_tokenized_tokens , {"[UNK]" , "[PAD]" })
126+ tokenizer_json ["model" ]["merges" ] = merges
127+
128+ return tokenizer_json
129+
130+
131+ def _process_unigram (
132+ tokenizer_json : dict [str , Any ], pre_tokenized_tokens : list [str ], unk_token : str | None
133+ ) -> dict [str , Any ]:
134+ """Process the Unigram tokenizer JSON."""
135+ unk_id = tokenizer_json ["model" ]["unk_id" ]
136+ vocab = tokenizer_json ["model" ]["vocab" ]
137+ unk_token = vocab [unk_id ][0 ] if unk_id is not None else None
138+ current_probas = dict (tokenizer_json ["model" ]["vocab" ])
139+ avg_proba = sum (current_probas .values ()) / len (current_probas )
140+ new_probas = {word : current_probas .get (word , avg_proba ) for word in pre_tokenized_tokens }
141+ tokenizer_json ["model" ]["vocab" ] = sorted (new_probas .items (), key = lambda x : x [1 ], reverse = True )
142+
143+ tokens , _ = zip (* tokenizer_json ["model" ]["vocab" ])
144+ tokenizer_json ["model" ]["unk_id" ] = list (tokens ).index (unk_token ) if unk_token in tokens else None
145+
146+ return tokenizer_json
147+
148+
110149def replace_vocabulary (
111150 tokenizer : Tokenizer , new_vocabulary : list [Token ], unk_token : str | None , pad_token : str | None
112151) -> Tokenizer :
@@ -121,41 +160,19 @@ def replace_vocabulary(
121160 model_type = tokenizer_json ["model" ]["type" ]
122161 added_tokens : list [dict [str , Any ]] = tokenizer_json ["added_tokens" ]
123162
124- # NOTE: all added tokens but the unk and pad tokens are removed already.
125- # We only need this for BPE.
126- added_token_forms = {x ["content" ] for x in added_tokens } | {"[UNK]" , "[PAD]" }
127163 # We need to remove the added tokens but keep [UNK] and [PAD] tokens.
128164 added_tokens = _rename_added_token (unk_token , "[UNK]" , added_tokens , pre_tokenized_tokens )
129165 added_tokens = _rename_added_token (pad_token , "[PAD]" , added_tokens , pre_tokenized_tokens )
130166
131167 # Remove old added tokens from added tokens
132168 tokenizer_json ["added_tokens" ] = [x for x in added_tokens if x ["content" ] in {"[UNK]" , "[PAD]" }]
133169
134- if model_type in {"WordPiece" , "BPE" }:
135- # Easiest, just add the new vocab
136- unk_token = unk_token or tokenizer_json ["model" ]["unk_token" ]
137- tokenizer_json ["model" ]["unk_token" ] = "[UNK]" if unk_token else None
138- tokenizer_json ["model" ]["vocab" ] = {token : idx for idx , token in enumerate (pre_tokenized_tokens )}
139-
140- if model_type == "BPE" :
141- # Bit more difficult, we need to take into account merges.
142- merges = tokenizer_json ["model" ]["merges" ]
143- merges = _make_new_merges_from_vocab (merges , pre_tokenized_tokens , added_token_forms )
144- tokenizer_json ["model" ]["merges" ] = merges
145-
170+ if model_type == "WordPiece" :
171+ tokenizer_json = _process_wordpiece (tokenizer_json , pre_tokenized_tokens , unk_token )
172+ elif model_type == "BPE" :
173+ tokenizer_json = _process_bpe (tokenizer_json , pre_tokenized_tokens )
146174 elif model_type == "Unigram" :
147- # Bit more difficult, we need to take into account probas.
148- unk_id = tokenizer_json ["model" ]["unk_id" ]
149- vocab = tokenizer_json ["model" ]["vocab" ]
150- unk_token = vocab [unk_id ][0 ] if unk_id is not None else None
151- current_probas = dict (tokenizer_json ["model" ]["vocab" ])
152- avg_proba = sum (current_probas .values ()) / len (current_probas )
153- new_probas = {word : current_probas .get (word , avg_proba ) for word in pre_tokenized_tokens }
154- tokenizer_json ["model" ]["vocab" ] = sorted (new_probas .items (), key = lambda x : x [1 ], reverse = True )
155-
156- tokens , _ = zip (* tokenizer_json ["model" ]["vocab" ])
157- tokenizer_json ["model" ]["unk_id" ] = list (tokens ).index (unk_token ) if unk_token in tokens else None
158-
175+ tokenizer_json = _process_unigram (tokenizer_json , pre_tokenized_tokens , unk_token )
159176 else :
160177 raise ValueError (f"Unknown model type { model_type } " )
161178
0 commit comments