We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 250168b commit 8ab20b2Copy full SHA for 8ab20b2
1 file changed
deepmd/pt/infer/inference.py
@@ -73,4 +73,7 @@ def __init__(
73
self.wrapper = ModelWrapper(self.model) # inference only
74
if JIT:
75
self.wrapper = torch.jit.script(self.wrapper)
76
+ # Drop loss-related keys (e.g. loss buffers like XASLoss.e_ref) that
77
+ # are not part of the inference-only wrapper.
78
+ state_dict = {k: v for k, v in state_dict.items() if not k.startswith("loss.")}
79
self.wrapper.load_state_dict(state_dict)
0 commit comments