Skip to content

Commit e619fa8

Browse files
committed
Update test_deep_eval.py
1 parent c9101d1 commit e619fa8

1 file changed

Lines changed: 37 additions & 10 deletions

File tree

source/tests/pt_expt/infer/test_deep_eval.py

Lines changed: 37 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -244,12 +244,18 @@ def test_dynamic_shapes(self) -> None:
244244
exported_mod = exported.module()
245245

246246
for nloc in [2, 5, 10]:
247-
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = (
247+
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin = (
248248
_make_sample_inputs(self.model, nloc=nloc)
249249
)
250250

251251
pte_ret = exported_mod(
252-
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam
252+
ext_coord,
253+
ext_atype,
254+
nlist_t,
255+
mapping_t,
256+
fparam,
257+
aparam,
258+
charge_spin,
253259
)
254260

255261
ec = ext_coord.detach().requires_grad_(True)
@@ -261,6 +267,7 @@ def test_dynamic_shapes(self) -> None:
261267
fparam=fparam,
262268
aparam=aparam,
263269
do_atomic_virial=True,
270+
charge_spin=charge_spin,
264271
)
265272

266273
for key in ("energy", "energy_redu", "energy_derv_r", "energy_derv_c"):
@@ -296,8 +303,8 @@ def test_oversized_nlist(self) -> None:
296303

297304
nnei = sum(self.sel) # model's expected neighbor count
298305
nloc = 5
299-
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = _make_sample_inputs(
300-
self.model, nloc=nloc
306+
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin = (
307+
_make_sample_inputs(self.model, nloc=nloc)
301308
)
302309

303310
# Pad nlist with -1 columns, then shuffle column order so real
@@ -331,11 +338,18 @@ def test_oversized_nlist(self) -> None:
331338
fparam=fparam,
332339
aparam=aparam,
333340
do_atomic_virial=True,
341+
charge_spin=charge_spin,
334342
)
335343

336344
# Exported model with same shuffled oversized nlist
337345
pte_ret = exported_mod(
338-
ext_coord, ext_atype, nlist_shuffled, mapping_t, fparam, aparam
346+
ext_coord,
347+
ext_atype,
348+
nlist_shuffled,
349+
mapping_t,
350+
fparam,
351+
aparam,
352+
charge_spin,
339353
)
340354

341355
for key in ("energy", "energy_redu", "energy_derv_r", "energy_derv_c"):
@@ -362,6 +376,7 @@ def test_oversized_nlist(self) -> None:
362376
fparam=fparam,
363377
aparam=aparam,
364378
do_atomic_virial=True,
379+
charge_spin=charge_spin,
365380
)
366381
# The truncated result MUST differ from the correctly sorted result,
367382
# proving that naive truncation discards real neighbors.
@@ -382,7 +397,7 @@ def test_serialize_round_trip(self) -> None:
382397
model2.eval()
383398

384399
for nloc in [3, 7]:
385-
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = (
400+
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin = (
386401
_make_sample_inputs(self.model, nloc=nloc)
387402
)
388403
ec1 = ext_coord.detach().requires_grad_(True)
@@ -396,6 +411,7 @@ def test_serialize_round_trip(self) -> None:
396411
fparam=fparam,
397412
aparam=aparam,
398413
do_atomic_virial=True,
414+
charge_spin=charge_spin,
399415
)
400416
ret2 = model2.forward_common_lower(
401417
ec2,
@@ -405,6 +421,7 @@ def test_serialize_round_trip(self) -> None:
405421
fparam=fparam,
406422
aparam=aparam,
407423
do_atomic_virial=True,
424+
charge_spin=charge_spin,
408425
)
409426

410427
for key in ("energy", "energy_redu", "energy_derv_r", "energy_derv_c"):
@@ -943,8 +960,8 @@ def test_oversized_nlist(self) -> None:
943960

944961
nnei = sum(self.sel) # model's expected neighbor count
945962
nloc = 5
946-
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = _make_sample_inputs(
947-
self.model, nloc=nloc
963+
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin = (
964+
_make_sample_inputs(self.model, nloc=nloc)
948965
)
949966

950967
# Pad nlist with -1 columns, then shuffle column order so real
@@ -977,10 +994,17 @@ def test_oversized_nlist(self) -> None:
977994
fparam=fparam,
978995
aparam=aparam,
979996
do_atomic_virial=True,
997+
charge_spin=charge_spin,
980998
)
981999

9821000
pte_ret = exported_mod(
983-
ext_coord, ext_atype, nlist_shuffled, mapping_t, fparam, aparam
1001+
ext_coord,
1002+
ext_atype,
1003+
nlist_shuffled,
1004+
mapping_t,
1005+
fparam,
1006+
aparam,
1007+
charge_spin,
9841008
)
9851009

9861010
for key in ("energy", "energy_redu", "energy_derv_r", "energy_derv_c"):
@@ -1004,6 +1028,7 @@ def test_oversized_nlist(self) -> None:
10041028
fparam=fparam,
10051029
aparam=aparam,
10061030
do_atomic_virial=True,
1031+
charge_spin=charge_spin,
10071032
)
10081033
e_ref = ref_ret["energy_redu"].detach().cpu().numpy()
10091034
e_trunc = trunc_ret["energy_redu"].detach().cpu().numpy()
@@ -1022,7 +1047,7 @@ def test_serialize_round_trip(self) -> None:
10221047
model2.eval()
10231048

10241049
for nloc in [3, 7]:
1025-
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = (
1050+
ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam, charge_spin = (
10261051
_make_sample_inputs(self.model, nloc=nloc)
10271052
)
10281053
ec1 = ext_coord.detach().requires_grad_(True)
@@ -1036,6 +1061,7 @@ def test_serialize_round_trip(self) -> None:
10361061
fparam=fparam,
10371062
aparam=aparam,
10381063
do_atomic_virial=True,
1064+
charge_spin=charge_spin,
10391065
)
10401066
ret2 = model2.forward_common_lower(
10411067
ec2,
@@ -1045,6 +1071,7 @@ def test_serialize_round_trip(self) -> None:
10451071
fparam=fparam,
10461072
aparam=aparam,
10471073
do_atomic_virial=True,
1074+
charge_spin=charge_spin,
10481075
)
10491076

10501077
for key in ("energy", "energy_redu", "energy_derv_r", "energy_derv_c"):

0 commit comments

Comments
 (0)