Skip to content

Commit 4f02616

Browse files
authored
fix: Incorrect added token can cause issues when adding token as multiword token (#319)
* fix tokenizer * fix issue with reassignment * update lock * fix issue with import
1 parent 6a40e36 commit 4f02616

4 files changed

Lines changed: 12 additions & 9 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,8 @@
99
from huggingface_hub.hf_api import model_info
1010
from skeletoken import TokenizerModel
1111
from skeletoken.external.transformers import reshape_embeddings
12-
from transformers import AutoModel, AutoTokenizer
12+
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast
1313
from transformers.modeling_utils import PreTrainedModel
14-
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
1514

1615
from model2vec.distill.inference import PCADimType, PoolingMode, create_embeddings, post_process_embeddings
1716
from model2vec.distill.utils import select_optimal_device

model2vec/tokenizer/tokenizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,9 @@ def clean_and_create_vocabulary(
5353
logger.warning(
5454
f"Token '{token}' was split into multiple tokens after preprocessing: [{split_into}], adding it as a multi-word token."
5555
)
56+
if token in model.vocabulary:
57+
# If the unprocessed token (incorrectly) is in the vocabulary, we should remove it.
58+
model = model.remove_token_from_vocabulary(token)
5659
added_tokens_to_add.append(token)
5760
continue
5861
token = preprocessed[0]

tests/test_distillation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
import pytest
1010
from pytest import LogCaptureFixture
1111
from skeletoken import TokenizerModel
12-
from transformers import BertTokenizerFast
12+
from transformers import BertTokenizer
1313
from transformers.modeling_utils import PreTrainedModel
14-
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
14+
from transformers.tokenization_utils_tokenizers import PreTrainedTokenizerFast
1515

1616
from model2vec.distill.distillation import distill, distill_from_model
1717
from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings
@@ -38,6 +38,7 @@
3838
(None, None, 1e-4), # No PCA, SIF on
3939
(None, 0.9, 1e-4), # PCA as float (variance), SIF on
4040
(["star wars"], 8, None), # Multiword vocabulary
41+
(["..."], 8, None), # Crashing multiword vocabulary
4142
],
4243
)
4344
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
@@ -92,7 +93,7 @@ def test_distill_from_model(
9293
def test_distill_removal_pattern_all_tokens(
9394
mock_auto_model: MagicMock,
9495
mock_model_info: MagicMock,
95-
mock_berttokenizer: BertTokenizerFast,
96+
mock_berttokenizer: BertTokenizer,
9697
mock_transformer: PreTrainedModel,
9798
) -> None:
9899
"""Test the removal pattern."""

uv.lock

Lines changed: 4 additions & 4 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)