Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions deepmd/dpmodel/model/spin_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import functools
from collections.abc import (
Callable,
)
from copy import (
deepcopy,
)
Expand Down Expand Up @@ -332,6 +336,88 @@ def model_output_def(self) -> ModelOutputDef:
backbone_model_atomic_output_def[var_name].magnetic = True
return ModelOutputDef(backbone_model_atomic_output_def)

def _get_spin_sampled_func(
self, sampled_func: Callable[[], list[dict]]
) -> Callable[[], list[dict]]:
"""Get a spin-aware sampled function that transforms spin data for the backbone model.

Parameters
----------
sampled_func
A callable that returns a list of data dicts containing 'coord', 'atype', 'spin', etc.

Returns
-------
Callable
A cached callable that returns spin-preprocessed data dicts.
"""

@functools.lru_cache
def spin_sampled_func() -> list[dict]:
sampled = sampled_func()
spin_sampled = []
for sys in sampled:
coord_updated, atype_updated = self.process_spin_input(
sys["coord"], sys["atype"], sys["spin"]
)
tmp_dict = {
"coord": coord_updated,
"atype": atype_updated,
}
if "natoms" in sys:
natoms = sys["natoms"]
tmp_dict["natoms"] = np.concatenate(
[2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], axis=-1
)
for item_key in sys.keys():
if item_key not in ["coord", "atype", "spin", "natoms"]:
tmp_dict[item_key] = sys[item_key]
spin_sampled.append(tmp_dict)
return spin_sampled

return self.backbone_model.atomic_model._make_wrapped_sampler(spin_sampled_func)
Comment thread
iProzd marked this conversation as resolved.

def change_out_bias(
self,
merged: Callable[[], list[dict]] | list[dict],
bias_adjust_mode: str = "change-by-statistic",
) -> None:
"""Change the output bias of atomic model according to the input data and the pretrained model.

Parameters
----------
merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
bias_adjust_mode : str
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
'change-by-statistic' : perform predictions on labels of target dataset,
and do least square on the errors to obtain the target shift as bias.
'set-by-statistic' : directly use the statistic output bias in the target dataset.
"""
spin_sampled_func = self._get_spin_sampled_func(
merged if callable(merged) else lambda: merged
)
self.backbone_model.change_out_bias(
spin_sampled_func,
bias_adjust_mode=bias_adjust_mode,
)

