diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 1d3ff5aa4a..1120078bb2 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -206,6 +206,20 @@ def change_type_map( self.reinit_pair_exclude( map_pair_exclude_types(self.pair_exclude_types, remap_index) ) + if has_new_type: + xp = array_api_compat.array_namespace(self.out_bias) + extend_shape = [ + self.out_bias.shape[0], + len(type_map), + *list(self.out_bias.shape[2:]), + ] + device = array_api_compat.device(self.out_bias) + extend_bias = xp.zeros( + extend_shape, dtype=self.out_bias.dtype, device=device + ) + self.out_bias = xp.concat([self.out_bias, extend_bias], axis=1) + extend_std = xp.ones(extend_shape, dtype=self.out_std.dtype, device=device) + self.out_std = xp.concat([self.out_std, extend_std], axis=1) self.out_bias = self.out_bias[:, remap_index, :] self.out_std = self.out_std[:, remap_index, :] diff --git a/deepmd/pt_expt/entrypoints/main.py b/deepmd/pt_expt/entrypoints/main.py index 18e9e7dc8f..0bf70dd93b 100644 --- a/deepmd/pt_expt/entrypoints/main.py +++ b/deepmd/pt_expt/entrypoints/main.py @@ -38,6 +38,8 @@ def get_trainer( config: dict[str, Any], init_model: str | None = None, restart_model: str | None = None, + finetune_model: str | None = None, + finetune_links: dict | None = None, ) -> training.Trainer: """Build a :class:`training.Trainer` from a normalised config.""" model_params = config["model"] @@ -94,6 +96,8 @@ def get_trainer( validation_data=validation_data, init_model=init_model, restart_model=restart_model, + finetune_model=finetune_model, + finetune_links=finetune_links, ) return trainer @@ -102,6 +106,9 @@ def train( input_file: str, init_model: str | None = None, restart: str | None = None, + finetune: str | None = None, + model_branch: str = "", + use_pretrain_script: bool = False, skip_neighbor_stat: bool = False, output: str = "out.json", ) -> None: @@ -115,14 +122,25 @@ def train( Path to a checkpoint to initialise weights from. restart : str or None Path to a checkpoint to restart training from. + finetune : str or None + Path to a pretrained checkpoint to fine-tune from. + model_branch : str + Branch to select from a multi-task pretrained model. + use_pretrain_script : bool + If True, copy descriptor/fitting params from the pretrained model. skip_neighbor_stat : bool Skip neighbour statistics calculation. output : str Where to dump the normalised config. """ + import torch + from deepmd.common import ( j_loader, ) + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) log.info("Configuration path: %s", input_file) config = j_loader(input_file) @@ -133,6 +151,27 @@ def train( if restart is not None and not restart.endswith(".pt"): restart += ".pt" + # update fine-tuning config + finetune_links = None + if finetune is not None: + from deepmd.pt_expt.utils.finetune import ( + get_finetune_rules, + ) + + config["model"], finetune_links = get_finetune_rules( + finetune, + config["model"], + model_branch=model_branch, + change_model_params=use_pretrain_script, + ) + + # update init_model config if --use-pretrain-script + if init_model is not None and use_pretrain_script: + init_state_dict = torch.load(init_model, map_location=DEVICE, weights_only=True) + if "model" in init_state_dict: + init_state_dict = init_state_dict["model"] + config["model"] = init_state_dict["_extra_state"]["model_params"] + # argcheck config = update_deepmd_input(config, warning=True, dump="input_v2_compat.json") config = normalize(config) @@ -156,7 +195,13 @@ def train( with open(output, "w") as fp: json.dump(config, fp, indent=4) - trainer = get_trainer(config, init_model, restart) + trainer = get_trainer( + config, + init_model, + restart, + finetune_model=finetune, + finetune_links=finetune_links, + ) trainer.run() @@ -214,7 +259,7 @@ def freeze( m.eval() model_dict = m.serialize() - deserialize_to_file(output, {"model": model_dict}) + deserialize_to_file(output, {"model": model_dict}, model_params=model_params) log.info("Saved frozen model to %s", output) @@ -250,6 +295,9 @@ def main(args: list[str] | argparse.Namespace | None = None) -> None: input_file=FLAGS.INPUT, init_model=FLAGS.init_model, restart=FLAGS.restart, + finetune=FLAGS.finetune, + model_branch=FLAGS.model_branch, + use_pretrain_script=FLAGS.use_pretrain_script, skip_neighbor_stat=FLAGS.skip_neighbor_stat, output=FLAGS.output, ) diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 2e3bb75abf..a206302818 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -23,6 +23,9 @@ import numpy as np import torch +from deepmd.dpmodel.common import ( + to_numpy_array, +) from deepmd.dpmodel.utils.batch import ( normalize_batch, split_batch, @@ -380,8 +383,16 @@ def __init__( validation_data: DeepmdDataSystem | None = None, init_model: str | None = None, restart_model: str | None = None, + finetune_model: str | None = None, + finetune_links: dict | None = None, ) -> None: - resume_model = init_model or restart_model + if finetune_model is not None and ( + init_model is not None or restart_model is not None + ): + raise ValueError( + "finetune_model cannot be combined with init_model or restart_model." + ) + resume_model = init_model or restart_model or finetune_model resuming = resume_model is not None self.restart_training = restart_model is not None @@ -429,7 +440,12 @@ def __init__( def get_sample() -> list[dict[str, np.ndarray]]: return make_stat_input(training_data, data_stat_nbatch) - if not resuming: + finetune_has_new_type = ( + finetune_model is not None + and finetune_links is not None + and finetune_links["Default"].get_has_new_type() + ) + if not resuming or finetune_has_new_type: self.model.compute_or_load_stat( sampled_func=get_sample, stat_file_path=stat_file_path, @@ -472,23 +488,98 @@ def get_sample() -> list[dict[str, np.ndarray]]: # Resume -------------------------------------------------------------- if resuming: log.info(f"Resuming from {resume_model}.") - state_dict = torch.load( - resume_model, map_location=DEVICE, weights_only=True - ) - if "model" in state_dict: - optimizer_state_dict = ( - state_dict["optimizer"] if self.restart_training else None + is_pte = resume_model.endswith((".pte", ".pt2")) + + if is_pte: + # .pte frozen model: no optimizer state, no step counter + optimizer_state_dict = None + self.start_step = 0 + else: + state_dict = torch.load( + resume_model, map_location=DEVICE, weights_only=True + ) + if "model" in state_dict: + optimizer_state_dict = ( + state_dict["optimizer"] + if self.restart_training and finetune_model is None + else None + ) + state_dict = state_dict["model"] + else: + optimizer_state_dict = None + self.start_step = ( + state_dict["_extra_state"]["train_infos"]["step"] + if self.restart_training + else 0 + ) + + if finetune_model is not None and finetune_links is not None: + # --- Finetune: selective weight loading ----------------------- + finetune_rule = finetune_links["Default"] + + # Build pretrained model and load weights + if is_pte: + from deepmd.pt_expt.model import ( + BaseModel, + ) + from deepmd.pt_expt.utils.serialization import ( + serialize_from_file, + ) + + data = serialize_from_file(finetune_model) + pretrained_model = BaseModel.deserialize(data["model"]).to(DEVICE) + else: + pretrained_model = get_model( + deepcopy(state_dict["_extra_state"]["model_params"]) + ).to(DEVICE) + pretrained_wrapper = ModelWrapper(pretrained_model) + if not is_pte: + pretrained_wrapper.load_state_dict(state_dict) + + # Change type map if needed + if ( + finetune_rule.get_finetune_tmap() + != pretrained_wrapper.model.get_type_map() + ): + model_with_new_type_stat = ( + self.wrapper.model if finetune_rule.get_has_new_type() else None + ) + pretrained_wrapper.model.change_type_map( + finetune_rule.get_finetune_tmap(), + model_with_new_type_stat=model_with_new_type_stat, + ) + + # Selectively copy weights: descriptor always from pretrained, + # fitting from pretrained unless random_fitting is True + pretrained_state = pretrained_wrapper.state_dict() + target_state = self.wrapper.state_dict() + new_state = {} + for key in target_state: + if key == "_extra_state": + new_state[key] = target_state[key] + elif ( + finetune_rule.get_random_fitting() and ".descriptor." not in key + ): + new_state[key] = target_state[key] # keep random init + elif key in pretrained_state: + new_state[key] = pretrained_state[key] # from pretrained + else: + new_state[key] = target_state[key] # fallback + self.wrapper.load_state_dict(new_state) + + # Adjust output bias + bias_mode = ( + "change-by-statistic" + if not finetune_rule.get_random_fitting() + else "set-by-statistic" + ) + self.model = model_change_out_bias( + self.model, get_sample, _bias_adjust_mode=bias_mode ) - state_dict = state_dict["model"] else: - optimizer_state_dict = None + # --- Normal resume (init_model / restart) -------------------- + self.wrapper.load_state_dict(state_dict) - self.start_step = ( - state_dict["_extra_state"]["train_infos"]["step"] - if self.restart_training - else 0 - ) - self.wrapper.load_state_dict(state_dict) if optimizer_state_dict is not None: self.optimizer.load_state_dict(optimizer_state_dict) # rebuild scheduler from the resumed step. @@ -910,3 +1001,38 @@ def print_on_training( line += f" {cur_lr:8.1e}\n" fout.write(line) fout.flush() + + +def model_change_out_bias( + _model: Any, + _sample_func: Any, + _bias_adjust_mode: str = "change-by-statistic", +) -> Any: + """Change the output bias of a model based on sampled data. + + Parameters + ---------- + _model + The model whose bias should be adjusted. + _sample_func + Callable that returns sampled data for bias computation. + _bias_adjust_mode + ``"change-by-statistic"`` or ``"set-by-statistic"``. + + Returns + ------- + The model with updated bias. + """ + old_bias = deepcopy(_model.get_out_bias()) + _model.change_out_bias( + _sample_func, + bias_adjust_mode=_bias_adjust_mode, + ) + new_bias = deepcopy(_model.get_out_bias()) + model_type_map = _model.get_type_map() + log.info( + f"Change output bias of {model_type_map!s} " + f"from {to_numpy_array(old_bias).reshape(-1)[: len(model_type_map)]!s} " + f"to {to_numpy_array(new_bias).reshape(-1)[: len(model_type_map)]!s}." + ) + return _model diff --git a/deepmd/pt_expt/utils/finetune.py b/deepmd/pt_expt/utils/finetune.py new file mode 100644 index 0000000000..cc24aec219 --- /dev/null +++ b/deepmd/pt_expt/utils/finetune.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Finetune utilities for the pt_expt backend. + +Supports finetuning from both ``.pt`` checkpoints and frozen ``.pte`` models. +""" + +from typing import ( + Any, +) + +import torch + +from deepmd.pt.utils.finetune import ( + get_finetune_rule_single, +) +from deepmd.pt_expt.utils.env import ( + DEVICE, +) +from deepmd.utils.finetune import ( + FinetuneRuleItem, +) + + +def _is_pte(path: str) -> bool: + return path.endswith((".pte", ".pt2")) + + +def _load_model_params(finetune_model: str) -> dict[str, Any]: + """Extract model_params from a ``.pt`` checkpoint or ``.pte`` frozen model.""" + if _is_pte(finetune_model): + from deepmd.pt_expt.utils.serialization import ( + serialize_from_file, + ) + + data = serialize_from_file(finetune_model) + # Prefer embedded model_params (full config); fall back to + # a minimal dict with just type_map for older .pte files. + if "model_params" in data: + return data["model_params"] + return {"type_map": data["model"]["type_map"]} + else: + state_dict = torch.load(finetune_model, map_location=DEVICE, weights_only=True) + if "model" in state_dict: + state_dict = state_dict["model"] + return state_dict["_extra_state"]["model_params"] + + +def get_finetune_rules( + finetune_model: str, + model_config: dict[str, Any], + model_branch: str = "", + change_model_params: bool = True, +) -> tuple[dict[str, Any], dict[str, FinetuneRuleItem]]: + """Get fine-tuning rules for a single-task pt_expt model. + + Loads a pretrained ``.pt`` checkpoint or ``.pte`` frozen model and + builds ``FinetuneRuleItem`` objects describing how to map types and + weights from the pretrained model to the new model. + + Parameters + ---------- + finetune_model : str + Path to the pretrained model (``.pt`` or ``.pte``). + model_config : dict + The model section of the fine-tuning config. + model_branch : str + Branch to select from a multi-task pretrained model (command-line). + change_model_params : bool + Whether to overwrite descriptor/fitting params from the pretrained + model. Not supported for ``.pte`` sources. + + Returns + ------- + model_config : dict + Possibly updated model config. + finetune_links : dict[str, FinetuneRuleItem] + Fine-tuning rules keyed by ``"Default"``. + """ + last_model_params = _load_model_params(finetune_model) + + if change_model_params and "descriptor" not in last_model_params: + raise ValueError( + "Cannot use --use-pretrain-script: the pretrained model does not " + "contain full model params. If finetuning from a .pte file, " + "re-freeze it with the latest code so that model_params is embedded." + ) + + finetune_from_multi_task = "model_dict" in last_model_params + + # pt_expt is single-task only + if model_branch == "" and "finetune_head" in model_config: + model_branch = model_config["finetune_head"] + model_config, finetune_rule = get_finetune_rule_single( + model_config, + last_model_params, + from_multitask=finetune_from_multi_task, + model_branch="Default", + model_branch_from=model_branch, + change_model_params=change_model_params, + ) + finetune_links: dict[str, FinetuneRuleItem] = {"Default": finetune_rule} + return model_config, finetune_links diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index 4587bc7931..b9914e2ac9 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -211,18 +211,23 @@ def serialize_from_file(model_file: str) -> dict: Returns ------- dict - The serialized model data. + The serialized model data. If the archive contains + ``model_params.json``, it is included under the + ``"model_params"`` key. """ - extra_files = {"model.json": ""} + extra_files = {"model.json": "", "model_params.json": ""} torch.export.load(model_file, extra_files=extra_files) model_dict = json.loads(extra_files["model.json"]) model_dict = _json_to_numpy(model_dict) + if extra_files["model_params.json"]: + model_dict["model_params"] = json.loads(extra_files["model_params.json"]) return model_dict def deserialize_to_file( model_file: str, data: dict, + model_params: dict | None = None, model_json_override: dict | None = None, ) -> None: """Deserialize a dictionary to a .pte model file. @@ -237,6 +242,10 @@ def deserialize_to_file( data : dict The dictionary to be deserialized (same format as dpmodel's serialize output, with "model" and optionally "model_def_script" keys). + model_params : dict or None + Original model config (the dict passed to ``get_model``). + If provided, embedded in the .pte so that ``--use-pretrain-script`` + can extract descriptor/fitting params at finetune time. model_json_override : dict or None If provided, this dict is stored in model.json instead of ``data``. Used by ``dp compress`` to store the compressed model dict while @@ -299,6 +308,8 @@ def deserialize_to_file( "model_def_script.json": json.dumps(metadata), "model.json": json.dumps(data_for_json, separators=(",", ":")), } + if model_params is not None: + extra_files["model_params.json"] = json.dumps(model_params) # 7. Save torch.export.save(exported, model_file, extra_files=extra_files) diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 4c914bff41..def9f67f32 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -1283,6 +1283,160 @@ def test_change_type_map_extend_stat(self) -> None: atol=1e-10, ) + def test_change_type_map_new_type(self) -> None: + """change_type_map with new types should extend out_bias/out_std consistently across dp, pt, and pt_expt. + + When the new type_map introduces types not present in the original + type_map, the model's out_bias must be extended (zeros for bias, + ones for std) before remapping. This test verifies the fix in + dpmodel's base_atomic_model.change_type_map. + """ + from deepmd.utils.argcheck import model_args as model_args_fn + + small_tm = ["O", "H"] + large_tm = ["H", "X1", "X2", "O", "B"] + + data = model_args_fn().normalize_value( + { + "type_map": small_tm, + "descriptor": { + "type": "se_atten", + "sel": 20, + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [3, 6], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "seed": 1, + "attn": 6, + "attn_layer": 0, + }, + "fitting_net": { + "neuron": [5, 5], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + }, + trim_pattern="_*", + ) + dp_model = get_model_dp(data) + + # Set non-zero out_bias so the remap is non-trivial + dp_bias_orig = to_numpy_array(dp_model.get_out_bias()).copy() + new_bias = dp_bias_orig.copy() + new_bias[:, 0, :] = 1.5 # type 0 ("O") + new_bias[:, 1, :] = -3.7 # type 1 ("H") + dp_model.set_out_bias(new_bias) + + # Snapshot out_std before change_type_map for remap verification + dp_std_before = to_numpy_array(dp_model.atomic_model.out_std).copy() + + # Build pt and pt_expt models from dp serialization + pt_model = EnergyModelPT.deserialize(dp_model.serialize()) + pt_expt_model = EnergyModelPTExpt.deserialize(dp_model.serialize()) + + # Extend type map with new types (no model_with_new_type_stat) + dp_model.change_type_map(large_tm) + pt_model.change_type_map(large_tm) + pt_expt_model.change_type_map(large_tm) + + # All should have the new type_map + self.assertEqual(dp_model.get_type_map(), large_tm) + self.assertEqual(pt_model.get_type_map(), large_tm) + self.assertEqual(pt_expt_model.get_type_map(), large_tm) + + # Out_bias should be consistent across all backends + dp_bias_new = to_numpy_array(dp_model.get_out_bias()) + pt_bias_new = torch_to_numpy(pt_model.get_out_bias()) + pt_expt_bias_new = to_numpy_array(pt_expt_model.get_out_bias()) + + np.testing.assert_allclose( + dp_bias_new, + pt_bias_new, + rtol=1e-10, + atol=1e-10, + err_msg="dp vs pt out_bias mismatch after change_type_map with new types", + ) + np.testing.assert_allclose( + dp_bias_new, + pt_expt_bias_new, + rtol=1e-10, + atol=1e-10, + err_msg="dp vs pt_expt out_bias mismatch after change_type_map with new types", + ) + + # Verify remap correctness: + # large_tm = ["H", "X1", "X2", "O", "B"] + # old "O" (index 0) -> new index 3 + # old "H" (index 1) -> new index 0 + # new types X1(1), X2(2), B(4) -> bias should be 0 + np.testing.assert_allclose( + dp_bias_new[:, 3, :], + new_bias[:, 0, :], # O + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + dp_bias_new[:, 0, :], + new_bias[:, 1, :], # H + rtol=1e-10, + atol=1e-10, + ) + for idx in [1, 2, 4]: # X1, X2, B + np.testing.assert_allclose( + dp_bias_new[:, idx, :], + 0.0, + rtol=1e-10, + atol=1e-10, + err_msg=f"new type at index {idx} should have zero bias", + ) + + # Out_std for new types should be 1.0 (default) + dp_std_new = to_numpy_array(dp_model.atomic_model.out_std) + pt_std_new = torch_to_numpy(pt_model.atomic_model.out_std) + pt_expt_std_new = to_numpy_array(pt_expt_model.atomic_model.out_std) + + np.testing.assert_allclose( + dp_std_new, + pt_std_new, + rtol=1e-10, + atol=1e-10, + err_msg="dp vs pt out_std mismatch after change_type_map with new types", + ) + np.testing.assert_allclose( + dp_std_new, + pt_expt_std_new, + rtol=1e-10, + atol=1e-10, + err_msg="dp vs pt_expt out_std mismatch after change_type_map with new types", + ) + # Verify old types' std was remapped correctly + # old "O" (index 0) -> new index 3, old "H" (index 1) -> new index 0 + np.testing.assert_allclose( + dp_std_new[:, 3, :], + dp_std_before[:, 0, :], # O + rtol=1e-10, + atol=1e-10, + err_msg="out_std for O should be remapped from old index 0", + ) + np.testing.assert_allclose( + dp_std_new[:, 0, :], + dp_std_before[:, 1, :], # H + rtol=1e-10, + atol=1e-10, + err_msg="out_std for H should be remapped from old index 1", + ) + for idx in [1, 2, 4]: # X1, X2, B + np.testing.assert_allclose( + dp_std_new[:, idx, :], + 1.0, + rtol=1e-10, + atol=1e-10, + err_msg=f"new type at index {idx} should have unit std", + ) + def test_update_sel(self) -> None: """update_sel should return the same result on dp and pt.""" from unittest.mock import ( diff --git a/source/tests/pt_expt/test_finetune.py b/source/tests/pt_expt/test_finetune.py new file mode 100644 index 0000000000..250ba85d46 --- /dev/null +++ b/source/tests/pt_expt/test_finetune.py @@ -0,0 +1,882 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for ``dp finetune`` in the pt_expt backend. + +Part A: Model-level tests (FinetuneTest mixin) + - test_finetune_change_out_bias + - test_finetune_change_type + +Part B: CLI end-to-end tests (TestFinetuneCLI) + - test_finetune_cli + - test_finetune_cli_use_pretrain_script + - test_finetune_random_fitting +""" + +import json +import os +import shutil +import tempfile +import unittest +from copy import ( + deepcopy, +) + +import numpy as np +import torch + +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.pt_expt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt_expt.model import ( + get_model, +) +from deepmd.pt_expt.train.wrapper import ( + ModelWrapper, +) +from deepmd.pt_expt.utils.env import ( + DEVICE, +) +from deepmd.pt_expt.utils.finetune import ( + get_finetune_rules, +) +from deepmd.pt_expt.utils.stat import ( + make_stat_input, +) +from deepmd.utils.argcheck import ( + normalize, +) +from deepmd.utils.compat import ( + update_deepmd_input, +) +from deepmd.utils.data import ( + DataRequirementItem, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, + process_systems, +) + +EXAMPLE_DIR = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "..", + "examples", + "water", +) + +energy_data_requirement = [ + DataRequirementItem("energy", ndof=1, atomic=False, must=False, high_prec=True), + DataRequirementItem("force", ndof=3, atomic=True, must=False, high_prec=False), + DataRequirementItem("virial", ndof=9, atomic=False, must=False, high_prec=False), + DataRequirementItem("atom_ener", ndof=1, atomic=True, must=False, high_prec=False), + DataRequirementItem( + "atom_pref", ndof=1, atomic=True, must=False, high_prec=False, repeat=3 + ), +] + + +model_se_e2_a = { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [6, 12], + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "resnet_dt": False, + "axis_neuron": 4, + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 1, +} + +model_dpa1 = { + "type_map": ["O", "H"], + "descriptor": { + "type": "dpa1", + "sel": 18, + "rcut_smth": 0.50, + "rcut": 3.00, + "neuron": [8, 16], + "axis_neuron": 4, + "attn": 4, + "attn_layer": 2, + "attn_dotr": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [16, 16], + "resnet_dt": True, + "seed": 1, + }, + "data_stat_nbatch": 1, +} + + +def _subsample_data(src_dir: str, dst_dir: str, nframes: int = 2) -> None: + """Copy a data system, keeping only the first *nframes* frames.""" + import shutil as _shutil + + _shutil.copytree(src_dir, dst_dir, dirs_exist_ok=True) + set_dir = os.path.join(dst_dir, "set.000") + for name in os.listdir(set_dir): + if name.endswith(".npy"): + arr = np.load(os.path.join(set_dir, name)) + np.save(os.path.join(set_dir, name), arr[:nframes]) + + +def _make_config(data_dir: str, model_params: dict, numb_steps: int = 1) -> dict: + """Build a minimal config dict for finetune tests.""" + config = { + "model": deepcopy(model_params), + "learning_rate": { + "type": "exp", + "decay_steps": 500, + "start_lr": 0.001, + "stop_lr": 3.51e-8, + }, + "loss": { + "type": "ener", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_f": 1000, + "limit_pref_f": 1, + "start_pref_v": 0, + "limit_pref_v": 0, + }, + "training": { + "training_data": { + "systems": [os.path.join(data_dir, "data_0")], + "batch_size": 2, + }, + "validation_data": { + "systems": [os.path.join(data_dir, "data_0")], + "batch_size": 2, + "numb_btch": 1, + }, + "numb_steps": numb_steps, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 1, + "save_freq": numb_steps, + }, + } + return config + + +# --------------------------------------------------------------------------- +# Part A: Model-level tests +# --------------------------------------------------------------------------- + + +class FinetuneTest: + """Mixin with model-level finetune tests.""" + + def test_finetune_change_out_bias(self) -> None: + """Train model -> randomize bias -> change_out_bias -> verify shift.""" + # get data + type_map = self.config["model"]["type_map"] + data_systems = process_systems( + self.config["training"]["training_data"]["systems"] + ) + data = DeepmdDataSystem( + systems=data_systems, + batch_size=1, + test_size=1, + type_map=type_map, + trn_all_set=True, + ) + data.add_data_requirements(energy_data_requirement) + sampled = make_stat_input(data, nbatches=1) + + # make sampled of multiple frames with different atom numbs + numb_atom = sampled[0]["atype"].shape[1] + small_numb_atom = numb_atom // 2 + small_atom_data = deepcopy(sampled[0]) + # coord is (nframes, nloc*3) in dpmodel/pt_expt format: + # reshape to 3D, slice atoms, flatten back + nframes = small_atom_data["coord"].shape[0] + coord_3d = small_atom_data["coord"].reshape(nframes, numb_atom, 3) + small_atom_data["coord"] = coord_3d[:, :small_numb_atom, :].reshape( + nframes, small_numb_atom * 3 + ) + small_atom_data["atype"] = small_atom_data["atype"][:, :small_numb_atom] + scale_pref = float(small_numb_atom / numb_atom) + small_atom_data["energy"] *= scale_pref + small_atom_data["natoms"][:, :2] = small_numb_atom + # recount per-type atoms + atype_flat = small_atom_data["atype"][0] + for ii in range(len(type_map)): + small_atom_data["natoms"][:, 2 + ii] = np.sum(atype_flat == ii) + sampled = [sampled[0], small_atom_data] + + # get model and randomize bias + config = deepcopy(self.config) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + model = get_model(config["model"]).to(DEVICE) + + old_bias = model.get_out_bias() + rng = np.random.default_rng(42) + random_bias = rng.standard_normal(to_numpy_array(old_bias).shape).astype( + to_numpy_array(old_bias).dtype + ) + model.set_out_bias(random_bias) + energy_bias_before = to_numpy_array(model.get_out_bias())[0] + + # Run inference BEFORE bias change (need original model predictions) + model.eval() + origin_type_map = type_map + full_type_map = type_map + sorter = np.argsort(full_type_map) + idx_type_map = sorter[ + np.searchsorted(full_type_map, origin_type_map, sorter=sorter) + ] + ntest = 1 + + # model inference (coord needs requires_grad for force via autograd.grad) + coord0 = torch.tensor( + sampled[0]["coord"][:ntest], dtype=torch.float64, device=DEVICE + ).requires_grad_(True) + atype0 = torch.tensor( + sampled[0]["atype"][:ntest], dtype=torch.int64, device=DEVICE + ) + box0 = torch.tensor( + sampled[0]["box"][:ntest], dtype=torch.float64, device=DEVICE + ) + coord1 = torch.tensor( + sampled[1]["coord"][:ntest], dtype=torch.float64, device=DEVICE + ).requires_grad_(True) + atype1 = torch.tensor( + sampled[1]["atype"][:ntest], dtype=torch.int64, device=DEVICE + ) + box1 = torch.tensor( + sampled[1]["box"][:ntest], dtype=torch.float64, device=DEVICE + ) + + energy = model(coord0, atype0, box0)["energy"].detach().cpu().numpy() + energy_small = model(coord1, atype1, box1)["energy"].detach().cpu().numpy() + + # Now change energy bias + model.change_out_bias( + sampled, + bias_adjust_mode="change-by-statistic", + ) + energy_bias_after = to_numpy_array(model.get_out_bias())[0] + + # compute ground-truth bias change via least squares + atom_nums = np.tile( + np.bincount(sampled[0]["atype"][0].astype(int))[idx_type_map], + (ntest, 1), + ) + atom_nums_small = np.tile( + np.bincount(sampled[1]["atype"][0].astype(int))[idx_type_map], + (ntest, 1), + ) + atom_nums = np.concatenate([atom_nums, atom_nums_small], axis=0) + + energy_diff = sampled[0]["energy"][:ntest] - energy + energy_diff_small = sampled[1]["energy"][:ntest] - energy_small + energy_diff = np.concatenate([energy_diff, energy_diff_small], axis=0) + + finetune_shift = ( + energy_bias_after[idx_type_map] - energy_bias_before[idx_type_map] + ).ravel() + ground_truth_shift = np.linalg.lstsq(atom_nums, energy_diff, rcond=None)[ + 0 + ].reshape(-1) + + np.testing.assert_almost_equal(finetune_shift, ground_truth_shift, decimal=10) + + def test_finetune_change_type(self) -> None: + """Train with type_map A -> load with type_map B -> verify consistency. + + Tests that change_type_map + selective weight copy correctly remaps + weights so predictions are identical for the same atoms regardless + of type map ordering. Uses direct weight loading (bypassing Trainer + bias adjustment) to isolate the remapping logic, then verifies the + full Trainer finetune path (with bias adjustment) also works. + """ + if not self.mixed_types: + return + + from deepmd.utils.finetune import ( + get_index_between_two_maps, + ) + + config = deepcopy(self.config) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + + data_type_map = config["model"]["type_map"] + + for old_type_map, new_type_map in [ + [["H", "X1", "X2", "O", "B"], ["O", "H", "B"]], + [["O", "H", "B"], ["H", "X1", "X2", "O", "B"]], + ]: + old_type_map_index = np.array( + [old_type_map.index(i) for i in data_type_map], dtype=np.int32 + ) + new_type_map_index = np.array( + [new_type_map.index(i) for i in data_type_map], dtype=np.int32 + ) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_ft_type_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + # Train pretrained model with old type map + config_old = deepcopy(config) + config_old["model"]["type_map"] = old_type_map + trainer = get_trainer(config_old) + trainer.run() + finetune_ckpt = ( + config_old["training"].get("save_ckpt", "model.ckpt") + ".pt" + ) + + # Load pretrained checkpoint + state_dict = torch.load( + finetune_ckpt, map_location=DEVICE, weights_only=True + ) + if "model" in state_dict: + state_dict = state_dict["model"] + + # Build model_old: same type_map, direct weight load + model_old = get_model( + deepcopy(state_dict["_extra_state"]["model_params"]) + ).to(DEVICE) + wrapper_old = ModelWrapper(model_old) + wrapper_old.load_state_dict(state_dict) + + # Build model_new: change_type_map + selective weight copy + pretrained_model = get_model( + deepcopy(state_dict["_extra_state"]["model_params"]) + ).to(DEVICE) + pretrained_wrapper = ModelWrapper(pretrained_model) + pretrained_wrapper.load_state_dict(state_dict) + + config_new = deepcopy(config) + config_new["model"]["type_map"] = new_type_map + config_new = normalize(config_new) + model_new = get_model(config_new["model"]).to(DEVICE) + wrapper_new = ModelWrapper(model_new) + + _, has_new_type = get_index_between_two_maps(old_type_map, new_type_map) + model_with_new_type_stat = wrapper_new.model if has_new_type else None + pretrained_wrapper.model.change_type_map( + new_type_map, + model_with_new_type_stat=model_with_new_type_stat, + ) + + pre_state = pretrained_wrapper.state_dict() + tgt_state = wrapper_new.state_dict() + new_state = {} + for key in tgt_state: + if key == "_extra_state": + new_state[key] = tgt_state[key] + elif key in pre_state: + new_state[key] = pre_state[key] + else: + new_state[key] = tgt_state[key] + wrapper_new.load_state_dict(new_state) + + # Get sample data for comparison + data_systems = process_systems( + config["training"]["training_data"]["systems"] + ) + data = DeepmdDataSystem( + systems=data_systems, + batch_size=1, + test_size=1, + type_map=data_type_map, + trn_all_set=True, + ) + data.add_data_requirements(energy_data_requirement) + sampled = make_stat_input(data, nbatches=1) + + ntest = 1 + prec = 1e-10 + box = torch.tensor( + sampled[0]["box"][:ntest], dtype=torch.float64, device=DEVICE + ) + atype_raw = torch.tensor( + sampled[0]["atype"][:ntest], dtype=torch.int64, device=DEVICE + ) + + old_index = torch.tensor( + old_type_map_index, dtype=torch.int64, device=DEVICE + ) + new_index = torch.tensor( + new_type_map_index, dtype=torch.int64, device=DEVICE + ) + + model_old.eval() + model_new.eval() + + coord_old = torch.tensor( + sampled[0]["coord"][:ntest], + dtype=torch.float64, + device=DEVICE, + ).requires_grad_(True) + result_old = model_old(coord_old, old_index[atype_raw], box=box) + coord_new = torch.tensor( + sampled[0]["coord"][:ntest], + dtype=torch.float64, + device=DEVICE, + ).requires_grad_(True) + result_new = model_new(coord_new, new_index[atype_raw], box=box) + + for key in ["energy", "force", "virial"]: + torch.testing.assert_close( + result_old[key], + result_new[key], + rtol=prec, + atol=prec, + ) + + # Now verify full Trainer finetune path (with bias adjustment) + config_old_ft = deepcopy(config) + config_old_ft["model"]["type_map"] = old_type_map + config_old_ft["model"], finetune_links_old = get_finetune_rules( + finetune_ckpt, config_old_ft["model"] + ) + trainer_old = get_trainer( + config_old_ft, + finetune_model=finetune_ckpt, + finetune_links=finetune_links_old, + ) + + config_new_ft = deepcopy(config) + config_new_ft["model"]["type_map"] = new_type_map + config_new_ft["model"], finetune_links_new = get_finetune_rules( + finetune_ckpt, config_new_ft["model"] + ) + trainer_new = get_trainer( + config_new_ft, + finetune_model=finetune_ckpt, + finetune_links=finetune_links_new, + ) + + trainer_old.model.eval() + trainer_new.model.eval() + + coord_old2 = torch.tensor( + sampled[0]["coord"][:ntest], + dtype=torch.float64, + device=DEVICE, + ).requires_grad_(True) + result_old2 = trainer_old.model( + coord_old2, old_index[atype_raw], box=box + ) + coord_new2 = torch.tensor( + sampled[0]["coord"][:ntest], + dtype=torch.float64, + device=DEVICE, + ).requires_grad_(True) + result_new2 = trainer_new.model( + coord_new2, new_index[atype_raw], box=box + ) + + for key in ["energy", "force", "virial"]: + torch.testing.assert_close( + result_old2[key], + result_new2[key], + rtol=prec, + atol=prec, + ) + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + +class TestEnergyModelSeA(FinetuneTest, unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = data_dir + cls.config = _make_config(data_dir, model_se_e2_a) + cls.mixed_types = False + + +class TestEnergyModelDPA1(FinetuneTest, unittest.TestCase): + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls._tmpdir = tempfile.mkdtemp(prefix="pt_expt_ft_dpa1_data_") + _subsample_data( + os.path.join(data_dir, "data_0"), + os.path.join(cls._tmpdir, "data_0"), + ) + cls.data_dir = cls._tmpdir + cls.config = _make_config(cls._tmpdir, model_dpa1) + cls.mixed_types = True + + @classmethod + def tearDownClass(cls) -> None: + shutil.rmtree(cls._tmpdir, ignore_errors=True) + + +# --------------------------------------------------------------------------- +# Part B: CLI end-to-end tests +# --------------------------------------------------------------------------- + + +class TestFinetuneCLI(unittest.TestCase): + """End-to-end tests for the ``dp --pt-expt train --finetune`` CLI path.""" + + @classmethod + def setUpClass(cls) -> None: + data_dir = os.path.join(EXAMPLE_DIR, "data") + if not os.path.isdir(data_dir): + raise unittest.SkipTest(f"Example data not found: {data_dir}") + cls.data_dir = data_dir + + def _train_pretrained(self, config: dict, tmpdir: str) -> str: + """Train a 1-step model and return checkpoint path.""" + trainer = get_trainer(config) + trainer.run() + ckpt = os.path.join(tmpdir, "model.ckpt.pt") + self.assertTrue(os.path.exists(ckpt), "Pretrained checkpoint not found") + return ckpt + + def _assert_inherited_weights_match( + self, + ft_state: dict, + pre_state: dict, + random_fitting: bool = False, + ) -> None: + """Assert that inherited weights in finetuned model match pretrained. + + Descriptor weights must always match. Fitting weights must match + unless ``random_fitting`` is True. ``_extra_state`` and out_bias + (adjusted by bias computation) are skipped. + """ + for key in ft_state: + if key == "_extra_state": + continue + if key not in pre_state: + continue + if ".descriptor." in key: + torch.testing.assert_close( + ft_state[key], + pre_state[key], + msg=f"Descriptor weight {key} should match pretrained", + ) + elif ".fitting" in key: + if not random_fitting: + torch.testing.assert_close( + ft_state[key], + pre_state[key], + msg=f"Fitting weight {key} should match pretrained", + ) + else: + # random_fitting: network weights must differ + # (bias_atom_e is set by bias adjustment, not random init) + if ft_state[key].is_floating_point() and "bias_atom_e" not in key: + self.assertFalse( + torch.equal(ft_state[key], pre_state[key]), + msg=f"Fitting weight {key} should NOT match pretrained " + f"when random_fitting=True", + ) + + def test_finetune_cli(self) -> None: + """Train -> finetune via main() dispatcher -> verify checkpoint exists.""" + from deepmd.pt_expt.entrypoints.main import ( + main, + ) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_ft_cli_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + # Phase 1: train pretrained model + config = _make_config(self.data_dir, model_se_e2_a, numb_steps=1) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + ckpt_path = self._train_pretrained(config, tmpdir) + + # Save original bias + state = torch.load(ckpt_path, map_location=DEVICE, weights_only=True) + model_state = state["model"] if "model" in state else state + original_model = get_model(model_state["_extra_state"]["model_params"]).to( + DEVICE + ) + original_wrapper = ModelWrapper(original_model) + original_wrapper.load_state_dict(model_state) + original_bias = to_numpy_array(original_model.get_out_bias()).copy() + + # Phase 2: finetune via CLI (lr=0 so weights stay unchanged) + ft_config = _make_config(self.data_dir, model_se_e2_a, numb_steps=1) + ft_config["learning_rate"]["start_lr"] = 1e-30 + ft_config["learning_rate"]["stop_lr"] = 1e-30 + ft_config_file = os.path.join(tmpdir, "finetune_input.json") + with open(ft_config_file, "w") as f: + json.dump(ft_config, f) + + main( + [ + "train", + ft_config_file, + "--finetune", + ckpt_path, + "--skip-neighbor-stat", + ] + ) + + # Verify new checkpoint exists + ft_ckpt = os.path.join(tmpdir, "model.ckpt.pt") + self.assertTrue(os.path.exists(ft_ckpt), "Finetune checkpoint not found") + + # Load finetuned model and verify bias changed + ft_state = torch.load(ft_ckpt, map_location=DEVICE, weights_only=True) + ft_model_state = ft_state["model"] if "model" in ft_state else ft_state + ft_model = get_model(ft_model_state["_extra_state"]["model_params"]).to( + DEVICE + ) + ft_wrapper = ModelWrapper(ft_model) + ft_wrapper.load_state_dict(ft_model_state) + ft_bias = to_numpy_array(ft_model.get_out_bias()) + + # Bias should have been adjusted (may or may not differ depending + # on data, but the checkpoint should at least be valid) + self.assertEqual(original_bias.shape, ft_bias.shape) + + # Inherited weights (descriptor + fitting) must match pretrained. + # lr=0 so training step doesn't modify weights. + self._assert_inherited_weights_match( + ft_model_state, model_state, random_fitting=False + ) + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_finetune_cli_use_pretrain_script(self) -> None: + """Finetune with --use-pretrain-script -> config copied from pretrained.""" + from deepmd.pt_expt.entrypoints.main import ( + main, + ) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_ft_pretrain_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + # Phase 1: train pretrained model + config = _make_config(self.data_dir, model_se_e2_a, numb_steps=1) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + ckpt_path = self._train_pretrained(config, tmpdir) + + # Phase 2: finetune with --use-pretrain-script + # Use a config with different descriptor neuron sizes + ft_model_params = deepcopy(model_se_e2_a) + ft_model_params["descriptor"]["neuron"] = [4, 8] # different + ft_config = _make_config(self.data_dir, ft_model_params, numb_steps=1) + ft_config_file = os.path.join(tmpdir, "finetune_input.json") + with open(ft_config_file, "w") as f: + json.dump(ft_config, f) + + main( + [ + "train", + ft_config_file, + "--finetune", + ckpt_path, + "--use-pretrain-script", + "--skip-neighbor-stat", + ] + ) + + # Verify the output config was updated from pretrained + with open(os.path.join(tmpdir, "out.json")) as f: + output_config = json.load(f) + # Descriptor neuron should be from pretrained, not from ft_config + self.assertEqual( + output_config["model"]["descriptor"]["neuron"], + model_se_e2_a["descriptor"]["neuron"], + ) + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_finetune_random_fitting(self) -> None: + """Finetune with --model-branch RANDOM -> descriptor from pretrained, fitting random.""" + tmpdir = tempfile.mkdtemp(prefix="pt_expt_ft_random_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + # Phase 1: train pretrained model + config = _make_config(self.data_dir, model_se_e2_a, numb_steps=1) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + ckpt_path = self._train_pretrained(config, tmpdir) + + # Phase 2: finetune with RANDOM (random fitting) + ft_config = _make_config(self.data_dir, model_se_e2_a, numb_steps=1) + ft_config = update_deepmd_input(ft_config, warning=False) + ft_config = normalize(ft_config) + ft_config["model"], finetune_links = get_finetune_rules( + ckpt_path, + ft_config["model"], + model_branch="RANDOM", + ) + + # Verify finetune rule has random_fitting=True + self.assertTrue(finetune_links["Default"].get_random_fitting()) + + trainer_ft = get_trainer( + ft_config, + finetune_model=ckpt_path, + finetune_links=finetune_links, + ) + + # Load pretrained weights for comparison + pretrained_state = torch.load( + ckpt_path, map_location=DEVICE, weights_only=True + ) + if "model" in pretrained_state: + pretrained_state = pretrained_state["model"] + pretrained_model = get_model( + pretrained_state["_extra_state"]["model_params"] + ).to(DEVICE) + pretrained_wrapper = ModelWrapper(pretrained_model) + pretrained_wrapper.load_state_dict(pretrained_state) + + # Descriptor weights should match; fitting should NOT + ft_state = trainer_ft.wrapper.state_dict() + pre_state = pretrained_wrapper.state_dict() + self._assert_inherited_weights_match( + ft_state, pre_state, random_fitting=True + ) + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_finetune_from_pte(self) -> None: + """Train -> freeze to .pte -> finetune from .pte -> verify checkpoint.""" + from deepmd.pt_expt.entrypoints.main import ( + freeze, + main, + ) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_ft_pte_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + # Phase 1: train pretrained model + config = _make_config(self.data_dir, model_se_e2_a, numb_steps=1) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + ckpt_path = self._train_pretrained(config, tmpdir) + + # Phase 2: freeze to .pte + pte_path = os.path.join(tmpdir, "frozen.pte") + freeze(model=ckpt_path, output=pte_path) + self.assertTrue(os.path.exists(pte_path)) + + # Phase 3: finetune from .pte via CLI (lr=0 so weights stay unchanged) + ft_config = _make_config(self.data_dir, model_se_e2_a, numb_steps=1) + ft_config["learning_rate"]["start_lr"] = 1e-30 + ft_config["learning_rate"]["stop_lr"] = 1e-30 + ft_config_file = os.path.join(tmpdir, "finetune_input.json") + with open(ft_config_file, "w") as f: + json.dump(ft_config, f) + + main( + [ + "train", + ft_config_file, + "--finetune", + pte_path, + "--skip-neighbor-stat", + ] + ) + + # Verify new checkpoint exists + ft_ckpt = os.path.join(tmpdir, "model.ckpt.pt") + self.assertTrue(os.path.exists(ft_ckpt), "Finetune checkpoint not found") + + # Load finetuned model and verify it's valid + ft_state = torch.load(ft_ckpt, map_location=DEVICE, weights_only=True) + ft_model_state = ft_state["model"] if "model" in ft_state else ft_state + self.assertIn("_extra_state", ft_model_state) + + # Load pretrained from .pt for weight comparison + pre_state = torch.load(ckpt_path, map_location=DEVICE, weights_only=True) + pre_model_state = pre_state["model"] if "model" in pre_state else pre_state + + # Inherited weights must match pretrained + self._assert_inherited_weights_match( + ft_model_state, pre_model_state, random_fitting=False + ) + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + def test_finetune_from_pte_use_pretrain_script(self) -> None: + """Train -> freeze to .pte -> finetune with --use-pretrain-script.""" + from deepmd.pt_expt.entrypoints.main import ( + freeze, + main, + ) + + tmpdir = tempfile.mkdtemp(prefix="pt_expt_ft_pte_ups_") + old_cwd = os.getcwd() + os.chdir(tmpdir) + try: + # Phase 1: train pretrained model + config = _make_config(self.data_dir, model_se_e2_a, numb_steps=1) + config = update_deepmd_input(config, warning=False) + config = normalize(config) + ckpt_path = self._train_pretrained(config, tmpdir) + + # Phase 2: freeze to .pte (embeds model_params) + pte_path = os.path.join(tmpdir, "frozen.pte") + freeze(model=ckpt_path, output=pte_path) + + # Phase 3: finetune from .pte with --use-pretrain-script + ft_model_params = deepcopy(model_se_e2_a) + ft_model_params["descriptor"]["neuron"] = [4, 8] # different + ft_config = _make_config(self.data_dir, ft_model_params, numb_steps=1) + ft_config_file = os.path.join(tmpdir, "finetune_input.json") + with open(ft_config_file, "w") as f: + json.dump(ft_config, f) + + main( + [ + "train", + ft_config_file, + "--finetune", + pte_path, + "--use-pretrain-script", + "--skip-neighbor-stat", + ] + ) + + # Verify the output config was updated from pretrained + with open(os.path.join(tmpdir, "out.json")) as f: + output_config = json.load(f) + # Descriptor neuron should be from pretrained, not from ft_config + self.assertEqual( + output_config["model"]["descriptor"]["neuron"], + model_se_e2_a["descriptor"]["neuron"], + ) + finally: + os.chdir(old_cwd) + shutil.rmtree(tmpdir, ignore_errors=True) + + +if __name__ == "__main__": + unittest.main()