99
1010#include " DeepPot.h"
1111#include " DeepPotPTExpt.h"
12+ #include " expected_ref.h"
1213#include " test_utils.h"
1314
15+ namespace {
16+ constexpr const char * kRefPath = " ../../tests/infer/fparam_aparam.expected" ;
17+ constexpr const char * kModelPath = " ../../tests/infer/fparam_aparam.pt2" ;
18+ } // namespace
19+
1420template <class VALUETYPE >
1521class TestInferDeepPotAFparamAparamNFramesPtExpt : public ::testing::Test {
1622 protected:
@@ -27,88 +33,9 @@ class TestInferDeepPotAFparamAparamNFramesPtExpt : public ::testing::Test {
2733 std::vector<VALUETYPE> aparam = {
2834 0.25852028 , 0.25852028 , 0.25852028 , 0.25852028 , 0.25852028 , 0.25852028 ,
2935 0.25852028 , 0.25852028 , 0.25852028 , 0.25852028 , 0.25852028 , 0.25852028 };
30- // Same reference values as single-frame, duplicated for 2 frames
31- std::vector<VALUETYPE> expected_e = {
32- -1.038271223729636539e-01 , -7.285433579124989123e-02 ,
33- -9.467600492266425860e-02 , -1.467050207422957442e-01 ,
34- -7.660561676973243195e-02 , -7.277296000253175023e-02 ,
35- -1.038271223729636539e-01 , -7.285433579124989123e-02 ,
36- -9.467600492266425860e-02 , -1.467050207422957442e-01 ,
37- -7.660561676973243195e-02 , -7.277296000253175023e-02 };
38- std::vector<VALUETYPE> expected_f = {
39- 6.622266941151369601e-02 , 5.278739714221529489e-02 ,
40- 2.265728009692277028e-02 , -2.606048291367509331e-02 ,
41- -4.538812303131847109e-02 , 1.058247419681241676e-02 ,
42- 1.679392617013223121e-01 , -2.257826240741929533e-03 ,
43- -4.490146347357203138e-02 , -1.148364179422036724e-01 ,
44- -1.169790528013799069e-02 , 6.140403441496700837e-02 ,
45- -8.078778123309421355e-02 , -5.838879041789352825e-02 ,
46- 6.773641084621376263e-02 , -1.247724902386305318e-02 ,
47- 6.494524782787665373e-02 , -1.174787360813439457e-01 ,
48- 6.622266941151369601e-02 , 5.278739714221529489e-02 ,
49- 2.265728009692277028e-02 , -2.606048291367509331e-02 ,
50- -4.538812303131847109e-02 , 1.058247419681241676e-02 ,
51- 1.679392617013223121e-01 , -2.257826240741929533e-03 ,
52- -4.490146347357203138e-02 , -1.148364179422036724e-01 ,
53- -1.169790528013799069e-02 , 6.140403441496700837e-02 ,
54- -8.078778123309421355e-02 , -5.838879041789352825e-02 ,
55- 6.773641084621376263e-02 , -1.247724902386305318e-02 ,
56- 6.494524782787665373e-02 , -1.174787360813439457e-01 };
57- std::vector<VALUETYPE> expected_v = {
58- -1.589185601903579381e-01 , 2.586167090689088510e-03 ,
59- -1.575150812458056548e-04 , -1.855360549216640564e-02 ,
60- 1.949822308966445150e-02 , -1.006552178977542650e-02 ,
61- 3.177030388421490936e-02 , 1.714350280402104215e-03 ,
62- -1.290389705296313833e-03 , -8.553511587973079699e-02 ,
63- -5.654638208496251539e-03 , -1.286955066237439882e-02 ,
64- 2.464156699303176462e-02 , -2.398203243424212178e-02 ,
65- -1.957110698882909630e-02 , 2.233493653505165544e-02 ,
66- 6.107843889444162372e-03 , 1.707076397717688723e-03 ,
67- -1.653994136896924094e-01 , 3.894358809712639147e-02 ,
68- -2.169596032233910010e-02 , 6.819702786556020371e-03 ,
69- -5.018240707559744503e-03 , 2.640663592968431426e-03 ,
70- -1.985295554050418160e-03 , -3.638422207618969423e-02 ,
71- 2.342932709960221863e-02 , -8.501331666888653493e-02 ,
72- -2.181253119706856591e-03 , 4.311299629418858387e-03 ,
73- -1.910329576491436726e-03 , -1.808810428459609043e-03 ,
74- -1.540075460017477360e-03 , -1.173703527688202929e-02 ,
75- -2.596307050960845741e-03 , 6.705026635782097323e-03 ,
76- -9.038454847872562370e-02 , 3.011717694088476838e-02 ,
77- -5.083053967307901710e-02 , -2.951212926932282599e-03 ,
78- 2.342446057919112673e-02 , -4.091208178777860222e-02 ,
79- -1.648470670751139844e-02 , -2.872262362355524484e-02 ,
80- 4.763925761561256522e-02 , -8.300037376164930147e-02 ,
81- 1.020429200603871836e-03 , -1.026734257188876599e-03 ,
82- 5.678534821710372327e-02 , 1.273635858276599142e-02 ,
83- -1.530143401888291177e-02 , -1.061672032476311256e-01 ,
84- -2.486859787145567074e-02 , 2.875323543588798395e-02 ,
85- -1.589185601903579381e-01 , 2.586167090689088510e-03 ,
86- -1.575150812458056548e-04 , -1.855360549216640564e-02 ,
87- 1.949822308966445150e-02 , -1.006552178977542650e-02 ,
88- 3.177030388421490936e-02 , 1.714350280402104215e-03 ,
89- -1.290389705296313833e-03 , -8.553511587973079699e-02 ,
90- -5.654638208496251539e-03 , -1.286955066237439882e-02 ,
91- 2.464156699303176462e-02 , -2.398203243424212178e-02 ,
92- -1.957110698882909630e-02 , 2.233493653505165544e-02 ,
93- 6.107843889444162372e-03 , 1.707076397717688723e-03 ,
94- -1.653994136896924094e-01 , 3.894358809712639147e-02 ,
95- -2.169596032233910010e-02 , 6.819702786556020371e-03 ,
96- -5.018240707559744503e-03 , 2.640663592968431426e-03 ,
97- -1.985295554050418160e-03 , -3.638422207618969423e-02 ,
98- 2.342932709960221863e-02 , -8.501331666888653493e-02 ,
99- -2.181253119706856591e-03 , 4.311299629418858387e-03 ,
100- -1.910329576491436726e-03 , -1.808810428459609043e-03 ,
101- -1.540075460017477360e-03 , -1.173703527688202929e-02 ,
102- -2.596307050960845741e-03 , 6.705026635782097323e-03 ,
103- -9.038454847872562370e-02 , 3.011717694088476838e-02 ,
104- -5.083053967307901710e-02 , -2.951212926932282599e-03 ,
105- 2.342446057919112673e-02 , -4.091208178777860222e-02 ,
106- -1.648470670751139844e-02 , -2.872262362355524484e-02 ,
107- 4.763925761561256522e-02 , -8.300037376164930147e-02 ,
108- 1.020429200603871836e-03 , -1.026734257188876599e-03 ,
109- 5.678534821710372327e-02 , 1.273635858276599142e-02 ,
110- -1.530143401888291177e-02 , -1.061672032476311256e-01 ,
111- -2.486859787145567074e-02 , 2.875323543588798395e-02 };
36+ std::vector<VALUETYPE> expected_e;
37+ std::vector<VALUETYPE> expected_f;
38+ std::vector<VALUETYPE> expected_v;
11239 int natoms;
11340 int nframes = 2 ;
11441 std::vector<double > expected_tot_e;
@@ -118,22 +45,34 @@ class TestInferDeepPotAFparamAparamNFramesPtExpt : public ::testing::Test {
11845
11946 static void SetUpTestSuite () {
12047#if defined(BUILD_PYTORCH) && BUILD_PT_EXPT
121- dp.init (" ../../tests/infer/fparam_aparam.pt2 " );
48+ dp.init (kModelPath );
12249#endif
12350 }
12451
12552 void SetUp () override {
12653#if !defined(BUILD_PYTORCH) || !BUILD_PT_EXPT
12754 GTEST_SKIP () << " Skip because PyTorch support is not enabled." ;
12855#endif
56+ deepmd_test::ExpectedRef ref;
57+ ref.load (kRefPath );
58+ auto e_single = ref.get <VALUETYPE>(" default" , " expected_e" );
59+ auto f_single = ref.get <VALUETYPE>(" default" , " expected_f" );
60+ auto v_single = ref.get <VALUETYPE>(" default" , " expected_v" );
61+ // Replicate single-frame reference for nframes batched inference.
62+ expected_e.reserve (nframes * e_single.size ());
63+ expected_f.reserve (nframes * f_single.size ());
64+ expected_v.reserve (nframes * v_single.size ());
65+ for (int kk = 0 ; kk < nframes; ++kk) {
66+ expected_e.insert (expected_e.end (), e_single.begin (), e_single.end ());
67+ expected_f.insert (expected_f.end (), f_single.begin (), f_single.end ());
68+ expected_v.insert (expected_v.end (), v_single.begin (), v_single.end ());
69+ }
12970
13071 natoms = expected_e.size () / nframes;
13172 EXPECT_EQ (nframes * natoms * 3 , expected_f.size ());
13273 EXPECT_EQ (nframes * natoms * 9 , expected_v.size ());
133- expected_tot_e.resize (nframes);
134- expected_tot_v.resize (static_cast <size_t >(nframes) * 9 );
135- std::fill (expected_tot_e.begin (), expected_tot_e.end (), 0 .);
136- std::fill (expected_tot_v.begin (), expected_tot_v.end (), 0 .);
74+ expected_tot_e.assign (nframes, 0 .);
75+ expected_tot_v.assign (static_cast <size_t >(nframes) * 9 , 0 .);
13776 for (int kk = 0 ; kk < nframes; ++kk) {
13877 for (int ii = 0 ; ii < natoms; ++ii) {
13978 expected_tot_e[kk] += expected_e[kk * natoms + ii];
0 commit comments