Skip to content

Commit 669f732

Browse files
committed
fix: add force download and remove readme stuff
1 parent e666b4b commit 669f732

3 files changed

Lines changed: 1071 additions & 888 deletions

File tree

model2vec/hf_utils.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def load_pretrained(
100100
subfolder: str | None = None,
101101
token: str | None = None,
102102
from_sentence_transformers: bool = False,
103-
skip_metadata: bool = False,
103+
force_download: bool = False,
104104
) -> tuple[np.ndarray, Tokenizer, dict[str, Any], dict[str, Any]]:
105105
"""
106106
Loads a pretrained model from a folder.
@@ -111,7 +111,8 @@ def load_pretrained(
111111
:param subfolder: The subfolder to load from.
112112
:param token: The huggingface token to use.
113113
:param from_sentence_transformers: Whether to load the model from a sentence transformers model.
114-
:param skip_metadata: Whether to skip loading metadata. This is useful if you don't need the metadata.
114+
:param force_download: Whether to force the download of the model. If False, the model is only downloaded if it is not
115+
already present in the cache.
115116
:raises: FileNotFoundError if the folder exists, but the file does not exist locally.
116117
:return: The embeddings, tokenizer, config, and metadata.
117118
@@ -125,7 +126,8 @@ def load_pretrained(
125126
tokenizer_file = "tokenizer.json"
126127
config_name = "config.json"
127128

128-
if cached_folder := _get_latest_model_path(str(folder_or_repo_path)):
129+
cached_folder = _get_latest_model_path(str(folder_or_repo_path))
130+
if cached_folder and not force_download:
129131
logger.info(f"Found cached model at {cached_folder}, loading from cache.")
130132
folder_or_repo_path = cached_folder
131133
else:
@@ -177,7 +179,7 @@ def load_pretrained(
177179
embedding_key = "embedding.weight" if from_sentence_transformers else "embeddings"
178180
embeddings = opened_tensor_file.get_tensor(embedding_key)
179181

180-
if not skip_metadata and readme_path.exists():
182+
if readme_path.exists():
181183
metadata = _get_metadata_from_readme(readme_path)
182184
else:
183185
metadata = {}

model2vec/model.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def from_pretrained(
156156
subfolder: str | None = None,
157157
quantize_to: str | DType | None = None,
158158
dimensionality: int | None = None,
159-
skip_metadata: bool = False,
159+
force_download: bool = False,
160160
) -> StaticModel:
161161
"""
162162
Load a StaticModel from a local path or huggingface hub path.
@@ -172,8 +172,8 @@ def from_pretrained(
172172
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
173173
This is useful if you want to load a model with a lower dimensionality.
174174
Note that this only applies if you have trained your model using mrl or PCA.
175-
:param skip_metadata: Whether to skip loading metadata. This is useful if you don't need the metadata.
176-
Loading metadata can be slow for models with lots of results in the README.md
175+
:param force_download: Whether to force the download of the model. If False, the model is only downloaded if it is not
176+
already present in the cache.
177177
:return: A StaticModel.
178178
"""
179179
from model2vec.hf_utils import load_pretrained
@@ -183,7 +183,7 @@ def from_pretrained(
183183
token=token,
184184
from_sentence_transformers=False,
185185
subfolder=subfolder,
186-
skip_metadata=skip_metadata,
186+
force_download=force_download,
187187
)
188188

189189
embeddings = quantize_and_reduce_dim(
@@ -209,7 +209,7 @@ def from_sentence_transformers(
209209
normalize: bool | None = None,
210210
quantize_to: str | DType | None = None,
211211
dimensionality: int | None = None,
212-
skip_metadata: bool = False,
212+
force_download: bool = False,
213213
) -> StaticModel:
214214
"""
215215
Load a StaticModel trained with sentence transformers from a local path or huggingface hub path.
@@ -224,8 +224,8 @@ def from_sentence_transformers(
224224
:param dimensionality: The dimensionality of the model. If this is None, use the dimensionality of the model.
225225
This is useful if you want to load a model with a lower dimensionality.
226226
Note that this only applies if you have trained your model using mrl or PCA.
227-
:param skip_metadata: Whether to skip loading metadata. This is useful if you don't need the metadata.
228-
Loading metadata can be slow for models with lots of results in the README.md
227+
:param force_download: Whether to force the download of the model. If False, the model is only downloaded if it is not
228+
already present in the cache.
229229
:return: A StaticModel.
230230
"""
231231
from model2vec.hf_utils import load_pretrained
@@ -235,7 +235,7 @@ def from_sentence_transformers(
235235
token=token,
236236
from_sentence_transformers=True,
237237
subfolder=None,
238-
skip_metadata=skip_metadata,
238+
force_download=force_download,
239239
)
240240

241241
embeddings = quantize_and_reduce_dim(

0 commit comments

Comments
 (0)