Skip to content

Commit 3217022

Browse files
committed
update model
1 parent 9b4e743 commit 3217022

File tree

10 files changed

+232
-310
lines changed

10 files changed

+232
-310
lines changed

deepmd/pt/model/model/transform_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def fit_output_to_model_output(
196196
if extended_coord_corr is not None:
197197
dc_corr = (
198198
dr.squeeze(-2).unsqueeze(-1)
199-
@ extended_coord_corr.unsqueeze(-2)
199+
@ extended_coord_corr.unsqueeze(-2).to(dr.dtype)
200200
).view(list(dc.shape[:-2]) + [1, 9]) # noqa: RUF005
201201
dc = dc + dc_corr
202202
model_ret[kk_derv_c] = dc

source/api_c/tests/test_deepspin_a.cc

Lines changed: 38 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,25 @@ class TestInferDeepSpinA : public ::testing::Test {
1919
int atype[6] = {0, 1, 1, 0, 1, 1};
2020
double box[9] = {13., 0., 0., 0., 13., 0., 0., 0., 13.};
2121
float boxf[9] = {13., 0., 0., 0., 13., 0., 0., 0., 13.};
22-
std::vector<double> expected_e = {-5.835211567762678, -5.071189078159807,
23-
-5.044361601406714, -5.582324154346981,
24-
-5.059906899269188, -5.074135576182056};
22+
std::vector<double> expected_e = {
23+
-1.8626545229251095e+00, -2.3502165071948093e+00, -2.3500944968573521e+00,
24+
-2.0688274735854710e+00, -2.3485113271625320e+00, -2.3489022338537353e+00,
25+
};
2526
std::vector<double> expected_f = {
26-
-0.0619881702551019, 0.0646720543680939, 0.2137632336140025,
27-
0.037800173877136, -0.096327623008356, -0.1531911892384847,
28-
-0.112204927558682, 0.0299145670766557, -0.0589474826303666,
29-
0.2278904556868233, 0.0382061907026398, 0.0888060647788163,
30-
-0.0078898845686437, 0.0019385598635839, -0.0791616129664364,
31-
-0.083607647181527, -0.0384037490026167, -0.0112690135575317};
27+
3.7989110974834261e-02, -6.8203560994098300e-02, 3.1554995279414300e-02,
28+
-6.0769407958790114e-02, 5.6658432967656878e-03, 2.1814741358389407e-02,
29+
1.5027739412753049e-02, 6.2090755323245192e-02, -5.3346442187326704e-02,
30+
-5.2134406995188787e-02, 4.0990812807417676e-02, -1.6987454510304811e-02,
31+
-6.7153786204261134e-03, -5.3801784772022326e-02, 5.6707773168242034e-02,
32+
6.6602343186817375e-02, 1.3257934338691726e-02, -3.9743613108414025e-02,
33+
};
3234
std::vector<double> expected_fm = {
33-
-3.0778301386623275,
34-
-1.3135930534661662,
35-
-0.8332043979367366,
36-
0.0,
37-
0.0,
38-
0.0,
39-
0.0,
40-
0.0,
41-
0.0,
42-
-0.5452347545527696,
43-
-0.2051506559632127,
44-
-0.4908015055951312,
45-
0.0,
46-
0.0,
47-
0.0,
48-
0.0,
49-
0.0,
50-
0.0,
35+
4.8385521455777196e+00, 5.3158441514550137e-01, 1.0855626815019124e+00,
36+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
37+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
38+
1.2140862110260138e+00, 9.6823434985033552e-01, 1.0689000529371890e+00,
39+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
40+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
5141
};
5242
int natoms;
5343
double expected_tot_e;
@@ -173,12 +163,12 @@ TEST_F(TestInferDeepSpinA, float_infer) {
173163

174164
TEST_F(TestInferDeepSpinA, cutoff) {
175165
double cutoff = DP_DeepSpinGetCutoff(dp);
176-
EXPECT_EQ(cutoff, 6.0);
166+
EXPECT_EQ(cutoff, 4.0);
177167
}
178168

179169
TEST_F(TestInferDeepSpinA, numb_types) {
180170
int numb_types = DP_DeepSpinGetNumbTypes(dp);
181-
EXPECT_EQ(numb_types, 2);
171+
EXPECT_EQ(numb_types, 3);
182172
}
183173

184174
TEST_F(TestInferDeepSpinA, numb_types_spin) {
@@ -188,7 +178,7 @@ TEST_F(TestInferDeepSpinA, numb_types_spin) {
188178

189179
TEST_F(TestInferDeepSpinA, type_map) {
190180
const char* type_map = DP_DeepSpinGetTypeMap(dp);
191-
char expected_type_map[] = "Ni O";
181+
char expected_type_map[] = "Ni O H";
192182
EXPECT_EQ(strcmp(type_map, expected_type_map), 0);
193183
DP_DeleteChar(type_map);
194184
}
@@ -204,34 +194,25 @@ class TestInferDeepSpinANoPBC : public ::testing::Test {
204194
float spinf[18] = {0.13, 0.02, 0.03, 0., 0., 0., 0., 0., 0.,
205195
0.14, 0.10, 0.12, 0., 0., 0., 0., 0., 0.};
206196
int atype[6] = {0, 1, 1, 0, 1, 1};
207-
std::vector<double> expected_e = {-5.921669893870771, -5.1676693791758685,
208-
-5.205933794558385, -5.58688965168251,
209-
-5.080322972018686, -5.08213772482076};
197+
std::vector<double> expected_e = {
198+
-1.9136796509970209e+00, -2.3532121417832528e+00,
199+
-2.3589759416772553e+00, -2.0689533840218703e+00,
200+
-2.3485273598793084e+00, -2.3489022338537353e+00};
210201
std::vector<double> expected_f = {
211-
-0.2929142244191496, 0.0801070990501456, 0.148216178514704,
212-
0.2929142244191503, -0.0801070990501454, -0.1482161785147037,
213-
-0.2094984819251435, 0.0241594118950041, -0.0215199116994508,
214-
0.3068843038300324, -0.001620530344866, 0.1508093841389746,
215-
-0.0122719879278721, 0.0186341247897136, -0.1137104245023705,
216-
-0.0851138339770169, -0.0411730063398516, -0.0155790479371533};
217-
std::vector<double> expected_fm = {-1.5298530476860008,
218-
0.0071315024546899,
219-
0.0650492472558729,
220-
0.,
221-
0.,
222-
0.,
223-
0.,
224-
0.,
225-
0.,
226-
-0.6212052813442365,
227-
-0.2290265978320395,
228-
-0.5101405083352206,
229-
0.,
230-
0.,
231-
0.,
232-
0.,
233-
0.,
234-
0.};
202+
5.2440246818294511e-02, -8.2643189092284075e-03, -1.6057110078610215e-02,
203+
-5.2440246818295698e-02, 8.2643189092281334e-03, 1.6057110078610277e-02,
204+
-1.6724663644564395e-03, 7.9346065821642349e-05, -2.5251632397208987e-04,
205+
-5.6934098675373246e-02, 4.0398593044712161e-02, -1.6520316500527876e-02,
206+
-7.9878577602028808e-03, -5.3736758888210570e-02, 5.6516778947603999e-02,
207+
6.6594422800032166e-02, 1.3258819777676990e-02, -3.9743946123104140e-02,
208+
};
209+
std::vector<double> expected_fm = {
210+
4.5904360179010135e+00, 6.2821415259365443e-01, 9.2483695213043082e-01,
211+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
212+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
213+
1.2125967529512662e+00, 9.6807902483755459e-01, 1.0691011858092361e+00,
214+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
215+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00};
235216
int natoms;
236217
double expected_tot_e;
237218
// std::vector<double> expected_tot_v;

source/api_c/tests/test_deepspin_a_hpp.cc

Lines changed: 36 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -19,35 +19,25 @@ class TestInferDeepSpinAHPP : public ::testing::Test {
1919
0.14, 0.10, 0.12, 0., 0., 0., 0., 0., 0.};
2020
std::vector<int> atype = {0, 1, 1, 0, 1, 1};
2121
std::vector<VALUETYPE> box = {13., 0., 0., 0., 13., 0., 0., 0., 13.};
22-
std::vector<VALUETYPE> expected_e = {-5.835211567762678, -5.071189078159807,
23-
-5.044361601406714, -5.582324154346981,
24-
-5.059906899269188, -5.074135576182056};
22+
std::vector<VALUETYPE> expected_e = {
23+
-1.8626545229251095e+00, -2.3502165071948093e+00, -2.3500944968573521e+00,
24+
-2.0688274735854710e+00, -2.3485113271625320e+00, -2.3489022338537353e+00,
25+
};
2526
std::vector<VALUETYPE> expected_f = {
26-
-0.0619881702551019, 0.0646720543680939, 0.2137632336140025,
27-
0.037800173877136, -0.096327623008356, -0.1531911892384847,
28-
-0.112204927558682, 0.0299145670766557, -0.0589474826303666,
29-
0.2278904556868233, 0.0382061907026398, 0.0888060647788163,
30-
-0.0078898845686437, 0.0019385598635839, -0.0791616129664364,
31-
-0.083607647181527, -0.0384037490026167, -0.0112690135575317};
27+
3.7989110974834261e-02, -6.8203560994098300e-02, 3.1554995279414300e-02,
28+
-6.0769407958790114e-02, 5.6658432967656878e-03, 2.1814741358389407e-02,
29+
1.5027739412753049e-02, 6.2090755323245192e-02, -5.3346442187326704e-02,
30+
-5.2134406995188787e-02, 4.0990812807417676e-02, -1.6987454510304811e-02,
31+
-6.7153786204261134e-03, -5.3801784772022326e-02, 5.6707773168242034e-02,
32+
6.6602343186817375e-02, 1.3257934338691726e-02, -3.9743613108414025e-02,
33+
};
3234
std::vector<VALUETYPE> expected_fm = {
33-
-3.0778301386623275,
34-
-1.3135930534661662,
35-
-0.8332043979367366,
36-
0.0,
37-
0.0,
38-
0.0,
39-
0.0,
40-
0.0,
41-
0.0,
42-
-0.5452347545527696,
43-
-0.2051506559632127,
44-
-0.4908015055951312,
45-
0.0,
46-
0.0,
47-
0.0,
48-
0.0,
49-
0.0,
50-
0.0,
35+
4.8385521455777196e+00, 5.3158441514550137e-01, 1.0855626815019124e+00,
36+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
37+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
38+
1.2140862110260138e+00, 9.6823434985033552e-01, 1.0689000529371890e+00,
39+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
40+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
5141
};
5242
unsigned int natoms;
5343
double expected_tot_e;
@@ -176,34 +166,26 @@ class TestInferDeepSpinANoPbcHPP : public ::testing::Test {
176166
0.14, 0.10, 0.12, 0., 0., 0., 0., 0., 0.};
177167
std::vector<int> atype = {0, 1, 1, 0, 1, 1};
178168
std::vector<VALUETYPE> box = {};
179-
std::vector<VALUETYPE> expected_e = {-5.921669893870771, -5.1676693791758685,
180-
-5.205933794558385, -5.58688965168251,
181-
-5.080322972018686, -5.08213772482076};
169+
std::vector<VALUETYPE> expected_e = {
170+
-1.9136796509970209e+00, -2.3532121417832528e+00,
171+
-2.3589759416772553e+00, -2.0689533840218703e+00,
172+
-2.3485273598793084e+00, -2.3489022338537353e+00};
182173
std::vector<VALUETYPE> expected_f = {
183-
-0.2929142244191496, 0.0801070990501456, 0.148216178514704,
184-
0.2929142244191503, -0.0801070990501454, -0.1482161785147037,
185-
-0.2094984819251435, 0.0241594118950041, -0.0215199116994508,
186-
0.3068843038300324, -0.001620530344866, 0.1508093841389746,
187-
-0.0122719879278721, 0.0186341247897136, -0.1137104245023705,
188-
-0.0851138339770169, -0.0411730063398516, -0.0155790479371533};
189-
std::vector<VALUETYPE> expected_fm = {-1.5298530476860008,
190-
0.0071315024546899,
191-
0.0650492472558729,
192-
0.,
193-
0.,
194-
0.,
195-
0.,
196-
0.,
197-
0.,
198-
-0.6212052813442365,
199-
-0.2290265978320395,
200-
-0.5101405083352206,
201-
0.,
202-
0.,
203-
0.,
204-
0.,
205-
0.,
206-
0.};
174+
5.2440246818294511e-02, -8.2643189092284075e-03, -1.6057110078610215e-02,
175+
-5.2440246818295698e-02, 8.2643189092281334e-03, 1.6057110078610277e-02,
176+
-1.6724663644564395e-03, 7.9346065821642349e-05, -2.5251632397208987e-04,
177+
-5.6934098675373246e-02, 4.0398593044712161e-02, -1.6520316500527876e-02,
178+
-7.9878577602028808e-03, -5.3736758888210570e-02, 5.6516778947603999e-02,
179+
6.6594422800032166e-02, 1.3258819777676990e-02, -3.9743946123104140e-02,
180+
};
181+
std::vector<VALUETYPE> expected_fm = {
182+
4.5904360179010135e+00, 6.2821415259365443e-01, 9.2483695213043082e-01,
183+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
184+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
185+
1.2125967529512662e+00, 9.6807902483755459e-01, 1.0691011858092361e+00,
186+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
187+
0.0000000000000000e+00, 0.0000000000000000e+00, 0.0000000000000000e+00,
188+
};
207189
unsigned int natoms;
208190
double expected_tot_e;
209191
// std::vector<VALUETYPE> expected_tot_v;

source/api_cc/src/DeepSpinPT.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
251251
c10::IValue energy_ = outputs.at("energy");
252252
c10::IValue force_ = outputs.at("extended_force");
253253
c10::IValue force_mag_ = outputs.at("extended_force_mag");
254-
bool has_virial = outputs.contains("virial");
254+
bool has_virial = outputs.contains(c10::IValue("virial"));
255255
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
256256
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
257257
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
@@ -297,7 +297,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
297297
atom_energy.resize(static_cast<size_t>(nframes) * fwd_map.size());
298298
select_map<VALUETYPE>(atom_energy, datom_energy, bkw_map, 1, nframes,
299299
fwd_map.size(), nall_real);
300-
if (outputs.contains("extended_virial")) {
300+
if (outputs.contains(c10::IValue("extended_virial"))) {
301301
c10::IValue atom_virial_ = outputs.at("extended_virial");
302302
torch::Tensor flat_atom_virial_ =
303303
atom_virial_.toTensor().view({-1}).to(floatType);
@@ -421,7 +421,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
421421
c10::IValue energy_ = outputs.at("energy");
422422
c10::IValue force_ = outputs.at("force");
423423
c10::IValue force_mag_ = outputs.at("force_mag");
424-
bool has_virial = outputs.contains("virial");
424+
bool has_virial = outputs.contains(c10::IValue("virial"));
425425
torch::Tensor flat_energy_ = energy_.toTensor().view({-1});
426426
torch::Tensor cpu_energy_ = flat_energy_.to(torch::kCPU);
427427
ener.assign(cpu_energy_.data_ptr<ENERGYTYPE>(),
@@ -453,7 +453,7 @@ void DeepSpinPT::compute(ENERGYVTYPE& ener,
453453
atom_energy.assign(
454454
cpu_atom_energy_.data_ptr<VALUETYPE>(),
455455
cpu_atom_energy_.data_ptr<VALUETYPE>() + cpu_atom_energy_.numel());
456-
if (outputs.contains("atom_virial")) {
456+
if (outputs.contains(c10::IValue("atom_virial"))) {
457457
c10::IValue atom_virial_ = outputs.at("atom_virial");
458458
torch::Tensor flat_atom_virial_ =
459459
atom_virial_.toTensor().view({-1}).to(floatType);

0 commit comments

Comments
 (0)