Skip to content

Commit c2efbf1

Browse files
author
Han Wang
committed
fix(pt): wrap fparam/aparam reshape with descriptive ValueError
Match the dpmodel try/except pattern so shape mismatches produce a clear error instead of a raw RuntimeError from torch.view.
1 parent 6158d9c commit c2efbf1

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

deepmd/pt/model/task/fitting.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,11 @@ def _forward_common(
779779
assert fparam is not None, "fparam should not be None"
780780
assert self.fparam_avg is not None
781781
assert self.fparam_inv_std is not None
782+
if fparam.numel() != nf * self.numb_fparam:
783+
raise ValueError(
784+
f"input fparam: cannot reshape {list(fparam.shape)} "
785+
f"into ({nf}, {self.numb_fparam})."
786+
)
782787
fparam = fparam.view([nf, self.numb_fparam])
783788
nb, _ = fparam.shape
784789
t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb)
@@ -799,6 +804,11 @@ def _forward_common(
799804
assert aparam is not None, "aparam should not be None"
800805
assert self.aparam_avg is not None
801806
assert self.aparam_inv_std is not None
807+
if aparam.numel() % (nf * self.numb_aparam) != 0:
808+
raise ValueError(
809+
f"input aparam: cannot reshape {list(aparam.shape)} "
810+
f"into ({nf}, nloc, {self.numb_aparam})."
811+
)
802812
aparam = aparam.view([nf, -1, self.numb_aparam])
803813
nb, nloc, _ = aparam.shape
804814
t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc)

0 commit comments

Comments
 (0)