1212from tokenizers import Encoding , Tokenizer
1313from tqdm import tqdm
1414
15- from model2vec .quantization import DType , quantize_embeddings
15+ from model2vec .quantization import DType , quantize_and_reduce_dim
1616from model2vec .utils import ProgressParallel , load_local_model
1717
1818PathLike = Union [Path , str ]
@@ -147,51 +147,6 @@ def tokenize(self, sentences: list[str], max_length: int | None = None) -> list[
147147
148148 return encodings_ids
149149
150- @classmethod
151- def _load_model (
152- cls : type [StaticModel ],
153- path : PathLike ,
154- token : str | None ,
155- normalize : bool | None ,
156- quantize_to : str | DType | None ,
157- dimensionality : int | None ,
158- from_sentence_transformers : bool ,
159- subfolder : str | None = None ,
160- ) -> StaticModel :
161- """Helper function to load a model from a path and optionally quantize it/reduce the dimensionality."""
162- from model2vec .hf_utils import load_pretrained
163-
164- embeddings , tokenizer , config , metadata = load_pretrained (
165- folder_or_repo_path = path ,
166- token = token ,
167- from_sentence_transformers = from_sentence_transformers ,
168- subfolder = subfolder ,
169- )
170-
171- if quantize_to is not None :
172- quantize_to = DType (quantize_to )
173- embeddings = quantize_embeddings (embeddings , quantize_to )
174-
175- if dimensionality is not None :
176- if dimensionality > embeddings .shape [1 ]:
177- raise ValueError (
178- f"Dimensionality { dimensionality } is greater than the model dimensionality { embeddings .shape [1 ]} "
179- )
180- embeddings = embeddings [:, :dimensionality ]
181- if config .get ("apply_pca" , None ) is None :
182- logger .warning (
183- "You are reducing the dimensionality of the model, but we can't find a pca key in the model config. This might not work as expected."
184- )
185-
186- return cls (
187- embeddings ,
188- tokenizer ,
189- config ,
190- normalize = normalize ,
191- base_model_name = metadata .get ("base_model" ),
192- language = metadata .get ("language" ),
193- )
194-
195150 @classmethod
196151 def from_pretrained (
197152 cls : type [StaticModel ],
@@ -218,16 +173,30 @@ def from_pretrained(
218173 Note that this only applies if you have trained your model using mrl or PCA.
219174 :return: A StaticModel.
220175 """
221- return cls ._load_model (
222- path = path ,
176+ from model2vec .hf_utils import load_pretrained
177+
178+ embeddings , tokenizer , config , metadata = load_pretrained (
179+ folder_or_repo_path = path ,
223180 token = token ,
224- normalize = normalize ,
225- quantize_to = quantize_to ,
226- dimensionality = dimensionality ,
227181 from_sentence_transformers = False ,
228182 subfolder = subfolder ,
229183 )
230184
185+ embeddings = quantize_and_reduce_dim (
186+ embeddings = embeddings ,
187+ quantize_to = quantize_to ,
188+ dimensionality = dimensionality ,
189+ )
190+
191+ return cls (
192+ embeddings ,
193+ tokenizer ,
194+ config ,
195+ normalize = normalize ,
196+ base_model_name = metadata .get ("base_model" ),
197+ language = metadata .get ("language" ),
198+ )
199+
231200 @classmethod
232201 def from_sentence_transformers (
233202 cls : type [StaticModel ],
@@ -252,13 +221,28 @@ def from_sentence_transformers(
252221 Note that this only applies if you have trained your model using mrl or PCA.
253222 :return: A StaticModel.
254223 """
255- return cls ._load_model (
256- path = path ,
224+ from model2vec .hf_utils import load_pretrained
225+
226+ embeddings , tokenizer , config , metadata = load_pretrained (
227+ folder_or_repo_path = path ,
257228 token = token ,
258- normalize = normalize ,
229+ from_sentence_transformers = True ,
230+ subfolder = None ,
231+ )
232+
233+ embeddings = quantize_and_reduce_dim (
234+ embeddings = embeddings ,
259235 quantize_to = quantize_to ,
260236 dimensionality = dimensionality ,
261- from_sentence_transformers = True ,
237+ )
238+
239+ return cls (
240+ embeddings ,
241+ tokenizer ,
242+ config ,
243+ normalize = normalize ,
244+ base_model_name = metadata .get ("base_model" ),
245+ language = metadata .get ("language" ),
262246 )
263247
264248 def encode_as_sequence (
0 commit comments