Skip to content

Commit a5727a7

Browse files
committed
Update normalization
1 parent 8d43143 commit a5727a7

2 files changed

Lines changed: 41 additions & 20 deletions

File tree

model2vec/model.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
class StaticModel(nn.Module):
2424
def __init__(
25-
self, vectors: np.ndarray, tokenizer: Tokenizer, config: dict[str, Any], normalize: bool = False
25+
self, vectors: np.ndarray, tokenizer: Tokenizer, config: dict[str, Any], normalize: bool | None = None
2626
) -> None:
2727
"""
2828
Initialize the StaticModel.
@@ -49,7 +49,27 @@ def __init__(
4949
self.unk_token_id = None
5050

5151
self.config = config
52-
self.normalize = normalize
52+
self.normalize = config.get("normalize", normalize if normalize is not None else False)
53+
54+
@property
55+
def normalize(self) -> bool:
56+
"""
57+
Get the normalize value.
58+
59+
:return: The normalize value.
60+
"""
61+
return self._normalize
62+
63+
@normalize.setter
64+
def normalize(self, value: bool) -> None:
65+
"""Update the config if the value of normalize changes."""
66+
config_normalize = self.config.get("normalize", False)
67+
self._normalize = value
68+
if config_normalize is not None and value != config_normalize:
69+
logger.warning(
70+
f"Set normalization to `{value}`, which does not match config value `{config_normalize}`. Updating config."
71+
)
72+
self.config["normalize"] = value
5373

5474
def save_pretrained(self, path: PathLike) -> None:
5575
"""

tests/test_model.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,29 @@
22

33
import numpy as np
44
import pytest
5-
from transformers import PreTrainedTokenizerFast
5+
from tokenizers import Tokenizer
66

77
from model2vec import StaticModel
88

99

10-
def test_initialization(
11-
mock_vectors: np.ndarray, mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
12-
) -> None:
10+
def test_initialization(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
1311
"""Test successful initialization of StaticModel."""
1412
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
15-
assert model.vectors.shape == (5, 2)
13+
assert model.embedding.weight.shape == (5, 2)
1614
assert len(model.tokens) == 5
1715
assert model.tokenizer == mock_tokenizer
1816
assert model.config == mock_config
1917

2018

21-
def test_initialization_token_vector_mismatch(
22-
mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
23-
) -> None:
19+
def test_initialization_token_vector_mismatch(mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
2420
"""Test if error is raised when number of tokens and vectors don't match."""
2521
mock_vectors = np.array([[0.1, 0.2], [0.2, 0.3]])
2622
with pytest.raises(ValueError):
2723
StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
2824

2925

3026
def test_encode_single_sentence(
31-
mock_vectors: np.ndarray, mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
27+
mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
3228
) -> None:
3329
"""Test encoding of a single sentence."""
3430
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
@@ -37,7 +33,7 @@ def test_encode_single_sentence(
3733

3834

3935
def test_encode_multiple_sentences(
40-
mock_vectors: np.ndarray, mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
36+
mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
4137
) -> None:
4238
"""Test encoding of multiple sentences."""
4339
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
@@ -46,24 +42,29 @@ def test_encode_multiple_sentences(
4642

4743

4844
def test_encode_empty_sentence(
49-
mock_vectors: np.ndarray, mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
45+
mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
5046
) -> None:
5147
"""Test encoding with an empty sentence."""
5248
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
5349
encoded = model.encode("")
5450
assert np.array_equal(encoded, np.zeros((2,)))
5551

5652

57-
def test_normalize() -> None:
53+
def test_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
5854
"""Test normalization of vectors."""
59-
X = np.array([[3, 4], [1, 2], [0, 0]])
60-
normalized = StaticModel.normalize(X)
61-
expected = np.array([[0.6, 0.8], [0.4472136, 0.89442719], [0, 0]])
55+
s = "word1 word2 word3"
56+
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config, normalize=False)
57+
X = model.encode(s)
58+
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config, normalize=True)
59+
normalized = model.encode(s)
60+
61+
expected = X / np.linalg.norm(X)
62+
6263
np.testing.assert_almost_equal(normalized, expected)
6364

6465

6566
def test_save_pretrained(
66-
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
67+
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
6768
) -> None:
6869
"""Test saving a pretrained model."""
6970
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
@@ -80,7 +81,7 @@ def test_save_pretrained(
8081

8182

8283
def test_load_pretrained(
83-
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
84+
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
8485
) -> None:
8586
"""Test loading a pretrained model after saving it."""
8687
# Save the model to a temporary path
@@ -92,6 +93,6 @@ def test_load_pretrained(
9293
loaded_model = StaticModel.from_pretrained(save_path)
9394

9495
# Assert that the loaded model has the same properties as the original one
95-
np.testing.assert_array_equal(loaded_model.vectors, mock_vectors)
96+
np.testing.assert_array_equal(loaded_model.embedding.weight.numpy(), mock_vectors)
9697
assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
9798
assert loaded_model.config == mock_config

0 commit comments

Comments
 (0)