Skip to content

Commit ad17768

Browse files
committed
fix: issues with unk and pad
1 parent 77f16df commit ad17768

1 file changed

Lines changed: 35 additions & 9 deletions

File tree

model2vec/distill/tokenizer.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[Token]) -> list[
3030
:param tokens: The tokens to pre-tokenize.
3131
:return: The pre-tokenized tokens.
3232
"""
33-
current_tokenizer_vocab = set(tokenizer.get_vocab())
3433
pre_tokenized_tokens = []
3534

3635
if tokenizer.pre_tokenizer is not None:
@@ -50,7 +49,7 @@ def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[Token]) -> list[
5049
def _remap_added_tokens(
5150
special_tokens: list[dict[str, Any]],
5251
vocabulary: list[str],
53-
) -> list[dict[str, int]]:
52+
) -> list[dict[str, Any]]:
5453
"""
5554
Remap special tokens in the tokenizer.
5655
@@ -119,25 +118,33 @@ def replace_vocabulary(
119118
pre_tokenized_tokens = _pre_tokenize_vocabulary(tokenizer, new_vocabulary)
120119

121120
model_type = tokenizer_json["model"]["type"]
122-
special_tokens = {unk_token, pad_token}
121+
added_tokens: list[dict[str, Any]] = tokenizer_json["added_tokens"]
122+
original_added_tokens = {x["content"] for x in added_tokens} - {"[UNK]", "[PAD]"}
123+
added_tokens = _rename_added_token(unk_token, "[UNK]", added_tokens, pre_tokenized_tokens)
124+
added_tokens = _rename_added_token(pad_token, "[PAD]", added_tokens, pre_tokenized_tokens)
125+
126+
# Remove old special tokens
127+
added_tokens = [x for x in added_tokens if x["content"] not in original_added_tokens]
128+
# Remove other special tokens from the vocabulary
129+
pre_tokenized_tokens = [x for x in pre_tokenized_tokens if x not in original_added_tokens]
130+
tokenizer_json["added_tokens"] = added_tokens
131+
all_added_tokens = {x["content"] for x in added_tokens} | original_added_tokens
123132

124133
if model_type in {"WordPiece", "BPE"}:
125134
# Easiest, just add the new vocab
126135
unk_token = unk_token or tokenizer_json["model"]["unk_token"]
127-
tokenizer_json["model"]["unk_token"] = unk_token
128-
tokenizer_json["added_tokens"] = [x for x in tokenizer_json["added_tokens"] if x["content"] in special_tokens]
136+
tokenizer_json["model"]["unk_token"] = "[UNK]"
129137
tokenizer_json["model"]["vocab"] = {token: idx for idx, token in enumerate(pre_tokenized_tokens)}
130138

131139
if model_type == "BPE":
132140
# Bit more difficult, we need to take into account merges.
133141
merges = tokenizer_json["model"]["merges"]
134-
merges = _make_new_merges_from_vocab(merges, pre_tokenized_tokens, special_tokens)
142+
merges = _make_new_merges_from_vocab(merges, pre_tokenized_tokens, all_added_tokens)
135143
tokenizer_json["model"]["merges"] = merges
136144

137145
elif model_type == "Unigram":
138146
# Bit more difficult, we need to take into account probas.
139147
unk_id = tokenizer_json["model"]["unk_id"]
140-
tokenizer_json["added_tokens"] = [x for x in tokenizer_json["added_tokens"] if x["content"] in special_tokens]
141148
vocab = tokenizer_json["model"]["vocab"]
142149
unk_token = vocab[unk_id][0] if unk_id is not None else None
143150
current_probas = dict(tokenizer_json["model"]["vocab"])
@@ -152,8 +159,27 @@ def replace_vocabulary(
152159
raise ValueError(f"Unknown model type {model_type}")
153160

154161
# Remap special tokens
155-
added_tokens = tokenizer_json["added_tokens"]
156-
tokenizer_json["added_tokens"] = _remap_added_tokens(added_tokens, pre_tokenized_tokens)
162+
tokenizer_json["added_tokens"] = _remap_added_tokens(
163+
special_tokens=tokenizer_json["added_tokens"],
164+
vocabulary=pre_tokenized_tokens,
165+
)
157166
tokenizer_json["post_processor"] = _DEFAULT_POST_PROCESSOR_TEMPLATE
158167

159168
return Tokenizer.from_str(json.dumps(tokenizer_json))
169+
170+
171+
def _rename_added_token(
172+
form: str | None, new_form: str, added_tokens: list[dict[str, Any]], vocabulary: list[str]
173+
) -> list[dict[str, Any]]:
174+
"""Rename special tokens in the tokenizer."""
175+
if form is None:
176+
return added_tokens
177+
178+
idx = vocabulary.index(form)
179+
added_token = [x for x in added_tokens if x["content"] == form]
180+
if added_token:
181+
added_token[0]["id"] = idx
182+
added_token[0]["content"] = new_form
183+
vocabulary[idx] = new_form
184+
185+
return added_tokens

0 commit comments

Comments
 (0)