@@ -24,7 +24,7 @@ class StaticModel:
2424 def __init__ (
2525 self ,
2626 vectors : np .ndarray ,
27- weights : np .ndarray ,
27+ weights : np .ndarray | None ,
2828 token_mapping : dict [int , int ],
2929 tokenizer : Tokenizer ,
3030 config : dict [str , Any ] | None = None ,
@@ -107,6 +107,8 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None, subfold
107107 """
108108 from model2vec .hf_utils import save_pretrained
109109
110+ self .config ["token_mapping" ] = list (self .token_mapping .items ())
111+
110112 save_pretrained (
111113 folder_path = Path (path ),
112114 embeddings = self .embedding ,
@@ -116,6 +118,7 @@ def save_pretrained(self, path: PathLike, model_name: str | None = None, subfold
116118 language = self .language ,
117119 model_name = model_name ,
118120 subfolder = subfolder ,
121+ weights = self .weights ,
119122 )
120123
121124 def tokenize (self , sentences : Sequence [str ], max_length : int | None = None ) -> list [list [int ]]:
@@ -131,8 +134,6 @@ def tokenize(self, sentences: Sequence[str], max_length: int | None = None) -> l
131134 m = max_length * self .median_token_length
132135 sentences = [sentence [:m ] for sentence in sentences ]
133136
134- max_len = max ([len (sentence ) for sentence in sentences ])
135- # self.tokenizer.model.max_input_chars_per_word = max_len + 1
136137 if self ._can_encode_fast :
137138 encodings : list [Encoding ] = self .tokenizer .encode_batch_fast (sentences , add_special_tokens = False )
138139 else :
@@ -159,6 +160,7 @@ def from_pretrained(
159160 subfolder : str | None = None ,
160161 quantize_to : str | DType | None = None ,
161162 dimensionality : int | None = None ,
163+ vocabulary_quantization : int | None = None ,
162164 ) -> StaticModel :
163165 """
164166 Load a StaticModel from a local path or huggingface hub path.
@@ -178,36 +180,45 @@ def from_pretrained(
178180 """
179181 from model2vec .hf_utils import load_pretrained
180182
181- embeddings , tokenizer , config , metadata = load_pretrained (
183+ embeddings , tokenizer , config , metadata , weights = load_pretrained (
182184 folder_or_repo_path = path ,
183185 token = token ,
184186 from_sentence_transformers = False ,
185187 subfolder = subfolder ,
186188 )
187189
188- weights = np .linalg .norm (embeddings , axis = 1 , keepdims = True ) + 1e-32
189- embeddings = embeddings / weights
190-
191- """from sklearn.cluster import KMeans
192- from sklearn.decomposition import PCA
193- km = KMeans(n_clusters=4096, random_state=0)
194- km.fit(embeddings)
195- # Do PCA again?
196- assignments = km.predict(embeddings)
197- embeddings = km.cluster_centers_
198-
199- p = PCA(n_components=dimensionality)
200- embeddings = p.fit_transform(embeddings)
201-
202- token_mapping = {i: x for i, x in enumerate(assignments)}"""
203- token_mapping = {i : i for i in range (len (embeddings ))}
204-
205190 embeddings = quantize_and_reduce_dim (
206191 embeddings = embeddings ,
207192 quantize_to = quantize_to ,
208193 dimensionality = dimensionality ,
209194 )
210195
196+ if vocabulary_quantization is not None :
197+ if len (embeddings ) != len (tokenizer .get_vocab ()):
198+ raise ValueError (
199+ "Already quantized. "
200+ )
201+
202+ if weights is None :
203+ weights = np .linalg .norm (embeddings , axis = 1 , keepdims = True ) + 1e-32
204+ embeddings = embeddings / weights
205+
206+ # Quantize the vocabulary
207+ from sklearn .cluster import KMeans
208+ kmeans = KMeans (n_clusters = vocabulary_quantization , random_state = 42 )
209+ kmeans .fit (embeddings )
210+ token_mapping = {idx : x for idx , x in enumerate (kmeans .predict (embeddings ))}
211+ embeddings = kmeans .cluster_centers_
212+
213+ else :
214+ token_mapping = config .pop ("token_mapping" , None )
215+ if isinstance (token_mapping , list ):
216+ # If the token mapping is a list, convert it to a dict
217+ token_mapping = {int (k ): int (v ) for k , v in token_mapping }
218+ elif token_mapping is None :
219+ # If no token mapping is provided, use the default mapping
220+ token_mapping = {i : i for i in range (len (embeddings ))}
221+
211222 return cls (
212223 embeddings ,
213224 weights ,
@@ -245,7 +256,7 @@ def from_sentence_transformers(
245256 """
246257 from model2vec .hf_utils import load_pretrained
247258
248- embeddings , tokenizer , config , metadata = load_pretrained (
259+ embeddings , tokenizer , config , metadata , weights = load_pretrained (
249260 folder_or_repo_path = path ,
250261 token = token ,
251262 from_sentence_transformers = True ,
@@ -258,9 +269,10 @@ def from_sentence_transformers(
258269 dimensionality = dimensionality ,
259270 )
260271
261- weights = np .linalg .norm (embeddings , axis = 1 , keepdims = True ) + 1e-32
262- embeddings = embeddings / weights
263- token_mapping = {i : i for i in range (len (embeddings ))}
272+ token_mapping = config .pop ("token_mapping" , None )
273+ if token_mapping is None :
274+ # If no token mapping is provided, use the default mapping
275+ token_mapping = {i : i for i in range (len (embeddings ))}
264276
265277 return cls (
266278 embeddings ,
0 commit comments