3030
3131
3232class DPTest :
33- def test_dp_test_1_frame (self ) -> None :
33+ def _run_dp_test (self , use_input_json : bool , numb_test : int = 0 ) -> None :
3434 trainer = get_trainer (deepcopy (self .config ))
3535 with torch .device ("cpu" ):
3636 input_dict , label_dict , _ = trainer .get_data (is_train = False )
@@ -44,12 +44,16 @@ def test_dp_test_1_frame(self) -> None:
4444 model = torch .jit .script (trainer .model )
4545 tmp_model = tempfile .NamedTemporaryFile (delete = False , suffix = ".pth" )
4646 torch .jit .save (model , tmp_model .name )
47+ val_sys = self .config ["training" ]["validation_data" ]["systems" ]
48+ if isinstance (val_sys , list ):
49+ val_sys = val_sys [0 ]
4750 dp_test (
4851 model = tmp_model .name ,
49- system = self . config [ "training" ][ "validation_data" ][ "systems" ][ 0 ] ,
52+ system = val_sys ,
5053 datafile = None ,
54+ input_json = self .input_json if use_input_json else None ,
5155 set_prefix = "set" ,
52- numb_test = 0 ,
56+ numb_test = numb_test ,
5357 rand_seed = None ,
5458 shuffle_test = False ,
5559 detail_file = self .detail_file ,
@@ -93,6 +97,12 @@ def test_dp_test_1_frame(self) -> None:
9397 ).reshape (- 1 , 3 ),
9498 )
9599
100+ def test_dp_test_1_frame (self ) -> None :
101+ self ._run_dp_test (False )
102+
103+ def test_dp_test_input_json (self ) -> None :
104+ self ._run_dp_test (True )
105+
96106 def tearDown (self ) -> None :
97107 for f in os .listdir ("." ):
98108 if f .startswith ("model" ) and f .endswith (".pt" ):
@@ -140,6 +150,61 @@ def setUp(self) -> None:
140150 json .dump (self .config , fp , indent = 4 )
141151
142152
153+ class TestDPTestSeARglob (unittest .TestCase ):
154+ def setUp (self ) -> None :
155+ self .detail_file = "test_dp_test_ener_rglob_detail"
156+ input_json = str (Path (__file__ ).parent / "water/se_atten.json" )
157+ with open (input_json ) as f :
158+ self .config = json .load (f )
159+ self .config ["training" ]["numb_steps" ] = 1
160+ self .config ["training" ]["save_freq" ] = 1
161+ data_file = [str (Path (__file__ ).parent / "water/data/single" )]
162+ self .config ["training" ]["training_data" ]["systems" ] = data_file
163+ root_dir = str (Path (__file__ ).parent )
164+ self .config ["training" ]["validation_data" ]["systems" ] = root_dir
165+ self .config ["training" ]["validation_data" ]["rglob_patterns" ] = [
166+ "water/data/single"
167+ ]
168+ self .config ["model" ] = deepcopy (model_se_e2_a )
169+ self .input_json = "test_dp_test_rglob.json"
170+ with open (self .input_json , "w" ) as fp :
171+ json .dump (self .config , fp , indent = 4 )
172+
173+ def test_dp_test_input_json_rglob (self ) -> None :
174+ trainer = get_trainer (deepcopy (self .config ))
175+ with torch .device ("cpu" ):
176+ input_dict , _ , _ = trainer .get_data (is_train = False )
177+ input_dict .pop ("spin" , None )
178+ model = torch .jit .script (trainer .model )
179+ tmp_model = tempfile .NamedTemporaryFile (delete = False , suffix = ".pth" )
180+ torch .jit .save (model , tmp_model .name )
181+ dp_test (
182+ model = tmp_model .name ,
183+ system = self .config ["training" ]["validation_data" ]["systems" ],
184+ datafile = None ,
185+ input_json = self .input_json ,
186+ set_prefix = "set" ,
187+ numb_test = 1 ,
188+ rand_seed = None ,
189+ shuffle_test = False ,
190+ detail_file = self .detail_file ,
191+ atomic = False ,
192+ )
193+ os .unlink (tmp_model .name )
194+ self .assertTrue (os .path .exists (self .detail_file + ".e.out" ))
195+
196+ def tearDown (self ) -> None :
197+ for f in os .listdir ("." ):
198+ if f .startswith ("model" ) and f .endswith (".pt" ):
199+ os .remove (f )
200+ if f .startswith (self .detail_file ):
201+ os .remove (f )
202+ if f in ["lcurve.out" , self .input_json ]:
203+ os .remove (f )
204+ if f in ["stat_files" ]:
205+ shutil .rmtree (f )
206+
207+
143208class TestDPTestPropertySeA (unittest .TestCase ):
144209 def setUp (self ) -> None :
145210 self .detail_file = "test_dp_test_property_detail"
0 commit comments