Skip to content

Commit 810d25e

Browse files
authored
clean up post-processing code (#323)
* clean up post-processing code * add additional test
1 parent 7569d0d commit 810d25e

3 files changed

Lines changed: 56 additions & 22 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
import re
66
from typing import cast
77

8-
import numpy as np
98
from huggingface_hub.hf_api import model_info
109
from skeletoken import TokenizerModel
1110
from skeletoken.external.transformers import reshape_embeddings
1211
from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast
1312
from transformers.modeling_utils import PreTrainedModel
1413

15-
from model2vec.distill.inference import PCADimType, PoolingMode, create_embeddings, post_process_embeddings
14+
from model2vec.distill.inference import PCADimType, PoolingMode, apply_pca, compute_weights, create_embeddings
1615
from model2vec.distill.utils import select_optimal_device
1716
from model2vec.model import StaticModel
1817
from model2vec.quantization import DType, quantize_embeddings
@@ -108,16 +107,17 @@ def distill_from_model(
108107
pooling=pooling,
109108
)
110109

111-
# Maybe apply quantization
110+
# Apply quantization
112111
if vocabulary_quantization is not None:
113-
_, weights = post_process_embeddings(np.asarray(embeddings), None, sif_coefficient=sif_coefficient)
112+
weights = compute_weights(len(embeddings), sif_coefficient=sif_coefficient)
114113
embeddings, token_mapping, weights = quantize_vocabulary(
115-
n_clusters=vocabulary_quantization, weights=weights, embeddings=np.asarray(embeddings)
114+
n_clusters=vocabulary_quantization, weights=weights, embeddings=embeddings
116115
)
117-
embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient=sif_coefficient)
116+
embeddings = apply_pca(embeddings, pca_dims)
118117
else:
119118
# Post-process the embeddings.
120-
embeddings, weights = post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
119+
weights = compute_weights(len(embeddings), sif_coefficient=sif_coefficient)
120+
embeddings = apply_pca(embeddings, pca_dims)
121121
embeddings = embeddings * weights[:, None]
122122
weights = None
123123
token_mapping = None

model2vec/distill/inference.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,20 @@ def _encode_pooler_with_model(model: PreTrainedModel, encodings: dict[str, torch
206206
return pooler.cpu()
207207

208208

209-
def post_process_embeddings(
210-
embeddings: np.ndarray, pca_dims: PCADimType, sif_coefficient: float | None = 1e-4
211-
) -> tuple[np.ndarray, np.ndarray]:
212-
"""Post process embeddings by applying PCA and SIF weighting by estimating the frequencies through Zipf's law."""
209+
def compute_weights(n_embeddings: int, sif_coefficient: float | None) -> np.ndarray:
210+
"""Compute the weights based on Zipf's law and a SIF coefficient."""
211+
if sif_coefficient is None:
212+
return np.ones(n_embeddings)
213+
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
214+
inv_rank = 1 / (np.arange(2, n_embeddings + 2))
215+
proba = inv_rank / np.sum(inv_rank)
216+
weight = sif_coefficient / (sif_coefficient + proba)
217+
218+
return weight
219+
220+
221+
def apply_pca(embeddings: np.ndarray, pca_dims: PCADimType) -> np.ndarray:
222+
"""Apply PCA to the embeddings."""
213223
if pca_dims is not None:
214224
if pca_dims == "auto":
215225
pca_dims = embeddings.shape[1]
@@ -241,12 +251,4 @@ def post_process_embeddings(
241251
logger.info(f"Explained variance ratio: {explained_variance_ratio:.3f}.")
242252
logger.info(f"Explained variance: {explained_variance:.3f}.")
243253

244-
if sif_coefficient is not None:
245-
logger.info("Estimating word frequencies using Zipf's law, and then applying SIF.")
246-
inv_rank = 1 / (np.arange(2, embeddings.shape[0] + 2))
247-
proba = inv_rank / np.sum(inv_rank)
248-
weight = sif_coefficient / (sif_coefficient + proba)
249-
else:
250-
weight = np.ones(embeddings.shape[0])
251-
252-
return embeddings, weight
254+
return embeddings

tests/test_distillation.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from transformers.tokenization_utils_tokenizers import PreTrainedTokenizerFast
1515

1616
from model2vec.distill.distillation import distill, distill_from_model
17-
from model2vec.distill.inference import PoolingMode, create_embeddings, post_process_embeddings
17+
from model2vec.distill.inference import PoolingMode, apply_pca, compute_weights, create_embeddings
1818
from model2vec.model import StaticModel
1919
from model2vec.tokenizer import clean_and_create_vocabulary
2020

@@ -88,6 +88,36 @@ def test_distill_from_model(
8888
assert token in static_model.tokens or normalized in static_model.tokens
8989

9090

91+
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
92+
@patch("transformers.AutoModel.from_pretrained")
93+
def test_distill_quantization(
94+
mock_auto_model: MagicMock,
95+
mock_model_info: MagicMock,
96+
mock_berttokenizer: PreTrainedTokenizerFast,
97+
mock_transformer: PreTrainedModel,
98+
) -> None:
99+
"""Test distill function with different parameters."""
100+
# Mock the return value of model_info to avoid calling the Hugging Face API
101+
mock_model_info.return_value = type("ModelInfo", (object,), {"cardData": {"language": "en"}})
102+
mock_auto_model.return_value = mock_transformer
103+
104+
static_model = distill_from_model(
105+
model=mock_transformer,
106+
tokenizer=mock_berttokenizer,
107+
vocabulary=None,
108+
device="cpu",
109+
pca_dims="auto",
110+
sif_coefficient=1e-4,
111+
token_remove_pattern=None,
112+
vocabulary_quantization=3,
113+
)
114+
115+
assert static_model.embedding.shape == (3, 768)
116+
assert static_model.weights is not None
117+
assert static_model.token_mapping is not None
118+
assert len(static_model.weights) == static_model.tokenizer.get_vocab_size()
119+
120+
91121
@patch.object(import_module("model2vec.distill.distillation"), "model_info")
92122
@patch("transformers.AutoModel.from_pretrained")
93123
def test_distill_removal_pattern_all_tokens(
@@ -259,7 +289,9 @@ def test__post_process_embeddings(
259289
# The implementation logs a warning and skips reduction; no exception expected.
260290
pass
261291

262-
processed_embeddings, _ = post_process_embeddings(embeddings, pca_dims, sif_coefficient)
292+
processed_embeddings = apply_pca(embeddings, pca_dims)
293+
weights = compute_weights(len(processed_embeddings), sif_coefficient=sif_coefficient)
294+
processed_embeddings = processed_embeddings * weights[:, None]
263295

264296
# Assert the shape is correct
265297
assert processed_embeddings.shape == expected_shape

0 commit comments

Comments
 (0)