99import numpy as np
1010import safetensors
1111from huggingface_hub import ModelCard , ModelCardData
12+ from huggingface_hub .constants import HF_HUB_CACHE
1213from safetensors .numpy import save_file
1314from tokenizers import Tokenizer
1415
@@ -107,9 +108,10 @@ def _create_model_card(
107108
108109def load_pretrained (
109110 folder_or_repo_path : str | Path ,
110- subfolder : str | None = None ,
111- token : str | None = None ,
112- from_sentence_transformers : bool = False ,
111+ subfolder : str | None ,
112+ token : str | None ,
113+ from_sentence_transformers : bool ,
114+ force_download : bool ,
113115) -> tuple [np .ndarray , Tokenizer , dict [str , Any ], dict [str , Any ], np .ndarray | None , np .ndarray | None ]:
114116 """
115117 Loads a pretrained model from a folder.
@@ -120,8 +122,10 @@ def load_pretrained(
120122 :param subfolder: The subfolder to load from.
121123 :param token: The huggingface token to use.
122124 :param from_sentence_transformers: Whether to load the model from a sentence transformers model.
125+ :param force_download: Whether to force the download of the model. If False, the model is only downloaded if it is not
126+ already present in the cache.
123127 :raises: FileNotFoundError if the folder exists, but the file does not exist locally.
124- :return: The embeddings, tokenizer, config, and metadata .
128+ :return: The embeddings, tokenizer, config, metadata, weights and mapping .
125129
126130 """
127131 if from_sentence_transformers :
@@ -133,7 +137,13 @@ def load_pretrained(
133137 tokenizer_file = "tokenizer.json"
134138 config_name = "config.json"
135139
136- folder_or_repo_path = Path (folder_or_repo_path )
140+ cached_folder = _get_latest_model_path (str (folder_or_repo_path ))
141+ if cached_folder and not force_download :
142+ logger .info (f"Found cached model at { cached_folder } , loading from cache." )
143+ folder_or_repo_path = cached_folder
144+ else :
145+ logger .info (f"No cached model found for { folder_or_repo_path } , loading from local or hub." )
146+ folder_or_repo_path = Path (folder_or_repo_path )
137147
138148 local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path
139149
@@ -150,9 +160,7 @@ def load_pretrained(
150160 if not tokenizer_path .exists ():
151161 raise FileNotFoundError (f"Tokenizer file does not exist in { local_folder } " )
152162
153- # README is optional, so this is a bit finicky.
154163 readme_path = local_folder / "README.md"
155- metadata = _get_metadata_from_readme (readme_path )
156164
157165 else :
158166 logger .info ("Folder does not exist locally, attempting to use huggingface hub." )
@@ -161,18 +169,11 @@ def load_pretrained(
161169 folder_or_repo_path .as_posix (), model_file , token = token , subfolder = subfolder
162170 )
163171 )
164-
165- try :
166- readme_path = Path (
167- huggingface_hub .hf_hub_download (
168- folder_or_repo_path .as_posix (), "README.md" , token = token , subfolder = subfolder
169- )
172+ readme_path = Path (
173+ huggingface_hub .hf_hub_download (
174+ folder_or_repo_path .as_posix (), "README.md" , token = token , subfolder = subfolder
170175 )
171- metadata = _get_metadata_from_readme (Path (readme_path ))
172- except Exception as e :
173- # NOTE: we don't want to raise an error here, since the README is optional.
174- logger .info (f"No README found in the model folder: { e } No model card loaded." )
175- metadata = {}
176+ )
176177
177178 config_path = Path (
178179 huggingface_hub .hf_hub_download (
@@ -186,21 +187,22 @@ def load_pretrained(
186187 )
187188
188189 opened_tensor_file = cast (SafeOpenProtocol , safetensors .safe_open (embeddings_path , framework = "numpy" ))
189- if from_sentence_transformers :
190- embeddings = opened_tensor_file .get_tensor ("embedding.weight" )
190+ embedding_name = "embedding.weight" if from_sentence_transformers else "embeddings"
191+ embeddings = opened_tensor_file .get_tensor (embedding_name )
192+ try :
193+ weights = opened_tensor_file .get_tensor ("weights" )
194+ except Exception :
195+ # Bare except because safetensors does not export its own errors.
191196 weights = None
197+ try :
198+ mapping = opened_tensor_file .get_tensor ("mapping" )
199+ except Exception :
192200 mapping = None
201+
202+ if readme_path .exists ():
203+ metadata = _get_metadata_from_readme (readme_path )
193204 else :
194- embeddings = opened_tensor_file .get_tensor ("embeddings" )
195- try :
196- weights = opened_tensor_file .get_tensor ("weights" )
197- except Exception :
198- # Bare except because safetensors does not export its own errors.
199- weights = None
200- try :
201- mapping = opened_tensor_file .get_tensor ("mapping" )
202- except Exception :
203- mapping = None
205+ metadata = {}
204206
205207 tokenizer : Tokenizer = Tokenizer .from_file (str (tokenizer_path ))
206208 config = json .load (open (config_path ))
@@ -240,3 +242,28 @@ def push_folder_to_hub(
240242 huggingface_hub .upload_folder (repo_id = repo_id , folder_path = folder_path , token = token , path_in_repo = subfolder )
241243
242244 logger .info (f"Pushed model to { repo_id } " )
245+
246+
247+ def _get_latest_model_path (model_id : str ) -> Path | None :
248+ """
249+ Gets the latest model path for a given identifier from the hugging face hub cache.
250+
251+ Returns None if there is no cached model. In this case, the model will be downloaded.
252+ """
253+ # Make path object
254+ cache_dir = Path (HF_HUB_CACHE )
255+ # This is specific to how HF stores the files.
256+ normalized = model_id .replace ("/" , "--" )
257+ repo_dir = cache_dir / f"models--{ normalized } " / "snapshots"
258+
259+ if not repo_dir .exists ():
260+ return None
261+
262+ # Find all directories.
263+ snapshots = [p for p in repo_dir .iterdir () if p .is_dir ()]
264+ if not snapshots :
265+ return None
266+
267+ # Get the latest directory by modification time.
268+ latest_snapshot = max (snapshots , key = lambda p : p .stat ().st_mtime )
269+ return latest_snapshot
0 commit comments