Skip to content

Commit ecde199

Browse files
authored
Merge pull request #18 from MinishLab/finetuning
Turn model into module
2 parents 8e0be62 + d1d432a commit ecde199

2 files changed

Lines changed: 118 additions & 50 deletions

File tree

model2vec/model.py

Lines changed: 72 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import torch
1010
from tokenizers import Encoding, Tokenizer
11+
from torch import nn
1112
from torch.nn import EmbeddingBag
1213
from tqdm import tqdm
1314

@@ -19,18 +20,21 @@
1920
logger = getLogger(__name__)
2021

2122

22-
class StaticModel:
23-
def __init__(self, vectors: np.ndarray, tokenizer: Tokenizer, config: dict[str, Any]) -> None:
23+
class StaticModel(nn.Module):
24+
def __init__(
25+
self, vectors: np.ndarray, tokenizer: Tokenizer, config: dict[str, Any], normalize: bool | None = None
26+
) -> None:
2427
"""
2528
Initialize the StaticModel.
2629
2730
:param vectors: The vectors to use.
2831
:param tokenizer: The Transformers tokenizer to use.
2932
:param config: Any metadata config.
33+
:param normalize: Whether to normalize.
3034
:raises: ValueError if the number of tokens does not match the number of vectors.
3135
"""
36+
super().__init__()
3237
tokens, _ = zip(*sorted(tokenizer.get_vocab().items(), key=lambda x: x[1]))
33-
self.vectors = vectors
3438
self.tokens = tokens
3539
self.embedding = EmbeddingBag.from_pretrained(torch.from_numpy(vectors))
3640

