Skip to content

Commit 0a07cd4

Browse files
authored
fix: issues with unk and pad (#225)
* fix: issues with unk and pad * fix tests * upper case deprecated * Clearify code * fix: separate tokenizers
1 parent 77f16df commit 0a07cd4

5 files changed

Lines changed: 140 additions & 161 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ def distill_from_model(
3939
pca_dims: PCADimType = 256,
4040
apply_zipf: bool | None = None,
4141
sif_coefficient: float | None = 1e-4,
42-
use_subword: bool = True,
4342
token_remove_pattern: str | None = r"\[unused\d+\]",
4443
quantize_to: DType | str = DType.Float16,
44+
use_subword: bool | None = None,
4545
) -> StaticModel:
4646
"""
4747
Distill a staticmodel from a sentence transformer.
@@ -63,18 +63,20 @@ def distill_from_model(
6363
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
6464
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
6565
Should be a value > 0 and < 1.0. A value of 1e-4 is a good default.
66-
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
6766
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
6867
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
6968
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
69+
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
7070
:return: A StaticModel
7171
7272
"""
73+
if use_subword is not None:
74+
logger.warning(
75+
"The `use_subword` parameter is deprecated and will be removed in the next release. It doesn't do anything."
76+
)
7377
quantize_to = DType(quantize_to)
7478
backend_tokenizer = tokenizer.backend_tokenizer
75-
sif_coefficient, token_remove_regex = _validate_parameters(
76-
vocabulary, apply_zipf, sif_coefficient, use_subword, token_remove_pattern
77-
)
79+
sif_coefficient, token_remove_regex = _validate_parameters(apply_zipf, sif_coefficient, token_remove_pattern)
7880

7981
if vocabulary is None:
8082
vocabulary = []
@@ -98,7 +100,6 @@ def distill_from_model(
98100
tokenizer=tokenizer,
99101
tokens=cleaned_vocabulary,
100102
device=device,
101-
use_subword=use_subword,
102103
token_remove_regex=token_remove_regex,
103104
)
104105

@@ -151,27 +152,20 @@ def distill_from_model(
151152

152153

153154
def _validate_parameters(
154-
vocabulary: list[str] | None,
155155
apply_zipf: bool | None,
156156
sif_coefficient: float | None,
157-
use_subword: bool,
158157
token_remove_pattern: str | None,
159158
) -> tuple[float | None, re.Pattern | None]:
160159
"""
161160
Validate the parameters passed to the distillation function.
162161
163-
:param vocabulary: The vocabulary to use.
164162
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
165163
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
166164
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
167165
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
168-
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
169166
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
170167
:return: The SIF coefficient to use.
171-
:raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings.
172-
:raises: ValueError if the vocabulary contains duplicate tokens.
173168
:raises: ValueError if the regex can't be compiled.
174-
:raises: ValueError if the vocabulary is empty after token removal.
175169
176170
"""
177171
if apply_zipf is not None:
@@ -191,11 +185,6 @@ def _validate_parameters(
191185
if not 0 < sif_coefficient < 1.0:
192186
raise ValueError("SIF coefficient must be a value > 0 and < 1.0.")
193187

194-
if not use_subword and vocabulary is None:
195-
raise ValueError(
196-
"You must pass a vocabulary if you don't use subword tokens. Either pass a vocabulary, or set use_subword to True."
197-
)
198-
199188
token_remove_regex: re.Pattern | None = None
200189
if token_remove_pattern is not None:
201190
try:
@@ -213,10 +202,10 @@ def distill(
213202
pca_dims: PCADimType = 256,
214203
apply_zipf: bool | None = None,
215204
sif_coefficient: float | None = 1e-4,
216-
use_subword: bool = True,
217205
token_remove_pattern: str | None = r"\[unused\d+\]",
218206
trust_remote_code: bool = False,
219207
quantize_to: DType | str = DType.Float16,
208+
use_subword: bool | None = None,
220209
) -> StaticModel:
221210
"""
222211
Distill a staticmodel from a sentence transformer.
@@ -237,10 +226,10 @@ def distill(
237226
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
238227
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
239228
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
240-
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
241229
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
242230
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
243231
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
232+
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
244233
:return: A StaticModel
245234
246235
"""
@@ -254,10 +243,10 @@ def distill(
254243
device=device,
255244
pca_dims=pca_dims,
256245
apply_zipf=apply_zipf,
257-
use_subword=use_subword,
258246
token_remove_pattern=token_remove_pattern,
259247
sif_coefficient=sif_coefficient,
260248
quantize_to=quantize_to,
249+
use_subword=use_subword,
261250
)
262251

263252

model2vec/distill/inference.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class ModulewithWeights(Protocol):
3131
def create_embeddings(
3232
model: PreTrainedModel,
3333
tokenizer: PreTrainedTokenizerFast,
34-
use_subword: bool,
3534
tokens: list[str],
3635
device: str,
3736
token_remove_regex: re.Pattern | None,
@@ -44,7 +43,6 @@ def create_embeddings(
4443
:param model: The model to use.
4544
This should be a transformers model.
4645
:param tokenizer: The tokenizer to use.
47-
:param use_subword: Whether to include subword tokens in the output.
4846
:param tokens: The tokens to use.
4947
:param device: The torch device to use.
5048
:param token_remove_regex: A regex pattern to remove tokens from the vocabulary.
@@ -58,28 +56,23 @@ def create_embeddings(
5856
out_tokens: list[Token] = []
5957
tokenized: list[torch.Tensor] = []
6058
pad_token = tokenizer.special_tokens_map.get("pad_token")
59+
# We need to use the pad token id for padding below.
6160
pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
6261
unk_token = tokenizer.special_tokens_map.get("unk_token")
6362

64-
tokens_to_keep = {pad_token, unk_token}
63+
# Empty set if no pad or unk token is set.
64+
tokens_to_keep = {pad_token, unk_token} - {None}
6565

66-
if use_subword:
67-
if token_remove_regex is not None:
68-
# Sort the vocabulary by id, important for zipf.
69-
sorted_vocab = sorted(tokenizer.get_vocab().items(), key=lambda x: x[1])
70-
id_list = filter_vocabulary_by_regex(token_remove_regex, sorted_vocab)
71-
else:
72-
# If the token remove regex is None, just use all tokens.
73-
id_list = list(range(len(tokenizer.get_vocab())))
74-
75-
added_tokens_ids = [id for token, id in tokenizer.added_tokens_encoder.items() if token not in tokens_to_keep]
76-
ids = torch.Tensor(sorted(set(id_list) - set(added_tokens_ids))).long()
77-
78-
elif unk_token:
79-
# Include unk token. This is necessary for some models.
80-
ids = torch.Tensor(tokenizer.convert_tokens_to_ids([unk_token, pad_token])).long()
66+
if token_remove_regex is not None:
67+
# Sort the vocabulary by id, important for zipf.
68+
sorted_vocab = sorted(tokenizer.get_vocab().items(), key=lambda x: x[1])
69+
id_list = filter_vocabulary_by_regex(token_remove_regex, sorted_vocab)
8170
else:
82-
ids = None
71+
# If the token remove regex is None, just use all tokens.
72+
id_list = list(range(len(tokenizer.get_vocab())))
73+
74+
added_tokens_ids = [id for token, id in tokenizer.added_tokens_encoder.items() if token not in tokens_to_keep]
75+
ids = torch.Tensor(sorted(set(id_list) - set(added_tokens_ids))).long()
8376

8477
if ids is not None:
8578
dummy_encoding = tokenizer.encode("A")

model2vec/distill/tokenizer.py

Lines changed: 75 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,12 @@ def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[Token]) -> list[
3030
:param tokens: The tokens to pre-tokenize.
3131
:return: The pre-tokenized tokens.
3232
"""
33-
current_tokenizer_vocab = set(tokenizer.get_vocab())
3433
pre_tokenized_tokens = []
3534

3635
if tokenizer.pre_tokenizer is not None:
3736
for token in tokens:
38-
if token.is_subword:
37+
if token.is_original:
38+
# Original tokens do not need to be pre-tokenized.
3939
pre_tokenized_tokens.append(token.form)
4040
else:
4141
# We know 100% sure that all pretokenized tokens will have length 1.
@@ -50,7 +50,7 @@ def _pre_tokenize_vocabulary(tokenizer: Tokenizer, tokens: list[Token]) -> list[
5050
def _remap_added_tokens(
5151
special_tokens: list[dict[str, Any]],
5252
vocabulary: list[str],
53-
) -> list[dict[str, int]]:
53+
) -> list[dict[str, Any]]:
5454
"""
5555
Remap special tokens in the tokenizer.
5656
@@ -107,6 +107,45 @@ def _make_new_merges_from_vocab(
107107
return new_merges
108108

109109

110+
def _process_wordpiece(
111+
tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
112+
) -> dict[str, Any]:
113+
"""Process the WordPiece tokenizer JSON."""
114+
unk_token = unk_token or tokenizer_json["model"]["unk_token"]
115+
tokenizer_json["model"]["unk_token"] = "[UNK]" if unk_token else None
116+
tokenizer_json["model"]["vocab"] = {token: idx for idx, token in enumerate(pre_tokenized_tokens)}
117+
118+
return tokenizer_json
119+
120+
121+
def _process_bpe(tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str]) -> dict[str, Any]:
122+
"""Process the BPE tokenizer JSON."""
123+
tokenizer_json = _process_wordpiece(tokenizer_json, pre_tokenized_tokens, None)
124+
merges = tokenizer_json["model"]["merges"]
125+
merges = _make_new_merges_from_vocab(merges, pre_tokenized_tokens, {"[UNK]", "[PAD]"})
126+
tokenizer_json["model"]["merges"] = merges
127+
128+
return tokenizer_json
129+
130+
131+
def _process_unigram(
132+
tokenizer_json: dict[str, Any], pre_tokenized_tokens: list[str], unk_token: str | None
133+
) -> dict[str, Any]:
134+
"""Process the Unigram tokenizer JSON."""
135+
unk_id = tokenizer_json["model"]["unk_id"]
136+
vocab = tokenizer_json["model"]["vocab"]
137+
unk_token = vocab[unk_id][0] if unk_id is not None else None
138+
current_probas = dict(tokenizer_json["model"]["vocab"])
139+
avg_proba = sum(current_probas.values()) / len(current_probas)
140+
new_probas = {word: current_probas.get(word, avg_proba) for word in pre_tokenized_tokens}
141+
tokenizer_json["model"]["vocab"] = sorted(new_probas.items(), key=lambda x: x[1], reverse=True)
142+
143+
tokens, _ = zip(*tokenizer_json["model"]["vocab"])
144+
tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token) if unk_token in tokens else None
145+
146+
return tokenizer_json
147+
148+
110149
def replace_vocabulary(
111150
tokenizer: Tokenizer, new_vocabulary: list[Token], unk_token: str | None, pad_token: str | None
112151
) -> Tokenizer:
@@ -115,45 +154,50 @@ def replace_vocabulary(
115154

116155
# NOTE: all tokens have been normalized before.
117156
# Very careful, we need to pretokenize words before adding them to the vocabulary.
118-
# But only if they are not subword tokens.
157+
# But only if they are not part of the original vocabulary.
119158
pre_tokenized_tokens = _pre_tokenize_vocabulary(tokenizer, new_vocabulary)
120159

121160
model_type = tokenizer_json["model"]["type"]
122-
special_tokens = {unk_token, pad_token}
161+
added_tokens: list[dict[str, Any]] = tokenizer_json["added_tokens"]
123162

124-
if model_type in {"WordPiece", "BPE"}:
125-
# Easiest, just add the new vocab
126-
unk_token = unk_token or tokenizer_json["model"]["unk_token"]
127-
tokenizer_json["model"]["unk_token"] = unk_token
128-
tokenizer_json["added_tokens"] = [x for x in tokenizer_json["added_tokens"] if x["content"] in special_tokens]
129-
tokenizer_json["model"]["vocab"] = {token: idx for idx, token in enumerate(pre_tokenized_tokens)}
163+
# We need to remove the added tokens but keep [UNK] and [PAD] tokens.
164+
added_tokens = _rename_added_token(unk_token, "[UNK]", added_tokens, pre_tokenized_tokens)
165+
added_tokens = _rename_added_token(pad_token, "[PAD]", added_tokens, pre_tokenized_tokens)
130166

131-
if model_type == "BPE":
132-
# Bit more difficult, we need to take into account merges.
133-
merges = tokenizer_json["model"]["merges"]
134-
merges = _make_new_merges_from_vocab(merges, pre_tokenized_tokens, special_tokens)
135-
tokenizer_json["model"]["merges"] = merges
167+
# Remove old added tokens from added tokens
168+
tokenizer_json["added_tokens"] = [x for x in added_tokens if x["content"] in {"[UNK]", "[PAD]"}]
136169

170+
if model_type == "WordPiece":
171+
tokenizer_json = _process_wordpiece(tokenizer_json, pre_tokenized_tokens, unk_token)
172+
elif model_type == "BPE":
173+
tokenizer_json = _process_bpe(tokenizer_json, pre_tokenized_tokens)
137174
elif model_type == "Unigram":
138-
# Bit more difficult, we need to take into account probas.
139-
unk_id = tokenizer_json["model"]["unk_id"]
140-
tokenizer_json["added_tokens"] = [x for x in tokenizer_json["added_tokens"] if x["content"] in special_tokens]
141-
vocab = tokenizer_json["model"]["vocab"]
142-
unk_token = vocab[unk_id][0] if unk_id is not None else None
143-
current_probas = dict(tokenizer_json["model"]["vocab"])
144-
avg_proba = sum(current_probas.values()) / len(current_probas)
145-
new_probas = {word: current_probas.get(word, avg_proba) for word in pre_tokenized_tokens}
146-
tokenizer_json["model"]["vocab"] = sorted(new_probas.items(), key=lambda x: x[1], reverse=True)
147-
148-
tokens, _ = zip(*tokenizer_json["model"]["vocab"])
149-
tokenizer_json["model"]["unk_id"] = list(tokens).index(unk_token) if unk_token in tokens else None
150-
175+
tokenizer_json = _process_unigram(tokenizer_json, pre_tokenized_tokens, unk_token)
151176
else:
152177
raise ValueError(f"Unknown model type {model_type}")
153178

154179
# Remap special tokens
155-
added_tokens = tokenizer_json["added_tokens"]
156-
tokenizer_json["added_tokens"] = _remap_added_tokens(added_tokens, pre_tokenized_tokens)
180+
tokenizer_json["added_tokens"] = _remap_added_tokens(
181+
special_tokens=tokenizer_json["added_tokens"],
182+
vocabulary=pre_tokenized_tokens,
183+
)
157184
tokenizer_json["post_processor"] = _DEFAULT_POST_PROCESSOR_TEMPLATE
158185

159186
return Tokenizer.from_str(json.dumps(tokenizer_json))
187+
188+
189+
def _rename_added_token(
190+
form: str | None, new_form: str, added_tokens: list[dict[str, Any]], vocabulary: list[str]
191+
) -> list[dict[str, Any]]:
192+
"""Rename special tokens in the tokenizer."""
193+
if form is None:
194+
return added_tokens
195+
196+
idx = vocabulary.index(form)
197+
added_token = [x for x in added_tokens if x["content"] == form]
198+
if added_token:
199+
added_token[0]["id"] = idx
200+
added_token[0]["content"] = new_form
201+
vocabulary[idx] = new_form
202+
203+
return added_tokens

model2vec/distill/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class Token:
1414
"""A class to represent a token."""
1515

1616
form: str
17-
is_subword: bool
17+
is_original: bool
1818

1919

2020
def select_optimal_device(device: str | None) -> str:

0 commit comments

Comments
 (0)