Skip to content

Commit 1be559c

Browse files
committed
add unit test for pb2pth model
1 parent 335d53e commit 1be559c

4 files changed

Lines changed: 17154 additions & 7 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -616,13 +616,14 @@ def single_model_finetune(
616616
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
617617
try:
618618
self.model.load_state_dict(frz_model.state_dict())
619-
except RuntimeError as e:
620-
if "Missing key(s) in state_dict" in str(e):
619+
except RuntimeError as err_msg:
620+
if "Missing key(s) in state_dict" in str(
621+
err_msg
622+
) or "Unexpected key(s) in state_dict" in str(err_msg):
621623
self.model.load_state_dict(frz_model.state_dict(), strict=False)
622-
log.warning("Use strict=False to ignore non-matching keys.")
623-
log.warning(f"Model state_dict mismatch detected: {e}")
624+
log.warning("Loaded with strict=False to ignore non-matching keys.")
624625
else:
625-
raise e
626+
raise
626627

627628
# Get model prob for multi-task
628629
if self.multi_task:
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
{
2+
"model": {
3+
"type_map": [
4+
"O",
5+
"H"
6+
],
7+
"descriptor": {
8+
"type": "se_e2_a",
9+
"sel": [
10+
23,
11+
46
12+
],
13+
"rcut_smth": 0.50,
14+
"rcut": 4.00,
15+
"neuron": [
16+
2,
17+
4
18+
],
19+
"resnet_dt": false,
20+
"axis_neuron": 4,
21+
"type_one_side": true,
22+
"seed": 1,
23+
"_comment": " that's all"
24+
},
25+
"fitting_net": {
26+
"neuron": [
27+
120
28+
],
29+
"resnet_dt": true,
30+
"seed": 1,
31+
"_comment": " that's all"
32+
},
33+
"data_stat_nbatch": 20,
34+
"_comment": " that's all"
35+
},
36+
"learning_rate": {
37+
"type": "exp",
38+
"decay_steps": 5000,
39+
"start_lr": 0.001,
40+
"stop_lr": 3.51e-8,
41+
"_comment": "that's all"
42+
},
43+
"loss": {
44+
"type": "ener",
45+
"start_pref_e": 0.02,
46+
"limit_pref_e": 1,
47+
"start_pref_f": 1000,
48+
"limit_pref_f": 1,
49+
"_comment": " that's all"
50+
},
51+
"training": {
52+
"training_data": {
53+
"systems": [
54+
"../data/data_0",
55+
"../data/data_1",
56+
"../data/data_2"
57+
],
58+
"batch_size": 1,
59+
"_comment": "that's all"
60+
},
61+
"validation_data": {
62+
"systems": [
63+
"../data/data_3"
64+
],
65+
"batch_size": 1,
66+
"numb_btch": 3,
67+
"_comment": "that's all"
68+
},
69+
"numb_steps": 100000,
70+
"seed": 10,
71+
"disp_file": "lcurve.out",
72+
"disp_freq": 100,
73+
"save_freq": 10000,
74+
"_comment": "that's all"
75+
},
76+
"_comment": "that's all"
77+
}

0 commit comments

Comments
 (0)