Skip to content

Commit 448baea

Browse files
committed
fix loss dp test
1 parent bd7a3cd commit 448baea

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

deepmd/pt/infer/deep_eval.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,14 @@ def __init__(
168168
if not self.input_param.get("hessian_mode") and not no_jit:
169169
model = torch.jit.script(model)
170170
self.dp = ModelWrapper(model)
171-
self.dp.load_state_dict(state_dict)
171+
try:
172+
self.dp.load_state_dict(state_dict)
173+
except RuntimeError:
174+
state_dict_wo_loss = {"_extra_state": state_dict["_extra_state"]}
175+
for item in state_dict:
176+
if "loss.Default." not in item and "_extra_state" not in item:
177+
state_dict_wo_loss[item] = state_dict[item].clone()
178+
self.dp.load_state_dict(state_dict_wo_loss)
172179
elif str(self.model_path).endswith(".pth"):
173180
model = torch.jit.load(model_file, map_location=env.DEVICE)
174181
self.dp = ModelWrapper(model)

0 commit comments

Comments
 (0)