Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,20 @@ def change_type_map(
self.reinit_pair_exclude(
map_pair_exclude_types(self.pair_exclude_types, remap_index)
)
if has_new_type:
xp = array_api_compat.array_namespace(self.out_bias)
extend_shape = [
self.out_bias.shape[0],
len(type_map),
*list(self.out_bias.shape[2:]),
]
device = array_api_compat.device(self.out_bias)
extend_bias = xp.zeros(
extend_shape, dtype=self.out_bias.dtype, device=device
)
self.out_bias = xp.concat([self.out_bias, extend_bias], axis=1)
extend_std = xp.ones(extend_shape, dtype=self.out_std.dtype, device=device)
self.out_std = xp.concat([self.out_std, extend_std], axis=1)
self.out_bias = self.out_bias[:, remap_index, :]
self.out_std = self.out_std[:, remap_index, :]

Expand Down
52 changes: 50 additions & 2 deletions deepmd/pt_expt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def get_trainer(
config: dict[str, Any],
init_model: str | None = None,
restart_model: str | None = None,
finetune_model: str | None = None,
finetune_links: dict | None = None,
) -> training.Trainer:
"""Build a :class:`training.Trainer` from a normalised config."""
model_params = config["model"]
Expand Down Expand Up @@ -94,6 +96,8 @@ def get_trainer(
validation_data=validation_data,
init_model=init_model,
restart_model=restart_model,
finetune_model=finetune_model,
finetune_links=finetune_links,
)
return trainer

Expand All @@ -102,6 +106,9 @@ def train(
input_file: str,
init_model: str | None = None,
restart: str | None = None,
finetune: str | None = None,
model_branch: str = "",
use_pretrain_script: bool = False,
skip_neighbor_stat: bool = False,
output: str = "out.json",
) -> None:
Expand All @@ -115,14 +122,25 @@ def train(
Path to a checkpoint to initialise weights from.
restart : str or None
Path to a checkpoint to restart training from.
finetune : str or None
Path to a pretrained checkpoint to fine-tune from.
model_branch : str
Branch to select from a multi-task pretrained model.
use_pretrain_script : bool
If True, copy descriptor/fitting params from the pretrained model.
skip_neighbor_stat : bool
Skip neighbour statistics calculation.
output : str
Where to dump the normalised config.
"""
import torch

from deepmd.common import (
j_loader,
)
from deepmd.pt_expt.utils.env import (
DEVICE,
)

log.info("Configuration path: %s", input_file)
config = j_loader(input_file)
Expand All @@ -133,6 +151,27 @@ def train(
if restart is not None and not restart.endswith(".pt"):
restart += ".pt"

# update fine-tuning config
finetune_links = None
if finetune is not None:
from deepmd.pt_expt.utils.finetune import (
get_finetune_rules,
)

config["model"], finetune_links = get_finetune_rules(
finetune,
config["model"],
model_branch=model_branch,
change_model_params=use_pretrain_script,
)

# update init_model config if --use-pretrain-script
if init_model is not None and use_pretrain_script:
init_state_dict = torch.load(init_model, map_location=DEVICE, weights_only=True)
if "model" in init_state_dict:
init_state_dict = init_state_dict["model"]
config["model"] = init_state_dict["_extra_state"]["model_params"]

# argcheck
config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json")
config = normalize(config)
Expand All @@ -156,7 +195,13 @@ def train(
with open(output, "w") as fp:
json.dump(config, fp, indent=4)

trainer = get_trainer(config, init_model, restart)
trainer = get_trainer(
config,
init_model,
restart,
finetune_model=finetune,
finetune_links=finetune_links,
)
trainer.run()


Expand Down Expand Up @@ -214,7 +259,7 @@ def freeze(
m.eval()

model_dict = m.serialize()
deserialize_to_file(output, {"model": model_dict})
deserialize_to_file(output, {"model": model_dict}, model_params=model_params)
log.info("Saved frozen model to %s", output)


Expand Down Expand Up @@ -250,6 +295,9 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None:
input_file=FLAGS.INPUT,
init_model=FLAGS.init_model,
restart=FLAGS.restart,
finetune=FLAGS.finetune,
model_branch=FLAGS.model_branch,
use_pretrain_script=FLAGS.use_pretrain_script,
skip_neighbor_stat=FLAGS.skip_neighbor_stat,
output=FLAGS.output,
)
Expand Down
158 changes: 142 additions & 16 deletions deepmd/pt_expt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
import numpy as np
import torch

from deepmd.dpmodel.common import (
to_numpy_array,
)
from deepmd.dpmodel.utils.batch import (
normalize_batch,
split_batch,
Expand Down Expand Up @@ -380,8 +383,16 @@ def __init__(
validation_data: DeepmdDataSystem | None = None,
init_model: str | None = None,
restart_model: str | None = None,
finetune_model: str | None = None,
finetune_links: dict | None = None,
) -> None:
resume_model = init_model or restart_model
if finetune_model is not None and (
init_model is not None or restart_model is not None
):
raise ValueError(
"finetune_model cannot be combined with init_model or restart_model."
)
resume_model = init_model or restart_model or finetune_model
resuming = resume_model is not None
Comment thread
coderabbitai[bot] marked this conversation as resolved.
self.restart_training = restart_model is not None

Expand Down Expand Up @@ -429,7 +440,12 @@ def __init__(
def get_sample() -> list[dict[str, np.ndarray]]:
return make_stat_input(training_data, data_stat_nbatch)

if not resuming:
finetune_has_new_type = (
finetune_model is not None
and finetune_links is not None
and finetune_links["Default"].get_has_new_type()
)
if not resuming or finetune_has_new_type:
self.model.compute_or_load_stat(
sampled_func=get_sample,
stat_file_path=stat_file_path,
Expand Down Expand Up @@ -472,23 +488,98 @@ def get_sample() -> list[dict[str, np.ndarray]]:
# Resume --------------------------------------------------------------
if resuming:
log.info(f"Resuming from {resume_model}.")
state_dict = torch.load(
resume_model, map_location=DEVICE, weights_only=True
)
if "model" in state_dict:
optimizer_state_dict = (
state_dict["optimizer"] if self.restart_training else None
is_pte = resume_model.endswith((".pte", ".pt2"))

if is_pte:
# .pte frozen model: no optimizer state, no step counter
optimizer_state_dict = None
self.start_step = 0
else:
state_dict = torch.load(
resume_model, map_location=DEVICE, weights_only=True
)
if "model" in state_dict:
optimizer_state_dict = (
state_dict["optimizer"]
if self.restart_training and finetune_model is None
else None
)
state_dict = state_dict["model"]
else:
optimizer_state_dict = None
self.start_step = (
state_dict["_extra_state"]["train_infos"]["step"]
if self.restart_training
else 0
)

if finetune_model is not None and finetune_links is not None:
# --- Finetune: selective weight loading -----------------------
finetune_rule = finetune_links["Default"]

# Build pretrained model and load weights
if is_pte:
from deepmd.pt_expt.model import (
BaseModel,
)
from deepmd.pt_expt.utils.serialization import (
serialize_from_file,
)

data = serialize_from_file(finetune_model)
pretrained_model = BaseModel.deserialize(data["model"]).to(DEVICE)
else:
pretrained_model = get_model(
deepcopy(state_dict["_extra_state"]["model_params"])
).to(DEVICE)
pretrained_wrapper = ModelWrapper(pretrained_model)
if not is_pte:
pretrained_wrapper.load_state_dict(state_dict)

# Change type map if needed
if (
finetune_rule.get_finetune_tmap()
!= pretrained_wrapper.model.get_type_map()
):
model_with_new_type_stat = (
self.wrapper.model if finetune_rule.get_has_new_type() else None
)
pretrained_wrapper.model.change_type_map(
finetune_rule.get_finetune_tmap(),
model_with_new_type_stat=model_with_new_type_stat,
)

# Selectively copy weights: descriptor always from pretrained,
# fitting from pretrained unless random_fitting is True
pretrained_state = pretrained_wrapper.state_dict()
target_state = self.wrapper.state_dict()
new_state = {}
for key in target_state:
if key == "_extra_state":
new_state[key] = target_state[key]
elif (
finetune_rule.get_random_fitting() and ".descriptor." not in key
):
new_state[key] = target_state[key] # keep random init
elif key in pretrained_state:
new_state[key] = pretrained_state[key] # from pretrained
else:
new_state[key] = target_state[key] # fallback
self.wrapper.load_state_dict(new_state)

# Adjust output bias
bias_mode = (
"change-by-statistic"
if not finetune_rule.get_random_fitting()
else "set-by-statistic"
)
self.model = model_change_out_bias(
self.model, get_sample, _bias_adjust_mode=bias_mode
)
state_dict = state_dict["model"]
else:
optimizer_state_dict = None
# --- Normal resume (init_model / restart) --------------------
self.wrapper.load_state_dict(state_dict)

self.start_step = (
state_dict["_extra_state"]["train_infos"]["step"]
if self.restart_training
else 0
)
self.wrapper.load_state_dict(state_dict)
if optimizer_state_dict is not None:
self.optimizer.load_state_dict(optimizer_state_dict)
# rebuild scheduler from the resumed step.
Expand Down Expand Up @@ -910,3 +1001,38 @@ def print_on_training(
line += f" {cur_lr:8.1e}\n"
fout.write(line)
fout.flush()


def model_change_out_bias(
_model: Any,
_sample_func: Any,
_bias_adjust_mode: str = "change-by-statistic",
) -> Any:
"""Change the output bias of a model based on sampled data.

Parameters
----------
_model
The model whose bias should be adjusted.
_sample_func
Callable that returns sampled data for bias computation.
_bias_adjust_mode
``"change-by-statistic"`` or ``"set-by-statistic"``.

Returns
-------
The model with updated bias.
"""
old_bias = deepcopy(_model.get_out_bias())
_model.change_out_bias(
_sample_func,
bias_adjust_mode=_bias_adjust_mode,
)
new_bias = deepcopy(_model.get_out_bias())
model_type_map = _model.get_type_map()
log.info(
f"Change output bias of {model_type_map!s} "
f"from {to_numpy_array(old_bias).reshape(-1)[: len(model_type_map)]!s} "
f"to {to_numpy_array(new_bias).reshape(-1)[: len(model_type_map)]!s}."
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
return _model
Loading
Loading