Skip to content

Commit d1799fd

Browse files
committed
add additional test
1 parent ad0fb8c commit d1799fd

1 file changed

Lines changed: 30 additions & 0 deletions

File tree

tests/test_distillation.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,36 @@ def test_distill_from_model(
8888
assert token in static_model.tokens or normalized in static_model.tokens
8989

9090

91+
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
92+
@patch("transformers.AutoModel.from_pretrained")
93+
def test_distill_quantization(
94+
mock_auto_model: MagicMock,
95+
mock_model_info: MagicMock,
96+
mock_berttokenizer: PreTrainedTokenizerFast,
97+
mock_transformer: PreTrainedModel,
98+
) -> None:
99+
"""Test distill function with different parameters."""
100+
# Mock the return value of model_info to avoid calling the Hugging Face API
101+
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})
102+
mock_auto_model.return_value = mock_transformer
103+
104+
static_model = distill_from_model(
105+
model=mock_transformer,
106+
tokenizer=mock_berttokenizer,
107+
vocabulary=None,
108+
device="cpu",
109+
pca_dims="auto",
110+
sif_coefficient=1e-4,
111+
token_remove_pattern=None,
112+
vocabulary_quantization=3,
113+
)
114+
115+
assert static_model.embedding.shape == (3, 768)
116+
assert static_model.weights is not None
117+
assert static_model.token_mapping is not None
118+
assert len(static_model.weights) == static_model.tokenizer.get_vocab_size()
119+
120+
91121
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
92122
@patch("transformers.AutoModel.from_pretrained")
93123
def test_distill_removal_pattern_all_tokens(

0 commit comments

Comments
 (0)