def change_type_map(
self, type_map: list[str], model_with_new_type_stat: Any = None
) -> None:
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
"""
type_map_with_spin = type_map + [item + "_spin" for item in type_map]
self.backbone_model.change_type_map(
type_map_with_spin, model_with_new_type_stat
)
Comment thread
iProzd marked this conversation as resolved.

def __getattr__(self, name: str) -> Any:
"""Get attribute from the wrapped model."""
if "backbone_model" not in self.__dict__:
Expand Down
53 changes: 53 additions & 0 deletions deepmd/pt/model/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import functools
import logging
from collections.abc import (
Callable,
Expand Down Expand Up @@ -184,6 +185,58 @@ def has_default_fparam(self) -> bool:
"""Check if the model has default frame parameters."""
return False

def get_default_fparam(self) -> torch.Tensor | None:
"""Get the default frame parameters."""
return None

def _make_wrapped_sampler(
self,
sampled_func: Callable[[], list[dict]],
) -> Callable[[], list[dict]]:
"""Wrap the sampled function with exclusion types and default fparam.
The returned callable is cached so that the sampling (which may be
expensive) is performed at most once.
Parameters
----------
sampled_func
The lazy sampled function to get data frames from different data
systems.
Returns
-------
Callable[[], list[dict]]
A cached wrapper around *sampled_func* that additionally sets
``pair_exclude_types``, ``atom_exclude_types`` and default
``fparam`` on every sample dict when applicable.
"""

@functools.lru_cache
def wrapped_sampler() -> list[dict]:
sampled = sampled_func()
if self.pair_excl is not None:
pair_exclude_types = self.pair_excl.get_exclude_types()
for sample in sampled:
sample["pair_exclude_types"] = list(pair_exclude_types)
if self.atom_excl is not None:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)
if (
"find_fparam" not in sampled[0]
and "fparam" not in sampled[0]
and self.has_default_fparam()
):
default_fparam = self.get_default_fparam()
if default_fparam is not None:
for sample in sampled:
nframe = sample["atype"].shape[0]
sample["fparam"] = default_fparam.repeat(nframe, 1)
Comment thread
iProzd marked this conversation as resolved.
return sampled

return wrapped_sampler

def reinit_atom_exclude(
self,
exclude_types: list[int] = [],
Expand Down
24 changes: 1 addition & 23 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import functools
import logging
from collections.abc import (
Callable,
Expand Down Expand Up @@ -329,28 +328,7 @@ def compute_or_load_stat(
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)

@functools.lru_cache
def wrapped_sampler() -> list[dict]:
sampled = sampled_func()
if self.pair_excl is not None:
pair_exclude_types = self.pair_excl.get_exclude_types()
for sample in sampled:
sample["pair_exclude_types"] = list(pair_exclude_types)
if self.atom_excl is not None:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)
if (
"find_fparam" not in sampled[0]
and "fparam" not in sampled[0]
and self.has_default_fparam()
):
default_fparam = self.get_default_fparam()
for sample in sampled:
nframe = sample["atype"].shape[0]
sample["fparam"] = default_fparam.repeat(nframe, 1)
return sampled

wrapped_sampler = self._make_wrapped_sampler(sampled_func)
self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path)
self.compute_fitting_input_stat(wrapped_sampler, stat_file_path)
if compute_or_load_out_stat:
Expand Down
15 changes: 1 addition & 14 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import functools
from collections.abc import (
Callable,
)
Expand Down Expand Up @@ -518,19 +517,7 @@ def compute_or_load_stat(
# should not share the same parameters
stat_file_path /= " ".join(self.type_map)

@functools.lru_cache
def wrapped_sampler() -> list[dict[str, Any]]:
sampled = sampled_func()
if self.pair_excl is not None:
pair_exclude_types = self.pair_excl.get_exclude_types()
for sample in sampled:
sample["pair_exclude_types"] = list(pair_exclude_types)
if self.atom_excl is not None:
atom_exclude_types = self.atom_excl.get_exclude_types()
for sample in sampled:
sample["atom_exclude_types"] = list(atom_exclude_types)
return sampled

wrapped_sampler = self._make_wrapped_sampler(sampled_func)
self.compute_or_load_out_stat(wrapped_sampler, stat_file_path)


Expand Down
69 changes: 69 additions & 0 deletions deepmd/pt/model/model/spin_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,75 @@ def __getattr__(self, name: str) -> Any:
else:
return getattr(self.backbone_model, name)

def _get_spin_sampled_func(
self, sampled_func: Callable[[], list[dict]]
) -> Callable[[], list[dict]]:
@functools.lru_cache
def spin_sampled_func() -> list[dict]:
sampled = sampled_func()
spin_sampled = []
for sys in sampled:
coord_updated, atype_updated, _ = self.process_spin_input(
sys["coord"], sys["atype"], sys["spin"]
)
tmp_dict = {
"coord": coord_updated,
"atype": atype_updated,
}
if "natoms" in sys:
natoms = sys["natoms"]
tmp_dict["natoms"] = torch.cat(
[2 * natoms[:, :2], natoms[:, 2:], natoms[:, 2:]], dim=-1
)
for item_key in sys.keys():
if item_key not in ["coord", "atype", "spin", "natoms"]:
tmp_dict[item_key] = sys[item_key]
spin_sampled.append(tmp_dict)
return spin_sampled

return self.backbone_model.atomic_model._make_wrapped_sampler(spin_sampled_func)

def change_out_bias(
self,
merged: Callable[[], list[dict]] | list[dict],
bias_adjust_mode: str = "change-by-statistic",
) -> None:
"""Change the output bias of atomic model according to the input data and the pretrained model.

Parameters
----------
merged : Union[Callable[[], list[dict]], list[dict]]
- list[dict]: A list of data samples from various data systems.
Each element, `merged[i]`, is a data dictionary containing `keys`: `torch.Tensor`
originating from the `i`-th data system.
- Callable[[], list[dict]]: A lazy function that returns data samples in the above format
only when needed. Since the sampling process can be slow and memory-intensive,
the lazy function helps by only sampling once.
bias_adjust_mode : str
The mode for changing output bias : ['change-by-statistic', 'set-by-statistic']
'change-by-statistic' : perform predictions on labels of target dataset,
and do least square on the errors to obtain the target shift as bias.
'set-by-statistic' : directly use the statistic output bias in the target dataset.
"""
spin_sampled_func = self._get_spin_sampled_func(
merged if callable(merged) else lambda: merged
)
self.backbone_model.change_out_bias(
spin_sampled_func,
bias_adjust_mode=bias_adjust_mode,
)

def change_type_map(
self, type_map: list[str], model_with_new_type_stat: Any = None
) -> None:
"""Change the type related params to new ones, according to `type_map` and the original one in the model.
If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types.
"""
type_map_with_spin = type_map + [item + "_spin" for item in type_map]
self.backbone_model.change_type_map(
type_map_with_spin, model_with_new_type_stat
)
Comment thread
iProzd marked this conversation as resolved.
Comment thread
iProzd marked this conversation as resolved.

def compute_or_load_stat(
self,
sampled_func: Callable[[], list[dict[str, Any]]],
Expand Down
4 changes: 3 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1833,6 +1833,8 @@ def model_change_out_bias(

model_type_map = _model.get_type_map()
log.info(
f"Change output bias of {model_type_map!s} from {to_numpy_array(old_bias).reshape(-1)!s} to {to_numpy_array(new_bias).reshape(-1)!s}."
f"Change output bias of {model_type_map!s} "
f"from {to_numpy_array(old_bias).reshape(-1)[: len(model_type_map)]!s} "
f"to {to_numpy_array(new_bias).reshape(-1)[: len(model_type_map)]!s}."
Comment thread
iProzd marked this conversation as resolved.
)
return _model
Loading