Skip to content

Commit 35d5cb2

Browse files
author
Han Wang
committed
test: replace hardcoded C++ expected values with sidecar reference files
Each `gen_*.py` script writes a plain-text `.expected` file alongside the .pt2/.pth model files; C++ tests load the arrays at SetUp() via a new `ExpectedRef` helper instead of carrying ~3000 magic-number floats baked into the test sources. This removes the manual paste-from-stdout sync that broke whenever dpmodel numerics shifted (e.g., the bias/idt init alignment in ff6a931), and keeps reference values in lockstep with the gen scripts that produce them. - New `source/api_cc/tests/expected_ref.h` (header-only loader) - New `gen_common.write_expected_ref()` writer; `print_cpp_values` and `print_cpp_spin_values` retained as ad-hoc debug helpers - 7 gen scripts (dpa1/2/3, spin, fparam_aparam, model_devi) emit sidecars - 15 C++ tests load arrays from sidecars (no test cases removed) - `test_deeppot_ptexpt.cc` keeps its TF-derived hardcoded refs (its model weights flow through `gen_fparam_aparam.py`'s state_dict load path, not random init, so they're stable) - Bump float32 atol from 1e-5 to 4e-5 in `test_compressed_forward` for DPA1 / SeAttenV2: the new bias init scale widens the embedding-net output range, observed compression-tabulation error ~3.2e-5 - `.expected` files are gitignored — regenerated at CI time
1 parent ff6a931 commit 35d5cb2

26 files changed

Lines changed: 664 additions & 1274 deletions

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,4 @@ frozen_model.*
7373

7474
# Test system directories
7575
system/
76+
*.expected

source/api_cc/tests/expected_ref.h

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
// SPDX-License-Identifier: LGPL-3.0-or-later
2+
#pragma once
3+
4+
#include <fstream>
5+
#include <map>
6+
#include <sstream>
7+
#include <stdexcept>
8+
#include <string>
9+
#include <vector>
10+
11+
namespace deepmd_test {
12+
13+
// Loader for sidecar reference files written by
14+
// `gen_common.write_expected_ref`.
15+
//
16+
// File format:
17+
// # auto-generated -- do not edit
18+
// [case_name_1]
19+
// array_name_1 N
20+
// v0
21+
// v1
22+
// ...
23+
// array_name_2 M
24+
// ...
25+
//
26+
// [case_name_2]
27+
// ...
28+
//
29+
// Lines beginning with '#' or empty lines are ignored.
30+
class ExpectedRef {
31+
public:
32+
// Parse `path`. Throws std::runtime_error on malformed input.
33+
void load(const std::string& path) {
34+
std::ifstream in(path);
35+
if (!in) {
36+
throw std::runtime_error("ExpectedRef: cannot open " + path);
37+
}
38+
sections_.clear();
39+
std::string line;
40+
std::string current_section;
41+
while (std::getline(in, line)) {
42+
if (line.empty() || line[0] == '#') {
43+
continue;
44+
}
45+
if (line.front() == '[' && line.back() == ']') {
46+
current_section = line.substr(1, line.size() - 2);
47+
continue;
48+
}
49+
// "<key> <count>" header — followed by `count` numeric lines.
50+
if (current_section.empty()) {
51+
throw std::runtime_error("ExpectedRef: array '" + line +
52+
"' before any [section]");
53+
}
54+
std::istringstream iss(line);
55+
std::string key;
56+
std::size_t n = 0;
57+
if (!(iss >> key >> n)) {
58+
throw std::runtime_error("ExpectedRef: bad header line: " + line);
59+
}
60+
std::vector<double> values;
61+
values.reserve(n);
62+
for (std::size_t i = 0; i < n; ++i) {
63+
if (!std::getline(in, line)) {
64+
throw std::runtime_error("ExpectedRef: unexpected EOF in '" + key +
65+
"'");
66+
}
67+
values.push_back(std::stod(line));
68+
}
69+
sections_[current_section][key] = std::move(values);
70+
}
71+
}
72+
73+
// Get array of `key` from `case_name`. Throws if missing.
74+
template <typename T = double>
75+
std::vector<T> get(const std::string& case_name,
76+
const std::string& key) const {
77+
auto sit = sections_.find(case_name);
78+
if (sit == sections_.end()) {
79+
throw std::runtime_error("ExpectedRef: missing case '" + case_name + "'");
80+
}
81+
auto kit = sit->second.find(key);
82+
if (kit == sit->second.end()) {
83+
throw std::runtime_error("ExpectedRef: missing array '" + key +
84+
"' in case '" + case_name + "'");
85+
}
86+
return std::vector<T>(kit->second.begin(), kit->second.end());
87+
}
88+
89+
private:
90+
std::map<std::string, std::map<std::string, std::vector<double>>> sections_;
91+
};
92+
93+
} // namespace deepmd_test

