diff --git a/deepmd/dpmodel/model/spin_model.py b/deepmd/dpmodel/model/spin_model.py index 510ca5dce2..eda8318abc 100644 --- a/deepmd/dpmodel/model/spin_model.py +++ b/deepmd/dpmodel/model/spin_model.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import functools +from collections.abc import ( + Callable, +) from copy import ( deepcopy, ) @@ -332,6 +336,88 @@ def model_output_def(self) -> ModelOutputDef: backbone_model_atomic_output_def[var_name].magnetic = True return ModelOutputDef(backbone_model_atomic_output_def) + def _get_spin_sampled_func( + self, sampled_func: Callable[[], list[dict]] + ) -> Callable[[], list[dict]]: + """Get a spin-aware sampled function that transforms spin data for the backbone model. + + Parameters + ---------- + sampled_func + A callable that returns a list of data dicts containing 'coord', 'atype', 'spin', etc. + + Returns + ------- + Callable + A cached callable that returns spin-preprocessed data dicts. + """ + + @functools.lru_cache + def spin_sampled_func() -> list[dict]: + sampled = sampled_func() + spin_sampled = [] + for sys in sampled: + coord_updated, atype_updated = self.process_spin_input( + sys["coord"], sys["atype"], sys["spin"] + ) + tmp_dict = { + "coord": coord_updated, + "atype": atype_updated, + } + if "natoms" in sys: + natoms = sys["natoms"] + tmp_dict["natoms"] = np.concatenate( + [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], axis=-1 + ) + for item_key in sys.keys(): + if item_key not in ["coord", "atype", "spin", "natoms"]: + tmp_dict[item_key] = sys[item_key] + spin_sampled.append(tmp_dict) + return spin_sampled + + return self.backbone_model.atomic_model._make_wrapped_sampler(spin_sampled_func) + + def change_out_bias( + self, + merged: Callable[[], list[dict]] | list[dict], + bias_adjust_mode: str = "change-by-statistic", + ) -> None: + """Change the output bias of atomic model according to the input data and the pretrained model. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + bias_adjust_mode : str + The mode for changing output bias : ['change-by-statistic', 'set-by-statistic'] + 'change-by-statistic' : perform predictions on labels of target dataset, + and do least square on the errors to obtain the target shift as bias. + 'set-by-statistic' : directly use the statistic output bias in the target dataset. + """ + spin_sampled_func = self._get_spin_sampled_func( + merged if callable(merged) else lambda: merged + ) + self.backbone_model.change_out_bias( + spin_sampled_func, + bias_adjust_mode=bias_adjust_mode, + ) + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat: Any = None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + type_map_with_spin = type_map + [item + "_spin" for item in type_map] + self.backbone_model.change_type_map( + type_map_with_spin, model_with_new_type_stat + ) + def __getattr__(self, name: str) -> Any: """Get attribute from the wrapped model.""" if "backbone_model" not in self.__dict__: diff --git a/deepmd/pt/model/atomic_model/base_atomic_model.py b/deepmd/pt/model/atomic_model/base_atomic_model.py index 596c4c07fe..bfc67cf82b 100644 --- a/deepmd/pt/model/atomic_model/base_atomic_model.py +++ b/deepmd/pt/model/atomic_model/base_atomic_model.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import functools import logging from collections.abc import ( Callable, @@ -184,6 +185,58 @@ def has_default_fparam(self) -> bool: """Check if the model has default frame parameters.""" return False + def get_default_fparam(self) -> torch.Tensor | None: + """Get the default frame parameters.""" + return None + + def _make_wrapped_sampler( + self, + sampled_func: Callable[[], list[dict]], + ) -> Callable[[], list[dict]]: + """Wrap the sampled function with exclusion types and default fparam. + + The returned callable is cached so that the sampling (which may be + expensive) is performed at most once. + + Parameters + ---------- + sampled_func + The lazy sampled function to get data frames from different data + systems. + + Returns + ------- + Callable[[], list[dict]] + A cached wrapper around *sampled_func* that additionally sets + ``pair_exclude_types``, ``atom_exclude_types`` and default + ``fparam`` on every sample dict when applicable. + """ + + @functools.lru_cache + def wrapped_sampler() -> list[dict]: + sampled = sampled_func() + if self.pair_excl is not None: + pair_exclude_types = self.pair_excl.get_exclude_types() + for sample in sampled: + sample["pair_exclude_types"] = list(pair_exclude_types) + if self.atom_excl is not None: + atom_exclude_types = self.atom_excl.get_exclude_types() + for sample in sampled: + sample["atom_exclude_types"] = list(atom_exclude_types) + if ( + "find_fparam" not in sampled[0] + and "fparam" not in sampled[0] + and self.has_default_fparam() + ): + default_fparam = self.get_default_fparam() + if default_fparam is not None: + for sample in sampled: + nframe = sample["atype"].shape[0] + sample["fparam"] = default_fparam.repeat(nframe, 1) + return sampled + + return wrapped_sampler + def reinit_atom_exclude( self, exclude_types: list[int] = [], diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 3fd501df13..84cb158ca2 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import functools import logging from collections.abc import ( Callable, @@ -329,28 +328,7 @@ def compute_or_load_stat( # should not share the same parameters stat_file_path /= " ".join(self.type_map) - @functools.lru_cache - def wrapped_sampler() -> list[dict]: - sampled = sampled_func() - if self.pair_excl is not None: - pair_exclude_types = self.pair_excl.get_exclude_types() - for sample in sampled: - sample["pair_exclude_types"] = list(pair_exclude_types) - if self.atom_excl is not None: - atom_exclude_types = self.atom_excl.get_exclude_types() - for sample in sampled: - sample["atom_exclude_types"] = list(atom_exclude_types) - if ( - "find_fparam" not in sampled[0] - and "fparam" not in sampled[0] - and self.has_default_fparam() - ): - default_fparam = self.get_default_fparam() - for sample in sampled: - nframe = sample["atype"].shape[0] - sample["fparam"] = default_fparam.repeat(nframe, 1) - return sampled - + wrapped_sampler = self._make_wrapped_sampler(sampled_func) self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) self.compute_fitting_input_stat(wrapped_sampler, stat_file_path) if compute_or_load_out_stat: diff --git a/deepmd/pt/model/atomic_model/linear_atomic_model.py b/deepmd/pt/model/atomic_model/linear_atomic_model.py index b00393c0ff..55090422be 100644 --- a/deepmd/pt/model/atomic_model/linear_atomic_model.py +++ b/deepmd/pt/model/atomic_model/linear_atomic_model.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import functools from collections.abc import ( Callable, ) @@ -518,19 +517,7 @@ def compute_or_load_stat( # should not share the same parameters stat_file_path /= " ".join(self.type_map) - @functools.lru_cache - def wrapped_sampler() -> list[dict[str, Any]]: - sampled = sampled_func() - if self.pair_excl is not None: - pair_exclude_types = self.pair_excl.get_exclude_types() - for sample in sampled: - sample["pair_exclude_types"] = list(pair_exclude_types) - if self.atom_excl is not None: - atom_exclude_types = self.atom_excl.get_exclude_types() - for sample in sampled: - sample["atom_exclude_types"] = list(atom_exclude_types) - return sampled - + wrapped_sampler = self._make_wrapped_sampler(sampled_func) self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) diff --git a/deepmd/pt/model/model/spin_model.py b/deepmd/pt/model/model/spin_model.py index 6d864c3205..cc6404675e 100644 --- a/deepmd/pt/model/model/spin_model.py +++ b/deepmd/pt/model/model/spin_model.py @@ -388,6 +388,75 @@ def __getattr__(self, name: str) -> Any: else: return getattr(self.backbone_model, name) + def _get_spin_sampled_func( + self, sampled_func: Callable[[], list[dict]] + ) -> Callable[[], list[dict]]: + @functools.lru_cache + def spin_sampled_func() -> list[dict]: + sampled = sampled_func() + spin_sampled = [] + for sys in sampled: + coord_updated, atype_updated, _ = self.process_spin_input( + sys["coord"], sys["atype"], sys["spin"] + ) + tmp_dict = { + "coord": coord_updated, + "atype": atype_updated, + } + if "natoms" in sys: + natoms = sys["natoms"] + tmp_dict["natoms"] = torch.cat( + [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], dim=-1 + ) + for item_key in sys.keys(): + if item_key not in ["coord", "atype", "spin", "natoms"]: + tmp_dict[item_key] = sys[item_key] + spin_sampled.append(tmp_dict) + return spin_sampled + + return self.backbone_model.atomic_model._make_wrapped_sampler(spin_sampled_func) + + def change_out_bias( + self, + merged: Callable[[], list[dict]] | list[dict], + bias_adjust_mode: str = "change-by-statistic", + ) -> None: + """Change the output bias of atomic model according to the input data and the pretrained model. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + bias_adjust_mode : str + The mode for changing output bias : ['change-by-statistic', 'set-by-statistic'] + 'change-by-statistic' : perform predictions on labels of target dataset, + and do least square on the errors to obtain the target shift as bias. + 'set-by-statistic' : directly use the statistic output bias in the target dataset. + """ + spin_sampled_func = self._get_spin_sampled_func( + merged if callable(merged) else lambda: merged + ) + self.backbone_model.change_out_bias( + spin_sampled_func, + bias_adjust_mode=bias_adjust_mode, + ) + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat: Any = None + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + type_map_with_spin = type_map + [item + "_spin" for item in type_map] + self.backbone_model.change_type_map( + type_map_with_spin, model_with_new_type_stat + ) + def compute_or_load_stat( self, sampled_func: Callable[[], list[dict[str, Any]]], diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index c823ade109..c13846cbc9 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1833,6 +1833,8 @@ def model_change_out_bias( model_type_map = _model.get_type_map() log.info( - f"Change output bias of {model_type_map!s} from {to_numpy_array(old_bias).reshape(-1)!s} to {to_numpy_array(new_bias).reshape(-1)!s}." + f"Change output bias of {model_type_map!s} " + f"from {to_numpy_array(old_bias).reshape(-1)[: len(model_type_map)]!s} " + f"to {to_numpy_array(new_bias).reshape(-1)[: len(model_type_map)]!s}." ) return _model diff --git a/source/tests/common/dpmodel/test_finetune_spin.py b/source/tests/common/dpmodel/test_finetune_spin.py new file mode 100644 index 0000000000..b98f93cb39 --- /dev/null +++ b/source/tests/common/dpmodel/test_finetune_spin.py @@ -0,0 +1,266 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for dpmodel spin model finetune: _get_spin_sampled_func, change_out_bias, change_type_map.""" + +import copy +import unittest + +import numpy as np + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.dpmodel.fitting import ( + InvarFitting, +) +from deepmd.dpmodel.model.spin_model import ( + SpinModel, +) +from deepmd.utils.spin import ( + Spin, +) + +from ...seed import ( + GLOBAL_SEED, +) + + +def _make_spin_model( + type_map: list[str], + use_spin: list[bool], + virtual_scale: list[float], + sel: list[int], + rcut: float = 4.0, + rcut_smth: float = 0.5, + numb_fparam: int = 0, + default_fparam: list[float] | None = None, +) -> SpinModel: + """Create a dpmodel SpinModel for testing.""" + # The backbone model sees both real and virtual types + ntypes_real = len(type_map) + type_map_backbone = type_map + [t + "_spin" for t in type_map] + # sel needs to be doubled for virtual types + sel_backbone = sel + sel + + descriptor = DescrptSeA( + rcut=rcut, + rcut_smth=rcut_smth, + sel=sel_backbone, + seed=GLOBAL_SEED, + ) + fitting = InvarFitting( + "energy", + ntypes_real * 2, # backbone sees real + virtual types + descriptor.get_dim_out(), + 1, + mixed_types=descriptor.mixed_types(), + seed=GLOBAL_SEED, + numb_fparam=numb_fparam, + default_fparam=default_fparam, + ) + + from deepmd.dpmodel.atomic_model.dp_atomic_model import ( + DPAtomicModel, + ) + from deepmd.dpmodel.common import ( + NativeOP, + ) + from deepmd.dpmodel.model.base_model import ( + BaseModel, + ) + from deepmd.dpmodel.model.make_model import ( + make_model, + ) + + CM = make_model(DPAtomicModel, T_Bases=(NativeOP, BaseModel)) + backbone = CM(descriptor, fitting, type_map=type_map_backbone) + + spin = Spin(use_spin=use_spin, virtual_scale=virtual_scale) + return SpinModel(backbone_model=backbone, spin=spin) + + +def _make_sample_data( + nframes: int, + nloc: int, + ntypes: int, + rng: np.random.RandomState, +) -> list[dict]: + """Create fake sample data for testing.""" + atype = rng.randint(0, ntypes, size=(nframes, nloc)).astype(np.int64) + coord = rng.randn(nframes, nloc, 3).astype(np.float64) + spin = 0.5 * rng.randn(nframes, nloc, 3).astype(np.float64) + energy = rng.randn(nframes, 1).astype(np.float64) + natoms_count = np.zeros((nframes, 2 + ntypes), dtype=np.int32) + natoms_count[:, 0] = nloc + natoms_count[:, 1] = nloc + for i in range(nframes): + for t in range(ntypes): + natoms_count[i, 2 + t] = np.sum(atype[i] == t) + return [ + { + "coord": coord, + "atype": atype, + "spin": spin, + "energy": energy, + "natoms": natoms_count, + "find_energy": np.float32(1.0), + "find_fparam": np.float32(0.0), + } + ] + + +class TestSpinModelGetSpinSampledFunc(unittest.TestCase): + """Test _get_spin_sampled_func correctly transforms spin data (dpmodel).""" + + def setUp(self) -> None: + self.type_map = ["Ni", "O"] + self.model = _make_spin_model( + type_map=self.type_map, + use_spin=[True, False], + virtual_scale=[0.3140], + sel=[10, 10], + ) + self.rng = np.random.RandomState(GLOBAL_SEED) + + def test_spin_data_transformation(self) -> None: + nframes, nloc, ntypes = 2, 6, 2 + sampled = _make_sample_data(nframes, nloc, ntypes, self.rng) + + def sampled_func() -> list[dict]: + return sampled + + spin_sampled_func = self.model._get_spin_sampled_func(sampled_func) + spin_sampled = spin_sampled_func() + + for i, sys_data in enumerate(spin_sampled): + original = sampled[i] + # coord should be doubled (real + virtual) + assert sys_data["coord"].shape[1] == 2 * nloc + # atype should be doubled + assert sys_data["atype"].shape[1] == 2 * nloc + # spin should not be in the transformed data + assert "spin" not in sys_data + # energy should be preserved + np.testing.assert_array_equal(sys_data["energy"], original["energy"]) + # natoms should be transformed correctly + natoms = original["natoms"] + expected_natoms = np.concatenate( + [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], axis=-1 + ) + np.testing.assert_array_equal(sys_data["natoms"], expected_natoms) + + def test_coord_values(self) -> None: + """Verify virtual coordinates = real + spin * virtual_scale.""" + nframes, nloc, ntypes = 1, 4, 2 + sampled = _make_sample_data(nframes, nloc, ntypes, self.rng) + + spin_sampled = self.model._get_spin_sampled_func(lambda: sampled)() + + original = sampled[0] + transformed = spin_sampled[0] + coord = original["coord"] # (nframes, nloc, 3) + spin = original["spin"] + atype = original["atype"] + virtual_scale_mask = self.model.virtual_scale_mask + + # Real coords should be unchanged + np.testing.assert_array_equal(transformed["coord"][:, :nloc], coord) + # Virtual coords = real + spin * scale + expected_virtual = coord + spin * virtual_scale_mask[atype].reshape( + nframes, nloc, 1 + ) + np.testing.assert_allclose( + transformed["coord"][:, nloc:], expected_virtual, atol=1e-12 + ) + + +class TestSpinModelChangeOutBias(unittest.TestCase): + """Test change_out_bias for dpmodel SpinModel.""" + + def setUp(self) -> None: + self.type_map = ["Ni", "O"] + self.model = _make_spin_model( + type_map=self.type_map, + use_spin=[True, False], + virtual_scale=[0.3140], + sel=[10, 10], + ) + self.rng = np.random.RandomState(GLOBAL_SEED) + + def test_change_out_bias_runs(self) -> None: + """Test that change_out_bias does not raise with spin model.""" + sampled = _make_sample_data(2, 6, 2, self.rng) + old_bias = copy.deepcopy(self.model.backbone_model.get_out_bias()) + self.model.change_out_bias(sampled, bias_adjust_mode="set-by-statistic") + new_bias = self.model.backbone_model.get_out_bias() + # Bias should have changed + assert not np.allclose(old_bias, new_bias), "Bias was not updated" + + def test_change_out_bias_with_callable(self) -> None: + """Test change_out_bias with a callable (lazy sampled func).""" + sampled = _make_sample_data(2, 6, 2, self.rng) + old_bias = copy.deepcopy(self.model.backbone_model.get_out_bias()) + self.model.change_out_bias(lambda: sampled, bias_adjust_mode="set-by-statistic") + new_bias = self.model.backbone_model.get_out_bias() + assert not np.allclose(old_bias, new_bias), "Bias was not updated" + + +class TestSpinModelChangeTypeMap(unittest.TestCase): + """Test change_type_map for dpmodel SpinModel.""" + + def setUp(self) -> None: + self.type_map = ["Ni", "O"] + self.model = _make_spin_model( + type_map=self.type_map, + use_spin=[True, False], + virtual_scale=[0.3140], + sel=[10, 10], + ) + + def test_change_type_map(self) -> None: + """Test that change_type_map delegates to backbone with _spin suffixes. + + DescrptSeA does not support change_type_map, so we verify that the + SpinModel correctly constructs the suffixed type map and delegates + to the backbone (which raises NotImplementedError from se_e2_a). + The full change_type_map workflow is tested via PT backend with + mixed-types descriptors. + """ + new_type_map = ["O", "Ni"] + with self.assertRaises(NotImplementedError): + self.model.change_type_map(new_type_map) + + +class TestSpinModelWithDefaultFparam(unittest.TestCase): + """Test _get_spin_sampled_func injects default fparam (dpmodel).""" + + def setUp(self) -> None: + self.type_map = ["Ni", "O"] + self.default_fparam = [0.5, 1.0] + self.model = _make_spin_model( + type_map=self.type_map, + use_spin=[True, False], + virtual_scale=[0.3140], + sel=[10, 10], + numb_fparam=2, + default_fparam=self.default_fparam, + ) + self.rng = np.random.RandomState(GLOBAL_SEED) + + def test_fparam_injected(self) -> None: + """Test that _get_spin_sampled_func + _make_wrapped_sampler injects fparam.""" + sampled = _make_sample_data(2, 6, 2, self.rng) + assert "fparam" not in sampled[0] + + spin_sampled = self.model._get_spin_sampled_func(lambda: sampled)() + + for sys_data in spin_sampled: + assert "fparam" in sys_data, ( + "_make_wrapped_sampler did not inject default fparam" + ) + nframe = sys_data["atype"].shape[0] + assert sys_data["fparam"].shape == (nframe, 2) + np.testing.assert_allclose(sys_data["fparam"][0], self.default_fparam) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/test_finetune_spin.py b/source/tests/pt/test_finetune_spin.py new file mode 100644 index 0000000000..b1bb770ac7 --- /dev/null +++ b/source/tests/pt/test_finetune_spin.py @@ -0,0 +1,556 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for spin model finetune: change_out_bias, change_type_map, and e2e finetune.""" + +import json +import os +import shutil +import tempfile +import unittest +from copy import ( + deepcopy, +) +from pathlib import ( + Path, +) + +import numpy as np +import torch + +from deepmd.infer.deep_eval import ( + DeepEval, +) +from deepmd.pt.entrypoints.main import ( + get_trainer, +) +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.dataloader import ( + DpLoaderSet, +) +from deepmd.pt.utils.finetune import ( + get_finetune_rules, +) +from deepmd.pt.utils.stat import ( + make_stat_input, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) +from deepmd.utils.data import ( + DataRequirementItem, +) + +from .model.test_permutation import ( + model_spin, +) + +spin_data_requirement = [ + DataRequirementItem( + "energy", + ndof=1, + atomic=False, + must=False, + high_prec=True, + ), + DataRequirementItem( + "force", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "force_mag", + ndof=3, + atomic=True, + must=False, + high_prec=False, + ), + DataRequirementItem( + "virial", + ndof=9, + atomic=False, + must=False, + high_prec=False, + ), + DataRequirementItem( + "spin", + ndof=3, + atomic=True, + must=True, + high_prec=False, + ), +] + + +class SpinFinetuneTest: + """Mixin test class for spin model finetune operations.""" + + def test_change_out_bias(self) -> None: + """Test that change_out_bias correctly adjusts energy bias for spin model.""" + # get data + data = DpLoaderSet( + self.data_file, + batch_size=1, + type_map=self.config["model"]["type_map"], + ) + data.add_data_requirement(spin_data_requirement) + sampled = make_stat_input( + data.systems, + data.dataloaders, + nbatches=1, + ) + + # get model + model = get_model(self.config["model"]).to(env.DEVICE) + + # set random bias + atomic_model = model.backbone_model.atomic_model + atomic_model["out_bias"] = torch.rand_like(atomic_model["out_bias"]) + energy_bias_before = to_numpy_array(atomic_model["out_bias"])[0] + + # prepare original model for prediction + dp = torch.jit.script(model) + tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth") + torch.jit.save(dp, tmp_model.name) + dp = DeepEval(tmp_model.name) + + origin_type_map = self.config["model"]["type_map"][:2] + full_type_map = self.config["model"]["type_map"] + + # change energy bias via spin model's change_out_bias + model.change_out_bias( + sampled, + bias_adjust_mode="change-by-statistic", + ) + energy_bias_after = to_numpy_array(atomic_model["out_bias"])[0] + + # get ground-truth energy bias change via least squares + sorter = np.argsort(full_type_map) + idx_type_map = sorter[ + np.searchsorted(full_type_map, origin_type_map, sorter=sorter) + ] + ntest = 1 + atom_nums = np.tile( + np.bincount(to_numpy_array(sampled[0]["atype"][0]))[idx_type_map], + (ntest, 1), + ) + energy = dp.eval( + to_numpy_array(sampled[0]["coord"][:ntest]), + to_numpy_array(sampled[0]["box"][:ntest]), + to_numpy_array(sampled[0]["atype"][0]), + spin=to_numpy_array(sampled[0]["spin"][:ntest]), + )[0] + + energy_diff = to_numpy_array(sampled[0]["energy"][:ntest]) - energy + finetune_shift = ( + energy_bias_after[idx_type_map] - energy_bias_before[idx_type_map] + ).ravel() + ground_truth_shift = np.linalg.lstsq(atom_nums, energy_diff, rcond=None)[ + 0 + ].reshape(-1) + + # check values + np.testing.assert_almost_equal(finetune_shift, ground_truth_shift, decimal=10) + os.unlink(tmp_model.name) + + def test__get_spin_sampled_func(self) -> None: + """Test that _get_spin_sampled_func correctly transforms spin data.""" + # get data + data = DpLoaderSet( + self.data_file, + batch_size=1, + type_map=self.config["model"]["type_map"], + ) + data.add_data_requirement(spin_data_requirement) + sampled = make_stat_input( + data.systems, + data.dataloaders, + nbatches=1, + ) + + model = get_model(self.config["model"]).to(env.DEVICE) + + # Create a sampled_func callable + def sampled_func(): + return sampled + + spin_sampled_func = model._get_spin_sampled_func(sampled_func) + spin_sampled = spin_sampled_func() + + # Verify the transformed data + for i, sys_data in enumerate(spin_sampled): + original = sampled[i] + nloc = original["atype"].shape[1] + # coord should be doubled (real + virtual) + assert sys_data["coord"].shape[1] == 2 * nloc + # atype should be doubled + assert sys_data["atype"].shape[1] == 2 * nloc + # spin should not be in the transformed data + assert "spin" not in sys_data + # energy should be preserved + if "energy" in original: + torch.testing.assert_close(sys_data["energy"], original["energy"]) + # natoms should be transformed correctly + if "natoms" in original: + natoms = original["natoms"] + expected_natoms = torch.cat( + [2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], dim=-1 + ) + torch.testing.assert_close(sys_data["natoms"], expected_natoms) + + def tearDown(self) -> None: + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) + + +class TestSpinFinetuneSeA(SpinFinetuneTest, unittest.TestCase): + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.data_file = [str(Path(__file__).parent / "NiO/data/single")] + self.config["training"]["training_data"]["systems"] = self.data_file + self.config["training"]["validation_data"]["systems"] = self.data_file + self.config["model"] = deepcopy(model_spin) + self.config["model"]["type_map"] = ["Ni", "O"] + self.config["model"]["descriptor"] = { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 4.00, + "neuron": [25, 50, 100], + "resnet_dt": False, + "axis_neuron": 16, + "seed": 1, + } + self.config["model"]["fitting_net"] = { + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + } + self.config["model"]["spin"] = { + "use_spin": [True, False], + "virtual_scale": [0.3140], + } + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.config["loss"] = { + "type": "ener_spin", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_fr": 1000, + "limit_pref_fr": 1, + "start_pref_fm": 1000, + "limit_pref_fm": 1, + } + self.mixed_types = False + + +class SpinFinetuneE2ETest: + """End-to-end test mixin for spin model finetune workflow. + + Tests the full workflow: train from scratch -> save -> finetune with change_out_bias. + """ + + def test_finetune_e2e(self) -> None: + """Test the full finetune workflow for a spin model.""" + # Step 1: Train from scratch + config_pretrain = deepcopy(self.config) + config_pretrain["training"]["save_ckpt"] = "model" + trainer = get_trainer(config_pretrain) + trainer.run() + finetune_model = ( + config_pretrain["training"].get("save_ckpt", "model.ckpt") + ".pt" + ) + + # Step 2: Finetune with the same type_map (should work after the fix) + config_finetune = deepcopy(self.config) + config_finetune["model"], finetune_links = get_finetune_rules( + finetune_model, + config_finetune["model"], + ) + # This should NOT raise an error after the fix + trainer_finetune = get_trainer( + config_finetune, + finetune_model=finetune_model, + finetune_links=finetune_links, + ) + + # Verify the model is functional after finetune loading + data = DpLoaderSet( + self.data_file, + batch_size=1, + type_map=self.config["model"]["type_map"], + ) + data.add_data_requirement(spin_data_requirement) + sampled = make_stat_input( + data.systems, + data.dataloaders, + nbatches=1, + ) + # Run inference to verify model works + ntest = 1 + result = trainer_finetune.model( + sampled[0]["coord"][:ntest], + sampled[0]["atype"][:ntest], + spin=sampled[0]["spin"][:ntest], + box=sampled[0]["box"][:ntest], + ) + # Basic checks - model should produce valid outputs + assert "energy" in result + assert "force" in result + assert result["energy"].shape == (ntest, 1) + nloc = sampled[0]["atype"].shape[1] + assert result["force"].shape == (ntest, nloc, 3) + + def test_finetune_change_type_map(self) -> None: + """Test change_type_map for spin model. + + Only runs for mixed_types descriptors since se_e2_a + does not support type map extension. + """ + if not self.mixed_types: + return + # Train a pretrained model + config_pretrain = deepcopy(self.config) + config_pretrain["training"]["save_ckpt"] = "model" + trainer = get_trainer(config_pretrain) + trainer.run() + finetune_model = ( + config_pretrain["training"].get("save_ckpt", "model.ckpt") + ".pt" + ) + + # Finetune with a new type_map that has extra types + config_finetune = deepcopy(self.config) + config_finetune["model"]["type_map"] = self.config["model"]["type_map"] + ["Fe"] + # Extend spin config for the new type + config_finetune["model"]["spin"]["use_spin"] = self.config["model"]["spin"][ + "use_spin" + ] + [False] + config_finetune["model"], finetune_links = get_finetune_rules( + finetune_model, + config_finetune["model"], + ) + # This should NOT raise an error: the spin model should handle type_map change + trainer_finetune = get_trainer( + config_finetune, + finetune_model=finetune_model, + finetune_links=finetune_links, + ) + + # Verify the new type map is applied correctly + new_type_map = trainer_finetune.model.get_type_map() + assert "Fe" in new_type_map + + def tearDown(self) -> None: + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out", "checkpoint"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) + + +class TestSpinFinetuneE2ESeA(SpinFinetuneE2ETest, unittest.TestCase): + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.data_file = [str(Path(__file__).parent / "NiO/data/single")] + self.config["training"]["training_data"]["systems"] = self.data_file + self.config["training"]["validation_data"]["systems"] = self.data_file + self.config["model"] = { + "type_map": ["Ni", "O"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 4.00, + "neuron": [25, 50, 100], + "resnet_dt": False, + "axis_neuron": 16, + "seed": 1, + }, + "fitting_net": { + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + }, + "spin": { + "use_spin": [True, False], + "virtual_scale": [0.3140], + }, + } + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.config["loss"] = { + "type": "ener_spin", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_fr": 1000, + "limit_pref_fr": 1, + "start_pref_fm": 1000, + "limit_pref_fm": 1, + } + self.mixed_types = False + + +class TestSpinFinetuneWithDefaultFparam(unittest.TestCase): + """Test spin model finetune with default_fparam enabled. + + Verifies that _make_wrapped_sampler correctly injects default fparam + into sampled data when spin preprocessing is also active. + """ + + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.data_file = [str(Path(__file__).parent / "NiO/data/single")] + self.config["training"]["training_data"]["systems"] = self.data_file + self.config["training"]["validation_data"]["systems"] = self.data_file + self.config["model"] = { + "type_map": ["Ni", "O"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 4.00, + "neuron": [25, 50, 100], + "resnet_dt": False, + "axis_neuron": 16, + "seed": 1, + }, + "fitting_net": { + "neuron": [24, 24, 24], + "resnet_dt": True, + "seed": 1, + "numb_fparam": 2, + "default_fparam": [0.5, 1.0], + }, + "spin": { + "use_spin": [True, False], + "virtual_scale": [0.3140], + }, + } + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.config["loss"] = { + "type": "ener_spin", + "start_pref_e": 0.02, + "limit_pref_e": 1, + "start_pref_fr": 1000, + "limit_pref_fr": 1, + "start_pref_fm": 1000, + "limit_pref_fm": 1, + } + + def test_spin_sampled_func_with_default_fparam(self) -> None: + """Test that _get_spin_sampled_func + _make_wrapped_sampler injects fparam.""" + data = DpLoaderSet( + self.data_file, + batch_size=1, + type_map=self.config["model"]["type_map"], + ) + data.add_data_requirement(spin_data_requirement) + sampled = make_stat_input( + data.systems, + data.dataloaders, + nbatches=1, + ) + + model = get_model(self.config["model"]).to(env.DEVICE) + + # Verify model has default_fparam + assert model.backbone_model.atomic_model.has_default_fparam() + + # sampled should NOT have fparam yet + assert "fparam" not in sampled[0] + + def sampled_func(): + return sampled + + # _get_spin_sampled_func chains: spin preprocess -> _make_wrapped_sampler + spin_sampled_func = model._get_spin_sampled_func(sampled_func) + spin_sampled = spin_sampled_func() + + for sys_data in spin_sampled: + # fparam should be injected by _make_wrapped_sampler + assert "fparam" in sys_data, ( + "_make_wrapped_sampler did not inject default fparam" + ) + nframe = sys_data["atype"].shape[0] + assert sys_data["fparam"].shape == (nframe, 2) + # check values match default_fparam + np.testing.assert_allclose( + to_numpy_array(sys_data["fparam"][0]), + [0.5, 1.0], + ) + + def test_finetune_e2e_with_default_fparam(self) -> None: + """Test e2e finetune for spin model with default_fparam.""" + config_pretrain = deepcopy(self.config) + config_pretrain["training"]["save_ckpt"] = "model" + trainer = get_trainer(config_pretrain) + trainer.run() + finetune_model = ( + config_pretrain["training"].get("save_ckpt", "model.ckpt") + ".pt" + ) + + config_finetune = deepcopy(self.config) + config_finetune["model"], finetune_links = get_finetune_rules( + finetune_model, + config_finetune["model"], + ) + # Should not raise an error with spin + default_fparam + trainer_finetune = get_trainer( + config_finetune, + finetune_model=finetune_model, + finetune_links=finetune_links, + ) + + # Verify the model works + data = DpLoaderSet( + self.data_file, + batch_size=1, + type_map=self.config["model"]["type_map"], + ) + data.add_data_requirement(spin_data_requirement) + sampled = make_stat_input( + data.systems, + data.dataloaders, + nbatches=1, + ) + ntest = 1 + result = trainer_finetune.model( + sampled[0]["coord"][:ntest], + sampled[0]["atype"][:ntest], + spin=sampled[0]["spin"][:ntest], + box=sampled[0]["box"][:ntest], + ) + assert "energy" in result + assert result["energy"].shape == (ntest, 1) + + def tearDown(self) -> None: + for f in os.listdir("."): + if f.startswith("model") and f.endswith(".pt"): + os.remove(f) + if f in ["lcurve.out", "checkpoint"]: + os.remove(f) + if f in ["stat_files"]: + shutil.rmtree(f) + + +if __name__ == "__main__": + unittest.main()