|
9 | 9 | import pytest |
10 | 10 | from pytest import LogCaptureFixture |
11 | 11 | from skeletoken import TokenizerModel |
12 | | -from transformers import BertTokenizerFast |
| 12 | +from transformers import BertTokenizer |
13 | 13 | from transformers.modeling_utils import PreTrainedModel |
14 | | -from transformers.tokenization_utils_fast import PreTrainedTokenizerFast |
| 14 | +from transformers.tokenization_utils_tokenizers import PreTrainedTokenizerFast |
15 | 15 |
|
16 | 16 | from model2vec.distill.distillation import distill, distill_from_model |
17 | 17 | from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings |
|
38 | 38 | (None, None, 1e-4), # No PCA, SIF on |
39 | 39 | (None, 0.9, 1e-4), # PCA as float (variance), SIF on |
40 | 40 | (["star wars"], 8, None), # Multiword vocabulary |
| 41 | + (["..."], 8, None), # Crashing multiword vocabulary |
41 | 42 | ], |
42 | 43 | ) |
43 | 44 | @patch.object(import_module("model2vec.distill.distillation"), "model_info") |
@@ -92,7 +93,7 @@ def test_distill_from_model( |
92 | 93 | def test_distill_removal_pattern_all_tokens( |
93 | 94 | mock_auto_model: MagicMock, |
94 | 95 | mock_model_info: MagicMock, |
95 | | - mock_berttokenizer: BertTokenizerFast, |
| 96 | + mock_berttokenizer: BertTokenizer, |
96 | 97 | mock_transformer: PreTrainedModel, |
97 | 98 | ) -> None: |
98 | 99 | """Test the removal pattern.""" |
|
0 commit comments