1212from tokenizers import Encoding , Tokenizer
1313from tqdm import tqdm
1414
15- from model2vec .quantization import DType , quantize_embeddings
15+ from model2vec .quantization import DType , quantize_and_reduce_dim
1616from model2vec .utils import ProgressParallel , load_local_model
1717
1818PathLike = Union [Path , str ]
@@ -171,28 +171,22 @@ def from_pretrained(
171171 :param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
172172 This is useful if you want to load a model with a lower dimensionality.
173173 Note that this only applies if you have trained your model using mrl or PCA.
174- :return: A StaticModel
175- :raises: ValueError if the dimensionality is greater than the model dimensionality.
174+ :return: A StaticModel.
176175 """
177176 from model2vec .hf_utils import load_pretrained
178177
179178 embeddings , tokenizer , config , metadata = load_pretrained (
180- path , token = token , from_sentence_transformers = False , subfolder = subfolder
179+ folder_or_repo_path = path ,
180+ token = token ,
181+ from_sentence_transformers = False ,
182+ subfolder = subfolder ,
181183 )
182184
183- if quantize_to is not None :
184- quantize_to = DType (quantize_to )
185- embeddings = quantize_embeddings (embeddings , quantize_to )
186- if dimensionality is not None :
187- if dimensionality > embeddings .shape [1 ]:
188- raise ValueError (
189- f"Dimensionality { dimensionality } is greater than the model dimensionality { embeddings .shape [1 ]} "
190- )
191- embeddings = embeddings [:, :dimensionality ]
192- if config .get ("apply_pca" , None ) is None :
193- logger .warning (
194- "You are reducing the dimensionality of the model, but we can't find a pca key in the model config. This might not work as expected."
195- )
185+ embeddings = quantize_and_reduce_dim (
186+ embeddings = embeddings ,
187+ quantize_to = quantize_to ,
188+ dimensionality = dimensionality ,
189+ )
196190
197191 return cls (
198192 embeddings ,
@@ -209,6 +203,8 @@ def from_sentence_transformers(
209203 path : PathLike ,
210204 token : str | None = None ,
211205 normalize : bool | None = None ,
206+ quantize_to : str | DType | None = None ,
207+ dimensionality : int | None = None ,
212208 ) -> StaticModel :
213209 """
214210 Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
@@ -218,13 +214,36 @@ def from_sentence_transformers(
218214 :param path: The path to load your static model from.
219215 :param token: The huggingface token to use.
220216 :param normalize: Whether to normalize the embeddings.
221- :return: A StaticModel
217+ :param quantize_to: The dtype to quantize the model to. If None, no quantization is done.
218+ If a string is passed, it is converted to a DType.
219+ :param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
220+ This is useful if you want to load a model with a lower dimensionality.
221+ Note that this only applies if you have trained your model using mrl or PCA.
222+ :return: A StaticModel.
222223 """
223224 from model2vec .hf_utils import load_pretrained
224225
225- embeddings , tokenizer , config , _ = load_pretrained (path , token = token , from_sentence_transformers = True )
226+ embeddings , tokenizer , config , metadata = load_pretrained (
227+ folder_or_repo_path = path ,
228+ token = token ,
229+ from_sentence_transformers = True ,
230+ subfolder = None ,
231+ )
232+
233+ embeddings = quantize_and_reduce_dim (
234+ embeddings = embeddings ,
235+ quantize_to = quantize_to ,
236+ dimensionality = dimensionality ,
237+ )
226238
227- return cls (embeddings , tokenizer , config , normalize = normalize , base_model_name = None , language = None )
239+ return cls (
240+ embeddings ,
241+ tokenizer ,
242+ config ,
243+ normalize = normalize ,
244+ base_model_name = metadata .get ("base_model" ),
245+ language = metadata .get ("language" ),
246+ )
228247
229248 def encode_as_sequence (
230249 self ,
0 commit comments