source/api_cc/tests/test_deeppot_a_fparam_aparam_nframes_ptexpt.cc

Lines changed: 26 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,14 @@
99

1010
#include "DeepPot.h"
1111
#include "DeepPotPTExpt.h"
12+
#include "expected_ref.h"
1213
#include "test_utils.h"
1314

15+
namespace {
16+
constexpr const char* kRefPath = "../../tests/infer/fparam_aparam.expected";
17+
constexpr const char* kModelPath = "../../tests/infer/fparam_aparam.pt2";
18+
} // namespace
19+
1420
template <class VALUETYPE>
1521
class TestInferDeepPotAFparamAparamNFramesPtExpt : public ::testing::Test {
1622
protected:
@@ -27,88 +33,9 @@ class TestInferDeepPotAFparamAparamNFramesPtExpt : public ::testing::Test {
2733
std::vector<VALUETYPE> aparam = {
2834
0.25852028, 0.25852028, 0.25852028, 0.25852028, 0.25852028, 0.25852028,
2935
0.25852028, 0.25852028, 0.25852028, 0.25852028, 0.25852028, 0.25852028};
30-
// Same reference values as single-frame, duplicated for 2 frames
31-
std::vector<VALUETYPE> expected_e = {
32-
-1.038271223729636539e-01, -7.285433579124989123e-02,
33-
-9.467600492266425860e-02, -1.467050207422957442e-01,
34-
-7.660561676973243195e-02, -7.277296000253175023e-02,
35-
-1.038271223729636539e-01, -7.285433579124989123e-02,
36-
-9.467600492266425860e-02, -1.467050207422957442e-01,
37-
-7.660561676973243195e-02, -7.277296000253175023e-02};
38-
std::vector<VALUETYPE> expected_f = {
39-
6.622266941151369601e-02, 5.278739714221529489e-02,
40-
2.265728009692277028e-02, -2.606048291367509331e-02,
41-
-4.538812303131847109e-02, 1.058247419681241676e-02,
42-
1.679392617013223121e-01, -2.257826240741929533e-03,
43-
-4.490146347357203138e-02, -1.148364179422036724e-01,
44-
-1.169790528013799069e-02, 6.140403441496700837e-02,
45-
-8.078778123309421355e-02, -5.838879041789352825e-02,
46-
6.773641084621376263e-02, -1.247724902386305318e-02,
47-
6.494524782787665373e-02, -1.174787360813439457e-01,
48-
6.622266941151369601e-02, 5.278739714221529489e-02,
49-
2.265728009692277028e-02, -2.606048291367509331e-02,
50-
-4.538812303131847109e-02, 1.058247419681241676e-02,
51-
1.679392617013223121e-01, -2.257826240741929533e-03,
52-
-4.490146347357203138e-02, -1.148364179422036724e-01,
53-
-1.169790528013799069e-02, 6.140403441496700837e-02,
54-
-8.078778123309421355e-02, -5.838879041789352825e-02,
55-
6.773641084621376263e-02, -1.247724902386305318e-02,
56-
6.494524782787665373e-02, -1.174787360813439457e-01};
57-
std::vector<VALUETYPE> expected_v = {
58-
-1.589185601903579381e-01, 2.586167090689088510e-03,
59-
-1.575150812458056548e-04, -1.855360549216640564e-02,
60-
1.949822308966445150e-02, -1.006552178977542650e-02,
61-
3.177030388421490936e-02, 1.714350280402104215e-03,
62-
-1.290389705296313833e-03, -8.553511587973079699e-02,
63-
-5.654638208496251539e-03, -1.286955066237439882e-02,
64-
2.464156699303176462e-02, -2.398203243424212178e-02,
65-
-1.957110698882909630e-02, 2.233493653505165544e-02,
66-
6.107843889444162372e-03, 1.707076397717688723e-03,
67-
-1.653994136896924094e-01, 3.894358809712639147e-02,
68-
-2.169596032233910010e-02, 6.819702786556020371e-03,
69-
-5.018240707559744503e-03, 2.640663592968431426e-03,
70-
-1.985295554050418160e-03, -3.638422207618969423e-02,
71-
2.342932709960221863e-02, -8.501331666888653493e-02,
72-
-2.181253119706856591e-03, 4.311299629418858387e-03,
73-
-1.910329576491436726e-03, -1.808810428459609043e-03,
74-
-1.540075460017477360e-03, -1.173703527688202929e-02,
75-
-2.596307050960845741e-03, 6.705026635782097323e-03,
76-
-9.038454847872562370e-02, 3.011717694088476838e-02,
77-
-5.083053967307901710e-02, -2.951212926932282599e-03,
78-
2.342446057919112673e-02, -4.091208178777860222e-02,
79-
-1.648470670751139844e-02, -2.872262362355524484e-02,
80-
4.763925761561256522e-02, -8.300037376164930147e-02,
81-
1.020429200603871836e-03, -1.026734257188876599e-03,
82-
5.678534821710372327e-02, 1.273635858276599142e-02,
83-
-1.530143401888291177e-02, -1.061672032476311256e-01,
84-
-2.486859787145567074e-02, 2.875323543588798395e-02,
85-
-1.589185601903579381e-01, 2.586167090689088510e-03,
86-
-1.575150812458056548e-04, -1.855360549216640564e-02,
87-
1.949822308966445150e-02, -1.006552178977542650e-02,
88-
3.177030388421490936e-02, 1.714350280402104215e-03,
89-
-1.290389705296313833e-03, -8.553511587973079699e-02,
90-
-5.654638208496251539e-03, -1.286955066237439882e-02,
91-
2.464156699303176462e-02, -2.398203243424212178e-02,
92-
-1.957110698882909630e-02, 2.233493653505165544e-02,
93-
6.107843889444162372e-03, 1.707076397717688723e-03,
94-
-1.653994136896924094e-01, 3.894358809712639147e-02,
95-
-2.169596032233910010e-02, 6.819702786556020371e-03,
96-
-5.018240707559744503e-03, 2.640663592968431426e-03,
97-
-1.985295554050418160e-03, -3.638422207618969423e-02,
98-
2.342932709960221863e-02, -8.501331666888653493e-02,
99-
-2.181253119706856591e-03, 4.311299629418858387e-03,
100-
-1.910329576491436726e-03, -1.808810428459609043e-03,
101-
-1.540075460017477360e-03, -1.173703527688202929e-02,
102-
-2.596307050960845741e-03, 6.705026635782097323e-03,
103-
-9.038454847872562370e-02, 3.011717694088476838e-02,
104-
-5.083053967307901710e-02, -2.951212926932282599e-03,
105-
2.342446057919112673e-02, -4.091208178777860222e-02,
106-
-1.648470670751139844e-02, -2.872262362355524484e-02,
107-
4.763925761561256522e-02, -8.300037376164930147e-02,
108-
1.020429200603871836e-03, -1.026734257188876599e-03,
109-
5.678534821710372327e-02, 1.273635858276599142e-02,
110-
-1.530143401888291177e-02, -1.061672032476311256e-01,
111-
-2.486859787145567074e-02, 2.875323543588798395e-02};
36+
std::vector<VALUETYPE> expected_e;
37+
std::vector<VALUETYPE> expected_f;
38+
std::vector<VALUETYPE> expected_v;
11239
int natoms;
11340
int nframes = 2;
11441
std::vector<double> expected_tot_e;
@@ -118,22 +45,34 @@ class TestInferDeepPotAFparamAparamNFramesPtExpt : public ::testing::Test {
11845

11946
static void SetUpTestSuite() {
12047
#if defined(BUILD_PYTORCH) && BUILD_PT_EXPT
121-
dp.init("../../tests/infer/fparam_aparam.pt2");
48+
dp.init(kModelPath);
12249
#endif
12350
}
12451

12552
void SetUp() override {
12653
#if !defined(BUILD_PYTORCH) || !BUILD_PT_EXPT
12754
GTEST_SKIP() << "Skip because PyTorch support is not enabled.";
12855
#endif
56+
deepmd_test::ExpectedRef ref;
57+
ref.load(kRefPath);
58+
auto e_single = ref.get<VALUETYPE>("default", "expected_e");
59+
auto f_single = ref.get<VALUETYPE>("default", "expected_f");
60+
auto v_single = ref.get<VALUETYPE>("default", "expected_v");
61+
// Replicate single-frame reference for nframes batched inference.
62+
expected_e.reserve(nframes * e_single.size());
63+
expected_f.reserve(nframes * f_single.size());
64+
expected_v.reserve(nframes * v_single.size());
65+
for (int kk = 0; kk < nframes; ++kk) {
66+
expected_e.insert(expected_e.end(), e_single.begin(), e_single.end());
67+
expected_f.insert(expected_f.end(), f_single.begin(), f_single.end());
68+
expected_v.insert(expected_v.end(), v_single.begin(), v_single.end());
69+
}
12970

13071
natoms = expected_e.size() / nframes;
13172
EXPECT_EQ(nframes * natoms * 3, expected_f.size());
13273
EXPECT_EQ(nframes * natoms * 9, expected_v.size());
133-
expected_tot_e.resize(nframes);
134-
expected_tot_v.resize(static_cast<size_t>(nframes) * 9);
135-
std::fill(expected_tot_e.begin(), expected_tot_e.end(), 0.);
136-
std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
74+
expected_tot_e.assign(nframes, 0.);
75+
expected_tot_v.assign(static_cast<size_t>(nframes) * 9, 0.);
13776
for (int kk = 0; kk < nframes; ++kk) {
13877
for (int ii = 0; ii < natoms; ++ii) {
13978
expected_tot_e[kk] += expected_e[kk * natoms + ii];

source/api_cc/tests/test_deeppot_a_fparam_aparam_pt.cc

Lines changed: 17 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,19 @@
1010
#include <vector>
1111

1212
#include "DeepPot.h"
13+
#include "expected_ref.h"
1314
#include "neighbor_list.h"
1415
#include "test_utils.h"
1516

1617
// 1e-10 cannot pass; unclear bug or not
1718
#undef EPSILON
1819
#define EPSILON (std::is_same<VALUETYPE, double>::value ? 1e-7 : 1e-4)
1920

21+
namespace {
22+
constexpr const char* kRefPath = "../../tests/infer/fparam_aparam.expected";
23+
constexpr const char* kModelPath = "../../tests/infer/fparam_aparam.pth";
24+
} // namespace
25+
2026
template <class VALUETYPE>
2127
class TestInferDeepPotAFParamAParamPt : public ::testing::Test {
2228
protected:
@@ -28,50 +34,9 @@ class TestInferDeepPotAFParamAParamPt : public ::testing::Test {
2834
std::vector<VALUETYPE> fparam = {0.25852028};
2935
std::vector<VALUETYPE> aparam = {0.25852028, 0.25852028, 0.25852028,
3036
0.25852028, 0.25852028, 0.25852028};
31-
// Generated by source/tests/infer/gen_fparam_aparam.py
32-
// (from pre-committed fparam_aparam_default.pth, type_one_side=True)
33-
std::vector<VALUETYPE> expected_e = {
34-
-1.038271223729636539e-01, -7.285433579124989123e-02,
35-
-9.467600492266425860e-02, -1.467050207422957442e-01,
36-
-7.660561676973243195e-02, -7.277296000253175023e-02};
37-
std::vector<VALUETYPE> expected_f = {
38-
6.622266941151369601e-02, 5.278739714221529489e-02,
39-
2.265728009692277028e-02, -2.606048291367509331e-02,
40-
-4.538812303131847109e-02, 1.058247419681241676e-02,
41-
1.679392617013223121e-01, -2.257826240741929533e-03,
42-
-4.490146347357203138e-02, -1.148364179422036724e-01,
43-
-1.169790528013799069e-02, 6.140403441496700837e-02,
44-
-8.078778123309421355e-02, -5.838879041789352825e-02,
45-
6.773641084621376263e-02, -1.247724902386305318e-02,
46-
6.494524782787665373e-02, -1.174787360813439457e-01};
47-
std::vector<VALUETYPE> expected_v = {
48-
-1.589185601903579381e-01, 2.586167090689088510e-03,
49-
-1.575150812458056548e-04, -1.855360549216640564e-02,
50-
1.949822308966445150e-02, -1.006552178977542650e-02,
51-
3.177030388421490936e-02, 1.714350280402104215e-03,
52-
-1.290389705296313833e-03, -8.553511587973079699e-02,
53-
-5.654638208496251539e-03, -1.286955066237439882e-02,
54-
2.464156699303176462e-02, -2.398203243424212178e-02,
55-
-1.957110698882909630e-02, 2.233493653505165544e-02,
56-
6.107843889444162372e-03, 1.707076397717688723e-03,
57-
-1.653994136896924094e-01, 3.894358809712639147e-02,
58-
-2.169596032233910010e-02, 6.819702786556020371e-03,
59-
-5.018240707559744503e-03, 2.640663592968431426e-03,
60-
-1.985295554050418160e-03, -3.638422207618969423e-02,
61-
2.342932709960221863e-02, -8.501331666888653493e-02,
62-
-2.181253119706856591e-03, 4.311299629418858387e-03,
63-
-1.910329576491436726e-03, -1.808810428459609043e-03,
64-
-1.540075460017477360e-03, -1.173703527688202929e-02,
65-
-2.596307050960845741e-03, 6.705026635782097323e-03,
66-
-9.038454847872562370e-02, 3.011717694088476838e-02,
67-
-5.083053967307901710e-02, -2.951212926932282599e-03,
68-
2.342446057919112673e-02, -4.091208178777860222e-02,
69-
-1.648470670751139844e-02, -2.872262362355524484e-02,
70-
4.763925761561256522e-02, -8.300037376164930147e-02,
71-
1.020429200603871836e-03, -1.026734257188876599e-03,
72-
5.678534821710372327e-02, 1.273635858276599142e-02,
73-
-1.530143401888291177e-02, -1.061672032476311256e-01,
74-
-2.486859787145567074e-02, 2.875323543588798395e-02};
37+
std::vector<VALUETYPE> expected_e;
38+
std::vector<VALUETYPE> expected_f;
39+
std::vector<VALUETYPE> expected_v;
7540
int natoms;
7641
double expected_tot_e;
7742
std::vector<VALUETYPE> expected_tot_v;
@@ -82,14 +47,19 @@ class TestInferDeepPotAFParamAParamPt : public ::testing::Test {
8247
#ifndef BUILD_PYTORCH
8348
GTEST_SKIP() << "Skip because PyTorch support is not enabled.";
8449
#endif
85-
dp.init("../../tests/infer/fparam_aparam.pth");
50+
deepmd_test::ExpectedRef ref;
51+
ref.load(kRefPath);
52+
expected_e = ref.get<VALUETYPE>("default", "expected_e");
53+
expected_f = ref.get<VALUETYPE>("default", "expected_f");
54+
expected_v = ref.get<VALUETYPE>("default", "expected_v");
55+
56+
dp.init(kModelPath);
8657

8758
natoms = expected_e.size();
8859
EXPECT_EQ(natoms * 3, expected_f.size());
8960
EXPECT_EQ(natoms * 9, expected_v.size());
9061
expected_tot_e = 0.;
91-
expected_tot_v.resize(9);
92-
std::fill(expected_tot_v.begin(), expected_tot_v.end(), 0.);
62+
expected_tot_v.assign(9, 0.);
9363
for (int ii = 0; ii < natoms; ++ii) {
9464
expected_tot_e += expected_e[ii];
9565
}

0 commit comments

Comments
 (0)