88import numpy as np
99import torch
1010from tokenizers import Encoding , Tokenizer
11+ from torch import nn
1112from torch .nn import EmbeddingBag
1213from tqdm import tqdm
1314
1920logger = getLogger (__name__ )
2021
2122
22- class StaticModel :
23- def __init__ (self , vectors : np .ndarray , tokenizer : Tokenizer , config : dict [str , Any ]) -> None :
23+ class StaticModel (nn .Module ):
24+ def __init__ (
25+ self , vectors : np .ndarray , tokenizer : Tokenizer , config : dict [str , Any ], normalize : bool = False
26+ ) -> None :
2427 """
2528 Initialize the StaticModel.
2629
2730 :param vectors: The vectors to use.
2831 :param tokenizer: The Transformers tokenizer to use.
2932 :param config: Any metadata config.
33+ :param normalize: Whether to normalize.
3034 :raises: ValueError if the number of tokens does not match the number of vectors.
3135 """
36+ super ().__init__ ()
3237 tokens , _ = zip (* sorted (tokenizer .get_vocab ().items (), key = lambda x : x [1 ]))
33- self .vectors = vectors
3438 self .tokens = tokens
3539 self .embedding = EmbeddingBag .from_pretrained (torch .from_numpy (vectors ))
3640
@@ -45,14 +49,52 @@ def __init__(self, vectors: np.ndarray, tokenizer: Tokenizer, config: dict[str,
4549 self .unk_token_id = None
4650
4751 self .config = config
52+ self .normalize = normalize
4853
4954 def save_pretrained (self , path : PathLike ) -> None :
5055 """
5156 Save the pretrained model.
5257
5358 :param path: The path to save to.
5459 """
55- save_pretrained (Path (path ), self .vectors , self .tokenizer , self .config )
60+ save_pretrained (Path (path ), self .embedding .weight .numpy (), self .tokenizer , self .config )
61+
62+ def forward (self , ids : torch .Tensor , offsets : torch .Tensor ) -> torch .Tensor :
63+ """
64+ Forward pass of the model.
65+
66+ :param ids: The input tensor.
67+ :param offsets: The offsets tensor.
68+ :return: The output tensor.
69+ """
70+ means = self .embedding (ids , offsets )
71+ if self .normalize :
72+ return torch .nn .functional .normalize (means )
73+ return means
74+
75+ def tokenize (self , sentences : list [str ], max_length : int | None = None ) -> tuple [torch .Tensor , torch .Tensor ]:
76+ """
77+ Tokenize a sentence.
78+
79+ :param sentences: The sentence to tokenize.
80+ :param max_length: The maximum length of the sentence.
81+ :return: The tokens.
82+ """
83+ encodings : list [Encoding ] = self .tokenizer .encode_batch (sentences , add_special_tokens = False )
84+ encodings_ids = [encoding .ids for encoding in encodings ]
85+
86+ if self .unk_token_id is not None :
87+ # NOTE: Remove the unknown token: necessary for word-level models.
88+ encodings_ids = [
89+ [token_id for token_id in token_ids if token_id != self .unk_token_id ] for token_ids in encodings_ids
90+ ]
91+ if max_length is not None :
92+ encodings_ids = [token_ids [:max_length ] for token_ids in encodings_ids ]
93+
94+ offsets = torch .from_numpy (np .cumsum ([0 ] + [len (token_ids ) for token_ids in encodings_ids [:- 1 ]]))
95+ ids = torch .tensor ([token_id for token_ids in encodings_ids for token_id in token_ids ], dtype = torch .long )
96+
97+ return ids , offsets
5698
5799 @classmethod
58100 def from_pretrained (
@@ -80,20 +122,11 @@ def dim(self) -> int:
80122 """
81123 return self .vectors .shape [1 ]
82124
83- @staticmethod
84- def normalize (X : np .ndarray ) -> np .ndarray :
85- """Normalize an array to unit length."""
86- norms = np .linalg .norm (X , axis = 1 , keepdims = True )
87- norms [norms == 0 ] = 1
88-
89- return X / norms
90-
91125 def encode (
92126 self ,
93127 sentences : list [str ] | str ,
94128 show_progressbar : bool = False ,
95129 max_length : int | None = 512 ,
96- norm : bool = False ,
97130 batch_size : int = 1024 ,
98131 ** kwargs : Any ,
99132 ) -> np .ndarray :
@@ -107,7 +140,6 @@ def encode(
107140 :param show_progressbar: Whether to show the progress bar.
108141 :param max_length: The maximum length of the sentences. Any tokens beyond this length will be truncated.
109142 If this is None, no truncation is done.
110- :param norm: Whether to normalize the embeddings to unit length.
111143 :param batch_size: The batch size to use.
112144 :param **kwargs: Any additional arguments. These are ignored.
113145 :return: The encoded sentences. If a single sentence was passed, a vector is returned.
@@ -125,31 +157,16 @@ def encode(
125157
126158 out_array = np .concatenate (out_arrays , axis = 0 )
127159
128- if norm :
129- out_array = self .normalize (out_array )
130-
131160 if was_single :
132161 return out_array [0 ]
133162
134163 return out_array
135164
165+ @torch .no_grad ()
136166 def _encode_batch (self , sentences : list [str ], max_length : int | None ) -> np .ndarray :
137167 """Encode a batch of sentences."""
138- encodings : list [Encoding ] = self .tokenizer .encode_batch (sentences , add_special_tokens = False )
139- encodings_ids = [encoding .ids for encoding in encodings ]
140-
141- if self .unk_token_id is not None :
142- # NOTE: Remove the unknown token: necessary for word-level models.
143- encodings_ids = [
144- [token_id for token_id in token_ids if token_id != self .unk_token_id ] for token_ids in encodings_ids
145- ]
146- if max_length is not None :
147- encodings_ids = [token_ids [:max_length ] for token_ids in encodings_ids ]
148-
149- offsets = np .cumsum ([0 ] + [len (token_ids ) for token_ids in encodings_ids [:- 1 ]])
150- ids = torch .tensor ([token_id for token_ids in encodings_ids for token_id in token_ids ], dtype = torch .long )
151-
152- return self .embedding (ids , torch .tensor (offsets , dtype = torch .long )).detach ().numpy ()
168+ ids , offsets = self .tokenize (sentences , max_length )
169+ return self .forward (ids , offsets ).numpy ()
153170
154171 @staticmethod
155172 def _batch (sentences : list [str ], batch_size : int ) -> Iterator [list [str ]]:
0 commit comments