We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent bd7a3cd commit 448baeaCopy full SHA for 448baea
1 file changed
deepmd/pt/infer/deep_eval.py
@@ -168,7 +168,14 @@ def __init__(
168
if not self.input_param.get("hessian_mode") and not no_jit:
169
model = torch.jit.script(model)
170
self.dp = ModelWrapper(model)
171
- self.dp.load_state_dict(state_dict)
+ try:
172
+ self.dp.load_state_dict(state_dict)
173
+ except RuntimeError:
174
+ state_dict_wo_loss = {"_extra_state": state_dict["_extra_state"]}
175
+ for item in state_dict:
176
+ if "loss.Default." not in item and "_extra_state" not in item:
177
+ state_dict_wo_loss[item] = state_dict[item].clone()
178
+ self.dp.load_state_dict(state_dict_wo_loss)
179
elif str(self.model_path).endswith(".pth"):
180
model = torch.jit.load(model_file, map_location=env.DEVICE)
181
0 commit comments