Skip to content

Commit d24c387

Browse files
committed
add superbpe tokenizers
1 parent b172b6d commit d24c387

2 files changed

Lines changed: 77 additions & 8 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,6 @@ def _clean_vocabulary(tokenizer: Tokenizer, vocabulary: list[str], added_tokens:
302302
cleaned_vocabulary = []
303303
n_empty = 0
304304
n_duplicates = 0
305-
n_multiword = 0
306305
for token in vocabulary:
307306
if tokenizer.normalizer is not None:
308307
token = tokenizer.normalizer.normalize_str(token)
@@ -321,7 +320,5 @@ def _clean_vocabulary(tokenizer: Tokenizer, vocabulary: list[str], added_tokens:
321320
logger.warning(f"Removed {n_duplicates} duplicate tokens.")
322321
if n_empty:
323322
logger.warning(f"Removed {n_empty} empty tokens.")
324-
if n_multiword:
325-
logger.warning(f"Removed {n_multiword} multiword tokens.")
326323

327324
return cleaned_vocabulary

model2vec/distill/tokenizer.py

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,32 @@
55
from typing import Any
66

77
from tokenizers import Tokenizer
8+
from tokenizers.pre_tokenizers import (
9+
BertPreTokenizer,
10+
ByteLevel,
11+
CharDelimiterSplit,
12+
Digits,
13+
Metaspace,
14+
PreTokenizer,
15+
Punctuation,
16+
Sequence,
17+
Split,
18+
UnicodeScripts,
19+
Whitespace,
20+
WhitespaceSplit,
21+
)
22+
23+
_FORBIDDEN_PRETOKENIZERS = (
24+
BertPreTokenizer,
25+
CharDelimiterSplit,
26+
Metaspace,
27+
Punctuation,
28+
Split,
29+
UnicodeScripts,
30+
Whitespace,
31+
WhitespaceSplit,
32+
)
33+
834

935
logger = logging.getLogger(__name__)
1036

@@ -36,9 +62,9 @@ def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[str]) -> list[st
3662
if token in current_tokenizer_vocab:
3763
pre_tokenized_tokens.append(token)
3864
else:
39-
# We know 100% sure that all pretokenized tokens will have length 1.
40-
pretokenized_tokens, _ = zip(*tokenizer.pre_tokenizer.pre_tokenize_str(f" {token}"))
41-
pre_tokenized_tokens.append(pretokenized_tokens[-1])
65+
# Join tokens just to be sure.
66+
pretokenized_tokens, _ = zip(*tokenizer.pre_tokenizer.pre_tokenize_str(token))
67+
pre_tokenized_tokens.append(" ".join(pretokenized_tokens))
4268
else:
4369
pre_tokenized_tokens = tokens
4470

@@ -67,6 +93,30 @@ def _remap_added_tokens(
6793
return special_tokens
6894

6995

96+
def _fix_single_pretokenizer(pretokenizer: PreTokenizer) -> PreTokenizer | None:
97+
"""Fixes a single pretokenizer to allow multiword units."""
98+
if isinstance(pretokenizer, _FORBIDDEN_PRETOKENIZERS):
99+
return Metaspace(split=False, replacement="Ġ")
100+
elif isinstance(pretokenizer, ByteLevel):
101+
pretokenizer.use_regex = False
102+
103+
return pretokenizer
104+
105+
106+
def _fix_pretokenizer_for_super(pre: PreTokenizer | None) -> Tokenizer:
107+
"""Fixes the pretokenizer to allow multiword units."""
108+
if pre is None:
109+
return pre
110+
111+
if isinstance(pre, Sequence):
112+
new_pretokenizers = []
113+
for pretokenizer in pre:
114+
new_pretokenizers.append(_fix_single_pretokenizer(pretokenizer))
115+
return Sequence(new_pretokenizers)
116+
117+
return _fix_single_pretokenizer(pre)
118+
119+
70120
def _make_new_merges_from_vocab(
71121
merges: list[tuple[str, str]], tokens: list[str], special_tokens: set[str | None]
72122
) -> list[tuple[str, str]]:
@@ -109,6 +159,7 @@ def replace_vocabulary(
109159
tokenizer: Tokenizer, new_vocabulary: list[str], unk_token: str | None, pad_token: str | None
110160
) -> Tokenizer:
111161
"""Replace the vocabulary of a tokenizer with a new one."""
162+
tokenizer.pre_tokenizer = _fix_pretokenizer_for_super(tokenizer.pre_tokenizer)
112163
tokenizer_json: dict[str, Any] = json.loads(tokenizer.to_str())
113164

114165
# NOTE: all tokens have been normalized before.
@@ -124,10 +175,30 @@ def replace_vocabulary(
124175
unk_token = unk_token or tokenizer_json["model"]["unk_token"]
125176
tokenizer_json["model"]["unk_token"] = unk_token
126177
tokenizer_json["added_tokens"] = [x for x in tokenizer_json["added_tokens"] if x["content"] in special_tokens]
127-
tokenizer_json["model"]["vocab"] = {token: idx for idx, token in enumerate(pre_tokenized_tokens)}
178+
179+
if model_type == "WordPiece":
180+
subword_prefix = tokenizer_json["model"]["continuing_subword_prefix"]
181+
new_vocab = {}
182+
for idx, token in enumerate(pre_tokenized_tokens):
183+
if token in special_tokens:
184+
# We need to remove the prefix from the token
185+
pass
186+
elif token.startswith(subword_prefix):
187+
# We need to remove the prefix from the token
188+
token = token.removeprefix(subword_prefix)
189+
elif token.startswith("Ġ"):
190+
pass
191+
else:
192+
# We need to add the prefix to the token
193+
token = f"Ġ{token}"
194+
new_vocab[token] = idx
195+
tokenizer_json["model"]["continuing_subword_prefix"] = ""
196+
tokenizer_json["model"]["max_input_chars_per_word"] = 10_000
197+
tokenizer_json["model"]["vocab"] = new_vocab
128198

129199
if model_type == "BPE":
130200
# Bit more difficult, we need to take into account merges.
201+
tokenizer_json["model"]["vocab"] = {token: idx for idx, token in enumerate(pre_tokenized_tokens)}
131202
merges = tokenizer_json["model"]["merges"]
132203
merges = _make_new_merges_from_vocab(merges, pre_tokenized_tokens, special_tokens)
133204
tokenizer_json["model"]["merges"] = merges
@@ -154,4 +225,5 @@ def replace_vocabulary(
154225
tokenizer_json["added_tokens"] = _remap_added_tokens(added_tokens, pre_tokenized_tokens)
155226
tokenizer_json["post_processor"] = _DEFAULT_POST_PROCESSOR_TEMPLATE
156227

157-
return Tokenizer.from_str(json.dumps(tokenizer_json))
228+
tokenizer = Tokenizer.from_str(json.dumps(tokenizer_json))
229+
return tokenizer

0 commit comments

Comments
 (0)