Skip to content

Commit e025b9f

Browse files
committed
fix: separate tokenizers
1 parent 5833c1c commit e025b9f

1 file changed

Lines changed: 44 additions & 27 deletions

File tree

model2vec/distill/tokenizer.py

Lines changed: 44 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,45 @@ def _make_new_merges_from_vocab(
107107
return new_merges
108108

109109

110+
def _process_wordpiece(
111+
tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
112+
) -> dict[str, Any]:
113+
"""Process the WordPiece tokenizer JSON."""
114+
unk_token = unk_token or tokenizer_json["model"]["unk_token"]
115+
tokenizer_json["model"]["unk_token"] = "[UNK]" if unk_token else None
116+
tokenizer_json["model"]["vocab"] = {token: idx for idx, token in enumerate(pre_tokenized_tokens)}
117+
118+
return tokenizer_json
119+
120+
121+
def _process_bpe(tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str]) -> dict[str, Any]:
122+
"""Process the BPE tokenizer JSON."""
123+
tokenizer_json = _process_wordpiece(tokenizer_json, pre_tokenized_tokens, None)
124+
merges = tokenizer_json["model"]["merges"]
125+
merges = _make_new_merges_from_vocab(merges, pre_tokenized_tokens, {"[UNK]", "[PAD]"})
126+
tokenizer_json["model"]["merges"] = merges
127+
128+
return tokenizer_json
129+
130+
131+
def _process_unigram(
132+
tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
133+
) -> dict[str, Any]:
134+
"""Process the Unigram tokenizer JSON."""
135+
unk_id = tokenizer_json["model"]["unk_id"]
136+
vocab = tokenizer_json["model"]["vocab"]
137+
unk_token = vocab[unk_id][0] if unk_id is not None else None
138+
current_probas = dict(tokenizer_json["model"]["vocab"])
139+
avg_proba = sum(current_probas.values()) / len(current_probas)
140+
new_probas = {word: current_probas.get(word, avg_proba) for word in pre_tokenized_tokens}
141+
tokenizer_json["model"]["vocab"] = sorted(new_probas.items(), key=lambda x: x[1], reverse=True)
142+
143+
tokens, _ = zip(*tokenizer_json["model"]["vocab"])
144+
tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token) if unk_token in tokens else None
145+
146+
return tokenizer_json
147+
148+
110149
def replace_vocabulary(
111150
tokenizer: Tokenizer, new_vocabulary: list[Token], unk_token: str | None, pad_token: str | None
112151
) -> Tokenizer:
@@ -121,41 +160,19 @@ def replace_vocabulary(
121160
model_type = tokenizer_json["model"]["type"]
122161
added_tokens: list[dict[str, Any]] = tokenizer_json["added_tokens"]
123162

124-
# NOTE: all added tokens but the unk and pad tokens are removed already.
125-
# We only need this for BPE.
126-
added_token_forms = {x["content"] for x in added_tokens} | {"[UNK]", "[PAD]"}
127163
# We need to remove the added tokens but keep [UNK] and [PAD] tokens.
128164
added_tokens = _rename_added_token(unk_token, "[UNK]", added_tokens, pre_tokenized_tokens)
129165
added_tokens = _rename_added_token(pad_token, "[PAD]", added_tokens, pre_tokenized_tokens)
130166

131167
# Remove old added tokens from added tokens
132168
tokenizer_json["added_tokens"] = [x for x in added_tokens if x["content"] in {"[UNK]", "[PAD]"}]
133169

134-
if model_type in {"WordPiece", "BPE"}:
135-
# Easiest, just add the new vocab
136-
unk_token = unk_token or tokenizer_json["model"]["unk_token"]
137-
tokenizer_json["model"]["unk_token"] = "[UNK]" if unk_token else None
138-
tokenizer_json["model"]["vocab"] = {token: idx for idx, token in enumerate(pre_tokenized_tokens)}
139-
140-
if model_type == "BPE":
141-
# Bit more difficult, we need to take into account merges.
142-
merges = tokenizer_json["model"]["merges"]
143-
merges = _make_new_merges_from_vocab(merges, pre_tokenized_tokens, added_token_forms)
144-
tokenizer_json["model"]["merges"] = merges
145-
170+
if model_type == "WordPiece":
171+
tokenizer_json = _process_wordpiece(tokenizer_json, pre_tokenized_tokens, unk_token)
172+
elif model_type == "BPE":
173+
tokenizer_json = _process_bpe(tokenizer_json, pre_tokenized_tokens)
146174
elif model_type == "Unigram":
147-
# Bit more difficult, we need to take into account probas.
148-
unk_id = tokenizer_json["model"]["unk_id"]
149-
vocab = tokenizer_json["model"]["vocab"]
150-
unk_token = vocab[unk_id][0] if unk_id is not None else None
151-
current_probas = dict(tokenizer_json["model"]["vocab"])
152-
avg_proba = sum(current_probas.values()) / len(current_probas)
153-
new_probas = {word: current_probas.get(word, avg_proba) for word in pre_tokenized_tokens}
154-
tokenizer_json["model"]["vocab"] = sorted(new_probas.items(), key=lambda x: x[1], reverse=True)
155-
156-
tokens, _ = zip(*tokenizer_json["model"]["vocab"])
157-
tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token) if unk_token in tokens else None
158-
175+
tokenizer_json = _process_unigram(tokenizer_json, pre_tokenized_tokens, unk_token)
159176
else:
160177
raise ValueError(f"Unknown model type {model_type}")
161178

0 commit comments

Comments
 (0)