Skip to content

Commit 6731674

Browse files
authored
feat: add quantization (#217)
* feat: add quantization * add comment
1 parent 3f5786a commit 6731674

6 files changed

Lines changed: 115 additions & 2 deletions

File tree

model2vec/distill/distillation.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from model2vec.distill.tokenizer import replace_vocabulary
1616
from model2vec.distill.utils import select_optimal_device
1717
from model2vec.model import StaticModel
18+
from model2vec.quantization import DType, quantize_embeddings
1819

1920
try:
2021
# For huggingface_hub>=0.25.0
@@ -40,6 +41,7 @@ def distill_from_model(
4041
sif_coefficient: float | None = 1e-4,
4142
use_subword: bool = True,
4243
token_remove_pattern: str | None = r"\[unused\d+\]",
44+
quantize_to: DType | str = DType.Float16,
4345
) -> StaticModel:
4446
"""
4547
Distill a staticmodel from a sentence transformer.
@@ -64,9 +66,11 @@ def distill_from_model(
6466
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
6567
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
6668
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
69+
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
6770
:return: A StaticModel
6871
6972
"""
73+
quantize_to = DType(quantize_to)
7074
backend_tokenizer = tokenizer.backend_tokenizer
7175
sif_coefficient, token_remove_regex = _validate_parameters(
7276
vocabulary, apply_zipf, sif_coefficient, use_subword, token_remove_pattern
@@ -106,6 +110,9 @@ def distill_from_model(
106110
# Post process the embeddings by applying PCA and Zipf weighting.
107111
embeddings = _post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)
108112

113+
# Quantize the embeddings.
114+
embeddings = quantize_embeddings(embeddings, quantize_to)
115+
109116
model_name = getattr(model, "name_or_path", "")
110117

111118
config = {
@@ -209,6 +216,7 @@ def distill(
209216
use_subword: bool = True,
210217
token_remove_pattern: str | None = r"\[unused\d+\]",
211218
trust_remote_code: bool = False,
219+
quantize_to: DType | str = DType.Float16,
212220
) -> StaticModel:
213221
"""
214222
Distill a staticmodel from a sentence transformer.
@@ -232,6 +240,7 @@ def distill(
232240
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
233241
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
234242
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
243+
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
235244
:return: A StaticModel
236245
237246
"""
@@ -248,6 +257,7 @@ def distill(
248257
use_subword=use_subword,
249258
token_remove_pattern=token_remove_pattern,
250259
sif_coefficient=sif_coefficient,
260+
quantize_to=quantize_to,
251261
)
252262

253263

model2vec/model.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tokenizers import Encoding, Tokenizer
1313
from tqdm import tqdm
1414

15+
from model2vec.quantization import DType, quantize_embeddings
1516
from model2vec.utils import ProgressParallel, load_local_model
1617

1718
PathLike = Union[Path, str]
@@ -150,6 +151,7 @@ def from_pretrained(
150151
path: PathLike,
151152
token: str | None = None,
152153
normalize: bool | None = None,
154+
quantize_to: str | DType | None = None,
153155
dimensionality: int | None = None,
154156
) -> StaticModel:
155157
"""
@@ -160,6 +162,8 @@ def from_pretrained(
160162
:param path: The path to load your static model from.
161163
:param token: The huggingface token to use.
162164
:param normalize: Whether to normalize the embeddings.
165+
:param quantize_to: The dtype to quantize the model to. If None, no quantization is done.
166+
If a string is passed, it is converted to a DType.
163167
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
164168
This is useful if you want to load a model with a lower dimensionality.
165169
Note that this only applies if you have trained your model using mrl or PCA.
@@ -170,6 +174,9 @@ def from_pretrained(
170174

171175
embeddings, tokenizer, config, metadata = load_pretrained(path, token=token, from_sentence_transformers=False)
172176

177+
if quantize_to is not None:
178+
quantize_to = DType(quantize_to)
179+
embeddings = quantize_embeddings(embeddings, quantize_to)
173180
if dimensionality is not None:
174181
if dimensionality > embeddings.shape[1]:
175182
raise ValueError(

model2vec/quantization.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from enum import Enum
2+
3+
import numpy as np
4+
5+
6+
class DType(str, Enum):
7+
Float16 = "float16"
8+
Float32 = "float32"
9+
Float64 = "float64"
10+
Int8 = "int8"
11+
12+
13+
def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarray:
14+
"""
15+
Quantize embeddings to a specified data type to reduce memory usage.
16+
17+
:param embeddings: The embeddings to quantize, as a numpy array.
18+
:param quantize_to: The data type to quantize to.
19+
:return: The quantized embeddings.
20+
:raises ValueError: If the quantization type is not valid.
21+
"""
22+
if quantize_to == DType.Float16:
23+
return embeddings.astype(np.float16)
24+
elif quantize_to == DType.Float32:
25+
return embeddings.astype(np.float32)
26+
elif quantize_to == DType.Float64:
27+
return embeddings.astype(np.float64)
28+
elif quantize_to == DType.Int8:
29+
# Normalize to [-128, 127] range for int8
30+
# We normalize to -127 to 127 to keep symmetry.
31+
scale = np.max(np.abs(embeddings)) / 127.0
32+
quantized = np.round(embeddings / scale).astype(np.int8)
33+
return quantized
34+
else:
35+
raise ValueError("Not a valid enum member of DType.")

scripts/export_to_onnx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ def __init__(self, model: StaticModel) -> None:
2727
"""Initialize the TorchStaticModel with a StaticModel instance."""
2828
super().__init__()
2929
# Convert NumPy embeddings to a torch.nn.EmbeddingBag
30-
embeddings = torch.tensor(model.embedding, dtype=torch.float32)
30+
embeddings = torch.from_numpy(model.embedding)
31+
if embeddings.dtype in {torch.int8, torch.uint8}:
32+
embeddings = embeddings.to(torch.float16)
3133
self.embedding_bag = torch.nn.EmbeddingBag.from_pretrained(embeddings, mode="mean", freeze=True)
3234
self.normalize = model.normalize
3335
# Save tokenizer attributes

tests/test_model.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def test_load_pretrained(
182182
assert loaded_model.config == mock_config
183183

184184

185-
def test_load_pretrained_dim(
185+
def test_load_pretrained_quantized(
186186
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
187187
) -> None:
188188
"""Test loading a pretrained model after saving it."""
@@ -192,6 +192,35 @@ def test_load_pretrained_dim(
192192
model.save_pretrained(save_path)
193193

194194
# Load the model back from the same path
195+
loaded_model = StaticModel.from_pretrained(save_path, quantize_to="int8")
196+
197+
# Assert that the loaded model has the same properties as the original one
198+
assert loaded_model.embedding.dtype == np.int8
199+
assert loaded_model.embedding.shape == mock_vectors.shape
200+
201+
# Load the model back from the same path
202+
loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float16")
203+
204+
# Assert that the loaded model has the same properties as the original one
205+
assert loaded_model.embedding.dtype == np.float16
206+
assert loaded_model.embedding.shape == mock_vectors.shape
207+
208+
# Load the model back from the same path
209+
loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float32")
210+
# Assert that the loaded model has the same properties as the original one
211+
assert loaded_model.embedding.dtype == np.float32
212+
assert loaded_model.embedding.shape == mock_vectors.shape
213+
214+
215+
def test_load_pretrained_dim(
216+
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
217+
) -> None:
218+
"""Test loading a pretrained model with dimensionality."""
219+
# Save the model to a temporary path
220+
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
221+
save_path = tmp_path / "saved_model"
222+
model.save_pretrained(save_path)
223+
195224
loaded_model = StaticModel.from_pretrained(save_path, dimensionality=2)
196225

197226
# Assert that the loaded model has the same properties as the original one

tests/test_quantization.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import numpy as np
2+
import pytest
3+
4+
from model2vec.quantization import DType, quantize_embeddings
5+
6+
7+
@pytest.mark.parametrize(
8+
"input_dtype,target_dtype,expected_dtype",
9+
[
10+
(np.float32, DType.Float16, np.float16),
11+
(np.float16, DType.Float32, np.float32),
12+
(np.float32, DType.Float64, np.float64),
13+
(np.float32, DType.Int8, np.int8),
14+
],
15+
)
16+
def test_quantize_embeddings(input_dtype: DType, target_dtype: DType, expected_dtype: DType) -> None:
17+
"""Test quantization to different dtypes."""
18+
embeddings = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=input_dtype)
19+
# Use negative values for int8 test case
20+
if target_dtype == DType.Int8:
21+
embeddings = np.array([[-1.0, 2.0], [-3.0, 4.0]], dtype=input_dtype)
22+
23+
quantized = quantize_embeddings(embeddings, target_dtype)
24+
assert quantized.dtype == expected_dtype
25+
26+
if target_dtype == DType.Int8:
27+
# Check if the values are in the range [-127, 127]
28+
assert np.all(quantized >= -127) and np.all(quantized <= 127)
29+
else:
30+
assert np.allclose(quantized, embeddings.astype(expected_dtype))

0 commit comments

Comments
 (0)