@@ -45,14 +49,75 @@ def __init__(self, vectors: np.ndarray, tokenizer: Tokenizer, config: dict[str,
4549
self.unk_token_id = None
4650

4751
self.config = config
52+
if normalize is not None:
53+
self.normalize = normalize
54+
else:
55+
self.normalize = config.get("normalize", False)
56+
57+
@property
58+
def normalize(self) -> bool:
59+
"""
60+
Get the normalize value.
61+
62+
:return: The normalize value.
63+
"""
64+
return self._normalize
65+
66+
@normalize.setter
67+
def normalize(self, value: bool) -> None:
68+
"""Update the config if the value of normalize changes."""
69+
config_normalize = self.config.get("normalize", False)
70+
self._normalize = value
71+
if value != config_normalize:
72+
logger.warning(
73+
f"Set normalization to `{value}`, which does not match config value `{config_normalize}`. Updating config."
74+
)
75+
self.config["normalize"] = value
4876

4977
def save_pretrained(self, path: PathLike) -> None:
5078
"""
5179
Save the pretrained model.
5280
5381
:param path: The path to save to.
5482
"""
55-
save_pretrained(Path(path), self.vectors, self.tokenizer, self.config)
83+
save_pretrained(Path(path), self.embedding.weight.numpy(), self.tokenizer, self.config)
84+
85+
def forward(self, ids: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
86+
"""
87+
Forward pass of the model.
88+
89+
:param ids: The input tensor.
90+
:param offsets: The offsets tensor.
91+
:return: The output tensor.
92+
"""
93+
means = self.embedding(ids, offsets)
94+
if self.normalize:
95+
return torch.nn.functional.normalize(means)
96+
return means
97+
98+
def tokenize(self, sentences: list[str], max_length: int | None = None) -> tuple[torch.Tensor, torch.Tensor]:
99+
"""
100+
Tokenize a sentence.
101+
102+
:param sentences: The sentence to tokenize.
103+
:param max_length: The maximum length of the sentence.
104+
:return: The tokens.
105+
"""
106+
encodings: list[Encoding] = self.tokenizer.encode_batch(sentences, add_special_tokens=False)
107+
encodings_ids = [encoding.ids for encoding in encodings]
108+
109+
if self.unk_token_id is not None:
110+
# NOTE: Remove the unknown token: necessary for word-level models.
111+
encodings_ids = [
112+
[token_id for token_id in token_ids if token_id != self.unk_token_id] for token_ids in encodings_ids
113+
]
114+
if max_length is not None:
115+
encodings_ids = [token_ids[:max_length] for token_ids in encodings_ids]
116+
117+
offsets = torch.from_numpy(np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]]))
118+
ids = torch.tensor([token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long)
119+
120+
return ids, offsets
56121

57122
@classmethod
58123
def from_pretrained(
@@ -80,20 +145,11 @@ def dim(self) -> int:
80145
"""
81146
return self.vectors.shape[1]
82147

83-
@staticmethod
84-
def normalize(X: np.ndarray) -> np.ndarray:
85-
"""Normalize an array to unit length."""
86-
norms = np.linalg.norm(X, axis=1, keepdims=True)
87-
norms[norms == 0] = 1
88-
89-
return X / norms
90-
91148
def encode(
92149
self,
93150
sentences: list[str] | str,
94151
show_progressbar: bool = False,
95152
max_length: int | None = 512,
96-
norm: bool = False,
97153
batch_size: int = 1024,
98154
**kwargs: Any,
99155
) -> np.ndarray:
@@ -107,7 +163,6 @@ def encode(
107163
:param show_progressbar: Whether to show the progress bar.
108164
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
109165
If this is None, no truncation is done.
110-
:param norm: Whether to normalize the embeddings to unit length.
111166
:param batch_size: The batch size to use.
112167
:param **kwargs: Any additional arguments. These are ignored.
113168
:return: The encoded sentences. If a single sentence was passed, a vector is returned.
@@ -125,31 +180,16 @@ def encode(
125180

126181
out_array = np.concatenate(out_arrays, axis=0)
127182

128-
if norm:
129-
out_array = self.normalize(out_array)
130-
131183
if was_single:
132184
return out_array[0]
133185

134186
return out_array
135187

188+
@torch.no_grad()
136189
def _encode_batch(self, sentences: list[str], max_length: int | None) -> np.ndarray:
137190
"""Encode a batch of sentences."""
138-
encodings: list[Encoding] = self.tokenizer.encode_batch(sentences, add_special_tokens=False)
139-
encodings_ids = [encoding.ids for encoding in encodings]
140-
141-
if self.unk_token_id is not None:
142-
# NOTE: Remove the unknown token: necessary for word-level models.
143-
encodings_ids = [
144-
[token_id for token_id in token_ids if token_id != self.unk_token_id] for token_ids in encodings_ids
145-
]
146-
if max_length is not None:
147-
encodings_ids = [token_ids[:max_length] for token_ids in encodings_ids]
148-
149-
offsets = np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]])
150-
ids = torch.tensor([token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long)
151-
152-
return self.embedding(ids, torch.tensor(offsets, dtype=torch.long)).detach().numpy()
191+
ids, offsets = self.tokenize(sentences, max_length)
192+
return self.forward(ids, offsets).numpy()
153193

154194
@staticmethod
155195
def _batch(sentences: list[str], batch_size: int) -> Iterator[list[str]]:

tests/test_model.py

Lines changed: 46 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,29 @@
22

33
import numpy as np
44
import pytest
5-
from transformers import PreTrainedTokenizerFast
5+
from tokenizers import Tokenizer
66

77
from model2vec import StaticModel
88

99

10-
def test_initialization(
11-
mock_vectors: np.ndarray, mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
12-
) -> None:
10+
def test_initialization(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
1311
"""Test successful initialization of StaticModel."""
1412
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
15-
assert model.vectors.shape == (5, 2)
13+
assert model.embedding.weight.shape == (5, 2)
1614
assert len(model.tokens) == 5
1715
assert model.tokenizer == mock_tokenizer
1816
assert model.config == mock_config
1917

2018

21-
def test_initialization_token_vector_mismatch(
22-
mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
23-
) -> None:
19+
def test_initialization_token_vector_mismatch(mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
2420
"""Test if error is raised when number of tokens and vectors don't match."""
2521
mock_vectors = np.array([[0.1, 0.2], [0.2, 0.3]])
2622
with pytest.raises(ValueError):
2723
StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
2824

2925

3026
def test_encode_single_sentence(
31-
mock_vectors: np.ndarray, mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
27+
mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
3228
) -> None:
3329
"""Test encoding of a single sentence."""
3430
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
@@ -37,7 +33,7 @@ def test_encode_single_sentence(
3733

3834

3935
def test_encode_multiple_sentences(
40-
mock_vectors: np.ndarray, mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
36+
mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
4137
) -> None:
4238
"""Test encoding of multiple sentences."""
4339
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
@@ -46,24 +42,29 @@ def test_encode_multiple_sentences(
4642

4743

4844
def test_encode_empty_sentence(
49-
mock_vectors: np.ndarray, mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
45+
mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
5046
) -> None:
5147
"""Test encoding with an empty sentence."""
5248
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
5349
encoded = model.encode("")
5450
assert np.array_equal(encoded, np.zeros((2,)))
5551

5652

57-
def test_normalize() -> None:
53+
def test_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]) -> None:
5854
"""Test normalization of vectors."""
59-
X = np.array([[3, 4], [1, 2], [0, 0]])
60-
normalized = StaticModel.normalize(X)
61-
expected = np.array([[0.6, 0.8], [0.4472136, 0.89442719], [0, 0]])
55+
s = "word1 word2 word3"
56+
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config, normalize=False)
57+
X = model.encode(s)
58+
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config, normalize=True)
59+
normalized = model.encode(s)
60+
61+
expected = X / np.linalg.norm(X)
62+
6263
np.testing.assert_almost_equal(normalized, expected)
6364

