@@ -88,6 +88,36 @@ def test_distill_from_model(
8888 assert token in static_model .tokens or normalized in static_model .tokens
8989
9090
91+ @patch .object (import_module ("model2vec.distill.distillation" ), "model_info" )
92+ @patch ("transformers.AutoModel.from_pretrained" )
93+ def test_distill_quantization (
94+ mock_auto_model : MagicMock ,
95+ mock_model_info : MagicMock ,
96+ mock_berttokenizer : PreTrainedTokenizerFast ,
97+ mock_transformer : PreTrainedModel ,
98+ ) -> None :
99+ """Test distill function with different parameters."""
100+ # Mock the return value of model_info to avoid calling the Hugging Face API
101+ mock_model_info .return_value = type ("ModelInfo" , (object ,), {"cardData" : {"language" : "en" }})
102+ mock_auto_model .return_value = mock_transformer
103+
104+ static_model = distill_from_model (
105+ model = mock_transformer ,
106+ tokenizer = mock_berttokenizer ,
107+ vocabulary = None ,
108+ device = "cpu" ,
109+ pca_dims = "auto" ,
110+ sif_coefficient = 1e-4 ,
111+ token_remove_pattern = None ,
112+ vocabulary_quantization = 3 ,
113+ )
114+
115+ assert static_model .embedding .shape == (3 , 768 )
116+ assert static_model .weights is not None
117+ assert static_model .token_mapping is not None
118+ assert len (static_model .weights ) == static_model .tokenizer .get_vocab_size ()
119+
120+
91121@patch .object (import_module ("model2vec.distill.distillation" ), "model_info" )
92122@patch ("transformers.AutoModel.from_pretrained" )
93123def test_distill_removal_pattern_all_tokens (
0 commit comments