Skip to content

Commit 8d43143

Browse files
committed
Make staticmodel a module
1 parent 5d32041 commit 8d43143

1 file changed

Lines changed: 49 additions & 32 deletions

File tree

model2vec/model.py

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import torch
1010
from tokenizers import Encoding, Tokenizer
11+
from torch import nn
1112
from torch.nn import EmbeddingBag
1213
from tqdm import tqdm
1314

@@ -19,18 +20,21 @@
1920
logger = getLogger(__name__)
2021

2122

22-
class StaticModel:
23-
def __init__(self, vectors: np.ndarray, tokenizer: Tokenizer, config: dict[str, Any]) -> None:
23+
class StaticModel(nn.Module):
24+
def __init__(
25+
self, vectors: np.ndarray, tokenizer: Tokenizer, config: dict[str, Any], normalize: bool = False
26+
) -> None:
2427
"""
2528
Initialize the StaticModel.
2629
2730
:param vectors: The vectors to use.
2831
:param tokenizer: The Transformers tokenizer to use.
2932
:param config: Any metadata config.
33+
:param normalize: Whether to normalize.
3034
:raises: ValueError if the number of tokens does not match the number of vectors.
3135
"""
36+
super().__init__()
3237
tokens, _ = zip(*sorted(tokenizer.get_vocab().items(), key=lambda x: x[1]))
33-
self.vectors = vectors
3438
self.tokens = tokens
3539
self.embedding = EmbeddingBag.from_pretrained(torch.from_numpy(vectors))
3640

@@ -45,14 +49,52 @@ def __init__(self, vectors: np.ndarray, tokenizer: Tokenizer, config: dict[str,
4549
self.unk_token_id = None
4650

4751
self.config = config
52+
self.normalize = normalize
4853

4954
def save_pretrained(self, path: PathLike) -> None:
5055
"""
5156
Save the pretrained model.
5257
5358
:param path: The path to save to.
5459
"""
55-
save_pretrained(Path(path), self.vectors, self.tokenizer, self.config)
60+
save_pretrained(Path(path), self.embedding.weight.numpy(), self.tokenizer, self.config)
61+
62+
def forward(self, ids: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor:
63+
"""
64+
Forward pass of the model.
65+
66+
:param ids: The input tensor.
67+
:param offsets: The offsets tensor.
68+
:return: The output tensor.
69+
"""
70+
means = self.embedding(ids, offsets)
71+
if self.normalize:
72+
return torch.nn.functional.normalize(means)
73+
return means
74+
75+
def tokenize(self, sentences: list[str], max_length: int | None = None) -> tuple[torch.Tensor, torch.Tensor]:
76+
"""
77+
Tokenize a sentence.
78+
79+
:param sentences: The sentence to tokenize.
80+
:param max_length: The maximum length of the sentence.
81+
:return: The tokens.
82+
"""
83+
encodings: list[Encoding] = self.tokenizer.encode_batch(sentences, add_special_tokens=False)
84+
encodings_ids = [encoding.ids for encoding in encodings]
85+
86+
if self.unk_token_id is not None:
87+
# NOTE: Remove the unknown token: necessary for word-level models.
88+
encodings_ids = [
89+
[token_id for token_id in token_ids if token_id != self.unk_token_id] for token_ids in encodings_ids
90+
]
91+
if max_length is not None:
92+
encodings_ids = [token_ids[:max_length] for token_ids in encodings_ids]
93+
94+
offsets = torch.from_numpy(np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]]))
95+
ids = torch.tensor([token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long)
96+
97+
return ids, offsets
5698

5799
@classmethod
58100
def from_pretrained(
@@ -80,20 +122,11 @@ def dim(self) -> int:
80122
"""
81123
return self.vectors.shape[1]
82124

83-
@staticmethod
84-
def normalize(X: np.ndarray) -> np.ndarray:
85-
"""Normalize an array to unit length."""
86-
norms = np.linalg.norm(X, axis=1, keepdims=True)
87-
norms[norms == 0] = 1
88-
89-
return X / norms
90-
91125
def encode(
92126
self,
93127
sentences: list[str] | str,
94128
show_progressbar: bool = False,
95129
max_length: int | None = 512,
96-
norm: bool = False,
97130
batch_size: int = 1024,
98131
**kwargs: Any,
99132
) -> np.ndarray:
@@ -107,7 +140,6 @@ def encode(
107140
:param show_progressbar: Whether to show the progress bar.
108141
:param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
109142
If this is None, no truncation is done.
110-
:param norm: Whether to normalize the embeddings to unit length.
111143
:param batch_size: The batch size to use.
112144
:param **kwargs: Any additional arguments. These are ignored.
113145
:return: The encoded sentences. If a single sentence was passed, a vector is returned.
@@ -125,31 +157,16 @@ def encode(
125157

126158
out_array = np.concatenate(out_arrays, axis=0)
127159

128-
if norm:
129-
out_array = self.normalize(out_array)
130-
131160
if was_single:
132161
return out_array[0]
133162

134163
return out_array
135164

165+
@torch.no_grad()
136166
def _encode_batch(self, sentences: list[str], max_length: int | None) -> np.ndarray:
137167
"""Encode a batch of sentences."""
138-
encodings: list[Encoding] = self.tokenizer.encode_batch(sentences, add_special_tokens=False)
139-
encodings_ids = [encoding.ids for encoding in encodings]
140-
141-
if self.unk_token_id is not None:
142-
# NOTE: Remove the unknown token: necessary for word-level models.
143-
encodings_ids = [
144-
[token_id for token_id in token_ids if token_id != self.unk_token_id] for token_ids in encodings_ids
145-
]
146-
if max_length is not None:
147-
encodings_ids = [token_ids[:max_length] for token_ids in encodings_ids]
148-
149-
offsets = np.cumsum([0] + [len(token_ids) for token_ids in encodings_ids[:-1]])
150-
ids = torch.tensor([token_id for token_ids in encodings_ids for token_id in token_ids], dtype=torch.long)
151-
152-
return self.embedding(ids, torch.tensor(offsets, dtype=torch.long)).detach().numpy()
168+
ids, offsets = self.tokenize(sentences, max_length)
169+
return self.forward(ids, offsets).numpy()
153170

154171
@staticmethod
155172
def _batch(sentences: list[str], batch_size: int) -> Iterator[list[str]]:

0 commit comments

Comments
 (0)