6465

6566
def test_save_pretrained(
66-
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
67+
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
6768
) -> None:
6869
"""Test saving a pretrained model."""
6970
model = StaticModel(vectors=mock_vectors, tokenizer=mock_tokenizer, config=mock_config)
@@ -80,7 +81,7 @@ def test_save_pretrained(
8081

8182

8283
def test_load_pretrained(
83-
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: PreTrainedTokenizerFast, mock_config: dict[str, str]
84+
tmp_path: Path, mock_vectors: np.ndarray, mock_tokenizer: Tokenizer, mock_config: dict[str, str]
8485
) -> None:
8586
"""Test loading a pretrained model after saving it."""
8687
# Save the model to a temporary path
@@ -92,6 +93,33 @@ def test_load_pretrained(
9293
loaded_model = StaticModel.from_pretrained(save_path)
9394

9495
# Assert that the loaded model has the same properties as the original one
95-
np.testing.assert_array_equal(loaded_model.vectors, mock_vectors)
96+
np.testing.assert_array_equal(loaded_model.embedding.weight.numpy(), mock_vectors)
9697
assert loaded_model.tokenizer.get_vocab() == mock_tokenizer.get_vocab()
9798
assert loaded_model.config == mock_config
99+
100+
101+
def test_initialize_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
102+
"""Tests whether the normalization initialization is correct."""
103+
model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=None)
104+
assert not model.normalize
105+
106+
model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=False)
107+
assert not model.normalize
108+
109+
model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=True)
110+
assert model.normalize
111+
112+
model = StaticModel(mock_vectors, mock_tokenizer, {"normalize": False}, normalize=True)
113+
assert model.normalize
114+
115+
model = StaticModel(mock_vectors, mock_tokenizer, {"normalize": True}, normalize=False)
116+
assert not model.normalize
117+
118+
119+
def test_set_normalize(mock_vectors: np.ndarray, mock_tokenizer: Tokenizer) -> None:
120+
"""Tests whether the normalize is set correctly."""
121+
model = StaticModel(mock_vectors, mock_tokenizer, {}, normalize=True)
122+
model.normalize = False
123+
assert model.config == {"normalize": False}
124+
model.normalize = True
125+
assert model.config == {"normalize": True}

0 commit comments

Comments
 (0)