Skip to content

Commit 3d7348e

Browse files
committed
refine save and load funcs
1 parent 5c0b7db commit 3d7348e

5 files changed

Lines changed: 48 additions & 32 deletions

File tree

docs/source/api/encoder.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ inherited from :class:`torch_molecule.base.base.BaseModel`
1818

1919
- ``save_to_local(path)``: Save the trained model to a local file
2020
- ``load_from_local(path)``: Load a trained model from a local file
21-
- ``push_to_huggingface(repo_id)``: Push the model to Hugging Face Hub
22-
- ``load_from_huggingface(repo_id)``: Load a model from Hugging Face Hub
23-
- ``save``: Save the model to either local storage or Hugging Face
24-
- ``load``: Load a model from either local storage or Hugging Face
21+
- ``save_to_hf(repo_id)``: Push the model to Hugging Face Hub
22+
- ``load_from_hf(repo_id, local_cache)``: Load a model from Hugging Face Hub and save it to a local file
23+
- ``save(path, repo_id)``: Save the model to either local storage or Hugging Face
24+
- ``load(path, repo_id)``: Load a model from either local storage or Hugging Face
2525

2626

2727
Self-supervised Molecular Representation Learning

docs/source/api/generator.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,18 @@ inherited from :class:`torch_molecule.base.base.BaseModel`
1414

1515
- ``save_to_local(path)``: Save the trained model to a local file
1616
- ``load_from_local(path)``: Load a trained model from a local file
17-
- ``push_to_huggingface(repo_id)``: Push the model to Hugging Face Hub
17+
- ``save_to_hf(repo_id)``: Push the model to Hugging Face Hub
1818

1919
Not implemented for:
2020
- :class:`torch_molecule.generator.graph_ga.modeling_graph_ga.GraphGAMolecularGenerator`
2121

22-
- ``load_from_huggingface(repo_id)``: Load a model from Hugging Face Hub
22+
- ``load_from_hf(repo_id, local_cache)``: Load a model from Hugging Face Hub and save it to a local file
2323

2424
Not implemented for:
2525
- :class:`torch_molecule.generator.graph_ga.modeling_graph_ga.GraphGAMolecularGenerator`
2626

27-
- ``save``: Save the model to either local storage or Hugging Face
28-
- ``load``: Load a model from either local storage or Hugging Face
27+
- ``save(path, repo_id)``: Save the model to either local storage or Hugging Face
28+
- ``load(path, repo_id)``: Load a model from either local storage or Hugging Face
2929

3030
Modeling Molecules as Graphs with GNN / Transformer-based Generators
3131
---------------------------------------------------------------------

docs/source/api/predictor.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ inherited from :class:`torch_molecule.base.base.BaseModel`
1515

1616
- ``save_to_local(path)``: Save the trained model to a local file
1717
- ``load_from_local(path)``: Load a trained model from a local file
18-
- ``push_to_huggingface(repo_id)``: Push the model to Hugging Face Hub
19-
- ``load_from_huggingface(repo_id)``: Load a model from Hugging Face Hub
20-
- ``save``: Save the model to either local storage or Hugging Face
21-
- ``load``: Load a model from either local storage or Hugging Face
18+
- ``save_to_hf(repo_id)``: Push the model to Hugging Face Hub
19+
- ``load_from_hf(repo_id, local_cache)``: Load a model from Hugging Face Hub and save it to a local file
20+
- ``save(path, repo_id)``: Save the model to either local storage or Hugging Face
21+
- ``load(path, repo_id)``: Load a model from either local storage or Hugging Face
2222

2323
Modeling Molecules as Graphs with Graph Neural Networks
2424
-------------------------------------------------------

