Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 10 additions & 21 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
pca_dims: PCADimType = 256,
apply_zipf: bool | None = None,
sif_coefficient: float | None = 1e-4,
use_subword: bool = True,
token_remove_pattern: str | None = r"\[unused\d+\]",
quantize_to: DType | str = DType.Float16,
use_subword: bool | None = None,
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.
Expand All @@ -63,18 +63,20 @@
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value > 0 and < 1.0. A value of 1e-4 is a good default.
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
:return: A StaticModel

"""
if use_subword is not None:
logger.warning(

Check warning on line 74 in model2vec/distill/distillation.py

View check run for this annotation

Codecov / codecov/patch

model2vec/distill/distillation.py#L74

Added line #L74 was not covered by tests
"The `use_subword` parameter is deprecated and will be removed in the next release. It doesn't do anything."
)
quantize_to = DType(quantize_to)
backend_tokenizer = tokenizer.backend_tokenizer
sif_coefficient, token_remove_regex = _validate_parameters(
vocabulary, apply_zipf, sif_coefficient, use_subword, token_remove_pattern
)
sif_coefficient, token_remove_regex = _validate_parameters(apply_zipf, sif_coefficient, token_remove_pattern)

if vocabulary is None:
vocabulary = []
Expand All @@ -98,7 +100,6 @@
tokenizer=tokenizer,
tokens=cleaned_vocabulary,
device=device,
use_subword=use_subword,
token_remove_regex=token_remove_regex,
)

Expand Down Expand Up @@ -151,27 +152,20 @@


def _validate_parameters(
vocabulary: list[str] | None,
apply_zipf: bool | None,
sif_coefficient: float | None,
use_subword: bool,
token_remove_pattern: str | None,
) -> tuple[float | None, re.Pattern | None]:
"""
Validate the parameters passed to the distillation function.

:param vocabulary: The vocabulary to use.
:param apply_zipf: DEPRECATED: This parameter used to control whether Zipf is applied.
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
:return: The SIF coefficient to use.
:raises: ValueError if the PCA dimension is larger than the number of dimensions in the embeddings.
:raises: ValueError if the vocabulary contains duplicate tokens.
:raises: ValueError if the regex can't be compiled.
:raises: ValueError if the vocabulary is empty after token removal.

"""
if apply_zipf is not None:
Expand All @@ -191,11 +185,6 @@
if not 0 < sif_coefficient < 1.0:
raise ValueError("SIF coefficient must be a value > 0 and < 1.0.")

if not use_subword and vocabulary is None:
raise ValueError(
"You must pass a vocabulary if you don't use subword tokens. Either pass a vocabulary, or set use_subword to True."
)

token_remove_regex: re.Pattern | None = None
if token_remove_pattern is not None:
try:
Expand All @@ -213,10 +202,10 @@
pca_dims: PCADimType = 256,
apply_zipf: bool | None = None,
sif_coefficient: float | None = 1e-4,
use_subword: bool = True,
token_remove_pattern: str | None = r"\[unused\d+\]",
trust_remote_code: bool = False,
quantize_to: DType | str = DType.Float16,
use_subword: bool | None = None,
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.
Expand All @@ -237,10 +226,10 @@
Zipf weighting is now controlled by the sif_coefficient parameter. If this is set to None, no weighting is applied.
:param sif_coefficient: The SIF coefficient to use. If this is None, no weighting is applied.
Should be a value >= 0 and < 1.0. A value of 1e-4 is a good default.
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:param use_subword: DEPRECATED: If this is not set to None, we show a warning. It doesn't do anything.
:return: A StaticModel

"""
Expand All @@ -254,10 +243,10 @@
device=device,
pca_dims=pca_dims,
apply_zipf=apply_zipf,
use_subword=use_subword,
token_remove_pattern=token_remove_pattern,
sif_coefficient=sif_coefficient,
quantize_to=quantize_to,
use_subword=use_subword,
)


