Skip to content

Commit 62cba14

Browse files
authored
Add support for safetensors (#36)
1 parent c5393ad commit 62cba14

1 file changed

Lines changed: 22 additions & 6 deletions

File tree

model2vec/utils.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import click
88
import huggingface_hub
9+
import huggingface_hub.errors
910
import numpy as np
1011
import safetensors
1112
from huggingface_hub import ModelCard, ModelCardData
@@ -53,7 +54,7 @@ def save_pretrained(
5354
:param **kwargs: Any additional arguments.
5455
"""
5556
folder_path.mkdir(exist_ok=True, parents=True)
56-
save_file({"embeddings": embeddings}, folder_path / "embeddings.safetensors")
57+
save_file({"embeddings": embeddings}, folder_path / "model.safetensors")
5758
tokenizer.save(str(folder_path / "tokenizer.json"))
5859
json.dump(config, open(folder_path / "config.json", "w"))
5960

@@ -115,9 +116,14 @@ def load_pretrained(
115116
"""
116117
folder_or_repo_path = Path(folder_or_repo_path)
117118
if folder_or_repo_path.exists():
118-
embeddings_path = folder_or_repo_path / "embeddings.safetensors"
119+
embeddings_path = folder_or_repo_path / "model.safetensors"
119120
if not embeddings_path.exists():
120-
raise FileNotFoundError(f"Embeddings file does not exist in {folder_or_repo_path}")
121+
old_embeddings_path = folder_or_repo_path / "embeddings.safetensors"
122+
if old_embeddings_path.exists():
123+
logger.warning("Old embeddings file found. Please rename to `model.safetensors` and re-save.")
124+
embeddings_path = old_embeddings_path
125+
else:
126+
raise FileNotFoundError(f"Embeddings file does not exist in {folder_or_repo_path}")
121127

122128
config_path = folder_or_repo_path / "config.json"
123129
if not config_path.exists():
@@ -129,9 +135,19 @@ def load_pretrained(
129135

130136
else:
131137
logger.info("Folder does not exist locally, attempting to use huggingface hub.")
132-
embeddings_path = huggingface_hub.hf_hub_download(
133-
folder_or_repo_path.as_posix(), "embeddings.safetensors", token=token
134-
)
138+
try:
139+
embeddings_path = huggingface_hub.hf_hub_download(
140+
folder_or_repo_path.as_posix(), "model.safetensors", token=token
141+
)
142+
except huggingface_hub.utils.EntryNotFoundError as e:
143+
try:
144+
embeddings_path = huggingface_hub.hf_hub_download(
145+
folder_or_repo_path.as_posix(), "embeddings.safetensors", token=token
146+
)
147+
except huggingface_hub.utils.EntryNotFoundError:
148+
# Raise original exception.
149+
raise e
150+
135151
config_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "config.json", token=token)
136152
tokenizer_path = huggingface_hub.hf_hub_download(folder_or_repo_path.as_posix(), "tokenizer.json", token=token)
137153

0 commit comments

Comments
 (0)