66
77import click
88import huggingface_hub
9+ import huggingface_hub .errors
910import numpy as np
1011import safetensors
1112from huggingface_hub import ModelCard , ModelCardData
@@ -53,7 +54,7 @@ def save_pretrained(
5354 :param **kwargs: Any additional arguments.
5455 """
5556 folder_path .mkdir (exist_ok = True , parents = True )
56- save_file ({"embeddings" : embeddings }, folder_path / "embeddings .safetensors" )
57+ save_file ({"embeddings" : embeddings }, folder_path / "model .safetensors" )
5758 tokenizer .save (str (folder_path / "tokenizer.json" ))
5859 json .dump (config , open (folder_path / "config.json" , "w" ))
5960
@@ -115,9 +116,14 @@ def load_pretrained(
115116 """
116117 folder_or_repo_path = Path (folder_or_repo_path )
117118 if folder_or_repo_path .exists ():
118- embeddings_path = folder_or_repo_path / "embeddings .safetensors"
119+ embeddings_path = folder_or_repo_path / "model .safetensors"
119120 if not embeddings_path .exists ():
120- raise FileNotFoundError (f"Embeddings file does not exist in { folder_or_repo_path } " )
121+ old_embeddings_path = folder_or_repo_path / "embeddings.safetensors"
122+ if old_embeddings_path .exists ():
123+ logger .warning ("Old embeddings file found. Please rename to `model.safetensors` and re-save." )
124+ embeddings_path = old_embeddings_path
125+ else :
126+ raise FileNotFoundError (f"Embeddings file does not exist in { folder_or_repo_path } " )
121127
122128 config_path = folder_or_repo_path / "config.json"
123129 if not config_path .exists ():
@@ -129,9 +135,19 @@ def load_pretrained(
129135
130136 else :
131137 logger .info ("Folder does not exist locally, attempting to use huggingface hub." )
132- embeddings_path = huggingface_hub .hf_hub_download (
133- folder_or_repo_path .as_posix (), "embeddings.safetensors" , token = token
134- )
138+ try :
139+ embeddings_path = huggingface_hub .hf_hub_download (
140+ folder_or_repo_path .as_posix (), "model.safetensors" , token = token
141+ )
142+ except huggingface_hub .utils .EntryNotFoundError as e :
143+ try :
144+ embeddings_path = huggingface_hub .hf_hub_download (
145+ folder_or_repo_path .as_posix (), "embeddings.safetensors" , token = token
146+ )
147+ except huggingface_hub .utils .EntryNotFoundError :
148+ # Raise original exception.
149+ raise e
150+
135151 config_path = huggingface_hub .hf_hub_download (folder_or_repo_path .as_posix (), "config.json" , token = token )
136152 tokenizer_path = huggingface_hub .hf_hub_download (folder_or_repo_path .as_posix (), "tokenizer.json" , token = token )
137153
0 commit comments