diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index 4da9bf781b..93451d17c5 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -29,10 +29,16 @@ from .energy_atomic_model import ( DPEnergyAtomicModel, ) +from .les_atomic_model import ( + LESEnergyAtomicModel, +) from .linear_atomic_model import ( DPZBLLinearEnergyAtomicModel, LinearEnergyAtomicModel, ) +from .lr_energy_atomic_model import ( + LREnergyAtomicModel, +) from .pairtab_atomic_model import ( PairTabAtomicModel, ) @@ -42,6 +48,9 @@ from .property_atomic_model import ( DPPropertyAtomicModel, ) +from .sog_atomic_model import ( + SOGEnergyAtomicModel, +) __all__ = [ "BaseAtomicModel", @@ -52,6 +61,9 @@ "DPPolarAtomicModel", "DPPropertyAtomicModel", "DPZBLLinearEnergyAtomicModel", + "LESEnergyAtomicModel", + "LREnergyAtomicModel", "LinearEnergyAtomicModel", "PairTabAtomicModel", + "SOGEnergyAtomicModel", ] diff --git a/deepmd/pt/model/atomic_model/les_atomic_model.py b/deepmd/pt/model/atomic_model/les_atomic_model.py new file mode 100644 index 0000000000..e6a19f01fc --- /dev/null +++ b/deepmd/pt/model/atomic_model/les_atomic_model.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt.model.task.les_energy_fitting import ( + LESEnergyFittingNet, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_atomic_model import ( + BaseAtomicModel, +) + + +@BaseAtomicModel.register("energy_les") +class LESEnergyAtomicModel(BaseAtomicModel): + """Energy model using a dedicated LES energy fitting net. + + The LES energy fitting net combines a short-range invariant fitting + and a long-range correction derived from another invariant fitting. + This avoids requiring a user-defined property name in the dataset. + """ + + def __init__( + self, + descriptor: Any, + type_map: list[str], + les_energy_fitting: LESEnergyFittingNet | None = None, + fitting: Any | None = None, + **kwargs: Any, + ) -> None: + super().__init__(type_map, **kwargs) + if les_energy_fitting is None: + les_energy_fitting = fitting + if not isinstance(les_energy_fitting, LESEnergyFittingNet): + raise TypeError( + "les_energy_fitting must be an instance of LESEnergyFittingNet" + ) + + self.descriptor = descriptor + self.fitting_net = les_energy_fitting + # self.les_energy_fitting = self.fitting_net + self.type_map = type_map + self.ntypes = len(type_map) + self.rcut = self.descriptor.get_rcut() + self.sel = self.descriptor.get_sel() + + super().init_out_stat() + + self.enable_eval_descriptor_hook = False + self.enable_eval_fitting_last_layer_hook = False + self.eval_descriptor_list: list[torch.Tensor] = [] + self.eval_fitting_last_layer_list: list[torch.Tensor] = [] + + @torch.jit.export + def fitting_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + name="latent_charge", + shape=[self.fitting_net.dim_out_lr], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), + ] + ) + + def get_rcut(self) -> float: + return self.rcut + + def get_sel(self) -> list[int]: + return self.sel + + def mixed_types(self) -> bool: + return self.descriptor.mixed_types() + + def has_message_passing(self) -> bool: + return self.descriptor.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + return self.descriptor.need_sorted_nlist_for_lower() + + def set_case_embd(self, case_idx: int) -> None: + self.fitting_net.set_case_embd(case_idx) + + def get_dim_fparam(self) -> int: + return self.fitting_net.get_dim_fparam() + + def has_default_fparam(self) -> bool: + return self.fitting_net.has_default_fparam() + + def get_default_fparam(self) -> torch.Tensor | None: + return self.fitting_net.get_default_fparam() + + def get_dim_aparam(self) -> int: + return self.fitting_net.get_dim_aparam() + + def get_sel_type(self) -> list[int]: + return self.fitting_net.get_sel_type() + + def is_aparam_nall(self) -> bool: + return False + + def set_eval_descriptor_hook(self, enable: bool) -> None: + self.enable_eval_descriptor_hook = enable + self.eval_descriptor_list.clear() + + def eval_descriptor(self) -> torch.Tensor: + return torch.concat(self.eval_descriptor_list) + + def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: + self.enable_eval_fitting_last_layer_hook = enable + self.fitting_net.set_return_middle_output(enable) + self.eval_fitting_last_layer_list.clear() + + def eval_fitting_last_layer(self) -> torch.Tensor: + return torch.concat(self.eval_fitting_last_layer_list) + + def forward_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, + ) -> dict[str, torch.Tensor]: + nframes, nloc, _ = nlist.shape + atype = extended_atype[:, :nloc] + if self.do_grad_r() or self.do_grad_c(): + extended_coord.requires_grad_(True) + + descriptor_comm_dict = comm_dict + if comm_dict is not None and "send_list" not in comm_dict: + descriptor_comm_dict = None + + descriptor, rot_mat, g2, h2, sw = self.descriptor( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + comm_dict=descriptor_comm_dict, + ) + assert descriptor is not None + if self.enable_eval_descriptor_hook: + self.eval_descriptor_list.append(descriptor.detach()) + + energy_ret = self.fitting_net( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + + if self.enable_eval_fitting_last_layer_hook and "middle_output" in energy_ret: + self.eval_fitting_last_layer_list.append( + energy_ret["middle_output"].detach() + ) + + ret = { + "energy": energy_ret["energy"], + "latent_charge": energy_ret["latent_charge"], + } + if "middle_output" in energy_ret: + ret["middle_output"] = energy_ret["middle_output"] + return ret + + def apply_out_stat( + self, + ret: dict[str, torch.Tensor], + atype: torch.Tensor, + ) -> dict[str, torch.Tensor]: + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for kk in self.bias_keys: + ret[kk] = ret[kk] + out_bias[kk][atype] + return ret + + def compute_or_load_stat( + self, + sampled_func: Any, + stat_file_path: Any | None = None, + compute_or_load_out_stat: bool = True, + preset_observed_type: list[str] | None = None, + ) -> None: + if stat_file_path is not None and self.type_map is not None: + stat_file_path /= " ".join(self.type_map) + + def wrapped_sampler() -> list[dict]: + sampled = sampled_func() + if self.pair_excl is not None: + pair_exclude_types = self.pair_excl.get_exclude_types() + for sample in sampled: + sample["pair_exclude_types"] = list(pair_exclude_types) + if self.atom_excl is not None: + atom_exclude_types = self.atom_excl.get_exclude_types() + for sample in sampled: + sample["atom_exclude_types"] = list(atom_exclude_types) + if ( + "find_fparam" not in sampled[0] + and "fparam" not in sampled[0] + and self.has_default_fparam() + ): + default_fparam = self.get_default_fparam() + for sample in sampled: + nframe = sample["atype"].shape[0] + sample["fparam"] = default_fparam.repeat(nframe, 1) + return sampled + + self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) + self.compute_fitting_input_stat(wrapped_sampler, stat_file_path) + if compute_or_load_out_stat: + self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + + self._collect_and_set_observed_type( + wrapped_sampler, stat_file_path, preset_observed_type + ) + + def compute_fitting_input_stat( + self, + sample_merged: Any, + stat_file_path: Any | None = None, + ) -> None: + self.fitting_net.compute_input_stats( + sample_merged, + protection=self.data_stat_protect, + stat_file_path=stat_file_path, + ) + + def serialize(self) -> dict: + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "@version": 1, + "type": "energy_les", + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "les_energy_fitting": self.fitting_net.serialize(), + } + ) + return dd + + @classmethod + def deserialize(cls, data: dict) -> "LESEnergyAtomicModel": + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + data.pop("type", None) + descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) + les_energy_fitting_obj = LESEnergyFittingNet.deserialize( + data.pop("les_energy_fitting") + ) + obj = cls( + descriptor=descriptor_obj, + les_energy_fitting=les_energy_fitting_obj, + **data, + ) + return obj diff --git a/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py b/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py new file mode 100644 index 0000000000..a21127595c --- /dev/null +++ b/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py @@ -0,0 +1,368 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from collections.abc import ( + Iterable, +) +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt.model.task.ener import ( + EnergyFittingNet, + EnergyFittingNetDirect, + InvarFitting, +) +from deepmd.pt.model.task.property import ( + PropertyFittingNet, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_atomic_model import ( + BaseAtomicModel, +) + + +@BaseAtomicModel.register("energy_lr") +class LREnergyAtomicModel(BaseAtomicModel): + """Energy model with an auxiliary property-driven correction. + + This model shares one descriptor with two fitting nets: + - ``energy_fitting`` predicts the primary atomic energy/force term. + - ``property_fitting`` predicts an auxiliary per-atom property ``q``. + The property is then passed through a small trainable correction head + to generate an additive energy (and resulting force) correction. + """ + + def __init__( + self, + descriptor: BaseDescriptor, + energy_fitting: InvarFitting, + property_fitting: PropertyFittingNet, + type_map: list[str], + correction_hidden: Iterable[int] | None = None, + correction_activation: str = "tanh", + **kwargs: Any, + ) -> None: + super().__init__(type_map, **kwargs) + if not ( + isinstance(energy_fitting, EnergyFittingNet) + or isinstance(energy_fitting, EnergyFittingNetDirect) + or isinstance(energy_fitting, InvarFitting) + ): + raise TypeError( + "energy_fitting must be an energy-like InvarFitting for LREnergyAtomicModel" + ) + if not isinstance(property_fitting, PropertyFittingNet): + raise TypeError( + "property_fitting must be an instance of PropertyFittingNet for LREnergyAtomicModel" + ) + + if energy_fitting.get_dim_fparam() != property_fitting.get_dim_fparam(): + raise ValueError( + "energy_fitting and property_fitting must share the same dim_fparam" + ) + if energy_fitting.get_dim_aparam() != property_fitting.get_dim_aparam(): + raise ValueError( + "energy_fitting and property_fitting must share the same dim_aparam" + ) + + self.descriptor = descriptor + self.energy_fitting = energy_fitting + self.property_fitting = property_fitting + self.property_name = property_fitting.var_name + self.type_map = type_map + self.ntypes = len(type_map) + self.rcut = self.descriptor.get_rcut() + self.sel = self.descriptor.get_sel() + self.correction_activation = correction_activation + hidden = ( + list(correction_hidden) + if correction_hidden is not None + else [property_fitting.dim_out] + ) + self.correction_hidden = hidden + self.correction_head = self._build_correction_head( + property_fitting.dim_out, hidden, correction_activation + ) + super().init_out_stat() + + self.enable_eval_descriptor_hook = False + self.enable_eval_fitting_last_layer_hook = False + self.eval_descriptor_list: list[torch.Tensor] = [] + self.eval_fitting_last_layer_list: list[torch.Tensor] = [] + + @staticmethod + def _build_correction_head( + input_dim: int, hidden: list[int], activation: str + ) -> torch.nn.Module: + layers: list[torch.nn.Module] = [] + last = input_dim + act_factory = getattr(torch.nn, activation.capitalize(), torch.nn.Tanh) + for width in hidden: + layers.append(torch.nn.Linear(last, width)) + layers.append(act_factory()) + last = width + layers.append(torch.nn.Linear(last, 1)) + return torch.nn.Sequential(*layers) + + def fitting_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + name=self.property_name, + shape=[self.property_fitting.dim_out], + reducible=True, + r_differentiable=False, + c_differentiable=False, + intensive=self.property_fitting.get_intensive(), + ), + ] + ) + + def get_rcut(self) -> float: + return self.rcut + + def get_sel(self) -> list[int]: + return self.sel + + def mixed_types(self) -> bool: + return self.descriptor.mixed_types() + + def has_message_passing(self) -> bool: + return self.descriptor.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + return self.descriptor.need_sorted_nlist_for_lower() + + def set_case_embd(self, case_idx: int) -> None: + self.energy_fitting.set_case_embd(case_idx) + self.property_fitting.set_case_embd(case_idx) + + def get_dim_fparam(self) -> int: + return self.energy_fitting.get_dim_fparam() + + def has_default_fparam(self) -> bool: + return self.energy_fitting.has_default_fparam() + + def get_default_fparam(self) -> torch.Tensor | None: + return self.energy_fitting.get_default_fparam() + + def get_dim_aparam(self) -> int: + return self.energy_fitting.get_dim_aparam() + + def get_sel_type(self) -> list[int]: + return self.energy_fitting.get_sel_type() + + def is_aparam_nall(self) -> bool: + return False + + def set_eval_descriptor_hook(self, enable: bool) -> None: + self.enable_eval_descriptor_hook = enable + self.eval_descriptor_list.clear() + + def eval_descriptor(self) -> torch.Tensor: + return torch.concat(self.eval_descriptor_list) + + def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: + self.enable_eval_fitting_last_layer_hook = enable + self.energy_fitting.set_return_middle_output(enable) + self.eval_fitting_last_layer_list.clear() + + def eval_fitting_last_layer(self) -> torch.Tensor: + return torch.concat(self.eval_fitting_last_layer_list) + + def forward_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, + ) -> dict[str, torch.Tensor]: + nframes, nloc, _ = nlist.shape + atype = extended_atype[:, :nloc] + if self.do_grad_r() or self.do_grad_c(): + extended_coord.requires_grad_(True) + descriptor, rot_mat, g2, h2, sw = self.descriptor( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + comm_dict=comm_dict, + ) + assert descriptor is not None + if self.enable_eval_descriptor_hook: + self.eval_descriptor_list.append(descriptor.detach()) + + energy_ret = self.energy_fitting( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + prop_ret = self.property_fitting( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + + if self.enable_eval_fitting_last_layer_hook and "middle_output" in energy_ret: + self.eval_fitting_last_layer_list.append( + energy_ret["middle_output"].detach() + ) + + q_val = prop_ret[self.property_name] + corr_energy = self.correction_head(q_val) + total_energy = energy_ret["energy"] + corr_energy + + return { + "energy": total_energy, + self.property_name: q_val, + } + + def apply_out_stat( + self, + ret: dict[str, torch.Tensor], + atype: torch.Tensor, + ) -> dict[str, torch.Tensor]: + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for kk in self.bias_keys: + if kk == self.property_name: + ret[kk] = ret[kk] * out_std[kk][atype] + out_bias[kk][atype] + else: + ret[kk] = ret[kk] + out_bias[kk][atype] + return ret + + def compute_or_load_stat( + self, + sampled_func: Any, + stat_file_path: Any | None = None, + compute_or_load_out_stat: bool = True, + ) -> None: + if stat_file_path is not None and self.type_map is not None: + stat_file_path /= " ".join(self.type_map) + + def wrapped_sampler() -> list[dict]: + sampled = sampled_func() + if self.pair_excl is not None: + pair_exclude_types = self.pair_excl.get_exclude_types() + for sample in sampled: + sample["pair_exclude_types"] = list(pair_exclude_types) + if self.atom_excl is not None: + atom_exclude_types = self.atom_excl.get_exclude_types() + for sample in sampled: + sample["atom_exclude_types"] = list(atom_exclude_types) + if ( + "find_fparam" not in sampled[0] + and "fparam" not in sampled[0] + and self.has_default_fparam() + ): + default_fparam = self.get_default_fparam() + for sample in sampled: + nframe = sample["atype"].shape[0] + sample["fparam"] = default_fparam.repeat(nframe, 1) + return sampled + + self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) + self.compute_fitting_input_stat(wrapped_sampler, stat_file_path) + if compute_or_load_out_stat: + self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + + def compute_fitting_input_stat( + self, + sample_merged: Any, + stat_file_path: Any | None = None, + ) -> None: + self.energy_fitting.compute_input_stats( + sample_merged, + protection=self.data_stat_protect, + stat_file_path=stat_file_path, + ) + self.property_fitting.compute_input_stats( + sample_merged, + protection=self.data_stat_protect, + stat_file_path=stat_file_path, + ) + + def serialize(self) -> dict: + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "@version": 1, + "type": "energy_q_aug", + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "energy_fitting": self.energy_fitting.serialize(), + "property_fitting": self.property_fitting.serialize(), + "correction_hidden": self.correction_hidden, + "correction_activation": self.correction_activation, + "@variables": { + **dd.get("@variables", {}), + "correction_head": { + k: to_numpy_array(v) + for k, v in self.correction_head.state_dict().items() + }, + }, + } + ) + return dd + + @classmethod + def deserialize(cls, data: dict) -> "LREnergyAtomicModel": + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + data.pop("type", None) + variables = data.pop("@variables", {}) + descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) + energy_fitting_obj = InvarFitting.deserialize(data.pop("energy_fitting")) + property_fitting_obj = PropertyFittingNet.deserialize( + data.pop("property_fitting") + ) + correction_hidden = data.pop("correction_hidden", None) + correction_activation = data.pop("correction_activation", "tanh") + obj = cls( + descriptor_obj, + energy_fitting_obj, + property_fitting_obj, + correction_hidden=correction_hidden, + correction_activation=correction_activation, + **data, + ) + correction_state = variables.get("correction_head", None) + if correction_state is not None: + obj.correction_head.load_state_dict( + {k: to_torch_tensor(v) for k, v in correction_state.items()} + ) + return obj diff --git a/deepmd/pt/model/atomic_model/sog_atomic_model.py b/deepmd/pt/model/atomic_model/sog_atomic_model.py new file mode 100644 index 0000000000..e23c0b671d --- /dev/null +++ b/deepmd/pt/model/atomic_model/sog_atomic_model.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt.model.task.sog_energy_fitting import ( + SOGEnergyFittingNet, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_atomic_model import ( + BaseAtomicModel, +) + + +@BaseAtomicModel.register("energy_sog") +class SOGEnergyAtomicModel(BaseAtomicModel): + """Energy model using a dedicated SOG energy fitting net. + + The SOG energy fitting net combines a short-range invariant fitting + and a long-range correction derived from another invariant fitting. + This avoids requiring a user-defined property name in the dataset. + """ + + def __init__( + self, + descriptor: Any, + type_map: list[str], + sog_energy_fitting: SOGEnergyFittingNet | None = None, + fitting: Any | None = None, + **kwargs: Any, + ) -> None: + super().__init__(type_map, **kwargs) + if sog_energy_fitting is None: + sog_energy_fitting = fitting + if not isinstance(sog_energy_fitting, SOGEnergyFittingNet): + raise TypeError( + "sog_energy_fitting must be an instance of SOGEnergyFittingNet" + ) + + self.descriptor = descriptor + self.fitting_net = sog_energy_fitting + # self.sog_energy_fitting = self.fitting_net + self.type_map = type_map + self.ntypes = len(type_map) + self.rcut = self.descriptor.get_rcut() + self.sel = self.descriptor.get_sel() + + super().init_out_stat() + + self.enable_eval_descriptor_hook = False + self.enable_eval_fitting_last_layer_hook = False + self.eval_descriptor_list: list[torch.Tensor] = [] + self.eval_fitting_last_layer_list: list[torch.Tensor] = [] + + @torch.jit.export + def fitting_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + name="latent_charge", + shape=[self.fitting_net.dim_out_lr], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), + ] + ) + + def get_rcut(self) -> float: + return self.rcut + + def get_sel(self) -> list[int]: + return self.sel + + def mixed_types(self) -> bool: + return self.descriptor.mixed_types() + + def has_message_passing(self) -> bool: + return self.descriptor.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + return self.descriptor.need_sorted_nlist_for_lower() + + def set_case_embd(self, case_idx: int) -> None: + self.fitting_net.set_case_embd(case_idx) + + def get_dim_fparam(self) -> int: + return self.fitting_net.get_dim_fparam() + + def has_default_fparam(self) -> bool: + return self.fitting_net.has_default_fparam() + + def get_default_fparam(self) -> torch.Tensor | None: + return self.fitting_net.get_default_fparam() + + def get_dim_aparam(self) -> int: + return self.fitting_net.get_dim_aparam() + + def get_sel_type(self) -> list[int]: + return self.fitting_net.get_sel_type() + + def is_aparam_nall(self) -> bool: + return False + + def set_eval_descriptor_hook(self, enable: bool) -> None: + self.enable_eval_descriptor_hook = enable + self.eval_descriptor_list.clear() + + def eval_descriptor(self) -> torch.Tensor: + return torch.concat(self.eval_descriptor_list) + + def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: + self.enable_eval_fitting_last_layer_hook = enable + self.fitting_net.set_return_middle_output(enable) + self.eval_fitting_last_layer_list.clear() + + def eval_fitting_last_layer(self) -> torch.Tensor: + return torch.concat(self.eval_fitting_last_layer_list) + + def forward_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, + ) -> dict[str, torch.Tensor]: + nframes, nloc, _ = nlist.shape + atype = extended_atype[:, :nloc] + if self.do_grad_r() or self.do_grad_c(): + extended_coord.requires_grad_(True) + + descriptor_comm_dict = comm_dict + if comm_dict is not None and "send_list" not in comm_dict: + descriptor_comm_dict = None + + descriptor, rot_mat, g2, h2, sw = self.descriptor( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + comm_dict=descriptor_comm_dict, + ) + assert descriptor is not None + if self.enable_eval_descriptor_hook: + self.eval_descriptor_list.append(descriptor.detach()) + + energy_ret = self.fitting_net( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + + if self.enable_eval_fitting_last_layer_hook and "middle_output" in energy_ret: + self.eval_fitting_last_layer_list.append( + energy_ret["middle_output"].detach() + ) + + ret = { + "energy": energy_ret["energy"], + "latent_charge": energy_ret["latent_charge"], + } + if "middle_output" in energy_ret: + ret["middle_output"] = energy_ret["middle_output"] + return ret + + def apply_out_stat( + self, + ret: dict[str, torch.Tensor], + atype: torch.Tensor, + ) -> dict[str, torch.Tensor]: + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for kk in self.bias_keys: + ret[kk] = ret[kk] + out_bias[kk][atype] + return ret + + def compute_or_load_stat( + self, + sampled_func: Any, + stat_file_path: Any | None = None, + compute_or_load_out_stat: bool = True, + preset_observed_type: list[str] | None = None, + ) -> None: + if stat_file_path is not None and self.type_map is not None: + stat_file_path /= " ".join(self.type_map) + + def wrapped_sampler() -> list[dict]: + sampled = sampled_func() + if self.pair_excl is not None: + pair_exclude_types = self.pair_excl.get_exclude_types() + for sample in sampled: + sample["pair_exclude_types"] = list(pair_exclude_types) + if self.atom_excl is not None: + atom_exclude_types = self.atom_excl.get_exclude_types() + for sample in sampled: + sample["atom_exclude_types"] = list(atom_exclude_types) + if ( + "find_fparam" not in sampled[0] + and "fparam" not in sampled[0] + and self.has_default_fparam() + ): + default_fparam = self.get_default_fparam() + for sample in sampled: + nframe = sample["atype"].shape[0] + sample["fparam"] = default_fparam.repeat(nframe, 1) + return sampled + + self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) + self.compute_fitting_input_stat(wrapped_sampler, stat_file_path) + if compute_or_load_out_stat: + self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + + self._collect_and_set_observed_type( + wrapped_sampler, stat_file_path, preset_observed_type + ) + + def compute_fitting_input_stat( + self, + sample_merged: Any, + stat_file_path: Any | None = None, + ) -> None: + self.fitting_net.compute_input_stats( + sample_merged, + protection=self.data_stat_protect, + stat_file_path=stat_file_path, + ) + + def serialize(self) -> dict: + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "@version": 1, + "type": "energy_sog", + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "sog_energy_fitting": self.fitting_net.serialize(), + } + ) + return dd + + @classmethod + def deserialize(cls, data: dict) -> "SOGEnergyAtomicModel": + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + data.pop("type", None) + descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) + sog_energy_fitting_obj = SOGEnergyFittingNet.deserialize( + data.pop("sog_energy_fitting") + ) + obj = cls( + descriptor=descriptor_obj, + sog_energy_fitting=sog_energy_fitting_obj, + **data, + ) + return obj diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index 24075412db..fca5239110 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -54,6 +54,9 @@ from .frozen import ( FrozenModel, ) +from .les_model import ( + LESEnergyModel, +) from .make_hessian_model import ( make_hessian_model, ) @@ -69,6 +72,9 @@ from .property_model import ( PropertyModel, ) +from .sog_model import ( + SOGEnergyModel, +) from .spin_model import ( SpinEnergyModel, SpinModel, @@ -270,6 +276,10 @@ def get_standard_model(model_params: dict) -> BaseModel: modelcls = EnergyModel elif fitting_net_type == "property": modelcls = PropertyModel + elif fitting_net_type == "sog_energy": + modelcls = SOGEnergyModel + elif fitting_net_type == "les_energy": + modelcls = LESEnergyModel else: raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") @@ -311,8 +321,10 @@ def get_model(model_params: dict) -> Any: "DipoleModel", "EnergyModel", "FrozenModel", + "LESEnergyModel", "LinearEnergyModel", "PolarModel", + "SOGEnergyModel", "SpinEnergyModel", "SpinModel", "get_model", diff --git a/deepmd/pt/model/model/les_model.py b/deepmd/pt/model/model/les_model.py new file mode 100644 index 0000000000..c530871294 --- /dev/null +++ b/deepmd/pt/model/model/les_model.py @@ -0,0 +1,486 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import pytorch_finufft +import torch + +from deepmd.pt.model.atomic_model import ( + LESEnergyAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) +from deepmd.pt.model.model.transform_output import ( + communicate_extended_output, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_hessian_model import ( + make_hessian_model, +) +from .make_model import ( + make_model, +) + +LESEnergyModel_ = make_model(LESEnergyAtomicModel) + + +@BaseModel.register("les_ener") +class LESEnergyModel(DPModelCommon, LESEnergyModel_): + model_type = "les_ener" + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + DPModelCommon.__init__(self) + LESEnergyModel_.__init__(self, *args, **kwargs) + self._hessian_enabled = False + + def enable_hessian(self) -> None: + self.__class__ = make_hessian_model(type(self)) + self.hess_fitting_def = super(type(self), self).atomic_output_def() + self.requires_hessian("energy") + self._hessian_enabled = True + + @torch.jit.export + def get_observed_type_list(self) -> list[str]: + """Get observed types (elements) of the model during data statistics. + + Returns + ------- + observed_type_list: a list of the observed types in this model. + """ + type_map = self.get_type_map() + out_bias = self.atomic_model.get_out_bias()[0] + + assert out_bias is not None, "No out_bias found in the model." + assert out_bias.dim() == 2, "The supported out_bias should be a 2D tensor." + assert out_bias.size(0) == len(type_map), ( + "The out_bias shape does not match the type_map length." + ) + bias_mask = ( + torch.gt(torch.abs(out_bias), 1e-6).any(dim=-1).detach().cpu() + ) # 1e-6 for stability + + observed_type_list: list[str] = [] + for i in range(len(type_map)): + if bias_mask[i]: + observed_type_list.append(type_map[i]) + return observed_type_list + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], + } + if self.do_grad_r("energy"): + output_def["force"] = out_def_data["energy_derv_r"] + output_def["force"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = out_def_data["energy_derv_c_redu"] + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = out_def_data["energy_derv_c"] + output_def["atom_virial"].squeeze(-3) + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + if self._hessian_enabled: + output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] + return output_def + + def _compute_les_frame_correction_bundle( + self, + coord: torch.Tensor, + latent_charge: torch.Tensor, + box: torch.Tensor, + *, + need_force: bool, + need_virial: bool, + ) -> dict[str, torch.Tensor]: + fitting = self.get_fitting_net() + runtime_device = coord.device + real_dtype = coord.dtype + complex_dtype = ( + torch.complex128 if real_dtype == torch.float64 else torch.complex64 + ) + latent_charge = latent_charge.to(device=runtime_device, dtype=real_dtype) + box = box.to(device=runtime_device, dtype=real_dtype) + + sigma_raw = getattr(fitting, "sigma", None) + if sigma_raw is None: + raise ValueError( + "LES fitting net should provide `sigma` for frame correction." + ) + sigma = torch.as_tensor( + sigma_raw, + dtype=real_dtype, + device=runtime_device, + ).reshape(-1)[0] + sigma = torch.clamp(sigma, min=torch.finfo(real_dtype).eps) + remove_self_interaction = bool(fitting.remove_self_interaction) + n_dl = int(fitting.n_dl) + pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) + two_pi = torch.tensor(2.0 * torch.pi, dtype=real_dtype, device=runtime_device) + + nf, nloc, _ = coord.shape + corr = torch.zeros((nf, 1), dtype=real_dtype, device=runtime_device) + force_local = ( + torch.zeros((nf, nloc, 3), dtype=real_dtype, device=runtime_device) + if need_force + else None + ) + virial_local = ( + torch.zeros((nf, nloc, 1, 9), dtype=real_dtype, device=runtime_device) + if need_virial + else None + ) + + for ff in range(nf): + r_raw = coord[ff] + q = latent_charge[ff] + box_frame = box[ff] + + volume = torch.det(box_frame) + if torch.abs(volume) <= torch.finfo(real_dtype).eps: + raise ValueError( + "`box` is singular (near-zero volume), cannot run NUFFT." + ) + + cell_inv = torch.linalg.inv(box_frame) + r_frac = torch.matmul(r_raw, cell_inv) + r_frac = torch.remainder(r_frac + 0.5, 1.0) - 0.5 + point_limit = pi_tensor - 32.0 * torch.finfo(real_dtype).eps + r_in = torch.clamp( + 2.0 * pi_tensor * r_frac, + min=-point_limit, + max=point_limit, + ).contiguous() + nufft_points = r_in.transpose(0, 1).contiguous() + + norms = torch.norm(box_frame, dim=1) + nk = tuple(max(1, int(n.item() / n_dl)) for n in norms) + n1 = torch.arange( + -nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype + ) + n2 = torch.arange( + -nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype + ) + n3 = torch.arange( + -nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype + ) + + kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") + k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 + zero_mask = k_sq == 0 + + k_sq_safe = torch.where(zero_mask, torch.ones_like(k_sq), k_sq) + kfac = torch.exp(-0.5 * (sigma**2) * k_sq_safe) / k_sq_safe + kfac = kfac.to(dtype=real_dtype) + kfac[zero_mask] = 0.0 + + q_t = q.transpose(0, 1).contiguous() + charge = ( + torch.complex(q_t, torch.zeros_like(q_t)) + .to(dtype=complex_dtype) + .contiguous() + ) + output_shape = tuple(int(x) for x in kx_grid.shape) + recon = pytorch_finufft.functional.finufft_type1( + nufft_points, + charge, + output_shape=output_shape, + eps=1e-4, + isign=-1, + ) + + rho_sq = recon.real.square() + recon.imag.square() + corr[ff, 0] = (kfac.unsqueeze(0) * rho_sq).sum() / (2.0 * volume) + + conv = None + if need_force: + conv = kfac.unsqueeze(0).to(dtype=complex_dtype) * recon + + if need_force: + assert conv is not None + kk1 = torch.fft.ifftshift(kx_grid, dim=0) + kk2 = torch.fft.ifftshift(ky_grid, dim=1) + kk3 = torch.fft.ifftshift(kz_grid, dim=2) + k_grid = torch.stack((kk1, kk2, kk3), dim=0) + g_cart = two_pi * torch.einsum("ik,k...->i...", cell_inv, k_grid) + grad_conv = ( + 1j * g_cart.unsqueeze(1).to(dtype=complex_dtype) + ) * conv.unsqueeze(0) + grad_field = pytorch_finufft.functional.finufft_type2( + nufft_points, + grad_conv, + eps=1e-4, + isign=1, + ) + force_frame = ( + -(q_t.unsqueeze(0) * grad_field.real.to(dtype=real_dtype)) + .sum(dim=1) + .transpose(0, 1) + ) + force_frame = force_frame / volume + force_local[ff] = force_frame + + if need_virial: + virial_local[ff] = torch.einsum( + "ai,aj->aij", + force_frame, + r_raw, + ).reshape(nloc, 1, 9) + + if remove_self_interaction: + diag_sum = kfac.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume) + corr[ff, 0] -= torch.sum(q**2) * diag_sum + + out: dict[str, torch.Tensor] = {"corr_redu": corr} + if force_local is not None: + out["force_local"] = force_local + if virial_local is not None: + out["virial_local"] = virial_local + return out + + def _compute_les_frame_correction( + self, + coord: torch.Tensor, + latent_charge: torch.Tensor, + box: torch.Tensor, + ) -> torch.Tensor: + out = self._compute_les_frame_correction_bundle( + coord, + latent_charge, + box, + need_force=False, + need_virial=False, + ) + return out["corr_redu"] + + def _apply_frame_correction_lower( + self, + model_ret: dict[str, torch.Tensor], + extended_coord: torch.Tensor, + nlist: torch.Tensor, + box: torch.Tensor | None, + do_atomic_virial: bool, + ) -> dict[str, torch.Tensor]: + if box is None or "latent_charge" not in model_ret: + return model_ret + + nf, nloc, _ = nlist.shape + nall = extended_coord.shape[1] + coord_local = extended_coord[:, :nloc, :] + box_local = box.view(nf, 3, 3) + latent_charge = model_ret["latent_charge"] + need_force = self.do_grad_r("energy") or self.do_grad_c("energy") + need_virial = self.do_grad_c("energy") + latent_charge_runtime = ( + latent_charge if self.training else latent_charge.detach() + ) + corr_bundle = self._compute_les_frame_correction_bundle( + coord_local, + latent_charge_runtime, + box_local, + need_force=need_force, + need_virial=need_virial, + ) + corr_redu = corr_bundle["corr_redu"] + + model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to( + model_ret["energy_redu"].dtype + ) + + if need_force: + corr_force_local = corr_bundle["force_local"].to(coord_local.dtype) + + corr_force_ext = torch.zeros( + (nf, nall, 3), + dtype=corr_force_local.dtype, + device=corr_force_local.device, + ) + corr_force_ext[:, :nloc, :] = corr_force_local + if "energy_derv_r" in model_ret: + model_ret["energy_derv_r"] = model_ret[ + "energy_derv_r" + ] + corr_force_ext.unsqueeze(-2).to(model_ret["energy_derv_r"].dtype) + + if need_virial: + corr_virial_local = corr_bundle["virial_local"].to( + corr_force_local.dtype + ) + corr_virial_redu = corr_virial_local.sum(dim=1) + if "energy_derv_c_redu" in model_ret: + model_ret["energy_derv_c_redu"] = model_ret[ + "energy_derv_c_redu" + ] + corr_virial_redu.to(model_ret["energy_derv_c_redu"].dtype) + if do_atomic_virial and "energy_derv_c" in model_ret: + corr_atom_virial = torch.zeros( + (nf, nall, 1, 9), + dtype=corr_virial_local.dtype, + device=corr_virial_local.device, + ) + corr_atom_virial[:, :nloc, :, :] = corr_virial_local + model_ret["energy_derv_c"] = model_ret[ + "energy_derv_c" + ] + corr_atom_virial.to(model_ret["energy_derv_c"].dtype) + + return model_ret + + @torch.jit.export + def forward_common_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + extra_nlist_sort: bool = False, + extended_coord_corr: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + if self.do_grad_r("energy") or self.do_grad_c("energy"): + extended_coord = extended_coord.requires_grad_(True) + model_ret = super().forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=extra_nlist_sort, + extended_coord_corr=extended_coord_corr, + ) + box = None + if comm_dict is not None and "box" in comm_dict: + box = comm_dict["box"] + return self._apply_frame_correction_lower( + model_ret, + extended_coord, + nlist, + box, + do_atomic_virial, + ) + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + cc, bb, fp, ap, input_prec = self._input_type_cast( + coord, box=box, fparam=fparam, aparam=aparam + ) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + cc, + atype, + self.get_rcut(), + self.get_sel(), + mixed_types=True, + box=bb, + ) + comm_dict: dict[str, torch.Tensor] | None = None + if bb is not None: + comm_dict = {"box": bb} + model_predict_lower = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fp, + aparam=ap, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + ) + model_ret = communicate_extended_output( + model_predict_lower, + self.model_output_def(), + mapping, + do_atomic_virial=do_atomic_virial, + ) + model_ret = self._output_type_cast(model_ret, input_prec) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze( + -3 + ) + else: + model_predict["force"] = model_ret["dforce"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + if self._hessian_enabled: + model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-2) + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + @torch.jit.export + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + ) -> dict[str, torch.Tensor]: + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret[ + "energy_derv_c" + ].squeeze(-3) + else: + assert model_ret["dforce"] is not None + model_predict["dforce"] = model_ret["dforce"] + else: + model_predict = model_ret + return model_predict diff --git a/deepmd/pt/model/model/sog_model.py b/deepmd/pt/model/model/sog_model.py new file mode 100644 index 0000000000..44628737bc --- /dev/null +++ b/deepmd/pt/model/model/sog_model.py @@ -0,0 +1,559 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import pytorch_finufft +import torch + +from deepmd.pt.model.atomic_model import ( + SOGEnergyAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) +from deepmd.pt.model.model.transform_output import ( + communicate_extended_output, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_hessian_model import ( + make_hessian_model, +) +from .make_model import ( + make_model, +) + +SOGEnergyModel_ = make_model(SOGEnergyAtomicModel) + + +@BaseModel.register("sog_ener") +class SOGEnergyModel(DPModelCommon, SOGEnergyModel_): + model_type = "sog_ener" + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + DPModelCommon.__init__(self) + SOGEnergyModel_.__init__(self, *args, **kwargs) + self._hessian_enabled = False + # Runtime-only caches for NUFFT correction path. + self._sog_param_cache: dict[ + tuple[Any, ...], tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ] = {} + + @staticmethod + def _device_key(device: torch.device) -> str: + if device.index is None: + return device.type + return f"{device.type}:{device.index}" + + @staticmethod + def _trim_cache(cache: dict[Any, Any], max_size: int = 8) -> None: + if len(cache) > max_size: + oldest_key = next(iter(cache.keys())) + cache.pop(oldest_key, None) + + def _get_cached_sog_params( + self, + fitting: Any, + runtime_device: torch.device, + real_dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + wl_raw = fitting.wl + sl_raw = fitting.sl + grad_mode = torch.is_grad_enabled() and ( + wl_raw.requires_grad or sl_raw.requires_grad + ) + + wl = ( + wl_raw + if (wl_raw.device == runtime_device and wl_raw.dtype == real_dtype) + else wl_raw.to(dtype=real_dtype, device=runtime_device) + ) + sl = ( + sl_raw + if (sl_raw.device == runtime_device and sl_raw.dtype == real_dtype) + else sl_raw.to(dtype=real_dtype, device=runtime_device) + ) + min_term = -1.0 / torch.exp(-2.0 * sl) + + # Do not cache differentiable tensors across iterations. + if grad_mode: + return wl, sl, min_term + + wl_version = int(getattr(fitting.wl, "_version", 0)) + sl_version = int(getattr(fitting.sl, "_version", 0)) + cache_key = ( + self._device_key(runtime_device), + str(real_dtype), + wl_version, + sl_version, + ) + cached = self._sog_param_cache.get(cache_key) + if cached is not None: + return cached + + self._sog_param_cache[cache_key] = (wl, sl, min_term) + self._trim_cache(self._sog_param_cache) + return wl, sl, min_term + + def enable_hessian(self) -> None: + self.__class__ = make_hessian_model(type(self)) + self.hess_fitting_def = super(type(self), self).atomic_output_def() + self.requires_hessian("energy") + self._hessian_enabled = True + + @torch.jit.export + def get_observed_type_list(self) -> list[str]: + """Get observed types (elements) of the model during data statistics. + + Returns + ------- + observed_type_list: a list of the observed types in this model. + """ + type_map = self.get_type_map() + out_bias = self.atomic_model.get_out_bias()[0] + + assert out_bias is not None, "No out_bias found in the model." + assert out_bias.dim() == 2, "The supported out_bias should be a 2D tensor." + assert out_bias.size(0) == len(type_map), ( + "The out_bias shape does not match the type_map length." + ) + bias_mask = ( + torch.gt(torch.abs(out_bias), 1e-6).any(dim=-1).detach().cpu() + ) # 1e-6 for stability + + observed_type_list: list[str] = [] + for i in range(len(type_map)): + if bias_mask[i]: + observed_type_list.append(type_map[i]) + return observed_type_list + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], + } + if self.do_grad_r("energy"): + output_def["force"] = out_def_data["energy_derv_r"] + output_def["force"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = out_def_data["energy_derv_c_redu"] + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = out_def_data["energy_derv_c"] + output_def["atom_virial"].squeeze(-3) + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + if self._hessian_enabled: + output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] + return output_def + + def _compute_sog_frame_correction_bundle( + self, + coord: torch.Tensor, + latent_charge: torch.Tensor, + box: torch.Tensor, + *, + need_force: bool, + need_virial: bool, + ) -> dict[str, torch.Tensor]: + if coord.dim() != 3: + raise ValueError( + f"`coord` should be [nf, nloc, 3], got shape {tuple(coord.shape)}" + ) + if latent_charge.dim() != 3: + raise ValueError( + f"`latent_charge` should be [nf, nloc, nq], got shape {tuple(latent_charge.shape)}" + ) + if coord.shape[:2] != latent_charge.shape[:2]: + raise ValueError( + "`coord` and `latent_charge` local dimensions mismatch: " + f"{tuple(coord.shape[:2])} vs {tuple(latent_charge.shape[:2])}" + ) + + fitting = self.get_fitting_net() + runtime_device = coord.device + real_dtype = coord.dtype + complex_dtype = ( + torch.complex128 if real_dtype == torch.float64 else torch.complex64 + ) + latent_charge = latent_charge.to(device=runtime_device, dtype=real_dtype) + box = box.to(device=runtime_device, dtype=real_dtype) + if box.dim() != 3 or box.shape[-2:] != (3, 3): + raise ValueError( + f"`box` should be [nf, 3, 3], got shape {tuple(box.shape)}" + ) + + wl, _sl, min_term = self._get_cached_sog_params( + fitting, + runtime_device, + real_dtype, + ) + remove_self_interaction = bool(fitting.remove_self_interaction) + n_dl = int(fitting.n_dl) + pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) + two_pi = torch.tensor(2.0 * torch.pi, dtype=real_dtype, device=runtime_device) + + nf, nloc, _ = coord.shape + corr = torch.zeros((nf, 1), dtype=real_dtype, device=runtime_device) + force_local = ( + torch.zeros((nf, nloc, 3), dtype=real_dtype, device=runtime_device) + if need_force + else None + ) + virial_local = ( + torch.zeros((nf, nloc, 1, 9), dtype=real_dtype, device=runtime_device) + if need_virial + else None + ) + + for ff in range(nf): + r_raw = coord[ff] + q = latent_charge[ff] + box_frame = box[ff] + + volume = torch.det(box_frame) + if torch.abs(volume) <= torch.finfo(real_dtype).eps: + raise ValueError( + "`box` is singular (near-zero volume), cannot run NUFFT." + ) + + cell_inv = torch.linalg.inv(box_frame) + r_frac = torch.matmul(r_raw, cell_inv) + r_frac = torch.remainder(r_frac + 0.5, 1.0) - 0.5 + point_limit = pi_tensor - 32.0 * torch.finfo(real_dtype).eps + r_in = torch.clamp( + 2.0 * pi_tensor * r_frac, + min=-point_limit, + max=point_limit, + ).contiguous() + nufft_points = r_in.transpose(0, 1).contiguous() + + norms = torch.norm(box_frame, dim=1) + nk = tuple(max(1, int(n.item() / n_dl)) for n in norms) + n1 = torch.arange( + -nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype + ) + n2 = torch.arange( + -nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype + ) + n3 = torch.arange( + -nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype + ) + kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") + k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 + zero_mask = k_sq == 0 + + kfac = wl.view(1, 1, 1, -1) * torch.exp(k_sq.unsqueeze(-1) * min_term) + kfac = kfac.sum(dim=-1) + kfac = kfac.to(dtype=real_dtype) + kfac[zero_mask] = 0.0 + output_shape = tuple(int(x) for x in kx_grid.shape) + + q_t = q.transpose(0, 1).contiguous() + charge = ( + torch.complex(q_t, torch.zeros_like(q_t)) + .to(dtype=complex_dtype) + .contiguous() + ) + recon = pytorch_finufft.functional.finufft_type1( + nufft_points, + charge, + output_shape=output_shape, + eps=1e-4, + isign=-1, + ) + + rho_sq = recon.real.square() + recon.imag.square() + corr[ff, 0] = (kfac.unsqueeze(0) * rho_sq).sum() / (2.0 * volume) + + conv = None + if need_force: + conv = kfac.unsqueeze(0).to(dtype=complex_dtype) * recon + + if need_force: + assert conv is not None + # Reuse the already built k-grid and only reorder it to the FFT + # storage order required by type-2 inputs. + kk1 = torch.fft.ifftshift(kx_grid, dim=0) + kk2 = torch.fft.ifftshift(ky_grid, dim=1) + kk3 = torch.fft.ifftshift(kz_grid, dim=2) + k_grid = torch.stack((kk1, kk2, kk3), dim=0) + g_cart = two_pi * torch.einsum("ik,k...->i...", cell_inv, k_grid) + grad_conv = ( + 1j * g_cart.unsqueeze(1).to(dtype=complex_dtype) + ) * conv.unsqueeze(0) + grad_field = pytorch_finufft.functional.finufft_type2( + nufft_points, + grad_conv, + eps=1e-4, + isign=1, + ) + force_frame = ( + -(q_t.unsqueeze(0) * grad_field.real.to(dtype=real_dtype)) + .sum(dim=1) + .transpose(0, 1) + ) + force_frame = force_frame / volume + force_local[ff] = force_frame + + if need_virial: + virial_local[ff] = torch.einsum( + "ai,aj->aij", + force_frame, + r_raw, + ).reshape(nloc, 1, 9) + + if remove_self_interaction: + diag_sum = kfac.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume) + corr[ff, 0] -= torch.sum(q**2) * diag_sum + + out: dict[str, torch.Tensor] = {"corr_redu": corr} + if force_local is not None: + out["force_local"] = force_local + if virial_local is not None: + out["virial_local"] = virial_local + return out + + def _compute_sog_frame_correction( + self, + coord: torch.Tensor, + latent_charge: torch.Tensor, + box: torch.Tensor, + ) -> torch.Tensor: + out = self._compute_sog_frame_correction_bundle( + coord, + latent_charge, + box, + need_force=False, + need_virial=False, + ) + return out["corr_redu"] + + def _apply_frame_correction_lower( + self, + model_ret: dict[str, torch.Tensor], + extended_coord: torch.Tensor, + nlist: torch.Tensor, + box: torch.Tensor | None, + do_atomic_virial: bool, + ) -> dict[str, torch.Tensor]: + if box is None or "latent_charge" not in model_ret: + return model_ret + + nf, nloc, _ = nlist.shape + nall = extended_coord.shape[1] + coord_local = extended_coord[:, :nloc, :] + box_local = box.view(nf, 3, 3) + latent_charge = model_ret["latent_charge"] + need_force = self.do_grad_r("energy") or self.do_grad_c("energy") + need_virial = self.do_grad_c("energy") + latent_charge_runtime = ( + latent_charge if self.training else latent_charge.detach() + ) + corr_bundle = self._compute_sog_frame_correction_bundle( + coord_local, + latent_charge_runtime, + box_local, + need_force=need_force, + need_virial=need_virial, + ) + corr_redu = corr_bundle["corr_redu"] + + model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to( + model_ret["energy_redu"].dtype + ) + + if need_force: + corr_force_local = corr_bundle["force_local"].to(coord_local.dtype) + + corr_force_ext = torch.zeros( + (nf, nall, 3), + dtype=corr_force_local.dtype, + device=corr_force_local.device, + ) + corr_force_ext[:, :nloc, :] = corr_force_local + if "energy_derv_r" in model_ret: + model_ret["energy_derv_r"] = model_ret[ + "energy_derv_r" + ] + corr_force_ext.unsqueeze(-2).to(model_ret["energy_derv_r"].dtype) + + if need_virial: + corr_virial_local = corr_bundle["virial_local"].to( + corr_force_local.dtype + ) + corr_virial_redu = corr_virial_local.sum(dim=1) + if "energy_derv_c_redu" in model_ret: + model_ret["energy_derv_c_redu"] = model_ret[ + "energy_derv_c_redu" + ] + corr_virial_redu.to(model_ret["energy_derv_c_redu"].dtype) + if do_atomic_virial and "energy_derv_c" in model_ret: + corr_atom_virial = torch.zeros( + (nf, nall, 1, 9), + dtype=corr_virial_local.dtype, + device=corr_virial_local.device, + ) + corr_atom_virial[:, :nloc, :, :] = corr_virial_local + model_ret["energy_derv_c"] = model_ret[ + "energy_derv_c" + ] + corr_atom_virial.to(model_ret["energy_derv_c"].dtype) + + return model_ret + + @torch.jit.export + def forward_common_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + extra_nlist_sort: bool = False, + extended_coord_corr: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + if self.do_grad_r("energy") or self.do_grad_c("energy"): + extended_coord = extended_coord.requires_grad_(True) + model_ret = super().forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=extra_nlist_sort, + extended_coord_corr=extended_coord_corr, + ) + box = None + if comm_dict is not None and "box" in comm_dict: + box = comm_dict["box"] + return self._apply_frame_correction_lower( + model_ret, + extended_coord, + nlist, + box, + do_atomic_virial, + ) + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + cc, bb, fp, ap, input_prec = self._input_type_cast( + coord, box=box, fparam=fparam, aparam=aparam + ) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + cc, + atype, + self.get_rcut(), + self.get_sel(), + mixed_types=True, + box=bb, + ) + comm_dict: dict[str, torch.Tensor] | None = None + if bb is not None: + comm_dict = {"box": bb} + model_predict_lower = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fp, + aparam=ap, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + ) + model_ret = communicate_extended_output( + model_predict_lower, + self.model_output_def(), + mapping, + do_atomic_virial=do_atomic_virial, + ) + model_ret = self._output_type_cast(model_ret, input_prec) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze( + -3 + ) + else: + model_predict["force"] = model_ret["dforce"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + if self._hessian_enabled: + model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-2) + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + @torch.jit.export + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + comm_dict: dict[str, torch.Tensor] | None = None, + ) -> dict[str, torch.Tensor]: + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret[ + "energy_derv_c" + ].squeeze(-3) + else: + assert model_ret["dforce"] is not None + model_predict["dforce"] = model_ret["dforce"] + else: + model_predict = model_ret + return model_predict diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index 37ffec2725..7cdfbd35a4 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -18,12 +18,21 @@ from .fitting import ( Fitting, ) +from .les_energy_fitting import ( + LESEnergyFittingNet, +) +from .lr_fitting import ( + LRFittingNet, +) from .polarizability import ( PolarFittingNet, ) from .property import ( PropertyFittingNet, ) +from .sog_energy_fitting import ( + SOGEnergyFittingNet, +) from .type_predict import ( TypePredictNet, ) @@ -36,7 +45,10 @@ "EnergyFittingNet", "EnergyFittingNetDirect", "Fitting", + "LESEnergyFittingNet", + "LRFittingNet", "PolarFittingNet", "PropertyFittingNet", + "SOGEnergyFittingNet", "TypePredictNet", ] diff --git a/deepmd/pt/model/task/les_energy_fitting.py b/deepmd/pt/model/task/les_energy_fitting.py new file mode 100644 index 0000000000..325355291d --- /dev/null +++ b/deepmd/pt/model/task/les_energy_fitting.py @@ -0,0 +1,258 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, + fitting_check_output, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE + +log = logging.getLogger(__name__) +from deepmd.pt.model.task.lr_fitting import ( + LRFittingNet, +) + +LES_DEFAULT_SIGMA = to_numpy_array(np.array(2.8 / np.sqrt(2.0))) + + +@LRFittingNet.register("les_energy") +@fitting_check_output +class LESEnergyFittingNet(LRFittingNet): + """Construct a LES sr+lr interactions fitting net. + + Parameters + ---------- + var_name : str + The atomic property to fit. + ntypes : int + Element count. + dim_descrpt : int + Embedding width per atom. + dim_out_sr : int + The output dimension of the sr fitting net. + dim_out_lr : int + The output dimension of the lr fitting net. + neuron_sr : list[int] + Number of neurons in each hidden layers of the sr fitting net. + neuron_lr : list[int] + Number of neurons in each hidden layers of the lr fitting net. + bias_atom_e : torch.Tensor, optional + Average energy per atom for each element. + resnet_dt : bool + Using time-step in the ResNet construction. + numb_fparam : int + Number of frame parameters. + numb_aparam : int + Number of atomic parameters. + dim_case_embd : int + Dimension of case specific embedding. + activation_function : str + Activation function. + precision : str + Numerical precision. + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. + rcond : float, optional + The condition number for the regression of atomic energy. + seed : int, optional + Random seed. + exclude_types: list[int] + Atomic contributions of the excluded atom types are set zero. + trainable : Union[list[bool], bool] + If the parameters in the fitting net are trainable. + Now this only supports setting all the parameters in the fitting net at one state. + When in list[bool], the trainable will be True only if all the boolean parameters are True. + remove_vaccum_contribution: list[bool], optional + Remove vacuum contribution before the bias is added. The list assigned each + type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same + length as `ntypes` signaling if or not removing the vacuum contribution for the atom types in the list. + type_map: list[str], Optional + A list of strings. Give the name to each type of atoms. + use_aparam_as_mask: bool + If True, the aparam will not be used in fitting net for embedding. + default_fparam: list[float], optional + The default frame parameter. If set, when `fparam.npy` files are not included in the data system, + this value will be used as the default value for the frame parameter in the fitting net. + n_dl : int + NUFFT long-range grid density control factor. + remove_self_interaction : bool + If True, remove self interaction term in long-range correction. + """ + + def __init__( + self, + var_name: str, + ntypes: int, + dim_descrpt: int, + dim_out_sr: int, + dim_out_lr: int, + neuron_sr: list[int] = [128, 128, 128], + neuron_lr: list[int] = [128, 128, 128], + bias_atom_e: torch.Tensor | None = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + dim_case_embd: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + mixed_types: bool = True, + rcond: float | None = None, + seed: int | list[int] | None = None, + exclude_types: list[int] = [], + trainable: bool | list[bool] = True, + remove_vaccum_contribution: list[bool] | None = None, + type_map: list[str] | None = None, + use_aparam_as_mask: bool = False, + default_fparam: list[float] | None = None, + sigma: float | list[float] | torch.Tensor | None = None, + n_dl: int = 1, + remove_self_interaction: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + var_name=var_name, + ntypes=ntypes, + dim_descrpt=dim_descrpt, + dim_out_sr=dim_out_sr, + dim_out_lr=dim_out_lr, + neuron_sr=neuron_sr, + neuron_lr=neuron_lr, + bias_atom_e=bias_atom_e, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, + activation_function=activation_function, + precision=precision, + mixed_types=mixed_types, + rcond=rcond, + seed=seed, + exclude_types=exclude_types, + trainable=trainable, + remove_vaccum_contribution=remove_vaccum_contribution, + type_map=type_map, + use_aparam_as_mask=use_aparam_as_mask, + default_fparam=default_fparam, + **kwargs, + ) + if isinstance(sigma, (list, tuple)): + sigma = sigma[0] if len(sigma) > 0 else None + sigma_tensor = to_torch_tensor(sigma) + if sigma_tensor is None: + sigma_tensor = to_torch_tensor(LES_DEFAULT_SIGMA) + sigma_tensor = sigma_tensor.to(dtype=dtype, device=device).reshape(1) + sigma_tensor = torch.clamp( + sigma_tensor, + min=torch.finfo(sigma_tensor.dtype).eps, + ) + + self.n_dl = max(1, int(n_dl)) + self.sigma = torch.nn.Parameter( + sigma_tensor, + requires_grad=bool(self.trainable), + ) + self.remove_self_interaction = bool(remove_self_interaction) + self._nufft_fallback_warned = False + + def output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + name="latent_charge", + shape=[self.dim_out_lr], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), + ] + ) + + def serialize(self) -> dict: + data = super().serialize() + data["type"] = "les_energy" + data["@variables"]["sigma"] = to_numpy_array(self.sigma) + data["n_dl"] = self.n_dl + data["remove_self_interaction"] = bool(self.remove_self_interaction) + return data + + @classmethod + def deserialize(cls, data: dict) -> "LESEnergyFittingNet": + data = data.copy() + variables = data.get("@variables", {}).copy() + + sigma_tensor = to_torch_tensor(variables.pop("sigma", None)) + data["@variables"] = variables + + obj = super().deserialize(data) + + with torch.no_grad(): + if sigma_tensor is None: + raise ValueError( + "LES fitting net deserialize requires `sigma` in @variables." + ) + obj.sigma.copy_( + sigma_tensor.to(dtype=obj.sigma.dtype, device=obj.sigma.device).reshape( + 1 + ) + ) + return obj + + def _kernel_params(self) -> tuple[torch.Tensor]: + return (self.sigma,) + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + out = self._forward_common( + descriptor=descriptor, + atype=atype, + gr=gr, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + result = { + "energy": out["sr"], + "latent_charge": out["lr"], + } + if "middle_output" in out: + result["middle_output"] = out["middle_output"] + return result + + # make jit happy with torch 2.0.0 + exclude_types: list[int] diff --git a/deepmd/pt/model/task/lr_fitting.py b/deepmd/pt/model/task/lr_fitting.py new file mode 100644 index 0000000000..acf1d7eee2 --- /dev/null +++ b/deepmd/pt/model/task/lr_fitting.py @@ -0,0 +1,630 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, +) + +import numpy as np +import torch + +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.network.mlp import ( + FittingNet, + NetworkCollection, +) +from deepmd.pt.model.task.fitting import ( + Fitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) +from deepmd.pt.utils.exclude_mask import ( + AtomExcludeMask, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_atom_exclude_types, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE + + +@Fitting.register("lr") +class LRFittingNet(Fitting): + """Construct a general sr+lr interactions fitting net. + + Parameters + ---------- + var_name : str + The atomic property to fit. + ntypes : int + Element count. + dim_descrpt : int + Embedding width per atom. + dim_out_sr : int + The output dimension of the sr fitting net. + dim_out_lr : int + The output dimension of the lr fitting net. + neuron_sr : list[int] + Number of neurons in each hidden layers of the sr fitting net. + neuron_lr : list[int] + Number of neurons in each hidden layers of the lr fitting net. + bias_atom_e : torch.Tensor, optional + Average energy per atom for each element. + resnet_dt : bool + Using time-step in the ResNet construction. + numb_fparam : int + Number of frame parameters. + numb_aparam : int + Number of atomic parameters. + dim_case_embd : int + Dimension of case specific embedding. + activation_function : str + Activation function. + precision : str + Numerical precision. + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. + rcond : float, optional + The condition number for the regression of atomic energy. + seed : int, optional + Random seed. + exclude_types: list[int] + Atomic contributions of the excluded atom types are set zero. + trainable : Union[list[bool], bool] + If the parameters in the fitting net are trainable. + Now this only supports setting all the parameters in the fitting net at one state. + When in list[bool], the trainable will be True only if all the boolean parameters are True. + remove_vaccum_contribution: list[bool], optional + Remove vacuum contribution before the bias is added. The list assigned each + type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same + length as `ntypes` signaling if or not removing the vacuum contribution for the atom types in the list. + type_map: list[str], Optional + A list of strings. Give the name to each type of atoms. + use_aparam_as_mask: bool + If True, the aparam will not be used in fitting net for embedding. + default_fparam: list[float], optional + The default frame parameter. If set, when `fparam.npy` files are not included in the data system, + this value will be used as the default value for the frame parameter in the fitting net. + """ + + def __init__( + self, + var_name: str, + ntypes: int, + dim_descrpt: int, + dim_out_sr: int, + dim_out_lr: int, + neuron_sr: list[int] = [128, 128, 128], + neuron_lr: list[int] = [128, 128, 128], + bias_atom_e: torch.Tensor | None = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + dim_case_embd: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + mixed_types: bool = True, + rcond: float | None = None, + seed: int | list[int] | None = None, + exclude_types: list[int] = [], + trainable: bool | list[bool] = True, + remove_vaccum_contribution: list[bool] | None = None, + type_map: list[str] | None = None, + use_aparam_as_mask: bool = False, + default_fparam: list[float] | None = None, + **kwargs: Any, + ) -> None: + super().__init__() + self.var_name = var_name + self.ntypes = ntypes + self.dim_descrpt = dim_descrpt + self.dim_out_sr = dim_out_sr + self.dim_out_lr = dim_out_lr + self.neuron_sr = neuron_sr + self.neuron_lr = neuron_lr + self.mixed_types = mixed_types + self.resnet_dt = resnet_dt + self.numb_fparam = numb_fparam + self.numb_aparam = numb_aparam + self.default_fparam = default_fparam + self.dim_case_embd = dim_case_embd + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.rcond = rcond + self.seed = seed + self.type_map = type_map + self.use_aparam_as_mask = use_aparam_as_mask + self.reinit_exclude(exclude_types) + self.trainable = trainable + # need support for each layer settings + self.trainable = ( + all(self.trainable) if isinstance(self.trainable, list) else self.trainable + ) + self.remove_vaccum_contribution = remove_vaccum_contribution + + self.sr_net_dim_out = self._sr_net_out_dim() + self.lr_net_dim_out = self._lr_net_out_dim() + # init constants + if bias_atom_e is None: + bias_atom_e = np.zeros([self.ntypes, self.sr_net_dim_out], dtype=np.float64) + bias_atom_e = torch.tensor( + bias_atom_e, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device + ) + bias_atom_e = bias_atom_e.view([self.ntypes, self.sr_net_dim_out]) + if not self.mixed_types: + assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!" + self.register_buffer("bias_atom_e", bias_atom_e) + + if self.numb_fparam > 0: + self.register_buffer( + "fparam_avg", + torch.zeros(self.numb_fparam, dtype=self.prec, device=env.DEVICE), + ) + self.register_buffer( + "fparam_inv_std", + torch.ones(self.numb_fparam, dtype=self.prec, device=env.DEVICE), + ) + else: + self.fparam_avg, self.fparam_inv_std = None, None + if self.numb_aparam > 0: + self.register_buffer( + "aparam_avg", + torch.zeros(self.numb_aparam, dtype=self.prec, device=env.DEVICE), + ) + self.register_buffer( + "aparam_inv_std", + torch.ones(self.numb_aparam, dtype=self.prec, device=env.DEVICE), + ) + else: + self.aparam_avg, self.aparam_inv_std = None, None + + if self.dim_case_embd > 0: + self.register_buffer( + "case_embd", + torch.zeros(self.dim_case_embd, dtype=self.prec, device=env.DEVICE), + ) + else: + self.case_embd = None + + if self.default_fparam is not None: + if self.numb_fparam > 0: + assert len(self.default_fparam) == self.numb_fparam, ( + "default_fparam length mismatch!" + ) + self.register_buffer( + "default_fparam_tensor", + torch.tensor( + np.array(self.default_fparam), + dtype=self.prec, + device=env.DEVICE, + ), + ) + else: + self.default_fparam_tensor = None + + in_dim = ( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + + self.dim_case_embd + ) + + net_count = self.ntypes if not self.mixed_types else 1 + self.filter_layers_lr = NetworkCollection( + 1 if not self.mixed_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + self.lr_net_dim_out, + self.neuron_lr, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + seed=child_seed(self.seed, ii * 2), + trainable=self.trainable, + ) + for ii in range(net_count) + ], + ) + self.filter_layers_sr = NetworkCollection( + 1 if not self.mixed_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + self.sr_net_dim_out, + self.neuron_sr, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + seed=child_seed(self.seed, ii * 2 + 1), + trainable=self.trainable, + ) + for ii in range(net_count) + ], + ) + + for param in self.parameters(): + param.requires_grad = self.trainable + + self.eval_return_middle_output = False + + def reinit_exclude( + self, + exclude_types: list[int] = [], + ) -> None: + self.exclude_types = exclude_types + self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) + + def change_type_map( + self, + type_map: list[str], + model_with_new_type_stat: Optional["LRFittingNet"] = None, + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert self.type_map is not None, ( + "'type_map' must be defined when performing type changing!" + ) + assert self.mixed_types, "Only models in mixed types can perform type changing!" + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + self.type_map = type_map + self.ntypes = len(type_map) + self.reinit_exclude(map_atom_exclude_types(self.exclude_types, remap_index)) + if has_new_type: + extend_shape = [len(type_map), *list(self.bias_atom_e.shape[1:])] + extend_bias_atom_e = torch.zeros( + extend_shape, + dtype=self.bias_atom_e.dtype, + device=self.bias_atom_e.device, + ) + self.bias_atom_e = torch.cat([self.bias_atom_e, extend_bias_atom_e], dim=0) + self.bias_atom_e = self.bias_atom_e[remap_index] + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + return { + "@class": "LRFitting", + "@version": 1, + "var_name": self.var_name, + "ntypes": self.ntypes, + "dim_descrpt": self.dim_descrpt, + "dim_out_sr": self.dim_out_sr, + "dim_out_lr": self.dim_out_lr, + "neuron_sr": self.neuron_sr, + "neuron_lr": self.neuron_lr, + "resnet_dt": self.resnet_dt, + "numb_fparam": self.numb_fparam, + "numb_aparam": self.numb_aparam, + "dim_case_embd": self.dim_case_embd, + "default_fparam": self.default_fparam, + "activation_function": self.activation_function, + "precision": self.precision, + "mixed_types": self.mixed_types, + "nets_sr": self.filter_layers_sr.serialize(), + "nets_lr": self.filter_layers_lr.serialize(), + "rcond": self.rcond, + "exclude_types": self.exclude_types, + "@variables": { + "bias_atom_e": to_numpy_array(self.bias_atom_e), + "case_embd": to_numpy_array(self.case_embd), + "fparam_avg": to_numpy_array(self.fparam_avg), + "fparam_inv_std": to_numpy_array(self.fparam_inv_std), + "aparam_avg": to_numpy_array(self.aparam_avg), + "aparam_inv_std": to_numpy_array(self.aparam_inv_std), + }, + "type_map": self.type_map, + # "tot_ener_zero": self.tot_ener_zero , + # "trainable": self.trainable , + # "atom_ener": self.atom_ener , + # "layer_name": self.layer_name , + # "spin": self.spin , + ## NOTICE: not supported by far + "tot_ener_zero": False, + "trainable_sr": [self.trainable] * (len(self.neuron_sr) + 1), + "trainable_lr": [self.trainable] * (len(self.neuron_lr) + 1), + "layer_name": None, + "use_aparam_as_mask": self.use_aparam_as_mask, + "spin": None, + } + + @classmethod + def deserialize(cls, data: dict) -> "LRFittingNet": + data = data.copy() + variables = data.pop("@variables") + nets_sr = data.pop("nets_sr") + nets_lr = data.pop("nets_lr") + obj = cls(**data) + for kk in variables.keys(): + obj[kk] = to_torch_tensor(variables[kk]) + obj.filter_layers_sr = NetworkCollection.deserialize(nets_sr) + obj.filter_layers_lr = NetworkCollection.deserialize(nets_lr) + return obj + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.numb_fparam + + def has_default_fparam(self) -> bool: + """Check if the fitting has default frame parameters.""" + return self.default_fparam is not None + + def get_default_fparam(self) -> torch.Tensor | None: + return self.default_fparam_tensor + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.numb_aparam + + # make jit happy + exclude_types: list[int] + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + # make jit happy + sel_type: list[int] = [] + for ii in range(self.ntypes): + if ii not in self.exclude_types: + sel_type.append(ii) + return sel_type + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def set_case_embd(self, case_idx: int) -> None: + """ + Set the case embedding of this fitting net by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + self.case_embd = torch.eye(self.dim_case_embd, dtype=self.prec, device=device)[ + case_idx + ] + + def set_return_middle_output(self, return_middle_output: bool = True) -> None: + self.eval_return_middle_output = return_middle_output + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + if key in ["bias_atom_e"]: + value = value.view([self.ntypes, self._sr_net_out_dim()]) + self.bias_atom_e = value + elif key in ["fparam_avg"]: + self.fparam_avg = value + elif key in ["fparam_inv_std"]: + self.fparam_inv_std = value + elif key in ["aparam_avg"]: + self.aparam_avg = value + elif key in ["aparam_inv_std"]: + self.aparam_inv_std = value + elif key in ["case_embd"]: + self.case_embd = value + elif key in ["scale"]: + self.scale = value + elif key in ["default_fparam_tensor"]: + self.default_fparam_tensor = value + else: + raise KeyError(key) + + def __getitem__(self, key: str) -> torch.Tensor: + if key in ["bias_atom_e"]: + return self.bias_atom_e + elif key in ["fparam_avg"]: + return self.fparam_avg + elif key in ["fparam_inv_std"]: + return self.fparam_inv_std + elif key in ["aparam_avg"]: + return self.aparam_avg + elif key in ["aparam_inv_std"]: + return self.aparam_inv_std + elif key in ["case_embd"]: + return self.case_embd + elif key in ["scale"]: + return self.scale + elif key in ["default_fparam_tensor"]: + return self.default_fparam_tensor + else: + raise KeyError(key) + + def _sr_net_out_dim(self) -> int: + """Set the SRFittingNet output dim.""" + return self.dim_out_sr + + def _lr_net_out_dim(self) -> int: + """Set the LR FittingNet output dim.""" + return self.dim_out_lr + + def _corr_head(self, lr_out: torch.Tensor) -> torch.Tensor: + # TODO: Add latent_charge correction logic after LR output is finalized. + return lr_out + + def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor: + return torch.tile(xx.view([1, self.numb_fparam]), [nb, 1]) + + def _extend_a_avg_std(self, xx: torch.Tensor, nb: int, nloc: int) -> torch.Tensor: + return torch.tile(xx.view([1, 1, self.numb_aparam]), [nb, nloc, 1]) + + def _forward_common( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + xx = descriptor.to(self.prec) + nf, nloc, nd = xx.shape + + if self.numb_fparam > 0 and fparam is None: + assert self.default_fparam_tensor is not None + fparam = torch.tile(self.default_fparam_tensor.unsqueeze(0), [nf, 1]) + + fparam = fparam.to(self.prec) if fparam is not None else None + aparam = aparam.to(self.prec) if aparam is not None else None + + if self.remove_vaccum_contribution is not None: + xx_zeros = torch.zeros_like(xx) + else: + xx_zeros = None + + if nd != self.dim_descrpt: + raise ValueError( + f"get an input descriptor of dim {nd}," + f"which is not consistent with {self.dim_descrpt}." + ) + + if self.numb_fparam > 0: + assert fparam is not None, "fparam should not be None" + assert self.fparam_avg is not None + assert self.fparam_inv_std is not None + if fparam.shape[-1] != self.numb_fparam: + raise ValueError( + f"get an input fparam of dim {fparam.shape[-1]}, " + f"which is not consistent with {self.numb_fparam}." + ) + fparam = fparam.view([nf, self.numb_fparam]) + nb, _ = fparam.shape + t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) + t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb) + fparam = (fparam - t_fparam_avg) * t_fparam_inv_std + fparam = torch.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1]) + xx = torch.cat([xx, fparam], dim=-1) + if xx_zeros is not None: + xx_zeros = torch.cat([xx_zeros, fparam], dim=-1) + + if self.numb_aparam > 0 and not self.use_aparam_as_mask: + assert aparam is not None, "aparam should not be None" + assert self.aparam_avg is not None + assert self.aparam_inv_std is not None + if aparam.shape[-1] != self.numb_aparam: + raise ValueError( + f"get an input aparam of dim {aparam.shape[-1]}, " + f"which is not consistent with {self.numb_aparam}." + ) + aparam = aparam.view([nf, -1, self.numb_aparam]) + nb, nloc, _ = aparam.shape + t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) + t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc) + aparam = (aparam - t_aparam_avg) * t_aparam_inv_std + xx = torch.cat([xx, aparam], dim=-1) + if xx_zeros is not None: + xx_zeros = torch.cat([xx_zeros, aparam], dim=-1) + + if self.dim_case_embd > 0: + assert self.case_embd is not None + case_embd = torch.tile(self.case_embd.reshape([1, 1, -1]), [nf, nloc, 1]) + xx = torch.cat([xx, case_embd], dim=-1) + if xx_zeros is not None: + xx_zeros = torch.cat([xx_zeros, case_embd], dim=-1) + + results: dict[str, torch.Tensor] = {} + sr_out = self._apply_networks( + self.filter_layers_sr, + self.neuron_sr, + self.sr_net_dim_out, + xx, + xx_zeros, + atype, + middle_output=results, + bool_bias=True, + ) + lr_out = self._apply_networks( + self.filter_layers_lr, + self.neuron_lr, + self.lr_net_dim_out, + xx, + xx_zeros, + atype, + middle_output=results, + ) + mask = self.emask(atype).to(torch.bool) + sr_out = torch.where(mask[:, :, None], sr_out, 0.0) + lr_out = torch.where(mask[:, :, None], lr_out, 0.0) + lr_out = self._corr_head(lr_out) + results.update({"sr": sr_out, "lr": lr_out}) + return results + + def _apply_networks( + self, + layers: NetworkCollection, + neuron: list[int], + dim_out: int, + xx: torch.Tensor, + xx_zeros: torch.Tensor | None, + atype: torch.Tensor, + middle_output: dict[str, torch.Tensor] | None, + bool_bias: bool = False, + ) -> torch.Tensor: + nf, nloc, _ = xx.shape + outs = torch.zeros((nf, nloc, dim_out), dtype=self.prec, device=xx.device) + if self.mixed_types: + atom_property = layers.networks[0](xx) + if self.eval_return_middle_output and middle_output is not None: + middle_output["middle_output"] = layers.networks[0].call_until_last(xx) + if xx_zeros is not None: + atom_property -= layers.networks[0](xx_zeros) + outs = outs + atom_property + else: + if self.eval_return_middle_output and middle_output is not None: + outs_middle = torch.zeros( + (nf, nloc, neuron[-1]), + dtype=self.prec, + device=xx.device, + ) + for type_i, ll in enumerate(layers.networks): + mask = (atype == type_i).unsqueeze(-1) + mask = torch.tile(mask, (1, 1, dim_out)) + middle_output_type = ll.call_until_last(xx) + middle_output_type = torch.where( + torch.tile(mask, (1, 1, neuron[-1])), + middle_output_type, + 0.0, + ) + outs_middle = outs_middle + middle_output_type + middle_output["middle_output"] = outs_middle + for type_i, ll in enumerate(layers.networks): + mask = (atype == type_i).unsqueeze(-1) + mask = torch.tile(mask, (1, 1, dim_out)) + atom_property = ll(xx) + if xx_zeros is not None: + assert self.remove_vaccum_contribution is not None + if not ( + len(self.remove_vaccum_contribution) > type_i + and not self.remove_vaccum_contribution[type_i] + ): + atom_property -= ll(xx_zeros) + if bool_bias: + atom_property = atom_property + self.bias_atom_e[type_i].to( + self.prec + ) + else: + atom_property = atom_property + atom_property = torch.where(mask, atom_property, 0.0) + outs = outs + atom_property + return outs diff --git a/deepmd/pt/model/task/sog_energy_fitting.py b/deepmd/pt/model/task/sog_energy_fitting.py new file mode 100644 index 0000000000..eff762eca7 --- /dev/null +++ b/deepmd/pt/model/task/sog_energy_fitting.py @@ -0,0 +1,356 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from typing import ( + Any, +) + +import numpy as np +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, + fitting_check_output, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE + +log = logging.getLogger(__name__) +from deepmd.pt.model.task.lr_fitting import ( + LRFittingNet, +) + +SOG_DEFAULT_AMPLITUDE = to_numpy_array( + np.array( + [ + 0.2750, + 0.1375, + 0.0688, + 0.0344, + 0.0172, + 0.0086, + 0.0043, + 0.0021, + 0.0011, + 0.0005, + 0.0003, + 0.0001, + ] + ) +) +SOG_DEFAULT_SHIFT = to_numpy_array( + np.array( + [ + 2.8, + 5.7, + 11.4, + 22.7, + 45.5, + 91.0, + 182.0, + 364.0, + 728.0, + 1456.0, + 2912.0, + 5823.9, + ] + ) +) + + +@LRFittingNet.register("sog_energy") +@fitting_check_output +class SOGEnergyFittingNet(LRFittingNet): + """Construct a SOG sr+lr interactions fitting net. + + Parameters + ---------- + var_name : str + The atomic property to fit. + ntypes : int + Element count. + dim_descrpt : int + Embedding width per atom. + dim_out_sr : int + The output dimension of the sr fitting net. + dim_out_lr : int + The output dimension of the lr fitting net. + neuron_sr : list[int] + Number of neurons in each hidden layers of the sr fitting net. + neuron_lr : list[int] + Number of neurons in each hidden layers of the lr fitting net. + bias_atom_e : torch.Tensor, optional + Average energy per atom for each element. + resnet_dt : bool + Using time-step in the ResNet construction. + numb_fparam : int + Number of frame parameters. + numb_aparam : int + Number of atomic parameters. + dim_case_embd : int + Dimension of case specific embedding. + activation_function : str + Activation function. + precision : str + Numerical precision. + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. + rcond : float, optional + The condition number for the regression of atomic energy. + seed : int, optional + Random seed. + exclude_types: list[int] + Atomic contributions of the excluded atom types are set zero. + trainable : Union[list[bool], bool] + If the parameters in the fitting net are trainable. + Now this only supports setting all the parameters in the fitting net at one state. + When in list[bool], the trainable will be True only if all the boolean parameters are True. + remove_vaccum_contribution: list[bool], optional + Remove vacuum contribution before the bias is added. The list assigned each + type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same + length as `ntypes` signaling if or not removing the vacuum contribution for the atom types in the list. + type_map: list[str], Optional + A list of strings. Give the name to each type of atoms. + use_aparam_as_mask: bool + If True, the aparam will not be used in fitting net for embedding. + default_fparam: list[float], optional + The default frame parameter. If set, when `fparam.npy` files are not included in the data system, + this value will be used as the default value for the frame parameter in the fitting net. + n_dl : int + NUFFT long-range grid density control factor. + remove_self_interaction : bool + If True, remove self interaction term in long-range correction. + """ + + def __init__( + self, + var_name: str, + ntypes: int, + dim_descrpt: int, + dim_out_sr: int, + dim_out_lr: int, + neuron_sr: list[int] = [128, 128, 128], + neuron_lr: list[int] = [128, 128, 128], + bias_atom_e: torch.Tensor | None = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + dim_case_embd: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + mixed_types: bool = True, + rcond: float | None = None, + seed: int | list[int] | None = None, + exclude_types: list[int] = [], + trainable: bool | list[bool] = True, + remove_vaccum_contribution: list[bool] | None = None, + type_map: list[str] | None = None, + use_aparam_as_mask: bool = False, + default_fparam: list[float] | None = None, + shift: list[float] | torch.Tensor | None = None, + amplitude: list[float] | torch.Tensor | None = None, + n_dl: int = 1, + remove_self_interaction: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + var_name=var_name, + ntypes=ntypes, + dim_descrpt=dim_descrpt, + dim_out_sr=dim_out_sr, + dim_out_lr=dim_out_lr, + neuron_sr=neuron_sr, + neuron_lr=neuron_lr, + bias_atom_e=bias_atom_e, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, + activation_function=activation_function, + precision=precision, + mixed_types=mixed_types, + rcond=rcond, + seed=seed, + exclude_types=exclude_types, + trainable=trainable, + remove_vaccum_contribution=remove_vaccum_contribution, + type_map=type_map, + use_aparam_as_mask=use_aparam_as_mask, + default_fparam=default_fparam, + **kwargs, + ) + if isinstance(shift, (list, tuple)): + shift = to_numpy_array(np.array(shift)) + if isinstance(amplitude, (list, tuple)): + amplitude = to_numpy_array(np.array(amplitude)) + shift_tensor = to_torch_tensor(shift) + amplitude_tensor = to_torch_tensor(amplitude) + if shift_tensor is None: + shift_tensor = to_torch_tensor(SOG_DEFAULT_SHIFT) + if amplitude_tensor is None: + amplitude_tensor = to_torch_tensor(SOG_DEFAULT_AMPLITUDE) + + shift_tensor = shift_tensor.to(dtype=dtype, device=device) + amplitude_tensor = amplitude_tensor.to(dtype=dtype, device=device) + pi_tensor = torch.tensor(torch.pi, dtype=dtype, device=device) + sqr_pi_tensor = torch.sqrt(pi_tensor) + shift_safe = torch.clamp( + shift_tensor, + min=torch.finfo(shift_tensor.dtype).eps, + ) + wl_tensor = amplitude_tensor * (sqr_pi_tensor**3) * (shift_safe**3) + sl_tensor = -torch.log(2.0 / shift_safe) + + self.n_dl = max(1, int(n_dl)) + self.wl = torch.nn.Parameter( + wl_tensor, + requires_grad=bool(self.trainable), + ) + self.sl = torch.nn.Parameter( + sl_tensor, + requires_grad=bool(self.trainable), + ) + self.remove_self_interaction = bool(remove_self_interaction) + self._nufft_fallback_warned = False + + def output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + name="latent_charge", + shape=[self.dim_out_lr], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), + ] + ) + + @staticmethod + def _wl_sl_to_shift_amplitude( + wl_tensor: torch.Tensor, + sl_tensor: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + pi_tensor = torch.tensor( + torch.pi, + dtype=sl_tensor.dtype, + device=sl_tensor.device, + ) + sqr_pi_tensor = torch.sqrt(pi_tensor) + shift_tensor = 2.0 * torch.exp(sl_tensor) + amplitude_tensor = wl_tensor / ((sqr_pi_tensor**3) * (shift_tensor**3)) + return shift_tensor, amplitude_tensor + + @staticmethod + def _shift_amplitude_to_wl_sl( + shift_tensor: torch.Tensor, + amplitude_tensor: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + pi_tensor = torch.tensor( + torch.pi, + dtype=shift_tensor.dtype, + device=shift_tensor.device, + ) + sqr_pi_tensor = torch.sqrt(pi_tensor) + shift_safe = torch.clamp( + shift_tensor, + min=torch.finfo(shift_tensor.dtype).eps, + ) + wl_tensor = amplitude_tensor * (sqr_pi_tensor**3) * (shift_safe**3) + sl_tensor = -torch.log(2.0 / shift_safe) + return wl_tensor, sl_tensor + + def serialize(self) -> dict: + data = super().serialize() + data["type"] = "sog_energy" + variables = data["@variables"] + variables["wl"] = to_numpy_array(self.wl) + variables["sl"] = to_numpy_array(self.sl) + shift_tensor, amplitude_tensor = self._wl_sl_to_shift_amplitude( + self.wl, + self.sl, + ) + variables["shift"] = to_numpy_array(shift_tensor) + variables["amplitude"] = to_numpy_array(amplitude_tensor) + data["n_dl"] = self.n_dl + data["remove_self_interaction"] = bool(self.remove_self_interaction) + return data + + @classmethod + def deserialize(cls, data: dict) -> "SOGEnergyFittingNet": + data = data.copy() + + variables = data.get("@variables", {}).copy() + + wl_tensor = to_torch_tensor(variables.pop("wl", None)) + sl_tensor = to_torch_tensor(variables.pop("sl", None)) + shift_tensor = to_torch_tensor(variables.pop("shift", None)) + amplitude_tensor = to_torch_tensor(variables.pop("amplitude", None)) + data["@variables"] = variables + + obj = super().deserialize(data) + + with torch.no_grad(): + if wl_tensor is not None and sl_tensor is not None: + obj.wl.copy_(wl_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device)) + obj.sl.copy_(sl_tensor.to(dtype=obj.sl.dtype, device=obj.sl.device)) + elif shift_tensor is not None and amplitude_tensor is not None: + wl_tensor, sl_tensor = cls._shift_amplitude_to_wl_sl( + shift_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device), + amplitude_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device), + ) + obj.wl.copy_(wl_tensor) + obj.sl.copy_(sl_tensor) + return obj + + def _kernel_params(self) -> tuple[torch.Tensor, torch.Tensor]: + return self.wl, self.sl + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + out = self._forward_common( + descriptor=descriptor, + atype=atype, + gr=gr, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + result = { + "energy": out["sr"], + "latent_charge": out["lr"], + } + if "middle_output" in out: + result["middle_output"] = out["middle_output"] + return result + + # make jit happy with torch 2.0.0 + exclude_types: list[int] diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index 8f32ca660c..51b2b8517e 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -441,6 +441,10 @@ def __init__( # Model --------------------------------------------------------------- self.model = get_model(deepcopy(model_params)).to(DEVICE) + for module in self.model.modules(): + set_freq = getattr(module, "set_debug_print_freq", None) + if callable(set_freq): + set_freq(self.disp_freq) # Loss ---------------------------------------------------------------- self.loss = get_loss( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 5fe1d4f3f1..fe4e63f387 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -57,6 +57,8 @@ doc_ener = "Fit an energy model (potential energy surface)." doc_dos = "Fit a density of states model. The total density of states / site-projected density of states labels should be provided by `dos.npy` or `atom_dos.npy` in each data system. The file has number of frames lines and number of energy grid columns (times number of atoms in `atom_dos.npy`). See `loss` parameter." doc_dipole = "Fit an atomic dipole model. Global dipole labels or atomic dipole labels for all the selected atoms (see `sel_type`) should be provided by `dipole.npy` in each data system. The file either has number of frames lines and 3 times of number of selected atoms columns, or has number of frames lines and 3 columns. See `loss` parameter." +doc_sog = "Fit a SOG-energy model (SR+LR)." +doc_les = "Fit a LES-energy model (SR+LR)." doc_polar = "Fit an atomic polarizability model. Global polarizazbility labels or atomic polarizability labels for all the selected atoms (see `sel_type`) should be provided by `polarizability.npy` in each data system. The file with has number of frames lines and 9 times of number of selected atoms columns, or has number of frames lines and 9 columns. See `loss` parameter." # modifier doc_dipole_charge = "Use WFCC to model the electronic structure of the system. Correct the long-range interaction." @@ -2076,6 +2078,396 @@ def fitting_polar() -> list[Argument]: # def fitting_global_polar(): # return fitting_polar() +@fitting_args_plugin.register("sog_energy", doc=doc_sog) +def fitting_sog_energy() -> list[Argument]: + doc_var_name = ( + "The atomic property name used by the LR fitting net. Usually set to `energy`." + ) + doc_dim_out_sr = "The output dimension of the short-range fitting branch." + doc_dim_out_lr = "The output dimension of the long-range fitting branch." + doc_neuron_sr = ( + "The number of neurons in each hidden layer of the short-range fitting net." + ) + doc_neuron_lr = ( + "The number of neurons in each hidden layer of the long-range fitting net." + ) + doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." + doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." + doc_default_fparam = "The default frame parameter. If set, when `fparam.npy` files are not included in the data system, this value will be used as the default value for the frame parameter in the fitting net." + doc_dim_case_embd = "The dimension of the case embedding embedding. When training or fine-tuning a multitask model with case embedding embeddings, this number should be set to the number of model branches." + doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' + doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' + doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be bool or list[bool]. When list[bool] is given, all values must be True to make parameters trainable." + doc_rcond = "The condition number used to determine the initial energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details." + doc_seed = "Random seed for parameter initialization of the fitting net" + doc_exclude_types = ( + "The excluded atom types whose atomic contributions are set to zero." + ) + doc_use_aparam_as_mask = ( + "Whether to use the aparam as a mask in input." + "If True, the aparam will not be used in fitting net for embedding." + ) + doc_shift = "Shift values of the SOG long-range correction kernels." + doc_amplitude = "Amplitude values of the SOG long-range correction kernels." + doc_n_dl = "NUFFT long-range grid density control factor." + doc_remove_self_interaction = ( + "Whether to remove self interaction term in long-range correction." + ) + + return [ + Argument( + "var_name", + str, + optional=True, + default="energy", + doc=doc_only_pt_supported + doc_var_name, + ), + Argument( + "dim_out_sr", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_dim_out_sr, + ), + Argument( + "dim_out_lr", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_dim_out_lr, + ), + Argument( + "neuron_sr", + list[int], + optional=True, + default=[128, 128, 128], + doc=doc_only_pt_supported + doc_neuron_sr, + ), + Argument( + "neuron_lr", + list[int], + optional=True, + default=[128, 128, 128], + doc=doc_only_pt_supported + doc_neuron_lr, + ), + Argument( + "numb_fparam", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_numb_fparam, + ), + Argument( + "numb_aparam", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_numb_aparam, + ), + Argument( + "default_fparam", + list[float], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_default_fparam, + ), + Argument( + "dim_case_embd", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_dim_case_embd, + ), + Argument( + "activation_function", + str, + optional=True, + default="tanh", + doc=doc_activation_function, + ), + Argument("precision", str, optional=True, default="default", doc=doc_precision), + Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt), + Argument( + "trainable", + [list[bool], bool], + optional=True, + default=True, + doc=doc_only_pt_supported + doc_trainable, + ), + Argument( + "rcond", + [float, type(None)], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_rcond, + ), + Argument("seed", [int, None], optional=True, doc=doc_seed), + Argument( + "exclude_types", + list[int], + optional=True, + default=[], + doc=doc_only_pt_supported + doc_exclude_types, + ), + Argument( + "use_aparam_as_mask", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_use_aparam_as_mask, + ), + Argument( + "n_dl", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_n_dl, + ), + Argument( + "remove_self_interaction", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_remove_self_interaction, + ), + Argument( + "shift", + list[float], + optional=True, + default=[ + 0.2750, + 0.1375, + 0.0688, + 0.0344, + 0.0172, + 0.0086, + 0.0043, + 0.0021, + 0.0011, + 0.0005, + 0.0003, + 0.0001, + ], + doc=doc_only_pt_supported + doc_shift, + ), + Argument( + "amplitude", + list[float], + optional=True, + default=[ + 2.8, + 5.7, + 11.4, + 22.7, + 45.5, + 91.0, + 182.0, + 364.0, + 728.0, + 1456.0, + 2912.0, + 5823.9, + ], + doc=doc_only_pt_supported + doc_amplitude, + ), + ] + + +@fitting_args_plugin.register("les_energy", doc=doc_les) +def fitting_les_energy() -> list[Argument]: + doc_var_name = ( + "The atomic property name used by the LR fitting net. Usually set to `energy`." + ) + doc_dim_out_sr = "The output dimension of the short-range fitting branch." + doc_dim_out_lr = "The output dimension of the long-range fitting branch." + doc_neuron_sr = ( + "The number of neurons in each hidden layer of the short-range fitting net." + ) + doc_neuron_lr = ( + "The number of neurons in each hidden layer of the long-range fitting net." + ) + doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." + doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." + doc_default_fparam = "The default frame parameter. If set, when `fparam.npy` files are not included in the data system, this value will be used as the default value for the frame parameter in the fitting net." + doc_dim_case_embd = "The dimension of the case embedding embedding. When training or fine-tuning a multitask model with case embedding embeddings, this number should be set to the number of model branches." + doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' + doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' + doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be bool or list[bool]. When list[bool] is given, all values must be True to make parameters trainable." + doc_rcond = "The condition number used to determine the initial energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details." + doc_seed = "Random seed for parameter initialization of the fitting net" + doc_exclude_types = ( + "The excluded atom types whose atomic contributions are set to zero." + ) + doc_use_aparam_as_mask = ( + "Whether to use the aparam as a mask in input." + "If True, the aparam will not be used in fitting net for embedding." + ) + doc_shift = "Shift values of the LES long-range correction kernels." + doc_amplitude = "Amplitude values of the LES long-range correction kernels." + doc_n_dl = "NUFFT long-range grid density control factor." + doc_remove_self_interaction = ( + "Whether to remove self interaction term in long-range correction." + ) + + return [ + Argument( + "var_name", + str, + optional=True, + default="energy", + doc=doc_only_pt_supported + doc_var_name, + ), + Argument( + "dim_out_sr", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_dim_out_sr, + ), + Argument( + "dim_out_lr", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_dim_out_lr, + ), + Argument( + "neuron_sr", + list[int], + optional=True, + default=[128, 128, 128], + doc=doc_only_pt_supported + doc_neuron_sr, + ), + Argument( + "neuron_lr", + list[int], + optional=True, + default=[128, 128, 128], + doc=doc_only_pt_supported + doc_neuron_lr, + ), + Argument( + "numb_fparam", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_numb_fparam, + ), + Argument( + "numb_aparam", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_numb_aparam, + ), + Argument( + "default_fparam", + list[float], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_default_fparam, + ), + Argument( + "dim_case_embd", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_dim_case_embd, + ), + Argument( + "activation_function", + str, + optional=True, + default="tanh", + doc=doc_activation_function, + ), + Argument("precision", str, optional=True, default="default", doc=doc_precision), + Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt), + Argument( + "trainable", + [list[bool], bool], + optional=True, + default=True, + doc=doc_only_pt_supported + doc_trainable, + ), + Argument( + "rcond", + [float, type(None)], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_rcond, + ), + Argument("seed", [int, None], optional=True, doc=doc_seed), + Argument( + "exclude_types", + list[int], + optional=True, + default=[], + doc=doc_only_pt_supported + doc_exclude_types, + ), + Argument( + "use_aparam_as_mask", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_use_aparam_as_mask, + ), + Argument( + "n_dl", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_n_dl, + ), + Argument( + "remove_self_interaction", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_remove_self_interaction, + ), + Argument( + "shift", + list[float], + optional=True, + default=[ + 0.2750, + 0.1375, + 0.0688, + 0.0344, + 0.0172, + 0.0086, + 0.0043, + 0.0021, + 0.0011, + 0.0005, + 0.0003, + 0.0001, + ], + doc=doc_only_pt_supported + doc_shift, + ), + Argument( + "amplitude", + list[float], + optional=True, + default=[ + 2.8, + 5.7, + 11.4, + 22.7, + 45.5, + 91.0, + 182.0, + 364.0, + 728.0, + 1456.0, + 2912.0, + 5823.9, + ], + doc=doc_only_pt_supported + doc_amplitude, + ), + ] @fitting_args_plugin.register("dipole", doc=doc_dipole) diff --git a/examples/water/dpa3/dpa3.hdf5 b/examples/water/dpa3/dpa3.hdf5 new file mode 100644 index 0000000000..90ae7bb1d6 Binary files /dev/null and b/examples/water/dpa3/dpa3.hdf5 differ diff --git a/examples/water/dpa3/input_torch_copy.json b/examples/water/dpa3/input_torch_copy.json new file mode 100644 index 0000000000..3c951b6235 --- /dev/null +++ b/examples/water/dpa3/input_torch_copy.json @@ -0,0 +1,103 @@ +{ + "_comment": "that's all", + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 128, + "e_dim": 64, + "a_dim": 32, + "nlayers": 6, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 120, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 30, + "axis_neuron": 4, + "fix_stat_std": 0.3, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": true, + "update_angle": true, + "smooth_edge_update": true, + "edge_init_use_dist": true, + "use_exp_switch": true, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const" + }, + "activation_function": "silut:10.0", + "use_tebd_bias": false, + "precision": "float32", + "concat_output_tebd": false, + "seed": 1 + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "precision": "float32", + "activation_function": "silut:10.0", + "seed": 1, + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3e-5, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.2, + "limit_pref_e": 20, + "start_pref_f": 100, + "limit_pref_f": 60, + "start_pref_v": 0.02, + "limit_pref_v": 1, + "_comment": " that's all" + }, + "optimizer": { + "type": "AdamW", + "adam_beta1": 0.9, + "adam_beta2": 0.999, + "weight_decay": 0.001 + }, + "training": { + "stat_file": "./dpa3.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "numb_steps": 500000, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 2000, + "_comment": "that's all" + } +} diff --git a/examples/water/sog/README.md b/examples/water/sog/README.md new file mode 100644 index 0000000000..a70130de77 --- /dev/null +++ b/examples/water/sog/README.md @@ -0,0 +1,16 @@ +# Input for the SOG model + +This directory provides a SOG training example based on the same water data split and training layout used by `examples/water/dpa3`. + +## Run + +```bash +cd examples/water/sog +dp --pt train input_torch.json --skip-neighbor-stat +``` + +## Notes + +- Descriptor: DPA3 (`model.descriptor.type = "dpa3"`) +- Fitting: SOG energy (`model.fitting_net.type = "sog_energy"`) +- Data systems are reused from `examples/water/data/data_0` to `data_3`. diff --git a/examples/water/sog/ab_retain_graph.py b/examples/water/sog/ab_retain_graph.py new file mode 100644 index 0000000000..96b9044fa9 --- /dev/null +++ b/examples/water/sog/ab_retain_graph.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) + +import json +import time +from pathlib import ( + Path, +) +from types import ( + MethodType, +) + +import torch + +from deepmd.pt.model.model import ( + get_model, +) + + +def sync(dev: torch.device) -> None: + if dev.type == "cuda": + torch.cuda.synchronize(dev) + + +def build_input( + nf: int = 1, + nloc: int = 192, + box_len: float = 20.0, + device: str = "cuda", + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.device]: + dev = torch.device(device if torch.cuda.is_available() else "cpu") + g = torch.Generator(device=dev) + g.manual_seed(1234) + coord = torch.rand((nf, nloc, 3), device=dev, dtype=dtype, generator=g) * box_len + atype = torch.zeros((nf, nloc), device=dev, dtype=torch.long) + atype[:, 1::3] = 1 + atype[:, 2::3] = 1 + box = torch.zeros((nf, 3, 3), device=dev, dtype=dtype) + box[:, 0, 0] = box_len + box[:, 1, 1] = box_len + box[:, 2, 2] = box_len + return coord, atype, box, dev + + +def bench( + model, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor, + reps: int = 20, + warmup: int = 5, +) -> float: + model.eval() + for _ in range(warmup): + _ = model(coord, atype, box=box) + sync(coord.device) + t0 = time.perf_counter() + for _ in range(reps): + _ = model(coord, atype, box=box) + sync(coord.device) + return (time.perf_counter() - t0) / reps * 1000.0 + + +def make_patched_apply(): + def patched_apply(self, model_ret, extended_coord, nlist, box, do_atomic_virial): + if box is None or "latent_charge" not in model_ret: + return model_ret + nf, nloc, _ = nlist.shape + nall = extended_coord.shape[1] + coord_local = extended_coord[:, :nloc, :] + box_local = box.view(nf, 3, 3) + latent_charge = model_ret["latent_charge"] + corr_redu = self._compute_sog_frame_correction( + coord_local, latent_charge, box_local + ) + model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to( + model_ret["energy_redu"].dtype + ) + + if self.do_grad_r("energy") or self.do_grad_c("energy"): + corr_force_local = -torch.autograd.grad( + corr_redu.sum(), + coord_local, + create_graph=self.training, + retain_graph=False, + )[0].view(nf, nloc, 3) + + corr_force_ext = torch.zeros( + (nf, nall, 3), + dtype=corr_force_local.dtype, + device=corr_force_local.device, + ) + corr_force_ext[:, :nloc, :] = corr_force_local + if "energy_derv_r" in model_ret: + model_ret["energy_derv_r"] = model_ret[ + "energy_derv_r" + ] + corr_force_ext.unsqueeze(-2).to(model_ret["energy_derv_r"].dtype) + + if self.do_grad_c("energy"): + corr_virial_local = torch.einsum( + "fai,faj->faij", + corr_force_local, + coord_local, + ).reshape(nf, nloc, 1, 9) + corr_virial_redu = corr_virial_local.sum(dim=1) + if "energy_derv_c_redu" in model_ret: + model_ret["energy_derv_c_redu"] = model_ret[ + "energy_derv_c_redu" + ] + corr_virial_redu.to(model_ret["energy_derv_c_redu"].dtype) + if do_atomic_virial and "energy_derv_c" in model_ret: + corr_atom_virial = torch.zeros( + (nf, nall, 1, 9), + dtype=corr_virial_local.dtype, + device=corr_virial_local.device, + ) + corr_atom_virial[:, :nloc, :, :] = corr_virial_local + model_ret["energy_derv_c"] = model_ret[ + "energy_derv_c" + ] + corr_atom_virial.to(model_ret["energy_derv_c"].dtype) + + return model_ret + + return patched_apply + + +def main() -> None: + cfg = json.loads(Path("examples/water/sog/input_torch.json").read_text())["model"] + coord, atype, box, dev = build_input() + model = get_model(cfg).to(dev) + + base = bench(model, coord, atype, box) + + orig_apply = model._apply_frame_correction_lower + model._apply_frame_correction_lower = MethodType(make_patched_apply(), model) + patched = bench(model, coord, atype, box) + + model._apply_frame_correction_lower = orig_apply + out0 = model(coord, atype, box=box) + model._apply_frame_correction_lower = MethodType(make_patched_apply(), model) + out1 = model(coord, atype, box=box) + + max_de = (out0["energy"] - out1["energy"]).abs().max().item() + max_df = (out0["force"] - out1["force"]).abs().max().item() + + print(f"baseline_ms={base:.3f}") + print(f"retain_graph_false_ms={patched:.3f}") + print(f"speedup={(base / patched):.3f}x") + print(f"max|dE|={max_de:.3e}") + print(f"max|dF|={max_df:.3e}") + + +if __name__ == "__main__": + main() diff --git a/examples/water/sog/check_sog_consistency_with_cace.py b/examples/water/sog/check_sog_consistency_with_cace.py new file mode 100644 index 0000000000..4efd9e3b96 --- /dev/null +++ b/examples/water/sog/check_sog_consistency_with_cace.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) + +import importlib.util +import json +import pathlib + +import torch + +from deepmd.pt.model.model import ( + get_model, +) + + +def main() -> None: + cfg = json.loads(pathlib.Path("examples/water/sog/input_torch.json").read_text())[ + "model" + ] + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = get_model(cfg).to(dev).eval() + fit = model.get_fitting_net() + + sog_path = pathlib.Path("/data/zyjin/cace/SOG-Net/CACE-SOG/cace/modules/sog.py") + spec = importlib.util.spec_from_file_location("cace_sog_module", sog_path) + if spec is None or spec.loader is None: + raise RuntimeError("Failed to load cace sog.py module") + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + SOGPotential = mod.SOGPotential + + cace_sog = SOGPotential(N_dl=int(fit.n_dl), Periodic=True).to(dev).eval() + with torch.no_grad(): + cace_sog.wl.copy_(fit.wl.detach().to(cace_sog.wl.dtype).to(dev)) + cace_sog.sl.copy_(fit.sl.detach().to(cace_sog.sl.dtype).to(dev)) + + nf, nloc, nq = 3, 32, 1 + coord = torch.rand(nf, nloc, 3, device=dev, dtype=torch.float32) * 10.0 + box = torch.zeros(nf, 3, 3, device=dev, dtype=torch.float32) + box[:, 0, 0] = 10.0 + box[:, 1, 1] = 11.0 + box[:, 2, 2] = 12.0 + latent = torch.randn(nf, nloc, nq, device=dev, dtype=torch.float32) + + with torch.no_grad(): + corr_deepmd = model._compute_sog_frame_correction(coord, latent, box).squeeze( + -1 + ) + corr_cace = [] + for i in range(nf): + v = cace_sog.compute_potential_SOG_triclinic_NUFFT( + coord[i], latent[i], box[i] + ) + corr_cace.append(v.squeeze()) + corr_cace = torch.stack(corr_cace) + + abs_diff = (corr_deepmd - corr_cace).abs() + rel_diff = abs_diff / torch.clamp(corr_cace.abs(), min=1e-8) + + print("corr_deepmd", corr_deepmd.detach().cpu().numpy()) + print("corr_cace ", corr_cace.detach().cpu().numpy()) + print("max_abs_diff", abs_diff.max().item()) + print("mean_abs_diff", abs_diff.mean().item()) + print("max_rel_diff", rel_diff.max().item()) + + +if __name__ == "__main__": + main() diff --git a/examples/water/sog/compare_sog_dpa3_timing.py b/examples/water/sog/compare_sog_dpa3_timing.py new file mode 100644 index 0000000000..d19098fb7b --- /dev/null +++ b/examples/water/sog/compare_sog_dpa3_timing.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) + +import json +import time +from pathlib import ( + Path, +) + +import torch + +from deepmd.pt.model.model import ( + get_model, +) + + +def sync(dev: torch.device) -> None: + if dev.type == "cuda": + torch.cuda.synchronize(dev) + + +def build_input( + nf: int = 1, + nloc: int = 192, + box_len: float = 20.0, + device: str = "cuda", + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.device]: + dev = torch.device(device if torch.cuda.is_available() else "cpu") + coord = torch.rand((nf, nloc, 3), device=dev, dtype=dtype) * box_len + atype = torch.zeros((nf, nloc), device=dev, dtype=torch.long) + atype[:, 1::3] = 1 + atype[:, 2::3] = 1 + box = torch.zeros((nf, 3, 3), device=dev, dtype=dtype) + box[:, 0, 0] = box_len + box[:, 1, 1] = box_len + box[:, 2, 2] = box_len + return coord, atype, box, dev + + +def bench_model( + model, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor, + warmup: int = 5, + reps: int = 30, +) -> float: + model.eval() + for _ in range(warmup): + _ = model(coord, atype, box=box) + sync(coord.device) + t0 = time.perf_counter() + for _ in range(reps): + _ = model(coord, atype, box=box) + sync(coord.device) + return (time.perf_counter() - t0) / reps * 1000.0 + + +def main() -> None: + sog_cfg = json.loads(Path("examples/water/sog/input_torch.json").read_text())[ + "model" + ] + dpa_cfg = json.loads(Path("examples/water/dpa3/input_torch_copy.json").read_text())[ + "model" + ] + + coord, atype, box, dev = build_input() + + sog = get_model(sog_cfg).to(dev) + dpa = get_model(dpa_cfg).to(dev) + if hasattr(sog.get_fitting_net(), "n_dl"): + sog.get_fitting_net().n_dl = 2 + + sog_ms = bench_model(sog, coord, atype, box) + + orig = sog._apply_frame_correction_lower + sog._apply_frame_correction_lower = lambda model_ret, *args, **kwargs: model_ret + sog_nocorr_ms = bench_model(sog, coord, atype, box) + sog._apply_frame_correction_lower = orig + + dpa_ms = bench_model(dpa, coord, atype, box) + + print(f"dpa3_total_ms={dpa_ms:.3f}") + print(f"sog_total_ms={sog_ms:.3f}") + print(f"sog_without_frame_corr_ms={sog_nocorr_ms:.3f}") + print(f"sog_extra_frame_corr_ms={sog_ms - sog_nocorr_ms:.3f}") + print(f"sog_vs_dpa3_delta_ms={sog_ms - dpa_ms:.3f}") + + +if __name__ == "__main__": + main() diff --git a/examples/water/sog/input_torch.json b/examples/water/sog/input_torch.json new file mode 100644 index 0000000000..91ee791564 --- /dev/null +++ b/examples/water/sog/input_torch.json @@ -0,0 +1,136 @@ +{ + "_comment": "SOG example based on water/dpa3 layout", + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 128, + "e_dim": 64, + "a_dim": 32, + "nlayers": 6, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 120, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 30, + "axis_neuron": 4, + "fix_stat_std": 0.3, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": true, + "update_angle": true, + "smooth_edge_update": true, + "edge_init_use_dist": true, + "use_exp_switch": true, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const" + }, + "activation_function": "silut:10.0", + "use_tebd_bias": false, + "precision": "float32", + "concat_output_tebd": false, + "seed": 1 + }, + "fitting_net": { + "type": "sog_energy", + "var_name": "energy", + "dim_out_sr": 1, + "dim_out_lr": 1, + "neuron_sr": [ + 240, + 240, + 240 + ], + "neuron_lr": [ + 120, + 120, + 120 + ], + "resnet_dt": true, + "precision": "float32", + "activation_function": "silut:10.0", + "seed": 1, + "n_dl": 1, + "remove_self_interaction": true, + "_comment": "SOG-specific long-range kernel parameters", + "shift": [ + 2.8, + 5.7, + 11.4, + 22.7, + 45.5, + 91.0, + 182.0, + 364.0, + 728.0, + 1456.0, + 2912.0, + 5823.9 + ], + "amplitude": [ + 0.275, + 0.1375, + 0.0688, + 0.0344, + 0.0172, + 0.0086, + 0.0043, + 0.0021, + 0.0011, + 0.0005, + 0.0003, + 0.0001 + ] + } + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3e-5 + }, + "loss": { + "type": "ener", + "start_pref_e": 0.2, + "limit_pref_e": 20, + "start_pref_f": 100, + "limit_pref_f": 60, + "start_pref_v": 0.02, + "limit_pref_v": 1 + }, + "optimizer": { + "type": "AdamW", + "adam_beta1": 0.9, + "adam_beta2": 0.999, + "weight_decay": 0.001 + }, + "training": { + "stat_file": "./sog.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1 + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1 + }, + "numb_steps": 500000, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve_remove_self.out", + "disp_freq": 100, + "save_freq": 2000 + } +} diff --git a/examples/water/sog/profile_sog_timing.py b/examples/water/sog/profile_sog_timing.py new file mode 100644 index 0000000000..78876f40ce --- /dev/null +++ b/examples/water/sog/profile_sog_timing.py @@ -0,0 +1,696 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Profile SOG model runtime breakdown using input_torch.json as model config.""" + +from __future__ import ( + annotations, +) + +import argparse +import json +import time +import types +from collections import ( + defaultdict, +) +from contextlib import ( + nullcontext, +) +from pathlib import ( + Path, +) +from typing import ( + Any, +) + +import pytorch_finufft +import torch + +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.model.model.sog_model import ( + SOGEnergyModel_, +) +from deepmd.pt.model.model.transform_output import ( + communicate_extended_output, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + + +def _sync_if_cuda(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def _time_block(name: str, timings: dict[str, float], device: torch.device): + class _Timer: + def __enter__(self_inner): + _sync_if_cuda(device) + self_inner.t0 = time.perf_counter() + return self_inner + + def __exit__(self_inner, exc_type, exc, tb): + _sync_if_cuda(device) + timings[name] += time.perf_counter() - self_inner.t0 + + return _Timer() + + +def _build_synthetic_input( + nframes: int, + nloc: int, + box_len: float, + device: torch.device, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + coord = torch.rand((nframes, nloc, 3), device=device, dtype=dtype) * box_len + atype = torch.zeros((nframes, nloc), device=device, dtype=torch.long) + + # Water-like type ratio: O:H = 1:2 + atype[:, 1::3] = 1 + atype[:, 2::3] = 1 + + box = torch.zeros((nframes, 3, 3), device=device, dtype=dtype) + box[:, 0, 0] = box_len + box[:, 1, 1] = box_len + box[:, 2, 2] = box_len + return coord, atype, box + + +def _load_model(config_path: Path, device: torch.device) -> Any: + cfg = json.loads(config_path.read_text()) + model = get_model(cfg["model"]) + model = model.to(device) + model.eval() + return model + + +def _install_fine_frame_corr_profiler( + model: Any, + detail_times: dict[str, float], + device: torch.device, + collect_flag: dict[str, bool], +) -> tuple[Any, Any]: + orig_bundle = model._compute_sog_frame_correction_bundle + orig_apply = model._apply_frame_correction_lower + + def _timed_bundle( + self, + coord: torch.Tensor, + latent_charge: torch.Tensor, + box: torch.Tensor, + *, + need_force: bool, + need_virial: bool, + ) -> dict[str, torch.Tensor]: + if coord.dim() != 3: + raise ValueError( + f"`coord` should be [nf, nloc, 3], got shape {tuple(coord.shape)}" + ) + if latent_charge.dim() != 3: + raise ValueError( + f"`latent_charge` should be [nf, nloc, nq], got shape {tuple(latent_charge.shape)}" + ) + if coord.shape[:2] != latent_charge.shape[:2]: + raise ValueError( + "`coord` and `latent_charge` local dimensions mismatch: " + f"{tuple(coord.shape[:2])} vs {tuple(latent_charge.shape[:2])}" + ) + + fitting = self.get_fitting_net() + runtime_device = coord.device + real_dtype = coord.dtype + complex_dtype = ( + torch.complex128 if real_dtype == torch.float64 else torch.complex64 + ) + with ( + _time_block("fc_cast_inputs", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + latent_charge = latent_charge.to(device=runtime_device, dtype=real_dtype) + box = box.to(device=runtime_device, dtype=real_dtype) + if box.dim() != 3 or box.shape[-2:] != (3, 3): + raise ValueError( + f"`box` should be [nf, 3, 3], got shape {tuple(box.shape)}" + ) + + with ( + _time_block("fc_param_prepare", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + wl, _sl, min_term = self._get_cached_sog_params( + fitting, + runtime_device, + real_dtype, + ) + remove_self_interaction = bool(fitting.remove_self_interaction) + n_dl = int(fitting.n_dl) + pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) + two_pi = torch.tensor( + 2.0 * torch.pi, dtype=real_dtype, device=runtime_device + ) + + nf, nloc, _ = coord.shape + corr = torch.zeros((nf, 1), dtype=real_dtype, device=runtime_device) + force_local = ( + torch.zeros((nf, nloc, 3), dtype=real_dtype, device=runtime_device) + if need_force + else None + ) + virial_local = ( + torch.zeros((nf, nloc, 1, 9), dtype=real_dtype, device=runtime_device) + if need_virial + else None + ) + + for ff in range(nf): + r_raw = coord[ff] + q = latent_charge[ff] + box_frame = box[ff] + + with ( + _time_block("fc_geom_and_points", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + volume = torch.det(box_frame) + if torch.abs(volume) <= torch.finfo(real_dtype).eps: + raise ValueError( + "`box` is singular (near-zero volume), cannot run NUFFT." + ) + + cell_inv = torch.linalg.inv(box_frame) + r_frac = torch.matmul(r_raw, cell_inv) + r_frac = torch.remainder(r_frac + 0.5, 1.0) - 0.5 + point_limit = pi_tensor - 32.0 * torch.finfo(real_dtype).eps + r_in = torch.clamp( + 2.0 * pi_tensor * r_frac, + min=-point_limit, + max=point_limit, + ).contiguous() + nufft_points = r_in.transpose(0, 1).contiguous() + + with ( + _time_block("fc_build_k_grid", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + norms = torch.norm(box_frame, dim=1) + nk = tuple(max(1, int(n.item() / n_dl)) for n in norms) + n1 = torch.arange( + -nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype + ) + n2 = torch.arange( + -nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype + ) + n3 = torch.arange( + -nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype + ) + kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") + k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 + zero_mask = k_sq == 0 + + with ( + _time_block("fc_build_kfac", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + kfac = wl.view(1, 1, 1, -1) * torch.exp(k_sq.unsqueeze(-1) * min_term) + kfac = kfac.sum(dim=-1) + kfac = kfac.to(dtype=real_dtype) + kfac[zero_mask] = 0.0 + + with ( + _time_block("fc_prepare_charge", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + q_t = q.transpose(0, 1).contiguous() + charge = ( + torch.complex(q_t, torch.zeros_like(q_t)) + .to(dtype=complex_dtype) + .contiguous() + ) + + with ( + _time_block("fc_nufft_type1", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + recon = pytorch_finufft.functional.finufft_type1( + nufft_points, + charge, + output_shape=tuple(int(x) for x in kx_grid.shape), + eps=1e-4, + isign=-1, + ) + + with ( + _time_block("fc_energy_reduce", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + rho_sq = recon.real.square() + recon.imag.square() + corr[ff, 0] = (kfac.unsqueeze(0) * rho_sq).sum() / (2.0 * volume) + + if need_force: + with ( + _time_block("fc_prepare_force_conv", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + conv = kfac.unsqueeze(0).to(dtype=complex_dtype) * recon + + with ( + _time_block("fc_prepare_force_kgrid", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + kk1 = torch.fft.ifftshift(kx_grid, dim=0) + kk2 = torch.fft.ifftshift(ky_grid, dim=1) + kk3 = torch.fft.ifftshift(kz_grid, dim=2) + k_grid = torch.stack((kk1, kk2, kk3), dim=0) + g_cart = two_pi * torch.einsum("ik,k...->i...", cell_inv, k_grid) + grad_conv = ( + 1j * g_cart.unsqueeze(1).to(dtype=complex_dtype) + ) * conv.unsqueeze(0) + + with ( + _time_block("fc_nufft_type2_force", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + grad_field = pytorch_finufft.functional.finufft_type2( + nufft_points, + grad_conv, + eps=1e-4, + isign=1, + ) + + with ( + _time_block("fc_force_reduce", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + force_frame = ( + -(q_t.unsqueeze(0) * grad_field.real.to(dtype=real_dtype)) + .sum(dim=1) + .transpose(0, 1) + ) + force_frame = force_frame / volume + force_local[ff] = force_frame + + if need_virial: + with ( + _time_block("fc_virial_local", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + virial_local[ff] = torch.einsum( + "ai,aj->aij", + force_frame, + r_raw, + ).reshape(nloc, 1, 9) + + if remove_self_interaction: + with ( + _time_block("fc_self_interaction", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + diag_sum = kfac.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume) + corr[ff, 0] -= torch.sum(q**2) * diag_sum + + out: dict[str, torch.Tensor] = {"corr_redu": corr} + if force_local is not None: + out["force_local"] = force_local + if virial_local is not None: + out["virial_local"] = virial_local + return out + + def _timed_apply( + self, + model_ret: dict[str, torch.Tensor], + extended_coord: torch.Tensor, + nlist: torch.Tensor, + box: torch.Tensor | None, + do_atomic_virial: bool, + ) -> dict[str, torch.Tensor]: + with ( + _time_block("fc_guard_and_slice", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + if box is None or "latent_charge" not in model_ret: + return model_ret + + nf, nloc, _ = nlist.shape + nall = extended_coord.shape[1] + coord_local = extended_coord[:, :nloc, :] + box_local = box.view(nf, 3, 3) + latent_charge = model_ret["latent_charge"] + need_force = self.do_grad_r("energy") or self.do_grad_c("energy") + need_virial = self.do_grad_c("energy") + latent_charge_runtime = ( + latent_charge if self.training else latent_charge.detach() + ) + + with ( + _time_block("fc_compute_corr_bundle", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + corr_bundle = self._compute_sog_frame_correction_bundle( + coord_local, + latent_charge_runtime, + box_local, + need_force=need_force, + need_virial=need_virial, + ) + corr_redu = corr_bundle["corr_redu"] + + with ( + _time_block("fc_add_energy", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to( + model_ret["energy_redu"].dtype + ) + + if need_force: + corr_force_local = corr_bundle["force_local"].to(coord_local.dtype) + + with ( + _time_block("fc_scatter_force", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + corr_force_ext = torch.zeros( + (nf, nall, 3), + dtype=corr_force_local.dtype, + device=corr_force_local.device, + ) + corr_force_ext[:, :nloc, :] = corr_force_local + if "energy_derv_r" in model_ret: + model_ret["energy_derv_r"] = model_ret[ + "energy_derv_r" + ] + corr_force_ext.unsqueeze(-2).to( + model_ret["energy_derv_r"].dtype + ) + + if need_virial: + corr_virial_local = corr_bundle["virial_local"].to( + corr_force_local.dtype + ) + with ( + _time_block("fc_virial_update", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + corr_virial_redu = corr_virial_local.sum(dim=1) + if "energy_derv_c_redu" in model_ret: + model_ret["energy_derv_c_redu"] = model_ret[ + "energy_derv_c_redu" + ] + corr_virial_redu.to(model_ret["energy_derv_c_redu"].dtype) + if do_atomic_virial and "energy_derv_c" in model_ret: + corr_atom_virial = torch.zeros( + (nf, nall, 1, 9), + dtype=corr_virial_local.dtype, + device=corr_virial_local.device, + ) + corr_atom_virial[:, :nloc, :, :] = corr_virial_local + model_ret["energy_derv_c"] = model_ret[ + "energy_derv_c" + ] + corr_atom_virial.to(model_ret["energy_derv_c"].dtype) + + return model_ret + + model._compute_sog_frame_correction_bundle = types.MethodType(_timed_bundle, model) + model._apply_frame_correction_lower = types.MethodType(_timed_apply, model) + return orig_bundle, orig_apply + + +def profile( + model: Any, + nframes: int, + nloc: int, + box_len: float, + repeats: int, + warmup: int, + do_atomic_virial: bool, + dtype: torch.dtype, + fine_frame_profile: bool = False, +) -> dict[str, float]: + device = next(model.parameters()).device + + # NUFFT fine-grained timers by monkeypatching function calls. + nufft_times: dict[str, float] = defaultdict(float) + collect_nufft = False + orig_type1 = pytorch_finufft.functional.finufft_type1 + orig_type2 = pytorch_finufft.functional.finufft_type2 + + def timed_type1(*args, **kwargs): + if collect_nufft: + with _time_block("nufft_type1", nufft_times, device): + return orig_type1(*args, **kwargs) + return orig_type1(*args, **kwargs) + + def timed_type2(*args, **kwargs): + if collect_nufft: + with _time_block("nufft_type2", nufft_times, device): + return orig_type2(*args, **kwargs) + return orig_type2(*args, **kwargs) + + pytorch_finufft.functional.finufft_type1 = timed_type1 + pytorch_finufft.functional.finufft_type2 = timed_type2 + + timings: dict[str, float] = defaultdict(float) + detail_times: dict[str, float] = defaultdict(float) + collect_detail = {"on": False} + orig_bundle = None + orig_apply = None + + if fine_frame_profile: + orig_bundle, orig_apply = _install_fine_frame_corr_profiler( + model, + detail_times, + device, + collect_detail, + ) + + try: + for _ in range(warmup + repeats): + coord, atype, box = _build_synthetic_input( + nframes=nframes, + nloc=nloc, + box_len=box_len, + device=device, + dtype=dtype, + ) + + is_warmup = _ < warmup + collect_nufft = not is_warmup + collect_detail["on"] = not is_warmup + iter_times: dict[str, float] = defaultdict(float) + + with _time_block("input_cast", iter_times, device): + cc, bb, fp, ap, input_prec = model._input_type_cast( + coord, box=box, fparam=None, aparam=None + ) + + with _time_block("build_nlist", iter_times, device): + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + cc, + atype, + model.get_rcut(), + model.get_sel(), + mixed_types=True, + box=bb, + ) + + comm_dict: dict[str, torch.Tensor] | None = {"box": bb} + + if model.do_grad_r("energy") or model.do_grad_c("energy"): + extended_coord = extended_coord.requires_grad_(True) + + with _time_block("lower_super", iter_times, device): + model_ret = SOGEnergyModel_.forward_common_lower( + model, + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fp, + aparam=ap, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=False, + extended_coord_corr=None, + ) + + with _time_block("lower_frame_corr", iter_times, device): + model_ret = model._apply_frame_correction_lower( + model_ret, + extended_coord, + nlist, + bb, + do_atomic_virial, + ) + + with _time_block("communicate_output", iter_times, device): + model_ret = communicate_extended_output( + model_ret, + model.model_output_def(), + mapping, + do_atomic_virial=do_atomic_virial, + ) + + with _time_block("output_cast", iter_times, device): + model_ret = model._output_type_cast(model_ret, input_prec) + _ = model_ret["energy_redu"] + + if not is_warmup: + iter_times["total"] = sum( + iter_times[k] + for k in [ + "input_cast", + "build_nlist", + "lower_super", + "lower_frame_corr", + "communicate_output", + "output_cast", + ] + ) + for k, v in iter_times.items(): + timings[k] += v + + # Only keep averaged timings for measured iterations. + for k in list(timings.keys()): + timings[k] /= repeats + for k, v in nufft_times.items(): + timings[k] = v / repeats + for k, v in detail_times.items(): + timings[k] = v / repeats + + return dict(timings) + finally: + pytorch_finufft.functional.finufft_type1 = orig_type1 + pytorch_finufft.functional.finufft_type2 = orig_type2 + if fine_frame_profile and orig_bundle is not None and orig_apply is not None: + model._compute_sog_frame_correction_bundle = orig_bundle + model._apply_frame_correction_lower = orig_apply + + +def _format_report(timings: dict[str, float]) -> str: + total = sum( + timings.get(k, 0.0) + for k in [ + "input_cast", + "build_nlist", + "lower_super", + "lower_frame_corr", + "communicate_output", + "output_cast", + ] + ) + + keys = [ + "input_cast", + "build_nlist", + "lower_super", + "lower_frame_corr", + "nufft_type1", + "nufft_type2", + "communicate_output", + "output_cast", + "fc_guard_and_slice", + "fc_compute_corr_bundle", + "fc_cast_inputs", + "fc_param_prepare", + "fc_geom_and_points", + "fc_build_k_grid", + "fc_build_kfac", + "fc_prepare_charge", + "fc_nufft_type1", + "fc_energy_reduce", + "fc_prepare_force_conv", + "fc_prepare_force_kgrid", + "fc_nufft_type2_force", + "fc_force_reduce", + "fc_virial_local", + "fc_self_interaction", + "fc_add_energy", + "fc_scatter_force", + "fc_virial_update", + ] + + lines = [] + lines.append("Timing breakdown (avg per iteration):") + for k in keys: + if k in timings: + ms = timings[k] * 1000.0 + ratio = (timings[k] / total * 100.0) if total > 0 else 0.0 + lines.append(f" - {k:20s}: {ms:10.3f} ms ({ratio:6.2f}%)") + + lines.append(f" - {'total(sum)':20s}: {total * 1000.0:10.3f} ms (100.00%)") + return "\n".join(lines) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + type=Path, + default=Path("examples/water/sog/input_torch.json"), + ) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument( + "--dtype", type=str, default="float32", choices=["float32", "float64"] + ) + parser.add_argument("--nframes", type=int, default=1) + parser.add_argument("--nloc", type=int, default=192) + parser.add_argument("--box-len", type=float, default=20.0) + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--repeats", type=int, default=10) + parser.add_argument("--atomic-virial", action="store_true") + parser.add_argument("--n-dl-override", type=int, default=2) + parser.add_argument("--disable-energy-grad", action="store_true") + parser.add_argument("--fine-frame-profile", action="store_true") + args = parser.parse_args() + + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + dtype = torch.float64 if args.dtype == "float64" else torch.float32 + + model = _load_model(args.config, device) + if args.n_dl_override > 0 and hasattr(model.get_fitting_net(), "n_dl"): + model.get_fitting_net().n_dl = int(args.n_dl_override) + if args.disable_energy_grad: + model.do_grad_r = types.MethodType(lambda self, _name: False, model) + model.do_grad_c = types.MethodType(lambda self, _name: False, model) + timings = profile( + model=model, + nframes=args.nframes, + nloc=args.nloc, + box_len=args.box_len, + repeats=args.repeats, + warmup=args.warmup, + do_atomic_virial=args.atomic_virial, + dtype=dtype, + fine_frame_profile=args.fine_frame_profile, + ) + print(_format_report(timings)) + + +if __name__ == "__main__": + main() diff --git a/examples/water/sog/profile_sog_whatif.py b/examples/water/sog/profile_sog_whatif.py new file mode 100644 index 0000000000..d56e30091a --- /dev/null +++ b/examples/water/sog/profile_sog_whatif.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) + +import copy +import json +from pathlib import ( + Path, +) + +import torch +from profile_sog_timing import ( + profile, +) + +from deepmd.pt.model.model import ( + get_model, +) + +CFG_PATH = Path("examples/water/sog/input_torch.json") + + +def run(tag: str, model_cfg: dict) -> None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = get_model(model_cfg).to(device) + model.eval() + t = profile( + model, + nframes=1, + nloc=192, + box_len=20.0, + repeats=6, + warmup=2, + do_atomic_virial=False, + dtype=torch.float32, + ) + total = ( + t["input_cast"] + + t["build_nlist"] + + t["lower_super"] + + t["lower_frame_corr"] + + t["communicate_output"] + + t["output_cast"] + ) * 1000.0 + print( + f"{tag}: total={total:.3f}ms, " + f"lower_super={t['lower_super'] * 1000.0:.3f}ms, " + f"lower_frame_corr={t['lower_frame_corr'] * 1000.0:.3f}ms, " + f"nufft1={t.get('nufft_type1', 0.0) * 1000.0:.3f}ms, " + f"nufft2={t.get('nufft_type2', 0.0) * 1000.0:.3f}ms" + ) + + +def main() -> None: + cfg = json.loads(CFG_PATH.read_text()) + base_model_cfg = cfg["model"] + + run("baseline", copy.deepcopy(base_model_cfg)) + + cfg_lr1 = copy.deepcopy(base_model_cfg) + cfg_lr1["fitting_net"]["dim_out_lr"] = 1 + run("dim_out_lr=1", cfg_lr1) + + cfg_small = copy.deepcopy(base_model_cfg) + cfg_small["fitting_net"]["neuron_sr"] = [128, 128, 128] + cfg_small["fitting_net"]["neuron_lr"] = [128, 128, 128] + run("neurons=128", cfg_small) + + cfg_both = copy.deepcopy(cfg_lr1) + cfg_both["fitting_net"]["neuron_sr"] = [128, 128, 128] + cfg_both["fitting_net"]["neuron_lr"] = [128, 128, 128] + run("dim_out_lr=1 + neurons=128", cfg_both) + + +if __name__ == "__main__": + main() diff --git a/examples/water/sog/sog.hdf5 b/examples/water/sog/sog.hdf5 new file mode 100644 index 0000000000..90ae7bb1d6 Binary files /dev/null and b/examples/water/sog/sog.hdf5 differ diff --git a/pyproject.toml b/pyproject.toml index 2f8f86a713..69506beb4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ build-backend = "backend.dp_backend" backend-path = ["."] [project] -name = "deepmd-kit" +name = "deepmd-kit-dev" dynamic = ["version", "optional-dependencies", "scripts", "readme"] description = "A deep learning package for many-body potential energy representation and molecular dynamics" authors = [ @@ -142,7 +142,7 @@ jax = [ ] [tool.deepmd_build_backend.scripts] -dp = "deepmd.main:main" +dp_dev = "deepmd.main:main" [dependency-groups] dev = [ @@ -171,6 +171,8 @@ pin_pytorch_cpu = [ ] pin_pytorch_gpu = [ "torch==2.10.0", + "pytorch-finufft>=0.1.0", + "cufinufft>=2.5.0; platform_system=='Linux' and platform_machine=='x86_64'", ] pin_jax = [ "jax==0.5.0;python_version>='3.10'", diff --git a/source/tests/pt/model/test_les_working_layer.py b/source/tests/pt/model/test_les_working_layer.py new file mode 100644 index 0000000000..b0ce487a1f --- /dev/null +++ b/source/tests/pt/model/test_les_working_layer.py @@ -0,0 +1,212 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import torch + +try: + import pytorch_finufft # noqa: F401 + + HAS_FINUFFT = True +except Exception: + HAS_FINUFFT = False + +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.model.les_model import ( + LESEnergyModel, +) +from deepmd.pt.model.task.les_energy_fitting import ( + LESEnergyFittingNet, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +def _reduce_extended_tensor( + extended_tensor: torch.Tensor, mapping: torch.Tensor, nloc: int +): + nframes = extended_tensor.shape[0] + ext_dims = extended_tensor.shape[2:] + reduced_tensor = torch.zeros( + [nframes, nloc, *ext_dims], + dtype=extended_tensor.dtype, + device=extended_tensor.device, + ) + mldims = list(mapping.shape) + mapping_exp = mapping.view(mldims + [1] * len(ext_dims)).expand( + [-1] * len(mldims) + list(ext_dims) + ) + reduced_tensor = torch.scatter_reduce( + reduced_tensor, + 1, + index=mapping_exp, + src=extended_tensor, + reduce="sum", + ) + return reduced_tensor + + +@unittest.skipIf(not HAS_FINUFFT, "pytorch_finufft is required for LES tests") +class TestLESWorkingLayer(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(2027) + self.nf = 2 + self.nloc = 4 + self.nt = 2 + self.rcut = 4.0 + self.rcut_smth = 3.5 + self.sel = [8, 8] + + self.descriptor = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + self.fitting = LESEnergyFittingNet( + var_name="energy", + ntypes=self.nt, + dim_descrpt=self.descriptor.get_dim_out(), + dim_out_sr=1, + dim_out_lr=1, + mixed_types=self.descriptor.mixed_types(), + n_dl=2, + ).to(env.DEVICE) + self.model = LESEnergyModel( + descriptor=self.descriptor, + fitting=self.fitting, + type_map=["A", "B"], + ).to(env.DEVICE) + + coord = torch.rand( + (self.nf, self.nloc, 3), + dtype=dtype, + device=env.DEVICE, + ) + cell = ( + torch.eye(3, dtype=dtype, device=env.DEVICE) + .unsqueeze(0) + .repeat(self.nf, 1, 1) + ) + cell = cell * 5.0 + self.coord = coord.reshape(self.nf, self.nloc * 3) + self.cell = cell.reshape(self.nf, 9) + self.atype = torch.tensor( + [[0, 0, 1, 1], [1, 0, 1, 0]], + dtype=torch.int64, + device=env.DEVICE, + ) + + def test_frame_correction_applies_once_per_frame(self) -> None: + coord3 = self.coord.view(self.nf, self.nloc, 3) + cell33 = self.cell.view(self.nf, 3, 3) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord3, + self.atype, + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=True, + box=cell33, + ) + + lower_ret = self.model.forward_common_lower( + extended_coord=extended_coord, + extended_atype=extended_atype, + nlist=nlist, + mapping=mapping, + do_atomic_virial=False, + comm_dict={"box": cell33}, + ) + + frame_corr = self.model._compute_les_frame_correction( + extended_coord[:, : self.nloc, :], + lower_ret["latent_charge"], + cell33, + ).to(lower_ret["energy_redu"].dtype) + expected_energy_redu = lower_ret["energy"].sum(dim=1) + frame_corr + + torch.testing.assert_close( + lower_ret["energy_redu"], + expected_energy_redu, + rtol=1e-8, + atol=1e-8, + ) + + def test_forward_and_forward_lower_consistency(self) -> None: + fw = self.model.forward( + self.coord, + self.atype, + box=self.cell, + do_atomic_virial=True, + ) + + coord3 = self.coord.view(self.nf, self.nloc, 3) + cell33 = self.cell.view(self.nf, 3, 3) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord3, + self.atype, + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=True, + box=cell33, + ) + + fw_lower = self.model.forward_lower( + extended_coord=extended_coord, + extended_atype=extended_atype, + nlist=nlist, + mapping=mapping, + do_atomic_virial=True, + comm_dict={"box": cell33}, + ) + + torch.testing.assert_close( + fw_lower["energy"], fw["energy"], rtol=1e-8, atol=1e-8 + ) + torch.testing.assert_close( + fw_lower["virial"], fw["virial"], rtol=1e-7, atol=1e-7 + ) + + reduced_force = _reduce_extended_tensor( + fw_lower["extended_force"], mapping, self.nloc + ) + torch.testing.assert_close(reduced_force, fw["force"], rtol=1e-7, atol=1e-7) + + def test_long_range_params_have_gradient_path(self) -> None: + self.model.zero_grad(set_to_none=True) + out = self.model.forward( + self.coord, + self.atype, + box=self.cell, + do_atomic_virial=False, + ) + grads = torch.autograd.grad( + out["energy"].sum(), + [self.fitting.wl, self.fitting.sl], + retain_graph=False, + create_graph=False, + allow_unused=False, + ) + for gg in grads: + self.assertIsNotNone(gg) + self.assertTrue(torch.isfinite(gg).all().item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_sog_working_layer.py b/source/tests/pt/model/test_sog_working_layer.py new file mode 100644 index 0000000000..7e4ed7d89b --- /dev/null +++ b/source/tests/pt/model/test_sog_working_layer.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import torch + +try: + import pytorch_finufft # noqa: F401 + + HAS_FINUFFT = True +except Exception: + HAS_FINUFFT = False + +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.model.sog_model import ( + SOGEnergyModel, +) +from deepmd.pt.model.task.sog_energy_fitting import ( + SOGEnergyFittingNet, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +def _reduce_extended_tensor( + extended_tensor: torch.Tensor, mapping: torch.Tensor, nloc: int +): + nframes = extended_tensor.shape[0] + ext_dims = extended_tensor.shape[2:] + reduced_tensor = torch.zeros( + [nframes, nloc, *ext_dims], + dtype=extended_tensor.dtype, + device=extended_tensor.device, + ) + mldims = list(mapping.shape) + mapping_exp = mapping.view(mldims + [1] * len(ext_dims)).expand( + [-1] * len(mldims) + list(ext_dims) + ) + reduced_tensor = torch.scatter_reduce( + reduced_tensor, + 1, + index=mapping_exp, + src=extended_tensor, + reduce="sum", + ) + return reduced_tensor + + +@unittest.skipIf(not HAS_FINUFFT, "pytorch_finufft is required for SOG tests") +class TestSOGWorkingLayer(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(2026) + self.nf = 2 + self.nloc = 4 + self.nt = 2 + self.rcut = 4.0 + self.rcut_smth = 3.5 + self.sel = [8, 8] + + self.descriptor = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + self.fitting = SOGEnergyFittingNet( + var_name="energy", + ntypes=self.nt, + dim_descrpt=self.descriptor.get_dim_out(), + dim_out_sr=1, + dim_out_lr=1, + mixed_types=self.descriptor.mixed_types(), + n_dl=2, + ).to(env.DEVICE) + self.model = SOGEnergyModel( + descriptor=self.descriptor, + fitting=self.fitting, + type_map=["A", "B"], + ).to(env.DEVICE) + + coord = torch.rand( + (self.nf, self.nloc, 3), + dtype=dtype, + device=env.DEVICE, + ) + cell = ( + torch.eye(3, dtype=dtype, device=env.DEVICE) + .unsqueeze(0) + .repeat(self.nf, 1, 1) + ) + cell = cell * 5.0 + self.coord = coord.reshape(self.nf, self.nloc * 3) + self.cell = cell.reshape(self.nf, 9) + self.atype = torch.tensor( + [[0, 0, 1, 1], [1, 0, 1, 0]], + dtype=torch.int64, + device=env.DEVICE, + ) + + def test_frame_correction_applies_once_per_frame(self) -> None: + coord3 = self.coord.view(self.nf, self.nloc, 3) + cell33 = self.cell.view(self.nf, 3, 3) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord3, + self.atype, + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=True, + box=cell33, + ) + + lower_ret = self.model.forward_common_lower( + extended_coord=extended_coord, + extended_atype=extended_atype, + nlist=nlist, + mapping=mapping, + do_atomic_virial=False, + comm_dict={"box": cell33}, + ) + + frame_corr = self.model._compute_sog_frame_correction( + extended_coord[:, : self.nloc, :], + lower_ret["latent_charge"], + cell33, + ).to(lower_ret["energy_redu"].dtype) + expected_energy_redu = lower_ret["energy"].sum(dim=1) + frame_corr + + torch.testing.assert_close( + lower_ret["energy_redu"], + expected_energy_redu, + rtol=1e-8, + atol=1e-8, + ) + + def test_forward_and_forward_lower_consistency(self) -> None: + fw = self.model.forward( + self.coord, + self.atype, + box=self.cell, + do_atomic_virial=True, + ) + + coord3 = self.coord.view(self.nf, self.nloc, 3) + cell33 = self.cell.view(self.nf, 3, 3) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord3, + self.atype, + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=True, + box=cell33, + ) + + fw_lower = self.model.forward_lower( + extended_coord=extended_coord, + extended_atype=extended_atype, + nlist=nlist, + mapping=mapping, + do_atomic_virial=True, + comm_dict={"box": cell33}, + ) + + torch.testing.assert_close( + fw_lower["energy"], fw["energy"], rtol=1e-8, atol=1e-8 + ) + torch.testing.assert_close( + fw_lower["virial"], fw["virial"], rtol=1e-7, atol=1e-7 + ) + + reduced_force = _reduce_extended_tensor( + fw_lower["extended_force"], mapping, self.nloc + ) + torch.testing.assert_close(reduced_force, fw["force"], rtol=1e-7, atol=1e-7) + + +if __name__ == "__main__": + unittest.main()