2929 dtype = jnp .float64
3030
3131
32- # @unittest.skipIf(
33- # sys.version_info < (3, 10),
34- # "JAX requires Python 3.10 or later",
35- # )
32+ @unittest .skipIf (
33+ sys .version_info < (3 , 10 ),
34+ "JAX requires Python 3.10 or later" ,
35+ )
3636class TestCaseSingleFrameWithoutNlist :
3737 def setUp (self ) -> None :
38- # nloc == 3, nall == 4
38+ # nf=2, nloc == 3
3939 self .nloc = 3
40- self .nf , self . nt = 1 , 2
40+ self .nt = 2
4141 self .coord = np .array (
4242 [
43- [0 , 0 , 0 ],
44- [0 , 1 , 0 ],
45- [0 , 0 , 1 ],
43+ [
44+ [0 , 0 , 0 ],
45+ [0 , 1 , 0 ],
46+ [0 , 0 , 1 ],
47+ ],
48+ [
49+ [1 , 0 , 1 ],
50+ [0 , 1 , 1 ],
51+ [1 , 1 , 0 ],
52+ ]
4653 ],
4754 dtype = np .float64 ,
48- ). reshape ([ 1 , self . nloc * 3 ])
49- self .atype = np .array ([0 , 0 , 1 ], dtype = int ).reshape ([1 , self .nloc ])
55+ )
56+ self .atype = np .array ([[ 0 , 0 , 1 ],[ 1 , 1 , 0 ]], dtype = int ).reshape ([2 , self .nloc ])
5057 self .cell = 2.0 * np .eye (3 ).reshape ([1 , 9 ])
51- # sel = [5, 2]
58+ self . cell = np . array ([ self . cell , self . cell ]). reshape ( 2 , 9 )
5259 self .sel = [16 , 8 ]
5360 self .rcut = 2.2
5461 self .rcut_smth = 0.4
5562 self .atol = 1e-12
5663
5764
58- # @unittest.skipIf(
59- # sys.version_info < (3, 10),
60- # "JAX requires Python 3.10 or later",
61- # )
65+ @unittest .skipIf (
66+ sys .version_info < (3 , 10 ),
67+ "JAX requires Python 3.10 or later" ,
68+ )
6269class TestPaddingAtoms (unittest .TestCase , TestCaseSingleFrameWithoutNlist ):
6370 def setUp (self ):
6471 TestCaseSingleFrameWithoutNlist .setUp (self )
@@ -77,12 +84,40 @@ def test_padding_atoms_consistency(self):
7784 )
7885 type_map = ["foo" , "bar" ]
7986 model = PropertyModel (ds , ft , type_map = type_map )
87+ var_name = model .get_var_name ()
8088 args = [to_jax_array (ii ) for ii in [self .coord , self .atype , self .cell ]]
81- ret_base = model .call (* args )
89+ result = model .call (* args )
90+ # test intensive
91+ np .testing .assert_allclose (
92+ to_numpy_array (result [f"{ var_name } _redu" ]),
93+ np .mean (to_numpy_array (result [f"{ var_name } " ]),axis = 1 ),
94+ atol = self .atol ,
95+ )
96+ # test padding atoms
97+ padding_atoms_list = [1 , 5 , 10 ]
98+ for padding_atoms in padding_atoms_list :
99+ coord = deepcopy (self .coord )
100+ atype = deepcopy (self .atype )
101+ atype_padding = np .pad (
102+ atype ,
103+ pad_width = ((0 , 0 ), (0 , padding_atoms )),
104+ mode = 'constant' ,
105+ constant_values = - 1
106+ )
107+ coord_padding = np .pad (
108+ coord ,
109+ pad_width = ((0 , 0 ), (0 , padding_atoms ), (0 , 0 )),
110+ mode = 'constant' ,
111+ constant_values = 0
112+ )
113+ args = [to_jax_array (ii ) for ii in [coord_padding , atype_padding , self .cell ]]
114+ result_padding = model .call (* args )
115+ np .testing .assert_allclose (
116+ to_numpy_array (result [f"{ var_name } _redu" ]),
117+ to_numpy_array (result_padding [f"{ var_name } _redu" ]),
118+ atol = self .atol ,
119+ )
82120
83121
84- #np.testing.assert_allclose(
85- # to_numpy_array(ret0[model.get_var_name()]),
86- # to_numpy_array(ret1[md1.get_var_name()]),
87- # atol=self.atol,
88- #)
122+ if __name__ == "__main__" :
123+ unittest .main ()
0 commit comments