Skip to content

Commit d057ca1

Browse files
author
Han Wang
committed
fix test
1 parent f5171f2 commit d057ca1

2 files changed

Lines changed: 14 additions & 6 deletions

File tree

source/tests/common/dpmodel/test_fitting_invar_fitting.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,19 +142,22 @@ def test_self_exception(
142142
iap = None
143143
with self.assertRaises(ValueError) as context:
144144
ret0 = ifn0(dd[0][:, :, :-2], atype, fparam=ifp, aparam=iap)
145-
self.assertIn("input descriptor", context.exception)
145+
self.assertIn("input descriptor", str(context.exception))
146146

147147
if nfp > 0:
148148
ifp = rng.normal(size=(self.nf, nfp - 1))
149149
with self.assertRaises(ValueError) as context:
150150
ret0 = ifn0(dd[0], atype, fparam=ifp, aparam=iap)
151-
self.assertIn("input fparam", context.exception)
151+
self.assertIn("input fparam", str(context.exception))
152152

153153
if nap > 0:
154+
# restore correct ifp before testing aparam
155+
if nfp > 0:
156+
ifp = rng.normal(size=(self.nf, nfp))
154157
iap = rng.normal(size=(self.nf, self.nloc, nap - 1))
155158
with self.assertRaises(ValueError) as context:
156159
ifn0(dd[0], atype, fparam=ifp, aparam=iap)
157-
self.assertIn("input aparam", context.exception)
160+
self.assertIn("input aparam", str(context.exception))
158161

159162
def test_get_set(self) -> None:
160163
ifn0 = InvarFitting(

source/tests/pt_expt/fitting/test_fitting_invar_fitting.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def test_self_exception(
172172
fparam=ifp,
173173
aparam=iap,
174174
)
175-
self.assertIn("input descriptor", str(context.exception))
175+
self.assertIn("input descriptor", str(context.exception))
176176

177177
if nfp > 0:
178178
ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp - 1))).to(
@@ -185,9 +185,14 @@ def test_self_exception(
185185
fparam=ifp,
186186
aparam=iap,
187187
)
188-
self.assertIn("input fparam", str(context.exception))
188+
self.assertIn("input fparam", str(context.exception))
189189

190190
if nap > 0:
191+
# restore correct ifp before testing aparam
192+
if nfp > 0:
193+
ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(
194+
self.device
195+
)
191196
iap = torch.from_numpy(
192197
rng.normal(size=(self.nf, self.nloc, nap - 1))
193198
).to(self.device)
@@ -198,7 +203,7 @@ def test_self_exception(
198203
fparam=ifp,
199204
aparam=iap,
200205
)
201-
self.assertIn("input aparam", str(context.exception))
206+
self.assertIn("input aparam", str(context.exception))
202207

203208
def test_get_set(self) -> None:
204209
ifn0 = InvarFitting(

0 commit comments

Comments
 (0)