|
14 | 14 | from transformers.tokenization_utils_tokenizers import PreTrainedTokenizerFast |
15 | 15 |
|
16 | 16 | from model2vec.distill.distillation import distill, distill_from_model |
17 | | -from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings |
| 17 | +from model2vec.distill.inference import PoolingMode, apply_pca, compute_weights, create_embeddings |
18 | 18 | from model2vec.model import StaticModel |
19 | 19 | from model2vec.tokenizer import clean_and_create_vocabulary |
20 | 20 |
|
@@ -88,6 +88,36 @@ def test_distill_from_model( |
88 | 88 | assert token in static_model.tokens or normalized in static_model.tokens |
89 | 89 |
|
90 | 90 |
|
| 91 | +@patch.object(import_module("model2vec.distill.distillation"), "model_info") |
| 92 | +@patch("transformers.AutoModel.from_pretrained") |
| 93 | +def test_distill_quantization( |
| 94 | + mock_auto_model: MagicMock, |
| 95 | + mock_model_info: MagicMock, |
| 96 | + mock_berttokenizer: PreTrainedTokenizerFast, |
| 97 | + mock_transformer: PreTrainedModel, |
| 98 | +) -> None: |
| 99 | + """Test distill function with different parameters.""" |
| 100 | + # Mock the return value of model_info to avoid calling the Hugging Face API |
| 101 | + mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}}) |
| 102 | + mock_auto_model.return_value = mock_transformer |
| 103 | + |
| 104 | + static_model = distill_from_model( |
| 105 | + model=mock_transformer, |
| 106 | + tokenizer=mock_berttokenizer, |
| 107 | + vocabulary=None, |
| 108 | + device="cpu", |
| 109 | + pca_dims="auto", |
| 110 | + sif_coefficient=1e-4, |
| 111 | + token_remove_pattern=None, |
| 112 | + vocabulary_quantization=3, |
| 113 | + ) |
| 114 | + |
| 115 | + assert static_model.embedding.shape == (3, 768) |
| 116 | + assert static_model.weights is not None |
| 117 | + assert static_model.token_mapping is not None |
| 118 | + assert len(static_model.weights) == static_model.tokenizer.get_vocab_size() |
| 119 | + |
| 120 | + |
91 | 121 | @patch.object(import_module("model2vec.distill.distillation"), "model_info") |
92 | 122 | @patch("transformers.AutoModel.from_pretrained") |
93 | 123 | def test_distill_removal_pattern_all_tokens( |
@@ -259,7 +289,9 @@ def test__post_process_embeddings( |
259 | 289 | # The implementation logs a warning and skips reduction; no exception expected. |
260 | 290 | pass |
261 | 291 |
|
262 | | - processed_embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient) |
| 292 | + processed_embeddings = apply_pca(embeddings, pca_dims) |
| 293 | + weights = compute_weights(len(processed_embeddings), sif_coefficient=sif_coefficient) |
| 294 | + processed_embeddings = processed_embeddings * weights[:, None] |
263 | 295 |
|
264 | 296 | # Assert the shape is correct |
265 | 297 | assert processed_embeddings.shape == expected_shape |
|
0 commit comments