@@ -55,7 +55,6 @@ def test_consistency(
5555 nme ,
5656 prec ,
5757 ect ,
58- use_ext_ebd ,
5958 ) in itertools .product (
6059 [True , False ], # update_angle
6160 ["res_residual" ], # update_style
@@ -66,7 +65,6 @@ def test_consistency(
6665 [1 , 2 ], # n_multi_edge_message
6766 ["float64" ], # precision
6867 [False ], # use_econf_tebd
69- [False , True ], # use_ext_ebd
7068 ):
7169 dtype = PRECISION_DICT [prec ]
7270 rtol , atol = get_tols (prec )
@@ -105,7 +103,6 @@ def test_consistency(
105103 use_econf_tebd = ect ,
106104 type_map = ["O" , "H" ] if ect else None ,
107105 seed = GLOBAL_SEED ,
108- use_ext_ebd = use_ext_ebd ,
109106 ).to (env .DEVICE )
110107
111108 dd0 .repflows .mean = torch .tensor (davg , dtype = dtype , device = env .DEVICE )
@@ -142,100 +139,6 @@ def test_consistency(
142139 atol = atol ,
143140 )
144141
145- def test_use_ext_ebd (
146- self ,
147- ) -> None :
148- rtol , atol = get_tols ("float32" )
149-
150- nf , nloc , nnei = self .nlist .shape
151- repflow0 = RepFlowArgs (
152- n_dim = 20 ,
153- e_dim = 10 ,
154- a_dim = 8 ,
155- nlayers = 3 ,
156- e_rcut = self .rcut ,
157- e_rcut_smth = self .rcut_smth ,
158- e_sel = nnei ,
159- a_rcut = self .rcut - 0.1 ,
160- a_rcut_smth = self .rcut_smth ,
161- a_sel = nnei - 1 ,
162- a_compress_rate = 0 ,
163- a_compress_e_rate = 2 ,
164- a_compress_use_split = False ,
165- n_multi_edge_message = 1 ,
166- axis_neuron = 4 ,
167- update_angle = True ,
168- update_style = "res_residual" ,
169- update_residual_init = "const" ,
170- smooth_edge_update = True ,
171- use_ext_ebd = True ,
172- )
173- # dpa3 with use_ext_ebd=True
174- dd0 = DescrptDPA3 (
175- self .nt ,
176- repflow = repflow0 ,
177- # kwargs for descriptor
178- exclude_types = [],
179- precision = "float32" ,
180- use_econf_tebd = True ,
181- type_map = ["O" , "H" ],
182- seed = GLOBAL_SEED ,
183- use_ext_ebd = True ,
184- ).to (env .DEVICE )
185- rd0 , _ , _ , _ , _ = dd0 (
186- torch .tensor (self .coord_ext , dtype = dtype , device = env .DEVICE ),
187- torch .tensor (self .atype_ext , dtype = int , device = env .DEVICE ),
188- torch .tensor (self .nlist , dtype = int , device = env .DEVICE ),
189- torch .tensor (self .mapping , dtype = int , device = env .DEVICE ),
190- )
191-
192- # dpa3 with use_ext_ebd=False
193- repflow1 = RepFlowArgs (
194- n_dim = 20 ,
195- e_dim = 10 ,
196- a_dim = 8 ,
197- nlayers = 3 ,
198- e_rcut = self .rcut ,
199- e_rcut_smth = self .rcut_smth ,
200- e_sel = nnei ,
201- a_rcut = self .rcut - 0.1 ,
202- a_rcut_smth = self .rcut_smth ,
203- a_sel = nnei - 1 ,
204- a_compress_rate = 0 ,
205- a_compress_e_rate = 2 ,
206- a_compress_use_split = False ,
207- n_multi_edge_message = 1 ,
208- axis_neuron = 4 ,
209- update_angle = True ,
210- update_style = "res_residual" ,
211- update_residual_init = "const" ,
212- smooth_edge_update = True ,
213- use_ext_ebd = False ,
214- )
215- dd1 = DescrptDPA3 (
216- self .nt ,
217- repflow = repflow1 ,
218- # kwargs for descriptor
219- exclude_types = [],
220- precision = "float32" ,
221- use_econf_tebd = True ,
222- type_map = ["O" , "H" ],
223- seed = GLOBAL_SEED ,
224- use_ext_ebd = False ,
225- ).to (env .DEVICE )
226- rd1 , _ , _ , _ , _ = dd1 (
227- torch .tensor (self .coord_ext , dtype = dtype , device = env .DEVICE ),
228- torch .tensor (self .atype_ext , dtype = int , device = env .DEVICE ),
229- torch .tensor (self .nlist , dtype = int , device = env .DEVICE ),
230- torch .tensor (self .mapping , dtype = int , device = env .DEVICE ),
231- )
232- np .testing .assert_allclose (
233- rd0 .detach ().cpu ().numpy (),
234- rd1 .detach ().cpu ().numpy (),
235- rtol = rtol ,
236- atol = atol ,
237- )
238-
239142 def test_jit (
240143 self ,
241144 ) -> None :
0 commit comments