From 67173d5c0902ceaebc52110ab5f37dfd230178d4 Mon Sep 17 00:00:00 2001 From: zhenyuan5090 Date: Sat, 14 Mar 2026 21:37:44 +0800 Subject: [PATCH 1/8] 1st commit(test passed, but without lr corr) --- deepmd/pt/model/atomic_model/__init__.py | 8 + .../atomic_model/lr_energy_atomic_model.py | 348 ++++++++++ .../pt/model/atomic_model/sog_atomic_model.py | 267 ++++++++ deepmd/pt/model/model/__init__.py | 6 + deepmd/pt/model/model/sog_model.py | 174 +++++ deepmd/pt/model/task/__init__.py | 8 + deepmd/pt/model/task/lr_fitting.py | 629 ++++++++++++++++++ deepmd/pt/model/task/sog_energy_fitting.py | 282 ++++++++ deepmd/utils/argcheck.py | 170 +++++ examples/water/dpa3/dpa3.hdf5 | Bin 0 -> 4930 bytes examples/water/dpa3/input_torch_copy.json | 103 +++ examples/water/sog/README.md | 16 + examples/water/sog/input_torch.json | 134 ++++ examples/water/sog/sog.hdf5 | Bin 0 -> 4930 bytes pyproject.toml | 4 +- 15 files changed, 2147 insertions(+), 2 deletions(-) create mode 100644 deepmd/pt/model/atomic_model/lr_energy_atomic_model.py create mode 100644 deepmd/pt/model/atomic_model/sog_atomic_model.py create mode 100644 deepmd/pt/model/model/sog_model.py create mode 100644 deepmd/pt/model/task/lr_fitting.py create mode 100644 deepmd/pt/model/task/sog_energy_fitting.py create mode 100644 examples/water/dpa3/dpa3.hdf5 create mode 100644 examples/water/dpa3/input_torch_copy.json create mode 100644 examples/water/sog/README.md create mode 100644 examples/water/sog/input_torch.json create mode 100644 examples/water/sog/sog.hdf5 diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index 4da9bf781b..c008d75575 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -42,6 +42,12 @@ from .property_atomic_model import ( DPPropertyAtomicModel, ) +from .lr_energy_atomic_model import ( + LREnergyAtomicModel, +) +from .sog_atomic_model import ( + SOGEnergyAtomicModel, +) __all__ = [ "BaseAtomicModel", @@ -53,5 +59,7 @@ "DPPropertyAtomicModel", "DPZBLLinearEnergyAtomicModel", "LinearEnergyAtomicModel", + "LREnergyAtomicModel", "PairTabAtomicModel", + "SOGEnergyAtomicModel", ] diff --git a/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py b/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py new file mode 100644 index 0000000000..43cdfd4c30 --- /dev/null +++ b/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py @@ -0,0 +1,348 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import Any, Iterable, Optional + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt.model.task.ener import ( + EnergyFittingNet, + EnergyFittingNetDirect, + InvarFitting, +) +from deepmd.pt.model.task.property import ( + PropertyFittingNet, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_atomic_model import ( + BaseAtomicModel, +) + + +@BaseAtomicModel.register("energy_lr") +class LREnergyAtomicModel(BaseAtomicModel): + """Energy model with an auxiliary property-driven correction. + + This model shares one descriptor with two fitting nets: + - ``energy_fitting`` predicts the primary atomic energy/force term. + - ``property_fitting`` predicts an auxiliary per-atom property ``q``. + The property is then passed through a small trainable correction head + to generate an additive energy (and resulting force) correction. + """ + + def __init__( + self, + descriptor: BaseDescriptor, + energy_fitting: InvarFitting, + property_fitting: PropertyFittingNet, + type_map: list[str], + correction_hidden: Optional[Iterable[int]] = None, + correction_activation: str = "tanh", + **kwargs: Any, + ) -> None: + super().__init__(type_map, **kwargs) + if not ( + isinstance(energy_fitting, EnergyFittingNet) + or isinstance(energy_fitting, EnergyFittingNetDirect) + or isinstance(energy_fitting, InvarFitting) + ): + raise TypeError( + "energy_fitting must be an energy-like InvarFitting for LREnergyAtomicModel" + ) + if not isinstance(property_fitting, PropertyFittingNet): + raise TypeError( + "property_fitting must be an instance of PropertyFittingNet for LREnergyAtomicModel" + ) + + if energy_fitting.get_dim_fparam() != property_fitting.get_dim_fparam(): + raise ValueError("energy_fitting and property_fitting must share the same dim_fparam") + if energy_fitting.get_dim_aparam() != property_fitting.get_dim_aparam(): + raise ValueError("energy_fitting and property_fitting must share the same dim_aparam") + + self.descriptor = descriptor + self.energy_fitting = energy_fitting + self.property_fitting = property_fitting + self.property_name = property_fitting.var_name + self.type_map = type_map + self.ntypes = len(type_map) + self.rcut = self.descriptor.get_rcut() + self.sel = self.descriptor.get_sel() + self.correction_activation = correction_activation + hidden = list(correction_hidden) if correction_hidden is not None else [property_fitting.dim_out] + self.correction_hidden = hidden + self.correction_head = self._build_correction_head( + property_fitting.dim_out, hidden, correction_activation + ) + super().init_out_stat() + + self.enable_eval_descriptor_hook = False + self.enable_eval_fitting_last_layer_hook = False + self.eval_descriptor_list: list[torch.Tensor] = [] + self.eval_fitting_last_layer_list: list[torch.Tensor] = [] + + @staticmethod + def _build_correction_head( + input_dim: int, hidden: list[int], activation: str + ) -> torch.nn.Module: + layers: list[torch.nn.Module] = [] + last = input_dim + act_factory = getattr(torch.nn, activation.capitalize(), torch.nn.Tanh) + for width in hidden: + layers.append(torch.nn.Linear(last, width)) + layers.append(act_factory()) + last = width + layers.append(torch.nn.Linear(last, 1)) + return torch.nn.Sequential(*layers) + + def fitting_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + name=self.property_name, + shape=[self.property_fitting.dim_out], + reducible=True, + r_differentiable=False, + c_differentiable=False, + intensive=self.property_fitting.get_intensive(), + ), + ] + ) + + def get_rcut(self) -> float: + return self.rcut + + def get_sel(self) -> list[int]: + return self.sel + + def mixed_types(self) -> bool: + return self.descriptor.mixed_types() + + def has_message_passing(self) -> bool: + return self.descriptor.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + return self.descriptor.need_sorted_nlist_for_lower() + + def set_case_embd(self, case_idx: int) -> None: + self.energy_fitting.set_case_embd(case_idx) + self.property_fitting.set_case_embd(case_idx) + + def get_dim_fparam(self) -> int: + return self.energy_fitting.get_dim_fparam() + + def has_default_fparam(self) -> bool: + return self.energy_fitting.has_default_fparam() + + def get_default_fparam(self) -> Optional[torch.Tensor]: + return self.energy_fitting.get_default_fparam() + + def get_dim_aparam(self) -> int: + return self.energy_fitting.get_dim_aparam() + + def get_sel_type(self) -> list[int]: + return self.energy_fitting.get_sel_type() + + def is_aparam_nall(self) -> bool: + return False + + def set_eval_descriptor_hook(self, enable: bool) -> None: + self.enable_eval_descriptor_hook = enable + self.eval_descriptor_list.clear() + + def eval_descriptor(self) -> torch.Tensor: + return torch.concat(self.eval_descriptor_list) + + def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: + self.enable_eval_fitting_last_layer_hook = enable + self.energy_fitting.set_return_middle_output(enable) + self.eval_fitting_last_layer_list.clear() + + def eval_fitting_last_layer(self) -> torch.Tensor: + return torch.concat(self.eval_fitting_last_layer_list) + + def forward_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ) -> dict[str, torch.Tensor]: + nframes, nloc, _ = nlist.shape + atype = extended_atype[:, :nloc] + if self.do_grad_r() or self.do_grad_c(): + extended_coord.requires_grad_(True) + descriptor, rot_mat, g2, h2, sw = self.descriptor( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + comm_dict=comm_dict, + ) + assert descriptor is not None + if self.enable_eval_descriptor_hook: + self.eval_descriptor_list.append(descriptor.detach()) + + energy_ret = self.energy_fitting( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + prop_ret = self.property_fitting( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + + if self.enable_eval_fitting_last_layer_hook and "middle_output" in energy_ret: + self.eval_fitting_last_layer_list.append(energy_ret["middle_output"].detach()) + + q_val = prop_ret[self.property_name] + corr_energy = self.correction_head(q_val) + total_energy = energy_ret["energy"] + corr_energy + + return { + "energy": total_energy, + self.property_name: q_val, + } + + def apply_out_stat( + self, + ret: dict[str, torch.Tensor], + atype: torch.Tensor, + ) -> dict[str, torch.Tensor]: + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for kk in self.bias_keys: + if kk == self.property_name: + ret[kk] = ret[kk] * out_std[kk][atype] + out_bias[kk][atype] + else: + ret[kk] = ret[kk] + out_bias[kk][atype] + return ret + + def compute_or_load_stat( + self, + sampled_func: Any, + stat_file_path: Optional[Any] = None, + compute_or_load_out_stat: bool = True, + ) -> None: + if stat_file_path is not None and self.type_map is not None: + stat_file_path /= " ".join(self.type_map) + + def wrapped_sampler() -> list[dict]: + sampled = sampled_func() + if self.pair_excl is not None: + pair_exclude_types = self.pair_excl.get_exclude_types() + for sample in sampled: + sample["pair_exclude_types"] = list(pair_exclude_types) + if self.atom_excl is not None: + atom_exclude_types = self.atom_excl.get_exclude_types() + for sample in sampled: + sample["atom_exclude_types"] = list(atom_exclude_types) + if ( + "find_fparam" not in sampled[0] + and "fparam" not in sampled[0] + and self.has_default_fparam() + ): + default_fparam = self.get_default_fparam() + for sample in sampled: + nframe = sample["atype"].shape[0] + sample["fparam"] = default_fparam.repeat(nframe, 1) + return sampled + + self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) + self.compute_fitting_input_stat(wrapped_sampler, stat_file_path) + if compute_or_load_out_stat: + self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + + def compute_fitting_input_stat( + self, + sample_merged: Any, + stat_file_path: Optional[Any] = None, + ) -> None: + self.energy_fitting.compute_input_stats( + sample_merged, + protection=self.data_stat_protect, + stat_file_path=stat_file_path, + ) + self.property_fitting.compute_input_stats( + sample_merged, + protection=self.data_stat_protect, + stat_file_path=stat_file_path, + ) + + def serialize(self) -> dict: + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "@version": 1, + "type": "energy_q_aug", + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "energy_fitting": self.energy_fitting.serialize(), + "property_fitting": self.property_fitting.serialize(), + "correction_hidden": self.correction_hidden, + "correction_activation": self.correction_activation, + "@variables": { + **dd.get("@variables", {}), + "correction_head": { + k: to_numpy_array(v) for k, v in self.correction_head.state_dict().items() + }, + }, + } + ) + return dd + + @classmethod + def deserialize(cls, data: dict) -> "LREnergyAtomicModel": + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + data.pop("type", None) + variables = data.pop("@variables", {}) + descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) + energy_fitting_obj = InvarFitting.deserialize(data.pop("energy_fitting")) + property_fitting_obj = PropertyFittingNet.deserialize(data.pop("property_fitting")) + correction_hidden = data.pop("correction_hidden", None) + correction_activation = data.pop("correction_activation", "tanh") + obj = cls( + descriptor_obj, + energy_fitting_obj, + property_fitting_obj, + correction_hidden=correction_hidden, + correction_activation=correction_activation, + **data, + ) + correction_state = variables.get("correction_head", None) + if correction_state is not None: + obj.correction_head.load_state_dict({k: to_torch_tensor(v) for k, v in correction_state.items()}) + return obj diff --git a/deepmd/pt/model/atomic_model/sog_atomic_model.py b/deepmd/pt/model/atomic_model/sog_atomic_model.py new file mode 100644 index 0000000000..d83acd6f75 --- /dev/null +++ b/deepmd/pt/model/atomic_model/sog_atomic_model.py @@ -0,0 +1,267 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import Any, Optional + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt.model.task.sog_energy_fitting import ( + SOGEnergyFittingNet, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_atomic_model import ( + BaseAtomicModel, +) + + +@BaseAtomicModel.register("energy_sog") +class SOGEnergyAtomicModel(BaseAtomicModel): + """Energy model using a dedicated SOG energy fitting net. + + The SOG energy fitting net combines a short-range invariant fitting + and a long-range correction derived from another invariant fitting. + This avoids requiring a user-defined property name in the dataset. + """ + + def __init__( + self, + descriptor: Any, + type_map: list[str], + sog_energy_fitting: Optional[SOGEnergyFittingNet] = None, + fitting: Optional[Any] = None, + **kwargs: Any, + ) -> None: + super().__init__(type_map, **kwargs) + if sog_energy_fitting is None: + sog_energy_fitting = fitting + if not isinstance(sog_energy_fitting, SOGEnergyFittingNet): + raise TypeError( + "sog_energy_fitting must be an instance of SOGEnergyFittingNet" + ) + + self.descriptor = descriptor + self.fitting_net = sog_energy_fitting + # self.sog_energy_fitting = self.fitting_net + self.type_map = type_map + self.ntypes = len(type_map) + self.rcut = self.descriptor.get_rcut() + self.sel = self.descriptor.get_sel() + + super().init_out_stat() + + self.enable_eval_descriptor_hook = False + self.enable_eval_fitting_last_layer_hook = False + self.eval_descriptor_list: list[torch.Tensor] = [] + self.eval_fitting_last_layer_list: list[torch.Tensor] = [] + + @torch.jit.export + def fitting_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def get_rcut(self) -> float: + return self.rcut + + def get_sel(self) -> list[int]: + return self.sel + + def mixed_types(self) -> bool: + return self.descriptor.mixed_types() + + def has_message_passing(self) -> bool: + return self.descriptor.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + return self.descriptor.need_sorted_nlist_for_lower() + + def set_case_embd(self, case_idx: int) -> None: + self.fitting_net.set_case_embd(case_idx) + + def get_dim_fparam(self) -> int: + return self.fitting_net.get_dim_fparam() + + def has_default_fparam(self) -> bool: + return self.fitting_net.has_default_fparam() + + def get_default_fparam(self) -> Optional[torch.Tensor]: + return self.fitting_net.get_default_fparam() + + def get_dim_aparam(self) -> int: + return self.fitting_net.get_dim_aparam() + + def get_sel_type(self) -> list[int]: + return self.fitting_net.get_sel_type() + + def is_aparam_nall(self) -> bool: + return False + + def set_eval_descriptor_hook(self, enable: bool) -> None: + self.enable_eval_descriptor_hook = enable + self.eval_descriptor_list.clear() + + def eval_descriptor(self) -> torch.Tensor: + return torch.concat(self.eval_descriptor_list) + + def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: + self.enable_eval_fitting_last_layer_hook = enable + self.fitting_net.set_return_middle_output(enable) + self.eval_fitting_last_layer_list.clear() + + def eval_fitting_last_layer(self) -> torch.Tensor: + return torch.concat(self.eval_fitting_last_layer_list) + + def forward_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ) -> dict[str, torch.Tensor]: + nframes, nloc, _ = nlist.shape + atype = extended_atype[:, :nloc] + if self.do_grad_r() or self.do_grad_c(): + extended_coord.requires_grad_(True) + descriptor, rot_mat, g2, h2, sw = self.descriptor( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + comm_dict=comm_dict, + ) + assert descriptor is not None + if self.enable_eval_descriptor_hook: + self.eval_descriptor_list.append(descriptor.detach()) + + energy_ret = self.fitting_net( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + + if self.enable_eval_fitting_last_layer_hook and "middle_output" in energy_ret: + self.eval_fitting_last_layer_list.append( + energy_ret["middle_output"].detach() + ) + + ret = { + "energy": energy_ret["energy"], + } + if "middle_output" in energy_ret: + ret["middle_output"] = energy_ret["middle_output"] + return ret + + def apply_out_stat( + self, + ret: dict[str, torch.Tensor], + atype: torch.Tensor, + ) -> dict[str, torch.Tensor]: + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for kk in self.bias_keys: + ret[kk] = ret[kk] + out_bias[kk][atype] + return ret + + def compute_or_load_stat( + self, + sampled_func: Any, + stat_file_path: Optional[Any] = None, + compute_or_load_out_stat: bool = True, + preset_observed_type: Optional[list[str]] = None, + ) -> None: + if stat_file_path is not None and self.type_map is not None: + stat_file_path /= " ".join(self.type_map) + + def wrapped_sampler() -> list[dict]: + sampled = sampled_func() + if self.pair_excl is not None: + pair_exclude_types = self.pair_excl.get_exclude_types() + for sample in sampled: + sample["pair_exclude_types"] = list(pair_exclude_types) + if self.atom_excl is not None: + atom_exclude_types = self.atom_excl.get_exclude_types() + for sample in sampled: + sample["atom_exclude_types"] = list(atom_exclude_types) + if ( + "find_fparam" not in sampled[0] + and "fparam" not in sampled[0] + and self.has_default_fparam() + ): + default_fparam = self.get_default_fparam() + for sample in sampled: + nframe = sample["atype"].shape[0] + sample["fparam"] = default_fparam.repeat(nframe, 1) + return sampled + + self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) + self.compute_fitting_input_stat(wrapped_sampler, stat_file_path) + if compute_or_load_out_stat: + self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + + self._collect_and_set_observed_type( + wrapped_sampler, stat_file_path, preset_observed_type + ) + + def compute_fitting_input_stat( + self, + sample_merged: Any, + stat_file_path: Optional[Any] = None, + ) -> None: + self.fitting_net.compute_input_stats( + sample_merged, + protection=self.data_stat_protect, + stat_file_path=stat_file_path, + ) + + def serialize(self) -> dict: + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "@version": 1, + "type": "energy_sog", + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "sog_energy_fitting": self.fitting_net.serialize(), + } + ) + return dd + + @classmethod + def deserialize(cls, data: dict) -> "SOGEnergyAtomicModel": + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + data.pop("type", None) + descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) + sog_energy_fitting_obj = SOGEnergyFittingNet.deserialize( + data.pop("sog_energy_fitting") + ) + obj = cls( + descriptor=descriptor_obj, + sog_energy_fitting=sog_energy_fitting_obj, + **data, + ) + return obj diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index b8f7b171d4..aea6b1be12 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -73,6 +73,9 @@ SpinEnergyModel, SpinModel, ) +from .sog_model import ( + SOGEnergyModel, +) def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: @@ -268,6 +271,8 @@ def get_standard_model(model_params: dict) -> BaseModel: modelcls = EnergyModel elif fitting_net_type == "property": modelcls = PropertyModel + elif fitting_net_type == "sog_energy": + modelcls = SOGEnergyModel else: raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") @@ -311,6 +316,7 @@ def get_model(model_params: dict) -> Any: "FrozenModel", "LinearEnergyModel", "PolarModel", + "SOGEnergyModel", "SpinEnergyModel", "SpinModel", "get_model", diff --git a/deepmd/pt/model/model/sog_model.py b/deepmd/pt/model/model/sog_model.py new file mode 100644 index 0000000000..3d74a70d0b --- /dev/null +++ b/deepmd/pt/model/model/sog_model.py @@ -0,0 +1,174 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, +) + +import torch + +from deepmd.pt.model.atomic_model import ( + SOGEnergyAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_hessian_model import ( + make_hessian_model, +) +from .make_model import ( + make_model, +) + +SOGEnergyModel_ = make_model(SOGEnergyAtomicModel) + + +@BaseModel.register("sog_ener") +class SOGEnergyModel(DPModelCommon, SOGEnergyModel_): + model_type = "sog_ener" + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + DPModelCommon.__init__(self) + SOGEnergyModel_.__init__(self, *args, **kwargs) + self._hessian_enabled = False + + def enable_hessian(self) -> None: + self.__class__ = make_hessian_model(type(self)) + self.hess_fitting_def = super(type(self), self).atomic_output_def() + self.requires_hessian("energy") + self._hessian_enabled = True + + @torch.jit.export + def get_observed_type_list(self) -> list[str]: + """Get observed types (elements) of the model during data statistics. + + Returns + ------- + observed_type_list: a list of the observed types in this model. + """ + type_map = self.get_type_map() + out_bias = self.atomic_model.get_out_bias()[0] + + assert out_bias is not None, "No out_bias found in the model." + assert out_bias.dim() == 2, "The supported out_bias should be a 2D tensor." + assert out_bias.size(0) == len(type_map), ( + "The out_bias shape does not match the type_map length." + ) + bias_mask = ( + torch.gt(torch.abs(out_bias), 1e-6).any(dim=-1).detach().cpu() + ) # 1e-6 for stability + + observed_type_list: list[str] = [] + for i in range(len(type_map)): + if bias_mask[i]: + observed_type_list.append(type_map[i]) + return observed_type_list + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], + } + if self.do_grad_r("energy"): + output_def["force"] = out_def_data["energy_derv_r"] + output_def["force"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = out_def_data["energy_derv_c_redu"] + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = out_def_data["energy_derv_c"] + output_def["atom_virial"].squeeze(-3) + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + if self._hessian_enabled: + output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] + return output_def + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + model_ret = self.forward_common( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze( + -3 + ) + else: + model_predict["force"] = model_ret["dforce"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + if self._hessian_enabled: + model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-2) + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + @torch.jit.export + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ) -> dict[str, torch.Tensor]: + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret[ + "energy_derv_c" + ].squeeze(-3) + else: + assert model_ret["dforce"] is not None + model_predict["dforce"] = model_ret["dforce"] + else: + model_predict = model_ret + return model_predict diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index 37ffec2725..65d7983fd6 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -27,6 +27,12 @@ from .type_predict import ( TypePredictNet, ) +from .lr_fitting import ( + LRFittingNet, +) +from .sog_energy_fitting import ( + SOGEnergyFittingNet, +) __all__ = [ "BaseFitting", @@ -36,7 +42,9 @@ "EnergyFittingNet", "EnergyFittingNetDirect", "Fitting", + "LRFittingNet", "PolarFittingNet", "PropertyFittingNet", + "SOGEnergyFittingNet", "TypePredictNet", ] diff --git a/deepmd/pt/model/task/lr_fitting.py b/deepmd/pt/model/task/lr_fitting.py new file mode 100644 index 0000000000..aba8a5c238 --- /dev/null +++ b/deepmd/pt/model/task/lr_fitting.py @@ -0,0 +1,629 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import Any, Optional, Union +from abc import abstractmethod + +import numpy as np +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.dpmodel.utils.seed import ( + child_seed, +) +from deepmd.pt.model.network.mlp import ( + FittingNet, + NetworkCollection, +) +from deepmd.pt.model.task.fitting import ( + Fitting, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, + PRECISION_DICT, +) +from deepmd.pt.utils.exclude_mask import ( + AtomExcludeMask, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) +from deepmd.utils.finetune import ( + get_index_between_two_maps, + map_atom_exclude_types, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE + +@Fitting.register("lr") +class LRFittingNet(Fitting): + """Construct a general sr+lr interactions fitting net. + + Parameters + ---------- + var_name : str + The atomic property to fit. + ntypes : int + Element count. + dim_descrpt : int + Embedding width per atom. + dim_out_sr : int + The output dimension of the sr fitting net. + dim_out_lr : int + The output dimension of the lr fitting net. + neuron_sr : list[int] + Number of neurons in each hidden layers of the sr fitting net. + neuron_lr : list[int] + Number of neurons in each hidden layers of the lr fitting net. + bias_atom_e : torch.Tensor, optional + Average energy per atom for each element. + resnet_dt : bool + Using time-step in the ResNet construction. + numb_fparam : int + Number of frame parameters. + numb_aparam : int + Number of atomic parameters. + dim_case_embd : int + Dimension of case specific embedding. + activation_function : str + Activation function. + precision : str + Numerical precision. + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. + rcond : float, optional + The condition number for the regression of atomic energy. + seed : int, optional + Random seed. + exclude_types: list[int] + Atomic contributions of the excluded atom types are set zero. + trainable : Union[list[bool], bool] + If the parameters in the fitting net are trainable. + Now this only supports setting all the parameters in the fitting net at one state. + When in list[bool], the trainable will be True only if all the boolean parameters are True. + remove_vaccum_contribution: list[bool], optional + Remove vacuum contribution before the bias is added. The list assigned each + type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same + length as `ntypes` signaling if or not removing the vacuum contribution for the atom types in the list. + type_map: list[str], Optional + A list of strings. Give the name to each type of atoms. + use_aparam_as_mask: bool + If True, the aparam will not be used in fitting net for embedding. + default_fparam: list[float], optional + The default frame parameter. If set, when `fparam.npy` files are not included in the data system, + this value will be used as the default value for the frame parameter in the fitting net. + """ + + def __init__( + self, + var_name: str, + ntypes: int, + dim_descrpt: int, + dim_out_sr: int, + dim_out_lr: int, + neuron_sr: list[int] = [128, 128, 128], + neuron_lr: list[int] = [128, 128, 128], + bias_atom_e: Optional[torch.Tensor] = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + dim_case_embd: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + mixed_types: bool = True, + rcond: Optional[float] = None, + seed: Optional[Union[int, list[int]]] = None, + exclude_types: list[int] = [], + trainable: Union[bool, list[bool]] = True, + remove_vaccum_contribution: Optional[list[bool]] = None, + type_map: Optional[list[str]] = None, + use_aparam_as_mask: bool = False, + default_fparam: Optional[list[float]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + self.var_name = var_name + self.ntypes = ntypes + self.dim_descrpt = dim_descrpt + self.dim_out_sr = dim_out_sr + self.dim_out_lr = dim_out_lr + self.neuron_sr = neuron_sr + self.neuron_lr = neuron_lr + self.mixed_types = mixed_types + self.resnet_dt = resnet_dt + self.numb_fparam = numb_fparam + self.numb_aparam = numb_aparam + self.default_fparam = default_fparam + self.dim_case_embd = dim_case_embd + self.activation_function = activation_function + self.precision = precision + self.prec = PRECISION_DICT[self.precision] + self.rcond = rcond + self.seed = seed + self.type_map = type_map + self.use_aparam_as_mask = use_aparam_as_mask + self.reinit_exclude(exclude_types) + self.trainable = trainable + # need support for each layer settings + self.trainable = ( + all(self.trainable) if isinstance(self.trainable, list) else self.trainable + ) + self.remove_vaccum_contribution = remove_vaccum_contribution + + self.sr_net_dim_out = self._sr_net_out_dim() + self.lr_net_dim_out = self._lr_net_out_dim() + # init constants + if bias_atom_e is None: + bias_atom_e = np.zeros([self.ntypes, self.sr_net_dim_out], dtype=np.float64) + bias_atom_e = torch.tensor( + bias_atom_e, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=device + ) + bias_atom_e = bias_atom_e.view([self.ntypes, self.sr_net_dim_out]) + if not self.mixed_types: + assert self.ntypes == bias_atom_e.shape[0], "Element count mismatches!" + self.register_buffer("bias_atom_e", bias_atom_e) + + if self.numb_fparam > 0: + self.register_buffer( + "fparam_avg", + torch.zeros(self.numb_fparam, dtype=self.prec, device=env.DEVICE), + ) + self.register_buffer( + "fparam_inv_std", + torch.ones(self.numb_fparam, dtype=self.prec, device=env.DEVICE), + ) + else: + self.fparam_avg, self.fparam_inv_std = None, None + if self.numb_aparam > 0: + self.register_buffer( + "aparam_avg", + torch.zeros(self.numb_aparam, dtype=self.prec, device=env.DEVICE), + ) + self.register_buffer( + "aparam_inv_std", + torch.ones(self.numb_aparam, dtype=self.prec, device=env.DEVICE), + ) + else: + self.aparam_avg, self.aparam_inv_std = None, None + + if self.dim_case_embd > 0: + self.register_buffer( + "case_embd", + torch.zeros(self.dim_case_embd, dtype=self.prec, device=env.DEVICE), + ) + else: + self.case_embd = None + + if self.default_fparam is not None: + if self.numb_fparam > 0: + assert len(self.default_fparam) == self.numb_fparam, ( + "default_fparam length mismatch!" + ) + self.register_buffer( + "default_fparam_tensor", + torch.tensor( + np.array(self.default_fparam), + dtype=self.prec, + device=env.DEVICE, + ), + ) + else: + self.default_fparam_tensor = None + + in_dim = ( + self.dim_descrpt + + self.numb_fparam + + (0 if self.use_aparam_as_mask else self.numb_aparam) + + self.dim_case_embd + ) + + net_count = self.ntypes if not self.mixed_types else 1 + self.filter_layers_lr = NetworkCollection( + 1 if not self.mixed_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + 1, + self.neuron_lr, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + seed=child_seed(self.seed, ii * 2), + trainable=self.trainable, + ) + for ii in range(net_count) + ], + ) + self.filter_layers_sr = NetworkCollection( + 1 if not self.mixed_types else 0, + self.ntypes, + network_type="fitting_network", + networks=[ + FittingNet( + in_dim, + 1, + self.neuron_sr, + self.activation_function, + self.resnet_dt, + self.precision, + bias_out=True, + seed=child_seed(self.seed, ii * 2 + 1), + trainable=self.trainable, + ) + for ii in range(net_count) + ], + ) + + for param in self.parameters(): + param.requires_grad = self.trainable + + self.eval_return_middle_output = False + + def reinit_exclude( + self, + exclude_types: list[int] = [], + ) -> None: + self.exclude_types = exclude_types + self.emask = AtomExcludeMask(self.ntypes, self.exclude_types) + + def change_type_map( + self, + type_map: list[str], + model_with_new_type_stat: Optional["LRFittingNet"] = None, + ) -> None: + """Change the type related params to new ones, according to `type_map` and the original one in the model. + If there are new types in `type_map`, statistics will be updated accordingly to `model_with_new_type_stat` for these new types. + """ + assert self.type_map is not None, ( + "'type_map' must be defined when performing type changing!" + ) + assert self.mixed_types, "Only models in mixed types can perform type changing!" + remap_index, has_new_type = get_index_between_two_maps(self.type_map, type_map) + self.type_map = type_map + self.ntypes = len(type_map) + self.reinit_exclude(map_atom_exclude_types(self.exclude_types, remap_index)) + if has_new_type: + extend_shape = [len(type_map), *list(self.bias_atom_e.shape[1:])] + extend_bias_atom_e = torch.zeros( + extend_shape, + dtype=self.bias_atom_e.dtype, + device=self.bias_atom_e.device, + ) + self.bias_atom_e = torch.cat([self.bias_atom_e, extend_bias_atom_e], dim=0) + self.bias_atom_e = self.bias_atom_e[remap_index] + + def serialize(self) -> dict: + """Serialize the fitting to dict.""" + return { + "@class": "LRFitting", + "@version": 1, + "var_name": self.var_name, + "ntypes": self.ntypes, + "dim_descrpt": self.dim_descrpt, + "dim_out_sr": self.dim_out_sr, + "dim_out_lr": self.dim_out_lr, + "neuron_sr": self.neuron_sr, + "neuron_lr": self.neuron_lr, + "resnet_dt": self.resnet_dt, + "numb_fparam": self.numb_fparam, + "numb_aparam": self.numb_aparam, + "dim_case_embd": self.dim_case_embd, + "default_fparam": self.default_fparam, + "activation_function": self.activation_function, + "precision": self.precision, + "mixed_types": self.mixed_types, + "nets_sr": self.filter_layers_sr.serialize(), + "nets_lr": self.filter_layers_lr.serialize(), + "rcond": self.rcond, + "exclude_types": self.exclude_types, + "@variables": { + "bias_atom_e": to_numpy_array(self.bias_atom_e), + "case_embd": to_numpy_array(self.case_embd), + "fparam_avg": to_numpy_array(self.fparam_avg), + "fparam_inv_std": to_numpy_array(self.fparam_inv_std), + "aparam_avg": to_numpy_array(self.aparam_avg), + "aparam_inv_std": to_numpy_array(self.aparam_inv_std), + }, + "type_map": self.type_map, + # "tot_ener_zero": self.tot_ener_zero , + # "trainable": self.trainable , + # "atom_ener": self.atom_ener , + # "layer_name": self.layer_name , + # "spin": self.spin , + ## NOTICE: not supported by far + "tot_ener_zero": False, + "trainable_sr": [self.trainable] * (len(self.neuron_sr) + 1), + "trainable_lr": [self.trainable] * (len(self.neuron_lr) + 1), + "layer_name": None, + "use_aparam_as_mask": self.use_aparam_as_mask, + "spin": None, + } + + @classmethod + def deserialize(cls, data: dict) -> "LRFittingNet": + data = data.copy() + variables = data.pop("@variables") + nets_sr = data.pop("nets_sr") + nets_lr = data.pop("nets_lr") + obj = cls(**data) + for kk in variables.keys(): + obj[kk] = to_torch_tensor(variables[kk]) + obj.filter_layers_sr = NetworkCollection.deserialize(nets_sr) + obj.filter_layers_lr = NetworkCollection.deserialize(nets_lr) + return obj + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this atomic model.""" + return self.numb_fparam + + def has_default_fparam(self) -> bool: + """Check if the fitting has default frame parameters.""" + return self.default_fparam is not None + + def get_default_fparam(self) -> Optional[torch.Tensor]: + return self.default_fparam_tensor + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this atomic model.""" + return self.numb_aparam + + # make jit happy + exclude_types: list[int] + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + # make jit happy + sel_type: list[int] = [] + for ii in range(self.ntypes): + if ii not in self.exclude_types: + sel_type.append(ii) + return sel_type + + def get_type_map(self) -> list[str]: + """Get the name to each type of atoms.""" + return self.type_map + + def set_case_embd(self, case_idx: int) -> None: + """ + Set the case embedding of this fitting net by the given case_idx, + typically concatenated with the output of the descriptor and fed into the fitting net. + """ + self.case_embd = torch.eye(self.dim_case_embd, dtype=self.prec, device=device)[ + case_idx + ] + + def set_return_middle_output(self, return_middle_output: bool = True) -> None: + self.eval_return_middle_output = return_middle_output + + def __setitem__(self, key: str, value: torch.Tensor) -> None: + if key in ["bias_atom_e"]: + value = value.view([self.ntypes, self._net_out_dim()]) + self.bias_atom_e = value + elif key in ["fparam_avg"]: + self.fparam_avg = value + elif key in ["fparam_inv_std"]: + self.fparam_inv_std = value + elif key in ["aparam_avg"]: + self.aparam_avg = value + elif key in ["aparam_inv_std"]: + self.aparam_inv_std = value + elif key in ["case_embd"]: + self.case_embd = value + elif key in ["scale"]: + self.scale = value + elif key in ["default_fparam_tensor"]: + self.default_fparam_tensor = value + else: + raise KeyError(key) + + def __getitem__(self, key: str) -> torch.Tensor: + if key in ["bias_atom_e"]: + return self.bias_atom_e + elif key in ["fparam_avg"]: + return self.fparam_avg + elif key in ["fparam_inv_std"]: + return self.fparam_inv_std + elif key in ["aparam_avg"]: + return self.aparam_avg + elif key in ["aparam_inv_std"]: + return self.aparam_inv_std + elif key in ["case_embd"]: + return self.case_embd + elif key in ["scale"]: + return self.scale + elif key in ["default_fparam_tensor"]: + return self.default_fparam_tensor + else: + raise KeyError(key) + + def _sr_net_out_dim(self) -> int: + """Set the SRFittingNet output dim.""" + return self.dim_out_sr + + def _lr_net_out_dim(self) -> int: + """Set the LR FittingNet output dim.""" + return self.dim_out_lr + + def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor: + return torch.tile(xx.view([1, self.numb_fparam]), [nb, 1]) + + def _extend_a_avg_std(self, xx: torch.Tensor, nb: int, nloc: int) -> torch.Tensor: + return torch.tile(xx.view([1, 1, self.numb_aparam]), [nb, nloc, 1]) + + def _forward_common( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ) -> dict[str, torch.Tensor]: + xx = descriptor.to(self.prec) + nf, nloc, nd = xx.shape + + if self.numb_fparam > 0 and fparam is None: + assert self.default_fparam_tensor is not None + fparam = torch.tile(self.default_fparam_tensor.unsqueeze(0), [nf, 1]) + + fparam = fparam.to(self.prec) if fparam is not None else None + aparam = aparam.to(self.prec) if aparam is not None else None + + if self.remove_vaccum_contribution is not None: + xx_zeros = torch.zeros_like(xx) + else: + xx_zeros = None + + if nd != self.dim_descrpt: + raise ValueError( + f"get an input descriptor of dim {nd}," + f"which is not consistent with {self.dim_descrpt}." + ) + + if self.numb_fparam > 0: + assert fparam is not None, "fparam should not be None" + assert self.fparam_avg is not None + assert self.fparam_inv_std is not None + if fparam.shape[-1] != self.numb_fparam: + raise ValueError( + f"get an input fparam of dim {fparam.shape[-1]}, " + f"which is not consistent with {self.numb_fparam}." + ) + fparam = fparam.view([nf, self.numb_fparam]) + nb, _ = fparam.shape + t_fparam_avg = self._extend_f_avg_std(self.fparam_avg, nb) + t_fparam_inv_std = self._extend_f_avg_std(self.fparam_inv_std, nb) + fparam = (fparam - t_fparam_avg) * t_fparam_inv_std + fparam = torch.tile(fparam.reshape([nf, 1, -1]), [1, nloc, 1]) + xx = torch.cat([xx, fparam], dim=-1) + if xx_zeros is not None: + xx_zeros = torch.cat([xx_zeros, fparam], dim=-1) + + if self.numb_aparam > 0 and not self.use_aparam_as_mask: + assert aparam is not None, "aparam should not be None" + assert self.aparam_avg is not None + assert self.aparam_inv_std is not None + if aparam.shape[-1] != self.numb_aparam: + raise ValueError( + f"get an input aparam of dim {aparam.shape[-1]}, " + f"which is not consistent with {self.numb_aparam}." + ) + aparam = aparam.view([nf, -1, self.numb_aparam]) + nb, nloc, _ = aparam.shape + t_aparam_avg = self._extend_a_avg_std(self.aparam_avg, nb, nloc) + t_aparam_inv_std = self._extend_a_avg_std(self.aparam_inv_std, nb, nloc) + aparam = (aparam - t_aparam_avg) * t_aparam_inv_std + xx = torch.cat([xx, aparam], dim=-1) + if xx_zeros is not None: + xx_zeros = torch.cat([xx_zeros, aparam], dim=-1) + + if self.dim_case_embd > 0: + assert self.case_embd is not None + case_embd = torch.tile(self.case_embd.reshape([1, 1, -1]), [nf, nloc, 1]) + xx = torch.cat([xx, case_embd], dim=-1) + if xx_zeros is not None: + xx_zeros = torch.cat([xx_zeros, case_embd], dim=-1) + + results: dict[str, torch.Tensor] = {} + sr_out = self._apply_networks( + self.filter_layers_sr, + self.neuron_sr, + self.sr_net_dim_out, + xx, + xx_zeros, + atype, + middle_output=results, + bool_bias=True, + ) + lr_out = self._apply_networks( + self.filter_layers_lr, + self.neuron_lr, + self.lr_net_dim_out, + xx, + xx_zeros, + atype, + middle_output=results, + ) + mask = self.emask(atype).to(torch.bool) + sr_out = torch.where(mask[:, :, None], sr_out, 0.0) + lr_out = torch.where(mask[:, :, None], lr_out, 0.0) + results.update({"sr": sr_out, "lr": lr_out}) + return results + + def _apply_networks( + self, + layers: NetworkCollection, + neuron: int, + dim_out: int, + xx: torch.Tensor, + xx_zeros: Optional[torch.Tensor], + atype: torch.Tensor, + middle_output: Optional[dict[str, torch.Tensor]], + bool_bias: bool = False, + ) -> torch.Tensor: + nf, nloc, _ = xx.shape + outs = torch.zeros((nf, nloc, 1), dtype=self.prec, device=xx.device) + if self.mixed_types: + atom_property = layers.networks[0](xx) + if self.eval_return_middle_output and middle_output is not None: + middle_output["middle_output"] = layers.networks[0].call_until_last( + xx + ) + if xx_zeros is not None: + atom_property -= layers.networks[0](xx_zeros) + outs = outs + atom_property + else: + if self.eval_return_middle_output and middle_output is not None: + outs_middle = torch.zeros( + (nf, nloc, neuron[-1]), + dtype=self.prec, + device=xx.device, + ) + for type_i, ll in enumerate(layers.networks): + mask = (atype == type_i).unsqueeze(-1) + mask = torch.tile(mask, (1, 1, dim_out)) + middle_output_type = ll.call_until_last(xx) + middle_output_type = torch.where( + torch.tile(mask, (1, 1, neuron[-1])), + middle_output_type, + 0.0, + ) + outs_middle = outs_middle + middle_output_type + middle_output["middle_output"] = outs_middle + for type_i, ll in enumerate(layers.networks): + mask = (atype == type_i).unsqueeze(-1) + mask = torch.tile(mask, (1, 1, dim_out)) + atom_property = ll(xx) + if xx_zeros is not None: + assert self.remove_vaccum_contribution is not None + if not ( + len(self.remove_vaccum_contribution) > type_i + and not self.remove_vaccum_contribution[type_i] + ): + atom_property -= ll(xx_zeros) + if bool_bias: + atom_property = atom_property + self.bias_atom_e[type_i].to(self.prec) + else: + atom_property = atom_property + atom_property = torch.where(mask, atom_property, 0.0) + outs = outs + atom_property + return outs \ No newline at end of file diff --git a/deepmd/pt/model/task/sog_energy_fitting.py b/deepmd/pt/model/task/sog_energy_fitting.py new file mode 100644 index 0000000000..0b868aa710 --- /dev/null +++ b/deepmd/pt/model/task/sog_energy_fitting.py @@ -0,0 +1,282 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from typing import ( + Any, + Optional, + Union, +) + +import numpy as np +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, + fitting_check_output, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE + +log = logging.getLogger(__name__) +from deepmd.pt.model.task.lr_fitting import ( + LRFittingNet, +) + + +SOG_DEFAULT_SHIFT = np.array([ + 0.2750, + 0.1375, + 0.0688, + 0.0344, + 0.0172, + 0.0086, + 0.0043, + 0.0021, + 0.0011, + 0.0005, + 0.0003, + 0.0001, +], dtype=env.GLOBAL_NP_FLOAT_PRECISION) +SOG_DEFAULT_AMPLITUDE = np.array([ + 2.8, + 5.7, + 11.4, + 22.7, + 45.5, + 91.0, + 182.0, + 364.0, + 728.0, + 1456.0, + 2912.0, + 5823.9, +], dtype=env.GLOBAL_NP_FLOAT_PRECISION) + + +@LRFittingNet.register("sog_energy") +@fitting_check_output +class SOGEnergyFittingNet(LRFittingNet): + """Construct a SOG sr+lr interactions fitting net. + + Parameters + ---------- + var_name : str + The atomic property to fit. + ntypes : int + Element count. + dim_descrpt : int + Embedding width per atom. + dim_out_sr : int + The output dimension of the sr fitting net. + dim_out_lr : int + The output dimension of the lr fitting net. + neuron_sr : list[int] + Number of neurons in each hidden layers of the sr fitting net. + neuron_lr : list[int] + Number of neurons in each hidden layers of the lr fitting net. + bias_atom_e : torch.Tensor, optional + Average energy per atom for each element. + resnet_dt : bool + Using time-step in the ResNet construction. + numb_fparam : int + Number of frame parameters. + numb_aparam : int + Number of atomic parameters. + dim_case_embd : int + Dimension of case specific embedding. + activation_function : str + Activation function. + precision : str + Numerical precision. + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. + rcond : float, optional + The condition number for the regression of atomic energy. + seed : int, optional + Random seed. + exclude_types: list[int] + Atomic contributions of the excluded atom types are set zero. + trainable : Union[list[bool], bool] + If the parameters in the fitting net are trainable. + Now this only supports setting all the parameters in the fitting net at one state. + When in list[bool], the trainable will be True only if all the boolean parameters are True. + remove_vaccum_contribution: list[bool], optional + Remove vacuum contribution before the bias is added. The list assigned each + type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same + length as `ntypes` signaling if or not removing the vacuum contribution for the atom types in the list. + type_map: list[str], Optional + A list of strings. Give the name to each type of atoms. + use_aparam_as_mask: bool + If True, the aparam will not be used in fitting net for embedding. + default_fparam: list[float], optional + The default frame parameter. If set, when `fparam.npy` files are not included in the data system, + this value will be used as the default value for the frame parameter in the fitting net. + """ + + def __init__( + self, + var_name: str, + ntypes: int, + dim_descrpt: int, + dim_out_sr: int, + dim_out_lr: int, + neuron_sr: list[int] = [128, 128, 128], + neuron_lr: list[int] = [128, 128, 128], + bias_atom_e: Optional[torch.Tensor] = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + dim_case_embd: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + mixed_types: bool = True, + rcond: Optional[float] = None, + seed: Optional[Union[int, list[int]]] = None, + exclude_types: list[int] = [], + trainable: Union[bool, list[bool]] = True, + remove_vaccum_contribution: Optional[list[bool]] = None, + type_map: Optional[list[str]] = None, + use_aparam_as_mask: bool = False, + default_fparam: Optional[list[float]] = None, + shift: Optional[Union[list[float], torch.Tensor]] = None, + amplitude: Optional[Union[list[float], torch.Tensor]] = None, + **kwargs: Any, + ) -> None: + super().__init__( + var_name=var_name, + ntypes=ntypes, + dim_descrpt=dim_descrpt, + dim_out_sr=dim_out_sr, + dim_out_lr=dim_out_lr, + neuron_sr=neuron_sr, + neuron_lr=neuron_lr, + bias_atom_e=bias_atom_e, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, + activation_function=activation_function, + precision=precision, + mixed_types=mixed_types, + rcond=rcond, + seed=seed, + exclude_types=exclude_types, + trainable=trainable, + remove_vaccum_contribution=remove_vaccum_contribution, + type_map=type_map, + use_aparam_as_mask=use_aparam_as_mask, + default_fparam=default_fparam, + **kwargs, + ) + if isinstance(shift, (list, tuple)): + shift = np.array(shift, dtype=env.GLOBAL_NP_FLOAT_PRECISION) + if isinstance(amplitude, (list, tuple)): + amplitude = np.array(amplitude, dtype=env.GLOBAL_NP_FLOAT_PRECISION) + shift_tensor = to_torch_tensor(shift) + amplitude_tensor = to_torch_tensor(amplitude) + if shift_tensor is None: + shift_tensor = to_torch_tensor(SOG_DEFAULT_SHIFT) + if amplitude_tensor is None: + amplitude_tensor = to_torch_tensor(SOG_DEFAULT_AMPLITUDE) + # Register as trainable parameters so they are optimized with the fitting net. + self.shift = torch.nn.Parameter( + shift_tensor.to(dtype=dtype, device=device), + requires_grad=bool(self.trainable), + ) + self.amplitude = torch.nn.Parameter( + amplitude_tensor.to(dtype=dtype, device=device), + requires_grad=bool(self.trainable), + ) + + + def output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ) + ] + ) + + def serialize(self) -> dict: + data = super().serialize() + data["type"] = "sog_energy" + data["@variables"]["shift"] = to_numpy_array(self.shift) + data["@variables"]["amplitude"] = to_numpy_array(self.amplitude) + return data + + @classmethod + def deserialize(cls, data: dict) -> "SOGEnergyFittingNet": + data = data.copy() + variables = data.get("@variables", {}) + obj = super().deserialize(data) + shift_tensor = to_torch_tensor(variables.get("shift", None)) + amplitude_tensor = to_torch_tensor(variables.get("amplitude", None)) + # Backward compatibility: if serialized variables miss shift/amplitude, + # keep defaults initialized in __init__. + if shift_tensor is not None: + obj.shift = torch.nn.Parameter( + shift_tensor.to(dtype=dtype, device=device), + requires_grad=bool(obj.trainable), + ) + if amplitude_tensor is not None: + obj.amplitude = torch.nn.Parameter( + amplitude_tensor.to(dtype=dtype, device=device), + requires_grad=bool(obj.trainable), + ) + return obj + + def corr_head( + self, + lr_val: torch.Tensor, + amplitude: torch.Tensor, + shift: torch.Tensor, + ) -> torch.Tensor: + # TODO: + # Long-range correction energy calculation + return torch.zeros_like(lr_val) + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + ) -> dict[str, torch.Tensor]: + out = self._forward_common( + descriptor=descriptor, + atype=atype, + gr=gr, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + short_energy = out["sr"] + corr_energy = self.corr_head(out["lr"], self.amplitude, self.shift) + result = {"energy": short_energy + corr_energy} + if "middle_output" in out: + result["middle_output"] = out["middle_output"] + return result + + # make jit happy with torch 2.0.0 + exclude_types: list[int] \ No newline at end of file diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 2d20319888..4f4bb4db4b 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -57,6 +57,7 @@ doc_ener = "Fit an energy model (potential energy surface)." doc_dos = "Fit a density of states model. The total density of states / site-projected density of states labels should be provided by `dos.npy` or `atom_dos.npy` in each data system. The file has number of frames lines and number of energy grid columns (times number of atoms in `atom_dos.npy`). See `loss` parameter." doc_dipole = "Fit an atomic dipole model. Global dipole labels or atomic dipole labels for all the selected atoms (see `sel_type`) should be provided by `dipole.npy` in each data system. The file either has number of frames lines and 3 times of number of selected atoms columns, or has number of frames lines and 3 columns. See `loss` parameter." +doc_sog = "Fit a SOG-energy model (SR+LR)." doc_polar = "Fit an atomic polarizability model. Global polarizazbility labels or atomic polarizability labels for all the selected atoms (see `sel_type`) should be provided by `polarizability.npy` in each data system. The file with has number of frames lines and 9 times of number of selected atoms columns, or has number of frames lines and 9 columns. See `loss` parameter." # modifier doc_dipole_charge = "Use WFCC to model the electronic structure of the system. Correct the long-range interaction." @@ -2054,6 +2055,175 @@ def fitting_polar() -> list[Argument]: # def fitting_global_polar(): # return fitting_polar() +@fitting_args_plugin.register("sog_energy", doc=doc_sog) +def fitting_sog_energy() -> list[Argument]: + doc_var_name = "The atomic property name used by the LR fitting net. Usually set to `energy`." + doc_dim_out_sr = "The output dimension of the short-range fitting branch." + doc_dim_out_lr = "The output dimension of the long-range fitting branch." + doc_neuron_sr = "The number of neurons in each hidden layer of the short-range fitting net." + doc_neuron_lr = "The number of neurons in each hidden layer of the long-range fitting net." + doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." + doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." + doc_default_fparam = "The default frame parameter. If set, when `fparam.npy` files are not included in the data system, this value will be used as the default value for the frame parameter in the fitting net." + doc_dim_case_embd = "The dimension of the case embedding embedding. When training or fine-tuning a multitask model with case embedding embeddings, this number should be set to the number of model branches." + doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' + doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' + doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be bool or list[bool]. When list[bool] is given, all values must be True to make parameters trainable." + doc_rcond = "The condition number used to determine the initial energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details." + doc_seed = "Random seed for parameter initialization of the fitting net" + doc_exclude_types = "The excluded atom types whose atomic contributions are set to zero." + doc_use_aparam_as_mask = ( + "Whether to use the aparam as a mask in input." + "If True, the aparam will not be used in fitting net for embedding." + ) + doc_shift = "Shift values of the SOG long-range correction kernels." + doc_amplitude = "Amplitude values of the SOG long-range correction kernels." + + return [ + Argument( + "var_name", + str, + optional=True, + default="energy", + doc=doc_only_pt_supported + doc_var_name, + ), + Argument( + "dim_out_sr", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_dim_out_sr, + ), + Argument( + "dim_out_lr", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_dim_out_lr, + ), + Argument( + "neuron_sr", + list[int], + optional=True, + default=[128, 128, 128], + doc=doc_only_pt_supported + doc_neuron_sr, + ), + Argument( + "neuron_lr", + list[int], + optional=True, + default=[128, 128, 128], + doc=doc_only_pt_supported + doc_neuron_lr, + ), + Argument( + "numb_fparam", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_numb_fparam, + ), + Argument( + "numb_aparam", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_numb_aparam, + ), + Argument( + "default_fparam", + list[float], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_default_fparam, + ), + Argument( + "dim_case_embd", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_dim_case_embd, + ), + Argument( + "activation_function", + str, + optional=True, + default="tanh", + doc=doc_activation_function, + ), + Argument("precision", str, optional=True, default="default", doc=doc_precision), + Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt), + Argument( + "trainable", + [list[bool], bool], + optional=True, + default=True, + doc=doc_only_pt_supported + doc_trainable, + ), + Argument( + "rcond", + [float, type(None)], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_rcond, + ), + Argument("seed", [int, None], optional=True, doc=doc_seed), + Argument( + "exclude_types", + list[int], + optional=True, + default=[], + doc=doc_only_pt_supported + doc_exclude_types, + ), + Argument( + "use_aparam_as_mask", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_use_aparam_as_mask, + ), + Argument( + "shift", + list[float], + optional=True, + default=[ + 0.2750, + 0.1375, + 0.0688, + 0.0344, + 0.0172, + 0.0086, + 0.0043, + 0.0021, + 0.0011, + 0.0005, + 0.0003, + 0.0001, + ], + doc=doc_only_pt_supported + doc_shift, + ), + Argument( + "amplitude", + list[float], + optional=True, + default=[ + 2.8, + 5.7, + 11.4, + 22.7, + 45.5, + 91.0, + 182.0, + 364.0, + 728.0, + 1456.0, + 2912.0, + 5823.9, + ], + doc=doc_only_pt_supported + doc_amplitude, + ), + ] + @fitting_args_plugin.register("dipole", doc=doc_dipole) diff --git a/examples/water/dpa3/dpa3.hdf5 b/examples/water/dpa3/dpa3.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..90ae7bb1d6ae7ca6daec0d7ce02b99ac208ff5b6 GIT binary patch literal 4930 zcmeHJy-vbV6h60=SRp177UR^Bk%7UDYQS_7qb9mlNE>ue2qi`~M|=Z!7oWl~`Ub8( zf-hjOmvcT7YN}y^;4MkNbI;G^-tT@r?So%G*v#+bNnegbYqT{rahu4sUt$l&h+l}I z1V@6;78{E6k7!%`MVUXHcY54tP?CxHfl@~AIkp@)4E`$y{Knpq-tf8X8Y+Dioq>#| z>rYMSkw=Kmk6HTlZJz5=borNlcSX4F{7WzuE5JZnzn<)AB>KS?+gGhPXbrm8L39)K zJ45*6K{(^?w&SRO8->AO*o%@R=pB6)bti|-y4EoIrYpKdep<(2_1Q3+R82PO3?9mM05vZA;`M2 zzjak!|9u%1345}etZY{T_ZwLmtL0yut}08Zt7SXgx2N9cOL#VB?2LJRA2s05^WzSz v6n6sj0ib?jbflErG-f~do_ukmlK+wD5r#p-fMLKeU>GnA7zP%^fZzNAK5J{( literal 0 HcmV?d00001 diff --git a/examples/water/dpa3/input_torch_copy.json b/examples/water/dpa3/input_torch_copy.json new file mode 100644 index 0000000000..c2ed3891e3 --- /dev/null +++ b/examples/water/dpa3/input_torch_copy.json @@ -0,0 +1,103 @@ +{ + "_comment": "that's all", + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 128, + "e_dim": 64, + "a_dim": 32, + "nlayers": 6, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 120, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 30, + "axis_neuron": 4, + "fix_stat_std": 0.3, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": true, + "update_angle": true, + "smooth_edge_update": true, + "edge_init_use_dist": true, + "use_exp_switch": true, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const" + }, + "activation_function": "silut:10.0", + "use_tebd_bias": false, + "precision": "float32", + "concat_output_tebd": false, + "seed": 1 + }, + "fitting_net": { + "neuron": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "precision": "float32", + "activation_function": "silut:10.0", + "seed": 1, + "_comment": " that's all" + }, + "_comment": " that's all" + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3e-5, + "_comment": "that's all" + }, + "loss": { + "type": "ener", + "start_pref_e": 0.2, + "limit_pref_e": 20, + "start_pref_f": 100, + "limit_pref_f": 60, + "start_pref_v": 0.02, + "limit_pref_v": 1, + "_comment": " that's all" + }, + "optimizer": { + "type": "AdamW", + "adam_beta1": 0.9, + "adam_beta2": 0.999, + "weight_decay": 0.001 + }, + "training": { + "stat_file": "./dpa3.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1, + "_comment": "that's all" + }, + "numb_steps": 10000, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 2000, + "_comment": "that's all" + } +} diff --git a/examples/water/sog/README.md b/examples/water/sog/README.md new file mode 100644 index 0000000000..a70130de77 --- /dev/null +++ b/examples/water/sog/README.md @@ -0,0 +1,16 @@ +# Input for the SOG model + +This directory provides a SOG training example based on the same water data split and training layout used by `examples/water/dpa3`. + +## Run + +```bash +cd examples/water/sog +dp --pt train input_torch.json --skip-neighbor-stat +``` + +## Notes + +- Descriptor: DPA3 (`model.descriptor.type = "dpa3"`) +- Fitting: SOG energy (`model.fitting_net.type = "sog_energy"`) +- Data systems are reused from `examples/water/data/data_0` to `data_3`. diff --git a/examples/water/sog/input_torch.json b/examples/water/sog/input_torch.json new file mode 100644 index 0000000000..d2ed256283 --- /dev/null +++ b/examples/water/sog/input_torch.json @@ -0,0 +1,134 @@ +{ + "_comment": "SOG example based on water/dpa3 layout", + "model": { + "type_map": [ + "O", + "H" + ], + "descriptor": { + "type": "dpa3", + "repflow": { + "n_dim": 128, + "e_dim": 64, + "a_dim": 32, + "nlayers": 6, + "e_rcut": 6.0, + "e_rcut_smth": 5.3, + "e_sel": 120, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 30, + "axis_neuron": 4, + "fix_stat_std": 0.3, + "a_compress_rate": 1, + "a_compress_e_rate": 2, + "a_compress_use_split": true, + "update_angle": true, + "smooth_edge_update": true, + "edge_init_use_dist": true, + "use_exp_switch": true, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": "const" + }, + "activation_function": "silut:10.0", + "use_tebd_bias": false, + "precision": "float32", + "concat_output_tebd": false, + "seed": 1 + }, + "fitting_net": { + "type": "sog_energy", + "var_name": "energy", + "dim_out_sr": 1, + "dim_out_lr": 1, + "neuron_sr": [ + 240, + 240, + 240 + ], + "neuron_lr": [ + 240, + 240, + 240 + ], + "resnet_dt": true, + "precision": "float32", + "activation_function": "silut:10.0", + "seed": 1, + "_comment": "SOG-specific long-range kernel parameters", + "shift": [ + 0.275, + 0.1375, + 0.0688, + 0.0344, + 0.0172, + 0.0086, + 0.0043, + 0.0021, + 0.0011, + 0.0005, + 0.0003, + 0.0001 + ], + "amplitude": [ + 2.8, + 5.7, + 11.4, + 22.7, + 45.5, + 91.0, + 182.0, + 364.0, + 728.0, + 1456.0, + 2912.0, + 5823.9 + ] + } + }, + "learning_rate": { + "type": "exp", + "decay_steps": 5000, + "start_lr": 0.001, + "stop_lr": 3e-5 + }, + "loss": { + "type": "ener", + "start_pref_e": 0.2, + "limit_pref_e": 20, + "start_pref_f": 100, + "limit_pref_f": 60, + "start_pref_v": 0.02, + "limit_pref_v": 1 + }, + "optimizer": { + "type": "AdamW", + "adam_beta1": 0.9, + "adam_beta2": 0.999, + "weight_decay": 0.001 + }, + "training": { + "stat_file": "./sog.hdf5", + "training_data": { + "systems": [ + "../data/data_0", + "../data/data_1", + "../data/data_2" + ], + "batch_size": 1 + }, + "validation_data": { + "systems": [ + "../data/data_3" + ], + "batch_size": 1 + }, + "numb_steps": 10000, + "gradient_max_norm": 5.0, + "seed": 10, + "disp_file": "lcurve.out", + "disp_freq": 100, + "save_freq": 2000 + } +} diff --git a/examples/water/sog/sog.hdf5 b/examples/water/sog/sog.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..90ae7bb1d6ae7ca6daec0d7ce02b99ac208ff5b6 GIT binary patch literal 4930 zcmeHJy-vbV6h60=SRp177UR^Bk%7UDYQS_7qb9mlNE>ue2qi`~M|=Z!7oWl~`Ub8( zf-hjOmvcT7YN}y^;4MkNbI;G^-tT@r?So%G*v#+bNnegbYqT{rahu4sUt$l&h+l}I z1V@6;78{E6k7!%`MVUXHcY54tP?CxHfl@~AIkp@)4E`$y{Knpq-tf8X8Y+Dioq>#| z>rYMSkw=Kmk6HTlZJz5=borNlcSX4F{7WzuE5JZnzn<)AB>KS?+gGhPXbrm8L39)K zJ45*6K{(^?w&SRO8->AO*o%@R=pB6)bti|-y4EoIrYpKdep<(2_1Q3+R82PO3?9mM05vZA;`M2 zzjak!|9u%1345}etZY{T_ZwLmtL0yut}08Zt7SXgx2N9cOL#VB?2LJRA2s05^WzSz v6n6sj0ib?jbflErG-f~do_ukmlK+wD5r#p-fMLKeU>GnA7zP%^fZzNAK5J{( literal 0 HcmV?d00001 diff --git a/pyproject.toml b/pyproject.toml index bd1508193e..17f357c1ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ build-backend = "backend.dp_backend" backend-path = ["."] [project] -name = "deepmd-kit" +name = "deepmd-kit-dev" dynamic = ["version", "optional-dependencies", "scripts", "readme"] description = "A deep learning package for many-body potential energy representation and molecular dynamics" authors = [ @@ -142,7 +142,7 @@ jax = [ ] [tool.deepmd_build_backend.scripts] -dp = "deepmd.main:main" +dp_dev = "deepmd.main:main" [dependency-groups] dev = [ From a531e419e2e94f8c552873a55c6fc21b5353afa4 Mon Sep 17 00:00:00 2001 From: zhenyuan5090 Date: Mon, 16 Mar 2026 16:30:59 +0800 Subject: [PATCH 2/8] little change --- deepmd/pt/model/task/sog_energy_fitting.py | 37 ++++++++++------------ examples/water/sog/input_torch.json | 2 +- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/deepmd/pt/model/task/sog_energy_fitting.py b/deepmd/pt/model/task/sog_energy_fitting.py index 0b868aa710..e8a9b3a55d 100644 --- a/deepmd/pt/model/task/sog_energy_fitting.py +++ b/deepmd/pt/model/task/sog_energy_fitting.py @@ -34,7 +34,7 @@ ) -SOG_DEFAULT_SHIFT = np.array([ +SOG_DEFAULT_SHIFT = to_numpy_array(np.array([ 0.2750, 0.1375, 0.0688, @@ -47,8 +47,8 @@ 0.0005, 0.0003, 0.0001, -], dtype=env.GLOBAL_NP_FLOAT_PRECISION) -SOG_DEFAULT_AMPLITUDE = np.array([ +])) +SOG_DEFAULT_AMPLITUDE = to_numpy_array(np.array([ 2.8, 5.7, 11.4, @@ -61,7 +61,7 @@ 1456.0, 2912.0, 5823.9, -], dtype=env.GLOBAL_NP_FLOAT_PRECISION) +])) @LRFittingNet.register("sog_energy") @@ -181,9 +181,9 @@ def __init__( **kwargs, ) if isinstance(shift, (list, tuple)): - shift = np.array(shift, dtype=env.GLOBAL_NP_FLOAT_PRECISION) + shift = to_numpy_array(np.array(shift)) if isinstance(amplitude, (list, tuple)): - amplitude = np.array(amplitude, dtype=env.GLOBAL_NP_FLOAT_PRECISION) + amplitude = to_numpy_array(np.array(amplitude)) shift_tensor = to_torch_tensor(shift) amplitude_tensor = to_torch_tensor(amplitude) if shift_tensor is None: @@ -191,12 +191,15 @@ def __init__( if amplitude_tensor is None: amplitude_tensor = to_torch_tensor(SOG_DEFAULT_AMPLITUDE) # Register as trainable parameters so they are optimized with the fitting net. - self.shift = torch.nn.Parameter( - shift_tensor.to(dtype=dtype, device=device), + wl_tensor = to_torch_tensor(amplitude_tensor * (torch.sqrt(torch.tensor(torch.pi))**3)*shift_tensor**3) + sl_tensor = to_torch_tensor(-torch.log(2/shift_tensor)) + + self.wl = torch.nn.Parameter( + wl_tensor.to(dtype=dtype, device=device), requires_grad=bool(self.trainable), ) - self.amplitude = torch.nn.Parameter( - amplitude_tensor.to(dtype=dtype, device=device), + self.sl = torch.nn.Parameter( + sl_tensor.to(dtype=dtype, device=device), requires_grad=bool(self.trainable), ) @@ -231,22 +234,14 @@ def deserialize(cls, data: dict) -> "SOGEnergyFittingNet": # Backward compatibility: if serialized variables miss shift/amplitude, # keep defaults initialized in __init__. if shift_tensor is not None: - obj.shift = torch.nn.Parameter( - shift_tensor.to(dtype=dtype, device=device), - requires_grad=bool(obj.trainable), - ) + obj.shift = shift_tensor if amplitude_tensor is not None: - obj.amplitude = torch.nn.Parameter( - amplitude_tensor.to(dtype=dtype, device=device), - requires_grad=bool(obj.trainable), - ) + obj.amplitude = amplitude_tensor return obj def corr_head( self, lr_val: torch.Tensor, - amplitude: torch.Tensor, - shift: torch.Tensor, ) -> torch.Tensor: # TODO: # Long-range correction energy calculation @@ -272,7 +267,7 @@ def forward( aparam=aparam, ) short_energy = out["sr"] - corr_energy = self.corr_head(out["lr"], self.amplitude, self.shift) + corr_energy = self.corr_head(out["lr"]) result = {"energy": short_energy + corr_energy} if "middle_output" in out: result["middle_output"] = out["middle_output"] diff --git a/examples/water/sog/input_torch.json b/examples/water/sog/input_torch.json index d2ed256283..9be7bf249b 100644 --- a/examples/water/sog/input_torch.json +++ b/examples/water/sog/input_torch.json @@ -41,7 +41,7 @@ "type": "sog_energy", "var_name": "energy", "dim_out_sr": 1, - "dim_out_lr": 1, + "dim_out_lr": 2, "neuron_sr": [ 240, 240, From 2993424fc7b779a4a985832a060f2362da99b797 Mon Sep 17 00:00:00 2001 From: zhenyuan5090 Date: Wed, 18 Mar 2026 21:31:17 +0800 Subject: [PATCH 3/8] baseline passed test --- .../pt/model/atomic_model/sog_atomic_model.py | 13 +- deepmd/pt/model/model/sog_model.py | 45 +++- deepmd/pt/model/task/sog_energy_fitting.py | 208 ++++++++++++++++-- examples/water/dpa3/input_torch_copy.json | 2 +- examples/water/sog/input_torch.json | 4 +- 5 files changed, 242 insertions(+), 30 deletions(-) diff --git a/deepmd/pt/model/atomic_model/sog_atomic_model.py b/deepmd/pt/model/atomic_model/sog_atomic_model.py index d83acd6f75..c24d33f75c 100644 --- a/deepmd/pt/model/atomic_model/sog_atomic_model.py +++ b/deepmd/pt/model/atomic_model/sog_atomic_model.py @@ -141,12 +141,17 @@ def forward_atomic( atype = extended_atype[:, :nloc] if self.do_grad_r() or self.do_grad_c(): extended_coord.requires_grad_(True) + + descriptor_comm_dict = comm_dict + if comm_dict is not None and "send_list" not in comm_dict: + descriptor_comm_dict = None + descriptor, rot_mat, g2, h2, sw = self.descriptor( extended_coord, extended_atype, nlist, mapping=mapping, - comm_dict=comm_dict, + comm_dict=descriptor_comm_dict, ) assert descriptor is not None if self.enable_eval_descriptor_hook: @@ -160,6 +165,12 @@ def forward_atomic( h2=h2, fparam=fparam, aparam=aparam, + coord=extended_coord[:, :nloc, :], + box=( + comm_dict["box"].view(nframes, 3, 3) + if comm_dict is not None and "box" in comm_dict + else None + ), ) if self.enable_eval_fitting_last_layer_hook and "middle_output" in energy_ret: diff --git a/deepmd/pt/model/model/sog_model.py b/deepmd/pt/model/model/sog_model.py index 3d74a70d0b..81d5a8705d 100644 --- a/deepmd/pt/model/model/sog_model.py +++ b/deepmd/pt/model/model/sog_model.py @@ -6,6 +6,13 @@ import torch +from deepmd.pt.model.model.transform_output import ( + communicate_extended_output, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + from deepmd.pt.model.atomic_model import ( SOGEnergyAtomicModel, ) @@ -100,14 +107,42 @@ def forward( aparam: Optional[torch.Tensor] = None, do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: - model_ret = self.forward_common( - coord, + cc, bb, fp, ap, input_prec = self._input_type_cast( + coord, box=box, fparam=fparam, aparam=aparam + ) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + cc, atype, - box, - fparam=fparam, - aparam=aparam, + self.get_rcut(), + self.get_sel(), + mixed_types=True, + box=bb, + ) + comm_dict: dict[str, torch.Tensor] | None = None + if bb is not None: + comm_dict = {"box": bb} + model_predict_lower = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fp, + aparam=ap, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + ) + model_ret = communicate_extended_output( + model_predict_lower, + self.model_output_def(), + mapping, do_atomic_virial=do_atomic_virial, ) + model_ret = self._output_type_cast(model_ret, input_prec) if self.get_fitting_net() is not None: model_predict = {} model_predict["atom_energy"] = model_ret["energy"] diff --git a/deepmd/pt/model/task/sog_energy_fitting.py b/deepmd/pt/model/task/sog_energy_fitting.py index e8a9b3a55d..85bb1a2388 100644 --- a/deepmd/pt/model/task/sog_energy_fitting.py +++ b/deepmd/pt/model/task/sog_energy_fitting.py @@ -8,6 +8,7 @@ import numpy as np import torch +import pytorch_finufft from deepmd.dpmodel import ( FittingOutputDef, @@ -34,7 +35,7 @@ ) -SOG_DEFAULT_SHIFT = to_numpy_array(np.array([ +SOG_DEFAULT_AMPLITUDE = to_numpy_array(np.array([ 0.2750, 0.1375, 0.0688, @@ -48,7 +49,7 @@ 0.0003, 0.0001, ])) -SOG_DEFAULT_AMPLITUDE = to_numpy_array(np.array([ +SOG_DEFAULT_SHIFT = to_numpy_array(np.array([ 2.8, 5.7, 11.4, @@ -152,6 +153,7 @@ def __init__( default_fparam: Optional[list[float]] = None, shift: Optional[Union[list[float], torch.Tensor]] = None, amplitude: Optional[Union[list[float], torch.Tensor]] = None, + n_dl: int = 1, **kwargs: Any, ) -> None: super().__init__( @@ -190,18 +192,29 @@ def __init__( shift_tensor = to_torch_tensor(SOG_DEFAULT_SHIFT) if amplitude_tensor is None: amplitude_tensor = to_torch_tensor(SOG_DEFAULT_AMPLITUDE) - # Register as trainable parameters so they are optimized with the fitting net. - wl_tensor = to_torch_tensor(amplitude_tensor * (torch.sqrt(torch.tensor(torch.pi))**3)*shift_tensor**3) - sl_tensor = to_torch_tensor(-torch.log(2/shift_tensor)) - + + shift_tensor = shift_tensor.to(dtype=dtype, device=device) + amplitude_tensor = amplitude_tensor.to(dtype=dtype, device=device) + pi_tensor = torch.tensor(torch.pi, dtype=dtype, device=device) + sqr_pi_tensor = torch.sqrt(pi_tensor) + shift_safe = torch.clamp( + shift_tensor, + min=torch.finfo(shift_tensor.dtype).eps, + ) + wl_tensor = amplitude_tensor * (sqr_pi_tensor**3) * (shift_safe**3) + sl_tensor = -torch.log(2.0 / shift_safe) + + self.n_dl = max(1, int(n_dl)) self.wl = torch.nn.Parameter( - wl_tensor.to(dtype=dtype, device=device), + wl_tensor, requires_grad=bool(self.trainable), ) self.sl = torch.nn.Parameter( - sl_tensor.to(dtype=dtype, device=device), + sl_tensor, requires_grad=bool(self.trainable), ) + self.remove_self_interaction = False + self._nufft_fallback_warned = False def output_def(self) -> FittingOutputDef: @@ -220,8 +233,16 @@ def output_def(self) -> FittingOutputDef: def serialize(self) -> dict: data = super().serialize() data["type"] = "sog_energy" - data["@variables"]["shift"] = to_numpy_array(self.shift) - data["@variables"]["amplitude"] = to_numpy_array(self.amplitude) + data["@variables"]["wl"] = to_numpy_array(self.wl) + data["@variables"]["sl"] = to_numpy_array(self.sl) + + pi_tensor = torch.tensor(torch.pi, dtype=self.sl.dtype, device=self.sl.device) + sqr_pi_tensor = torch.sqrt(pi_tensor) + shift_tensor = 2.0 * torch.exp(self.sl) + amplitude_tensor = self.wl / ((sqr_pi_tensor**3) * (shift_tensor**3)) + data["@variables"]["shift"] = to_numpy_array(shift_tensor) + data["@variables"]["amplitude"] = to_numpy_array(amplitude_tensor) + data["n_dl"] = self.n_dl return data @classmethod @@ -229,23 +250,165 @@ def deserialize(cls, data: dict) -> "SOGEnergyFittingNet": data = data.copy() variables = data.get("@variables", {}) obj = super().deserialize(data) + wl_tensor = to_torch_tensor(variables.get("wl", None)) + sl_tensor = to_torch_tensor(variables.get("sl", None)) shift_tensor = to_torch_tensor(variables.get("shift", None)) amplitude_tensor = to_torch_tensor(variables.get("amplitude", None)) - # Backward compatibility: if serialized variables miss shift/amplitude, - # keep defaults initialized in __init__. - if shift_tensor is not None: - obj.shift = shift_tensor - if amplitude_tensor is not None: - obj.amplitude = amplitude_tensor + + with torch.no_grad(): + if wl_tensor is not None and sl_tensor is not None: + obj.wl.copy_(wl_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device)) + obj.sl.copy_(sl_tensor.to(dtype=obj.sl.dtype, device=obj.sl.device)) + elif shift_tensor is not None and amplitude_tensor is not None: + pi_tensor = torch.tensor(torch.pi, dtype=obj.wl.dtype, device=obj.wl.device) + sqr_pi_tensor = torch.sqrt(pi_tensor) + shift_tensor = shift_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device) + amplitude_tensor = amplitude_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device) + shift_safe = torch.clamp( + shift_tensor, + min=torch.finfo(shift_tensor.dtype).eps, + ) + wl_tensor = amplitude_tensor * (sqr_pi_tensor**3) * (shift_safe**3) + sl_tensor = -torch.log(2.0 / shift_safe) + obj.wl.copy_(wl_tensor) + obj.sl.copy_(sl_tensor) return obj - def corr_head( + def _kernel_params(self) -> tuple[torch.Tensor, torch.Tensor]: + return self.wl, self.sl + + def compute_potential_SOG_triclinic_NUFFT( + self, + r_raw: torch.Tensor, + q: torch.Tensor, + box: torch.Tensor, + ) -> torch.Tensor: + runtime_device = torch.device(device) + if q.dim() == 1: + q = q.unsqueeze(-1) + + if not torch.isfinite(r_raw).all(): + raise ValueError("`r_raw` contains non-finite values, cannot run NUFFT.") + if not torch.isfinite(box).all(): + raise ValueError("`box` contains non-finite values, cannot run NUFFT.") + + real_dtype = torch.float64 if r_raw.dtype == torch.float64 else torch.float32 + complex_dtype = torch.complex128 if real_dtype == torch.float64 else torch.complex64 + + r_raw = r_raw.to(dtype=real_dtype, device=runtime_device) + q = q.to(dtype=real_dtype, device=runtime_device) + box = box.to(dtype=real_dtype, device=runtime_device) + + volume = torch.det(box) + if torch.abs(volume) <= torch.finfo(real_dtype).eps: + raise ValueError("`box` is singular (near-zero volume), cannot run NUFFT.") + + cell_inv = torch.linalg.inv(box) + # Keep points strictly in the principal unit cell so NUFFT points are in [-pi, pi). + r_frac = torch.matmul(r_raw, cell_inv) + r_frac = torch.remainder(r_frac + 0.5, 1.0) - 0.5 + pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) + point_limit = pi_tensor - 32.0 * torch.finfo(real_dtype).eps + r_in = torch.clamp(2.0 * pi_tensor * r_frac, min=-point_limit, max=point_limit).contiguous() + nufft_points = r_in.transpose(0, 1).contiguous() + + norms = torch.norm(box, dim=1) + nk = [max(1, int(n.item() / self.n_dl)) for n in norms] + n1 = torch.arange(-nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype) + n2 = torch.arange(-nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype) + n3 = torch.arange(-nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype) + + kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") + k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 + mask = (k_sq==0) + + wl, sl = self._kernel_params() + min_term = -1.0 / torch.exp(-2.0 * sl) + kfac = wl.view(1, 1, 1, -1) * torch.exp(k_sq.unsqueeze(-1) * min_term) + # print("multiplier:",multiplier.shape,multiplier)c + kfac = kfac.sum(dim=-1) + kfac = kfac.to(dtype=real_dtype) + kfac[mask] = 0.0 + + atom_energy = torch.zeros((r_raw.shape[0], 1), dtype=real_dtype, device=runtime_device) + output_shape = tuple(int(x) for x in kx_grid.shape) + + def _eval_nufft(work_device: torch.device) -> torch.Tensor: + points_work = nufft_points.to(device=work_device) + q_work = q.to(device=work_device) + kfac_work = kfac.to(device=work_device) + volume_work = volume.to(device=work_device) + atom_energy_work = torch.zeros( + (r_raw.shape[0], 1), + dtype=real_dtype, + device=work_device, + ) + + q_work_t = q_work.transpose(0, 1).contiguous() + charge_work = torch.complex( + q_work_t, + torch.zeros_like(q_work_t), + ).to(dtype=complex_dtype) + recon = pytorch_finufft.functional.finufft_type1( + points_work, + charge_work, + output_shape=output_shape, + eps=1e-4, + isign=-1, + ) + con_sog = torch.mul(kfac_work.unsqueeze(0), recon) + ifft_con = pytorch_finufft.functional.finufft_type2( + points_work, + con_sog, + eps=1e-4, + isign=1, + ) / (2.0 * volume_work) + atom_energy_work += (charge_work * ifft_con).real.sum() + + if self.remove_self_interaction: + diag_sum = kfac_work.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume_work) + atom_energy_work -= torch.sum(q_work**2) * diag_sum + + return atom_energy_work.to(device=runtime_device) + + + atom_energy = _eval_nufft(runtime_device) + + return atom_energy + + def _corr_head( self, lr_val: torch.Tensor, + coord: Optional[torch.Tensor] = None, + box: Optional[torch.Tensor] = None, ) -> torch.Tensor: - # TODO: - # Long-range correction energy calculation - return torch.zeros_like(lr_val) + if lr_val.dim() != 3: + raise ValueError( + f"`lr_val` should have shape [nframe, nloc, nchan], got {tuple(lr_val.shape)}." + ) + + if coord is None or box is None: + return torch.zeros( + (lr_val.shape[0], lr_val.shape[1], 1), + dtype=lr_val.dtype, + device=lr_val.device, + ) + + nf, nloc, _ = lr_val.shape + coord = coord.reshape(nf, nloc, 3) + box = box.reshape(nf, 3, 3) + + coord = coord.to(dtype=lr_val.dtype, device=lr_val.device) + box = box.to(dtype=lr_val.dtype, device=lr_val.device) + corr = torch.zeros((nf, nloc, 1), dtype=lr_val.dtype, device=lr_val.device) + for ff in range(nf): + box_now = box[ff] + corr[ff] = self.compute_potential_SOG_triclinic_NUFFT( + coord[ff], + lr_val[ff], + box_now, + ) + return corr def forward( self, @@ -256,6 +419,8 @@ def forward( h2: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, + coord: Optional[torch.Tensor] = None, + box: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: out = self._forward_common( descriptor=descriptor, @@ -267,7 +432,8 @@ def forward( aparam=aparam, ) short_energy = out["sr"] - corr_energy = self.corr_head(out["lr"]) + corr_energy = self._corr_head(out["lr"], coord=coord, box=box) + # corr_energy = 0 result = {"energy": short_energy + corr_energy} if "middle_output" in out: result["middle_output"] = out["middle_output"] diff --git a/examples/water/dpa3/input_torch_copy.json b/examples/water/dpa3/input_torch_copy.json index c2ed3891e3..3c951b6235 100644 --- a/examples/water/dpa3/input_torch_copy.json +++ b/examples/water/dpa3/input_torch_copy.json @@ -92,7 +92,7 @@ "batch_size": 1, "_comment": "that's all" }, - "numb_steps": 10000, + "numb_steps": 500000, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", diff --git a/examples/water/sog/input_torch.json b/examples/water/sog/input_torch.json index 9be7bf249b..570459cefe 100644 --- a/examples/water/sog/input_torch.json +++ b/examples/water/sog/input_torch.json @@ -41,7 +41,7 @@ "type": "sog_energy", "var_name": "energy", "dim_out_sr": 1, - "dim_out_lr": 2, + "dim_out_lr": 1, "neuron_sr": [ 240, 240, @@ -124,7 +124,7 @@ ], "batch_size": 1 }, - "numb_steps": 10000, + "numb_steps": 500000, "gradient_max_norm": 5.0, "seed": 10, "disp_file": "lcurve.out", From d83257280499c3d39fba3fb6e5d178fe1c95b6d5 Mon Sep 17 00:00:00 2001 From: zhenyuan5090 Date: Mon, 23 Mar 2026 10:15:37 +0800 Subject: [PATCH 4/8] baseline-0.0.1 --- deepmd/pt/model/atomic_model/__init__.py | 4 + .../pt/model/atomic_model/les_atomic_model.py | 278 +++++++++++ deepmd/pt/model/model/__init__.py | 6 + deepmd/pt/model/model/les_model.py | 209 +++++++++ deepmd/pt/model/task/__init__.py | 4 + deepmd/pt/model/task/les_energy_fitting.py | 443 ++++++++++++++++++ deepmd/pt/model/task/lr_fitting.py | 8 +- deepmd/pt/model/task/sog_energy_fitting.py | 86 +++- deepmd/pt_expt/train/training.py | 4 + deepmd/utils/argcheck.py | 171 +++++++ examples/water/sog/input_torch.json | 4 +- pyproject.toml | 2 + 12 files changed, 1187 insertions(+), 32 deletions(-) create mode 100644 deepmd/pt/model/atomic_model/les_atomic_model.py create mode 100644 deepmd/pt/model/model/les_model.py create mode 100644 deepmd/pt/model/task/les_energy_fitting.py diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index c008d75575..c432ae22e4 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -48,6 +48,9 @@ from .sog_atomic_model import ( SOGEnergyAtomicModel, ) +from .les_atomic_model import ( + LESEnergyAtomicModel, +) __all__ = [ "BaseAtomicModel", @@ -59,6 +62,7 @@ "DPPropertyAtomicModel", "DPZBLLinearEnergyAtomicModel", "LinearEnergyAtomicModel", + "LESEnergyAtomicModel", "LREnergyAtomicModel", "PairTabAtomicModel", "SOGEnergyAtomicModel", diff --git a/deepmd/pt/model/atomic_model/les_atomic_model.py b/deepmd/pt/model/atomic_model/les_atomic_model.py new file mode 100644 index 0000000000..9036203f01 --- /dev/null +++ b/deepmd/pt/model/atomic_model/les_atomic_model.py @@ -0,0 +1,278 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import Any, Optional + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt.model.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt.model.task.les_energy_fitting import ( + LESEnergyFittingNet, +) +from deepmd.utils.version import ( + check_version_compatibility, +) + +from .base_atomic_model import ( + BaseAtomicModel, +) + + +@BaseAtomicModel.register("energy_les") +class LESEnergyAtomicModel(BaseAtomicModel): + """Energy model using a dedicated LES energy fitting net. + + The LES energy fitting net combines a short-range invariant fitting + and a long-range correction derived from another invariant fitting. + This avoids requiring a user-defined property name in the dataset. + """ + + def __init__( + self, + descriptor: Any, + type_map: list[str], + les_energy_fitting: Optional[LESEnergyFittingNet] = None, + fitting: Optional[Any] = None, + **kwargs: Any, + ) -> None: + super().__init__(type_map, **kwargs) + if les_energy_fitting is None: + les_energy_fitting = fitting + if not isinstance(les_energy_fitting, LESEnergyFittingNet): + raise TypeError( + "les_energy_fitting must be an instance of LESEnergyFittingNet" + ) + + self.descriptor = descriptor + self.fitting_net = les_energy_fitting + # self.les_energy_fitting = self.fitting_net + self.type_map = type_map + self.ntypes = len(type_map) + self.rcut = self.descriptor.get_rcut() + self.sel = self.descriptor.get_sel() + + super().init_out_stat() + + self.enable_eval_descriptor_hook = False + self.enable_eval_fitting_last_layer_hook = False + self.eval_descriptor_list: list[torch.Tensor] = [] + self.eval_fitting_last_layer_list: list[torch.Tensor] = [] + + @torch.jit.export + def fitting_output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def get_rcut(self) -> float: + return self.rcut + + def get_sel(self) -> list[int]: + return self.sel + + def mixed_types(self) -> bool: + return self.descriptor.mixed_types() + + def has_message_passing(self) -> bool: + return self.descriptor.has_message_passing() + + def need_sorted_nlist_for_lower(self) -> bool: + return self.descriptor.need_sorted_nlist_for_lower() + + def set_case_embd(self, case_idx: int) -> None: + self.fitting_net.set_case_embd(case_idx) + + def get_dim_fparam(self) -> int: + return self.fitting_net.get_dim_fparam() + + def has_default_fparam(self) -> bool: + return self.fitting_net.has_default_fparam() + + def get_default_fparam(self) -> Optional[torch.Tensor]: + return self.fitting_net.get_default_fparam() + + def get_dim_aparam(self) -> int: + return self.fitting_net.get_dim_aparam() + + def get_sel_type(self) -> list[int]: + return self.fitting_net.get_sel_type() + + def is_aparam_nall(self) -> bool: + return False + + def set_eval_descriptor_hook(self, enable: bool) -> None: + self.enable_eval_descriptor_hook = enable + self.eval_descriptor_list.clear() + + def eval_descriptor(self) -> torch.Tensor: + return torch.concat(self.eval_descriptor_list) + + def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: + self.enable_eval_fitting_last_layer_hook = enable + self.fitting_net.set_return_middle_output(enable) + self.eval_fitting_last_layer_list.clear() + + def eval_fitting_last_layer(self) -> torch.Tensor: + return torch.concat(self.eval_fitting_last_layer_list) + + def forward_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ) -> dict[str, torch.Tensor]: + nframes, nloc, _ = nlist.shape + atype = extended_atype[:, :nloc] + if self.do_grad_r() or self.do_grad_c(): + extended_coord.requires_grad_(True) + + descriptor_comm_dict = comm_dict + if comm_dict is not None and "send_list" not in comm_dict: + descriptor_comm_dict = None + + descriptor, rot_mat, g2, h2, sw = self.descriptor( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + comm_dict=descriptor_comm_dict, + ) + assert descriptor is not None + if self.enable_eval_descriptor_hook: + self.eval_descriptor_list.append(descriptor.detach()) + + energy_ret = self.fitting_net( + descriptor, + atype, + gr=rot_mat, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + coord=extended_coord[:, :nloc, :], + box=( + comm_dict["box"].view(nframes, 3, 3) + if comm_dict is not None and "box" in comm_dict + else None + ), + ) + + if self.enable_eval_fitting_last_layer_hook and "middle_output" in energy_ret: + self.eval_fitting_last_layer_list.append( + energy_ret["middle_output"].detach() + ) + + ret = { + "energy": energy_ret["energy"], + } + if "middle_output" in energy_ret: + ret["middle_output"] = energy_ret["middle_output"] + return ret + + def apply_out_stat( + self, + ret: dict[str, torch.Tensor], + atype: torch.Tensor, + ) -> dict[str, torch.Tensor]: + out_bias, out_std = self._fetch_out_stat(self.bias_keys) + for kk in self.bias_keys: + ret[kk] = ret[kk] + out_bias[kk][atype] + return ret + + def compute_or_load_stat( + self, + sampled_func: Any, + stat_file_path: Optional[Any] = None, + compute_or_load_out_stat: bool = True, + preset_observed_type: Optional[list[str]] = None, + ) -> None: + if stat_file_path is not None and self.type_map is not None: + stat_file_path /= " ".join(self.type_map) + + def wrapped_sampler() -> list[dict]: + sampled = sampled_func() + if self.pair_excl is not None: + pair_exclude_types = self.pair_excl.get_exclude_types() + for sample in sampled: + sample["pair_exclude_types"] = list(pair_exclude_types) + if self.atom_excl is not None: + atom_exclude_types = self.atom_excl.get_exclude_types() + for sample in sampled: + sample["atom_exclude_types"] = list(atom_exclude_types) + if ( + "find_fparam" not in sampled[0] + and "fparam" not in sampled[0] + and self.has_default_fparam() + ): + default_fparam = self.get_default_fparam() + for sample in sampled: + nframe = sample["atype"].shape[0] + sample["fparam"] = default_fparam.repeat(nframe, 1) + return sampled + + self.descriptor.compute_input_stats(wrapped_sampler, stat_file_path) + self.compute_fitting_input_stat(wrapped_sampler, stat_file_path) + if compute_or_load_out_stat: + self.compute_or_load_out_stat(wrapped_sampler, stat_file_path) + + self._collect_and_set_observed_type( + wrapped_sampler, stat_file_path, preset_observed_type + ) + + def compute_fitting_input_stat( + self, + sample_merged: Any, + stat_file_path: Optional[Any] = None, + ) -> None: + self.fitting_net.compute_input_stats( + sample_merged, + protection=self.data_stat_protect, + stat_file_path=stat_file_path, + ) + + def serialize(self) -> dict: + dd = BaseAtomicModel.serialize(self) + dd.update( + { + "@class": "Model", + "@version": 1, + "type": "energy_les", + "type_map": self.type_map, + "descriptor": self.descriptor.serialize(), + "les_energy_fitting": self.fitting_net.serialize(), + } + ) + return dd + + @classmethod + def deserialize(cls, data: dict) -> "LESEnergyAtomicModel": + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + data.pop("type", None) + descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) + les_energy_fitting_obj = LESEnergyFittingNet.deserialize( + data.pop("les_energy_fitting") + ) + obj = cls( + descriptor=descriptor_obj, + les_energy_fitting=les_energy_fitting_obj, + **data, + ) + return obj diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index aea6b1be12..fedebb6f22 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -76,6 +76,9 @@ from .sog_model import ( SOGEnergyModel, ) +from .les_model import ( + LESEnergyModel, +) def _get_standard_model_components(model_params: dict, ntypes: int) -> tuple: @@ -273,6 +276,8 @@ def get_standard_model(model_params: dict) -> BaseModel: modelcls = PropertyModel elif fitting_net_type == "sog_energy": modelcls = SOGEnergyModel + elif fitting_net_type == "les_energy": + modelcls = LESEnergyModel else: raise RuntimeError(f"Unknown fitting type: {fitting_net_type}") @@ -315,6 +320,7 @@ def get_model(model_params: dict) -> Any: "EnergyModel", "FrozenModel", "LinearEnergyModel", + "LESEnergyModel", "PolarModel", "SOGEnergyModel", "SpinEnergyModel", diff --git a/deepmd/pt/model/model/les_model.py b/deepmd/pt/model/model/les_model.py new file mode 100644 index 0000000000..da962d7000 --- /dev/null +++ b/deepmd/pt/model/model/les_model.py @@ -0,0 +1,209 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + Optional, +) + +import torch + +from deepmd.pt.model.model.transform_output import ( + communicate_extended_output, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + +from deepmd.pt.model.atomic_model import ( + LESEnergyAtomicModel, +) +from deepmd.pt.model.model.model import ( + BaseModel, +) + +from .dp_model import ( + DPModelCommon, +) +from .make_hessian_model import ( + make_hessian_model, +) +from .make_model import ( + make_model, +) + +LESEnergyModel_ = make_model(LESEnergyAtomicModel) + + +@BaseModel.register("les_ener") +class LESEnergyModel(DPModelCommon, LESEnergyModel_): + model_type = "les_ener" + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + DPModelCommon.__init__(self) + LESEnergyModel_.__init__(self, *args, **kwargs) + self._hessian_enabled = False + + def enable_hessian(self) -> None: + self.__class__ = make_hessian_model(type(self)) + self.hess_fitting_def = super(type(self), self).atomic_output_def() + self.requires_hessian("energy") + self._hessian_enabled = True + + @torch.jit.export + def get_observed_type_list(self) -> list[str]: + """Get observed types (elements) of the model during data statistics. + + Returns + ------- + observed_type_list: a list of the observed types in this model. + """ + type_map = self.get_type_map() + out_bias = self.atomic_model.get_out_bias()[0] + + assert out_bias is not None, "No out_bias found in the model." + assert out_bias.dim() == 2, "The supported out_bias should be a 2D tensor." + assert out_bias.size(0) == len(type_map), ( + "The out_bias shape does not match the type_map length." + ) + bias_mask = ( + torch.gt(torch.abs(out_bias), 1e-6).any(dim=-1).detach().cpu() + ) # 1e-6 for stability + + observed_type_list: list[str] = [] + for i in range(len(type_map)): + if bias_mask[i]: + observed_type_list.append(type_map[i]) + return observed_type_list + + def translated_output_def(self) -> dict[str, Any]: + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], + } + if self.do_grad_r("energy"): + output_def["force"] = out_def_data["energy_derv_r"] + output_def["force"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = out_def_data["energy_derv_c_redu"] + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = out_def_data["energy_derv_c"] + output_def["atom_virial"].squeeze(-3) + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + if self._hessian_enabled: + output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] + return output_def + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + cc, bb, fp, ap, input_prec = self._input_type_cast( + coord, box=box, fparam=fparam, aparam=aparam + ) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + cc, + atype, + self.get_rcut(), + self.get_sel(), + mixed_types=True, + box=bb, + ) + comm_dict: dict[str, torch.Tensor] | None = None + if bb is not None: + comm_dict = {"box": bb} + model_predict_lower = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fp, + aparam=ap, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + ) + model_ret = communicate_extended_output( + model_predict_lower, + self.model_output_def(), + mapping, + do_atomic_virial=do_atomic_virial, + ) + model_ret = self._output_type_cast(model_ret, input_prec) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze( + -3 + ) + else: + model_predict["force"] = model_ret["dforce"] + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + if self._hessian_enabled: + model_predict["hessian"] = model_ret["energy_derv_r_derv_r"].squeeze(-2) + else: + model_predict = model_ret + model_predict["updated_coord"] += coord + return model_predict + + @torch.jit.export + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + ) -> dict[str, torch.Tensor]: + model_ret = self.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=self.need_sorted_nlist_for_lower(), + ) + if self.get_fitting_net() is not None: + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret[ + "energy_derv_c" + ].squeeze(-3) + else: + assert model_ret["dforce"] is not None + model_predict["dforce"] = model_ret["dforce"] + else: + model_predict = model_ret + return model_predict diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index 65d7983fd6..c3b7025358 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -33,6 +33,9 @@ from .sog_energy_fitting import ( SOGEnergyFittingNet, ) +from .les_energy_fitting import ( + LESEnergyFittingNet, +) __all__ = [ "BaseFitting", @@ -43,6 +46,7 @@ "EnergyFittingNetDirect", "Fitting", "LRFittingNet", + "LESEnergyFittingNet", "PolarFittingNet", "PropertyFittingNet", "SOGEnergyFittingNet", diff --git a/deepmd/pt/model/task/les_energy_fitting.py b/deepmd/pt/model/task/les_energy_fitting.py new file mode 100644 index 0000000000..c46906de31 --- /dev/null +++ b/deepmd/pt/model/task/les_energy_fitting.py @@ -0,0 +1,443 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import logging +from typing import ( + Any, + Optional, + Union, +) + +import numpy as np +import torch +import pytorch_finufft + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, + fitting_check_output, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.env import ( + DEFAULT_PRECISION, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, + to_torch_tensor, +) + +dtype = env.GLOBAL_PT_FLOAT_PRECISION +device = env.DEVICE + +log = logging.getLogger(__name__) +from deepmd.pt.model.task.lr_fitting import ( + LRFittingNet, +) + + +LES_DEFAULT_AMPLITUDE = to_numpy_array(np.array([ + 0.2750, + 0.1375, + 0.0688, + 0.0344, + 0.0172, + 0.0086, + 0.0043, + 0.0021, + 0.0011, + 0.0005, + 0.0003, + 0.0001, +])) +LES_DEFAULT_SHIFT = to_numpy_array(np.array([ + 2.8, + 5.7, + 11.4, + 22.7, + 45.5, + 91.0, + 182.0, + 364.0, + 728.0, + 1456.0, + 2912.0, + 5823.9, +])) + + +@LRFittingNet.register("les_energy") +@fitting_check_output +class LESEnergyFittingNet(LRFittingNet): + """Construct a LES sr+lr interactions fitting net. + + Parameters + ---------- + var_name : str + The atomic property to fit. + ntypes : int + Element count. + dim_descrpt : int + Embedding width per atom. + dim_out_sr : int + The output dimension of the sr fitting net. + dim_out_lr : int + The output dimension of the lr fitting net. + neuron_sr : list[int] + Number of neurons in each hidden layers of the sr fitting net. + neuron_lr : list[int] + Number of neurons in each hidden layers of the lr fitting net. + bias_atom_e : torch.Tensor, optional + Average energy per atom for each element. + resnet_dt : bool + Using time-step in the ResNet construction. + numb_fparam : int + Number of frame parameters. + numb_aparam : int + Number of atomic parameters. + dim_case_embd : int + Dimension of case specific embedding. + activation_function : str + Activation function. + precision : str + Numerical precision. + mixed_types : bool + If true, use a uniform fitting net for all atom types, otherwise use + different fitting nets for different atom types. + rcond : float, optional + The condition number for the regression of atomic energy. + seed : int, optional + Random seed. + exclude_types: list[int] + Atomic contributions of the excluded atom types are set zero. + trainable : Union[list[bool], bool] + If the parameters in the fitting net are trainable. + Now this only supports setting all the parameters in the fitting net at one state. + When in list[bool], the trainable will be True only if all the boolean parameters are True. + remove_vaccum_contribution: list[bool], optional + Remove vacuum contribution before the bias is added. The list assigned each + type. For `mixed_types` provide `[True]`, otherwise it should be a list of the same + length as `ntypes` signaling if or not removing the vacuum contribution for the atom types in the list. + type_map: list[str], Optional + A list of strings. Give the name to each type of atoms. + use_aparam_as_mask: bool + If True, the aparam will not be used in fitting net for embedding. + default_fparam: list[float], optional + The default frame parameter. If set, when `fparam.npy` files are not included in the data system, + this value will be used as the default value for the frame parameter in the fitting net. + """ + + def __init__( + self, + var_name: str, + ntypes: int, + dim_descrpt: int, + dim_out_sr: int, + dim_out_lr: int, + neuron_sr: list[int] = [128, 128, 128], + neuron_lr: list[int] = [128, 128, 128], + bias_atom_e: Optional[torch.Tensor] = None, + resnet_dt: bool = True, + numb_fparam: int = 0, + numb_aparam: int = 0, + dim_case_embd: int = 0, + activation_function: str = "tanh", + precision: str = DEFAULT_PRECISION, + mixed_types: bool = True, + rcond: Optional[float] = None, + seed: Optional[Union[int, list[int]]] = None, + exclude_types: list[int] = [], + trainable: Union[bool, list[bool]] = True, + remove_vaccum_contribution: Optional[list[bool]] = None, + type_map: Optional[list[str]] = None, + use_aparam_as_mask: bool = False, + default_fparam: Optional[list[float]] = None, + shift: Optional[Union[list[float], torch.Tensor]] = None, + amplitude: Optional[Union[list[float], torch.Tensor]] = None, + n_dl: int = 1, + **kwargs: Any, + ) -> None: + super().__init__( + var_name=var_name, + ntypes=ntypes, + dim_descrpt=dim_descrpt, + dim_out_sr=dim_out_sr, + dim_out_lr=dim_out_lr, + neuron_sr=neuron_sr, + neuron_lr=neuron_lr, + bias_atom_e=bias_atom_e, + resnet_dt=resnet_dt, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + dim_case_embd=dim_case_embd, + activation_function=activation_function, + precision=precision, + mixed_types=mixed_types, + rcond=rcond, + seed=seed, + exclude_types=exclude_types, + trainable=trainable, + remove_vaccum_contribution=remove_vaccum_contribution, + type_map=type_map, + use_aparam_as_mask=use_aparam_as_mask, + default_fparam=default_fparam, + **kwargs, + ) + if isinstance(shift, (list, tuple)): + shift = to_numpy_array(np.array(shift)) + if isinstance(amplitude, (list, tuple)): + amplitude = to_numpy_array(np.array(amplitude)) + shift_tensor = to_torch_tensor(shift) + amplitude_tensor = to_torch_tensor(amplitude) + if shift_tensor is None: + shift_tensor = to_torch_tensor(LES_DEFAULT_SHIFT) + if amplitude_tensor is None: + amplitude_tensor = to_torch_tensor(LES_DEFAULT_AMPLITUDE) + + shift_tensor = shift_tensor.to(dtype=dtype, device=device) + amplitude_tensor = amplitude_tensor.to(dtype=dtype, device=device) + pi_tensor = torch.tensor(torch.pi, dtype=dtype, device=device) + sqr_pi_tensor = torch.sqrt(pi_tensor) + shift_safe = torch.clamp( + shift_tensor, + min=torch.finfo(shift_tensor.dtype).eps, + ) + wl_tensor = amplitude_tensor * (sqr_pi_tensor**3) * (shift_safe**3) + sl_tensor = -torch.log(2.0 / shift_safe) + + self.n_dl = max(1, int(n_dl)) + self.wl = torch.nn.Parameter( + wl_tensor, + requires_grad=bool(self.trainable), + ) + self.sl = torch.nn.Parameter( + sl_tensor, + requires_grad=bool(self.trainable), + ) + self.remove_self_interaction = False + self._nufft_fallback_warned = False + + + def output_def(self) -> FittingOutputDef: + return FittingOutputDef( + [ + OutputVariableDef( + name="energy", + shape=[1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ) + ] + ) + + def serialize(self) -> dict: + data = super().serialize() + data["type"] = "les_energy" + data["@variables"]["wl"] = to_numpy_array(self.wl) + data["@variables"]["sl"] = to_numpy_array(self.sl) + + pi_tensor = torch.tensor(torch.pi, dtype=self.sl.dtype, device=self.sl.device) + sqr_pi_tensor = torch.sqrt(pi_tensor) + shift_tensor = 2.0 * torch.exp(self.sl) + amplitude_tensor = self.wl / ((sqr_pi_tensor**3) * (shift_tensor**3)) + data["@variables"]["shift"] = to_numpy_array(shift_tensor) + data["@variables"]["amplitude"] = to_numpy_array(amplitude_tensor) + data["n_dl"] = self.n_dl + return data + + @classmethod + def deserialize(cls, data: dict) -> "LESEnergyFittingNet": + data = data.copy() + variables = data.get("@variables", {}) + obj = super().deserialize(data) + wl_tensor = to_torch_tensor(variables.get("wl", None)) + sl_tensor = to_torch_tensor(variables.get("sl", None)) + shift_tensor = to_torch_tensor(variables.get("shift", None)) + amplitude_tensor = to_torch_tensor(variables.get("amplitude", None)) + + with torch.no_grad(): + if wl_tensor is not None and sl_tensor is not None: + obj.wl.copy_(wl_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device)) + obj.sl.copy_(sl_tensor.to(dtype=obj.sl.dtype, device=obj.sl.device)) + elif shift_tensor is not None and amplitude_tensor is not None: + pi_tensor = torch.tensor(torch.pi, dtype=obj.wl.dtype, device=obj.wl.device) + sqr_pi_tensor = torch.sqrt(pi_tensor) + shift_tensor = shift_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device) + amplitude_tensor = amplitude_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device) + shift_safe = torch.clamp( + shift_tensor, + min=torch.finfo(shift_tensor.dtype).eps, + ) + wl_tensor = amplitude_tensor * (sqr_pi_tensor**3) * (shift_safe**3) + sl_tensor = -torch.log(2.0 / shift_safe) + obj.wl.copy_(wl_tensor) + obj.sl.copy_(sl_tensor) + return obj + + def _kernel_params(self) -> tuple[torch.Tensor, torch.Tensor]: + return self.wl, self.sl + + def compute_potential_LES_triclinic_NUFFT( + self, + r_raw: torch.Tensor, + q: torch.Tensor, + box: torch.Tensor, + ) -> torch.Tensor: + runtime_device = torch.device(device) + if q.dim() == 1: + q = q.unsqueeze(-1) + + if not torch.isfinite(r_raw).all(): + raise ValueError("`r_raw` contains non-finite values, cannot run NUFFT.") + if not torch.isfinite(box).all(): + raise ValueError("`box` contains non-finite values, cannot run NUFFT.") + + real_dtype = torch.float64 if r_raw.dtype == torch.float64 else torch.float32 + complex_dtype = torch.complex128 if real_dtype == torch.float64 else torch.complex64 + + r_raw = r_raw.to(dtype=real_dtype, device=runtime_device) + q = q.to(dtype=real_dtype, device=runtime_device) + box = box.to(dtype=real_dtype, device=runtime_device) + + volume = torch.det(box) + if torch.abs(volume) <= torch.finfo(real_dtype).eps: + raise ValueError("`box` is singular (near-zero volume), cannot run NUFFT.") + + cell_inv = torch.linalg.inv(box) + # Keep points strictly in the principal unit cell so NUFFT points are in [-pi, pi). + r_frac = torch.matmul(r_raw, cell_inv) + r_frac = torch.remainder(r_frac + 0.5, 1.0) - 0.5 + pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) + point_limit = pi_tensor - 32.0 * torch.finfo(real_dtype).eps + r_in = torch.clamp(2.0 * pi_tensor * r_frac, min=-point_limit, max=point_limit).contiguous() + nufft_points = r_in.transpose(0, 1).contiguous() + + norms = torch.norm(box, dim=1) + nk = [max(1, int(n.item() / self.n_dl)) for n in norms] + n1 = torch.arange(-nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype) + n2 = torch.arange(-nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype) + n3 = torch.arange(-nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype) + + kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") + k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 + mask = (k_sq==0) + + wl, sl = self._kernel_params() + min_term = -1.0 / torch.exp(-2.0 * sl) + kfac = wl.view(1, 1, 1, -1) * torch.exp(k_sq.unsqueeze(-1) * min_term) + # print("multiplier:",multiplier.shape,multiplier)c + kfac = kfac.sum(dim=-1) + kfac = kfac.to(dtype=real_dtype) + kfac[mask] = 0.0 + + atom_energy = torch.zeros((r_raw.shape[0], 1), dtype=real_dtype, device=runtime_device) + output_shape = tuple(int(x) for x in kx_grid.shape) + + def _eval_nufft(work_device: torch.device) -> torch.Tensor: + points_work = nufft_points.to(device=work_device) + q_work = q.to(device=work_device) + kfac_work = kfac.to(device=work_device) + volume_work = volume.to(device=work_device) + atom_energy_work = torch.zeros( + (r_raw.shape[0], 1), + dtype=real_dtype, + device=work_device, + ) + + q_work_t = q_work.transpose(0, 1).contiguous() + charge_work = torch.complex( + q_work_t, + torch.zeros_like(q_work_t), + ).to(dtype=complex_dtype) + recon = pytorch_finufft.functional.finufft_type1( + points_work, + charge_work, + output_shape=output_shape, + eps=1e-4, + isign=-1, + ) + con_les = torch.mul(kfac_work.unsqueeze(0), recon) + ifft_con = pytorch_finufft.functional.finufft_type2( + points_work, + con_les, + eps=1e-4, + isign=1, + ) / (2.0 * volume_work) + atom_energy_work += (charge_work * ifft_con).real.sum() + + if self.remove_self_interaction: + diag_sum = kfac_work.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume_work) + atom_energy_work -= torch.sum(q_work**2) * diag_sum + + return atom_energy_work.to(device=runtime_device) + + + atom_energy = _eval_nufft(runtime_device) + + return atom_energy + + def _corr_head( + self, + lr_val: torch.Tensor, + coord: Optional[torch.Tensor] = None, + box: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if lr_val.dim() != 3: + raise ValueError( + f"`lr_val` should have shape [nframe, nloc, nchan], got {tuple(lr_val.shape)}." + ) + + if coord is None or box is None: + return torch.zeros( + (lr_val.shape[0], lr_val.shape[1], 1), + dtype=lr_val.dtype, + device=lr_val.device, + ) + + nf, nloc, _ = lr_val.shape + coord = coord.reshape(nf, nloc, 3) + box = box.reshape(nf, 3, 3) + + coord = coord.to(dtype=lr_val.dtype, device=lr_val.device) + box = box.to(dtype=lr_val.dtype, device=lr_val.device) + corr = torch.zeros((nf, nloc, 1), dtype=lr_val.dtype, device=lr_val.device) + for ff in range(nf): + box_now = box[ff] + corr[ff] = self.compute_potential_LES_triclinic_NUFFT( + coord[ff], + lr_val[ff], + box_now, + ) + return corr + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: Optional[torch.Tensor] = None, + g2: Optional[torch.Tensor] = None, + h2: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + coord: Optional[torch.Tensor] = None, + box: Optional[torch.Tensor] = None, + ) -> dict[str, torch.Tensor]: + out = self._forward_common( + descriptor=descriptor, + atype=atype, + gr=gr, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + short_energy = out["sr"] + corr_energy = self._corr_head(out["lr"], coord=coord, box=box) + # corr_energy = 0 + result = {"energy": short_energy + corr_energy} + if "middle_output" in out: + result["middle_output"] = out["middle_output"] + return result + + # make jit happy with torch 2.0.0 + exclude_types: list[int] \ No newline at end of file diff --git a/deepmd/pt/model/task/lr_fitting.py b/deepmd/pt/model/task/lr_fitting.py index aba8a5c238..82702a38de 100644 --- a/deepmd/pt/model/task/lr_fitting.py +++ b/deepmd/pt/model/task/lr_fitting.py @@ -235,7 +235,7 @@ def __init__( networks=[ FittingNet( in_dim, - 1, + self.lr_net_dim_out, self.neuron_lr, self.activation_function, self.resnet_dt, @@ -254,7 +254,7 @@ def __init__( networks=[ FittingNet( in_dim, - 1, + self.sr_net_dim_out, self.neuron_sr, self.activation_function, self.resnet_dt, @@ -572,7 +572,7 @@ def _forward_common( def _apply_networks( self, layers: NetworkCollection, - neuron: int, + neuron: list[int], dim_out: int, xx: torch.Tensor, xx_zeros: Optional[torch.Tensor], @@ -581,7 +581,7 @@ def _apply_networks( bool_bias: bool = False, ) -> torch.Tensor: nf, nloc, _ = xx.shape - outs = torch.zeros((nf, nloc, 1), dtype=self.prec, device=xx.device) + outs = torch.zeros((nf, nloc, dim_out), dtype=self.prec, device=xx.device) if self.mixed_types: atom_property = layers.networks[0](xx) if self.eval_return_middle_output and middle_output is not None: diff --git a/deepmd/pt/model/task/sog_energy_fitting.py b/deepmd/pt/model/task/sog_energy_fitting.py index 85bb1a2388..78e6c81503 100644 --- a/deepmd/pt/model/task/sog_energy_fitting.py +++ b/deepmd/pt/model/task/sog_energy_fitting.py @@ -216,7 +216,6 @@ def __init__( self.remove_self_interaction = False self._nufft_fallback_warned = False - def output_def(self) -> FittingOutputDef: return FittingOutputDef( [ @@ -229,47 +228,78 @@ def output_def(self) -> FittingOutputDef: ) ] ) - + + @staticmethod + def _wl_sl_to_shift_amplitude( + wl_tensor: torch.Tensor, + sl_tensor: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + pi_tensor = torch.tensor( + torch.pi, + dtype=sl_tensor.dtype, + device=sl_tensor.device, + ) + sqr_pi_tensor = torch.sqrt(pi_tensor) + shift_tensor = 2.0 * torch.exp(sl_tensor) + amplitude_tensor = wl_tensor / ((sqr_pi_tensor**3) * (shift_tensor**3)) + return shift_tensor, amplitude_tensor + + @staticmethod + def _shift_amplitude_to_wl_sl( + shift_tensor: torch.Tensor, + amplitude_tensor: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + pi_tensor = torch.tensor( + torch.pi, + dtype=shift_tensor.dtype, + device=shift_tensor.device, + ) + sqr_pi_tensor = torch.sqrt(pi_tensor) + shift_safe = torch.clamp( + shift_tensor, + min=torch.finfo(shift_tensor.dtype).eps, + ) + wl_tensor = amplitude_tensor * (sqr_pi_tensor**3) * (shift_safe**3) + sl_tensor = -torch.log(2.0 / shift_safe) + return wl_tensor, sl_tensor + def serialize(self) -> dict: data = super().serialize() data["type"] = "sog_energy" - data["@variables"]["wl"] = to_numpy_array(self.wl) - data["@variables"]["sl"] = to_numpy_array(self.sl) - - pi_tensor = torch.tensor(torch.pi, dtype=self.sl.dtype, device=self.sl.device) - sqr_pi_tensor = torch.sqrt(pi_tensor) - shift_tensor = 2.0 * torch.exp(self.sl) - amplitude_tensor = self.wl / ((sqr_pi_tensor**3) * (shift_tensor**3)) - data["@variables"]["shift"] = to_numpy_array(shift_tensor) - data["@variables"]["amplitude"] = to_numpy_array(amplitude_tensor) + variables = data["@variables"] + variables["wl"] = to_numpy_array(self.wl) + variables["sl"] = to_numpy_array(self.sl) + shift_tensor, amplitude_tensor = self._wl_sl_to_shift_amplitude( + self.wl, + self.sl, + ) + variables["shift"] = to_numpy_array(shift_tensor) + variables["amplitude"] = to_numpy_array(amplitude_tensor) data["n_dl"] = self.n_dl return data @classmethod def deserialize(cls, data: dict) -> "SOGEnergyFittingNet": data = data.copy() - variables = data.get("@variables", {}) + variables = data.get("@variables", {}).copy() + + wl_tensor = to_torch_tensor(variables.pop("wl", None)) + sl_tensor = to_torch_tensor(variables.pop("sl", None)) + shift_tensor = to_torch_tensor(variables.pop("shift", None)) + amplitude_tensor = to_torch_tensor(variables.pop("amplitude", None)) + data["@variables"] = variables + obj = super().deserialize(data) - wl_tensor = to_torch_tensor(variables.get("wl", None)) - sl_tensor = to_torch_tensor(variables.get("sl", None)) - shift_tensor = to_torch_tensor(variables.get("shift", None)) - amplitude_tensor = to_torch_tensor(variables.get("amplitude", None)) with torch.no_grad(): if wl_tensor is not None and sl_tensor is not None: obj.wl.copy_(wl_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device)) obj.sl.copy_(sl_tensor.to(dtype=obj.sl.dtype, device=obj.sl.device)) elif shift_tensor is not None and amplitude_tensor is not None: - pi_tensor = torch.tensor(torch.pi, dtype=obj.wl.dtype, device=obj.wl.device) - sqr_pi_tensor = torch.sqrt(pi_tensor) - shift_tensor = shift_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device) - amplitude_tensor = amplitude_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device) - shift_safe = torch.clamp( - shift_tensor, - min=torch.finfo(shift_tensor.dtype).eps, + wl_tensor, sl_tensor = cls._shift_amplitude_to_wl_sl( + shift_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device), + amplitude_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device), ) - wl_tensor = amplitude_tensor * (sqr_pi_tensor**3) * (shift_safe**3) - sl_tensor = -torch.log(2.0 / shift_safe) obj.wl.copy_(wl_tensor) obj.sl.copy_(sl_tensor) return obj @@ -386,6 +416,11 @@ def _corr_head( raise ValueError( f"`lr_val` should have shape [nframe, nloc, nchan], got {tuple(lr_val.shape)}." ) + if lr_val.shape[-1] != self.dim_out_lr: + raise ValueError( + f"`lr_val` channel mismatch: got {lr_val.shape[-1]}, " + f"expected dim_out_lr={self.dim_out_lr}." + ) if coord is None or box is None: return torch.zeros( @@ -433,7 +468,6 @@ def forward( ) short_energy = out["sr"] corr_energy = self._corr_head(out["lr"], coord=coord, box=box) - # corr_energy = 0 result = {"energy": short_energy + corr_energy} if "middle_output" in out: result["middle_output"] = out["middle_output"] diff --git a/deepmd/pt_expt/train/training.py b/deepmd/pt_expt/train/training.py index f8730ed271..6504cad363 100644 --- a/deepmd/pt_expt/train/training.py +++ b/deepmd/pt_expt/train/training.py @@ -398,6 +398,10 @@ def __init__( # Model --------------------------------------------------------------- self.model = get_model(deepcopy(model_params)).to(DEVICE) + for module in self.model.modules(): + set_freq = getattr(module, "set_debug_print_freq", None) + if callable(set_freq): + set_freq(self.disp_freq) # Loss ---------------------------------------------------------------- self.loss = get_loss( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 4f4bb4db4b..aa1dafd57c 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -58,6 +58,7 @@ doc_dos = "Fit a density of states model. The total density of states / site-projected density of states labels should be provided by `dos.npy` or `atom_dos.npy` in each data system. The file has number of frames lines and number of energy grid columns (times number of atoms in `atom_dos.npy`). See `loss` parameter." doc_dipole = "Fit an atomic dipole model. Global dipole labels or atomic dipole labels for all the selected atoms (see `sel_type`) should be provided by `dipole.npy` in each data system. The file either has number of frames lines and 3 times of number of selected atoms columns, or has number of frames lines and 3 columns. See `loss` parameter." doc_sog = "Fit a SOG-energy model (SR+LR)." +doc_les = "Fit a LES-energy model (SR+LR)." doc_polar = "Fit an atomic polarizability model. Global polarizazbility labels or atomic polarizability labels for all the selected atoms (see `sel_type`) should be provided by `polarizability.npy` in each data system. The file with has number of frames lines and 9 times of number of selected atoms columns, or has number of frames lines and 9 columns. See `loss` parameter." # modifier doc_dipole_charge = "Use WFCC to model the electronic structure of the system. Correct the long-range interaction." @@ -2225,6 +2226,176 @@ def fitting_sog_energy() -> list[Argument]: ] +@fitting_args_plugin.register("les_energy", doc=doc_les) +def fitting_les_energy() -> list[Argument]: + doc_var_name = "The atomic property name used by the LR fitting net. Usually set to `energy`." + doc_dim_out_sr = "The output dimension of the short-range fitting branch." + doc_dim_out_lr = "The output dimension of the long-range fitting branch." + doc_neuron_sr = "The number of neurons in each hidden layer of the short-range fitting net." + doc_neuron_lr = "The number of neurons in each hidden layer of the long-range fitting net." + doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." + doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." + doc_default_fparam = "The default frame parameter. If set, when `fparam.npy` files are not included in the data system, this value will be used as the default value for the frame parameter in the fitting net." + doc_dim_case_embd = "The dimension of the case embedding embedding. When training or fine-tuning a multitask model with case embedding embeddings, this number should be set to the number of model branches." + doc_activation_function = f'The activation function in the fitting net. Supported activation functions are {list_to_doc(ACTIVATION_FN_DICT.keys())} Note that "gelu" denotes the custom operator version, and "gelu_tf" denotes the TF standard version. If you set "None" or "none" here, no activation function will be used.' + doc_precision = f"The precision of the fitting net parameters, supported options are {list_to_doc(PRECISION_DICT.keys())} Default follows the interface precision." + doc_resnet_dt = 'Whether to use a "Timestep" in the skip connection' + doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be bool or list[bool]. When list[bool] is given, all values must be True to make parameters trainable." + doc_rcond = "The condition number used to determine the initial energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details." + doc_seed = "Random seed for parameter initialization of the fitting net" + doc_exclude_types = "The excluded atom types whose atomic contributions are set to zero." + doc_use_aparam_as_mask = ( + "Whether to use the aparam as a mask in input." + "If True, the aparam will not be used in fitting net for embedding." + ) + doc_shift = "Shift values of the LES long-range correction kernels." + doc_amplitude = "Amplitude values of the LES long-range correction kernels." + + return [ + Argument( + "var_name", + str, + optional=True, + default="energy", + doc=doc_only_pt_supported + doc_var_name, + ), + Argument( + "dim_out_sr", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_dim_out_sr, + ), + Argument( + "dim_out_lr", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_dim_out_lr, + ), + Argument( + "neuron_sr", + list[int], + optional=True, + default=[128, 128, 128], + doc=doc_only_pt_supported + doc_neuron_sr, + ), + Argument( + "neuron_lr", + list[int], + optional=True, + default=[128, 128, 128], + doc=doc_only_pt_supported + doc_neuron_lr, + ), + Argument( + "numb_fparam", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_numb_fparam, + ), + Argument( + "numb_aparam", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_numb_aparam, + ), + Argument( + "default_fparam", + list[float], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_default_fparam, + ), + Argument( + "dim_case_embd", + int, + optional=True, + default=0, + doc=doc_only_pt_supported + doc_dim_case_embd, + ), + Argument( + "activation_function", + str, + optional=True, + default="tanh", + doc=doc_activation_function, + ), + Argument("precision", str, optional=True, default="default", doc=doc_precision), + Argument("resnet_dt", bool, optional=True, default=True, doc=doc_resnet_dt), + Argument( + "trainable", + [list[bool], bool], + optional=True, + default=True, + doc=doc_only_pt_supported + doc_trainable, + ), + Argument( + "rcond", + [float, type(None)], + optional=True, + default=None, + doc=doc_only_pt_supported + doc_rcond, + ), + Argument("seed", [int, None], optional=True, doc=doc_seed), + Argument( + "exclude_types", + list[int], + optional=True, + default=[], + doc=doc_only_pt_supported + doc_exclude_types, + ), + Argument( + "use_aparam_as_mask", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_use_aparam_as_mask, + ), + Argument( + "shift", + list[float], + optional=True, + default=[ + 0.2750, + 0.1375, + 0.0688, + 0.0344, + 0.0172, + 0.0086, + 0.0043, + 0.0021, + 0.0011, + 0.0005, + 0.0003, + 0.0001, + ], + doc=doc_only_pt_supported + doc_shift, + ), + Argument( + "amplitude", + list[float], + optional=True, + default=[ + 2.8, + 5.7, + 11.4, + 22.7, + 45.5, + 91.0, + 182.0, + 364.0, + 728.0, + 1456.0, + 2912.0, + 5823.9, + ], + doc=doc_only_pt_supported + doc_amplitude, + ), + ] + + @fitting_args_plugin.register("dipole", doc=doc_dipole) def fitting_dipole() -> list[Argument]: diff --git a/examples/water/sog/input_torch.json b/examples/water/sog/input_torch.json index 570459cefe..5719832a7a 100644 --- a/examples/water/sog/input_torch.json +++ b/examples/water/sog/input_torch.json @@ -41,7 +41,7 @@ "type": "sog_energy", "var_name": "energy", "dim_out_sr": 1, - "dim_out_lr": 1, + "dim_out_lr": 2, "neuron_sr": [ 240, 240, @@ -127,7 +127,7 @@ "numb_steps": 500000, "gradient_max_norm": 5.0, "seed": 10, - "disp_file": "lcurve.out", + "disp_file": "lcurve_lr_2.out", "disp_freq": 100, "save_freq": 2000 } diff --git a/pyproject.toml b/pyproject.toml index 17f357c1ee..8c1dc4cb8b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,6 +174,8 @@ pin_pytorch_cpu = [ ] pin_pytorch_gpu = [ "torch==2.10.0", + "pytorch-finufft>=0.1.0", + "cufinufft>=2.5.0; platform_system=='Linux' and platform_machine=='x86_64'", ] pin_jax = [ "jax==0.5.0;python_version>='3.10'", From 5c3fbf3b2c6c4d112db58b905f98e549c6d74726 Mon Sep 17 00:00:00 2001 From: zhenyuan5090 Date: Wed, 1 Apr 2026 01:48:05 +0800 Subject: [PATCH 5/8] baseline-0.0.2, add les_draft, waiting for direct force calculation --- .../pt/model/atomic_model/les_atomic_model.py | 14 +- .../pt/model/atomic_model/sog_atomic_model.py | 14 +- deepmd/pt/model/model/les_model.py | 184 ++++++ deepmd/pt/model/model/sog_model.py | 411 ++++++++++++++ deepmd/pt/model/task/les_energy_fitting.py | 265 ++------- deepmd/pt/model/task/lr_fitting.py | 5 + deepmd/pt/model/task/sog_energy_fitting.py | 155 +----- examples/water/sog/ab_retain_graph.py | 144 +++++ .../sog/check_sog_consistency_with_cace.py | 59 ++ examples/water/sog/compare_sog_dpa3_timing.py | 84 +++ examples/water/sog/input_torch.json | 34 +- examples/water/sog/profile_sog_timing.py | 526 ++++++++++++++++++ examples/water/sog/profile_sog_whatif.py | 70 +++ .../tests/pt/model/test_les_working_layer.py | 201 +++++++ .../tests/pt/model/test_sog_working_layer.py | 182 ++++++ 15 files changed, 1947 insertions(+), 401 deletions(-) create mode 100644 examples/water/sog/ab_retain_graph.py create mode 100644 examples/water/sog/check_sog_consistency_with_cace.py create mode 100644 examples/water/sog/compare_sog_dpa3_timing.py create mode 100644 examples/water/sog/profile_sog_timing.py create mode 100644 examples/water/sog/profile_sog_whatif.py create mode 100644 source/tests/pt/model/test_les_working_layer.py create mode 100644 source/tests/pt/model/test_sog_working_layer.py diff --git a/deepmd/pt/model/atomic_model/les_atomic_model.py b/deepmd/pt/model/atomic_model/les_atomic_model.py index 9036203f01..12dab13ab4 100644 --- a/deepmd/pt/model/atomic_model/les_atomic_model.py +++ b/deepmd/pt/model/atomic_model/les_atomic_model.py @@ -73,6 +73,13 @@ def fitting_output_def(self) -> FittingOutputDef: r_differentiable=True, c_differentiable=True, ), + OutputVariableDef( + name="latent_charge", + shape=[self.fitting_net.dim_out_lr], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), ] ) @@ -165,12 +172,6 @@ def forward_atomic( h2=h2, fparam=fparam, aparam=aparam, - coord=extended_coord[:, :nloc, :], - box=( - comm_dict["box"].view(nframes, 3, 3) - if comm_dict is not None and "box" in comm_dict - else None - ), ) if self.enable_eval_fitting_last_layer_hook and "middle_output" in energy_ret: @@ -180,6 +181,7 @@ def forward_atomic( ret = { "energy": energy_ret["energy"], + "latent_charge": energy_ret["latent_charge"], } if "middle_output" in energy_ret: ret["middle_output"] = energy_ret["middle_output"] diff --git a/deepmd/pt/model/atomic_model/sog_atomic_model.py b/deepmd/pt/model/atomic_model/sog_atomic_model.py index c24d33f75c..f1658743f8 100644 --- a/deepmd/pt/model/atomic_model/sog_atomic_model.py +++ b/deepmd/pt/model/atomic_model/sog_atomic_model.py @@ -73,6 +73,13 @@ def fitting_output_def(self) -> FittingOutputDef: r_differentiable=True, c_differentiable=True, ), + OutputVariableDef( + name="latent_charge", + shape=[self.fitting_net.dim_out_lr], + reducible=False, + r_differentiable=False, + c_differentiable=False, + ), ] ) @@ -165,12 +172,6 @@ def forward_atomic( h2=h2, fparam=fparam, aparam=aparam, - coord=extended_coord[:, :nloc, :], - box=( - comm_dict["box"].view(nframes, 3, 3) - if comm_dict is not None and "box" in comm_dict - else None - ), ) if self.enable_eval_fitting_last_layer_hook and "middle_output" in energy_ret: @@ -180,6 +181,7 @@ def forward_atomic( ret = { "energy": energy_ret["energy"], + "latent_charge": energy_ret["latent_charge"], } if "middle_output" in energy_ret: ret["middle_output"] = energy_ret["middle_output"] diff --git a/deepmd/pt/model/model/les_model.py b/deepmd/pt/model/model/les_model.py index da962d7000..2c5da94165 100644 --- a/deepmd/pt/model/model/les_model.py +++ b/deepmd/pt/model/model/les_model.py @@ -5,6 +5,7 @@ ) import torch +import pytorch_finufft from deepmd.pt.model.model.transform_output import ( communicate_extended_output, @@ -98,6 +99,189 @@ def translated_output_def(self) -> dict[str, Any]: output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] return output_def + def _compute_les_frame_correction( + self, + coord: torch.Tensor, + latent_charge: torch.Tensor, + box: torch.Tensor, + ) -> torch.Tensor: + fitting = self.get_fitting_net() + runtime_device = coord.device + real_dtype = coord.dtype + complex_dtype = torch.complex128 if real_dtype == torch.float64 else torch.complex64 + latent_charge = latent_charge.to(device=runtime_device, dtype=real_dtype) + box = box.to(device=runtime_device, dtype=real_dtype) + + sigma_raw = getattr(fitting, "sigma", None) + if sigma_raw is None: + raise ValueError("LES fitting net should provide `sigma` for frame correction.") + sigma = torch.as_tensor( + sigma_raw, + dtype=real_dtype, + device=runtime_device, + ).reshape(-1)[0] + sigma = torch.clamp(sigma, min=torch.finfo(real_dtype).eps) + remove_self_interaction = bool(fitting.remove_self_interaction) + n_dl = int(fitting.n_dl) + + nf, _, _ = coord.shape + corr = torch.zeros((nf, 1), dtype=real_dtype, device=runtime_device) + + for ff in range(nf): + r_raw = coord[ff] + q = latent_charge[ff] + box_frame = box[ff] + + volume = torch.det(box_frame) + if torch.abs(volume) <= torch.finfo(real_dtype).eps: + raise ValueError("`box` is singular (near-zero volume), cannot run NUFFT.") + + cell_inv = torch.linalg.inv(box_frame) + r_frac = torch.matmul(r_raw, cell_inv) + r_frac = torch.remainder(r_frac + 0.5, 1.0) - 0.5 + pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) + point_limit = pi_tensor - 32.0 * torch.finfo(real_dtype).eps + r_in = torch.clamp( + 2.0 * pi_tensor * r_frac, + min=-point_limit, + max=point_limit, + ).contiguous() + nufft_points = r_in.transpose(0, 1).contiguous() + + norms = torch.norm(box_frame, dim=1) + nk = [max(1, int(n.item() / n_dl)) for n in norms] + n1 = torch.arange(-nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype) + n2 = torch.arange(-nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype) + n3 = torch.arange(-nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype) + + kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") + k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 + zero_mask = k_sq == 0 + + k_sq_safe = torch.where(zero_mask, torch.ones_like(k_sq), k_sq) + kfac = torch.exp(-0.5 * (sigma**2) * k_sq_safe) / k_sq_safe + kfac = kfac.to(dtype=real_dtype) + kfac[zero_mask] = 0.0 + + q_t = q.transpose(0, 1).contiguous() + charge = torch.complex(q_t, torch.zeros_like(q_t)).to(dtype=complex_dtype) + recon = pytorch_finufft.functional.finufft_type1( + nufft_points, + charge, + output_shape=tuple(int(x) for x in kx_grid.shape), + eps=1e-4, + isign=-1, + ) + conv = kfac.unsqueeze(0) * recon + ifft_conv = pytorch_finufft.functional.finufft_type2( + nufft_points, + conv, + eps=1e-4, + isign=1, + ) / (2.0 * volume) + corr[ff, 0] = (charge * ifft_conv).real.sum() + + if remove_self_interaction: + diag_sum = kfac.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume) + corr[ff, 0] -= torch.sum(q**2) * diag_sum + + return corr + + def _apply_frame_correction_lower( + self, + model_ret: dict[str, torch.Tensor], + extended_coord: torch.Tensor, + nlist: torch.Tensor, + box: Optional[torch.Tensor], + do_atomic_virial: bool, + ) -> dict[str, torch.Tensor]: + if box is None or "latent_charge" not in model_ret: + return model_ret + + nf, nloc, _ = nlist.shape + nall = extended_coord.shape[1] + coord_local = extended_coord[:, :nloc, :] + box_local = box.view(nf, 3, 3) + latent_charge = model_ret["latent_charge"] + corr_redu = self._compute_les_frame_correction(coord_local, latent_charge, box_local) + + model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to(model_ret["energy_redu"].dtype) + + if self.do_grad_r("energy") or self.do_grad_c("energy"): + corr_force_ext = -torch.autograd.grad( + corr_redu.sum(), + extended_coord, + create_graph=self.training, + retain_graph=True, + )[0].view(nf, nall, 3) + if "energy_derv_r" in model_ret: + model_ret["energy_derv_r"] = model_ret["energy_derv_r"] + corr_force_ext.unsqueeze(-2).to( + model_ret["energy_derv_r"].dtype + ) + + if self.do_grad_c("energy"): + corr_virial_ext = torch.einsum( + "fai,faj->faij", + corr_force_ext, + extended_coord, + ).reshape(nf, nall, 1, 9) + corr_virial_redu = corr_virial_ext.sum(dim=1) + if "energy_derv_c_redu" in model_ret: + model_ret["energy_derv_c_redu"] = model_ret["energy_derv_c_redu"] + corr_virial_redu.to( + model_ret["energy_derv_c_redu"].dtype + ) + if do_atomic_virial and "energy_derv_c" in model_ret: + corr_atom_virial = torch.zeros( + (nf, nall, 1, 9), + dtype=corr_virial_ext.dtype, + device=corr_virial_ext.device, + ) + corr_atom_virial[:, :nloc, :, :] = corr_virial_ext[:, :nloc, :, :] + model_ret["energy_derv_c"] = model_ret["energy_derv_c"] + corr_atom_virial.to( + model_ret["energy_derv_c"].dtype + ) + + return model_ret + + @torch.jit.export + def forward_common_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + extra_nlist_sort: bool = False, + extended_coord_corr: Optional[torch.Tensor] = None, + ) -> dict[str, torch.Tensor]: + if self.do_grad_r("energy") or self.do_grad_c("energy"): + extended_coord = extended_coord.requires_grad_(True) + model_ret = super().forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=extra_nlist_sort, + extended_coord_corr=extended_coord_corr, + ) + box = None + if comm_dict is not None and "box" in comm_dict: + box = comm_dict["box"] + return self._apply_frame_correction_lower( + model_ret, + extended_coord, + nlist, + box, + do_atomic_virial, + ) + def forward( self, coord: torch.Tensor, diff --git a/deepmd/pt/model/model/sog_model.py b/deepmd/pt/model/model/sog_model.py index 81d5a8705d..aecae0742c 100644 --- a/deepmd/pt/model/model/sog_model.py +++ b/deepmd/pt/model/model/sog_model.py @@ -5,6 +5,7 @@ ) import torch +import pytorch_finufft from deepmd.pt.model.model.transform_output import ( communicate_extended_output, @@ -45,6 +46,54 @@ def __init__( DPModelCommon.__init__(self) SOGEnergyModel_.__init__(self, *args, **kwargs) self._hessian_enabled = False + # Runtime-only caches for NUFFT correction path. + self._sog_param_cache: dict[tuple[Any, ...], tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {} + + @staticmethod + def _device_key(device: torch.device) -> str: + if device.index is None: + return device.type + return f"{device.type}:{device.index}" + + @staticmethod + def _trim_cache(cache: dict[Any, Any], max_size: int = 8) -> None: + if len(cache) > max_size: + oldest_key = next(iter(cache.keys())) + cache.pop(oldest_key, None) + + def _get_cached_sog_params( + self, + fitting: Any, + runtime_device: torch.device, + real_dtype: torch.dtype, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + wl_raw = fitting.wl + sl_raw = fitting.sl + grad_mode = torch.is_grad_enabled() and (wl_raw.requires_grad or sl_raw.requires_grad) + + wl = wl_raw if (wl_raw.device == runtime_device and wl_raw.dtype == real_dtype) else wl_raw.to(dtype=real_dtype, device=runtime_device) + sl = sl_raw if (sl_raw.device == runtime_device and sl_raw.dtype == real_dtype) else sl_raw.to(dtype=real_dtype, device=runtime_device) + min_term = -1.0 / torch.exp(-2.0 * sl) + + # Do not cache differentiable tensors across iterations. + if grad_mode: + return wl, sl, min_term + + wl_version = int(getattr(fitting.wl, "_version", 0)) + sl_version = int(getattr(fitting.sl, "_version", 0)) + cache_key = ( + self._device_key(runtime_device), + str(real_dtype), + wl_version, + sl_version, + ) + cached = self._sog_param_cache.get(cache_key) + if cached is not None: + return cached + + self._sog_param_cache[cache_key] = (wl, sl, min_term) + self._trim_cache(self._sog_param_cache) + return wl, sl, min_term def enable_hessian(self) -> None: self.__class__ = make_hessian_model(type(self)) @@ -98,6 +147,368 @@ def translated_output_def(self) -> dict[str, Any]: output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] return output_def + def _compute_sog_frame_correction( + self, + coord: torch.Tensor, + latent_charge: torch.Tensor, + box: torch.Tensor, + return_kspace_info: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, list[dict[str, torch.Tensor]]]: + if coord.dim() != 3: + raise ValueError(f"`coord` should be [nf, nloc, 3], got shape {tuple(coord.shape)}") + if latent_charge.dim() != 3: + raise ValueError( + f"`latent_charge` should be [nf, nloc, nq], got shape {tuple(latent_charge.shape)}" + ) + if coord.shape[:2] != latent_charge.shape[:2]: + raise ValueError( + "`coord` and `latent_charge` local dimensions mismatch: " + f"{tuple(coord.shape[:2])} vs {tuple(latent_charge.shape[:2])}" + ) + + fitting = self.get_fitting_net() + runtime_device = coord.device + real_dtype = coord.dtype + complex_dtype = torch.complex128 if real_dtype == torch.float64 else torch.complex64 + latent_charge = latent_charge.to(device=runtime_device, dtype=real_dtype) + box = box.to(device=runtime_device, dtype=real_dtype) + if box.dim() != 3 or box.shape[-2:] != (3, 3): + raise ValueError(f"`box` should be [nf, 3, 3], got shape {tuple(box.shape)}") + + wl, _sl, min_term = self._get_cached_sog_params( + fitting, + runtime_device, + real_dtype, + ) + remove_self_interaction = bool(fitting.remove_self_interaction) + n_dl = int(fitting.n_dl) + pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) + + nf, nloc, _ = coord.shape + corr = torch.zeros((nf, 1), dtype=real_dtype, device=runtime_device) + kspace_info: list[dict[str, torch.Tensor]] = [] + + for ff in range(nf): + r_raw = coord[ff] + q = latent_charge[ff] + box_frame = box[ff] + + volume = torch.det(box_frame) + if torch.abs(volume) <= torch.finfo(real_dtype).eps: + raise ValueError("`box` is singular (near-zero volume), cannot run NUFFT.") + + cell_inv = torch.linalg.inv(box_frame) + r_frac = torch.matmul(r_raw, cell_inv) + r_frac = torch.remainder(r_frac + 0.5, 1.0) - 0.5 + point_limit = pi_tensor - 32.0 * torch.finfo(real_dtype).eps + r_in = torch.clamp( + 2.0 * pi_tensor * r_frac, + min=-point_limit, + max=point_limit, + ).contiguous() + nufft_points = r_in.transpose(0, 1).contiguous() + + norms = torch.norm(box_frame, dim=1) + nk = tuple(max(1, int(n.item() / n_dl)) for n in norms) + n1 = torch.arange(-nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype) + n2 = torch.arange(-nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype) + n3 = torch.arange(-nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype) + kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") + k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 + zero_mask = k_sq == 0 + + kfac = wl.view(1, 1, 1, -1) * torch.exp(k_sq.unsqueeze(-1) * min_term) + kfac = kfac.sum(dim=-1) + kfac = kfac.to(dtype=real_dtype) + kfac[zero_mask] = 0.0 + output_shape = tuple(int(x) for x in kx_grid.shape) + + q_t = q.transpose(0, 1).contiguous() + charge = torch.complex(q_t, torch.zeros_like(q_t)).to(dtype=complex_dtype).contiguous() + recon = pytorch_finufft.functional.finufft_type1( + nufft_points, + charge, + output_shape=output_shape, + eps=1e-4, + isign=-1, + ) + conv = kfac.unsqueeze(0) * recon + ifft_conv = pytorch_finufft.functional.finufft_type2( + nufft_points, + conv, + eps=1e-4, + isign=1, + ) / (2.0 * volume) + corr[ff, 0] = (charge * ifft_conv).real.sum() + + if return_kspace_info: + kspace_info.append( + { + "k_sq": k_sq, + "kfac": kfac, + "nufft_points": nufft_points, + "charge": charge, + "recon": recon, + "ifft_conv": ifft_conv, + "volume": volume, + } + ) + + if remove_self_interaction: + diag_sum = kfac.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume) + corr[ff, 0] -= torch.sum(q**2) * diag_sum + + if return_kspace_info: + return corr, kspace_info + return corr + + def analytic_sog_needs_kspace(self) -> bool: + """Whether analytic derivative hook requires k-space intermediates.""" + return False + + def compute_sog_correction_derivatives( + self, + coord: torch.Tensor, + latent_charge: torch.Tensor, + box: torch.Tensor, + energy_correction: torch.Tensor, + do_atomic_virial: bool, + kspace_info: Optional[list[dict[str, torch.Tensor]]] = None, + ) -> Optional[dict[str, torch.Tensor]]: + """Optional model-layer analytic derivatives hook. + + Override this in model subclasses if analytic force/virial is available. + """ + return None + + def _try_analytic_frame_correction_derivatives( + self, + coord_local: torch.Tensor, + latent_charge: torch.Tensor, + box_local: torch.Tensor, + corr_redu: torch.Tensor, + do_atomic_virial: bool, + kspace_info: Optional[list[dict[str, torch.Tensor]]] = None, + ) -> Optional[dict[str, torch.Tensor]]: + """Try to fetch analytic correction derivatives from fitting net. + + Contract for fitting-net hook (optional): + `compute_sog_correction_derivatives(coord, latent_charge, box, energy_correction, do_atomic_virial)` + + Returns a dict with: + - `force_local`: required, shape [nf, nloc, 3] + - `virial_local`: optional, shape [nf, nloc, 1, 9] + """ + out = self.compute_sog_correction_derivatives( + coord=coord_local, + latent_charge=latent_charge, + box=box_local, + energy_correction=corr_redu, + do_atomic_virial=do_atomic_virial, + kspace_info=kspace_info, + ) + if out is None: + # Backward compatibility: allow fitting-layer hook if present. + fitting = self.get_fitting_net() + hook = getattr(fitting, "compute_sog_correction_derivatives", None) + if hook is not None: + out = hook( + coord=coord_local, + latent_charge=latent_charge, + box=box_local, + energy_correction=corr_redu, + do_atomic_virial=do_atomic_virial, + ) + if out is None: + return None + if not isinstance(out, dict): + raise TypeError( + "`compute_sog_correction_derivatives` should return dict[str, torch.Tensor] or None." + ) + if "force_local" not in out: + raise KeyError( + "`compute_sog_correction_derivatives` must provide `force_local`." + ) + + force_local = out["force_local"] + expected_force_shape = coord_local.shape + if force_local.shape != expected_force_shape: + raise ValueError( + "`force_local` shape mismatch: " + f"expected {tuple(expected_force_shape)}, got {tuple(force_local.shape)}" + ) + if force_local.device != coord_local.device: + raise ValueError( + "`force_local` device mismatch: " + f"expected {coord_local.device}, got {force_local.device}" + ) + + if "virial_local" in out: + virial_local = out["virial_local"] + expected_virial_shape = ( + coord_local.shape[0], + coord_local.shape[1], + 1, + 9, + ) + if virial_local.shape != expected_virial_shape: + raise ValueError( + "`virial_local` shape mismatch: " + f"expected {tuple(expected_virial_shape)}, got {tuple(virial_local.shape)}" + ) + if virial_local.device != coord_local.device: + raise ValueError( + "`virial_local` device mismatch: " + f"expected {coord_local.device}, got {virial_local.device}" + ) + + return out + + def _apply_frame_correction_lower( + self, + model_ret: dict[str, torch.Tensor], + extended_coord: torch.Tensor, + nlist: torch.Tensor, + box: Optional[torch.Tensor], + do_atomic_virial: bool, + ) -> dict[str, torch.Tensor]: + if box is None or "latent_charge" not in model_ret: + return model_ret + + nf, nloc, _ = nlist.shape + nall = extended_coord.shape[1] + coord_local = extended_coord[:, :nloc, :] + box_local = box.view(nf, 3, 3) + latent_charge = model_ret["latent_charge"] + # Keep latent_charge on the computational graph for both training and eval + # so SOG correction gradients can always propagate through the LR branch. + latent_charge_for_energy = latent_charge + kspace_info: Optional[list[dict[str, torch.Tensor]]] = None + if self.analytic_sog_needs_kspace(): + corr_out = self._compute_sog_frame_correction( + coord_local, + latent_charge_for_energy, + box_local, + return_kspace_info=True, + ) + assert isinstance(corr_out, tuple) + corr_redu, kspace_info = corr_out + else: + corr_redu = self._compute_sog_frame_correction( + coord_local, + latent_charge_for_energy, + box_local, + ) + + model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to(model_ret["energy_redu"].dtype) + + if self.do_grad_r("energy") or self.do_grad_c("energy"): + analytic = self._try_analytic_frame_correction_derivatives( + coord_local=coord_local, + latent_charge=latent_charge, + box_local=box_local, + corr_redu=corr_redu, + do_atomic_virial=do_atomic_virial, + kspace_info=kspace_info, + ) + if analytic is not None: + corr_force_local = analytic["force_local"].to(coord_local.dtype) + else: + # Force correction keeps full dependency on latent_charge. + # If latent_charge is differentiable, recompute correction with the + # same graph connectivity; otherwise reuse corr_redu. + if self.training and latent_charge.requires_grad: + corr_redu_for_grad = self._compute_sog_frame_correction( + coord_local, + latent_charge, + box_local, + ) + else: + corr_redu_for_grad = corr_redu + corr_force_local = -torch.autograd.grad( + corr_redu_for_grad.sum(), + coord_local, + create_graph=self.training, + retain_graph=False, + )[0].view(nf, nloc, 3) + + corr_force_ext = torch.zeros( + (nf, nall, 3), + dtype=corr_force_local.dtype, + device=corr_force_local.device, + ) + corr_force_ext[:, :nloc, :] = corr_force_local + if "energy_derv_r" in model_ret: + model_ret["energy_derv_r"] = model_ret["energy_derv_r"] + corr_force_ext.unsqueeze(-2).to( + model_ret["energy_derv_r"].dtype + ) + + if self.do_grad_c("energy"): + if analytic is not None and "virial_local" in analytic: + corr_virial_local = analytic["virial_local"].to(corr_force_local.dtype) + else: + corr_virial_local = torch.einsum( + "fai,faj->faij", + corr_force_local, + coord_local, + ).reshape(nf, nloc, 1, 9) + corr_virial_redu = corr_virial_local.sum(dim=1) + if "energy_derv_c_redu" in model_ret: + model_ret["energy_derv_c_redu"] = model_ret["energy_derv_c_redu"] + corr_virial_redu.to( + model_ret["energy_derv_c_redu"].dtype + ) + if do_atomic_virial and "energy_derv_c" in model_ret: + corr_atom_virial = torch.zeros( + (nf, nall, 1, 9), + dtype=corr_virial_local.dtype, + device=corr_virial_local.device, + ) + corr_atom_virial[:, :nloc, :, :] = corr_virial_local + model_ret["energy_derv_c"] = model_ret["energy_derv_c"] + corr_atom_virial.to( + model_ret["energy_derv_c"].dtype + ) + + return model_ret + + @torch.jit.export + def forward_common_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: Optional[torch.Tensor] = None, + fparam: Optional[torch.Tensor] = None, + aparam: Optional[torch.Tensor] = None, + do_atomic_virial: bool = False, + comm_dict: Optional[dict[str, torch.Tensor]] = None, + extra_nlist_sort: bool = False, + extended_coord_corr: Optional[torch.Tensor] = None, + ) -> dict[str, torch.Tensor]: + if self.do_grad_r("energy") or self.do_grad_c("energy"): + extended_coord = extended_coord.requires_grad_(True) + model_ret = super().forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=extra_nlist_sort, + extended_coord_corr=extended_coord_corr, + ) + box = None + if comm_dict is not None and "box" in comm_dict: + box = comm_dict["box"] + return self._apply_frame_correction_lower( + model_ret, + extended_coord, + nlist, + box, + do_atomic_virial, + ) + def forward( self, coord: torch.Tensor, diff --git a/deepmd/pt/model/task/les_energy_fitting.py b/deepmd/pt/model/task/les_energy_fitting.py index c46906de31..64e89e03e2 100644 --- a/deepmd/pt/model/task/les_energy_fitting.py +++ b/deepmd/pt/model/task/les_energy_fitting.py @@ -8,7 +8,6 @@ import numpy as np import torch -import pytorch_finufft from deepmd.dpmodel import ( FittingOutputDef, @@ -35,34 +34,7 @@ ) -LES_DEFAULT_AMPLITUDE = to_numpy_array(np.array([ - 0.2750, - 0.1375, - 0.0688, - 0.0344, - 0.0172, - 0.0086, - 0.0043, - 0.0021, - 0.0011, - 0.0005, - 0.0003, - 0.0001, -])) -LES_DEFAULT_SHIFT = to_numpy_array(np.array([ - 2.8, - 5.7, - 11.4, - 22.7, - 45.5, - 91.0, - 182.0, - 364.0, - 728.0, - 1456.0, - 2912.0, - 5823.9, -])) +LES_DEFAULT_SIGMA = to_numpy_array(np.array(2.8 / np.sqrt(2.0))) @LRFittingNet.register("les_energy") @@ -151,8 +123,7 @@ def __init__( type_map: Optional[list[str]] = None, use_aparam_as_mask: bool = False, default_fparam: Optional[list[float]] = None, - shift: Optional[Union[list[float], torch.Tensor]] = None, - amplitude: Optional[Union[list[float], torch.Tensor]] = None, + sigma: Optional[Union[float, list[float], torch.Tensor]] = None, n_dl: int = 1, **kwargs: Any, ) -> None: @@ -182,35 +153,20 @@ def __init__( default_fparam=default_fparam, **kwargs, ) - if isinstance(shift, (list, tuple)): - shift = to_numpy_array(np.array(shift)) - if isinstance(amplitude, (list, tuple)): - amplitude = to_numpy_array(np.array(amplitude)) - shift_tensor = to_torch_tensor(shift) - amplitude_tensor = to_torch_tensor(amplitude) - if shift_tensor is None: - shift_tensor = to_torch_tensor(LES_DEFAULT_SHIFT) - if amplitude_tensor is None: - amplitude_tensor = to_torch_tensor(LES_DEFAULT_AMPLITUDE) - - shift_tensor = shift_tensor.to(dtype=dtype, device=device) - amplitude_tensor = amplitude_tensor.to(dtype=dtype, device=device) - pi_tensor = torch.tensor(torch.pi, dtype=dtype, device=device) - sqr_pi_tensor = torch.sqrt(pi_tensor) - shift_safe = torch.clamp( - shift_tensor, - min=torch.finfo(shift_tensor.dtype).eps, + if isinstance(sigma, (list, tuple)): + sigma = sigma[0] if len(sigma) > 0 else None + sigma_tensor = to_torch_tensor(sigma) + if sigma_tensor is None: + sigma_tensor = to_torch_tensor(LES_DEFAULT_SIGMA) + sigma_tensor = sigma_tensor.to(dtype=dtype, device=device).reshape(1) + sigma_tensor = torch.clamp( + sigma_tensor, + min=torch.finfo(sigma_tensor.dtype).eps, ) - wl_tensor = amplitude_tensor * (sqr_pi_tensor**3) * (shift_safe**3) - sl_tensor = -torch.log(2.0 / shift_safe) self.n_dl = max(1, int(n_dl)) - self.wl = torch.nn.Parameter( - wl_tensor, - requires_grad=bool(self.trainable), - ) - self.sl = torch.nn.Parameter( - sl_tensor, + self.sigma = torch.nn.Parameter( + sigma_tensor, requires_grad=bool(self.trainable), ) self.remove_self_interaction = False @@ -226,6 +182,13 @@ def output_def(self) -> FittingOutputDef: reducible=True, r_differentiable=True, c_differentiable=True, + ), + OutputVariableDef( + name="latent_charge", + shape=[self.dim_out_lr], + reducible=False, + r_differentiable=False, + c_differentiable=False, ) ] ) @@ -233,182 +196,30 @@ def output_def(self) -> FittingOutputDef: def serialize(self) -> dict: data = super().serialize() data["type"] = "les_energy" - data["@variables"]["wl"] = to_numpy_array(self.wl) - data["@variables"]["sl"] = to_numpy_array(self.sl) - - pi_tensor = torch.tensor(torch.pi, dtype=self.sl.dtype, device=self.sl.device) - sqr_pi_tensor = torch.sqrt(pi_tensor) - shift_tensor = 2.0 * torch.exp(self.sl) - amplitude_tensor = self.wl / ((sqr_pi_tensor**3) * (shift_tensor**3)) - data["@variables"]["shift"] = to_numpy_array(shift_tensor) - data["@variables"]["amplitude"] = to_numpy_array(amplitude_tensor) + data["@variables"]["sigma"] = to_numpy_array(self.sigma) data["n_dl"] = self.n_dl return data @classmethod def deserialize(cls, data: dict) -> "LESEnergyFittingNet": data = data.copy() - variables = data.get("@variables", {}) - obj = super().deserialize(data) - wl_tensor = to_torch_tensor(variables.get("wl", None)) - sl_tensor = to_torch_tensor(variables.get("sl", None)) - shift_tensor = to_torch_tensor(variables.get("shift", None)) - amplitude_tensor = to_torch_tensor(variables.get("amplitude", None)) - - with torch.no_grad(): - if wl_tensor is not None and sl_tensor is not None: - obj.wl.copy_(wl_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device)) - obj.sl.copy_(sl_tensor.to(dtype=obj.sl.dtype, device=obj.sl.device)) - elif shift_tensor is not None and amplitude_tensor is not None: - pi_tensor = torch.tensor(torch.pi, dtype=obj.wl.dtype, device=obj.wl.device) - sqr_pi_tensor = torch.sqrt(pi_tensor) - shift_tensor = shift_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device) - amplitude_tensor = amplitude_tensor.to(dtype=obj.wl.dtype, device=obj.wl.device) - shift_safe = torch.clamp( - shift_tensor, - min=torch.finfo(shift_tensor.dtype).eps, - ) - wl_tensor = amplitude_tensor * (sqr_pi_tensor**3) * (shift_safe**3) - sl_tensor = -torch.log(2.0 / shift_safe) - obj.wl.copy_(wl_tensor) - obj.sl.copy_(sl_tensor) - return obj - - def _kernel_params(self) -> tuple[torch.Tensor, torch.Tensor]: - return self.wl, self.sl - - def compute_potential_LES_triclinic_NUFFT( - self, - r_raw: torch.Tensor, - q: torch.Tensor, - box: torch.Tensor, - ) -> torch.Tensor: - runtime_device = torch.device(device) - if q.dim() == 1: - q = q.unsqueeze(-1) - - if not torch.isfinite(r_raw).all(): - raise ValueError("`r_raw` contains non-finite values, cannot run NUFFT.") - if not torch.isfinite(box).all(): - raise ValueError("`box` contains non-finite values, cannot run NUFFT.") - - real_dtype = torch.float64 if r_raw.dtype == torch.float64 else torch.float32 - complex_dtype = torch.complex128 if real_dtype == torch.float64 else torch.complex64 - - r_raw = r_raw.to(dtype=real_dtype, device=runtime_device) - q = q.to(dtype=real_dtype, device=runtime_device) - box = box.to(dtype=real_dtype, device=runtime_device) - - volume = torch.det(box) - if torch.abs(volume) <= torch.finfo(real_dtype).eps: - raise ValueError("`box` is singular (near-zero volume), cannot run NUFFT.") - - cell_inv = torch.linalg.inv(box) - # Keep points strictly in the principal unit cell so NUFFT points are in [-pi, pi). - r_frac = torch.matmul(r_raw, cell_inv) - r_frac = torch.remainder(r_frac + 0.5, 1.0) - 0.5 - pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) - point_limit = pi_tensor - 32.0 * torch.finfo(real_dtype).eps - r_in = torch.clamp(2.0 * pi_tensor * r_frac, min=-point_limit, max=point_limit).contiguous() - nufft_points = r_in.transpose(0, 1).contiguous() + variables = data.get("@variables", {}).copy() - norms = torch.norm(box, dim=1) - nk = [max(1, int(n.item() / self.n_dl)) for n in norms] - n1 = torch.arange(-nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype) - n2 = torch.arange(-nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype) - n3 = torch.arange(-nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype) + sigma_tensor = to_torch_tensor(variables.pop("sigma", None)) + data["@variables"] = variables - kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") - k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 - mask = (k_sq==0) - - wl, sl = self._kernel_params() - min_term = -1.0 / torch.exp(-2.0 * sl) - kfac = wl.view(1, 1, 1, -1) * torch.exp(k_sq.unsqueeze(-1) * min_term) - # print("multiplier:",multiplier.shape,multiplier)c - kfac = kfac.sum(dim=-1) - kfac = kfac.to(dtype=real_dtype) - kfac[mask] = 0.0 - - atom_energy = torch.zeros((r_raw.shape[0], 1), dtype=real_dtype, device=runtime_device) - output_shape = tuple(int(x) for x in kx_grid.shape) - - def _eval_nufft(work_device: torch.device) -> torch.Tensor: - points_work = nufft_points.to(device=work_device) - q_work = q.to(device=work_device) - kfac_work = kfac.to(device=work_device) - volume_work = volume.to(device=work_device) - atom_energy_work = torch.zeros( - (r_raw.shape[0], 1), - dtype=real_dtype, - device=work_device, - ) - - q_work_t = q_work.transpose(0, 1).contiguous() - charge_work = torch.complex( - q_work_t, - torch.zeros_like(q_work_t), - ).to(dtype=complex_dtype) - recon = pytorch_finufft.functional.finufft_type1( - points_work, - charge_work, - output_shape=output_shape, - eps=1e-4, - isign=-1, - ) - con_les = torch.mul(kfac_work.unsqueeze(0), recon) - ifft_con = pytorch_finufft.functional.finufft_type2( - points_work, - con_les, - eps=1e-4, - isign=1, - ) / (2.0 * volume_work) - atom_energy_work += (charge_work * ifft_con).real.sum() - - if self.remove_self_interaction: - diag_sum = kfac_work.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume_work) - atom_energy_work -= torch.sum(q_work**2) * diag_sum - - return atom_energy_work.to(device=runtime_device) - - - atom_energy = _eval_nufft(runtime_device) - - return atom_energy - - def _corr_head( - self, - lr_val: torch.Tensor, - coord: Optional[torch.Tensor] = None, - box: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if lr_val.dim() != 3: - raise ValueError( - f"`lr_val` should have shape [nframe, nloc, nchan], got {tuple(lr_val.shape)}." - ) + obj = super().deserialize(data) - if coord is None or box is None: - return torch.zeros( - (lr_val.shape[0], lr_val.shape[1], 1), - dtype=lr_val.dtype, - device=lr_val.device, + with torch.no_grad(): + if sigma_tensor is None: + raise ValueError("LES fitting net deserialize requires `sigma` in @variables.") + obj.sigma.copy_( + sigma_tensor.to(dtype=obj.sigma.dtype, device=obj.sigma.device).reshape(1) ) + return obj - nf, nloc, _ = lr_val.shape - coord = coord.reshape(nf, nloc, 3) - box = box.reshape(nf, 3, 3) - - coord = coord.to(dtype=lr_val.dtype, device=lr_val.device) - box = box.to(dtype=lr_val.dtype, device=lr_val.device) - corr = torch.zeros((nf, nloc, 1), dtype=lr_val.dtype, device=lr_val.device) - for ff in range(nf): - box_now = box[ff] - corr[ff] = self.compute_potential_LES_triclinic_NUFFT( - coord[ff], - lr_val[ff], - box_now, - ) - return corr + def _kernel_params(self) -> tuple[torch.Tensor]: + return (self.sigma,) def forward( self, @@ -419,8 +230,6 @@ def forward( h2: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, - coord: Optional[torch.Tensor] = None, - box: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: out = self._forward_common( descriptor=descriptor, @@ -431,10 +240,10 @@ def forward( fparam=fparam, aparam=aparam, ) - short_energy = out["sr"] - corr_energy = self._corr_head(out["lr"], coord=coord, box=box) - # corr_energy = 0 - result = {"energy": short_energy + corr_energy} + result = { + "energy": out["sr"], + "latent_charge": out["lr"], + } if "middle_output" in out: result["middle_output"] = out["middle_output"] return result diff --git a/deepmd/pt/model/task/lr_fitting.py b/deepmd/pt/model/task/lr_fitting.py index 82702a38de..38803ae67a 100644 --- a/deepmd/pt/model/task/lr_fitting.py +++ b/deepmd/pt/model/task/lr_fitting.py @@ -461,6 +461,10 @@ def _sr_net_out_dim(self) -> int: def _lr_net_out_dim(self) -> int: """Set the LR FittingNet output dim.""" return self.dim_out_lr + + def _corr_head(self, lr_out: torch.Tensor) -> torch.Tensor: + # TODO: Add latent_charge correction logic after LR output is finalized. + return lr_out def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor: return torch.tile(xx.view([1, self.numb_fparam]), [nb, 1]) @@ -566,6 +570,7 @@ def _forward_common( mask = self.emask(atype).to(torch.bool) sr_out = torch.where(mask[:, :, None], sr_out, 0.0) lr_out = torch.where(mask[:, :, None], lr_out, 0.0) + lr_out = self._corr_head(lr_out) results.update({"sr": sr_out, "lr": lr_out}) return results diff --git a/deepmd/pt/model/task/sog_energy_fitting.py b/deepmd/pt/model/task/sog_energy_fitting.py index 78e6c81503..480e10a27e 100644 --- a/deepmd/pt/model/task/sog_energy_fitting.py +++ b/deepmd/pt/model/task/sog_energy_fitting.py @@ -8,7 +8,6 @@ import numpy as np import torch -import pytorch_finufft from deepmd.dpmodel import ( FittingOutputDef, @@ -225,6 +224,13 @@ def output_def(self) -> FittingOutputDef: reducible=True, r_differentiable=True, c_differentiable=True, + ), + OutputVariableDef( + name="latent_charge", + shape=[self.dim_out_lr], + reducible=False, + r_differentiable=False, + c_differentiable=False, ) ] ) @@ -307,144 +313,6 @@ def deserialize(cls, data: dict) -> "SOGEnergyFittingNet": def _kernel_params(self) -> tuple[torch.Tensor, torch.Tensor]: return self.wl, self.sl - def compute_potential_SOG_triclinic_NUFFT( - self, - r_raw: torch.Tensor, - q: torch.Tensor, - box: torch.Tensor, - ) -> torch.Tensor: - runtime_device = torch.device(device) - if q.dim() == 1: - q = q.unsqueeze(-1) - - if not torch.isfinite(r_raw).all(): - raise ValueError("`r_raw` contains non-finite values, cannot run NUFFT.") - if not torch.isfinite(box).all(): - raise ValueError("`box` contains non-finite values, cannot run NUFFT.") - - real_dtype = torch.float64 if r_raw.dtype == torch.float64 else torch.float32 - complex_dtype = torch.complex128 if real_dtype == torch.float64 else torch.complex64 - - r_raw = r_raw.to(dtype=real_dtype, device=runtime_device) - q = q.to(dtype=real_dtype, device=runtime_device) - box = box.to(dtype=real_dtype, device=runtime_device) - - volume = torch.det(box) - if torch.abs(volume) <= torch.finfo(real_dtype).eps: - raise ValueError("`box` is singular (near-zero volume), cannot run NUFFT.") - - cell_inv = torch.linalg.inv(box) - # Keep points strictly in the principal unit cell so NUFFT points are in [-pi, pi). - r_frac = torch.matmul(r_raw, cell_inv) - r_frac = torch.remainder(r_frac + 0.5, 1.0) - 0.5 - pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) - point_limit = pi_tensor - 32.0 * torch.finfo(real_dtype).eps - r_in = torch.clamp(2.0 * pi_tensor * r_frac, min=-point_limit, max=point_limit).contiguous() - nufft_points = r_in.transpose(0, 1).contiguous() - - norms = torch.norm(box, dim=1) - nk = [max(1, int(n.item() / self.n_dl)) for n in norms] - n1 = torch.arange(-nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype) - n2 = torch.arange(-nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype) - n3 = torch.arange(-nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype) - - kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") - k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 - mask = (k_sq==0) - - wl, sl = self._kernel_params() - min_term = -1.0 / torch.exp(-2.0 * sl) - kfac = wl.view(1, 1, 1, -1) * torch.exp(k_sq.unsqueeze(-1) * min_term) - # print("multiplier:",multiplier.shape,multiplier)c - kfac = kfac.sum(dim=-1) - kfac = kfac.to(dtype=real_dtype) - kfac[mask] = 0.0 - - atom_energy = torch.zeros((r_raw.shape[0], 1), dtype=real_dtype, device=runtime_device) - output_shape = tuple(int(x) for x in kx_grid.shape) - - def _eval_nufft(work_device: torch.device) -> torch.Tensor: - points_work = nufft_points.to(device=work_device) - q_work = q.to(device=work_device) - kfac_work = kfac.to(device=work_device) - volume_work = volume.to(device=work_device) - atom_energy_work = torch.zeros( - (r_raw.shape[0], 1), - dtype=real_dtype, - device=work_device, - ) - - q_work_t = q_work.transpose(0, 1).contiguous() - charge_work = torch.complex( - q_work_t, - torch.zeros_like(q_work_t), - ).to(dtype=complex_dtype) - recon = pytorch_finufft.functional.finufft_type1( - points_work, - charge_work, - output_shape=output_shape, - eps=1e-4, - isign=-1, - ) - con_sog = torch.mul(kfac_work.unsqueeze(0), recon) - ifft_con = pytorch_finufft.functional.finufft_type2( - points_work, - con_sog, - eps=1e-4, - isign=1, - ) / (2.0 * volume_work) - atom_energy_work += (charge_work * ifft_con).real.sum() - - if self.remove_self_interaction: - diag_sum = kfac_work.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume_work) - atom_energy_work -= torch.sum(q_work**2) * diag_sum - - return atom_energy_work.to(device=runtime_device) - - - atom_energy = _eval_nufft(runtime_device) - - return atom_energy - - def _corr_head( - self, - lr_val: torch.Tensor, - coord: Optional[torch.Tensor] = None, - box: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if lr_val.dim() != 3: - raise ValueError( - f"`lr_val` should have shape [nframe, nloc, nchan], got {tuple(lr_val.shape)}." - ) - if lr_val.shape[-1] != self.dim_out_lr: - raise ValueError( - f"`lr_val` channel mismatch: got {lr_val.shape[-1]}, " - f"expected dim_out_lr={self.dim_out_lr}." - ) - - if coord is None or box is None: - return torch.zeros( - (lr_val.shape[0], lr_val.shape[1], 1), - dtype=lr_val.dtype, - device=lr_val.device, - ) - - nf, nloc, _ = lr_val.shape - coord = coord.reshape(nf, nloc, 3) - box = box.reshape(nf, 3, 3) - - coord = coord.to(dtype=lr_val.dtype, device=lr_val.device) - box = box.to(dtype=lr_val.dtype, device=lr_val.device) - corr = torch.zeros((nf, nloc, 1), dtype=lr_val.dtype, device=lr_val.device) - for ff in range(nf): - box_now = box[ff] - corr[ff] = self.compute_potential_SOG_triclinic_NUFFT( - coord[ff], - lr_val[ff], - box_now, - ) - return corr - def forward( self, descriptor: torch.Tensor, @@ -454,8 +322,6 @@ def forward( h2: Optional[torch.Tensor] = None, fparam: Optional[torch.Tensor] = None, aparam: Optional[torch.Tensor] = None, - coord: Optional[torch.Tensor] = None, - box: Optional[torch.Tensor] = None, ) -> dict[str, torch.Tensor]: out = self._forward_common( descriptor=descriptor, @@ -466,9 +332,10 @@ def forward( fparam=fparam, aparam=aparam, ) - short_energy = out["sr"] - corr_energy = self._corr_head(out["lr"], coord=coord, box=box) - result = {"energy": short_energy + corr_energy} + result = { + "energy": out["sr"], + "latent_charge": out["lr"], + } if "middle_output" in out: result["middle_output"] = out["middle_output"] return result diff --git a/examples/water/sog/ab_retain_graph.py b/examples/water/sog/ab_retain_graph.py new file mode 100644 index 0000000000..560d50fe71 --- /dev/null +++ b/examples/water/sog/ab_retain_graph.py @@ -0,0 +1,144 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import json +import time +from pathlib import Path +from types import MethodType + +import torch + +from deepmd.pt.model.model import get_model + + +def sync(dev: torch.device) -> None: + if dev.type == "cuda": + torch.cuda.synchronize(dev) + + +def build_input( + nf: int = 1, + nloc: int = 192, + box_len: float = 20.0, + device: str = "cuda", + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.device]: + dev = torch.device(device if torch.cuda.is_available() else "cpu") + g = torch.Generator(device=dev) + g.manual_seed(1234) + coord = torch.rand((nf, nloc, 3), device=dev, dtype=dtype, generator=g) * box_len + atype = torch.zeros((nf, nloc), device=dev, dtype=torch.long) + atype[:, 1::3] = 1 + atype[:, 2::3] = 1 + box = torch.zeros((nf, 3, 3), device=dev, dtype=dtype) + box[:, 0, 0] = box_len + box[:, 1, 1] = box_len + box[:, 2, 2] = box_len + return coord, atype, box, dev + + +def bench( + model, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor, + reps: int = 20, + warmup: int = 5, +) -> float: + model.eval() + for _ in range(warmup): + _ = model(coord, atype, box=box) + sync(coord.device) + t0 = time.perf_counter() + for _ in range(reps): + _ = model(coord, atype, box=box) + sync(coord.device) + return (time.perf_counter() - t0) / reps * 1000.0 + + +def make_patched_apply(): + def patched_apply(self, model_ret, extended_coord, nlist, box, do_atomic_virial): + if box is None or "latent_charge" not in model_ret: + return model_ret + nf, nloc, _ = nlist.shape + nall = extended_coord.shape[1] + coord_local = extended_coord[:, :nloc, :] + box_local = box.view(nf, 3, 3) + latent_charge = model_ret["latent_charge"] + corr_redu = self._compute_sog_frame_correction(coord_local, latent_charge, box_local) + model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to(model_ret["energy_redu"].dtype) + + if self.do_grad_r("energy") or self.do_grad_c("energy"): + corr_force_local = -torch.autograd.grad( + corr_redu.sum(), + coord_local, + create_graph=self.training, + retain_graph=False, + )[0].view(nf, nloc, 3) + + corr_force_ext = torch.zeros( + (nf, nall, 3), + dtype=corr_force_local.dtype, + device=corr_force_local.device, + ) + corr_force_ext[:, :nloc, :] = corr_force_local + if "energy_derv_r" in model_ret: + model_ret["energy_derv_r"] = model_ret["energy_derv_r"] + corr_force_ext.unsqueeze(-2).to( + model_ret["energy_derv_r"].dtype + ) + + if self.do_grad_c("energy"): + corr_virial_local = torch.einsum( + "fai,faj->faij", + corr_force_local, + coord_local, + ).reshape(nf, nloc, 1, 9) + corr_virial_redu = corr_virial_local.sum(dim=1) + if "energy_derv_c_redu" in model_ret: + model_ret["energy_derv_c_redu"] = model_ret["energy_derv_c_redu"] + corr_virial_redu.to( + model_ret["energy_derv_c_redu"].dtype + ) + if do_atomic_virial and "energy_derv_c" in model_ret: + corr_atom_virial = torch.zeros( + (nf, nall, 1, 9), + dtype=corr_virial_local.dtype, + device=corr_virial_local.device, + ) + corr_atom_virial[:, :nloc, :, :] = corr_virial_local + model_ret["energy_derv_c"] = model_ret["energy_derv_c"] + corr_atom_virial.to( + model_ret["energy_derv_c"].dtype + ) + + return model_ret + + return patched_apply + + +def main() -> None: + cfg = json.loads(Path("examples/water/sog/input_torch.json").read_text())["model"] + coord, atype, box, dev = build_input() + model = get_model(cfg).to(dev) + + base = bench(model, coord, atype, box) + + orig_apply = model._apply_frame_correction_lower + model._apply_frame_correction_lower = MethodType(make_patched_apply(), model) + patched = bench(model, coord, atype, box) + + model._apply_frame_correction_lower = orig_apply + out0 = model(coord, atype, box=box) + model._apply_frame_correction_lower = MethodType(make_patched_apply(), model) + out1 = model(coord, atype, box=box) + + max_de = (out0["energy"] - out1["energy"]).abs().max().item() + max_df = (out0["force"] - out1["force"]).abs().max().item() + + print(f"baseline_ms={base:.3f}") + print(f"retain_graph_false_ms={patched:.3f}") + print(f"speedup={(base / patched):.3f}x") + print(f"max|dE|={max_de:.3e}") + print(f"max|dF|={max_df:.3e}") + + +if __name__ == "__main__": + main() diff --git a/examples/water/sog/check_sog_consistency_with_cace.py b/examples/water/sog/check_sog_consistency_with_cace.py new file mode 100644 index 0000000000..35c34e81c8 --- /dev/null +++ b/examples/water/sog/check_sog_consistency_with_cace.py @@ -0,0 +1,59 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import importlib.util +import json +import pathlib + +import torch + +from deepmd.pt.model.model import get_model + + +def main() -> None: + cfg = json.loads(pathlib.Path("examples/water/sog/input_torch.json").read_text())["model"] + dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = get_model(cfg).to(dev).eval() + fit = model.get_fitting_net() + + sog_path = pathlib.Path("/data/zyjin/cace/SOG-Net/CACE-SOG/cace/modules/sog.py") + spec = importlib.util.spec_from_file_location("cace_sog_module", sog_path) + if spec is None or spec.loader is None: + raise RuntimeError("Failed to load cace sog.py module") + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + SOGPotential = mod.SOGPotential + + cace_sog = SOGPotential(N_dl=int(fit.n_dl), Periodic=True).to(dev).eval() + with torch.no_grad(): + cace_sog.wl.copy_(fit.wl.detach().to(cace_sog.wl.dtype).to(dev)) + cace_sog.sl.copy_(fit.sl.detach().to(cace_sog.sl.dtype).to(dev)) + + nf, nloc, nq = 3, 32, 1 + coord = torch.rand(nf, nloc, 3, device=dev, dtype=torch.float32) * 10.0 + box = torch.zeros(nf, 3, 3, device=dev, dtype=torch.float32) + box[:, 0, 0] = 10.0 + box[:, 1, 1] = 11.0 + box[:, 2, 2] = 12.0 + latent = torch.randn(nf, nloc, nq, device=dev, dtype=torch.float32) + + with torch.no_grad(): + corr_deepmd = model._compute_sog_frame_correction(coord, latent, box).squeeze(-1) + corr_cace = [] + for i in range(nf): + v = cace_sog.compute_potential_SOG_triclinic_NUFFT(coord[i], latent[i], box[i]) + corr_cace.append(v.squeeze()) + corr_cace = torch.stack(corr_cace) + + abs_diff = (corr_deepmd - corr_cace).abs() + rel_diff = abs_diff / torch.clamp(corr_cace.abs(), min=1e-8) + + print("corr_deepmd", corr_deepmd.detach().cpu().numpy()) + print("corr_cace ", corr_cace.detach().cpu().numpy()) + print("max_abs_diff", abs_diff.max().item()) + print("mean_abs_diff", abs_diff.mean().item()) + print("max_rel_diff", rel_diff.max().item()) + + +if __name__ == "__main__": + main() diff --git a/examples/water/sog/compare_sog_dpa3_timing.py b/examples/water/sog/compare_sog_dpa3_timing.py new file mode 100644 index 0000000000..f067024d18 --- /dev/null +++ b/examples/water/sog/compare_sog_dpa3_timing.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import json +import time +from pathlib import Path + +import torch + +from deepmd.pt.model.model import get_model + + +def sync(dev: torch.device) -> None: + if dev.type == "cuda": + torch.cuda.synchronize(dev) + + +def build_input( + nf: int = 1, + nloc: int = 192, + box_len: float = 20.0, + device: str = "cuda", + dtype: torch.dtype = torch.float32, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.device]: + dev = torch.device(device if torch.cuda.is_available() else "cpu") + coord = torch.rand((nf, nloc, 3), device=dev, dtype=dtype) * box_len + atype = torch.zeros((nf, nloc), device=dev, dtype=torch.long) + atype[:, 1::3] = 1 + atype[:, 2::3] = 1 + box = torch.zeros((nf, 3, 3), device=dev, dtype=dtype) + box[:, 0, 0] = box_len + box[:, 1, 1] = box_len + box[:, 2, 2] = box_len + return coord, atype, box, dev + + +def bench_model( + model, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor, + warmup: int = 5, + reps: int = 30, +) -> float: + model.eval() + for _ in range(warmup): + _ = model(coord, atype, box=box) + sync(coord.device) + t0 = time.perf_counter() + for _ in range(reps): + _ = model(coord, atype, box=box) + sync(coord.device) + return (time.perf_counter() - t0) / reps * 1000.0 + + +def main() -> None: + sog_cfg = json.loads(Path("examples/water/sog/input_torch.json").read_text())["model"] + dpa_cfg = json.loads(Path("examples/water/dpa3/input_torch_copy.json").read_text())["model"] + + coord, atype, box, dev = build_input() + + sog = get_model(sog_cfg).to(dev) + dpa = get_model(dpa_cfg).to(dev) + if hasattr(sog.get_fitting_net(), "n_dl"): + sog.get_fitting_net().n_dl = 2 + + sog_ms = bench_model(sog, coord, atype, box) + + orig = sog._apply_frame_correction_lower + sog._apply_frame_correction_lower = lambda model_ret, *args, **kwargs: model_ret + sog_nocorr_ms = bench_model(sog, coord, atype, box) + sog._apply_frame_correction_lower = orig + + dpa_ms = bench_model(dpa, coord, atype, box) + + print(f"dpa3_total_ms={dpa_ms:.3f}") + print(f"sog_total_ms={sog_ms:.3f}") + print(f"sog_without_frame_corr_ms={sog_nocorr_ms:.3f}") + print(f"sog_extra_frame_corr_ms={sog_ms - sog_nocorr_ms:.3f}") + print(f"sog_vs_dpa3_delta_ms={sog_ms - dpa_ms:.3f}") + + +if __name__ == "__main__": + main() diff --git a/examples/water/sog/input_torch.json b/examples/water/sog/input_torch.json index 5719832a7a..d14e8ff812 100644 --- a/examples/water/sog/input_torch.json +++ b/examples/water/sog/input_torch.json @@ -41,7 +41,7 @@ "type": "sog_energy", "var_name": "energy", "dim_out_sr": 1, - "dim_out_lr": 2, + "dim_out_lr": 1, "neuron_sr": [ 240, 240, @@ -58,20 +58,6 @@ "seed": 1, "_comment": "SOG-specific long-range kernel parameters", "shift": [ - 0.275, - 0.1375, - 0.0688, - 0.0344, - 0.0172, - 0.0086, - 0.0043, - 0.0021, - 0.0011, - 0.0005, - 0.0003, - 0.0001 - ], - "amplitude": [ 2.8, 5.7, 11.4, @@ -84,6 +70,20 @@ 1456.0, 2912.0, 5823.9 + ], + "amplitude": [ + 0.275, + 0.1375, + 0.0688, + 0.0344, + 0.0172, + 0.0086, + 0.0043, + 0.0021, + 0.0011, + 0.0005, + 0.0003, + 0.0001 ] } }, @@ -124,10 +124,10 @@ ], "batch_size": 1 }, - "numb_steps": 500000, + "numb_steps": 1000, "gradient_max_norm": 5.0, "seed": 10, - "disp_file": "lcurve_lr_2.out", + "disp_file": "lcurve_test_3.out", "disp_freq": 100, "save_freq": 2000 } diff --git a/examples/water/sog/profile_sog_timing.py b/examples/water/sog/profile_sog_timing.py new file mode 100644 index 0000000000..cf913780a5 --- /dev/null +++ b/examples/water/sog/profile_sog_timing.py @@ -0,0 +1,526 @@ +#!/usr/bin/env python3 +"""Profile SOG model runtime breakdown using input_torch.json as model config.""" + +from __future__ import annotations + +import argparse +import json +import time +import types +from collections import defaultdict +from contextlib import nullcontext +from pathlib import Path +from typing import Any + +import torch +import pytorch_finufft + +from deepmd.pt.model.model import get_model +from deepmd.pt.model.model.sog_model import SOGEnergyModel_ +from deepmd.pt.model.model.transform_output import communicate_extended_output +from deepmd.pt.utils.nlist import extend_input_and_build_neighbor_list + + +def _sync_if_cuda(device: torch.device) -> None: + if device.type == "cuda": + torch.cuda.synchronize(device) + + +def _time_block(name: str, timings: dict[str, float], device: torch.device): + class _Timer: + def __enter__(self_inner): + _sync_if_cuda(device) + self_inner.t0 = time.perf_counter() + return self_inner + + def __exit__(self_inner, exc_type, exc, tb): + _sync_if_cuda(device) + timings[name] += time.perf_counter() - self_inner.t0 + + return _Timer() + + +def _build_synthetic_input( + nframes: int, + nloc: int, + box_len: float, + device: torch.device, + dtype: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + coord = torch.rand((nframes, nloc, 3), device=device, dtype=dtype) * box_len + atype = torch.zeros((nframes, nloc), device=device, dtype=torch.long) + + # Water-like type ratio: O:H = 1:2 + atype[:, 1::3] = 1 + atype[:, 2::3] = 1 + + box = torch.zeros((nframes, 3, 3), device=device, dtype=dtype) + box[:, 0, 0] = box_len + box[:, 1, 1] = box_len + box[:, 2, 2] = box_len + return coord, atype, box + + +def _load_model(config_path: Path, device: torch.device) -> Any: + cfg = json.loads(config_path.read_text()) + model = get_model(cfg["model"]) + model = model.to(device) + model.eval() + return model + + +def _install_fine_frame_corr_profiler( + model: Any, + detail_times: dict[str, float], + device: torch.device, + collect_flag: dict[str, bool], +) -> tuple[Any, Any]: + orig_compute = model._compute_sog_frame_correction + orig_apply = model._apply_frame_correction_lower + + def _timed_compute(self, coord: torch.Tensor, latent_charge: torch.Tensor, box: torch.Tensor) -> torch.Tensor: + if coord.dim() != 3: + raise ValueError(f"`coord` should be [nf, nloc, 3], got shape {tuple(coord.shape)}") + if latent_charge.dim() != 3: + raise ValueError( + f"`latent_charge` should be [nf, nloc, nq], got shape {tuple(latent_charge.shape)}" + ) + if coord.shape[:2] != latent_charge.shape[:2]: + raise ValueError( + "`coord` and `latent_charge` local dimensions mismatch: " + f"{tuple(coord.shape[:2])} vs {tuple(latent_charge.shape[:2])}" + ) + + fitting = self.get_fitting_net() + runtime_device = coord.device + real_dtype = coord.dtype + complex_dtype = torch.complex128 if real_dtype == torch.float64 else torch.complex64 + with _time_block("fc_cast_inputs", detail_times, device) if collect_flag["on"] else nullcontext(): + latent_charge = latent_charge.to(device=runtime_device, dtype=real_dtype) + box = box.to(device=runtime_device, dtype=real_dtype) + if box.dim() != 3 or box.shape[-2:] != (3, 3): + raise ValueError(f"`box` should be [nf, 3, 3], got shape {tuple(box.shape)}") + + with _time_block("fc_param_prepare", detail_times, device) if collect_flag["on"] else nullcontext(): + wl, _sl, min_term = self._get_cached_sog_params( + fitting, + runtime_device, + real_dtype, + ) + remove_self_interaction = bool(fitting.remove_self_interaction) + n_dl = int(fitting.n_dl) + pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) + + nf, _nloc, _ = coord.shape + corr = torch.zeros((nf, 1), dtype=real_dtype, device=runtime_device) + + for ff in range(nf): + r_raw = coord[ff] + q = latent_charge[ff] + box_frame = box[ff] + + with _time_block("fc_geom_and_points", detail_times, device) if collect_flag["on"] else nullcontext(): + volume = torch.det(box_frame) + if torch.abs(volume) <= torch.finfo(real_dtype).eps: + raise ValueError("`box` is singular (near-zero volume), cannot run NUFFT.") + + cell_inv = torch.linalg.inv(box_frame) + r_frac = torch.matmul(r_raw, cell_inv) + r_frac = torch.remainder(r_frac + 0.5, 1.0) - 0.5 + point_limit = pi_tensor - 32.0 * torch.finfo(real_dtype).eps + r_in = torch.clamp( + 2.0 * pi_tensor * r_frac, + min=-point_limit, + max=point_limit, + ).contiguous() + nufft_points = r_in.transpose(0, 1).contiguous() + + with _time_block("fc_build_k_grid", detail_times, device) if collect_flag["on"] else nullcontext(): + norms = torch.norm(box_frame, dim=1) + nk = tuple(max(1, int(n.item() / n_dl)) for n in norms) + n1 = torch.arange(-nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype) + n2 = torch.arange(-nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype) + n3 = torch.arange(-nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype) + kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") + k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 + zero_mask = k_sq == 0 + + with _time_block("fc_build_kfac", detail_times, device) if collect_flag["on"] else nullcontext(): + kfac = wl.view(1, 1, 1, -1) * torch.exp(k_sq.unsqueeze(-1) * min_term) + kfac = kfac.sum(dim=-1) + kfac = kfac.to(dtype=real_dtype) + kfac[zero_mask] = 0.0 + + with _time_block("fc_prepare_charge", detail_times, device) if collect_flag["on"] else nullcontext(): + q_t = q.transpose(0, 1).contiguous() + charge = torch.complex(q_t, torch.zeros_like(q_t)).to(dtype=complex_dtype).contiguous() + + with _time_block("fc_nufft_type1", detail_times, device) if collect_flag["on"] else nullcontext(): + recon = pytorch_finufft.functional.finufft_type1( + nufft_points, + charge, + output_shape=tuple(int(x) for x in kx_grid.shape), + eps=1e-4, + isign=-1, + ) + + with _time_block("fc_nufft_type2_reduce", detail_times, device) if collect_flag["on"] else nullcontext(): + conv = kfac.unsqueeze(0) * recon + ifft_conv = pytorch_finufft.functional.finufft_type2( + nufft_points, + conv, + eps=1e-4, + isign=1, + ) / (2.0 * volume) + corr[ff, 0] = (charge * ifft_conv).real.sum() + + if remove_self_interaction: + with _time_block("fc_self_interaction", detail_times, device) if collect_flag["on"] else nullcontext(): + diag_sum = kfac.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume) + corr[ff, 0] -= torch.sum(q**2) * diag_sum + + return corr + + def _timed_apply( + self, + model_ret: dict[str, torch.Tensor], + extended_coord: torch.Tensor, + nlist: torch.Tensor, + box: torch.Tensor | None, + do_atomic_virial: bool, + ) -> dict[str, torch.Tensor]: + with _time_block("fc_guard_and_slice", detail_times, device) if collect_flag["on"] else nullcontext(): + if box is None or "latent_charge" not in model_ret: + return model_ret + + nf, nloc, _ = nlist.shape + nall = extended_coord.shape[1] + coord_local = extended_coord[:, :nloc, :] + box_local = box.view(nf, 3, 3) + latent_charge = model_ret["latent_charge"] + latent_charge_for_energy = latent_charge if self.training else latent_charge.detach() + + with _time_block("fc_compute_corr_redu", detail_times, device) if collect_flag["on"] else nullcontext(): + corr_redu = self._compute_sog_frame_correction( + coord_local, + latent_charge_for_energy, + box_local, + ) + + with _time_block("fc_add_energy", detail_times, device) if collect_flag["on"] else nullcontext(): + model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to(model_ret["energy_redu"].dtype) + + if self.do_grad_r("energy") or self.do_grad_c("energy"): + with _time_block("fc_compute_corr_redu_for_grad", detail_times, device) if collect_flag["on"] else nullcontext(): + if self.training and latent_charge.requires_grad: + corr_redu_for_grad = self._compute_sog_frame_correction( + coord_local, + latent_charge.detach(), + box_local, + ) + else: + corr_redu_for_grad = corr_redu + + with _time_block("fc_autograd_force", detail_times, device) if collect_flag["on"] else nullcontext(): + corr_force_local = -torch.autograd.grad( + corr_redu_for_grad.sum(), + coord_local, + create_graph=self.training, + retain_graph=False, + )[0].view(nf, nloc, 3) + + with _time_block("fc_scatter_force", detail_times, device) if collect_flag["on"] else nullcontext(): + corr_force_ext = torch.zeros( + (nf, nall, 3), + dtype=corr_force_local.dtype, + device=corr_force_local.device, + ) + corr_force_ext[:, :nloc, :] = corr_force_local + if "energy_derv_r" in model_ret: + model_ret["energy_derv_r"] = model_ret["energy_derv_r"] + corr_force_ext.unsqueeze(-2).to( + model_ret["energy_derv_r"].dtype + ) + + if self.do_grad_c("energy"): + with _time_block("fc_virial_update", detail_times, device) if collect_flag["on"] else nullcontext(): + corr_virial_local = torch.einsum( + "fai,faj->faij", + corr_force_local, + coord_local, + ).reshape(nf, nloc, 1, 9) + corr_virial_redu = corr_virial_local.sum(dim=1) + if "energy_derv_c_redu" in model_ret: + model_ret["energy_derv_c_redu"] = model_ret["energy_derv_c_redu"] + corr_virial_redu.to( + model_ret["energy_derv_c_redu"].dtype + ) + if do_atomic_virial and "energy_derv_c" in model_ret: + corr_atom_virial = torch.zeros( + (nf, nall, 1, 9), + dtype=corr_virial_local.dtype, + device=corr_virial_local.device, + ) + corr_atom_virial[:, :nloc, :, :] = corr_virial_local + model_ret["energy_derv_c"] = model_ret["energy_derv_c"] + corr_atom_virial.to( + model_ret["energy_derv_c"].dtype + ) + + return model_ret + + model._compute_sog_frame_correction = types.MethodType(_timed_compute, model) + model._apply_frame_correction_lower = types.MethodType(_timed_apply, model) + return orig_compute, orig_apply + + +def profile( + model: Any, + nframes: int, + nloc: int, + box_len: float, + repeats: int, + warmup: int, + do_atomic_virial: bool, + dtype: torch.dtype, + fine_frame_profile: bool = False, +) -> dict[str, float]: + device = next(model.parameters()).device + + # NUFFT fine-grained timers by monkeypatching function calls. + nufft_times: dict[str, float] = defaultdict(float) + collect_nufft = False + orig_type1 = pytorch_finufft.functional.finufft_type1 + orig_type2 = pytorch_finufft.functional.finufft_type2 + + def timed_type1(*args, **kwargs): + if collect_nufft: + with _time_block("nufft_type1", nufft_times, device): + return orig_type1(*args, **kwargs) + return orig_type1(*args, **kwargs) + + def timed_type2(*args, **kwargs): + if collect_nufft: + with _time_block("nufft_type2", nufft_times, device): + return orig_type2(*args, **kwargs) + return orig_type2(*args, **kwargs) + + pytorch_finufft.functional.finufft_type1 = timed_type1 + pytorch_finufft.functional.finufft_type2 = timed_type2 + + timings: dict[str, float] = defaultdict(float) + detail_times: dict[str, float] = defaultdict(float) + collect_detail = {"on": False} + orig_compute = None + orig_apply = None + + if fine_frame_profile: + orig_compute, orig_apply = _install_fine_frame_corr_profiler( + model, + detail_times, + device, + collect_detail, + ) + + try: + for _ in range(warmup + repeats): + coord, atype, box = _build_synthetic_input( + nframes=nframes, + nloc=nloc, + box_len=box_len, + device=device, + dtype=dtype, + ) + + is_warmup = _ < warmup + collect_nufft = not is_warmup + collect_detail["on"] = not is_warmup + iter_times: dict[str, float] = defaultdict(float) + + with _time_block("input_cast", iter_times, device): + cc, bb, fp, ap, input_prec = model._input_type_cast( + coord, box=box, fparam=None, aparam=None + ) + + with _time_block("build_nlist", iter_times, device): + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + cc, + atype, + model.get_rcut(), + model.get_sel(), + mixed_types=True, + box=bb, + ) + + comm_dict: dict[str, torch.Tensor] | None = {"box": bb} + + if model.do_grad_r("energy") or model.do_grad_c("energy"): + extended_coord = extended_coord.requires_grad_(True) + + with _time_block("lower_super", iter_times, device): + model_ret = SOGEnergyModel_.forward_common_lower( + model, + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fp, + aparam=ap, + do_atomic_virial=do_atomic_virial, + comm_dict=comm_dict, + extra_nlist_sort=False, + extended_coord_corr=None, + ) + + with _time_block("lower_frame_corr", iter_times, device): + model_ret = model._apply_frame_correction_lower( + model_ret, + extended_coord, + nlist, + bb, + do_atomic_virial, + ) + + with _time_block("communicate_output", iter_times, device): + model_ret = communicate_extended_output( + model_ret, + model.model_output_def(), + mapping, + do_atomic_virial=do_atomic_virial, + ) + + with _time_block("output_cast", iter_times, device): + model_ret = model._output_type_cast(model_ret, input_prec) + _ = model_ret["energy_redu"] + + if not is_warmup: + iter_times["total"] = sum( + iter_times[k] + for k in [ + "input_cast", + "build_nlist", + "lower_super", + "lower_frame_corr", + "communicate_output", + "output_cast", + ] + ) + for k, v in iter_times.items(): + timings[k] += v + + # Only keep averaged timings for measured iterations. + for k in list(timings.keys()): + timings[k] /= repeats + for k, v in nufft_times.items(): + timings[k] = v / repeats + for k, v in detail_times.items(): + timings[k] = v / repeats + + return dict(timings) + finally: + pytorch_finufft.functional.finufft_type1 = orig_type1 + pytorch_finufft.functional.finufft_type2 = orig_type2 + if fine_frame_profile and orig_compute is not None and orig_apply is not None: + model._compute_sog_frame_correction = orig_compute + model._apply_frame_correction_lower = orig_apply + + +def _format_report(timings: dict[str, float]) -> str: + total = sum( + timings.get(k, 0.0) + for k in [ + "input_cast", + "build_nlist", + "lower_super", + "lower_frame_corr", + "communicate_output", + "output_cast", + ] + ) + + keys = [ + "input_cast", + "build_nlist", + "lower_super", + "lower_frame_corr", + "nufft_type1", + "nufft_type2", + "communicate_output", + "output_cast", + "fc_guard_and_slice", + "fc_compute_corr_redu", + "fc_cast_inputs", + "fc_param_prepare", + "fc_geom_and_points", + "fc_build_k_grid", + "fc_build_kfac", + "fc_prepare_charge", + "fc_nufft_type1", + "fc_nufft_type2_reduce", + "fc_self_interaction", + "fc_add_energy", + "fc_compute_corr_redu_for_grad", + "fc_autograd_force", + "fc_scatter_force", + "fc_virial_update", + ] + + lines = [] + lines.append("Timing breakdown (avg per iteration):") + for k in keys: + if k in timings: + ms = timings[k] * 1000.0 + ratio = (timings[k] / total * 100.0) if total > 0 else 0.0 + lines.append(f" - {k:20s}: {ms:10.3f} ms ({ratio:6.2f}%)") + + lines.append(f" - {'total(sum)':20s}: {total*1000.0:10.3f} ms (100.00%)") + return "\n".join(lines) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--config", + type=Path, + default=Path("examples/water/sog/input_torch.json"), + ) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float64"]) + parser.add_argument("--nframes", type=int, default=1) + parser.add_argument("--nloc", type=int, default=192) + parser.add_argument("--box-len", type=float, default=20.0) + parser.add_argument("--warmup", type=int, default=2) + parser.add_argument("--repeats", type=int, default=10) + parser.add_argument("--atomic-virial", action="store_true") + parser.add_argument("--n-dl-override", type=int, default=2) + parser.add_argument("--disable-energy-grad", action="store_true") + parser.add_argument("--fine-frame-profile", action="store_true") + args = parser.parse_args() + + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + dtype = torch.float64 if args.dtype == "float64" else torch.float32 + + model = _load_model(args.config, device) + if args.n_dl_override > 0 and hasattr(model.get_fitting_net(), "n_dl"): + model.get_fitting_net().n_dl = int(args.n_dl_override) + if args.disable_energy_grad: + model.do_grad_r = types.MethodType(lambda self, _name: False, model) + model.do_grad_c = types.MethodType(lambda self, _name: False, model) + timings = profile( + model=model, + nframes=args.nframes, + nloc=args.nloc, + box_len=args.box_len, + repeats=args.repeats, + warmup=args.warmup, + do_atomic_virial=args.atomic_virial, + dtype=dtype, + fine_frame_profile=args.fine_frame_profile, + ) + print(_format_report(timings)) + + +if __name__ == "__main__": + main() diff --git a/examples/water/sog/profile_sog_whatif.py b/examples/water/sog/profile_sog_whatif.py new file mode 100644 index 0000000000..5132e5f180 --- /dev/null +++ b/examples/water/sog/profile_sog_whatif.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python3 +from __future__ import annotations + +import copy +import json +from pathlib import Path + +import torch + +from deepmd.pt.model.model import get_model +from profile_sog_timing import profile + + +CFG_PATH = Path("examples/water/sog/input_torch.json") + + +def run(tag: str, model_cfg: dict) -> None: + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = get_model(model_cfg).to(device) + model.eval() + t = profile( + model, + nframes=1, + nloc=192, + box_len=20.0, + repeats=6, + warmup=2, + do_atomic_virial=False, + dtype=torch.float32, + ) + total = ( + t["input_cast"] + + t["build_nlist"] + + t["lower_super"] + + t["lower_frame_corr"] + + t["communicate_output"] + + t["output_cast"] + ) * 1000.0 + print( + f"{tag}: total={total:.3f}ms, " + f"lower_super={t['lower_super']*1000.0:.3f}ms, " + f"lower_frame_corr={t['lower_frame_corr']*1000.0:.3f}ms, " + f"nufft1={t.get('nufft_type1', 0.0)*1000.0:.3f}ms, " + f"nufft2={t.get('nufft_type2', 0.0)*1000.0:.3f}ms" + ) + + +def main() -> None: + cfg = json.loads(CFG_PATH.read_text()) + base_model_cfg = cfg["model"] + + run("baseline", copy.deepcopy(base_model_cfg)) + + cfg_lr1 = copy.deepcopy(base_model_cfg) + cfg_lr1["fitting_net"]["dim_out_lr"] = 1 + run("dim_out_lr=1", cfg_lr1) + + cfg_small = copy.deepcopy(base_model_cfg) + cfg_small["fitting_net"]["neuron_sr"] = [128, 128, 128] + cfg_small["fitting_net"]["neuron_lr"] = [128, 128, 128] + run("neurons=128", cfg_small) + + cfg_both = copy.deepcopy(cfg_lr1) + cfg_both["fitting_net"]["neuron_sr"] = [128, 128, 128] + cfg_both["fitting_net"]["neuron_lr"] = [128, 128, 128] + run("dim_out_lr=1 + neurons=128", cfg_both) + + +if __name__ == "__main__": + main() diff --git a/source/tests/pt/model/test_les_working_layer.py b/source/tests/pt/model/test_les_working_layer.py new file mode 100644 index 0000000000..02d0cd7578 --- /dev/null +++ b/source/tests/pt/model/test_les_working_layer.py @@ -0,0 +1,201 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import torch + +try: + import pytorch_finufft # noqa: F401 + + HAS_FINUFFT = True +except Exception: + HAS_FINUFFT = False + +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.model.les_model import ( + LESEnergyModel, +) +from deepmd.pt.model.task.les_energy_fitting import ( + LESEnergyFittingNet, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +def _reduce_extended_tensor(extended_tensor: torch.Tensor, mapping: torch.Tensor, nloc: int): + nframes = extended_tensor.shape[0] + ext_dims = extended_tensor.shape[2:] + reduced_tensor = torch.zeros( + [nframes, nloc, *ext_dims], + dtype=extended_tensor.dtype, + device=extended_tensor.device, + ) + mldims = list(mapping.shape) + mapping_exp = mapping.view(mldims + [1] * len(ext_dims)).expand( + [-1] * len(mldims) + list(ext_dims) + ) + reduced_tensor = torch.scatter_reduce( + reduced_tensor, + 1, + index=mapping_exp, + src=extended_tensor, + reduce="sum", + ) + return reduced_tensor + + +@unittest.skipIf(not HAS_FINUFFT, "pytorch_finufft is required for LES tests") +class TestLESWorkingLayer(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(2027) + self.nf = 2 + self.nloc = 4 + self.nt = 2 + self.rcut = 4.0 + self.rcut_smth = 3.5 + self.sel = [8, 8] + + self.descriptor = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + self.fitting = LESEnergyFittingNet( + var_name="energy", + ntypes=self.nt, + dim_descrpt=self.descriptor.get_dim_out(), + dim_out_sr=1, + dim_out_lr=1, + mixed_types=self.descriptor.mixed_types(), + n_dl=2, + ).to(env.DEVICE) + self.model = LESEnergyModel( + descriptor=self.descriptor, + fitting=self.fitting, + type_map=["A", "B"], + ).to(env.DEVICE) + + coord = torch.rand( + (self.nf, self.nloc, 3), + dtype=dtype, + device=env.DEVICE, + ) + cell = torch.eye(3, dtype=dtype, device=env.DEVICE).unsqueeze(0).repeat(self.nf, 1, 1) + cell = cell * 5.0 + self.coord = coord.reshape(self.nf, self.nloc * 3) + self.cell = cell.reshape(self.nf, 9) + self.atype = torch.tensor( + [[0, 0, 1, 1], [1, 0, 1, 0]], + dtype=torch.int64, + device=env.DEVICE, + ) + + def test_frame_correction_applies_once_per_frame(self) -> None: + coord3 = self.coord.view(self.nf, self.nloc, 3) + cell33 = self.cell.view(self.nf, 3, 3) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord3, + self.atype, + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=True, + box=cell33, + ) + + lower_ret = self.model.forward_common_lower( + extended_coord=extended_coord, + extended_atype=extended_atype, + nlist=nlist, + mapping=mapping, + do_atomic_virial=False, + comm_dict={"box": cell33}, + ) + + frame_corr = self.model._compute_les_frame_correction( + extended_coord[:, : self.nloc, :], + lower_ret["latent_charge"], + cell33, + ).to(lower_ret["energy_redu"].dtype) + expected_energy_redu = lower_ret["energy"].sum(dim=1) + frame_corr + + torch.testing.assert_close( + lower_ret["energy_redu"], + expected_energy_redu, + rtol=1e-8, + atol=1e-8, + ) + + def test_forward_and_forward_lower_consistency(self) -> None: + fw = self.model.forward( + self.coord, + self.atype, + box=self.cell, + do_atomic_virial=True, + ) + + coord3 = self.coord.view(self.nf, self.nloc, 3) + cell33 = self.cell.view(self.nf, 3, 3) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord3, + self.atype, + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=True, + box=cell33, + ) + + fw_lower = self.model.forward_lower( + extended_coord=extended_coord, + extended_atype=extended_atype, + nlist=nlist, + mapping=mapping, + do_atomic_virial=True, + comm_dict={"box": cell33}, + ) + + torch.testing.assert_close(fw_lower["energy"], fw["energy"], rtol=1e-8, atol=1e-8) + torch.testing.assert_close(fw_lower["virial"], fw["virial"], rtol=1e-7, atol=1e-7) + + reduced_force = _reduce_extended_tensor(fw_lower["extended_force"], mapping, self.nloc) + torch.testing.assert_close(reduced_force, fw["force"], rtol=1e-7, atol=1e-7) + + def test_long_range_params_have_gradient_path(self) -> None: + self.model.zero_grad(set_to_none=True) + out = self.model.forward( + self.coord, + self.atype, + box=self.cell, + do_atomic_virial=False, + ) + grads = torch.autograd.grad( + out["energy"].sum(), + [self.fitting.wl, self.fitting.sl], + retain_graph=False, + create_graph=False, + allow_unused=False, + ) + for gg in grads: + self.assertIsNotNone(gg) + self.assertTrue(torch.isfinite(gg).all().item()) + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt/model/test_sog_working_layer.py b/source/tests/pt/model/test_sog_working_layer.py new file mode 100644 index 0000000000..924b02bfa6 --- /dev/null +++ b/source/tests/pt/model/test_sog_working_layer.py @@ -0,0 +1,182 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import torch + +try: + import pytorch_finufft # noqa: F401 + + HAS_FINUFFT = True +except Exception: + HAS_FINUFFT = False + +from deepmd.pt.model.descriptor.se_a import ( + DescrptSeA, +) +from deepmd.pt.model.model.sog_model import ( + SOGEnergyModel, +) +from deepmd.pt.model.task.sog_energy_fitting import ( + SOGEnergyFittingNet, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) + + +dtype = env.GLOBAL_PT_FLOAT_PRECISION + + +def _reduce_extended_tensor(extended_tensor: torch.Tensor, mapping: torch.Tensor, nloc: int): + nframes = extended_tensor.shape[0] + ext_dims = extended_tensor.shape[2:] + reduced_tensor = torch.zeros( + [nframes, nloc, *ext_dims], + dtype=extended_tensor.dtype, + device=extended_tensor.device, + ) + mldims = list(mapping.shape) + mapping_exp = mapping.view(mldims + [1] * len(ext_dims)).expand( + [-1] * len(mldims) + list(ext_dims) + ) + reduced_tensor = torch.scatter_reduce( + reduced_tensor, + 1, + index=mapping_exp, + src=extended_tensor, + reduce="sum", + ) + return reduced_tensor + + +@unittest.skipIf(not HAS_FINUFFT, "pytorch_finufft is required for SOG tests") +class TestSOGWorkingLayer(unittest.TestCase): + def setUp(self) -> None: + torch.manual_seed(2026) + self.nf = 2 + self.nloc = 4 + self.nt = 2 + self.rcut = 4.0 + self.rcut_smth = 3.5 + self.sel = [8, 8] + + self.descriptor = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(env.DEVICE) + self.fitting = SOGEnergyFittingNet( + var_name="energy", + ntypes=self.nt, + dim_descrpt=self.descriptor.get_dim_out(), + dim_out_sr=1, + dim_out_lr=1, + mixed_types=self.descriptor.mixed_types(), + n_dl=2, + ).to(env.DEVICE) + self.model = SOGEnergyModel( + descriptor=self.descriptor, + fitting=self.fitting, + type_map=["A", "B"], + ).to(env.DEVICE) + + coord = torch.rand( + (self.nf, self.nloc, 3), + dtype=dtype, + device=env.DEVICE, + ) + cell = torch.eye(3, dtype=dtype, device=env.DEVICE).unsqueeze(0).repeat(self.nf, 1, 1) + cell = cell * 5.0 + self.coord = coord.reshape(self.nf, self.nloc * 3) + self.cell = cell.reshape(self.nf, 9) + self.atype = torch.tensor( + [[0, 0, 1, 1], [1, 0, 1, 0]], + dtype=torch.int64, + device=env.DEVICE, + ) + + def test_frame_correction_applies_once_per_frame(self) -> None: + coord3 = self.coord.view(self.nf, self.nloc, 3) + cell33 = self.cell.view(self.nf, 3, 3) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord3, + self.atype, + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=True, + box=cell33, + ) + + lower_ret = self.model.forward_common_lower( + extended_coord=extended_coord, + extended_atype=extended_atype, + nlist=nlist, + mapping=mapping, + do_atomic_virial=False, + comm_dict={"box": cell33}, + ) + + frame_corr = self.model._compute_sog_frame_correction( + extended_coord[:, : self.nloc, :], + lower_ret["latent_charge"], + cell33, + ).to(lower_ret["energy_redu"].dtype) + expected_energy_redu = lower_ret["energy"].sum(dim=1) + frame_corr + + torch.testing.assert_close( + lower_ret["energy_redu"], + expected_energy_redu, + rtol=1e-8, + atol=1e-8, + ) + + def test_forward_and_forward_lower_consistency(self) -> None: + fw = self.model.forward( + self.coord, + self.atype, + box=self.cell, + do_atomic_virial=True, + ) + + coord3 = self.coord.view(self.nf, self.nloc, 3) + cell33 = self.cell.view(self.nf, 3, 3) + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord3, + self.atype, + self.model.get_rcut(), + self.model.get_sel(), + mixed_types=True, + box=cell33, + ) + + fw_lower = self.model.forward_lower( + extended_coord=extended_coord, + extended_atype=extended_atype, + nlist=nlist, + mapping=mapping, + do_atomic_virial=True, + comm_dict={"box": cell33}, + ) + + torch.testing.assert_close(fw_lower["energy"], fw["energy"], rtol=1e-8, atol=1e-8) + torch.testing.assert_close(fw_lower["virial"], fw["virial"], rtol=1e-7, atol=1e-7) + + reduced_force = _reduce_extended_tensor(fw_lower["extended_force"], mapping, self.nloc) + torch.testing.assert_close(reduced_force, fw["force"], rtol=1e-7, atol=1e-7) + + +if __name__ == "__main__": + unittest.main() From 6e8d01927e24d02c559410819273f580d37aff6d Mon Sep 17 00:00:00 2001 From: zhenyuan5090 Date: Thu, 2 Apr 2026 17:15:20 +0800 Subject: [PATCH 6/8] baseline-0.0.3 1. implementation of SOG-Net and LES, based on the local charge setting 2. direct calculation of force and virial, instead of autograd of energy todo: 1. improve of computational efficiency of kernel function(mainly at NUFFT) 2. more experiments to verify the correctness and effectiveness of the implementation 3. implementation on the LAMMPS end --- deepmd/pt/model/model/les_model.py | 135 ++++++++--- deepmd/pt/model/model/sog_model.py | 259 +++++++-------------- deepmd/pt/model/task/les_energy_fitting.py | 8 +- deepmd/pt/model/task/lr_fitting.py | 2 +- deepmd/pt/model/task/sog_energy_fitting.py | 9 +- deepmd/utils/argcheck.py | 36 +++ examples/water/sog/input_torch.json | 12 +- examples/water/sog/profile_sog_timing.py | 141 +++++++---- 8 files changed, 329 insertions(+), 273 deletions(-) diff --git a/deepmd/pt/model/model/les_model.py b/deepmd/pt/model/model/les_model.py index 2c5da94165..f0b08c0ca4 100644 --- a/deepmd/pt/model/model/les_model.py +++ b/deepmd/pt/model/model/les_model.py @@ -99,12 +99,15 @@ def translated_output_def(self) -> dict[str, Any]: output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] return output_def - def _compute_les_frame_correction( + def _compute_les_frame_correction_bundle( self, coord: torch.Tensor, latent_charge: torch.Tensor, box: torch.Tensor, - ) -> torch.Tensor: + *, + need_force: bool, + need_virial: bool, + ) -> dict[str, torch.Tensor]: fitting = self.get_fitting_net() runtime_device = coord.device real_dtype = coord.dtype @@ -123,9 +126,21 @@ def _compute_les_frame_correction( sigma = torch.clamp(sigma, min=torch.finfo(real_dtype).eps) remove_self_interaction = bool(fitting.remove_self_interaction) n_dl = int(fitting.n_dl) + pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) + two_pi = torch.tensor(2.0 * torch.pi, dtype=real_dtype, device=runtime_device) - nf, _, _ = coord.shape + nf, nloc, _ = coord.shape corr = torch.zeros((nf, 1), dtype=real_dtype, device=runtime_device) + force_local = ( + torch.zeros((nf, nloc, 3), dtype=real_dtype, device=runtime_device) + if need_force + else None + ) + virial_local = ( + torch.zeros((nf, nloc, 1, 9), dtype=real_dtype, device=runtime_device) + if need_virial + else None + ) for ff in range(nf): r_raw = coord[ff] @@ -139,7 +154,6 @@ def _compute_les_frame_correction( cell_inv = torch.linalg.inv(box_frame) r_frac = torch.matmul(r_raw, cell_inv) r_frac = torch.remainder(r_frac + 0.5, 1.0) - 0.5 - pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) point_limit = pi_tensor - 32.0 * torch.finfo(real_dtype).eps r_in = torch.clamp( 2.0 * pi_tensor * r_frac, @@ -149,7 +163,7 @@ def _compute_les_frame_correction( nufft_points = r_in.transpose(0, 1).contiguous() norms = torch.norm(box_frame, dim=1) - nk = [max(1, int(n.item() / n_dl)) for n in norms] + nk = tuple(max(1, int(n.item() / n_dl)) for n in norms) n1 = torch.arange(-nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype) n2 = torch.arange(-nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype) n3 = torch.arange(-nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype) @@ -164,28 +178,73 @@ def _compute_les_frame_correction( kfac[zero_mask] = 0.0 q_t = q.transpose(0, 1).contiguous() - charge = torch.complex(q_t, torch.zeros_like(q_t)).to(dtype=complex_dtype) + charge = torch.complex(q_t, torch.zeros_like(q_t)).to(dtype=complex_dtype).contiguous() + output_shape = tuple(int(x) for x in kx_grid.shape) recon = pytorch_finufft.functional.finufft_type1( nufft_points, charge, - output_shape=tuple(int(x) for x in kx_grid.shape), + output_shape=output_shape, eps=1e-4, isign=-1, ) - conv = kfac.unsqueeze(0) * recon - ifft_conv = pytorch_finufft.functional.finufft_type2( - nufft_points, - conv, - eps=1e-4, - isign=1, - ) / (2.0 * volume) - corr[ff, 0] = (charge * ifft_conv).real.sum() + + rho_sq = recon.real.square() + recon.imag.square() + corr[ff, 0] = (kfac.unsqueeze(0) * rho_sq).sum() / (2.0 * volume) + + conv = None + if need_force: + conv = kfac.unsqueeze(0).to(dtype=complex_dtype) * recon + + if need_force: + assert conv is not None + kk1 = torch.fft.ifftshift(kx_grid, dim=0) + kk2 = torch.fft.ifftshift(ky_grid, dim=1) + kk3 = torch.fft.ifftshift(kz_grid, dim=2) + k_grid = torch.stack((kk1, kk2, kk3), dim=0) + g_cart = two_pi * torch.einsum("ik,k...->i...", cell_inv, k_grid) + grad_conv = (1j * g_cart.unsqueeze(1).to(dtype=complex_dtype)) * conv.unsqueeze(0) + grad_field = pytorch_finufft.functional.finufft_type2( + nufft_points, + grad_conv, + eps=1e-4, + isign=1, + ) + force_frame = -(q_t.unsqueeze(0) * grad_field.real.to(dtype=real_dtype)).sum(dim=1).transpose(0, 1) + force_frame = force_frame / volume + force_local[ff] = force_frame + + if need_virial: + virial_local[ff] = torch.einsum( + "ai,aj->aij", + force_frame, + r_raw, + ).reshape(nloc, 1, 9) if remove_self_interaction: diag_sum = kfac.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume) corr[ff, 0] -= torch.sum(q**2) * diag_sum - return corr + out: dict[str, torch.Tensor] = {"corr_redu": corr} + if force_local is not None: + out["force_local"] = force_local + if virial_local is not None: + out["virial_local"] = virial_local + return out + + def _compute_les_frame_correction( + self, + coord: torch.Tensor, + latent_charge: torch.Tensor, + box: torch.Tensor, + ) -> torch.Tensor: + out = self._compute_les_frame_correction_bundle( + coord, + latent_charge, + box, + need_force=False, + need_virial=False, + ) + return out["corr_redu"] def _apply_frame_correction_lower( self, @@ -203,29 +262,37 @@ def _apply_frame_correction_lower( coord_local = extended_coord[:, :nloc, :] box_local = box.view(nf, 3, 3) latent_charge = model_ret["latent_charge"] - corr_redu = self._compute_les_frame_correction(coord_local, latent_charge, box_local) + need_force = self.do_grad_r("energy") or self.do_grad_c("energy") + need_virial = self.do_grad_c("energy") + latent_charge_runtime = latent_charge if self.training else latent_charge.detach() + corr_bundle = self._compute_les_frame_correction_bundle( + coord_local, + latent_charge_runtime, + box_local, + need_force=need_force, + need_virial=need_virial, + ) + corr_redu = corr_bundle["corr_redu"] model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to(model_ret["energy_redu"].dtype) - if self.do_grad_r("energy") or self.do_grad_c("energy"): - corr_force_ext = -torch.autograd.grad( - corr_redu.sum(), - extended_coord, - create_graph=self.training, - retain_graph=True, - )[0].view(nf, nall, 3) + if need_force: + corr_force_local = corr_bundle["force_local"].to(coord_local.dtype) + + corr_force_ext = torch.zeros( + (nf, nall, 3), + dtype=corr_force_local.dtype, + device=corr_force_local.device, + ) + corr_force_ext[:, :nloc, :] = corr_force_local if "energy_derv_r" in model_ret: model_ret["energy_derv_r"] = model_ret["energy_derv_r"] + corr_force_ext.unsqueeze(-2).to( model_ret["energy_derv_r"].dtype ) - if self.do_grad_c("energy"): - corr_virial_ext = torch.einsum( - "fai,faj->faij", - corr_force_ext, - extended_coord, - ).reshape(nf, nall, 1, 9) - corr_virial_redu = corr_virial_ext.sum(dim=1) + if need_virial: + corr_virial_local = corr_bundle["virial_local"].to(corr_force_local.dtype) + corr_virial_redu = corr_virial_local.sum(dim=1) if "energy_derv_c_redu" in model_ret: model_ret["energy_derv_c_redu"] = model_ret["energy_derv_c_redu"] + corr_virial_redu.to( model_ret["energy_derv_c_redu"].dtype @@ -233,10 +300,10 @@ def _apply_frame_correction_lower( if do_atomic_virial and "energy_derv_c" in model_ret: corr_atom_virial = torch.zeros( (nf, nall, 1, 9), - dtype=corr_virial_ext.dtype, - device=corr_virial_ext.device, + dtype=corr_virial_local.dtype, + device=corr_virial_local.device, ) - corr_atom_virial[:, :nloc, :, :] = corr_virial_ext[:, :nloc, :, :] + corr_atom_virial[:, :nloc, :, :] = corr_virial_local model_ret["energy_derv_c"] = model_ret["energy_derv_c"] + corr_atom_virial.to( model_ret["energy_derv_c"].dtype ) diff --git a/deepmd/pt/model/model/sog_model.py b/deepmd/pt/model/model/sog_model.py index aecae0742c..5ef6f96bed 100644 --- a/deepmd/pt/model/model/sog_model.py +++ b/deepmd/pt/model/model/sog_model.py @@ -147,13 +147,15 @@ def translated_output_def(self) -> dict[str, Any]: output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] return output_def - def _compute_sog_frame_correction( + def _compute_sog_frame_correction_bundle( self, coord: torch.Tensor, latent_charge: torch.Tensor, box: torch.Tensor, - return_kspace_info: bool = False, - ) -> torch.Tensor | tuple[torch.Tensor, list[dict[str, torch.Tensor]]]: + *, + need_force: bool, + need_virial: bool, + ) -> dict[str, torch.Tensor]: if coord.dim() != 3: raise ValueError(f"`coord` should be [nf, nloc, 3], got shape {tuple(coord.shape)}") if latent_charge.dim() != 3: @@ -183,10 +185,20 @@ def _compute_sog_frame_correction( remove_self_interaction = bool(fitting.remove_self_interaction) n_dl = int(fitting.n_dl) pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) + two_pi = torch.tensor(2.0 * torch.pi, dtype=real_dtype, device=runtime_device) nf, nloc, _ = coord.shape corr = torch.zeros((nf, 1), dtype=real_dtype, device=runtime_device) - kspace_info: list[dict[str, torch.Tensor]] = [] + force_local = ( + torch.zeros((nf, nloc, 3), dtype=real_dtype, device=runtime_device) + if need_force + else None + ) + virial_local = ( + torch.zeros((nf, nloc, 1, 9), dtype=real_dtype, device=runtime_device) + if need_virial + else None + ) for ff in range(nf): r_raw = coord[ff] @@ -232,137 +244,66 @@ def _compute_sog_frame_correction( eps=1e-4, isign=-1, ) - conv = kfac.unsqueeze(0) * recon - ifft_conv = pytorch_finufft.functional.finufft_type2( - nufft_points, - conv, - eps=1e-4, - isign=1, - ) / (2.0 * volume) - corr[ff, 0] = (charge * ifft_conv).real.sum() - - if return_kspace_info: - kspace_info.append( - { - "k_sq": k_sq, - "kfac": kfac, - "nufft_points": nufft_points, - "charge": charge, - "recon": recon, - "ifft_conv": ifft_conv, - "volume": volume, - } + + rho_sq = recon.real.square() + recon.imag.square() + corr[ff, 0] = (kfac.unsqueeze(0) * rho_sq).sum() / (2.0 * volume) + + conv = None + if need_force: + conv = kfac.unsqueeze(0).to(dtype=complex_dtype) * recon + + if need_force: + assert conv is not None + # Reuse the already built k-grid and only reorder it to the FFT + # storage order required by type-2 inputs. + kk1 = torch.fft.ifftshift(kx_grid, dim=0) + kk2 = torch.fft.ifftshift(ky_grid, dim=1) + kk3 = torch.fft.ifftshift(kz_grid, dim=2) + k_grid = torch.stack((kk1, kk2, kk3), dim=0) + g_cart = two_pi * torch.einsum("ik,k...->i...", cell_inv, k_grid) + grad_conv = (1j * g_cart.unsqueeze(1).to(dtype=complex_dtype)) * conv.unsqueeze(0) + grad_field = pytorch_finufft.functional.finufft_type2( + nufft_points, + grad_conv, + eps=1e-4, + isign=1, ) + force_frame = -(q_t.unsqueeze(0) * grad_field.real.to(dtype=real_dtype)).sum(dim=1).transpose(0, 1) + force_frame = force_frame / volume + force_local[ff] = force_frame + + if need_virial: + virial_local[ff] = torch.einsum( + "ai,aj->aij", + force_frame, + r_raw, + ).reshape(nloc, 1, 9) if remove_self_interaction: diag_sum = kfac.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume) corr[ff, 0] -= torch.sum(q**2) * diag_sum - if return_kspace_info: - return corr, kspace_info - return corr - - def analytic_sog_needs_kspace(self) -> bool: - """Whether analytic derivative hook requires k-space intermediates.""" - return False + out: dict[str, torch.Tensor] = {"corr_redu": corr} + if force_local is not None: + out["force_local"] = force_local + if virial_local is not None: + out["virial_local"] = virial_local + return out - def compute_sog_correction_derivatives( + def _compute_sog_frame_correction( self, coord: torch.Tensor, latent_charge: torch.Tensor, box: torch.Tensor, - energy_correction: torch.Tensor, - do_atomic_virial: bool, - kspace_info: Optional[list[dict[str, torch.Tensor]]] = None, - ) -> Optional[dict[str, torch.Tensor]]: - """Optional model-layer analytic derivatives hook. - - Override this in model subclasses if analytic force/virial is available. - """ - return None - - def _try_analytic_frame_correction_derivatives( - self, - coord_local: torch.Tensor, - latent_charge: torch.Tensor, - box_local: torch.Tensor, - corr_redu: torch.Tensor, - do_atomic_virial: bool, - kspace_info: Optional[list[dict[str, torch.Tensor]]] = None, - ) -> Optional[dict[str, torch.Tensor]]: - """Try to fetch analytic correction derivatives from fitting net. - - Contract for fitting-net hook (optional): - `compute_sog_correction_derivatives(coord, latent_charge, box, energy_correction, do_atomic_virial)` - - Returns a dict with: - - `force_local`: required, shape [nf, nloc, 3] - - `virial_local`: optional, shape [nf, nloc, 1, 9] - """ - out = self.compute_sog_correction_derivatives( - coord=coord_local, - latent_charge=latent_charge, - box=box_local, - energy_correction=corr_redu, - do_atomic_virial=do_atomic_virial, - kspace_info=kspace_info, + ) -> torch.Tensor: + out = self._compute_sog_frame_correction_bundle( + coord, + latent_charge, + box, + need_force=False, + need_virial=False, ) - if out is None: - # Backward compatibility: allow fitting-layer hook if present. - fitting = self.get_fitting_net() - hook = getattr(fitting, "compute_sog_correction_derivatives", None) - if hook is not None: - out = hook( - coord=coord_local, - latent_charge=latent_charge, - box=box_local, - energy_correction=corr_redu, - do_atomic_virial=do_atomic_virial, - ) - if out is None: - return None - if not isinstance(out, dict): - raise TypeError( - "`compute_sog_correction_derivatives` should return dict[str, torch.Tensor] or None." - ) - if "force_local" not in out: - raise KeyError( - "`compute_sog_correction_derivatives` must provide `force_local`." - ) - - force_local = out["force_local"] - expected_force_shape = coord_local.shape - if force_local.shape != expected_force_shape: - raise ValueError( - "`force_local` shape mismatch: " - f"expected {tuple(expected_force_shape)}, got {tuple(force_local.shape)}" - ) - if force_local.device != coord_local.device: - raise ValueError( - "`force_local` device mismatch: " - f"expected {coord_local.device}, got {force_local.device}" - ) - - if "virial_local" in out: - virial_local = out["virial_local"] - expected_virial_shape = ( - coord_local.shape[0], - coord_local.shape[1], - 1, - 9, - ) - if virial_local.shape != expected_virial_shape: - raise ValueError( - "`virial_local` shape mismatch: " - f"expected {tuple(expected_virial_shape)}, got {tuple(virial_local.shape)}" - ) - if virial_local.device != coord_local.device: - raise ValueError( - "`virial_local` device mismatch: " - f"expected {coord_local.device}, got {virial_local.device}" - ) - - return out + return out["corr_redu"] def _apply_frame_correction_lower( self, @@ -380,57 +321,22 @@ def _apply_frame_correction_lower( coord_local = extended_coord[:, :nloc, :] box_local = box.view(nf, 3, 3) latent_charge = model_ret["latent_charge"] - # Keep latent_charge on the computational graph for both training and eval - # so SOG correction gradients can always propagate through the LR branch. - latent_charge_for_energy = latent_charge - kspace_info: Optional[list[dict[str, torch.Tensor]]] = None - if self.analytic_sog_needs_kspace(): - corr_out = self._compute_sog_frame_correction( - coord_local, - latent_charge_for_energy, - box_local, - return_kspace_info=True, - ) - assert isinstance(corr_out, tuple) - corr_redu, kspace_info = corr_out - else: - corr_redu = self._compute_sog_frame_correction( - coord_local, - latent_charge_for_energy, - box_local, - ) + need_force = self.do_grad_r("energy") or self.do_grad_c("energy") + need_virial = self.do_grad_c("energy") + latent_charge_runtime = latent_charge if self.training else latent_charge.detach() + corr_bundle = self._compute_sog_frame_correction_bundle( + coord_local, + latent_charge_runtime, + box_local, + need_force=need_force, + need_virial=need_virial, + ) + corr_redu = corr_bundle["corr_redu"] model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to(model_ret["energy_redu"].dtype) - if self.do_grad_r("energy") or self.do_grad_c("energy"): - analytic = self._try_analytic_frame_correction_derivatives( - coord_local=coord_local, - latent_charge=latent_charge, - box_local=box_local, - corr_redu=corr_redu, - do_atomic_virial=do_atomic_virial, - kspace_info=kspace_info, - ) - if analytic is not None: - corr_force_local = analytic["force_local"].to(coord_local.dtype) - else: - # Force correction keeps full dependency on latent_charge. - # If latent_charge is differentiable, recompute correction with the - # same graph connectivity; otherwise reuse corr_redu. - if self.training and latent_charge.requires_grad: - corr_redu_for_grad = self._compute_sog_frame_correction( - coord_local, - latent_charge, - box_local, - ) - else: - corr_redu_for_grad = corr_redu - corr_force_local = -torch.autograd.grad( - corr_redu_for_grad.sum(), - coord_local, - create_graph=self.training, - retain_graph=False, - )[0].view(nf, nloc, 3) + if need_force: + corr_force_local = corr_bundle["force_local"].to(coord_local.dtype) corr_force_ext = torch.zeros( (nf, nall, 3), @@ -443,15 +349,8 @@ def _apply_frame_correction_lower( model_ret["energy_derv_r"].dtype ) - if self.do_grad_c("energy"): - if analytic is not None and "virial_local" in analytic: - corr_virial_local = analytic["virial_local"].to(corr_force_local.dtype) - else: - corr_virial_local = torch.einsum( - "fai,faj->faij", - corr_force_local, - coord_local, - ).reshape(nf, nloc, 1, 9) + if need_virial: + corr_virial_local = corr_bundle["virial_local"].to(corr_force_local.dtype) corr_virial_redu = corr_virial_local.sum(dim=1) if "energy_derv_c_redu" in model_ret: model_ret["energy_derv_c_redu"] = model_ret["energy_derv_c_redu"] + corr_virial_redu.to( diff --git a/deepmd/pt/model/task/les_energy_fitting.py b/deepmd/pt/model/task/les_energy_fitting.py index 64e89e03e2..92e60816a2 100644 --- a/deepmd/pt/model/task/les_energy_fitting.py +++ b/deepmd/pt/model/task/les_energy_fitting.py @@ -96,6 +96,10 @@ class LESEnergyFittingNet(LRFittingNet): default_fparam: list[float], optional The default frame parameter. If set, when `fparam.npy` files are not included in the data system, this value will be used as the default value for the frame parameter in the fitting net. + n_dl : int + NUFFT long-range grid density control factor. + remove_self_interaction : bool + If True, remove self interaction term in long-range correction. """ def __init__( @@ -125,6 +129,7 @@ def __init__( default_fparam: Optional[list[float]] = None, sigma: Optional[Union[float, list[float], torch.Tensor]] = None, n_dl: int = 1, + remove_self_interaction: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -169,7 +174,7 @@ def __init__( sigma_tensor, requires_grad=bool(self.trainable), ) - self.remove_self_interaction = False + self.remove_self_interaction = bool(remove_self_interaction) self._nufft_fallback_warned = False @@ -198,6 +203,7 @@ def serialize(self) -> dict: data["type"] = "les_energy" data["@variables"]["sigma"] = to_numpy_array(self.sigma) data["n_dl"] = self.n_dl + data["remove_self_interaction"] = bool(self.remove_self_interaction) return data @classmethod diff --git a/deepmd/pt/model/task/lr_fitting.py b/deepmd/pt/model/task/lr_fitting.py index 38803ae67a..b245632e54 100644 --- a/deepmd/pt/model/task/lr_fitting.py +++ b/deepmd/pt/model/task/lr_fitting.py @@ -415,7 +415,7 @@ def set_return_middle_output(self, return_middle_output: bool = True) -> None: def __setitem__(self, key: str, value: torch.Tensor) -> None: if key in ["bias_atom_e"]: - value = value.view([self.ntypes, self._net_out_dim()]) + value = value.view([self.ntypes, self._sr_net_out_dim()]) self.bias_atom_e = value elif key in ["fparam_avg"]: self.fparam_avg = value diff --git a/deepmd/pt/model/task/sog_energy_fitting.py b/deepmd/pt/model/task/sog_energy_fitting.py index 480e10a27e..e95757c981 100644 --- a/deepmd/pt/model/task/sog_energy_fitting.py +++ b/deepmd/pt/model/task/sog_energy_fitting.py @@ -123,6 +123,10 @@ class SOGEnergyFittingNet(LRFittingNet): default_fparam: list[float], optional The default frame parameter. If set, when `fparam.npy` files are not included in the data system, this value will be used as the default value for the frame parameter in the fitting net. + n_dl : int + NUFFT long-range grid density control factor. + remove_self_interaction : bool + If True, remove self interaction term in long-range correction. """ def __init__( @@ -153,6 +157,7 @@ def __init__( shift: Optional[Union[list[float], torch.Tensor]] = None, amplitude: Optional[Union[list[float], torch.Tensor]] = None, n_dl: int = 1, + remove_self_interaction: bool = False, **kwargs: Any, ) -> None: super().__init__( @@ -212,7 +217,7 @@ def __init__( sl_tensor, requires_grad=bool(self.trainable), ) - self.remove_self_interaction = False + self.remove_self_interaction = bool(remove_self_interaction) self._nufft_fallback_warned = False def output_def(self) -> FittingOutputDef: @@ -282,11 +287,13 @@ def serialize(self) -> dict: variables["shift"] = to_numpy_array(shift_tensor) variables["amplitude"] = to_numpy_array(amplitude_tensor) data["n_dl"] = self.n_dl + data["remove_self_interaction"] = bool(self.remove_self_interaction) return data @classmethod def deserialize(cls, data: dict) -> "SOGEnergyFittingNet": data = data.copy() + variables = data.get("@variables", {}).copy() wl_tensor = to_torch_tensor(variables.pop("wl", None)) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index aa1dafd57c..448c359712 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2080,6 +2080,10 @@ def fitting_sog_energy() -> list[Argument]: ) doc_shift = "Shift values of the SOG long-range correction kernels." doc_amplitude = "Amplitude values of the SOG long-range correction kernels." + doc_n_dl = "NUFFT long-range grid density control factor." + doc_remove_self_interaction = ( + "Whether to remove self interaction term in long-range correction." + ) return [ Argument( @@ -2183,6 +2187,20 @@ def fitting_sog_energy() -> list[Argument]: default=False, doc=doc_only_pt_supported + doc_use_aparam_as_mask, ), + Argument( + "n_dl", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_n_dl, + ), + Argument( + "remove_self_interaction", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_remove_self_interaction, + ), Argument( "shift", list[float], @@ -2250,6 +2268,10 @@ def fitting_les_energy() -> list[Argument]: ) doc_shift = "Shift values of the LES long-range correction kernels." doc_amplitude = "Amplitude values of the LES long-range correction kernels." + doc_n_dl = "NUFFT long-range grid density control factor." + doc_remove_self_interaction = ( + "Whether to remove self interaction term in long-range correction." + ) return [ Argument( @@ -2353,6 +2375,20 @@ def fitting_les_energy() -> list[Argument]: default=False, doc=doc_only_pt_supported + doc_use_aparam_as_mask, ), + Argument( + "n_dl", + int, + optional=True, + default=1, + doc=doc_only_pt_supported + doc_n_dl, + ), + Argument( + "remove_self_interaction", + bool, + optional=True, + default=False, + doc=doc_only_pt_supported + doc_remove_self_interaction, + ), Argument( "shift", list[float], diff --git a/examples/water/sog/input_torch.json b/examples/water/sog/input_torch.json index d14e8ff812..91ee791564 100644 --- a/examples/water/sog/input_torch.json +++ b/examples/water/sog/input_torch.json @@ -48,14 +48,16 @@ 240 ], "neuron_lr": [ - 240, - 240, - 240 + 120, + 120, + 120 ], "resnet_dt": true, "precision": "float32", "activation_function": "silut:10.0", "seed": 1, + "n_dl": 1, + "remove_self_interaction": true, "_comment": "SOG-specific long-range kernel parameters", "shift": [ 2.8, @@ -124,10 +126,10 @@ ], "batch_size": 1 }, - "numb_steps": 1000, + "numb_steps": 500000, "gradient_max_norm": 5.0, "seed": 10, - "disp_file": "lcurve_test_3.out", + "disp_file": "lcurve_remove_self.out", "disp_freq": 100, "save_freq": 2000 } diff --git a/examples/water/sog/profile_sog_timing.py b/examples/water/sog/profile_sog_timing.py index cf913780a5..55dd7e9079 100644 --- a/examples/water/sog/profile_sog_timing.py +++ b/examples/water/sog/profile_sog_timing.py @@ -75,10 +75,18 @@ def _install_fine_frame_corr_profiler( device: torch.device, collect_flag: dict[str, bool], ) -> tuple[Any, Any]: - orig_compute = model._compute_sog_frame_correction + orig_bundle = model._compute_sog_frame_correction_bundle orig_apply = model._apply_frame_correction_lower - def _timed_compute(self, coord: torch.Tensor, latent_charge: torch.Tensor, box: torch.Tensor) -> torch.Tensor: + def _timed_bundle( + self, + coord: torch.Tensor, + latent_charge: torch.Tensor, + box: torch.Tensor, + *, + need_force: bool, + need_virial: bool, + ) -> dict[str, torch.Tensor]: if coord.dim() != 3: raise ValueError(f"`coord` should be [nf, nloc, 3], got shape {tuple(coord.shape)}") if latent_charge.dim() != 3: @@ -110,9 +118,20 @@ def _timed_compute(self, coord: torch.Tensor, latent_charge: torch.Tensor, box: remove_self_interaction = bool(fitting.remove_self_interaction) n_dl = int(fitting.n_dl) pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) + two_pi = torch.tensor(2.0 * torch.pi, dtype=real_dtype, device=runtime_device) - nf, _nloc, _ = coord.shape + nf, nloc, _ = coord.shape corr = torch.zeros((nf, 1), dtype=real_dtype, device=runtime_device) + force_local = ( + torch.zeros((nf, nloc, 3), dtype=real_dtype, device=runtime_device) + if need_force + else None + ) + virial_local = ( + torch.zeros((nf, nloc, 1, 9), dtype=real_dtype, device=runtime_device) + if need_virial + else None + ) for ff in range(nf): r_raw = coord[ff] @@ -164,22 +183,54 @@ def _timed_compute(self, coord: torch.Tensor, latent_charge: torch.Tensor, box: isign=-1, ) - with _time_block("fc_nufft_type2_reduce", detail_times, device) if collect_flag["on"] else nullcontext(): - conv = kfac.unsqueeze(0) * recon - ifft_conv = pytorch_finufft.functional.finufft_type2( - nufft_points, - conv, - eps=1e-4, - isign=1, - ) / (2.0 * volume) - corr[ff, 0] = (charge * ifft_conv).real.sum() + with _time_block("fc_energy_reduce", detail_times, device) if collect_flag["on"] else nullcontext(): + rho_sq = recon.real.square() + recon.imag.square() + corr[ff, 0] = (kfac.unsqueeze(0) * rho_sq).sum() / (2.0 * volume) + + if need_force: + with _time_block("fc_prepare_force_conv", detail_times, device) if collect_flag["on"] else nullcontext(): + conv = kfac.unsqueeze(0).to(dtype=complex_dtype) * recon + + with _time_block("fc_prepare_force_kgrid", detail_times, device) if collect_flag["on"] else nullcontext(): + kk1 = torch.fft.ifftshift(kx_grid, dim=0) + kk2 = torch.fft.ifftshift(ky_grid, dim=1) + kk3 = torch.fft.ifftshift(kz_grid, dim=2) + k_grid = torch.stack((kk1, kk2, kk3), dim=0) + g_cart = two_pi * torch.einsum("ik,k...->i...", cell_inv, k_grid) + grad_conv = (1j * g_cart.unsqueeze(1).to(dtype=complex_dtype)) * conv.unsqueeze(0) + + with _time_block("fc_nufft_type2_force", detail_times, device) if collect_flag["on"] else nullcontext(): + grad_field = pytorch_finufft.functional.finufft_type2( + nufft_points, + grad_conv, + eps=1e-4, + isign=1, + ) + + with _time_block("fc_force_reduce", detail_times, device) if collect_flag["on"] else nullcontext(): + force_frame = -(q_t.unsqueeze(0) * grad_field.real.to(dtype=real_dtype)).sum(dim=1).transpose(0, 1) + force_frame = force_frame / volume + force_local[ff] = force_frame + + if need_virial: + with _time_block("fc_virial_local", detail_times, device) if collect_flag["on"] else nullcontext(): + virial_local[ff] = torch.einsum( + "ai,aj->aij", + force_frame, + r_raw, + ).reshape(nloc, 1, 9) if remove_self_interaction: with _time_block("fc_self_interaction", detail_times, device) if collect_flag["on"] else nullcontext(): diag_sum = kfac.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume) corr[ff, 0] -= torch.sum(q**2) * diag_sum - return corr + out: dict[str, torch.Tensor] = {"corr_redu": corr} + if force_local is not None: + out["force_local"] = force_local + if virial_local is not None: + out["virial_local"] = virial_local + return out def _timed_apply( self, @@ -198,36 +249,25 @@ def _timed_apply( coord_local = extended_coord[:, :nloc, :] box_local = box.view(nf, 3, 3) latent_charge = model_ret["latent_charge"] - latent_charge_for_energy = latent_charge if self.training else latent_charge.detach() + need_force = self.do_grad_r("energy") or self.do_grad_c("energy") + need_virial = self.do_grad_c("energy") + latent_charge_runtime = latent_charge if self.training else latent_charge.detach() - with _time_block("fc_compute_corr_redu", detail_times, device) if collect_flag["on"] else nullcontext(): - corr_redu = self._compute_sog_frame_correction( + with _time_block("fc_compute_corr_bundle", detail_times, device) if collect_flag["on"] else nullcontext(): + corr_bundle = self._compute_sog_frame_correction_bundle( coord_local, - latent_charge_for_energy, + latent_charge_runtime, box_local, + need_force=need_force, + need_virial=need_virial, ) + corr_redu = corr_bundle["corr_redu"] with _time_block("fc_add_energy", detail_times, device) if collect_flag["on"] else nullcontext(): model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to(model_ret["energy_redu"].dtype) - if self.do_grad_r("energy") or self.do_grad_c("energy"): - with _time_block("fc_compute_corr_redu_for_grad", detail_times, device) if collect_flag["on"] else nullcontext(): - if self.training and latent_charge.requires_grad: - corr_redu_for_grad = self._compute_sog_frame_correction( - coord_local, - latent_charge.detach(), - box_local, - ) - else: - corr_redu_for_grad = corr_redu - - with _time_block("fc_autograd_force", detail_times, device) if collect_flag["on"] else nullcontext(): - corr_force_local = -torch.autograd.grad( - corr_redu_for_grad.sum(), - coord_local, - create_graph=self.training, - retain_graph=False, - )[0].view(nf, nloc, 3) + if need_force: + corr_force_local = corr_bundle["force_local"].to(coord_local.dtype) with _time_block("fc_scatter_force", detail_times, device) if collect_flag["on"] else nullcontext(): corr_force_ext = torch.zeros( @@ -241,13 +281,9 @@ def _timed_apply( model_ret["energy_derv_r"].dtype ) - if self.do_grad_c("energy"): + if need_virial: + corr_virial_local = corr_bundle["virial_local"].to(corr_force_local.dtype) with _time_block("fc_virial_update", detail_times, device) if collect_flag["on"] else nullcontext(): - corr_virial_local = torch.einsum( - "fai,faj->faij", - corr_force_local, - coord_local, - ).reshape(nf, nloc, 1, 9) corr_virial_redu = corr_virial_local.sum(dim=1) if "energy_derv_c_redu" in model_ret: model_ret["energy_derv_c_redu"] = model_ret["energy_derv_c_redu"] + corr_virial_redu.to( @@ -266,9 +302,9 @@ def _timed_apply( return model_ret - model._compute_sog_frame_correction = types.MethodType(_timed_compute, model) + model._compute_sog_frame_correction_bundle = types.MethodType(_timed_bundle, model) model._apply_frame_correction_lower = types.MethodType(_timed_apply, model) - return orig_compute, orig_apply + return orig_bundle, orig_apply def profile( @@ -308,11 +344,11 @@ def timed_type2(*args, **kwargs): timings: dict[str, float] = defaultdict(float) detail_times: dict[str, float] = defaultdict(float) collect_detail = {"on": False} - orig_compute = None + orig_bundle = None orig_apply = None if fine_frame_profile: - orig_compute, orig_apply = _install_fine_frame_corr_profiler( + orig_bundle, orig_apply = _install_fine_frame_corr_profiler( model, detail_times, device, @@ -422,8 +458,8 @@ def timed_type2(*args, **kwargs): finally: pytorch_finufft.functional.finufft_type1 = orig_type1 pytorch_finufft.functional.finufft_type2 = orig_type2 - if fine_frame_profile and orig_compute is not None and orig_apply is not None: - model._compute_sog_frame_correction = orig_compute + if fine_frame_profile and orig_bundle is not None and orig_apply is not None: + model._compute_sog_frame_correction_bundle = orig_bundle model._apply_frame_correction_lower = orig_apply @@ -450,7 +486,7 @@ def _format_report(timings: dict[str, float]) -> str: "communicate_output", "output_cast", "fc_guard_and_slice", - "fc_compute_corr_redu", + "fc_compute_corr_bundle", "fc_cast_inputs", "fc_param_prepare", "fc_geom_and_points", @@ -458,11 +494,14 @@ def _format_report(timings: dict[str, float]) -> str: "fc_build_kfac", "fc_prepare_charge", "fc_nufft_type1", - "fc_nufft_type2_reduce", + "fc_energy_reduce", + "fc_prepare_force_conv", + "fc_prepare_force_kgrid", + "fc_nufft_type2_force", + "fc_force_reduce", + "fc_virial_local", "fc_self_interaction", "fc_add_energy", - "fc_compute_corr_redu_for_grad", - "fc_autograd_force", "fc_scatter_force", "fc_virial_update", ] From 023b56f599c7271036d74675670e530374f4282f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Apr 2026 09:22:25 +0000 Subject: [PATCH 7/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/model/atomic_model/__init__.py | 14 +- .../pt/model/atomic_model/les_atomic_model.py | 24 +- .../atomic_model/lr_energy_atomic_model.py | 50 ++-- .../pt/model/atomic_model/sog_atomic_model.py | 24 +- deepmd/pt/model/model/__init__.py | 14 +- deepmd/pt/model/model/les_model.py | 112 +++++---- deepmd/pt/model/model/sog_model.py | 136 ++++++---- deepmd/pt/model/task/__init__.py | 18 +- deepmd/pt/model/task/les_energy_fitting.py | 46 ++-- deepmd/pt/model/task/lr_fitting.py | 60 +++-- deepmd/pt/model/task/sog_energy_fitting.py | 101 ++++---- deepmd/utils/argcheck.py | 33 ++- examples/water/sog/ab_retain_graph.py | 43 ++-- .../sog/check_sog_consistency_with_cace.py | 21 +- examples/water/sog/compare_sog_dpa3_timing.py | 21 +- examples/water/sog/profile_sog_timing.py | 237 ++++++++++++++---- examples/water/sog/profile_sog_whatif.py | 26 +- .../tests/pt/model/test_les_working_layer.py | 23 +- .../tests/pt/model/test_sog_working_layer.py | 23 +- 19 files changed, 663 insertions(+), 363 deletions(-) diff --git a/deepmd/pt/model/atomic_model/__init__.py b/deepmd/pt/model/atomic_model/__init__.py index c432ae22e4..93451d17c5 100644 --- a/deepmd/pt/model/atomic_model/__init__.py +++ b/deepmd/pt/model/atomic_model/__init__.py @@ -29,10 +29,16 @@ from .energy_atomic_model import ( DPEnergyAtomicModel, ) +from .les_atomic_model import ( + LESEnergyAtomicModel, +) from .linear_atomic_model import ( DPZBLLinearEnergyAtomicModel, LinearEnergyAtomicModel, ) +from .lr_energy_atomic_model import ( + LREnergyAtomicModel, +) from .pairtab_atomic_model import ( PairTabAtomicModel, ) @@ -42,15 +48,9 @@ from .property_atomic_model import ( DPPropertyAtomicModel, ) -from .lr_energy_atomic_model import ( - LREnergyAtomicModel, -) from .sog_atomic_model import ( SOGEnergyAtomicModel, ) -from .les_atomic_model import ( - LESEnergyAtomicModel, -) __all__ = [ "BaseAtomicModel", @@ -61,9 +61,9 @@ "DPPolarAtomicModel", "DPPropertyAtomicModel", "DPZBLLinearEnergyAtomicModel", - "LinearEnergyAtomicModel", "LESEnergyAtomicModel", "LREnergyAtomicModel", + "LinearEnergyAtomicModel", "PairTabAtomicModel", "SOGEnergyAtomicModel", ] diff --git a/deepmd/pt/model/atomic_model/les_atomic_model.py b/deepmd/pt/model/atomic_model/les_atomic_model.py index 12dab13ab4..e6a19f01fc 100644 --- a/deepmd/pt/model/atomic_model/les_atomic_model.py +++ b/deepmd/pt/model/atomic_model/les_atomic_model.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import Any, Optional +from typing import ( + Any, +) import torch @@ -35,8 +37,8 @@ def __init__( self, descriptor: Any, type_map: list[str], - les_energy_fitting: Optional[LESEnergyFittingNet] = None, - fitting: Optional[Any] = None, + les_energy_fitting: LESEnergyFittingNet | None = None, + fitting: Any | None = None, **kwargs: Any, ) -> None: super().__init__(type_map, **kwargs) @@ -107,7 +109,7 @@ def get_dim_fparam(self) -> int: def has_default_fparam(self) -> bool: return self.fitting_net.has_default_fparam() - def get_default_fparam(self) -> Optional[torch.Tensor]: + def get_default_fparam(self) -> torch.Tensor | None: return self.fitting_net.get_default_fparam() def get_dim_aparam(self) -> int: @@ -139,10 +141,10 @@ def forward_atomic( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - mapping: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, - comm_dict: Optional[dict[str, torch.Tensor]] = None, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, ) -> dict[str, torch.Tensor]: nframes, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] @@ -200,9 +202,9 @@ def apply_out_stat( def compute_or_load_stat( self, sampled_func: Any, - stat_file_path: Optional[Any] = None, + stat_file_path: Any | None = None, compute_or_load_out_stat: bool = True, - preset_observed_type: Optional[list[str]] = None, + preset_observed_type: list[str] | None = None, ) -> None: if stat_file_path is not None and self.type_map is not None: stat_file_path /= " ".join(self.type_map) @@ -240,7 +242,7 @@ def wrapped_sampler() -> list[dict]: def compute_fitting_input_stat( self, sample_merged: Any, - stat_file_path: Optional[Any] = None, + stat_file_path: Any | None = None, ) -> None: self.fitting_net.compute_input_stats( sample_merged, diff --git a/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py b/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py index 43cdfd4c30..c44ef3f1d7 100644 --- a/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py +++ b/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import Any, Iterable, Optional +from typing import ( + Any, +) +from collections.abc import Iterable import torch @@ -48,7 +51,7 @@ def __init__( energy_fitting: InvarFitting, property_fitting: PropertyFittingNet, type_map: list[str], - correction_hidden: Optional[Iterable[int]] = None, + correction_hidden: Iterable[int] | None = None, correction_activation: str = "tanh", **kwargs: Any, ) -> None: @@ -67,9 +70,13 @@ def __init__( ) if energy_fitting.get_dim_fparam() != property_fitting.get_dim_fparam(): - raise ValueError("energy_fitting and property_fitting must share the same dim_fparam") + raise ValueError( + "energy_fitting and property_fitting must share the same dim_fparam" + ) if energy_fitting.get_dim_aparam() != property_fitting.get_dim_aparam(): - raise ValueError("energy_fitting and property_fitting must share the same dim_aparam") + raise ValueError( + "energy_fitting and property_fitting must share the same dim_aparam" + ) self.descriptor = descriptor self.energy_fitting = energy_fitting @@ -80,7 +87,11 @@ def __init__( self.rcut = self.descriptor.get_rcut() self.sel = self.descriptor.get_sel() self.correction_activation = correction_activation - hidden = list(correction_hidden) if correction_hidden is not None else [property_fitting.dim_out] + hidden = ( + list(correction_hidden) + if correction_hidden is not None + else [property_fitting.dim_out] + ) self.correction_hidden = hidden self.correction_head = self._build_correction_head( property_fitting.dim_out, hidden, correction_activation @@ -152,7 +163,7 @@ def get_dim_fparam(self) -> int: def has_default_fparam(self) -> bool: return self.energy_fitting.has_default_fparam() - def get_default_fparam(self) -> Optional[torch.Tensor]: + def get_default_fparam(self) -> torch.Tensor | None: return self.energy_fitting.get_default_fparam() def get_dim_aparam(self) -> int: @@ -184,10 +195,10 @@ def forward_atomic( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - mapping: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, - comm_dict: Optional[dict[str, torch.Tensor]] = None, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, ) -> dict[str, torch.Tensor]: nframes, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] @@ -224,7 +235,9 @@ def forward_atomic( ) if self.enable_eval_fitting_last_layer_hook and "middle_output" in energy_ret: - self.eval_fitting_last_layer_list.append(energy_ret["middle_output"].detach()) + self.eval_fitting_last_layer_list.append( + energy_ret["middle_output"].detach() + ) q_val = prop_ret[self.property_name] corr_energy = self.correction_head(q_val) @@ -251,7 +264,7 @@ def apply_out_stat( def compute_or_load_stat( self, sampled_func: Any, - stat_file_path: Optional[Any] = None, + stat_file_path: Any | None = None, compute_or_load_out_stat: bool = True, ) -> None: if stat_file_path is not None and self.type_map is not None: @@ -286,7 +299,7 @@ def wrapped_sampler() -> list[dict]: def compute_fitting_input_stat( self, sample_merged: Any, - stat_file_path: Optional[Any] = None, + stat_file_path: Any | None = None, ) -> None: self.energy_fitting.compute_input_stats( sample_merged, @@ -315,7 +328,8 @@ def serialize(self) -> dict: "@variables": { **dd.get("@variables", {}), "correction_head": { - k: to_numpy_array(v) for k, v in self.correction_head.state_dict().items() + k: to_numpy_array(v) + for k, v in self.correction_head.state_dict().items() }, }, } @@ -331,7 +345,9 @@ def deserialize(cls, data: dict) -> "LREnergyAtomicModel": variables = data.pop("@variables", {}) descriptor_obj = BaseDescriptor.deserialize(data.pop("descriptor")) energy_fitting_obj = InvarFitting.deserialize(data.pop("energy_fitting")) - property_fitting_obj = PropertyFittingNet.deserialize(data.pop("property_fitting")) + property_fitting_obj = PropertyFittingNet.deserialize( + data.pop("property_fitting") + ) correction_hidden = data.pop("correction_hidden", None) correction_activation = data.pop("correction_activation", "tanh") obj = cls( @@ -344,5 +360,7 @@ def deserialize(cls, data: dict) -> "LREnergyAtomicModel": ) correction_state = variables.get("correction_head", None) if correction_state is not None: - obj.correction_head.load_state_dict({k: to_torch_tensor(v) for k, v in correction_state.items()}) + obj.correction_head.load_state_dict( + {k: to_torch_tensor(v) for k, v in correction_state.items()} + ) return obj diff --git a/deepmd/pt/model/atomic_model/sog_atomic_model.py b/deepmd/pt/model/atomic_model/sog_atomic_model.py index f1658743f8..e23c0b671d 100644 --- a/deepmd/pt/model/atomic_model/sog_atomic_model.py +++ b/deepmd/pt/model/atomic_model/sog_atomic_model.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import Any, Optional +from typing import ( + Any, +) import torch @@ -35,8 +37,8 @@ def __init__( self, descriptor: Any, type_map: list[str], - sog_energy_fitting: Optional[SOGEnergyFittingNet] = None, - fitting: Optional[Any] = None, + sog_energy_fitting: SOGEnergyFittingNet | None = None, + fitting: Any | None = None, **kwargs: Any, ) -> None: super().__init__(type_map, **kwargs) @@ -107,7 +109,7 @@ def get_dim_fparam(self) -> int: def has_default_fparam(self) -> bool: return self.fitting_net.has_default_fparam() - def get_default_fparam(self) -> Optional[torch.Tensor]: + def get_default_fparam(self) -> torch.Tensor | None: return self.fitting_net.get_default_fparam() def get_dim_aparam(self) -> int: @@ -139,10 +141,10 @@ def forward_atomic( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - mapping: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, - comm_dict: Optional[dict[str, torch.Tensor]] = None, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + comm_dict: dict[str, torch.Tensor] | None = None, ) -> dict[str, torch.Tensor]: nframes, nloc, _ = nlist.shape atype = extended_atype[:, :nloc] @@ -200,9 +202,9 @@ def apply_out_stat( def compute_or_load_stat( self, sampled_func: Any, - stat_file_path: Optional[Any] = None, + stat_file_path: Any | None = None, compute_or_load_out_stat: bool = True, - preset_observed_type: Optional[list[str]] = None, + preset_observed_type: list[str] | None = None, ) -> None: if stat_file_path is not None and self.type_map is not None: stat_file_path /= " ".join(self.type_map) @@ -240,7 +242,7 @@ def wrapped_sampler() -> list[dict]: def compute_fitting_input_stat( self, sample_merged: Any, - stat_file_path: Optional[Any] = None, + stat_file_path: Any | None = None, ) -> None: self.fitting_net.compute_input_stats( sample_merged, diff --git a/deepmd/pt/model/model/__init__.py b/deepmd/pt/model/model/__init__.py index fedebb6f22..323167dbbe 100644 --- a/deepmd/pt/model/model/__init__.py +++ b/deepmd/pt/model/model/__init__.py @@ -54,6 +54,9 @@ from .frozen import ( FrozenModel, ) +from .les_model import ( + LESEnergyModel, +) from .make_hessian_model import ( make_hessian_model, ) @@ -69,15 +72,12 @@ from .property_model import ( PropertyModel, ) -from .spin_model import ( - SpinEnergyModel, - SpinModel, -) from .sog_model import ( SOGEnergyModel, ) -from .les_model import ( - LESEnergyModel, +from .spin_model import ( + SpinEnergyModel, + SpinModel, ) @@ -319,8 +319,8 @@ def get_model(model_params: dict) -> Any: "DipoleModel", "EnergyModel", "FrozenModel", - "LinearEnergyModel", "LESEnergyModel", + "LinearEnergyModel", "PolarModel", "SOGEnergyModel", "SpinEnergyModel", diff --git a/deepmd/pt/model/model/les_model.py b/deepmd/pt/model/model/les_model.py index f0b08c0ca4..c530871294 100644 --- a/deepmd/pt/model/model/les_model.py +++ b/deepmd/pt/model/model/les_model.py @@ -1,18 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, - Optional, ) -import torch import pytorch_finufft - -from deepmd.pt.model.model.transform_output import ( - communicate_extended_output, -) -from deepmd.pt.utils.nlist import ( - extend_input_and_build_neighbor_list, -) +import torch from deepmd.pt.model.atomic_model import ( LESEnergyAtomicModel, @@ -20,6 +12,12 @@ from deepmd.pt.model.model.model import ( BaseModel, ) +from deepmd.pt.model.model.transform_output import ( + communicate_extended_output, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) from .dp_model import ( DPModelCommon, @@ -111,13 +109,17 @@ def _compute_les_frame_correction_bundle( fitting = self.get_fitting_net() runtime_device = coord.device real_dtype = coord.dtype - complex_dtype = torch.complex128 if real_dtype == torch.float64 else torch.complex64 + complex_dtype = ( + torch.complex128 if real_dtype == torch.float64 else torch.complex64 + ) latent_charge = latent_charge.to(device=runtime_device, dtype=real_dtype) box = box.to(device=runtime_device, dtype=real_dtype) sigma_raw = getattr(fitting, "sigma", None) if sigma_raw is None: - raise ValueError("LES fitting net should provide `sigma` for frame correction.") + raise ValueError( + "LES fitting net should provide `sigma` for frame correction." + ) sigma = torch.as_tensor( sigma_raw, dtype=real_dtype, @@ -149,7 +151,9 @@ def _compute_les_frame_correction_bundle( volume = torch.det(box_frame) if torch.abs(volume) <= torch.finfo(real_dtype).eps: - raise ValueError("`box` is singular (near-zero volume), cannot run NUFFT.") + raise ValueError( + "`box` is singular (near-zero volume), cannot run NUFFT." + ) cell_inv = torch.linalg.inv(box_frame) r_frac = torch.matmul(r_raw, cell_inv) @@ -164,9 +168,15 @@ def _compute_les_frame_correction_bundle( norms = torch.norm(box_frame, dim=1) nk = tuple(max(1, int(n.item() / n_dl)) for n in norms) - n1 = torch.arange(-nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype) - n2 = torch.arange(-nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype) - n3 = torch.arange(-nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype) + n1 = torch.arange( + -nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype + ) + n2 = torch.arange( + -nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype + ) + n3 = torch.arange( + -nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype + ) kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 @@ -178,7 +188,11 @@ def _compute_les_frame_correction_bundle( kfac[zero_mask] = 0.0 q_t = q.transpose(0, 1).contiguous() - charge = torch.complex(q_t, torch.zeros_like(q_t)).to(dtype=complex_dtype).contiguous() + charge = ( + torch.complex(q_t, torch.zeros_like(q_t)) + .to(dtype=complex_dtype) + .contiguous() + ) output_shape = tuple(int(x) for x in kx_grid.shape) recon = pytorch_finufft.functional.finufft_type1( nufft_points, @@ -202,14 +216,20 @@ def _compute_les_frame_correction_bundle( kk3 = torch.fft.ifftshift(kz_grid, dim=2) k_grid = torch.stack((kk1, kk2, kk3), dim=0) g_cart = two_pi * torch.einsum("ik,k...->i...", cell_inv, k_grid) - grad_conv = (1j * g_cart.unsqueeze(1).to(dtype=complex_dtype)) * conv.unsqueeze(0) + grad_conv = ( + 1j * g_cart.unsqueeze(1).to(dtype=complex_dtype) + ) * conv.unsqueeze(0) grad_field = pytorch_finufft.functional.finufft_type2( nufft_points, grad_conv, eps=1e-4, isign=1, ) - force_frame = -(q_t.unsqueeze(0) * grad_field.real.to(dtype=real_dtype)).sum(dim=1).transpose(0, 1) + force_frame = ( + -(q_t.unsqueeze(0) * grad_field.real.to(dtype=real_dtype)) + .sum(dim=1) + .transpose(0, 1) + ) force_frame = force_frame / volume force_local[ff] = force_frame @@ -251,7 +271,7 @@ def _apply_frame_correction_lower( model_ret: dict[str, torch.Tensor], extended_coord: torch.Tensor, nlist: torch.Tensor, - box: Optional[torch.Tensor], + box: torch.Tensor | None, do_atomic_virial: bool, ) -> dict[str, torch.Tensor]: if box is None or "latent_charge" not in model_ret: @@ -264,7 +284,9 @@ def _apply_frame_correction_lower( latent_charge = model_ret["latent_charge"] need_force = self.do_grad_r("energy") or self.do_grad_c("energy") need_virial = self.do_grad_c("energy") - latent_charge_runtime = latent_charge if self.training else latent_charge.detach() + latent_charge_runtime = ( + latent_charge if self.training else latent_charge.detach() + ) corr_bundle = self._compute_les_frame_correction_bundle( coord_local, latent_charge_runtime, @@ -274,7 +296,9 @@ def _apply_frame_correction_lower( ) corr_redu = corr_bundle["corr_redu"] - model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to(model_ret["energy_redu"].dtype) + model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to( + model_ret["energy_redu"].dtype + ) if need_force: corr_force_local = corr_bundle["force_local"].to(coord_local.dtype) @@ -286,17 +310,19 @@ def _apply_frame_correction_lower( ) corr_force_ext[:, :nloc, :] = corr_force_local if "energy_derv_r" in model_ret: - model_ret["energy_derv_r"] = model_ret["energy_derv_r"] + corr_force_ext.unsqueeze(-2).to( - model_ret["energy_derv_r"].dtype - ) + model_ret["energy_derv_r"] = model_ret[ + "energy_derv_r" + ] + corr_force_ext.unsqueeze(-2).to(model_ret["energy_derv_r"].dtype) if need_virial: - corr_virial_local = corr_bundle["virial_local"].to(corr_force_local.dtype) + corr_virial_local = corr_bundle["virial_local"].to( + corr_force_local.dtype + ) corr_virial_redu = corr_virial_local.sum(dim=1) if "energy_derv_c_redu" in model_ret: - model_ret["energy_derv_c_redu"] = model_ret["energy_derv_c_redu"] + corr_virial_redu.to( - model_ret["energy_derv_c_redu"].dtype - ) + model_ret["energy_derv_c_redu"] = model_ret[ + "energy_derv_c_redu" + ] + corr_virial_redu.to(model_ret["energy_derv_c_redu"].dtype) if do_atomic_virial and "energy_derv_c" in model_ret: corr_atom_virial = torch.zeros( (nf, nall, 1, 9), @@ -304,9 +330,9 @@ def _apply_frame_correction_lower( device=corr_virial_local.device, ) corr_atom_virial[:, :nloc, :, :] = corr_virial_local - model_ret["energy_derv_c"] = model_ret["energy_derv_c"] + corr_atom_virial.to( - model_ret["energy_derv_c"].dtype - ) + model_ret["energy_derv_c"] = model_ret[ + "energy_derv_c" + ] + corr_atom_virial.to(model_ret["energy_derv_c"].dtype) return model_ret @@ -316,13 +342,13 @@ def forward_common_lower( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - mapping: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, - comm_dict: Optional[dict[str, torch.Tensor]] = None, + comm_dict: dict[str, torch.Tensor] | None = None, extra_nlist_sort: bool = False, - extended_coord_corr: Optional[torch.Tensor] = None, + extended_coord_corr: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: if self.do_grad_r("energy") or self.do_grad_c("energy"): extended_coord = extended_coord.requires_grad_(True) @@ -353,9 +379,9 @@ def forward( self, coord: torch.Tensor, atype: torch.Tensor, - box: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: cc, bb, fp, ap, input_prec = self._input_type_cast( @@ -423,11 +449,11 @@ def forward_lower( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - mapping: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, - comm_dict: Optional[dict[str, torch.Tensor]] = None, + comm_dict: dict[str, torch.Tensor] | None = None, ) -> dict[str, torch.Tensor]: model_ret = self.forward_common_lower( extended_coord, diff --git a/deepmd/pt/model/model/sog_model.py b/deepmd/pt/model/model/sog_model.py index 5ef6f96bed..44628737bc 100644 --- a/deepmd/pt/model/model/sog_model.py +++ b/deepmd/pt/model/model/sog_model.py @@ -1,18 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Any, - Optional, ) -import torch import pytorch_finufft - -from deepmd.pt.model.model.transform_output import ( - communicate_extended_output, -) -from deepmd.pt.utils.nlist import ( - extend_input_and_build_neighbor_list, -) +import torch from deepmd.pt.model.atomic_model import ( SOGEnergyAtomicModel, @@ -20,6 +12,12 @@ from deepmd.pt.model.model.model import ( BaseModel, ) +from deepmd.pt.model.model.transform_output import ( + communicate_extended_output, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) from .dp_model import ( DPModelCommon, @@ -47,7 +45,9 @@ def __init__( SOGEnergyModel_.__init__(self, *args, **kwargs) self._hessian_enabled = False # Runtime-only caches for NUFFT correction path. - self._sog_param_cache: dict[tuple[Any, ...], tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {} + self._sog_param_cache: dict[ + tuple[Any, ...], tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ] = {} @staticmethod def _device_key(device: torch.device) -> str: @@ -69,10 +69,20 @@ def _get_cached_sog_params( ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: wl_raw = fitting.wl sl_raw = fitting.sl - grad_mode = torch.is_grad_enabled() and (wl_raw.requires_grad or sl_raw.requires_grad) + grad_mode = torch.is_grad_enabled() and ( + wl_raw.requires_grad or sl_raw.requires_grad + ) - wl = wl_raw if (wl_raw.device == runtime_device and wl_raw.dtype == real_dtype) else wl_raw.to(dtype=real_dtype, device=runtime_device) - sl = sl_raw if (sl_raw.device == runtime_device and sl_raw.dtype == real_dtype) else sl_raw.to(dtype=real_dtype, device=runtime_device) + wl = ( + wl_raw + if (wl_raw.device == runtime_device and wl_raw.dtype == real_dtype) + else wl_raw.to(dtype=real_dtype, device=runtime_device) + ) + sl = ( + sl_raw + if (sl_raw.device == runtime_device and sl_raw.dtype == real_dtype) + else sl_raw.to(dtype=real_dtype, device=runtime_device) + ) min_term = -1.0 / torch.exp(-2.0 * sl) # Do not cache differentiable tensors across iterations. @@ -157,7 +167,9 @@ def _compute_sog_frame_correction_bundle( need_virial: bool, ) -> dict[str, torch.Tensor]: if coord.dim() != 3: - raise ValueError(f"`coord` should be [nf, nloc, 3], got shape {tuple(coord.shape)}") + raise ValueError( + f"`coord` should be [nf, nloc, 3], got shape {tuple(coord.shape)}" + ) if latent_charge.dim() != 3: raise ValueError( f"`latent_charge` should be [nf, nloc, nq], got shape {tuple(latent_charge.shape)}" @@ -171,11 +183,15 @@ def _compute_sog_frame_correction_bundle( fitting = self.get_fitting_net() runtime_device = coord.device real_dtype = coord.dtype - complex_dtype = torch.complex128 if real_dtype == torch.float64 else torch.complex64 + complex_dtype = ( + torch.complex128 if real_dtype == torch.float64 else torch.complex64 + ) latent_charge = latent_charge.to(device=runtime_device, dtype=real_dtype) box = box.to(device=runtime_device, dtype=real_dtype) if box.dim() != 3 or box.shape[-2:] != (3, 3): - raise ValueError(f"`box` should be [nf, 3, 3], got shape {tuple(box.shape)}") + raise ValueError( + f"`box` should be [nf, 3, 3], got shape {tuple(box.shape)}" + ) wl, _sl, min_term = self._get_cached_sog_params( fitting, @@ -207,7 +223,9 @@ def _compute_sog_frame_correction_bundle( volume = torch.det(box_frame) if torch.abs(volume) <= torch.finfo(real_dtype).eps: - raise ValueError("`box` is singular (near-zero volume), cannot run NUFFT.") + raise ValueError( + "`box` is singular (near-zero volume), cannot run NUFFT." + ) cell_inv = torch.linalg.inv(box_frame) r_frac = torch.matmul(r_raw, cell_inv) @@ -222,9 +240,15 @@ def _compute_sog_frame_correction_bundle( norms = torch.norm(box_frame, dim=1) nk = tuple(max(1, int(n.item() / n_dl)) for n in norms) - n1 = torch.arange(-nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype) - n2 = torch.arange(-nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype) - n3 = torch.arange(-nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype) + n1 = torch.arange( + -nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype + ) + n2 = torch.arange( + -nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype + ) + n3 = torch.arange( + -nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype + ) kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 zero_mask = k_sq == 0 @@ -236,7 +260,11 @@ def _compute_sog_frame_correction_bundle( output_shape = tuple(int(x) for x in kx_grid.shape) q_t = q.transpose(0, 1).contiguous() - charge = torch.complex(q_t, torch.zeros_like(q_t)).to(dtype=complex_dtype).contiguous() + charge = ( + torch.complex(q_t, torch.zeros_like(q_t)) + .to(dtype=complex_dtype) + .contiguous() + ) recon = pytorch_finufft.functional.finufft_type1( nufft_points, charge, @@ -261,14 +289,20 @@ def _compute_sog_frame_correction_bundle( kk3 = torch.fft.ifftshift(kz_grid, dim=2) k_grid = torch.stack((kk1, kk2, kk3), dim=0) g_cart = two_pi * torch.einsum("ik,k...->i...", cell_inv, k_grid) - grad_conv = (1j * g_cart.unsqueeze(1).to(dtype=complex_dtype)) * conv.unsqueeze(0) + grad_conv = ( + 1j * g_cart.unsqueeze(1).to(dtype=complex_dtype) + ) * conv.unsqueeze(0) grad_field = pytorch_finufft.functional.finufft_type2( nufft_points, grad_conv, eps=1e-4, isign=1, ) - force_frame = -(q_t.unsqueeze(0) * grad_field.real.to(dtype=real_dtype)).sum(dim=1).transpose(0, 1) + force_frame = ( + -(q_t.unsqueeze(0) * grad_field.real.to(dtype=real_dtype)) + .sum(dim=1) + .transpose(0, 1) + ) force_frame = force_frame / volume force_local[ff] = force_frame @@ -310,7 +344,7 @@ def _apply_frame_correction_lower( model_ret: dict[str, torch.Tensor], extended_coord: torch.Tensor, nlist: torch.Tensor, - box: Optional[torch.Tensor], + box: torch.Tensor | None, do_atomic_virial: bool, ) -> dict[str, torch.Tensor]: if box is None or "latent_charge" not in model_ret: @@ -323,7 +357,9 @@ def _apply_frame_correction_lower( latent_charge = model_ret["latent_charge"] need_force = self.do_grad_r("energy") or self.do_grad_c("energy") need_virial = self.do_grad_c("energy") - latent_charge_runtime = latent_charge if self.training else latent_charge.detach() + latent_charge_runtime = ( + latent_charge if self.training else latent_charge.detach() + ) corr_bundle = self._compute_sog_frame_correction_bundle( coord_local, latent_charge_runtime, @@ -333,7 +369,9 @@ def _apply_frame_correction_lower( ) corr_redu = corr_bundle["corr_redu"] - model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to(model_ret["energy_redu"].dtype) + model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to( + model_ret["energy_redu"].dtype + ) if need_force: corr_force_local = corr_bundle["force_local"].to(coord_local.dtype) @@ -345,17 +383,19 @@ def _apply_frame_correction_lower( ) corr_force_ext[:, :nloc, :] = corr_force_local if "energy_derv_r" in model_ret: - model_ret["energy_derv_r"] = model_ret["energy_derv_r"] + corr_force_ext.unsqueeze(-2).to( - model_ret["energy_derv_r"].dtype - ) + model_ret["energy_derv_r"] = model_ret[ + "energy_derv_r" + ] + corr_force_ext.unsqueeze(-2).to(model_ret["energy_derv_r"].dtype) if need_virial: - corr_virial_local = corr_bundle["virial_local"].to(corr_force_local.dtype) + corr_virial_local = corr_bundle["virial_local"].to( + corr_force_local.dtype + ) corr_virial_redu = corr_virial_local.sum(dim=1) if "energy_derv_c_redu" in model_ret: - model_ret["energy_derv_c_redu"] = model_ret["energy_derv_c_redu"] + corr_virial_redu.to( - model_ret["energy_derv_c_redu"].dtype - ) + model_ret["energy_derv_c_redu"] = model_ret[ + "energy_derv_c_redu" + ] + corr_virial_redu.to(model_ret["energy_derv_c_redu"].dtype) if do_atomic_virial and "energy_derv_c" in model_ret: corr_atom_virial = torch.zeros( (nf, nall, 1, 9), @@ -363,9 +403,9 @@ def _apply_frame_correction_lower( device=corr_virial_local.device, ) corr_atom_virial[:, :nloc, :, :] = corr_virial_local - model_ret["energy_derv_c"] = model_ret["energy_derv_c"] + corr_atom_virial.to( - model_ret["energy_derv_c"].dtype - ) + model_ret["energy_derv_c"] = model_ret[ + "energy_derv_c" + ] + corr_atom_virial.to(model_ret["energy_derv_c"].dtype) return model_ret @@ -375,13 +415,13 @@ def forward_common_lower( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - mapping: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, - comm_dict: Optional[dict[str, torch.Tensor]] = None, + comm_dict: dict[str, torch.Tensor] | None = None, extra_nlist_sort: bool = False, - extended_coord_corr: Optional[torch.Tensor] = None, + extended_coord_corr: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: if self.do_grad_r("energy") or self.do_grad_c("energy"): extended_coord = extended_coord.requires_grad_(True) @@ -412,9 +452,9 @@ def forward( self, coord: torch.Tensor, atype: torch.Tensor, - box: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, ) -> dict[str, torch.Tensor]: cc, bb, fp, ap, input_prec = self._input_type_cast( @@ -482,11 +522,11 @@ def forward_lower( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - mapping: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, - comm_dict: Optional[dict[str, torch.Tensor]] = None, + comm_dict: dict[str, torch.Tensor] | None = None, ) -> dict[str, torch.Tensor]: model_ret = self.forward_common_lower( extended_coord, diff --git a/deepmd/pt/model/task/__init__.py b/deepmd/pt/model/task/__init__.py index c3b7025358..7cdfbd35a4 100644 --- a/deepmd/pt/model/task/__init__.py +++ b/deepmd/pt/model/task/__init__.py @@ -18,23 +18,23 @@ from .fitting import ( Fitting, ) +from .les_energy_fitting import ( + LESEnergyFittingNet, +) +from .lr_fitting import ( + LRFittingNet, +) from .polarizability import ( PolarFittingNet, ) from .property import ( PropertyFittingNet, ) -from .type_predict import ( - TypePredictNet, -) -from .lr_fitting import ( - LRFittingNet, -) from .sog_energy_fitting import ( SOGEnergyFittingNet, ) -from .les_energy_fitting import ( - LESEnergyFittingNet, +from .type_predict import ( + TypePredictNet, ) __all__ = [ @@ -45,8 +45,8 @@ "EnergyFittingNet", "EnergyFittingNetDirect", "Fitting", - "LRFittingNet", "LESEnergyFittingNet", + "LRFittingNet", "PolarFittingNet", "PropertyFittingNet", "SOGEnergyFittingNet", diff --git a/deepmd/pt/model/task/les_energy_fitting.py b/deepmd/pt/model/task/les_energy_fitting.py index 92e60816a2..325355291d 100644 --- a/deepmd/pt/model/task/les_energy_fitting.py +++ b/deepmd/pt/model/task/les_energy_fitting.py @@ -2,8 +2,6 @@ import logging from typing import ( Any, - Optional, - Union, ) import numpy as np @@ -33,7 +31,6 @@ LRFittingNet, ) - LES_DEFAULT_SIGMA = to_numpy_array(np.array(2.8 / np.sqrt(2.0))) @@ -111,7 +108,7 @@ def __init__( dim_out_lr: int, neuron_sr: list[int] = [128, 128, 128], neuron_lr: list[int] = [128, 128, 128], - bias_atom_e: Optional[torch.Tensor] = None, + bias_atom_e: torch.Tensor | None = None, resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, @@ -119,15 +116,15 @@ def __init__( activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, - rcond: Optional[float] = None, - seed: Optional[Union[int, list[int]]] = None, + rcond: float | None = None, + seed: int | list[int] | None = None, exclude_types: list[int] = [], - trainable: Union[bool, list[bool]] = True, - remove_vaccum_contribution: Optional[list[bool]] = None, - type_map: Optional[list[str]] = None, + trainable: bool | list[bool] = True, + remove_vaccum_contribution: list[bool] | None = None, + type_map: list[str] | None = None, use_aparam_as_mask: bool = False, - default_fparam: Optional[list[float]] = None, - sigma: Optional[Union[float, list[float], torch.Tensor]] = None, + default_fparam: list[float] | None = None, + sigma: float | list[float] | torch.Tensor | None = None, n_dl: int = 1, remove_self_interaction: bool = False, **kwargs: Any, @@ -177,7 +174,6 @@ def __init__( self.remove_self_interaction = bool(remove_self_interaction) self._nufft_fallback_warned = False - def output_def(self) -> FittingOutputDef: return FittingOutputDef( [ @@ -194,10 +190,10 @@ def output_def(self) -> FittingOutputDef: reducible=False, r_differentiable=False, c_differentiable=False, - ) + ), ] ) - + def serialize(self) -> dict: data = super().serialize() data["type"] = "les_energy" @@ -218,9 +214,13 @@ def deserialize(cls, data: dict) -> "LESEnergyFittingNet": with torch.no_grad(): if sigma_tensor is None: - raise ValueError("LES fitting net deserialize requires `sigma` in @variables.") + raise ValueError( + "LES fitting net deserialize requires `sigma` in @variables." + ) obj.sigma.copy_( - sigma_tensor.to(dtype=obj.sigma.dtype, device=obj.sigma.device).reshape(1) + sigma_tensor.to(dtype=obj.sigma.dtype, device=obj.sigma.device).reshape( + 1 + ) ) return obj @@ -231,11 +231,11 @@ def forward( self, descriptor: torch.Tensor, atype: torch.Tensor, - gr: Optional[torch.Tensor] = None, - g2: Optional[torch.Tensor] = None, - h2: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: out = self._forward_common( descriptor=descriptor, @@ -253,6 +253,6 @@ def forward( if "middle_output" in out: result["middle_output"] = out["middle_output"] return result - + # make jit happy with torch 2.0.0 - exclude_types: list[int] \ No newline at end of file + exclude_types: list[int] diff --git a/deepmd/pt/model/task/lr_fitting.py b/deepmd/pt/model/task/lr_fitting.py index b245632e54..acf1d7eee2 100644 --- a/deepmd/pt/model/task/lr_fitting.py +++ b/deepmd/pt/model/task/lr_fitting.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import Any, Optional, Union -from abc import abstractmethod +from typing import ( + Any, + Optional, +) import numpy as np import torch -from deepmd.dpmodel import ( - FittingOutputDef, - OutputVariableDef, -) from deepmd.dpmodel.utils.seed import ( child_seed, ) @@ -37,13 +35,11 @@ get_index_between_two_maps, map_atom_exclude_types, ) -from deepmd.utils.version import ( - check_version_compatibility, -) dtype = env.GLOBAL_PT_FLOAT_PRECISION device = env.DEVICE + @Fitting.register("lr") class LRFittingNet(Fitting): """Construct a general sr+lr interactions fitting net. @@ -113,7 +109,7 @@ def __init__( dim_out_lr: int, neuron_sr: list[int] = [128, 128, 128], neuron_lr: list[int] = [128, 128, 128], - bias_atom_e: Optional[torch.Tensor] = None, + bias_atom_e: torch.Tensor | None = None, resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, @@ -121,14 +117,14 @@ def __init__( activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, - rcond: Optional[float] = None, - seed: Optional[Union[int, list[int]]] = None, + rcond: float | None = None, + seed: int | list[int] | None = None, exclude_types: list[int] = [], - trainable: Union[bool, list[bool]] = True, - remove_vaccum_contribution: Optional[list[bool]] = None, - type_map: Optional[list[str]] = None, + trainable: bool | list[bool] = True, + remove_vaccum_contribution: list[bool] | None = None, + type_map: list[str] | None = None, use_aparam_as_mask: bool = False, - default_fparam: Optional[list[float]] = None, + default_fparam: list[float] | None = None, **kwargs: Any, ) -> None: super().__init__() @@ -373,7 +369,7 @@ def has_default_fparam(self) -> bool: """Check if the fitting has default frame parameters.""" return self.default_fparam is not None - def get_default_fparam(self) -> Optional[torch.Tensor]: + def get_default_fparam(self) -> torch.Tensor | None: return self.default_fparam_tensor def get_dim_aparam(self) -> int: @@ -457,7 +453,7 @@ def __getitem__(self, key: str) -> torch.Tensor: def _sr_net_out_dim(self) -> int: """Set the SRFittingNet output dim.""" return self.dim_out_sr - + def _lr_net_out_dim(self) -> int: """Set the LR FittingNet output dim.""" return self.dim_out_lr @@ -465,7 +461,7 @@ def _lr_net_out_dim(self) -> int: def _corr_head(self, lr_out: torch.Tensor) -> torch.Tensor: # TODO: Add latent_charge correction logic after LR output is finalized. return lr_out - + def _extend_f_avg_std(self, xx: torch.Tensor, nb: int) -> torch.Tensor: return torch.tile(xx.view([1, self.numb_fparam]), [nb, 1]) @@ -476,11 +472,11 @@ def _forward_common( self, descriptor: torch.Tensor, atype: torch.Tensor, - gr: Optional[torch.Tensor] = None, - g2: Optional[torch.Tensor] = None, - h2: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: xx = descriptor.to(self.prec) nf, nloc, nd = xx.shape @@ -573,16 +569,16 @@ def _forward_common( lr_out = self._corr_head(lr_out) results.update({"sr": sr_out, "lr": lr_out}) return results - + def _apply_networks( self, layers: NetworkCollection, neuron: list[int], dim_out: int, xx: torch.Tensor, - xx_zeros: Optional[torch.Tensor], + xx_zeros: torch.Tensor | None, atype: torch.Tensor, - middle_output: Optional[dict[str, torch.Tensor]], + middle_output: dict[str, torch.Tensor] | None, bool_bias: bool = False, ) -> torch.Tensor: nf, nloc, _ = xx.shape @@ -590,9 +586,7 @@ def _apply_networks( if self.mixed_types: atom_property = layers.networks[0](xx) if self.eval_return_middle_output and middle_output is not None: - middle_output["middle_output"] = layers.networks[0].call_until_last( - xx - ) + middle_output["middle_output"] = layers.networks[0].call_until_last(xx) if xx_zeros is not None: atom_property -= layers.networks[0](xx_zeros) outs = outs + atom_property @@ -626,9 +620,11 @@ def _apply_networks( ): atom_property -= ll(xx_zeros) if bool_bias: - atom_property = atom_property + self.bias_atom_e[type_i].to(self.prec) + atom_property = atom_property + self.bias_atom_e[type_i].to( + self.prec + ) else: atom_property = atom_property atom_property = torch.where(mask, atom_property, 0.0) outs = outs + atom_property - return outs \ No newline at end of file + return outs diff --git a/deepmd/pt/model/task/sog_energy_fitting.py b/deepmd/pt/model/task/sog_energy_fitting.py index e95757c981..eff762eca7 100644 --- a/deepmd/pt/model/task/sog_energy_fitting.py +++ b/deepmd/pt/model/task/sog_energy_fitting.py @@ -2,8 +2,6 @@ import logging from typing import ( Any, - Optional, - Union, ) import numpy as np @@ -33,35 +31,42 @@ LRFittingNet, ) - -SOG_DEFAULT_AMPLITUDE = to_numpy_array(np.array([ - 0.2750, - 0.1375, - 0.0688, - 0.0344, - 0.0172, - 0.0086, - 0.0043, - 0.0021, - 0.0011, - 0.0005, - 0.0003, - 0.0001, -])) -SOG_DEFAULT_SHIFT = to_numpy_array(np.array([ - 2.8, - 5.7, - 11.4, - 22.7, - 45.5, - 91.0, - 182.0, - 364.0, - 728.0, - 1456.0, - 2912.0, - 5823.9, -])) +SOG_DEFAULT_AMPLITUDE = to_numpy_array( + np.array( + [ + 0.2750, + 0.1375, + 0.0688, + 0.0344, + 0.0172, + 0.0086, + 0.0043, + 0.0021, + 0.0011, + 0.0005, + 0.0003, + 0.0001, + ] + ) +) +SOG_DEFAULT_SHIFT = to_numpy_array( + np.array( + [ + 2.8, + 5.7, + 11.4, + 22.7, + 45.5, + 91.0, + 182.0, + 364.0, + 728.0, + 1456.0, + 2912.0, + 5823.9, + ] + ) +) @LRFittingNet.register("sog_energy") @@ -138,7 +143,7 @@ def __init__( dim_out_lr: int, neuron_sr: list[int] = [128, 128, 128], neuron_lr: list[int] = [128, 128, 128], - bias_atom_e: Optional[torch.Tensor] = None, + bias_atom_e: torch.Tensor | None = None, resnet_dt: bool = True, numb_fparam: int = 0, numb_aparam: int = 0, @@ -146,16 +151,16 @@ def __init__( activation_function: str = "tanh", precision: str = DEFAULT_PRECISION, mixed_types: bool = True, - rcond: Optional[float] = None, - seed: Optional[Union[int, list[int]]] = None, + rcond: float | None = None, + seed: int | list[int] | None = None, exclude_types: list[int] = [], - trainable: Union[bool, list[bool]] = True, - remove_vaccum_contribution: Optional[list[bool]] = None, - type_map: Optional[list[str]] = None, + trainable: bool | list[bool] = True, + remove_vaccum_contribution: list[bool] | None = None, + type_map: list[str] | None = None, use_aparam_as_mask: bool = False, - default_fparam: Optional[list[float]] = None, - shift: Optional[Union[list[float], torch.Tensor]] = None, - amplitude: Optional[Union[list[float], torch.Tensor]] = None, + default_fparam: list[float] | None = None, + shift: list[float] | torch.Tensor | None = None, + amplitude: list[float] | torch.Tensor | None = None, n_dl: int = 1, remove_self_interaction: bool = False, **kwargs: Any, @@ -236,7 +241,7 @@ def output_def(self) -> FittingOutputDef: reducible=False, r_differentiable=False, c_differentiable=False, - ) + ), ] ) @@ -324,11 +329,11 @@ def forward( self, descriptor: torch.Tensor, atype: torch.Tensor, - gr: Optional[torch.Tensor] = None, - g2: Optional[torch.Tensor] = None, - h2: Optional[torch.Tensor] = None, - fparam: Optional[torch.Tensor] = None, - aparam: Optional[torch.Tensor] = None, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, ) -> dict[str, torch.Tensor]: out = self._forward_common( descriptor=descriptor, @@ -346,6 +351,6 @@ def forward( if "middle_output" in out: result["middle_output"] = out["middle_output"] return result - + # make jit happy with torch 2.0.0 - exclude_types: list[int] \ No newline at end of file + exclude_types: list[int] diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 448c359712..a453b0141c 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2058,11 +2058,17 @@ def fitting_polar() -> list[Argument]: # return fitting_polar() @fitting_args_plugin.register("sog_energy", doc=doc_sog) def fitting_sog_energy() -> list[Argument]: - doc_var_name = "The atomic property name used by the LR fitting net. Usually set to `energy`." + doc_var_name = ( + "The atomic property name used by the LR fitting net. Usually set to `energy`." + ) doc_dim_out_sr = "The output dimension of the short-range fitting branch." doc_dim_out_lr = "The output dimension of the long-range fitting branch." - doc_neuron_sr = "The number of neurons in each hidden layer of the short-range fitting net." - doc_neuron_lr = "The number of neurons in each hidden layer of the long-range fitting net." + doc_neuron_sr = ( + "The number of neurons in each hidden layer of the short-range fitting net." + ) + doc_neuron_lr = ( + "The number of neurons in each hidden layer of the long-range fitting net." + ) doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." doc_default_fparam = "The default frame parameter. If set, when `fparam.npy` files are not included in the data system, this value will be used as the default value for the frame parameter in the fitting net." @@ -2073,7 +2079,9 @@ def fitting_sog_energy() -> list[Argument]: doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be bool or list[bool]. When list[bool] is given, all values must be True to make parameters trainable." doc_rcond = "The condition number used to determine the initial energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details." doc_seed = "Random seed for parameter initialization of the fitting net" - doc_exclude_types = "The excluded atom types whose atomic contributions are set to zero." + doc_exclude_types = ( + "The excluded atom types whose atomic contributions are set to zero." + ) doc_use_aparam_as_mask = ( "Whether to use the aparam as a mask in input." "If True, the aparam will not be used in fitting net for embedding." @@ -2246,11 +2254,17 @@ def fitting_sog_energy() -> list[Argument]: @fitting_args_plugin.register("les_energy", doc=doc_les) def fitting_les_energy() -> list[Argument]: - doc_var_name = "The atomic property name used by the LR fitting net. Usually set to `energy`." + doc_var_name = ( + "The atomic property name used by the LR fitting net. Usually set to `energy`." + ) doc_dim_out_sr = "The output dimension of the short-range fitting branch." doc_dim_out_lr = "The output dimension of the long-range fitting branch." - doc_neuron_sr = "The number of neurons in each hidden layer of the short-range fitting net." - doc_neuron_lr = "The number of neurons in each hidden layer of the long-range fitting net." + doc_neuron_sr = ( + "The number of neurons in each hidden layer of the short-range fitting net." + ) + doc_neuron_lr = ( + "The number of neurons in each hidden layer of the long-range fitting net." + ) doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." doc_numb_aparam = "The dimension of the atomic parameter. If set to >0, file `aparam.npy` should be included to provided the input aparams." doc_default_fparam = "The default frame parameter. If set, when `fparam.npy` files are not included in the data system, this value will be used as the default value for the frame parameter in the fitting net." @@ -2261,7 +2275,9 @@ def fitting_les_energy() -> list[Argument]: doc_trainable = "Whether the parameters in the fitting net are trainable. This option can be bool or list[bool]. When list[bool] is given, all values must be True to make parameters trainable." doc_rcond = "The condition number used to determine the initial energy shift for each type of atoms. See `rcond` in :py:meth:`numpy.linalg.lstsq` for more details." doc_seed = "Random seed for parameter initialization of the fitting net" - doc_exclude_types = "The excluded atom types whose atomic contributions are set to zero." + doc_exclude_types = ( + "The excluded atom types whose atomic contributions are set to zero." + ) doc_use_aparam_as_mask = ( "Whether to use the aparam as a mask in input." "If True, the aparam will not be used in fitting net for embedding." @@ -2432,7 +2448,6 @@ def fitting_les_energy() -> list[Argument]: ] - @fitting_args_plugin.register("dipole", doc=doc_dipole) def fitting_dipole() -> list[Argument]: doc_numb_fparam = "The dimension of the frame parameter. If set to >0, file `fparam.npy` should be included to provided the input fparams." diff --git a/examples/water/sog/ab_retain_graph.py b/examples/water/sog/ab_retain_graph.py index 560d50fe71..96b9044fa9 100644 --- a/examples/water/sog/ab_retain_graph.py +++ b/examples/water/sog/ab_retain_graph.py @@ -1,14 +1,23 @@ #!/usr/bin/env python3 -from __future__ import annotations +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) import json import time -from pathlib import Path -from types import MethodType +from pathlib import ( + Path, +) +from types import ( + MethodType, +) import torch -from deepmd.pt.model.model import get_model +from deepmd.pt.model.model import ( + get_model, +) def sync(dev: torch.device) -> None: @@ -65,8 +74,12 @@ def patched_apply(self, model_ret, extended_coord, nlist, box, do_atomic_virial) coord_local = extended_coord[:, :nloc, :] box_local = box.view(nf, 3, 3) latent_charge = model_ret["latent_charge"] - corr_redu = self._compute_sog_frame_correction(coord_local, latent_charge, box_local) - model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to(model_ret["energy_redu"].dtype) + corr_redu = self._compute_sog_frame_correction( + coord_local, latent_charge, box_local + ) + model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to( + model_ret["energy_redu"].dtype + ) if self.do_grad_r("energy") or self.do_grad_c("energy"): corr_force_local = -torch.autograd.grad( @@ -83,9 +96,9 @@ def patched_apply(self, model_ret, extended_coord, nlist, box, do_atomic_virial) ) corr_force_ext[:, :nloc, :] = corr_force_local if "energy_derv_r" in model_ret: - model_ret["energy_derv_r"] = model_ret["energy_derv_r"] + corr_force_ext.unsqueeze(-2).to( - model_ret["energy_derv_r"].dtype - ) + model_ret["energy_derv_r"] = model_ret[ + "energy_derv_r" + ] + corr_force_ext.unsqueeze(-2).to(model_ret["energy_derv_r"].dtype) if self.do_grad_c("energy"): corr_virial_local = torch.einsum( @@ -95,9 +108,9 @@ def patched_apply(self, model_ret, extended_coord, nlist, box, do_atomic_virial) ).reshape(nf, nloc, 1, 9) corr_virial_redu = corr_virial_local.sum(dim=1) if "energy_derv_c_redu" in model_ret: - model_ret["energy_derv_c_redu"] = model_ret["energy_derv_c_redu"] + corr_virial_redu.to( - model_ret["energy_derv_c_redu"].dtype - ) + model_ret["energy_derv_c_redu"] = model_ret[ + "energy_derv_c_redu" + ] + corr_virial_redu.to(model_ret["energy_derv_c_redu"].dtype) if do_atomic_virial and "energy_derv_c" in model_ret: corr_atom_virial = torch.zeros( (nf, nall, 1, 9), @@ -105,9 +118,9 @@ def patched_apply(self, model_ret, extended_coord, nlist, box, do_atomic_virial) device=corr_virial_local.device, ) corr_atom_virial[:, :nloc, :, :] = corr_virial_local - model_ret["energy_derv_c"] = model_ret["energy_derv_c"] + corr_atom_virial.to( - model_ret["energy_derv_c"].dtype - ) + model_ret["energy_derv_c"] = model_ret[ + "energy_derv_c" + ] + corr_atom_virial.to(model_ret["energy_derv_c"].dtype) return model_ret diff --git a/examples/water/sog/check_sog_consistency_with_cace.py b/examples/water/sog/check_sog_consistency_with_cace.py index 35c34e81c8..4efd9e3b96 100644 --- a/examples/water/sog/check_sog_consistency_with_cace.py +++ b/examples/water/sog/check_sog_consistency_with_cace.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 -from __future__ import annotations +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) import importlib.util import json @@ -7,11 +10,15 @@ import torch -from deepmd.pt.model.model import get_model +from deepmd.pt.model.model import ( + get_model, +) def main() -> None: - cfg = json.loads(pathlib.Path("examples/water/sog/input_torch.json").read_text())["model"] + cfg = json.loads(pathlib.Path("examples/water/sog/input_torch.json").read_text())[ + "model" + ] dev = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = get_model(cfg).to(dev).eval() fit = model.get_fitting_net() @@ -38,10 +45,14 @@ def main() -> None: latent = torch.randn(nf, nloc, nq, device=dev, dtype=torch.float32) with torch.no_grad(): - corr_deepmd = model._compute_sog_frame_correction(coord, latent, box).squeeze(-1) + corr_deepmd = model._compute_sog_frame_correction(coord, latent, box).squeeze( + -1 + ) corr_cace = [] for i in range(nf): - v = cace_sog.compute_potential_SOG_triclinic_NUFFT(coord[i], latent[i], box[i]) + v = cace_sog.compute_potential_SOG_triclinic_NUFFT( + coord[i], latent[i], box[i] + ) corr_cace.append(v.squeeze()) corr_cace = torch.stack(corr_cace) diff --git a/examples/water/sog/compare_sog_dpa3_timing.py b/examples/water/sog/compare_sog_dpa3_timing.py index f067024d18..d19098fb7b 100644 --- a/examples/water/sog/compare_sog_dpa3_timing.py +++ b/examples/water/sog/compare_sog_dpa3_timing.py @@ -1,13 +1,20 @@ #!/usr/bin/env python3 -from __future__ import annotations +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) import json import time -from pathlib import Path +from pathlib import ( + Path, +) import torch -from deepmd.pt.model.model import get_model +from deepmd.pt.model.model import ( + get_model, +) def sync(dev: torch.device) -> None: @@ -54,8 +61,12 @@ def bench_model( def main() -> None: - sog_cfg = json.loads(Path("examples/water/sog/input_torch.json").read_text())["model"] - dpa_cfg = json.loads(Path("examples/water/dpa3/input_torch_copy.json").read_text())["model"] + sog_cfg = json.loads(Path("examples/water/sog/input_torch.json").read_text())[ + "model" + ] + dpa_cfg = json.loads(Path("examples/water/dpa3/input_torch_copy.json").read_text())[ + "model" + ] coord, atype, box, dev = build_input() diff --git a/examples/water/sog/profile_sog_timing.py b/examples/water/sog/profile_sog_timing.py index 55dd7e9079..78876f40ce 100644 --- a/examples/water/sog/profile_sog_timing.py +++ b/examples/water/sog/profile_sog_timing.py @@ -1,24 +1,43 @@ #!/usr/bin/env python3 +# SPDX-License-Identifier: LGPL-3.0-or-later """Profile SOG model runtime breakdown using input_torch.json as model config.""" -from __future__ import annotations +from __future__ import ( + annotations, +) import argparse import json import time import types -from collections import defaultdict -from contextlib import nullcontext -from pathlib import Path -from typing import Any +from collections import ( + defaultdict, +) +from contextlib import ( + nullcontext, +) +from pathlib import ( + Path, +) +from typing import ( + Any, +) -import torch import pytorch_finufft +import torch -from deepmd.pt.model.model import get_model -from deepmd.pt.model.model.sog_model import SOGEnergyModel_ -from deepmd.pt.model.model.transform_output import communicate_extended_output -from deepmd.pt.utils.nlist import extend_input_and_build_neighbor_list +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.model.model.sog_model import ( + SOGEnergyModel_, +) +from deepmd.pt.model.model.transform_output import ( + communicate_extended_output, +) +from deepmd.pt.utils.nlist import ( + extend_input_and_build_neighbor_list, +) def _sync_if_cuda(device: torch.device) -> None: @@ -88,7 +107,9 @@ def _timed_bundle( need_virial: bool, ) -> dict[str, torch.Tensor]: if coord.dim() != 3: - raise ValueError(f"`coord` should be [nf, nloc, 3], got shape {tuple(coord.shape)}") + raise ValueError( + f"`coord` should be [nf, nloc, 3], got shape {tuple(coord.shape)}" + ) if latent_charge.dim() != 3: raise ValueError( f"`latent_charge` should be [nf, nloc, nq], got shape {tuple(latent_charge.shape)}" @@ -102,14 +123,26 @@ def _timed_bundle( fitting = self.get_fitting_net() runtime_device = coord.device real_dtype = coord.dtype - complex_dtype = torch.complex128 if real_dtype == torch.float64 else torch.complex64 - with _time_block("fc_cast_inputs", detail_times, device) if collect_flag["on"] else nullcontext(): + complex_dtype = ( + torch.complex128 if real_dtype == torch.float64 else torch.complex64 + ) + with ( + _time_block("fc_cast_inputs", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): latent_charge = latent_charge.to(device=runtime_device, dtype=real_dtype) box = box.to(device=runtime_device, dtype=real_dtype) if box.dim() != 3 or box.shape[-2:] != (3, 3): - raise ValueError(f"`box` should be [nf, 3, 3], got shape {tuple(box.shape)}") + raise ValueError( + f"`box` should be [nf, 3, 3], got shape {tuple(box.shape)}" + ) - with _time_block("fc_param_prepare", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_param_prepare", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): wl, _sl, min_term = self._get_cached_sog_params( fitting, runtime_device, @@ -118,7 +151,9 @@ def _timed_bundle( remove_self_interaction = bool(fitting.remove_self_interaction) n_dl = int(fitting.n_dl) pi_tensor = torch.tensor(torch.pi, dtype=real_dtype, device=runtime_device) - two_pi = torch.tensor(2.0 * torch.pi, dtype=real_dtype, device=runtime_device) + two_pi = torch.tensor( + 2.0 * torch.pi, dtype=real_dtype, device=runtime_device + ) nf, nloc, _ = coord.shape corr = torch.zeros((nf, 1), dtype=real_dtype, device=runtime_device) @@ -138,10 +173,16 @@ def _timed_bundle( q = latent_charge[ff] box_frame = box[ff] - with _time_block("fc_geom_and_points", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_geom_and_points", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): volume = torch.det(box_frame) if torch.abs(volume) <= torch.finfo(real_dtype).eps: - raise ValueError("`box` is singular (near-zero volume), cannot run NUFFT.") + raise ValueError( + "`box` is singular (near-zero volume), cannot run NUFFT." + ) cell_inv = torch.linalg.inv(box_frame) r_frac = torch.matmul(r_raw, cell_inv) @@ -154,27 +195,53 @@ def _timed_bundle( ).contiguous() nufft_points = r_in.transpose(0, 1).contiguous() - with _time_block("fc_build_k_grid", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_build_k_grid", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): norms = torch.norm(box_frame, dim=1) nk = tuple(max(1, int(n.item() / n_dl)) for n in norms) - n1 = torch.arange(-nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype) - n2 = torch.arange(-nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype) - n3 = torch.arange(-nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype) + n1 = torch.arange( + -nk[0], nk[0] + 1, device=runtime_device, dtype=real_dtype + ) + n2 = torch.arange( + -nk[1], nk[1] + 1, device=runtime_device, dtype=real_dtype + ) + n3 = torch.arange( + -nk[2], nk[2] + 1, device=runtime_device, dtype=real_dtype + ) kx_grid, ky_grid, kz_grid = torch.meshgrid(n1, n2, n3, indexing="ij") k_sq = kx_grid**2 + ky_grid**2 + kz_grid**2 zero_mask = k_sq == 0 - with _time_block("fc_build_kfac", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_build_kfac", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): kfac = wl.view(1, 1, 1, -1) * torch.exp(k_sq.unsqueeze(-1) * min_term) kfac = kfac.sum(dim=-1) kfac = kfac.to(dtype=real_dtype) kfac[zero_mask] = 0.0 - with _time_block("fc_prepare_charge", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_prepare_charge", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): q_t = q.transpose(0, 1).contiguous() - charge = torch.complex(q_t, torch.zeros_like(q_t)).to(dtype=complex_dtype).contiguous() + charge = ( + torch.complex(q_t, torch.zeros_like(q_t)) + .to(dtype=complex_dtype) + .contiguous() + ) - with _time_block("fc_nufft_type1", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_nufft_type1", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): recon = pytorch_finufft.functional.finufft_type1( nufft_points, charge, @@ -183,23 +250,41 @@ def _timed_bundle( isign=-1, ) - with _time_block("fc_energy_reduce", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_energy_reduce", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): rho_sq = recon.real.square() + recon.imag.square() corr[ff, 0] = (kfac.unsqueeze(0) * rho_sq).sum() / (2.0 * volume) if need_force: - with _time_block("fc_prepare_force_conv", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_prepare_force_conv", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): conv = kfac.unsqueeze(0).to(dtype=complex_dtype) * recon - with _time_block("fc_prepare_force_kgrid", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_prepare_force_kgrid", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): kk1 = torch.fft.ifftshift(kx_grid, dim=0) kk2 = torch.fft.ifftshift(ky_grid, dim=1) kk3 = torch.fft.ifftshift(kz_grid, dim=2) k_grid = torch.stack((kk1, kk2, kk3), dim=0) g_cart = two_pi * torch.einsum("ik,k...->i...", cell_inv, k_grid) - grad_conv = (1j * g_cart.unsqueeze(1).to(dtype=complex_dtype)) * conv.unsqueeze(0) - - with _time_block("fc_nufft_type2_force", detail_times, device) if collect_flag["on"] else nullcontext(): + grad_conv = ( + 1j * g_cart.unsqueeze(1).to(dtype=complex_dtype) + ) * conv.unsqueeze(0) + + with ( + _time_block("fc_nufft_type2_force", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): grad_field = pytorch_finufft.functional.finufft_type2( nufft_points, grad_conv, @@ -207,13 +292,25 @@ def _timed_bundle( isign=1, ) - with _time_block("fc_force_reduce", detail_times, device) if collect_flag["on"] else nullcontext(): - force_frame = -(q_t.unsqueeze(0) * grad_field.real.to(dtype=real_dtype)).sum(dim=1).transpose(0, 1) + with ( + _time_block("fc_force_reduce", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + force_frame = ( + -(q_t.unsqueeze(0) * grad_field.real.to(dtype=real_dtype)) + .sum(dim=1) + .transpose(0, 1) + ) force_frame = force_frame / volume force_local[ff] = force_frame if need_virial: - with _time_block("fc_virial_local", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_virial_local", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): virial_local[ff] = torch.einsum( "ai,aj->aij", force_frame, @@ -221,7 +318,11 @@ def _timed_bundle( ).reshape(nloc, 1, 9) if remove_self_interaction: - with _time_block("fc_self_interaction", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_self_interaction", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): diag_sum = kfac.sum(dim=-1).sum(dim=-1).sum(dim=-1) / (2.0 * volume) corr[ff, 0] -= torch.sum(q**2) * diag_sum @@ -240,7 +341,11 @@ def _timed_apply( box: torch.Tensor | None, do_atomic_virial: bool, ) -> dict[str, torch.Tensor]: - with _time_block("fc_guard_and_slice", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_guard_and_slice", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): if box is None or "latent_charge" not in model_ret: return model_ret @@ -251,9 +356,15 @@ def _timed_apply( latent_charge = model_ret["latent_charge"] need_force = self.do_grad_r("energy") or self.do_grad_c("energy") need_virial = self.do_grad_c("energy") - latent_charge_runtime = latent_charge if self.training else latent_charge.detach() + latent_charge_runtime = ( + latent_charge if self.training else latent_charge.detach() + ) - with _time_block("fc_compute_corr_bundle", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_compute_corr_bundle", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): corr_bundle = self._compute_sog_frame_correction_bundle( coord_local, latent_charge_runtime, @@ -263,13 +374,23 @@ def _timed_apply( ) corr_redu = corr_bundle["corr_redu"] - with _time_block("fc_add_energy", detail_times, device) if collect_flag["on"] else nullcontext(): - model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to(model_ret["energy_redu"].dtype) + with ( + _time_block("fc_add_energy", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): + model_ret["energy_redu"] = model_ret["energy_redu"] + corr_redu.to( + model_ret["energy_redu"].dtype + ) if need_force: corr_force_local = corr_bundle["force_local"].to(coord_local.dtype) - with _time_block("fc_scatter_force", detail_times, device) if collect_flag["on"] else nullcontext(): + with ( + _time_block("fc_scatter_force", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): corr_force_ext = torch.zeros( (nf, nall, 3), dtype=corr_force_local.dtype, @@ -277,18 +398,26 @@ def _timed_apply( ) corr_force_ext[:, :nloc, :] = corr_force_local if "energy_derv_r" in model_ret: - model_ret["energy_derv_r"] = model_ret["energy_derv_r"] + corr_force_ext.unsqueeze(-2).to( + model_ret["energy_derv_r"] = model_ret[ + "energy_derv_r" + ] + corr_force_ext.unsqueeze(-2).to( model_ret["energy_derv_r"].dtype ) if need_virial: - corr_virial_local = corr_bundle["virial_local"].to(corr_force_local.dtype) - with _time_block("fc_virial_update", detail_times, device) if collect_flag["on"] else nullcontext(): + corr_virial_local = corr_bundle["virial_local"].to( + corr_force_local.dtype + ) + with ( + _time_block("fc_virial_update", detail_times, device) + if collect_flag["on"] + else nullcontext() + ): corr_virial_redu = corr_virial_local.sum(dim=1) if "energy_derv_c_redu" in model_ret: - model_ret["energy_derv_c_redu"] = model_ret["energy_derv_c_redu"] + corr_virial_redu.to( - model_ret["energy_derv_c_redu"].dtype - ) + model_ret["energy_derv_c_redu"] = model_ret[ + "energy_derv_c_redu" + ] + corr_virial_redu.to(model_ret["energy_derv_c_redu"].dtype) if do_atomic_virial and "energy_derv_c" in model_ret: corr_atom_virial = torch.zeros( (nf, nall, 1, 9), @@ -296,9 +425,9 @@ def _timed_apply( device=corr_virial_local.device, ) corr_atom_virial[:, :nloc, :, :] = corr_virial_local - model_ret["energy_derv_c"] = model_ret["energy_derv_c"] + corr_atom_virial.to( - model_ret["energy_derv_c"].dtype - ) + model_ret["energy_derv_c"] = model_ret[ + "energy_derv_c" + ] + corr_atom_virial.to(model_ret["energy_derv_c"].dtype) return model_ret @@ -514,7 +643,7 @@ def _format_report(timings: dict[str, float]) -> str: ratio = (timings[k] / total * 100.0) if total > 0 else 0.0 lines.append(f" - {k:20s}: {ms:10.3f} ms ({ratio:6.2f}%)") - lines.append(f" - {'total(sum)':20s}: {total*1000.0:10.3f} ms (100.00%)") + lines.append(f" - {'total(sum)':20s}: {total * 1000.0:10.3f} ms (100.00%)") return "\n".join(lines) @@ -526,7 +655,9 @@ def main() -> None: default=Path("examples/water/sog/input_torch.json"), ) parser.add_argument("--device", type=str, default="cuda") - parser.add_argument("--dtype", type=str, default="float32", choices=["float32", "float64"]) + parser.add_argument( + "--dtype", type=str, default="float32", choices=["float32", "float64"] + ) parser.add_argument("--nframes", type=int, default=1) parser.add_argument("--nloc", type=int, default=192) parser.add_argument("--box-len", type=float, default=20.0) diff --git a/examples/water/sog/profile_sog_whatif.py b/examples/water/sog/profile_sog_whatif.py index 5132e5f180..d56e30091a 100644 --- a/examples/water/sog/profile_sog_whatif.py +++ b/examples/water/sog/profile_sog_whatif.py @@ -1,15 +1,23 @@ #!/usr/bin/env python3 -from __future__ import annotations +# SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) import copy import json -from pathlib import Path +from pathlib import ( + Path, +) import torch +from profile_sog_timing import ( + profile, +) -from deepmd.pt.model.model import get_model -from profile_sog_timing import profile - +from deepmd.pt.model.model import ( + get_model, +) CFG_PATH = Path("examples/water/sog/input_torch.json") @@ -38,10 +46,10 @@ def run(tag: str, model_cfg: dict) -> None: ) * 1000.0 print( f"{tag}: total={total:.3f}ms, " - f"lower_super={t['lower_super']*1000.0:.3f}ms, " - f"lower_frame_corr={t['lower_frame_corr']*1000.0:.3f}ms, " - f"nufft1={t.get('nufft_type1', 0.0)*1000.0:.3f}ms, " - f"nufft2={t.get('nufft_type2', 0.0)*1000.0:.3f}ms" + f"lower_super={t['lower_super'] * 1000.0:.3f}ms, " + f"lower_frame_corr={t['lower_frame_corr'] * 1000.0:.3f}ms, " + f"nufft1={t.get('nufft_type1', 0.0) * 1000.0:.3f}ms, " + f"nufft2={t.get('nufft_type2', 0.0) * 1000.0:.3f}ms" ) diff --git a/source/tests/pt/model/test_les_working_layer.py b/source/tests/pt/model/test_les_working_layer.py index 02d0cd7578..b0ce487a1f 100644 --- a/source/tests/pt/model/test_les_working_layer.py +++ b/source/tests/pt/model/test_les_working_layer.py @@ -26,11 +26,12 @@ extend_input_and_build_neighbor_list, ) - dtype = env.GLOBAL_PT_FLOAT_PRECISION -def _reduce_extended_tensor(extended_tensor: torch.Tensor, mapping: torch.Tensor, nloc: int): +def _reduce_extended_tensor( + extended_tensor: torch.Tensor, mapping: torch.Tensor, nloc: int +): nframes = extended_tensor.shape[0] ext_dims = extended_tensor.shape[2:] reduced_tensor = torch.zeros( @@ -88,7 +89,11 @@ def setUp(self) -> None: dtype=dtype, device=env.DEVICE, ) - cell = torch.eye(3, dtype=dtype, device=env.DEVICE).unsqueeze(0).repeat(self.nf, 1, 1) + cell = ( + torch.eye(3, dtype=dtype, device=env.DEVICE) + .unsqueeze(0) + .repeat(self.nf, 1, 1) + ) cell = cell * 5.0 self.coord = coord.reshape(self.nf, self.nloc * 3) self.cell = cell.reshape(self.nf, 9) @@ -171,10 +176,16 @@ def test_forward_and_forward_lower_consistency(self) -> None: comm_dict={"box": cell33}, ) - torch.testing.assert_close(fw_lower["energy"], fw["energy"], rtol=1e-8, atol=1e-8) - torch.testing.assert_close(fw_lower["virial"], fw["virial"], rtol=1e-7, atol=1e-7) + torch.testing.assert_close( + fw_lower["energy"], fw["energy"], rtol=1e-8, atol=1e-8 + ) + torch.testing.assert_close( + fw_lower["virial"], fw["virial"], rtol=1e-7, atol=1e-7 + ) - reduced_force = _reduce_extended_tensor(fw_lower["extended_force"], mapping, self.nloc) + reduced_force = _reduce_extended_tensor( + fw_lower["extended_force"], mapping, self.nloc + ) torch.testing.assert_close(reduced_force, fw["force"], rtol=1e-7, atol=1e-7) def test_long_range_params_have_gradient_path(self) -> None: diff --git a/source/tests/pt/model/test_sog_working_layer.py b/source/tests/pt/model/test_sog_working_layer.py index 924b02bfa6..7e4ed7d89b 100644 --- a/source/tests/pt/model/test_sog_working_layer.py +++ b/source/tests/pt/model/test_sog_working_layer.py @@ -26,11 +26,12 @@ extend_input_and_build_neighbor_list, ) - dtype = env.GLOBAL_PT_FLOAT_PRECISION -def _reduce_extended_tensor(extended_tensor: torch.Tensor, mapping: torch.Tensor, nloc: int): +def _reduce_extended_tensor( + extended_tensor: torch.Tensor, mapping: torch.Tensor, nloc: int +): nframes = extended_tensor.shape[0] ext_dims = extended_tensor.shape[2:] reduced_tensor = torch.zeros( @@ -88,7 +89,11 @@ def setUp(self) -> None: dtype=dtype, device=env.DEVICE, ) - cell = torch.eye(3, dtype=dtype, device=env.DEVICE).unsqueeze(0).repeat(self.nf, 1, 1) + cell = ( + torch.eye(3, dtype=dtype, device=env.DEVICE) + .unsqueeze(0) + .repeat(self.nf, 1, 1) + ) cell = cell * 5.0 self.coord = coord.reshape(self.nf, self.nloc * 3) self.cell = cell.reshape(self.nf, 9) @@ -171,10 +176,16 @@ def test_forward_and_forward_lower_consistency(self) -> None: comm_dict={"box": cell33}, ) - torch.testing.assert_close(fw_lower["energy"], fw["energy"], rtol=1e-8, atol=1e-8) - torch.testing.assert_close(fw_lower["virial"], fw["virial"], rtol=1e-7, atol=1e-7) + torch.testing.assert_close( + fw_lower["energy"], fw["energy"], rtol=1e-8, atol=1e-8 + ) + torch.testing.assert_close( + fw_lower["virial"], fw["virial"], rtol=1e-7, atol=1e-7 + ) - reduced_force = _reduce_extended_tensor(fw_lower["extended_force"], mapping, self.nloc) + reduced_force = _reduce_extended_tensor( + fw_lower["extended_force"], mapping, self.nloc + ) torch.testing.assert_close(reduced_force, fw["force"], rtol=1e-7, atol=1e-7) From 6c3f79298b4280cda9fb8e782a06151492b9a611 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Apr 2026 09:23:50 +0000 Subject: [PATCH 8/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pt/model/atomic_model/lr_energy_atomic_model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py b/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py index c44ef3f1d7..a21127595c 100644 --- a/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py +++ b/deepmd/pt/model/atomic_model/lr_energy_atomic_model.py @@ -1,8 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from collections.abc import ( + Iterable, +) from typing import ( Any, ) -from collections.abc import Iterable import torch