Skip to content

Commit 5e2a9e6

Browse files
committed
modify ut
1 parent dd0d642 commit 5e2a9e6

1 file changed

Lines changed: 0 additions & 97 deletions

File tree

source/tests/pt/model/test_dpa3.py

Lines changed: 0 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def test_consistency(
5555
nme,
5656
prec,
5757
ect,
58-
use_ext_ebd,
5958
) in itertools.product(
6059
[True, False], # update_angle
6160
["res_residual"], # update_style
@@ -66,7 +65,6 @@ def test_consistency(
6665
[1, 2], # n_multi_edge_message
6766
["float64"], # precision
6867
[False], # use_econf_tebd
69-
[False, True], # use_ext_ebd
7068
):
7169
dtype = PRECISION_DICT[prec]
7270
rtol, atol = get_tols(prec)
@@ -105,7 +103,6 @@ def test_consistency(
105103
use_econf_tebd=ect,
106104
type_map=["O", "H"] if ect else None,
107105
seed=GLOBAL_SEED,
108-
use_ext_ebd=use_ext_ebd,
109106
).to(env.DEVICE)
110107

111108
dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=env.DEVICE)
@@ -142,100 +139,6 @@ def test_consistency(
142139
atol=atol,
143140
)
144141

145-
def test_use_ext_ebd(
146-
self,
147-
) -> None:
148-
rtol, atol = get_tols("float32")
149-
150-
nf, nloc, nnei = self.nlist.shape
151-
repflow0 = RepFlowArgs(
152-
n_dim=20,
153-
e_dim=10,
154-
a_dim=8,
155-
nlayers=3,
156-
e_rcut=self.rcut,
157-
e_rcut_smth=self.rcut_smth,
158-
e_sel=nnei,
159-
a_rcut=self.rcut - 0.1,
160-
a_rcut_smth=self.rcut_smth,
161-
a_sel=nnei - 1,
162-
a_compress_rate=0,
163-
a_compress_e_rate=2,
164-
a_compress_use_split=False,
165-
n_multi_edge_message=1,
166-
axis_neuron=4,
167-
update_angle=True,
168-
update_style="res_residual",
169-
update_residual_init="const",
170-
smooth_edge_update=True,
171-
use_ext_ebd=True,
172-
)
173-
# dpa3 with use_ext_ebd=True
174-
dd0 = DescrptDPA3(
175-
self.nt,
176-
repflow=repflow0,
177-
# kwargs for descriptor
178-
exclude_types=[],
179-
precision="float32",
180-
use_econf_tebd=True,
181-
type_map=["O", "H"],
182-
seed=GLOBAL_SEED,
183-
use_ext_ebd=True,
184-
).to(env.DEVICE)
185-
rd0, _, _, _, _ = dd0(
186-
torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE),
187-
torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE),
188-
torch.tensor(self.nlist, dtype=int, device=env.DEVICE),
189-
torch.tensor(self.mapping, dtype=int, device=env.DEVICE),
190-
)
191-
192-
# dpa3 with use_ext_ebd=False
193-
repflow1 = RepFlowArgs(
194-
n_dim=20,
195-
e_dim=10,
196-
a_dim=8,
197-
nlayers=3,
198-
e_rcut=self.rcut,
199-
e_rcut_smth=self.rcut_smth,
200-
e_sel=nnei,
201-
a_rcut=self.rcut - 0.1,
202-
a_rcut_smth=self.rcut_smth,
203-
a_sel=nnei - 1,
204-
a_compress_rate=0,
205-
a_compress_e_rate=2,
206-
a_compress_use_split=False,
207-
n_multi_edge_message=1,
208-
axis_neuron=4,
209-
update_angle=True,
210-
update_style="res_residual",
211-
update_residual_init="const",
212-
smooth_edge_update=True,
213-
use_ext_ebd=False,
214-
)
215-
dd1 = DescrptDPA3(
216-
self.nt,
217-
repflow=repflow1,
218-
# kwargs for descriptor
219-
exclude_types=[],
220-
precision="float32",
221-
use_econf_tebd=True,
222-
type_map=["O", "H"],
223-
seed=GLOBAL_SEED,
224-
use_ext_ebd=False,
225-
).to(env.DEVICE)
226-
rd1, _, _, _, _ = dd1(
227-
torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE),
228-
torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE),
229-
torch.tensor(self.nlist, dtype=int, device=env.DEVICE),
230-
torch.tensor(self.mapping, dtype=int, device=env.DEVICE),
231-
)
232-
np.testing.assert_allclose(
233-
rd0.detach().cpu().numpy(),
234-
rd1.detach().cpu().numpy(),
235-
rtol=rtol,
236-
atol=atol,
237-
)
238-
239142
def test_jit(
240143
self,
241144
) -> None:

0 commit comments

Comments
 (0)