@@ -187,27 +187,6 @@ def test_dp_test_1_frame(self) -> None:
187187 to_numpy_array (result [model .get_var_name ()])[0 ],
188188 )
189189
190- def test_dp_test_padding_atoms (self ) -> None :
191- trainer = get_trainer (deepcopy (self .config ))
192- with torch .device ("cpu" ):
193- input_dict , label_dict , _ = trainer .get_data (is_train = False )
194- input_dict .pop ("spin" , None )
195- result = trainer .model (** input_dict )
196- padding_atoms_list = [1 , 5 , 10 ]
197- for padding_atoms in padding_atoms_list :
198- input_dict_padding = deepcopy (input_dict )
199- input_dict_padding ["atype" ] = F .pad (
200- input_dict_padding ["atype" ], (0 , padding_atoms ), value = - 1
201- )
202- input_dict_padding ["coord" ] = F .pad (
203- input_dict_padding ["coord" ], (0 , 0 , 0 , padding_atoms , 0 , 0 ), value = 0
204- )
205- result_padding = trainer .model (** input_dict_padding )
206- np .testing .assert_almost_equal (
207- to_numpy_array (result [trainer .model .get_var_name ()])[0 ],
208- to_numpy_array (result_padding [trainer .model .get_var_name ()])[0 ],
209- )
210-
211190 def tearDown (self ) -> None :
212191 for f in os .listdir ("." ):
213192 if f .startswith ("model" ) and f .endswith (".pt" ):
0 commit comments