Skip to content

Commit 3d3fd09

Browse files
author
Han Wang
committed
feat(c++): add DPA3 .pt2 export and C/C++ inference tests
Fix torch.export-incompatible slicing in dpmodel DPA3/repflows code ([:, :nloc] → xp_take_first_n, reshape with -1 → expand_dims), add gen_dpa3.py model generation script, and create C++ test files for both .pth and .pt2 backends with PBC and NoPbc fixtures.
1 parent ac78e07 commit 3d3fd09

6 files changed

Lines changed: 1110 additions & 5 deletions

File tree

deepmd/dpmodel/descriptor/dpa3.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
from deepmd.dpmodel.array_api import (
1212
Array,
13+
xp_take_first_n,
1314
)
1415
from deepmd.dpmodel.common import (
1516
cast_precision,
@@ -617,15 +618,19 @@ def call(
617618
type_embedding = self.type_embedding.call()
618619
if self.use_loc_mapping:
619620
node_ebd_ext = xp.reshape(
620-
xp.take(type_embedding, xp.reshape(atype_ext[:, :nloc], (-1,)), axis=0),
621+
xp.take(
622+
type_embedding,
623+
xp.reshape(xp_take_first_n(atype_ext, 1, nloc), (-1,)),
624+
axis=0,
625+
),
621626
(nframes, nloc, self.tebd_dim),
622627
)
623628
else:
624629
node_ebd_ext = xp.reshape(
625630
xp.take(type_embedding, xp.reshape(atype_ext, (-1,)), axis=0),
626631
(nframes, nall, self.tebd_dim),
627632
)
628-
node_ebd_inp = node_ebd_ext[:, :nloc, :]
633+
node_ebd_inp = xp_take_first_n(node_ebd_ext, 1, nloc)
629634
# repflows
630635
node_ebd, edge_ebd, h2, rot_mat, sw = self.repflows(
631636
nlist,

deepmd/dpmodel/descriptor/repflows.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from deepmd.dpmodel.array_api import (
1414
Array,
1515
xp_take_along_axis,
16+
xp_take_first_n,
1617
)
1718
from deepmd.dpmodel.common import (
1819
to_numpy_array,
@@ -562,7 +563,7 @@ def call(
562563

563564
# get node embedding
564565
# nb x nloc x tebd_dim
565-
atype_embd = atype_embd_ext[:, :nloc, :]
566+
atype_embd = xp_take_first_n(atype_embd_ext, 1, nloc)
566567
assert list(atype_embd.shape) == [nframes, nloc, self.n_dim]
567568

568569
node_ebd = self.act(atype_embd)
@@ -641,7 +642,7 @@ def call(
641642
angle_ebd = self.angle_embd(angle_input)
642643

643644
# nb x nall x n_dim
644-
mapping = xp.tile(xp.reshape(mapping, (nframes, -1, 1)), (1, 1, self.n_dim))
645+
mapping = xp.tile(xp.expand_dims(mapping, axis=-1), (1, 1, self.n_dim))
645646
for idx, ll in enumerate(self.layers):
646647
# node_ebd: nb x nloc x n_dim
647648
# node_ebd_ext: nb x nall x n_dim
@@ -1421,7 +1422,7 @@ def call(
14211422
n_edge = (
14221423
int(xp.sum(xp.astype(nlist_mask, xp.int32))) if self.use_dynamic_sel else 0
14231424
)
1424-
node_ebd = node_ebd_ext[:, :nloc, :]
1425+
node_ebd = xp_take_first_n(node_ebd_ext, 1, nloc)
14251426
assert (nb, nloc) == node_ebd.shape[:2]
14261427
if not self.use_dynamic_sel:
14271428
assert (nb, nloc, nnei) == h2.shape[:3]
Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
// SPDX-License-Identifier: LGPL-3.0-or-later
2+
// Test C++ inference for PT (.pth) backend with DPA3 (mixed-type) model.
3+
// Reference values generated by source/tests/infer/gen_dpa3.py.
4+
#include <gtest/gtest.h>
5+
6+
#include <algorithm>
7+
#include <cmath>
8+
#include <fstream>
9+
#include <vector>
10+
11+
#include "DeepPot.h"
12+
#include "neighbor_list.h"
13+
#include "test_utils.h"
14+
15+
// DPA3 models need relaxed epsilon
16+
#undef EPSILON
17+
#define EPSILON (std::is_same<VALUETYPE, double>::value ? 1e-7 : 1e-1)
18+
19+
template <class VALUETYPE>
20+
class TestInferDeepPotDpa3Pt : public ::testing::Test {
21+
protected:
22+
std::vector<VALUETYPE> coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
23+
00.25, 3.32, 1.68, 3.36, 3.00, 1.81,
24+
3.51, 2.51, 2.60, 4.27, 3.22, 1.56};
25+
std::vector<int> atype = {0, 1, 1, 0, 1, 1};
26+
std::vector<VALUETYPE> box = {13., 0., 0., 0., 13., 0., 0., 0., 13.};
27+
// Generated by source/tests/infer/gen_dpa3.py (PBC)
28+
// Same weights as deeppot_dpa3.pt2
29+
std::vector<VALUETYPE> expected_e = {
30+
2.733142942358297023e-01, 2.768815473296480922e-01,
31+
2.781664369968356865e-01, 2.697839344989072519e-01,
32+
2.741210600049306945e-01, 2.752870928812235496e-01};
33+
std::vector<VALUETYPE> expected_f = {
34+
-1.962618723134541832e-02, 4.287158582278347702e-02,
35+
7.640666386947853050e-03, 5.554130248696588501e-02,
36+
-6.501206231527984977e-03, -4.524468847893595158e-02,
37+
-3.851051736663693714e-02, -3.620789238677154381e-02,
38+
3.756162244251591564e-02, 6.729090678104879264e-02,
39+
-2.430710555108604037e-02, 4.496058666120762021e-02,
40+
9.285825331084011924e-03, 5.623126339971108029e-02,
41+
-8.776072674283137698e-02, -7.398133000111631330e-02,
42+
-3.208664505310900028e-02, 4.284253973109593966e-02};
43+
std::vector<VALUETYPE> expected_v = {
44+
-2.519191242984861884e-02, -7.976296517418629550e-04,
45+
2.293255716383547221e-02, -1.129879902880513709e-04,
46+
-2.480533869648754441e-02, 5.147545203263749480e-03,
47+
2.250634701911344987e-02, 5.288887046140826331e-03,
48+
-2.010244267109611085e-02, -1.779331319768159489e-02,
49+
3.093850189397499839e-03, 1.469388965841003300e-02,
50+
-3.857294749719837688e-03, 1.122172669801067097e-03,
51+
3.015485878866499582e-03, 1.588838841470147090e-02,
52+
-2.814760933954751562e-03, -1.277216714527013713e-02,
53+
-8.763367643346370306e-03, -1.305889135368112908e-02,
54+
1.181350951828694096e-02, -6.506014073233991855e-03,
55+
-6.021216432246893902e-03, 6.406967309407277100e-03,
56+
1.054423249710041179e-02, 1.210616766999832172e-02,
57+
-1.127472660426425549e-02, -3.873334330831591787e-02,
58+
-3.620067664760272686e-03, 1.173198873109224322e-03,
59+
-3.979800321914496279e-03, -1.483777776121806245e-02,
60+
2.311848485249741111e-02, 1.659292900032220339e-03,
61+
2.315104663227764842e-02, -3.645194750481960122e-02,
62+
-1.668107738824501848e-04, -7.331929353596922626e-03,
63+
1.141573012886789966e-02, -1.498650485705460686e-03,
64+
-1.339178008942835431e-02, 2.104129816063767672e-02,
65+
2.247013447171188061e-03, 2.035538814221872148e-02,
66+
-3.195007182084359104e-02, -2.339460083073257798e-02,
67+
-1.001949167693141039e-02, 1.320033846426920537e-02,
68+
-1.577941189045228843e-02, -6.283307183655661120e-03,
69+
8.237968913765561507e-03, 2.238394952866012630e-02,
70+
8.881021761757389166e-03, -1.162377795308391741e-02};
71+
int natoms;
72+
double expected_tot_e;
73+
std::vector<VALUETYPE> expected_tot_v;
74+
75+
deepmd::DeepPot dp;
76+
77+
void SetUp() override {
78+
#ifndef BUILD_PYTORCH
79+
GTEST_SKIP() << "Skip because PyTorch support is not enabled.";
80+
#endif
81+
dp.init("../../tests/infer/deeppot_dpa3.pth");
82+
83+
natoms = expected_e.size();
84+
EXPECT_EQ(natoms * 3, expected_f.size());
85+
EXPECT_EQ(natoms * 9, expected_v.size());
86+
expected_tot_e = 0.;
87+
expected_tot_v.resize(9);
88+
std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
89+
for (int ii = 0; ii < natoms; ++ii) {
90+
expected_tot_e += expected_e[ii];
91+
}
92+
for (int ii = 0; ii < natoms; ++ii) {
93+
for (int dd = 0; dd < 9; ++dd) {
94+
expected_tot_v[dd] += expected_v[ii * 9 + dd];
95+
}
96+
}
97+
};
98+
99+
void TearDown() override {};
100+
};
101+
102+
TYPED_TEST_SUITE(TestInferDeepPotDpa3Pt, ValueTypes);
103+
104+
TYPED_TEST(TestInferDeepPotDpa3Pt, cpu_build_nlist) {
105+
using VALUETYPE = TypeParam;
106+
std::vector<VALUETYPE>& coord = this->coord;
107+
std::vector<int>& atype = this->atype;
108+
std::vector<VALUETYPE>& box = this->box;
109+
std::vector<VALUETYPE>& expected_f = this->expected_f;
110+
int& natoms = this->natoms;
111+
double& expected_tot_e = this->expected_tot_e;
112+
std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
113+
deepmd::DeepPot& dp = this->dp;
114+
double ener;
115+
std::vector<VALUETYPE> force, virial;
116+
dp.compute(ener, force, virial, coord, atype, box);
117+
118+
EXPECT_EQ(force.size(), natoms * 3);
119+
EXPECT_EQ(virial.size(), 9);
120+
121+
EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
122+
for (int ii = 0; ii < natoms * 3; ++ii) {
123+
EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
124+
}
125+
for (int ii = 0; ii < 3 * 3; ++ii) {
126+
EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
127+
}
128+
}
129+
130+
TYPED_TEST(TestInferDeepPotDpa3Pt, cpu_build_nlist_atomic) {
131+
using VALUETYPE = TypeParam;
132+
std::vector<VALUETYPE>& coord = this->coord;
133+
std::vector<int>& atype = this->atype;
134+
std::vector<VALUETYPE>& box = this->box;
135+
std::vector<VALUETYPE>& expected_e = this->expected_e;
136+
std::vector<VALUETYPE>& expected_f = this->expected_f;
137+
std::vector<VALUETYPE>& expected_v = this->expected_v;
138+
int& natoms = this->natoms;
139+
double& expected_tot_e = this->expected_tot_e;
140+
std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
141+
deepmd::DeepPot& dp = this->dp;
142+
double ener;
143+
std::vector<VALUETYPE> force, virial, atom_ener, atom_vir;
144+
dp.compute(ener, force, virial, atom_ener, atom_vir, coord, atype, box);
145+
146+
EXPECT_EQ(force.size(), natoms * 3);
147+
EXPECT_EQ(virial.size(), 9);
148+
EXPECT_EQ(atom_ener.size(), natoms);
149+
EXPECT_EQ(atom_vir.size(), natoms * 9);
150+
151+
EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
152+
for (int ii = 0; ii < natoms * 3; ++ii) {
153+
EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
154+
}
155+
for (int ii = 0; ii < 3 * 3; ++ii) {
156+
EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
157+
}
158+
for (int ii = 0; ii < natoms; ++ii) {
159+
EXPECT_LT(fabs(atom_ener[ii] - expected_e[ii]), EPSILON);
160+
}
161+
for (int ii = 0; ii < natoms * 9; ++ii) {
162+
EXPECT_LT(fabs(atom_vir[ii] - expected_v[ii]), EPSILON);
163+
}
164+
}
165+
166+
template <class VALUETYPE>
167+
class TestInferDeepPotDpa3PtNoPbc : public ::testing::Test {
168+
protected:
169+
std::vector<VALUETYPE> coord = {12.83, 2.56, 2.18, 12.09, 2.87, 2.74,
170+
00.25, 3.32, 1.68, 3.36, 3.00, 1.81,
171+
3.51, 2.51, 2.60, 4.27, 3.22, 1.56};
172+
std::vector<int> atype = {0, 1, 1, 0, 1, 1};
173+
std::vector<VALUETYPE> box = {};
174+
// Generated by source/tests/infer/gen_dpa3.py (NoPbc)
175+
std::vector<VALUETYPE> expected_e = {
176+
2.748896667984845887e-01, 2.803947322373078754e-01,
177+
2.865499847997139971e-01, 2.695555136277474895e-01,
178+
2.739584531066059925e-01, 2.752217127378932537e-01};
179+
std::vector<VALUETYPE> expected_f = {
180+
-4.469562373941994571e-02, 1.872384237732456838e-02,
181+
3.382371526226372882e-02, 4.469562373941994571e-02,
182+
-1.872384237732456838e-02, -3.382371526226372882e-02,
183+
-8.962417443747255821e-04, 6.973117535150641388e-05,
184+
3.708588577163370883e-05, 6.643516471939500678e-02,
185+
-2.418189932122343649e-02, 4.484243027251725439e-02,
186+
9.031619071676464522e-03, 5.637239343551967569e-02,
187+
-8.796029317613156262e-02, -7.457054204669674724e-02,
188+
-3.226022528964775371e-02, 4.308077701784267244e-02};
189+
std::vector<VALUETYPE> expected_v = {
190+
-1.634330450074628072e-02, 6.846519453015231793e-03,
191+
1.236790610867266604e-02, 6.846519453015259549e-03,
192+
-2.868136527614494058e-03, -5.181149856335852399e-03,
193+
1.236790610867266604e-02, -5.181149856335859338e-03,
194+
-9.359496514671244993e-03, -1.673145706642453767e-02,
195+
7.009123906204950405e-03, 1.266164318540249911e-02,
196+
7.009123906204922649e-03, -2.936254609356120371e-03,
197+
-5.304201874965906727e-03, 1.266164318540247136e-02,
198+
-5.304201874965899788e-03, -9.581784032196449807e-03,
199+
2.483905957089865488e-03, -1.710616363479115602e-04,
200+
-5.347582359011894028e-05, -1.996686279554130779e-04,
201+
1.446275632786597548e-05, 2.638112328458543858e-06,
202+
-1.197563523836930226e-04, 1.205600575305949503e-05,
203+
-4.593499883389132697e-06, -4.089897480719173473e-02,
204+
-3.495830205935246404e-03, 1.154978330068986980e-03,
205+
-3.627142383941225900e-03, -1.488475129792680290e-02,
206+
2.311785022979555293e-02, 1.347848716528848856e-03,
207+
2.315545736893441509e-02, -3.642400982788428221e-02,
208+
-1.119743233540158867e-03, -7.327254171127076110e-03,
209+
1.144439607350029517e-02, -1.403015516843159061e-03,
210+
-1.349644754565121341e-02, 2.117430870829728473e-02,
211+
2.103115217604090148e-03, 2.047643373328661420e-02,
212+
-3.212706064943796069e-02, -2.418232649309504101e-02,
213+
-1.012366394018440752e-02, 1.334822742508814941e-02,
214+
-1.588798342485496506e-02, -6.330672283764562924e-03,
215+
8.295385033255518736e-03, 2.256291842331806241e-02,
216+
8.946234975702738179e-03, -1.170798305154926999e-02};
217+
int natoms;
218+
double expected_tot_e;
219+
std::vector<VALUETYPE> expected_tot_v;
220+
221+
deepmd::DeepPot dp;
222+
223+
void SetUp() override {
224+
#ifndef BUILD_PYTORCH
225+
GTEST_SKIP() << "Skip because PyTorch support is not enabled.";
226+
#endif
227+
dp.init("../../tests/infer/deeppot_dpa3.pth");
228+
229+
natoms = expected_e.size();
230+
EXPECT_EQ(natoms * 3, expected_f.size());
231+
EXPECT_EQ(natoms * 9, expected_v.size());
232+
expected_tot_e = 0.;
233+
expected_tot_v.resize(9);
234+
std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
235+
for (int ii = 0; ii < natoms; ++ii) {
236+
expected_tot_e += expected_e[ii];
237+
}
238+
for (int ii = 0; ii < natoms; ++ii) {
239+
for (int dd = 0; dd < 9; ++dd) {
240+
expected_tot_v[dd] += expected_v[ii * 9 + dd];
241+
}
242+
}
243+
};
244+
245+
void TearDown() override {};
246+
};
247+
248+
TYPED_TEST_SUITE(TestInferDeepPotDpa3PtNoPbc, ValueTypes);
249+
250+
TYPED_TEST(TestInferDeepPotDpa3PtNoPbc, cpu_build_nlist) {
251+
using VALUETYPE = TypeParam;
252+
std::vector<VALUETYPE>& coord = this->coord;
253+
std::vector<int>& atype = this->atype;
254+
std::vector<VALUETYPE>& box = this->box;
255+
std::vector<VALUETYPE>& expected_f = this->expected_f;
256+
int& natoms = this->natoms;
257+
double& expected_tot_e = this->expected_tot_e;
258+
std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
259+
deepmd::DeepPot& dp = this->dp;
260+
double ener;
261+
std::vector<VALUETYPE> force, virial;
262+
dp.compute(ener, force, virial, coord, atype, box);
263+
264+
EXPECT_EQ(force.size(), natoms * 3);
265+
EXPECT_EQ(virial.size(), 9);
266+
267+
EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
268+
for (int ii = 0; ii < natoms * 3; ++ii) {
269+
EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
270+
}
271+
for (int ii = 0; ii < 3 * 3; ++ii) {
272+
EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
273+
}
274+
}
275+
276+
TYPED_TEST(TestInferDeepPotDpa3PtNoPbc, cpu_lmp_nlist) {
277+
using VALUETYPE = TypeParam;
278+
std::vector<VALUETYPE>& coord = this->coord;
279+
std::vector<int>& atype = this->atype;
280+
std::vector<VALUETYPE>& box = this->box;
281+
std::vector<VALUETYPE>& expected_f = this->expected_f;
282+
int& natoms = this->natoms;
283+
double& expected_tot_e = this->expected_tot_e;
284+
std::vector<VALUETYPE>& expected_tot_v = this->expected_tot_v;
285+
deepmd::DeepPot& dp = this->dp;
286+
double ener;
287+
std::vector<VALUETYPE> force, virial;
288+
289+
std::vector<std::vector<int> > nlist_data = {
290+
{1, 2, 3, 4, 5}, {0, 2, 3, 4, 5}, {0, 1, 3, 4, 5},
291+
{0, 1, 2, 4, 5}, {0, 1, 2, 3, 5}, {0, 1, 2, 3, 4}};
292+
std::vector<int> ilist(natoms), numneigh(natoms);
293+
std::vector<int*> firstneigh(natoms);
294+
deepmd::InputNlist inlist(natoms, &ilist[0], &numneigh[0], &firstneigh[0]);
295+
convert_nlist(inlist, nlist_data);
296+
dp.compute(ener, force, virial, coord, atype, box, 0, inlist, 0);
297+
298+
EXPECT_EQ(force.size(), natoms * 3);
299+
EXPECT_EQ(virial.size(), 9);
300+
301+
EXPECT_LT(fabs(ener - expected_tot_e), EPSILON);
302+
for (int ii = 0; ii < natoms * 3; ++ii) {
303+
EXPECT_LT(fabs(force[ii] - expected_f[ii]), EPSILON);
304+
}
305+
for (int ii = 0; ii < 3 * 3; ++ii) {
306+
EXPECT_LT(fabs(virial[ii] - expected_tot_v[ii]), EPSILON);
307+
}
308+
}

0 commit comments

Comments
 (0)