@@ -141,6 +141,99 @@ def test_consistency(
141141 rtol = rtol ,
142142 atol = atol ,
143143 )
144+ def test_use_ext_ebd (
145+ self ,
146+ ) -> None :
147+ rtol , atol = get_tols ("float32" )
148+
149+ nf , nloc , nnei = self .nlist .shape
150+ repflow0 = RepFlowArgs (
151+ n_dim = 20 ,
152+ e_dim = 10 ,
153+ a_dim = 8 ,
154+ nlayers = 3 ,
155+ e_rcut = self .rcut ,
156+ e_rcut_smth = self .rcut_smth ,
157+ e_sel = nnei ,
158+ a_rcut = self .rcut - 0.1 ,
159+ a_rcut_smth = self .rcut_smth ,
160+ a_sel = nnei - 1 ,
161+ a_compress_rate = 0 ,
162+ a_compress_e_rate = 2 ,
163+ a_compress_use_split = False ,
164+ n_multi_edge_message = 1 ,
165+ axis_neuron = 4 ,
166+ update_angle = True ,
167+ update_style = "res_residual" ,
168+ update_residual_init = "const" ,
169+ smooth_edge_update = True ,
170+ use_ext_ebd = True ,
171+ )
172+ # dpa3 with use_ext_ebd=True
173+ dd0 = DescrptDPA3 (
174+ self .nt ,
175+ repflow = repflow0 ,
176+ # kwargs for descriptor
177+ exclude_types = [],
178+ precision = "float32" ,
179+ use_econf_tebd = True ,
180+ type_map = ["O" , "H" ],
181+ seed = GLOBAL_SEED ,
182+ use_ext_ebd = True ,
183+ ).to (env .DEVICE )
184+ rd0 , _ , _ , _ , _ = dd0 (
185+ torch .tensor (self .coord_ext , dtype = dtype , device = env .DEVICE ),
186+ torch .tensor (self .atype_ext , dtype = int , device = env .DEVICE ),
187+ torch .tensor (self .nlist , dtype = int , device = env .DEVICE ),
188+ torch .tensor (self .mapping , dtype = int , device = env .DEVICE ),
189+ )
190+
191+ # dpa3 with use_ext_ebd=False
192+ repflow1 = RepFlowArgs (
193+ n_dim = 20 ,
194+ e_dim = 10 ,
195+ a_dim = 8 ,
196+ nlayers = 3 ,
197+ e_rcut = self .rcut ,
198+ e_rcut_smth = self .rcut_smth ,
199+ e_sel = nnei ,
200+ a_rcut = self .rcut - 0.1 ,
201+ a_rcut_smth = self .rcut_smth ,
202+ a_sel = nnei - 1 ,
203+ a_compress_rate = 0 ,
204+ a_compress_e_rate = 2 ,
205+ a_compress_use_split = False ,
206+ n_multi_edge_message = 1 ,
207+ axis_neuron = 4 ,
208+ update_angle = True ,
209+ update_style = "res_residual" ,
210+ update_residual_init = "const" ,
211+ smooth_edge_update = True ,
212+ use_ext_ebd = False ,
213+ )
214+ dd1 = DescrptDPA3 (
215+ self .nt ,
216+ repflow = repflow1 ,
217+ # kwargs for descriptor
218+ exclude_types = [],
219+ precision = "float32" ,
220+ use_econf_tebd = True ,
221+ type_map = ["O" , "H" ],
222+ seed = GLOBAL_SEED ,
223+ use_ext_ebd = False ,
224+ ).to (env .DEVICE )
225+ rd1 , _ , _ , _ , _ = dd1 (
226+ torch .tensor (self .coord_ext , dtype = dtype , device = env .DEVICE ),
227+ torch .tensor (self .atype_ext , dtype = int , device = env .DEVICE ),
228+ torch .tensor (self .nlist , dtype = int , device = env .DEVICE ),
229+ torch .tensor (self .mapping , dtype = int , device = env .DEVICE ),
230+ )
231+ np .testing .assert_allclose (
232+ rd0 .detach ().cpu ().numpy (),
233+ rd1 .detach ().cpu ().numpy (),
234+ rtol = rtol ,
235+ atol = atol ,
236+ )
144237
145238 def test_jit (
146239 self ,
0 commit comments