55from typing import Any
66
77from tokenizers import Tokenizer
8+ from tokenizers .pre_tokenizers import (
9+ BertPreTokenizer ,
10+ ByteLevel ,
11+ CharDelimiterSplit ,
12+ Digits ,
13+ Metaspace ,
14+ PreTokenizer ,
15+ Punctuation ,
16+ Sequence ,
17+ Split ,
18+ UnicodeScripts ,
19+ Whitespace ,
20+ WhitespaceSplit ,
21+ )
22+
23+ _FORBIDDEN_PRETOKENIZERS = (
24+ BertPreTokenizer ,
25+ CharDelimiterSplit ,
26+ Metaspace ,
27+ Punctuation ,
28+ Split ,
29+ UnicodeScripts ,
30+ Whitespace ,
31+ WhitespaceSplit ,
32+ )
33+
834
935logger = logging .getLogger (__name__ )
1036
@@ -36,9 +62,9 @@ def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[str]) -> list[st
3662 if token in current_tokenizer_vocab :
3763 pre_tokenized_tokens .append (token )
3864 else :
39- # We know 100% sure that all pretokenized tokens will have length 1 .
40- pretokenized_tokens , _ = zip (* tokenizer .pre_tokenizer .pre_tokenize_str (f" { token } " ))
41- pre_tokenized_tokens .append (pretokenized_tokens [ - 1 ] )
65+ # Join tokens just to be sure .
66+ pretokenized_tokens , _ = zip (* tokenizer .pre_tokenizer .pre_tokenize_str (token ))
67+ pre_tokenized_tokens .append (" " . join ( pretokenized_tokens ) )
4268 else :
4369 pre_tokenized_tokens = tokens
4470
@@ -67,6 +93,30 @@ def _remap_added_tokens(
6793 return special_tokens
6894
6995
96+ def _fix_single_pretokenizer (pretokenizer : PreTokenizer ) -> PreTokenizer | None :
97+ """Fixes a single pretokenizer to allow multiword units."""
98+ if isinstance (pretokenizer , _FORBIDDEN_PRETOKENIZERS ):
99+ return Metaspace (split = False , replacement = "Ġ" )
100+ elif isinstance (pretokenizer , ByteLevel ):
101+ pretokenizer .use_regex = False
102+
103+ return pretokenizer
104+
105+
106+ def _fix_pretokenizer_for_super (pre : PreTokenizer | None ) -> Tokenizer :
107+ """Fixes the pretokenizer to allow multiword units."""
108+ if pre is None :
109+ return pre
110+
111+ if isinstance (pre , Sequence ):
112+ new_pretokenizers = []
113+ for pretokenizer in pre :
114+ new_pretokenizers .append (_fix_single_pretokenizer (pretokenizer ))
115+ return Sequence (new_pretokenizers )
116+
117+ return _fix_single_pretokenizer (pre )
118+
119+
70120def _make_new_merges_from_vocab (
71121 merges : list [tuple [str , str ]], tokens : list [str ], special_tokens : set [str | None ]
72122) -> list [tuple [str , str ]]:
@@ -109,6 +159,7 @@ def replace_vocabulary(
109159 tokenizer : Tokenizer , new_vocabulary : list [str ], unk_token : str | None , pad_token : str | None
110160) -> Tokenizer :
111161 """Replace the vocabulary of a tokenizer with a new one."""
162+ tokenizer .pre_tokenizer = _fix_pretokenizer_for_super (tokenizer .pre_tokenizer )
112163 tokenizer_json : dict [str , Any ] = json .loads (tokenizer .to_str ())
113164
114165 # NOTE: all tokens have been normalized before.
@@ -124,10 +175,30 @@ def replace_vocabulary(
124175 unk_token = unk_token or tokenizer_json ["model" ]["unk_token" ]
125176 tokenizer_json ["model" ]["unk_token" ] = unk_token
126177 tokenizer_json ["added_tokens" ] = [x for x in tokenizer_json ["added_tokens" ] if x ["content" ] in special_tokens ]
127- tokenizer_json ["model" ]["vocab" ] = {token : idx for idx , token in enumerate (pre_tokenized_tokens )}
178+
179+ if model_type == "WordPiece" :
180+ subword_prefix = tokenizer_json ["model" ]["continuing_subword_prefix" ]
181+ new_vocab = {}
182+ for idx , token in enumerate (pre_tokenized_tokens ):
183+ if token in special_tokens :
184+ # We need to remove the prefix from the token
185+ pass
186+ elif token .startswith (subword_prefix ):
187+ # We need to remove the prefix from the token
188+ token = token .removeprefix (subword_prefix )
189+ elif token .startswith ("Ġ" ):
190+ pass
191+ else :
192+ # We need to add the prefix to the token
193+ token = f"Ġ{ token } "
194+ new_vocab [token ] = idx
195+ tokenizer_json ["model" ]["continuing_subword_prefix" ] = ""
196+ tokenizer_json ["model" ]["max_input_chars_per_word" ] = 10_000
197+ tokenizer_json ["model" ]["vocab" ] = new_vocab
128198
129199 if model_type == "BPE" :
130200 # Bit more difficult, we need to take into account merges.
201+ tokenizer_json ["model" ]["vocab" ] = {token : idx for idx , token in enumerate (pre_tokenized_tokens )}
131202 merges = tokenizer_json ["model" ]["merges" ]
132203 merges = _make_new_merges_from_vocab (merges , pre_tokenized_tokens , special_tokens )
133204 tokenizer_json ["model" ]["merges" ] = merges
@@ -154,4 +225,5 @@ def replace_vocabulary(
154225 tokenizer_json ["added_tokens" ] = _remap_added_tokens (added_tokens , pre_tokenized_tokens )
155226 tokenizer_json ["post_processor" ] = _DEFAULT_POST_PROCESSOR_TEMPLATE
156227
157- return Tokenizer .from_str (json .dumps (tokenizer_json ))
228+ tokenizer = Tokenizer .from_str (json .dumps (tokenizer_json ))
229+ return tokenizer
0 commit comments