Skip to content

Commit 034e613

Browse files
wanghan-iapcmHan Wang
andauthored
feat(pt_expt): add dp finetune support (#5331)
## Summary - Add `--finetune`, `--model-branch`, and `--use-pretrain-script` support to `dp --pt-expt train`, mirroring the pt backend's finetune flow (load pretrained checkpoint, change type map, selective weight copy, output bias adjustment) - Support finetuning from both `.pt` checkpoints and frozen `.pte` models (embed `model_params` in `.pte` during freeze for `--use-pretrain-script`) - Fix a bug in dpmodel's `base_atomic_model.change_type_map` where `out_bias`/`out_std` were not extended before remapping when the new type map introduces unseen types, causing `IndexError` with negative remap indices ## Usage examples ```bash # Finetune from a .pt checkpoint dp --pt-expt train input.json --finetune pretrained.pt # Finetune from a frozen .pte model dp --pt-expt train input.json --finetune pretrained.pte # Copy descriptor/fitting config from pretrained model dp --pt-expt train input.json --finetune pretrained.pt --use-pretrain-script # Finetune from a multi-task pretrained model (select a branch) dp --pt-expt train input.json --finetune pretrained.pt --model-branch Default # Re-initialize fitting net randomly (only keep descriptor weights) dp --pt-expt train input.json --finetune pretrained.pt --model-branch RANDOM ``` ## Files changed | File | Change | |------|--------| | `deepmd/pt_expt/utils/finetune.py` | **New** — `get_finetune_rules()` for pt_expt, supports `.pt` and `.pte` | | `deepmd/pt_expt/entrypoints/main.py` | Wire `--finetune`/`--model-branch`/`--use-pretrain-script` through `train()` → `get_trainer()` → `Trainer`; pass `model_params` to `.pte` during freeze | | `deepmd/pt_expt/train/training.py` | Finetune weight loading in `Trainer.__init__` (`.pt` and `.pte`); `model_change_out_bias()` | | `deepmd/pt_expt/utils/serialization.py` | Embed/extract `model_params.json` in `.pte` archive | | `deepmd/dpmodel/atomic_model/base_atomic_model.py` | Fix `change_type_map` to extend `out_bias`/`out_std` for new types (array-api compatible) | | `source/tests/pt_expt/test_finetune.py` | **New** — 9 tests covering bias adjustment, type map change, CLI dispatch, `.pte` finetune, `--use-pretrain-script`, `random_fitting`, inherited weight consistency | | `source/tests/consistent/model/test_ener.py` | Add `test_change_type_map_new_type` verifying `out_bias`/`out_std` extension across dp, pt, pt_expt | ## Test plan - [x] `python -m pytest source/tests/pt_expt/test_finetune.py -v` (9 passed) - [x] `python -m pytest source/tests/pt_expt/test_training.py -v` (11 passed, no regression) - [x] `python -m pytest source/tests/consistent/model/test_ener.py -k change_type_map -v` (3 passed) - [x] `python -m pytest source/tests/consistent/descriptor/test_se_e2_a.py -v` (351 passed, no regression) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Fine-tuning workflow: supply pretrained checkpoints, select branch, and toggle pretrain-script behavior * Automatic expansion of atom type maps (new types get zero bias and unit std) while preserving existing mappings * Improved finetune resume: selective merging of pretrained descriptor/fitting weights and bias-adjustment modes * Export/import embeds/restores model metadata to/from artifacts * **Tests** * Unit and end-to-end tests for finetuning, bias adjustment, type-map expansion, and frozen-artifact scenarios <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: Han Wang <wang_han@iapcm.ac.cn>
1 parent 6122d97 commit 034e613

File tree

7 files changed

+1357
-20
lines changed

7 files changed

+1357
-20
lines changed

deepmd/dpmodel/atomic_model/base_atomic_model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,20 @@ def change_type_map(
206206
self.reinit_pair_exclude(
207207
map_pair_exclude_types(self.pair_exclude_types, remap_index)
208208
)
209+
if has_new_type:
210+
xp = array_api_compat.array_namespace(self.out_bias)
211+
extend_shape = [
212+
self.out_bias.shape[0],
213+
len(type_map),
214+
*list(self.out_bias.shape[2:]),
215+
]
216+
device = array_api_compat.device(self.out_bias)
217+
extend_bias = xp.zeros(
218+
extend_shape, dtype=self.out_bias.dtype, device=device
219+
)
220+
self.out_bias = xp.concat([self.out_bias, extend_bias], axis=1)
221+
extend_std = xp.ones(extend_shape, dtype=self.out_std.dtype, device=device)
222+
self.out_std = xp.concat([self.out_std, extend_std], axis=1)
209223
self.out_bias = self.out_bias[:, remap_index, :]
210224
self.out_std = self.out_std[:, remap_index, :]
211225

deepmd/pt_expt/entrypoints/main.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def get_trainer(
3838
config: dict[str, Any],
3939
init_model: str | None = None,
4040
restart_model: str | None = None,
41+
finetune_model: str | None = None,
42+
finetune_links: dict | None = None,
4143
) -> training.Trainer:
4244
"""Build a :class:`training.Trainer` from a normalised config."""
4345
model_params = config["model"]
@@ -94,6 +96,8 @@ def get_trainer(
9496
validation_data=validation_data,
9597
init_model=init_model,
9698
restart_model=restart_model,
99+
finetune_model=finetune_model,
100+
finetune_links=finetune_links,
97101
)
98102
return trainer
99103

@@ -102,6 +106,9 @@ def train(
102106
input_file: str,
103107
init_model: str | None = None,
104108
restart: str | None = None,
109+
finetune: str | None = None,
110+
model_branch: str = "",
111+
use_pretrain_script: bool = False,
105112
skip_neighbor_stat: bool = False,
106113
output: str = "out.json",
107114
) -> None:
@@ -115,14 +122,25 @@ def train(
115122
Path to a checkpoint to initialise weights from.
116123
restart : str or None
117124
Path to a checkpoint to restart training from.
125+
finetune : str or None
126+
Path to a pretrained checkpoint to fine-tune from.
127+
model_branch : str
128+
Branch to select from a multi-task pretrained model.
129+
use_pretrain_script : bool
130+
If True, copy descriptor/fitting params from the pretrained model.
118131
skip_neighbor_stat : bool
119132
Skip neighbour statistics calculation.
120133
output : str
121134
Where to dump the normalised config.
122135
"""
136+
import torch
137+
123138
from deepmd.common import (
124139
j_loader,
125140
)
141+
from deepmd.pt_expt.utils.env import (
142+
DEVICE,
143+
)
126144

127145
log.info("Configuration path: %s", input_file)
128146
config = j_loader(input_file)
@@ -133,6 +151,27 @@ def train(
133151
if restart is not None and not restart.endswith(".pt"):
134152
restart += ".pt"
135153

154+
# update fine-tuning config
155+
finetune_links = None
156+
if finetune is not None:
157+
from deepmd.pt_expt.utils.finetune import (
158+
get_finetune_rules,
159+
)
160+
161+
config["model"], finetune_links = get_finetune_rules(
162+
finetune,
163+
config["model"],
164+
model_branch=model_branch,
165+
change_model_params=use_pretrain_script,
166+
)
167+
168+
# update init_model config if --use-pretrain-script
169+
if init_model is not None and use_pretrain_script:
170+
init_state_dict = torch.load(init_model, map_location=DEVICE, weights_only=True)
171+
if "model" in init_state_dict:
172+
init_state_dict = init_state_dict["model"]
173+
config["model"] = init_state_dict["_extra_state"]["model_params"]
174+
136175
# argcheck
137176
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
138177
config = normalize(config)
@@ -156,7 +195,13 @@ def train(
156195
with open(output, "w") as fp:
157196
json.dump(config, fp, indent=4)
158197

159-
trainer = get_trainer(config, init_model, restart)
198+
trainer = get_trainer(
199+
config,
200+
init_model,
201+
restart,
202+
finetune_model=finetune,
203+
finetune_links=finetune_links,
204+
)
160205
trainer.run()
161206

162207

@@ -214,7 +259,7 @@ def freeze(
214259
m.eval()
215260

216261
model_dict = m.serialize()
217-
deserialize_to_file(output, {"model": model_dict})
262+
deserialize_to_file(output, {"model": model_dict}, model_params=model_params)
218263
log.info("Saved frozen model to %s", output)
219264

220265

@@ -250,6 +295,9 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
250295
input_file=FLAGS.INPUT,
251296
init_model=FLAGS.init_model,
252297
restart=FLAGS.restart,
298+
finetune=FLAGS.finetune,
299+
model_branch=FLAGS.model_branch,
300+
use_pretrain_script=FLAGS.use_pretrain_script,
253301
skip_neighbor_stat=FLAGS.skip_neighbor_stat,
254302
output=FLAGS.output,
255303
)

deepmd/pt_expt/train/training.py

Lines changed: 142 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
import numpy as np
2424
import torch
2525

26+
from deepmd.dpmodel.common import (
27+
to_numpy_array,
28+
)
2629
from deepmd.dpmodel.utils.batch import (
2730
normalize_batch,
2831
split_batch,
@@ -380,8 +383,16 @@ def __init__(
380383
validation_data: DeepmdDataSystem | None = None,
381384
init_model: str | None = None,
382385
restart_model: str | None = None,
386+
finetune_model: str | None = None,
387+
finetune_links: dict | None = None,
383388
) -> None:
384-
resume_model = init_model or restart_model
389+
if finetune_model is not None and (
390+
init_model is not None or restart_model is not None
391+
):
392+
raise ValueError(
393+
"finetune_model cannot be combined with init_model or restart_model."
394+
)
395+
resume_model = init_model or restart_model or finetune_model
385396
resuming = resume_model is not None
386397
self.restart_training = restart_model is not None
387398

@@ -429,7 +440,12 @@ def __init__(
429440
def get_sample() -> list[dict[str, np.ndarray]]:
430441
return make_stat_input(training_data, data_stat_nbatch)
431442

432-
if not resuming:
443+
finetune_has_new_type = (
444+
finetune_model is not None
445+
and finetune_links is not None
446+
and finetune_links["Default"].get_has_new_type()
447+
)
448+
if not resuming or finetune_has_new_type:
433449
self.model.compute_or_load_stat(
434450
sampled_func=get_sample,
435451
stat_file_path=stat_file_path,
@@ -472,23 +488,98 @@ def get_sample() -> list[dict[str, np.ndarray]]:
472488
# Resume --------------------------------------------------------------
473489
if resuming:
474490
log.info(f"Resuming from {resume_model}.")
475-
state_dict = torch.load(
476-
resume_model, map_location=DEVICE, weights_only=True
477-
)
478-
if "model" in state_dict:
479-
optimizer_state_dict = (
480-
state_dict["optimizer"] if self.restart_training else None
491+
is_pte = resume_model.endswith((".pte", ".pt2"))
492+
493+
if is_pte:
494+
# .pte frozen model: no optimizer state, no step counter
495+
optimizer_state_dict = None
496+
self.start_step = 0
497+
else:
498+
state_dict = torch.load(
499+
resume_model, map_location=DEVICE, weights_only=True
500+
)
501+
if "model" in state_dict:
502+
optimizer_state_dict = (
503+
state_dict["optimizer"]
504+
if self.restart_training and finetune_model is None
505+
else None
506+
)
507+
state_dict = state_dict["model"]
508+
else:
509+
optimizer_state_dict = None
510+
self.start_step = (
511+
state_dict["_extra_state"]["train_infos"]["step"]
512+
if self.restart_training
513+
else 0
514+
)
515+
516+
if finetune_model is not None and finetune_links is not None:
517+
# --- Finetune: selective weight loading -----------------------
518+
finetune_rule = finetune_links["Default"]
519+
520+
# Build pretrained model and load weights
521+
if is_pte:
522+
from deepmd.pt_expt.model import (
523+
BaseModel,
524+
)
525+
from deepmd.pt_expt.utils.serialization import (
526+
serialize_from_file,
527+
)
528+
529+
data = serialize_from_file(finetune_model)
530+
pretrained_model = BaseModel.deserialize(data["model"]).to(DEVICE)
531+
else:
532+
pretrained_model = get_model(
533+
deepcopy(state_dict["_extra_state"]["model_params"])
534+
).to(DEVICE)
535+
pretrained_wrapper = ModelWrapper(pretrained_model)
536+
if not is_pte:
537+
pretrained_wrapper.load_state_dict(state_dict)
538+
539+
# Change type map if needed
540+
if (
541+
finetune_rule.get_finetune_tmap()
542+
!= pretrained_wrapper.model.get_type_map()
543+
):
544+
model_with_new_type_stat = (
545+
self.wrapper.model if finetune_rule.get_has_new_type() else None
546+
)
547+
pretrained_wrapper.model.change_type_map(
548+
finetune_rule.get_finetune_tmap(),
549+
model_with_new_type_stat=model_with_new_type_stat,
550+
)
551+
552+
# Selectively copy weights: descriptor always from pretrained,
553+
# fitting from pretrained unless random_fitting is True
554+
pretrained_state = pretrained_wrapper.state_dict()
555+
target_state = self.wrapper.state_dict()
556+
new_state = {}
557+
for key in target_state:
558+
if key == "_extra_state":
559+
new_state[key] = target_state[key]
560+
elif (
561+
finetune_rule.get_random_fitting() and ".descriptor." not in key
562+
):
563+
new_state[key] = target_state[key] # keep random init
564+
elif key in pretrained_state:
565+
new_state[key] = pretrained_state[key] # from pretrained
566+
else:
567+
new_state[key] = target_state[key] # fallback
568+
self.wrapper.load_state_dict(new_state)
569+
570+
# Adjust output bias
571+
bias_mode = (
572+
"change-by-statistic"
573+
if not finetune_rule.get_random_fitting()
574+
else "set-by-statistic"
575+
)
576+
self.model = model_change_out_bias(
577+
self.model, get_sample, _bias_adjust_mode=bias_mode
481578
)
482-
state_dict = state_dict["model"]
483579
else:
484-
optimizer_state_dict = None
580+
# --- Normal resume (init_model / restart) --------------------
581+
self.wrapper.load_state_dict(state_dict)
485582

486-
self.start_step = (
487-
state_dict["_extra_state"]["train_infos"]["step"]
488-
if self.restart_training
489-
else 0
490-
)
491-
self.wrapper.load_state_dict(state_dict)
492583
if optimizer_state_dict is not None:
493584
self.optimizer.load_state_dict(optimizer_state_dict)
494585
# rebuild scheduler from the resumed step.
@@ -910,3 +1001,38 @@ def print_on_training(
9101001
line += f" {cur_lr:8.1e}\n"
9111002
fout.write(line)
9121003
fout.flush()
1004+
1005+
1006+
def model_change_out_bias(
1007+
_model: Any,
1008+
_sample_func: Any,
1009+
_bias_adjust_mode: str = "change-by-statistic",
1010+
) -> Any:
1011+
"""Change the output bias of a model based on sampled data.
1012+
1013+
Parameters
1014+
----------
1015+
_model
1016+
The model whose bias should be adjusted.
1017+
_sample_func
1018+
Callable that returns sampled data for bias computation.
1019+
_bias_adjust_mode
1020+
``"change-by-statistic"`` or ``"set-by-statistic"``.
1021+
1022+
Returns
1023+
-------
1024+
The model with updated bias.
1025+
"""
1026+
old_bias = deepcopy(_model.get_out_bias())
1027+
_model.change_out_bias(
1028+
_sample_func,
1029+
bias_adjust_mode=_bias_adjust_mode,
1030+
)
1031+
new_bias = deepcopy(_model.get_out_bias())
1032+
model_type_map = _model.get_type_map()
1033+
log.info(
1034+
f"Change output bias of {model_type_map!s} "
1035+
f"from {to_numpy_array(old_bias).reshape(-1)[: len(model_type_map)]!s} "
1036+
f"to {to_numpy_array(new_bias).reshape(-1)[: len(model_type_map)]!s}."
1037+
)
1038+
return _model

0 commit comments

Comments
 (0)