Skip to content

Commit ef4d618

Browse files
committed
feat: remove flag argument
1 parent 23bb90b commit ef4d618

2 files changed

Lines changed: 65 additions & 55 deletions

File tree

model2vec/model.py

Lines changed: 39 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tokenizers import Encoding, Tokenizer
1313
from tqdm import tqdm
1414

15-
from model2vec.quantization import DType, quantize_embeddings
15+
from model2vec.quantization import DType, quantize_and_reduce_dim
1616
from model2vec.utils import ProgressParallel, load_local_model
1717

1818
PathLike = Union[Path, str]
@@ -147,51 +147,6 @@ def tokenize(self, sentences: list[str], max_length: int | None = None) -> list[
147147

148148
return encodings_ids
149149

150-
@classmethod
151-
def _load_model(
152-
cls: type[StaticModel],
153-
path: PathLike,
154-
token: str | None,
155-
normalize: bool | None,
156-
quantize_to: str | DType | None,
157-
dimensionality: int | None,
158-
from_sentence_transformers: bool,
159-
subfolder: str | None = None,
160-
) -> StaticModel:
161-
"""Helper function to load a model from a path and optionally quantize it/reduce the dimensionality."""
162-
from model2vec.hf_utils import load_pretrained
163-
164-
embeddings, tokenizer, config, metadata = load_pretrained(
165-
folder_or_repo_path=path,
166-
token=token,
167-
from_sentence_transformers=from_sentence_transformers,
168-
subfolder=subfolder,
169-
)
170-
171-
if quantize_to is not None:
172-
quantize_to = DType(quantize_to)
173-
embeddings = quantize_embeddings(embeddings, quantize_to)
174-
175-
if dimensionality is not None:
176-
if dimensionality > embeddings.shape[1]:
177-
raise ValueError(
178-
f"Dimensionality {dimensionality} is greater than the model dimensionality {embeddings.shape[1]}"
179-
)
180-
embeddings = embeddings[:, :dimensionality]
181-
if config.get("apply_pca", None) is None:
182-
logger.warning(
183-
"You are reducing the dimensionality of the model, but we can't find a pca key in the model config. This might not work as expected."
184-
)
185-
186-
return cls(
187-
embeddings,
188-
tokenizer,
189-
config,
190-
normalize=normalize,
191-
base_model_name=metadata.get("base_model"),
192-
language=metadata.get("language"),
193-
)
194-
195150
@classmethod
196151
def from_pretrained(
197152
cls: type[StaticModel],
@@ -218,16 +173,30 @@ def from_pretrained(
218173
Note that this only applies if you have trained your model using mrl or PCA.
219174
:return: A StaticModel.
220175
"""
221-
return cls._load_model(
222-
path=path,
176+
from model2vec.hf_utils import load_pretrained
177+
178+
embeddings, tokenizer, config, metadata = load_pretrained(
179+
folder_or_repo_path=path,
223180
token=token,
224-
normalize=normalize,
225-
quantize_to=quantize_to,
226-
dimensionality=dimensionality,
227181
from_sentence_transformers=False,
228182
subfolder=subfolder,
229183
)
230184

185+
embeddings = quantize_and_reduce_dim(
186+
embeddings=embeddings,
187+
quantize_to=quantize_to,
188+
dimensionality=dimensionality,
189+
)
190+
191+
return cls(
192+
embeddings,
193+
tokenizer,
194+
config,
195+
normalize=normalize,
196+
base_model_name=metadata.get("base_model"),
197+
language=metadata.get("language"),
198+
)
199+
231200
@classmethod
232201
def from_sentence_transformers(
233202
cls: type[StaticModel],
@@ -252,13 +221,28 @@ def from_sentence_transformers(
252221
Note that this only applies if you have trained your model using mrl or PCA.
253222
:return: A StaticModel.
254223
"""
255-
return cls._load_model(
256-
path=path,
224+
from model2vec.hf_utils import load_pretrained
225+
226+
embeddings, tokenizer, config, metadata = load_pretrained(
227+
folder_or_repo_path=path,
257228
token=token,
258-
normalize=normalize,
229+
from_sentence_transformers=True,
230+
subfolder=None,
231+
)
232+
233+
embeddings = quantize_and_reduce_dim(
234+
embeddings=embeddings,
259235
quantize_to=quantize_to,
260236
dimensionality=dimensionality,
261-
from_sentence_transformers=True,
237+
)
238+
239+
return cls(
240+
embeddings,
241+
tokenizer,
242+
config,
243+
normalize=normalize,
244+
base_model_name=metadata.get("base_model"),
245+
language=metadata.get("language"),
262246
)
263247

264248
def encode_as_sequence(

model2vec/quantization.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,29 @@ def quantize_embeddings(embeddings: np.ndarray, quantize_to: DType) -> np.ndarra
3333
return quantized
3434
else:
3535
raise ValueError("Not a valid enum member of DType.")
36+
37+
38+
def quantize_and_reduce_dim(
39+
embeddings: np.ndarray, quantize_to: DType | str | None, dimensionality: int | None
40+
) -> np.ndarray:
41+
"""
42+
Quantize embeddings to a datatype and reduce dimensionality.
43+
44+
:param embeddings: The embeddings to quantize and reduce, as a numpy array.
45+
:param quantize_to: The data type to quantize to. If None, no quantization is performed.
46+
:param dimensionality: The number of dimensions to keep. If None, no dimensionality reduction is performed.
47+
:return: The quantized and reduced embeddings.
48+
:raises ValueError: If the passed dimensionality is not None and greater than the model dimensionality.
49+
"""
50+
if quantize_to is not None:
51+
quantize_to = DType(quantize_to)
52+
embeddings = quantize_embeddings(embeddings, quantize_to)
53+
54+
if dimensionality is not None:
55+
if dimensionality > embeddings.shape[1]:
56+
raise ValueError(
57+
f"Dimensionality {dimensionality} is greater than the model dimensionality {embeddings.shape[1]}"
58+
)
59+
embeddings = embeddings[:, :dimensionality]
60+
61+
return embeddings

0 commit comments

Comments
 (0)