@@ -225,6 +225,7 @@ def save_to_hf(
225225 commit_message : str = "Update model" ,
226226 hf_token : Optional [str ] = None ,
227227 private : bool = False ,
228+ config_filename : Optional [str ] = 'config.json' ,
228229 ) -> None :
229230 """Save model to Hugging Face Hub.
230231
@@ -244,7 +245,8 @@ def save_to_hf(
244245 Hugging Face authentication token
245246 private : bool, default=False
246247 Whether the repository should be private
247-
248+ config_filename : Optional[str], default='config.json'
249+ Name of the configuration file to save to the repository
248250 Raises
249251 ------
250252 ValueError
@@ -261,19 +263,24 @@ def save_to_hf(
261263 commit_message = commit_message ,
262264 token = hf_token ,
263265 private = private ,
266+ config_filename = config_filename ,
264267 )
265268
266- def load_from_hf (self , repo_id : str , path : str ) -> None :
269+ def load_from_hf (self , repo_id : str , local_cache : Optional [ str ] = None , config_filename : Optional [ str ] = 'config.json' ) -> None :
267270 """Load model from Hugging Face Hub.
268271
269272 Parameters
270273 ----------
271274 repo_id : str
272275 Hugging Face repository ID
273- path : str
274- Path within the repository to load the model from
276+ local_cache : str, default=None
277+ Local path to save the model
278+ config_filename : str, default='config.json'
279+ Name of the configuration file to load from the repository
275280 """
276- HuggingFaceCheckpointManager .load_model_from_hf (self , repo_id , path )
281+ if local_cache is None :
282+ local_cache = 'model.pt'
283+ HuggingFaceCheckpointManager .load_model_from_hf (self , repo_id , local_cache , config_filename )
277284
278285 def save (self , path : Optional [str ] = None , repo_id : Optional [str ] = None , ** kwargs ) -> None :
279286 """Automatic save to either local disk or Hugging Face Hub.
@@ -292,34 +299,42 @@ def save(self, path: Optional[str] = None, repo_id: Optional[str] = None, **kwar
292299 ValueError
293300 If path is None when repo_id is None
294301 """
302+ # if both path and repo_id are None, raise an error
303+ if path is None and repo_id is None :
304+ raise ValueError ("path must be provided if repo_id is not given." )
305+
295306 if repo_id is not None :
296307 self .save_to_hf (repo_id = repo_id , ** kwargs )
297- else :
298- if path is None :
299- raise ValueError ("path must be provided if repo_id is not given." )
308+
309+ if path is not None :
300310 self .save_to_local (path )
301311
302- def load (self , path : str , repo_id : Optional [str ] = None ) -> None :
312+ def load (self , path : Optional [ str ] = None , repo_id : Optional [str ] = None , ** kwargs ) -> None :
303313 """Automatic load from either local disk or Hugging Face Hub.
304314
305315 Parameters
306316 ----------
307- path : str
308- File path for local loading or path within the repository
317+ path : Optional[ str], default=None
318+ File path for local loading.
309319 repo_id : Optional[str], default=None
310- Hugging Face repository ID for remote loading
320+ Hugging Face repository ID for remote loading. If path is provided, repo_id is ignored.
321+ **kwargs
322+ Additional arguments passed to load_from_hf
311323
312324 Raises
313325 ------
314326 FileNotFoundError
315327 If no local file is found and no repo_id is provided
316328 """
317- if os .path .exists (path ):
318- self .load_from_local (path )
329+ if path is not None :
330+ if os .path .exists (path ):
331+ self .load_from_local (path )
332+ else :
333+ raise FileNotFoundError (f"No local file found at '{ path } '." )
319334 else :
320335 if repo_id is None :
321- raise FileNotFoundError ( f"No local file found at ' { path } ' and no repo_id provided ." )
322- self .load_from_hf (repo_id , path )
336+ raise ValueError ( "repo_id must be provided if path is not given ." )
337+ self .load_from_hf (repo_id , ** kwargs )
323338
324339 def _check_is_fitted (self ) -> None :
325340 """Check if the model is fitted.
0 commit comments