Expand Down
31 changes: 12 additions & 19 deletions model2vec/distill/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ class ModulewithWeights(Protocol):
def create_embeddings(
model: PreTrainedModel,
tokenizer: PreTrainedTokenizerFast,
use_subword: bool,
tokens: list[str],
device: str,
token_remove_regex: re.Pattern | None,
Expand All @@ -44,7 +43,6 @@ def create_embeddings(
:param model: The model to use.
This should be a transformers model.
:param tokenizer: The tokenizer to use.
:param use_subword: Whether to include subword tokens in the output.
:param tokens: The tokens to use.
:param device: The torch device to use.
:param token_remove_regex: A regex pattern to remove tokens from the vocabulary.
Expand All @@ -58,28 +56,23 @@ def create_embeddings(
out_tokens: list[Token] = []
tokenized: list[torch.Tensor] = []
pad_token = tokenizer.special_tokens_map.get("pad_token")
# We need to use the pad token id for padding below.
pad_token_id = tokenizer.convert_tokens_to_ids(pad_token)
unk_token = tokenizer.special_tokens_map.get("unk_token")

tokens_to_keep = {pad_token, unk_token}
# Empty set if no pad or unk token is set.
tokens_to_keep = {pad_token, unk_token} - {None}

if use_subword:
if token_remove_regex is not None:
# Sort the vocabulary by id, important for zipf.
sorted_vocab = sorted(tokenizer.get_vocab().items(), key=lambda x: x[1])
id_list = filter_vocabulary_by_regex(token_remove_regex, sorted_vocab)
else:
# If the token remove regex is None, just use all tokens.
id_list = list(range(len(tokenizer.get_vocab())))

added_tokens_ids = [id for token, id in tokenizer.added_tokens_encoder.items() if token not in tokens_to_keep]
ids = torch.Tensor(sorted(set(id_list) - set(added_tokens_ids))).long()

elif unk_token:
# Include unk token. This is necessary for some models.
ids = torch.Tensor(tokenizer.convert_tokens_to_ids([unk_token, pad_token])).long()
if token_remove_regex is not None:
# Sort the vocabulary by id, important for zipf.
sorted_vocab = sorted(tokenizer.get_vocab().items(), key=lambda x: x[1])
id_list = filter_vocabulary_by_regex(token_remove_regex, sorted_vocab)
else:
ids = None
# If the token remove regex is None, just use all tokens.
id_list = list(range(len(tokenizer.get_vocab())))

added_tokens_ids = [id for token, id in tokenizer.added_tokens_encoder.items() if token not in tokens_to_keep]
ids = torch.Tensor(sorted(set(id_list) - set(added_tokens_ids))).long()

if ids is not None:
dummy_encoding = tokenizer.encode("A")
Expand Down
49 changes: 38 additions & 11 deletions model2vec/distill/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@
:param tokens: The tokens to pre-tokenize.
:return: The pre-tokenized tokens.
"""
current_tokenizer_vocab = set(tokenizer.get_vocab())
pre_tokenized_tokens = []

if tokenizer.pre_tokenizer is not None:
for token in tokens:
if token.is_subword:
if token.is_original:
# Original tokens do not need to be pre-tokenized.
pre_tokenized_tokens.append(token.form)
else:
# We know 100% sure that all pretokenized tokens will have length 1.
Expand All @@ -50,7 +50,7 @@
def _remap_added_tokens(
special_tokens: list[dict[str, Any]],
vocabulary: list[str],
) -> list[dict[str, int]]:
) -> list[dict[str, Any]]:
"""
Remap special tokens in the tokenizer.

Expand Down Expand Up @@ -115,29 +115,37 @@

Comment thread
stephantul marked this conversation as resolved.
# NOTE: all tokens have been normalized before.
# Very careful, we need to pretokenize words before adding them to the vocabulary.
# But only if they are not subword tokens.
# But only if they are not part of the original vocabulary.
pre_tokenized_tokens = _pre_tokenize_vocabulary(tokenizer, new_vocabulary)

model_type = tokenizer_json["model"]["type"]
special_tokens = {unk_token, pad_token}
added_tokens: list[dict[str, Any]] = tokenizer_json["added_tokens"]

# NOTE: all added tokens but the unk and pad tokens are removed already.
# We only need this for BPE.
Comment thread
stephantul marked this conversation as resolved.
Outdated
added_token_forms = {x["content"] for x in added_tokens} | {"[UNK]", "[PAD]"}
# We need to remove the added tokens but keep [UNK] and [PAD] tokens.
added_tokens = _rename_added_token(unk_token, "[UNK]", added_tokens, pre_tokenized_tokens)
added_tokens = _rename_added_token(pad_token, "[PAD]", added_tokens, pre_tokenized_tokens)

# Remove old added tokens from added tokens
tokenizer_json["added_tokens"] = [x for x in added_tokens if x["content"] in {"[UNK]", "[PAD]"}]

if model_type in {"WordPiece", "BPE"}:
# Easiest, just add the new vocab
unk_token = unk_token or tokenizer_json["model"]["unk_token"]
tokenizer_json["model"]["unk_token"] = unk_token
tokenizer_json["added_tokens"] = [x for x in tokenizer_json["added_tokens"] if x["content"] in special_tokens]
tokenizer_json["model"]["unk_token"] = "[UNK]" if unk_token else None
tokenizer_json["model"]["vocab"] = {token: idx for idx, token in enumerate(pre_tokenized_tokens)}

if model_type == "BPE":
# Bit more difficult, we need to take into account merges.
merges = tokenizer_json["model"]["merges"]
merges = _make_new_merges_from_vocab(merges, pre_tokenized_tokens, special_tokens)
merges = _make_new_merges_from_vocab(merges, pre_tokenized_tokens, added_token_forms)

Check warning on line 143 in model2vec/distill/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

model2vec/distill/tokenizer.py#L143

Added line #L143 was not covered by tests
tokenizer_json["model"]["merges"] = merges

elif model_type == "Unigram":
# Bit more difficult, we need to take into account probas.
unk_id = tokenizer_json["model"]["unk_id"]
tokenizer_json["added_tokens"] = [x for x in tokenizer_json["added_tokens"] if x["content"] in special_tokens]
vocab = tokenizer_json["model"]["vocab"]
unk_token = vocab[unk_id][0] if unk_id is not None else None
current_probas = dict(tokenizer_json["model"]["vocab"])
Expand All @@ -152,8 +160,27 @@
raise ValueError(f"Unknown model type {model_type}")

# Remap special tokens
added_tokens = tokenizer_json["added_tokens"]
tokenizer_json["added_tokens"] = _remap_added_tokens(added_tokens, pre_tokenized_tokens)
tokenizer_json["added_tokens"] = _remap_added_tokens(
special_tokens=tokenizer_json["added_tokens"],
vocabulary=pre_tokenized_tokens,
)
tokenizer_json["post_processor"] = _DEFAULT_POST_PROCESSOR_TEMPLATE

return Tokenizer.from_str(json.dumps(tokenizer_json))


def _rename_added_token(
form: str | None, new_form: str, added_tokens: list[dict[str, Any]], vocabulary: list[str]
) -> list[dict[str, Any]]:
"""Rename special tokens in the tokenizer."""
if form is None:
return added_tokens

Check warning on line 177 in model2vec/distill/tokenizer.py

View check run for this annotation

Codecov / codecov/patch

model2vec/distill/tokenizer.py#L177

Added line #L177 was not covered by tests

idx = vocabulary.index(form)
added_token = [x for x in added_tokens if x["content"] == form]
if added_token:
added_token[0]["id"] = idx
added_token[0]["content"] = new_form
vocabulary[idx] = new_form

return added_tokens
2 changes: 1 addition & 1 deletion model2vec/distill/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class Token:
"""A class to represent a token."""

form: str
is_subword: bool
is_original: bool


def select_optimal_device(device: str | None) -> str:
Expand Down
Loading