3030
3131
3232class DPTest :
33- def test_dp_test_1_frame (self ) -> None :
33+ def _run_dp_test (
34+ self , use_input_json : bool , numb_test : int = 0 , use_train : bool = False
35+ ) -> None :
3436 trainer = get_trainer (deepcopy (self .config ))
3537 with torch .device ("cpu" ):
3638 input_dict , label_dict , _ = trainer .get_data (is_train = False )
@@ -44,12 +46,17 @@ def test_dp_test_1_frame(self) -> None:
4446 model = torch .jit .script (trainer .model )
4547 tmp_model = tempfile .NamedTemporaryFile (delete = False , suffix = ".pth" )
4648 torch .jit .save (model , tmp_model .name )
49+ val_sys = self .config ["training" ]["validation_data" ]["systems" ]
50+ if isinstance (val_sys , list ):
51+ val_sys = val_sys [0 ]
4752 dp_test (
4853 model = tmp_model .name ,
49- system = self . config [ "training" ][ "validation_data" ][ "systems" ][ 0 ] ,
54+ system = val_sys ,
5055 datafile = None ,
56+ input_json = self .input_json if use_input_json else None ,
57+ use_train = use_train ,
5158 set_prefix = "set" ,
52- numb_test = 0 ,
59+ numb_test = numb_test ,
5360 rand_seed = None ,
5461 shuffle_test = False ,
5562 detail_file = self .detail_file ,
@@ -93,6 +100,20 @@ def test_dp_test_1_frame(self) -> None:
93100 ).reshape (- 1 , 3 ),
94101 )
95102
103+ def test_dp_test_1_frame (self ) -> None :
104+ self ._run_dp_test (False )
105+
106+ def test_dp_test_input_json (self ) -> None :
107+ self ._run_dp_test (True )
108+
109+ def test_dp_test_input_json_train (self ) -> None :
110+ with open (self .input_json ) as f :
111+ cfg = json .load (f )
112+ cfg ["training" ]["validation_data" ]["systems" ] = ["non-existent" ]
113+ with open (self .input_json , "w" ) as f :
114+ json .dump (cfg , f , indent = 4 )
115+ self ._run_dp_test (True , use_train = True )
116+
96117 def tearDown (self ) -> None :
97118 for f in os .listdir ("." ):
98119 if f .startswith ("model" ) and f .endswith (".pt" ):
@@ -140,6 +161,117 @@ def setUp(self) -> None:
140161 json .dump (self .config , fp , indent = 4 )
141162
142163
164+ class TestDPTestSeARglob (unittest .TestCase ):
165+ def setUp (self ) -> None :
166+ self .detail_file = "test_dp_test_ener_rglob_detail"
167+ input_json = str (Path (__file__ ).parent / "water/se_atten.json" )
168+ with open (input_json ) as f :
169+ self .config = json .load (f )
170+ self .config ["training" ]["numb_steps" ] = 1
171+ self .config ["training" ]["save_freq" ] = 1
172+ data_file = [str (Path (__file__ ).parent / "water/data/single" )]
173+ self .config ["training" ]["training_data" ]["systems" ] = data_file
174+ root_dir = str (Path (__file__ ).parent )
175+ self .config ["training" ]["validation_data" ]["systems" ] = root_dir
176+ self .config ["training" ]["validation_data" ]["rglob_patterns" ] = [
177+ "water/data/single"
178+ ]
179+ self .config ["model" ] = deepcopy (model_se_e2_a )
180+ self .input_json = "test_dp_test_rglob.json"
181+ with open (self .input_json , "w" ) as fp :
182+ json .dump (self .config , fp , indent = 4 )
183+
184+ def test_dp_test_input_json_rglob (self ) -> None :
185+ trainer = get_trainer (deepcopy (self .config ))
186+ with torch .device ("cpu" ):
187+ input_dict , _ , _ = trainer .get_data (is_train = False )
188+ input_dict .pop ("spin" , None )
189+ model = torch .jit .script (trainer .model )
190+ tmp_model = tempfile .NamedTemporaryFile (delete = False , suffix = ".pth" )
191+ torch .jit .save (model , tmp_model .name )
192+ dp_test (
193+ model = tmp_model .name ,
194+ system = self .config ["training" ]["validation_data" ]["systems" ],
195+ datafile = None ,
196+ input_json = self .input_json ,
197+ set_prefix = "set" ,
198+ numb_test = 1 ,
199+ rand_seed = None ,
200+ shuffle_test = False ,
201+ detail_file = self .detail_file ,
202+ atomic = False ,
203+ )
204+ os .unlink (tmp_model .name )
205+ self .assertTrue (os .path .exists (self .detail_file + ".e.out" ))
206+
207+ def tearDown (self ) -> None :
208+ for f in os .listdir ("." ):
209+ if f .startswith ("model" ) and f .endswith (".pt" ):
210+ os .remove (f )
211+ if f .startswith (self .detail_file ):
212+ os .remove (f )
213+ if f in ["lcurve.out" , self .input_json ]:
214+ os .remove (f )
215+ if f in ["stat_files" ]:
216+ shutil .rmtree (f )
217+
218+
219+ class TestDPTestSeARglobTrain (unittest .TestCase ):
220+ def setUp (self ) -> None :
221+ self .detail_file = "test_dp_test_ener_rglob_train_detail"
222+ input_json = str (Path (__file__ ).parent / "water/se_atten.json" )
223+ with open (input_json ) as f :
224+ self .config = json .load (f )
225+ self .config ["training" ]["numb_steps" ] = 1
226+ self .config ["training" ]["save_freq" ] = 1
227+ root_dir = str (Path (__file__ ).parent )
228+ self .config ["training" ]["training_data" ]["systems" ] = root_dir
229+ self .config ["training" ]["training_data" ]["rglob_patterns" ] = [
230+ "water/data/single"
231+ ]
232+ data_file = [str (Path (__file__ ).parent / "water/data/single" )]
233+ self .config ["training" ]["validation_data" ]["systems" ] = data_file
234+ self .config ["model" ] = deepcopy (model_se_e2_a )
235+ self .input_json = "test_dp_test_rglob_train.json"
236+ with open (self .input_json , "w" ) as fp :
237+ json .dump (self .config , fp , indent = 4 )
238+
239+ def test_dp_test_input_json_rglob_train (self ) -> None :
240+ trainer = get_trainer (deepcopy (self .config ))
241+ with torch .device ("cpu" ):
242+ input_dict , _ , _ = trainer .get_data (is_train = False )
243+ input_dict .pop ("spin" , None )
244+ model = torch .jit .script (trainer .model )
245+ tmp_model = tempfile .NamedTemporaryFile (delete = False , suffix = ".pth" )
246+ torch .jit .save (model , tmp_model .name )
247+ dp_test (
248+ model = tmp_model .name ,
249+ system = self .config ["training" ]["validation_data" ]["systems" ],
250+ datafile = None ,
251+ input_json = self .input_json ,
252+ use_train = True ,
253+ set_prefix = "set" ,
254+ numb_test = 1 ,
255+ rand_seed = None ,
256+ shuffle_test = False ,
257+ detail_file = self .detail_file ,
258+ atomic = False ,
259+ )
260+ os .unlink (tmp_model .name )
261+ self .assertTrue (os .path .exists (self .detail_file + ".e.out" ))
262+
263+ def tearDown (self ) -> None :
264+ for f in os .listdir ("." ):
265+ if f .startswith ("model" ) and f .endswith (".pt" ):
266+ os .remove (f )
267+ if f .startswith (self .detail_file ):
268+ os .remove (f )
269+ if f in ["lcurve.out" , self .input_json ]:
270+ os .remove (f )
271+ if f in ["stat_files" ]:
272+ shutil .rmtree (f )
273+
274+
143275class TestDPTestPropertySeA (unittest .TestCase ):
144276 def setUp (self ) -> None :
145277 self .detail_file = "test_dp_test_property_detail"
0 commit comments