Skip to content

Commit 8ab20b2

Browse files
committed
fix: filter loss-related keys from state dict in inference and ignore tests directory
1 parent 250168b commit 8ab20b2

1 file changed

Lines changed: 3 additions & 0 deletions

File tree

deepmd/pt/infer/inference.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,7 @@ def __init__(
7373
self.wrapper = ModelWrapper(self.model) # inference only
7474
if JIT:
7575
self.wrapper = torch.jit.script(self.wrapper)
76+
# Drop loss-related keys (e.g. loss buffers like XASLoss.e_ref) that
77+
# are not part of the inference-only wrapper.
78+
state_dict = {k: v for k, v in state_dict.items() if not k.startswith("loss.")}
7679
self.wrapper.load_state_dict(state_dict)

0 commit comments

Comments
 (0)