Skip to content

Commit 6e17ed1

Browse files
committed
revise ckpt func verbose and device from new model
1 parent b22531e commit 6e17ed1

1 file changed

Lines changed: 6 additions & 2 deletions

File tree

torch_molecule/utils/checkpoint.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def load_model_from_local(model_instance, path: str) -> None:
5353
except Exception as e:
5454
raise ValueError(f"Error loading model from {path}: {str(e)}")
5555

56-
verbose = model_instance.get_params().get("verbose", False)
56+
verbose = model_instance.get_params().get("verbose", 'none')
5757

5858
required_keys = {"model_state_dict", "hyperparameters", "model_name"}
5959
if not all(key in checkpoint for key in required_keys):
@@ -62,6 +62,8 @@ def load_model_from_local(model_instance, path: str) -> None:
6262

6363
parameter_status = []
6464
for key, new_value in checkpoint["hyperparameters"].items():
65+
if key in ['device']:
66+
continue
6567
if hasattr(model_instance, key):
6668
old_value = getattr(model_instance, key)
6769
is_changed = (old_value != new_value)
@@ -74,7 +76,7 @@ def load_model_from_local(model_instance, path: str) -> None:
7476
if is_changed:
7577
setattr(model_instance, key, new_value)
7678

77-
if parameter_status and verbose:
79+
if parameter_status and verbose != 'none':
7880
print("\nHyperparameter Status:")
7981
print("-" * 80)
8082
print(f"{'Parameter':<20} {'Old Value':<20} {'New Value':<20} {'Status':<10}")
@@ -142,6 +144,8 @@ def load_model_from_hf(model_instance, repo_id: str, path: str, config_filename:
142144

143145
parameter_status = []
144146
for key, new_value in checkpoint["hyperparameters"].items():
147+
if key in ['device']:
148+
continue
145149
if hasattr(model_instance, key):
146150
old_value = getattr(model_instance, key)
147151
is_changed = (old_value != new_value)

0 commit comments

Comments
 (0)