Skip to content

Commit a735ca7

Browse files
committed
Add ZBL model fine-tuning from standard models
Enhanced Trainer to support fine-tuning ZBL models from standard models by handling key mapping and random state initialization. Added corresponding tests to verify ZBL fine-tuning behavior and ensure correct state dict transfer in test_training.py.
1 parent a9b41c3 commit a735ca7

2 files changed

Lines changed: 61 additions & 5 deletions

File tree

deepmd/pt/train/training.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -510,17 +510,31 @@ def collect_single_finetune_params(
510510
if i != "_extra_state" and f".{_model_key}." in i
511511
]
512512
for item_key in target_keys:
513+
new_key = item_key.replace(
514+
f".{_model_key}.", f".{_model_key_from}."
515+
)
516+
use_random_state = _new_fitting and (
517+
".descriptor." not in item_key
518+
)
513519
if (
514-
_new_fitting and (".descriptor." not in item_key)
515-
) or ".models.1." in item_key:
520+
not use_random_state
521+
and new_key not in _origin_state_dict
522+
):
523+
# for ZBL models finetuning from standard models
524+
if ".models.0." in new_key:
525+
new_key = new_key.replace(".models.0.", ".")
526+
elif ".models.1." in new_key:
527+
use_random_state = True
528+
else:
529+
raise KeyError(
530+
f"Key {new_key} not found in pretrained model."
531+
)
532+
if use_random_state:
516533
# print(f'Keep {item_key} in old model!')
517534
_new_state_dict[item_key] = (
518535
_random_state_dict[item_key].clone().detach()
519536
)
520537
else:
521-
new_key = item_key.replace(
522-
f".{_model_key}.", f".{_model_key_from}."
523-
).replace(".models.0.", ".") # for ZBL models
524538
# print(f'Replace {item_key} with {new_key} in pretrained_model!')
525539
_new_state_dict[item_key] = (
526540
_origin_state_dict[new_key].clone().detach()

source/tests/pt/test_training.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030

3131

3232
class DPTrainTest:
33+
test_zbl_from_standard: bool = False
34+
3335
def test_dp_train(self) -> None:
3436
# test training from scratch
3537
trainer = get_trainer(deepcopy(self.config))
@@ -95,6 +97,34 @@ def test_dp_train(self) -> None:
9597
state_dict_finetuned_random[state_key],
9698
)
9799

100+
if self.test_zbl_from_standard:
101+
# test fine-tuning using zbl from standard model
102+
finetune_model = (
103+
self.config["training"].get("save_ckpt", "model.ckpt") + ".pt"
104+
)
105+
self.config_zbl["model"], finetune_links = get_finetune_rules(
106+
finetune_model,
107+
self.config_zbl["model"],
108+
)
109+
trainer_finetune_zbl = get_trainer(
110+
deepcopy(self.config_zbl),
111+
finetune_model=finetune_model,
112+
finetune_links=finetune_links,
113+
)
114+
state_dict_finetuned_zbl = trainer_finetune_zbl.wrapper.model.state_dict()
115+
for state_key in state_dict_finetuned_zbl:
116+
if "out_bias" not in state_key and "out_std" not in state_key:
117+
original_key = state_key
118+
if ".models.0." in state_key:
119+
original_key = state_key.replace(".models.0.", ".")
120+
if ".models.1." not in state_key:
121+
torch.testing.assert_close(
122+
state_dict_trained[original_key],
123+
state_dict_finetuned_zbl[state_key],
124+
)
125+
# check running
126+
trainer_finetune_zbl.run()
127+
98128
# check running
99129
trainer_finetune.run()
100130
trainer_finetune_empty.run()
@@ -222,6 +252,18 @@ def setUp(self) -> None:
222252
self.config["training"]["numb_steps"] = 1
223253
self.config["training"]["save_freq"] = 1
224254

255+
self.test_zbl_from_standard = True
256+
257+
input_json_zbl = str(Path(__file__).parent / "water/zbl.json")
258+
with open(input_json_zbl) as f:
259+
self.config_zbl = json.load(f)
260+
data_file = [str(Path(__file__).parent / "water/data/data_0")]
261+
self.config_zbl["training"]["training_data"]["systems"] = data_file
262+
self.config_zbl["training"]["validation_data"]["systems"] = data_file
263+
self.config_zbl["model"] = deepcopy(model_zbl)
264+
self.config_zbl["training"]["numb_steps"] = 1
265+
self.config_zbl["training"]["save_freq"] = 1
266+
225267
def tearDown(self) -> None:
226268
DPTrainTest.tearDown(self)
227269

0 commit comments

Comments
 (0)