Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions model2vec/distill/distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from model2vec.distill.tokenizer import replace_vocabulary
from model2vec.distill.utils import select_optimal_device
from model2vec.model import StaticModel
from model2vec.quantization import DType, quantize_embeddings

try:
# For huggingface_hub>=0.25.0
Expand All @@ -40,6 +41,7 @@ def distill_from_model(
sif_coefficient: float | None = 1e-4,
use_subword: bool = True,
token_remove_pattern: str | None = r"\[unused\d+\]",
quantize_to: DType | str = DType.Float16,
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.
Expand All @@ -64,9 +66,11 @@ def distill_from_model(
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
If the pattern is so general that it removes all tokens, we throw an error. If the pattern can't be compiled into a valid regex, we also throw an error.
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:return: A StaticModel

"""
quantize_to = DType(quantize_to)
backend_tokenizer = tokenizer.backend_tokenizer
sif_coefficient, token_remove_regex = _validate_parameters(
vocabulary, apply_zipf, sif_coefficient, use_subword, token_remove_pattern
Expand Down Expand Up @@ -106,6 +110,9 @@ def distill_from_model(
# Post process the embeddings by applying PCA and Zipf weighting.
embeddings = _post_process_embeddings(np.asarray(embeddings), pca_dims, sif_coefficient=sif_coefficient)

# Quantize the embeddings.
embeddings = quantize_embeddings(embeddings, quantize_to)

model_name = getattr(model, "name_or_path", "")

config = {
Expand Down Expand Up @@ -209,6 +216,7 @@ def distill(
use_subword: bool = True,
token_remove_pattern: str | None = r"\[unused\d+\]",
trust_remote_code: bool = False,
quantize_to: DType | str = DType.Float16,
) -> StaticModel:
"""
Distill a staticmodel from a sentence transformer.
Expand All @@ -232,6 +240,7 @@ def distill(
:param use_subword: Whether to keep subword tokens in the vocabulary. If this is False, you must pass a vocabulary, and the returned tokenizer will only detect full words.
:param token_remove_pattern: If this is set to a string, we compile this into a regex. Any tokens that conform to this regex pattern will be removed from the vocabulary.
:param trust_remote_code: Whether to trust the remote code. If this is False, we will only load components coming from `transformers`. If this is True, we will load all components.
:param quantize_to: The data type to quantize to. Can be any of the DType enum members or their string equivalents.
:return: A StaticModel

"""
Expand All @@ -248,6 +257,7 @@ def distill(
use_subword=use_subword,
token_remove_pattern=token_remove_pattern,
sif_coefficient=sif_coefficient,
quantize_to=quantize_to,
)


Expand Down
8 changes: 8 additions & 0 deletions model2vec/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tokenizers import Encoding, Tokenizer
from tqdm import tqdm

from model2vec.quantization import DType, quantize_embeddings
from model2vec.utils import ProgressParallel, load_local_model

PathLike = Union[Path, str]
Expand Down Expand Up @@ -150,6 +151,7 @@ def from_pretrained(
path: PathLike,
token: str | None = None,
normalize: bool | None = None,
quantize_to: str | DType | None = None,
) -> StaticModel:
"""
Load a StaticModel from a local path or huggingface hub path.
Expand All @@ -159,12 +161,18 @@ def from_pretrained(
:param path: The path to load your static model from.
:param token: The huggingface token to use.
:param normalize: Whether to normalize the embeddings.
:param quantize_to: The dtype to quantize the model to. If None, no quantization is done.
If a string is passed, it is converted to a DType.
:return: A StaticModel
"""
from model2vec.hf_utils import load_pretrained

embeddings, tokenizer, config, metadata = load_pretrained(path, token=token, from_sentence_transformers=False)

if quantize_to is not None:
quantize_to = DType(quantize_to)
embeddings = quantize_embeddings(embeddings, quantize_to)

return cls(
embeddings,
tokenizer,
Expand Down
34 changes: 34 additions & 0 deletions model2vec/quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from enum import Enum

import numpy as np


class DType(str, Enum):
Float16 = "float16"
Float32 = "float32"
Float64 = "float64"
Int8 = "int8"


def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarray:
"""
Quantize embeddings to a specified data type to reduce memory usage.

:param embeddings: The embeddings to quantize, as a numpy array.
:param quantize_to: The data type to quantize to.
:return: The quantized embeddings.
:raises ValueError: If the quantization type is not valid.
"""
if quantize_to == DType.Float16:
return embeddings.astype(np.float16)
elif quantize_to == DType.Float32:
return embeddings.astype(np.float32)
elif quantize_to == DType.Float64:
return embeddings.astype(np.float64)
elif quantize_to == DType.Int8:
# Normalize to [-127, 127] range for int8
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this not be [-128, 127] (the range of an 8-bit signed integer)? Not sure if it's relevant for the code though since it doesn't change the division.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the symmetry is more important than making sure the 1 extra value is used. I updated the comment.

scale = np.max(np.abs(embeddings)) / 127.0
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this ever be 0 (zero division issues?)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only if all embeddings are 0

quantized = np.round(embeddings / scale).astype(np.int8)
return quantized
else:
raise ValueError("Not a valid enum member of DType.")

Check warning on line 34 in model2vec/quantization.py

View check run for this annotation

Codecov / codecov/patch

model2vec/quantization.py#L34

Added line #L34 was not covered by tests
4 changes: 3 additions & 1 deletion scripts/export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def __init__(self, model: StaticModel) -> None:
"""Initialize the TorchStaticModel with a StaticModel instance."""
super().__init__()
# Convert NumPy embeddings to a torch.nn.EmbeddingBag
embeddings = torch.tensor(model.embedding, dtype=torch.float32)
embeddings = torch.from_numpy(model.embedding)
if embeddings.dtype in {torch.int8, torch.uint8}:
embeddings = embeddings.to(torch.float16)
self.embedding_bag = torch.nn.EmbeddingBag.from_pretrained(embeddings, mode="mean", freeze=True)
self.normalize = model.normalize
# Save tokenizer attributes
Expand Down
31 changes: 31 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,37 @@ def test_load_pretrained(
assert loaded_model.config == mock_config


def test_load_pretrained_quantized(
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
) -> None:
"""Test loading a pretrained model after saving it."""
# Save the model to a temporary path
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
save_path = tmp_path / "saved_model"
model.save_pretrained(save_path)

# Load the model back from the same path
loaded_model = StaticModel.from_pretrained(save_path, quantize_to="int8")

# Assert that the loaded model has the same properties as the original one
assert loaded_model.embedding.dtype == np.int8
assert loaded_model.embedding.shape == mock_vectors.shape

# Load the model back from the same path
loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float16")

# Assert that the loaded model has the same properties as the original one
assert loaded_model.embedding.dtype == np.float16
assert loaded_model.embedding.shape == mock_vectors.shape

# Load the model back from the same path
loaded_model = StaticModel.from_pretrained(save_path, quantize_to="float32")

# Assert that the loaded model has the same properties as the original one
assert loaded_model.embedding.dtype == np.float32
assert loaded_model.embedding.shape == mock_vectors.shape


def test_initialize_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
"""Tests whether the normalization initialization is correct."""
model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=None)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_quantization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import numpy as np
import pytest

from model2vec.quantization import DType, quantize_embeddings


@pytest.mark.parametrize(
"input_dtype,target_dtype,expected_dtype",
[
(np.float32, DType.Float16, np.float16),
(np.float16, DType.Float32, np.float32),
(np.float32, DType.Float64, np.float64),
(np.float32, DType.Int8, np.int8),
],
)
def test_quantize_embeddings(input_dtype: DType, target_dtype: DType, expected_dtype: DType) -> None:
"""Test quantization to different dtypes."""
embeddings = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=input_dtype)
# Use negative values for int8 test case
if target_dtype == DType.Int8:
embeddings = np.array([[-1.0, 2.0], [-3.0, 4.0]], dtype=input_dtype)

quantized = quantize_embeddings(embeddings, target_dtype)
assert quantized.dtype == expected_dtype

if target_dtype == DType.Int8:
# Check if the values are in the range [-127, 127]
assert np.all(quantized >= -127) and np.all(quantized <= 127)
else:
assert np.allclose(quantized, embeddings.astype(expected_dtype))