99import numpy as np
1010import safetensors
1111from huggingface_hub import ModelCard , ModelCardData
12+ from huggingface_hub .constants import HF_HUB_CACHE
1213from safetensors .numpy import save_file
1314from tokenizers import Tokenizer
1415
@@ -99,6 +100,7 @@ def load_pretrained(
99100 subfolder : str | None = None ,
100101 token : str | None = None ,
101102 from_sentence_transformers : bool = False ,
103+ skip_metadata : bool = False ,
102104) -> tuple [np .ndarray , Tokenizer , dict [str , Any ], dict [str , Any ]]:
103105 """
104106 Loads a pretrained model from a folder.
@@ -109,6 +111,7 @@ def load_pretrained(
109111 :param subfolder: The subfolder to load from.
110112 :param token: The huggingface token to use.
111113 :param from_sentence_transformers: Whether to load the model from a sentence transformers model.
114+ :param skip_metadata: Whether to skip loading metadata. This is useful if you don't need the metadata.
112115 :raises: FileNotFoundError if the folder exists, but the file does not exist locally.
113116 :return: The embeddings, tokenizer, config, and metadata.
114117
@@ -122,7 +125,12 @@ def load_pretrained(
122125 tokenizer_file = "tokenizer.json"
123126 config_name = "config.json"
124127
125- folder_or_repo_path = Path (folder_or_repo_path )
128+ if cached_folder := _get_latest_model_path (str (folder_or_repo_path )):
129+ logger .info (f"Found cached model at { cached_folder } , loading from cache." )
130+ folder_or_repo_path = cached_folder
131+ else :
132+ logger .info (f"No cached model found for { folder_or_repo_path } , loading from local or hub." )
133+ folder_or_repo_path = Path (folder_or_repo_path )
126134
127135 local_folder = folder_or_repo_path / subfolder if subfolder else folder_or_repo_path
128136
@@ -139,9 +147,7 @@ def load_pretrained(
139147 if not tokenizer_path .exists ():
140148 raise FileNotFoundError (f"Tokenizer file does not exist in { local_folder } " )
141149
142- # README is optional, so this is a bit finicky.
143150 readme_path = local_folder / "README.md"
144- metadata = _get_metadata_from_readme (readme_path )
145151
146152 else :
147153 logger .info ("Folder does not exist locally, attempting to use huggingface hub." )
@@ -150,18 +156,11 @@ def load_pretrained(
150156 folder_or_repo_path .as_posix (), model_file , token = token , subfolder = subfolder
151157 )
152158 )
153-
154- try :
155- readme_path = Path (
156- huggingface_hub .hf_hub_download (
157- folder_or_repo_path .as_posix (), "README.md" , token = token , subfolder = subfolder
158- )
159+ readme_path = Path (
160+ huggingface_hub .hf_hub_download (
161+ folder_or_repo_path .as_posix (), "README.md" , token = token , subfolder = subfolder
159162 )
160- metadata = _get_metadata_from_readme (Path (readme_path ))
161- except Exception as e :
162- # NOTE: we don't want to raise an error here, since the README is optional.
163- logger .info (f"No README found in the model folder: { e } No model card loaded." )
164- metadata = {}
163+ )
165164
166165 config_path = Path (
167166 huggingface_hub .hf_hub_download (
@@ -175,10 +174,13 @@ def load_pretrained(
175174 )
176175
177176 opened_tensor_file = cast (SafeOpenProtocol , safetensors .safe_open (embeddings_path , framework = "numpy" ))
178- if from_sentence_transformers :
179- embeddings = opened_tensor_file .get_tensor ("embedding.weight" )
177+ embedding_key = "embedding.weight" if from_sentence_transformers else "embeddings"
178+ embeddings = opened_tensor_file .get_tensor (embedding_key )
179+
180+ if not skip_metadata and readme_path .exists ():
181+ metadata = _get_metadata_from_readme (readme_path )
180182 else :
181- embeddings = opened_tensor_file . get_tensor ( "embeddings" )
183+ metadata = {}
182184
183185 tokenizer : Tokenizer = Tokenizer .from_file (str (tokenizer_path ))
184186 config = json .load (open (config_path ))
@@ -223,3 +225,28 @@ def push_folder_to_hub(
223225 huggingface_hub .upload_folder (repo_id = repo_id , folder_path = folder_path , token = token , path_in_repo = subfolder )
224226
225227 logger .info (f"Pushed model to { repo_id } " )
228+
229+
230+ def _get_latest_model_path (model_id : str ) -> Path | None :
231+ """
232+ Gets the latest model path for a given identifier from the hugging face hub cache.
233+
234+ Returns None if there is no cached model. In this case, the model will be downloaded.
235+ """
236+ # Make path object
237+ cache_dir = Path (HF_HUB_CACHE )
238+ # This is specific to how HF stores the files.
239+ normalized = model_id .replace ("/" , "--" )
240+ repo_dir = cache_dir / f"models--{ normalized } " / "snapshots"
241+
242+ if not repo_dir .exists ():
243+ return None
244+
245+ # Find all directories.
246+ snapshots = [p for p in repo_dir .iterdir () if p .is_dir ()]
247+ if not snapshots :
248+ return None
249+
250+ # Get the latest directory by modification time.
251+ latest_snapshot = max (snapshots , key = lambda p : p .stat ().st_mtime )
252+ return latest_snapshot
0 commit comments