torch_molecule/base/base.py

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def save_to_hf(
225225
commit_message: str = "Update model",
226226
hf_token: Optional[str] = None,
227227
private: bool = False,
228+
config_filename: Optional[str] = 'config.json',
228229
) -> None:
229230
"""Save model to Hugging Face Hub.
230231
@@ -244,7 +245,8 @@ def save_to_hf(
244245
Hugging Face authentication token
245246
private : bool, default=False
246247
Whether the repository should be private
247-
248+
config_filename : Optional[str], default='config.json'
249+
Name of the configuration file to save to the repository
248250
Raises
249251
------
250252
ValueError
@@ -261,19 +263,24 @@ def save_to_hf(
261263
commit_message=commit_message,
262264
token=hf_token,
263265
private=private,
266+
config_filename=config_filename,
264267
)
265268

266-
def load_from_hf(self, repo_id: str, path: str) -> None:
269+
def load_from_hf(self, repo_id: str, local_cache: Optional[str] = None, config_filename: Optional[str] = 'config.json') -> None:
267270
"""Load model from Hugging Face Hub.
268271
269272
Parameters
270273
----------
271274
repo_id : str
272275
Hugging Face repository ID
273-
path : str
274-
Path within the repository to load the model from
276+
local_cache : str, default=None
277+
Local path to save the model
278+
config_filename : str, default='config.json'
279+
Name of the configuration file to load from the repository
275280
"""
276-
HuggingFaceCheckpointManager.load_model_from_hf(self, repo_id, path)
281+
if local_cache is None:
282+
local_cache = 'model.pt'
283+
HuggingFaceCheckpointManager.load_model_from_hf(self, repo_id, local_cache, config_filename)
277284

278285
def save(self, path: Optional[str] = None, repo_id: Optional[str] = None, **kwargs) -> None:
279286
"""Automatic save to either local disk or Hugging Face Hub.
@@ -292,34 +299,42 @@ def save(self, path: Optional[str] = None, repo_id: Optional[str] = None, **kwar
292299
ValueError
293300
If path is None when repo_id is None
294301
"""
302+
# if both path and repo_id are None, raise an error
303+
if path is None and repo_id is None:
304+
raise ValueError("path must be provided if repo_id is not given.")
305+
295306
if repo_id is not None:
296307
self.save_to_hf(repo_id=repo_id, **kwargs)
297-
else:
298-
if path is None:
299-
raise ValueError("path must be provided if repo_id is not given.")
308+
309+
if path is not None:
300310
self.save_to_local(path)
301311

302-
def load(self, path: str, repo_id: Optional[str] = None) -> None:
312+
def load(self, path: Optional[str] = None, repo_id: Optional[str] = None, **kwargs) -> None:
303313
"""Automatic load from either local disk or Hugging Face Hub.
304314
305315
Parameters
306316
----------
307-
path : str
308-
File path for local loading or path within the repository
317+
path : Optional[str], default=None
318+
File path for local loading.
309319
repo_id : Optional[str], default=None
310-
Hugging Face repository ID for remote loading
320+
Hugging Face repository ID for remote loading. If path is provided, repo_id is ignored.
321+
**kwargs
322+
Additional arguments passed to load_from_hf
311323
312324
Raises
313325
------
314326
FileNotFoundError
315327
If no local file is found and no repo_id is provided
316328
"""
317-
if os.path.exists(path):
318-
self.load_from_local(path)
329+
if path is not None:
330+
if os.path.exists(path):
331+
self.load_from_local(path)
332+
else:
333+
raise FileNotFoundError(f"No local file found at '{path}'.")
319334
else:
320335
if repo_id is None:
321-
raise FileNotFoundError(f"No local file found at '{path}' and no repo_id provided.")
322-
self.load_from_hf(repo_id, path)
336+
raise ValueError("repo_id must be provided if path is not given.")
337+
self.load_from_hf(repo_id, **kwargs)
323338

324339
def _check_is_fitted(self) -> None:
325340
"""Check if the model is fitted.

torch_molecule/utils/checkpoint.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ class HuggingFaceCheckpointManager:
107107
"""Handles saving and loading of models to and from the Hugging Face Hub."""
108108

109109
@staticmethod
110-
def load_model_from_hf(model_instance, repo_id: str, path: str) -> None:
110+
def load_model_from_hf(model_instance, repo_id: str, path: str, config_filename: str = "config.json") -> None:
111111
"""Load model from Hugging Face Hub, saving locally to `path` first."""
112112
try:
113113
from huggingface_hub import hf_hub_download
@@ -129,7 +129,7 @@ def load_model_from_hf(model_instance, repo_id: str, path: str) -> None:
129129

130130
hf_hub_download(
131131
repo_id=repo_id,
132-
filename="config.json",
132+
filename=config_filename,
133133
local_dir=os.path.dirname(path),
134134
)
135135

@@ -196,6 +196,7 @@ def push_to_huggingface(
196196
commit_message: str = "Update model",
197197
token: Optional[str] = None,
198198
private: bool = False,
199+
config_filename: str = "config.json",
199200
) -> None:
200201
"""Push a task-specific model checkpoint to Hugging Face Hub."""
201202
try:
@@ -252,8 +253,8 @@ def push_to_huggingface(
252253
num_params=num_params,
253254
)
254255

255-
# Save config.json
256-
config_path = os.path.join(tmp_dir, "config.json")
256+
# Save config file
257+
config_path = os.path.join(tmp_dir, config_filename)
257258
with open(config_path, "w") as f:
258259
json.dump(final_config, f, indent=2)
259260

0 commit comments

Comments
 (0)