Skip to content

Commit a746f82

Browse files
committed
fix UT
1 parent 7b7cdfa commit a746f82

3 files changed

Lines changed: 97 additions & 1 deletion

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ def __init__(
150150
skip_stat: bool = False,
151151
optim_update: bool = True,
152152
smooth_edge_update: bool = False,
153+
use_ext_ebd: bool = False,
153154
) -> None:
154155
self.n_dim = n_dim
155156
self.e_dim = e_dim
@@ -176,6 +177,7 @@ def __init__(
176177
self.a_compress_use_split = a_compress_use_split
177178
self.optim_update = optim_update
178179
self.smooth_edge_update = smooth_edge_update
180+
self.use_ext_ebd = use_ext_ebd
179181

180182
def __getitem__(self, key):
181183
if hasattr(self, key):
@@ -207,6 +209,7 @@ def serialize(self) -> dict:
207209
"fix_stat_std": self.fix_stat_std,
208210
"optim_update": self.optim_update,
209211
"smooth_edge_update": self.smooth_edge_update,
212+
"use_ext_ebd": self.use_ext_ebd,
210213
}
211214

212215
@classmethod

deepmd/pt/model/descriptor/repflows.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -462,7 +462,7 @@ def forward(
462462
assert mapping is not None
463463
node_ebd_ext = None
464464
nlist = torch.gather(
465-
mapping,
465+
mapping.reshape(nframes, -1),
466466
1,
467467
nlist.reshape(nframes, -1),
468468
).reshape(nlist.shape)

source/tests/pt/model/test_dpa3.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,99 @@ def test_consistency(
141141
rtol=rtol,
142142
atol=atol,
143143
)
144+
def test_use_ext_ebd(
145+
self,
146+
) -> None:
147+
rtol, atol = get_tols("float32")
148+
149+
nf, nloc, nnei = self.nlist.shape
150+
repflow0 = RepFlowArgs(
151+
n_dim=20,
152+
e_dim=10,
153+
a_dim=8,
154+
nlayers=3,
155+
e_rcut=self.rcut,
156+
e_rcut_smth=self.rcut_smth,
157+
e_sel=nnei,
158+
a_rcut=self.rcut - 0.1,
159+
a_rcut_smth=self.rcut_smth,
160+
a_sel=nnei - 1,
161+
a_compress_rate=0,
162+
a_compress_e_rate=2,
163+
a_compress_use_split=False,
164+
n_multi_edge_message=1,
165+
axis_neuron=4,
166+
update_angle=True,
167+
update_style="res_residual",
168+
update_residual_init="const",
169+
smooth_edge_update=True,
170+
use_ext_ebd=True,
171+
)
172+
# dpa3 with use_ext_ebd=True
173+
dd0 = DescrptDPA3(
174+
self.nt,
175+
repflow=repflow0,
176+
# kwargs for descriptor
177+
exclude_types=[],
178+
precision="float32",
179+
use_econf_tebd=True,
180+
type_map=["O", "H"],
181+
seed=GLOBAL_SEED,
182+
use_ext_ebd=True,
183+
).to(env.DEVICE)
184+
rd0, _, _, _, _ = dd0(
185+
torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE),
186+
torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE),
187+
torch.tensor(self.nlist, dtype=int, device=env.DEVICE),
188+
torch.tensor(self.mapping, dtype=int, device=env.DEVICE),
189+
)
190+
191+
# dpa3 with use_ext_ebd=False
192+
repflow1 = RepFlowArgs(
193+
n_dim=20,
194+
e_dim=10,
195+
a_dim=8,
196+
nlayers=3,
197+
e_rcut=self.rcut,
198+
e_rcut_smth=self.rcut_smth,
199+
e_sel=nnei,
200+
a_rcut=self.rcut - 0.1,
201+
a_rcut_smth=self.rcut_smth,
202+
a_sel=nnei - 1,
203+
a_compress_rate=0,
204+
a_compress_e_rate=2,
205+
a_compress_use_split=False,
206+
n_multi_edge_message=1,
207+
axis_neuron=4,
208+
update_angle=True,
209+
update_style="res_residual",
210+
update_residual_init="const",
211+
smooth_edge_update=True,
212+
use_ext_ebd=False,
213+
)
214+
dd1 = DescrptDPA3(
215+
self.nt,
216+
repflow=repflow1,
217+
# kwargs for descriptor
218+
exclude_types=[],
219+
precision="float32",
220+
use_econf_tebd=True,
221+
type_map=["O", "H"],
222+
seed=GLOBAL_SEED,
223+
use_ext_ebd=False,
224+
).to(env.DEVICE)
225+
rd1, _, _, _, _ = dd1(
226+
torch.tensor(self.coord_ext, dtype=dtype, device=env.DEVICE),
227+
torch.tensor(self.atype_ext, dtype=int, device=env.DEVICE),
228+
torch.tensor(self.nlist, dtype=int, device=env.DEVICE),
229+
torch.tensor(self.mapping, dtype=int, device=env.DEVICE),
230+
)
231+
np.testing.assert_allclose(
232+
rd0.detach().cpu().numpy(),
233+
rd1.detach().cpu().numpy(),
234+
rtol=rtol,
235+
atol=atol,
236+
)
144237

145238
def test_jit(
146239
self,

0 commit comments

Comments
 (0)