@@ -244,12 +244,18 @@ def test_dynamic_shapes(self) -> None:
244244 exported_mod = exported .module ()
245245
246246 for nloc in [2 , 5 , 10 ]:
247- ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam = (
247+ ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin = (
248248 _make_sample_inputs (self .model , nloc = nloc )
249249 )
250250
251251 pte_ret = exported_mod (
252- ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam
252+ ext_coord ,
253+ ext_atype ,
254+ nlist_t ,
255+ mapping_t ,
256+ fparam ,
257+ aparam ,
258+ charge_spin ,
253259 )
254260
255261 ec = ext_coord .detach ().requires_grad_ (True )
@@ -261,6 +267,7 @@ def test_dynamic_shapes(self) -> None:
261267 fparam = fparam ,
262268 aparam = aparam ,
263269 do_atomic_virial = True ,
270+ charge_spin = charge_spin ,
264271 )
265272
266273 for key in ("energy" , "energy_redu" , "energy_derv_r" , "energy_derv_c" ):
@@ -296,8 +303,8 @@ def test_oversized_nlist(self) -> None:
296303
297304 nnei = sum (self .sel ) # model's expected neighbor count
298305 nloc = 5
299- ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam = _make_sample_inputs (
300- self .model , nloc = nloc
306+ ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin = (
307+ _make_sample_inputs ( self .model , nloc = nloc )
301308 )
302309
303310 # Pad nlist with -1 columns, then shuffle column order so real
@@ -331,11 +338,18 @@ def test_oversized_nlist(self) -> None:
331338 fparam = fparam ,
332339 aparam = aparam ,
333340 do_atomic_virial = True ,
341+ charge_spin = charge_spin ,
334342 )
335343
336344 # Exported model with same shuffled oversized nlist
337345 pte_ret = exported_mod (
338- ext_coord , ext_atype , nlist_shuffled , mapping_t , fparam , aparam
346+ ext_coord ,
347+ ext_atype ,
348+ nlist_shuffled ,
349+ mapping_t ,
350+ fparam ,
351+ aparam ,
352+ charge_spin ,
339353 )
340354
341355 for key in ("energy" , "energy_redu" , "energy_derv_r" , "energy_derv_c" ):
@@ -362,6 +376,7 @@ def test_oversized_nlist(self) -> None:
362376 fparam = fparam ,
363377 aparam = aparam ,
364378 do_atomic_virial = True ,
379+ charge_spin = charge_spin ,
365380 )
366381 # The truncated result MUST differ from the correctly sorted result,
367382 # proving that naive truncation discards real neighbors.
@@ -382,7 +397,7 @@ def test_serialize_round_trip(self) -> None:
382397 model2 .eval ()
383398
384399 for nloc in [3 , 7 ]:
385- ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam = (
400+ ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin = (
386401 _make_sample_inputs (self .model , nloc = nloc )
387402 )
388403 ec1 = ext_coord .detach ().requires_grad_ (True )
@@ -396,6 +411,7 @@ def test_serialize_round_trip(self) -> None:
396411 fparam = fparam ,
397412 aparam = aparam ,
398413 do_atomic_virial = True ,
414+ charge_spin = charge_spin ,
399415 )
400416 ret2 = model2 .forward_common_lower (
401417 ec2 ,
@@ -405,6 +421,7 @@ def test_serialize_round_trip(self) -> None:
405421 fparam = fparam ,
406422 aparam = aparam ,
407423 do_atomic_virial = True ,
424+ charge_spin = charge_spin ,
408425 )
409426
410427 for key in ("energy" , "energy_redu" , "energy_derv_r" , "energy_derv_c" ):
@@ -943,8 +960,8 @@ def test_oversized_nlist(self) -> None:
943960
944961 nnei = sum (self .sel ) # model's expected neighbor count
945962 nloc = 5
946- ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam = _make_sample_inputs (
947- self .model , nloc = nloc
963+ ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin = (
964+ _make_sample_inputs ( self .model , nloc = nloc )
948965 )
949966
950967 # Pad nlist with -1 columns, then shuffle column order so real
@@ -977,10 +994,17 @@ def test_oversized_nlist(self) -> None:
977994 fparam = fparam ,
978995 aparam = aparam ,
979996 do_atomic_virial = True ,
997+ charge_spin = charge_spin ,
980998 )
981999
9821000 pte_ret = exported_mod (
983- ext_coord , ext_atype , nlist_shuffled , mapping_t , fparam , aparam
1001+ ext_coord ,
1002+ ext_atype ,
1003+ nlist_shuffled ,
1004+ mapping_t ,
1005+ fparam ,
1006+ aparam ,
1007+ charge_spin ,
9841008 )
9851009
9861010 for key in ("energy" , "energy_redu" , "energy_derv_r" , "energy_derv_c" ):
@@ -1004,6 +1028,7 @@ def test_oversized_nlist(self) -> None:
10041028 fparam = fparam ,
10051029 aparam = aparam ,
10061030 do_atomic_virial = True ,
1031+ charge_spin = charge_spin ,
10071032 )
10081033 e_ref = ref_ret ["energy_redu" ].detach ().cpu ().numpy ()
10091034 e_trunc = trunc_ret ["energy_redu" ].detach ().cpu ().numpy ()
@@ -1022,7 +1047,7 @@ def test_serialize_round_trip(self) -> None:
10221047 model2 .eval ()
10231048
10241049 for nloc in [3 , 7 ]:
1025- ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam = (
1050+ ext_coord , ext_atype , nlist_t , mapping_t , fparam , aparam , charge_spin = (
10261051 _make_sample_inputs (self .model , nloc = nloc )
10271052 )
10281053 ec1 = ext_coord .detach ().requires_grad_ (True )
@@ -1036,6 +1061,7 @@ def test_serialize_round_trip(self) -> None:
10361061 fparam = fparam ,
10371062 aparam = aparam ,
10381063 do_atomic_virial = True ,
1064+ charge_spin = charge_spin ,
10391065 )
10401066 ret2 = model2 .forward_common_lower (
10411067 ec2 ,
@@ -1045,6 +1071,7 @@ def test_serialize_round_trip(self) -> None:
10451071 fparam = fparam ,
10461072 aparam = aparam ,
10471073 do_atomic_virial = True ,
1074+ charge_spin = charge_spin ,
10481075 )
10491076
10501077 for key in ("energy" , "energy_redu" , "energy_derv_r" , "energy_derv_c" ):
0 commit comments