Skip to content

Commit 26013cb

Browse files
authored
feat(pt): allow --init-frz-model for pt model converted from tf model (#5091)
- Add try-catch block for loading frozen models - Use strict=False when state_dict keys don't match - Log warnings for model state_dict mismatches - Prevent crashes when loading models with different architectures" Fix bug in #5090 <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Frozen model loading now uses non‑strict mode; missing or unexpected parameters are logged as warnings instead of causing failure. * **Tests** * Added an end-to-end test covering conversion and loading of frozen models from external formats. * Test I/O standardized to UTF‑8; teardown extended to remove generated artifacts. * **New Files** * Added a JSON test model configuration for validation and training workflows. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent 89a3180 commit 26013cb

4 files changed

Lines changed: 17166 additions & 3 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -613,7 +613,13 @@ def single_model_finetune(
613613

614614
if init_frz_model is not None:
615615
frz_model = torch.jit.load(init_frz_model, map_location=DEVICE)
616-
self.model.load_state_dict(frz_model.state_dict())
616+
state = frz_model.state_dict()
617+
missing, unexpected = self.model.load_state_dict(state, strict=False)
618+
if missing or unexpected:
619+
log.warning(
620+
"Checkpoint loaded non-strictly. "
621+
f"Missing keys: {missing}, Unexpected keys: {unexpected}"
622+
)
617623

618624
# Get model prob for multi-task
619625
if self.multi_task:
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
{
2+
"model": {
3+
"type_map": [
4+
"O",
5+
"H"
6+
],
7+
"descriptor": {
8+
"type": "se_e2_a",
9+
"sel": [
10+
23,
11+
46
12+
],
13+
"rcut_smth": 0.50,
14+
"rcut": 4.00,
15+
"neuron": [
16+
2,
17+
4
18+
],
19+
"resnet_dt": false,
20+
"axis_neuron": 4,
21+
"type_one_side": true,
22+
"seed": 1,
23+
"_comment": " that's all"
24+
},
25+
"fitting_net": {
26+
"neuron": [
27+
120
28+
],
29+
"resnet_dt": true,
30+
"seed": 1,
31+
"_comment": " that's all"
32+
},
33+
"data_stat_nbatch": 20,
34+
"_comment": " that's all"
35+
},
36+
"learning_rate": {
37+
"type": "exp",
38+
"decay_steps": 5000,
39+
"start_lr": 0.001,
40+
"stop_lr": 3.51e-8,
41+
"_comment": "that's all"
42+
},
43+
"loss": {
44+
"type": "ener",
45+
"start_pref_e": 0.02,
46+
"limit_pref_e": 1,
47+
"start_pref_f": 1000,
48+
"limit_pref_f": 1,
49+
"_comment": " that's all"
50+
},
51+
"training": {
52+
"training_data": {
53+
"systems": [
54+
"../data/data_0",
55+
"../data/data_1",
56+
"../data/data_2"
57+
],
58+
"batch_size": 1,
59+
"_comment": "that's all"
60+
},
61+
"validation_data": {
62+
"systems": [
63+
"../data/data_3"
64+
],
65+
"batch_size": 1,
66+
"numb_btch": 3,
67+
"_comment": "that's all"
68+
},
69+
"numb_steps": 100000,
70+
"seed": 10,
71+
"disp_file": "lcurve.out",
72+
"disp_freq": 100,
73+
"save_freq": 10000,
74+
"_comment": "that's all"
75+
},
76+
"_comment": "that's all"
77+
}

0 commit comments

Comments
 (0)