Skip to content

Commit 6ae50db

Browse files
author
Han Wang
committed
fix(pt): wrap fparam/aparam reshape with descriptive ValueError
Match the dpmodel try/except pattern so shape mismatches produce a clear error instead of a raw RuntimeError from torch.view.
1 parent 6158d9c commit 6ae50db

File tree

1 file changed

+14
-2
lines changed

1 file changed

+14
-2
lines changed

deepmd/pt/model/task/fitting.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -779,7 +779,13 @@ def _forward_common(
779779
assert fparam is not None, "fparam should not be None"
780780
assert self.fparam_avg is not None
781781
assert self.fparam_inv_std is not None
782-
fparam = fparam.view([nf, self.numb_fparam])
782+
try:
783+
fparam = fparam.view([nf, self.numb_fparam])
784+
except RuntimeError as e:
785+
raise ValueError(
786+
f"input fparam: cannot reshape {list(fparam.shape)} "
787+
f"into ({nf}, {self.numb_fparam})."
788+
) from e
783789
nb, _ = fparam.shape
784790
t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb)
785791
t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb)
@@ -799,7 +805,13 @@ def _forward_common(
799805
assert aparam is not None, "aparam should not be None"
800806
assert self.aparam_avg is not None
801807
assert self.aparam_inv_std is not None
802-
aparam = aparam.view([nf, -1, self.numb_aparam])
808+
try:
809+
aparam = aparam.view([nf, -1, self.numb_aparam])
810+
except RuntimeError as e:
811+
raise ValueError(
812+
f"input aparam: cannot reshape {list(aparam.shape)} "
813+
f"into ({nf}, nloc, {self.numb_aparam})."
814+
) from e
803815
nb, nloc, _ = aparam.shape
804816
t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc)
805817
t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc)

0 commit comments

Comments
 (0)