From ec2e031e4384a66e96c2d64d45f4dec897bb769d Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 07:48:20 +0800 Subject: [PATCH 01/60] implement pytorch-exportable for se_e2_a descriptor --- deepmd/backend/pt_expt.py | 126 ++++++++++++++++ deepmd/dpmodel/descriptor/se_e2_a.py | 6 +- deepmd/pt_expt/__init__.py | 1 + deepmd/pt_expt/descriptor/__init__.py | 8 ++ deepmd/pt_expt/descriptor/se_e2_a.py | 101 +++++++++++++ deepmd/pt_expt/utils/__init__.py | 1 + deepmd/pt_expt/utils/network.py | 130 +++++++++++++++++ source/tests/consistent/common.py | 80 ++++++++++- source/tests/consistent/descriptor/common.py | 31 +++- .../consistent/descriptor/test_se_e2_a.py | 60 ++++++++ source/tests/pt_expt/__init__.py | 1 + source/tests/pt_expt/model/__init__.py | 1 + source/tests/pt_expt/model/test_se_e2_a.py | 135 ++++++++++++++++++ 13 files changed, 676 insertions(+), 5 deletions(-) create mode 100644 deepmd/backend/pt_expt.py create mode 100644 deepmd/pt_expt/__init__.py create mode 100644 deepmd/pt_expt/descriptor/__init__.py create mode 100644 deepmd/pt_expt/descriptor/se_e2_a.py create mode 100644 deepmd/pt_expt/utils/__init__.py create mode 100644 deepmd/pt_expt/utils/network.py create mode 100644 source/tests/pt_expt/__init__.py create mode 100644 source/tests/pt_expt/model/__init__.py create mode 100644 source/tests/pt_expt/model/test_se_e2_a.py diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py new file mode 100644 index 0000000000..38745c690c --- /dev/null +++ b/deepmd/backend/pt_expt.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from collections.abc import ( + Callable, +) +from importlib.util import ( + find_spec, +) +from typing import ( + TYPE_CHECKING, + ClassVar, +) + +from deepmd.backend.backend import ( + Backend, +) + +if TYPE_CHECKING: + from argparse import ( + Namespace, + ) + + from deepmd.infer.deep_eval import ( + DeepEvalBackend, + ) + from deepmd.utils.neighbor_stat import ( + NeighborStat, + ) + + +@Backend.register("pt-expt") +@Backend.register("pytorch-exportable") +class PyTorchExportableBackend(Backend): + """PyTorch exportable backend.""" + + name = "PyTorch Exportable" + """The formal name of the backend.""" + features: ClassVar[Backend.Feature] = ( + Backend.Feature.ENTRY_POINT + | Backend.Feature.DEEP_EVAL + | Backend.Feature.NEIGHBOR_STAT + | Backend.Feature.IO + ) + """The features of the backend.""" + suffixes: ClassVar[list[str]] = [".pth", ".pt"] + """The suffixes of the backend.""" + + def is_available(self) -> bool: + """Check if the backend is available. + + Returns + ------- + bool + Whether the backend is available. + """ + return find_spec("torch") is not None + + @property + def entry_point_hook(self) -> Callable[["Namespace"], None]: + """The entry point hook of the backend. + + Returns + ------- + Callable[[Namespace], None] + The entry point hook of the backend. + """ + from deepmd.pt.entrypoints.main import main as deepmd_main + + return deepmd_main + + @property + def deep_eval(self) -> type["DeepEvalBackend"]: + """The Deep Eval backend of the backend. + + Returns + ------- + type[DeepEvalBackend] + The Deep Eval backend of the backend. + """ + from deepmd.pt.infer.deep_eval import DeepEval as DeepEvalPT + + return DeepEvalPT + + @property + def neighbor_stat(self) -> type["NeighborStat"]: + """The neighbor statistics of the backend. + + Returns + ------- + type[NeighborStat] + The neighbor statistics of the backend. + """ + from deepmd.pt.utils.neighbor_stat import ( + NeighborStat, + ) + + return NeighborStat + + @property + def serialize_hook(self) -> Callable[[str], dict]: + """The serialize hook to convert the model file to a dictionary. + + Returns + ------- + Callable[[str], dict] + The serialize hook of the backend. + """ + from deepmd.pt.utils.serialization import ( + serialize_from_file, + ) + + return serialize_from_file + + @property + def deserialize_hook(self) -> Callable[[str, dict], None]: + """The deserialize hook to convert the dictionary to a model file. + + Returns + ------- + Callable[[str, dict], None] + The deserialize hook of the backend. + """ + from deepmd.pt.utils.serialization import ( + deserialize_to_file, + ) + + return deserialize_to_file diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index c09a6cbdc3..a6b17bf69a 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -607,7 +607,11 @@ def call( sec = self.sel_cumsum ng = self.neuron[-1] - gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype) + gr = xp.zeros( + [nf * nloc, ng, 4], + dtype=self.dstd.dtype, + device=array_api_compat.device(coord_ext), + ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # merge nf and nloc axis, so for type_one_side == False, # we don't require atype is the same in all frames diff --git a/deepmd/pt_expt/__init__.py b/deepmd/pt_expt/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt_expt/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py new file mode 100644 index 0000000000..fdac48ed41 --- /dev/null +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .se_e2_a import ( + DescrptSeA, +) + +__all__ = [ + "DescrptSeA", +] diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py new file mode 100644 index 0000000000..4334011ec3 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch # noqa: TID253 + +from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP +from deepmd.pt.model.descriptor.base_descriptor import ( # noqa: TID253 + BaseDescriptor, +) +from deepmd.pt.utils import ( # noqa: TID253 + env, +) +from deepmd.pt.utils.exclude_mask import ( # noqa: TID253 + PairExcludeMask, +) +from deepmd.pt_expt.utils.network import ( + NetworkCollection, +) + + +@BaseDescriptor.register("se_e2_a_expt") +@BaseDescriptor.register("se_a_expt") +class DescrptSeA(DescrptSeADP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptSeADP.__init__(self, *args, **kwargs) + self._convert_state() + + def __setattr__(self, name: str, value: Any) -> None: + if name in {"davg", "dstd"} and "_buffers" in self.__dict__: + tensor = ( + None if value is None else torch.as_tensor(value, device=env.DEVICE) + ) + if name in self._buffers: + self._buffers[name] = tensor + return + return super().__setattr__(name, tensor) + if name == "embeddings" and "_modules" in self.__dict__: + if value is not None and not isinstance(value, torch.nn.Module): + if hasattr(value, "serialize"): + value = NetworkCollection.deserialize(value.serialize()) + elif isinstance(value, dict): + value = NetworkCollection.deserialize(value) + return super().__setattr__(name, value) + if name == "emask" and "_modules" in self.__dict__: + if value is not None and not isinstance(value, torch.nn.Module): + value = PairExcludeMask( + self.ntypes, exclude_types=list(value.get_exclude_types()) + ) + return super().__setattr__(name, value) + return super().__setattr__(name, value) + + def _convert_state(self) -> None: + if self.davg is not None: + davg = torch.as_tensor(self.davg, device=env.DEVICE) + if "davg" in self._buffers: + self._buffers["davg"] = davg + else: + if hasattr(self, "davg"): + delattr(self, "davg") + self.register_buffer("davg", davg) + if self.dstd is not None: + dstd = torch.as_tensor(self.dstd, device=env.DEVICE) + if "dstd" in self._buffers: + self._buffers["dstd"] = dstd + else: + if hasattr(self, "dstd"): + delattr(self, "dstd") + self.register_buffer("dstd", dstd) + if self.embeddings is not None: + self.embeddings = NetworkCollection.deserialize(self.embeddings.serialize()) + if self.emask is not None: + self.emask = PairExcludeMask( + self.ntypes, exclude_types=list(self.emask.get_exclude_types()) + ) + + def forward( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_atype_embd: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + type_embedding: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + del extended_atype_embd, type_embedding + descrpt, rot_mat, g2, h2, sw = self.call( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + return descrpt, rot_mat, g2, h2, sw diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt_expt/utils/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py new file mode 100644 index 0000000000..f29d8970b3 --- /dev/null +++ b/deepmd/pt_expt/utils/network.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + ClassVar, + Self, +) + +import numpy as np +import torch # noqa: TID253 + +from deepmd.dpmodel.common import ( + NativeOP, +) +from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP +from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP +from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP +from deepmd.dpmodel.utils.network import ( + make_embedding_network, + make_fitting_network, + make_multilayer_network, +) +from deepmd.pt.utils import ( # noqa: TID253 + env, +) + + +def _to_torch_array(value: Any) -> torch.Tensor | None: + if value is None: + return None + if torch.is_tensor(value): + return value + return torch.as_tensor(value, device=env.DEVICE) + + +class TorchArrayParam(torch.nn.Parameter): + def __new__(cls, data: Any = None, requires_grad: bool = True) -> Self: + return torch.nn.Parameter.__new__(cls, data, requires_grad) + + def __array__(self, dtype: Any | None = None) -> np.ndarray: + arr = self.detach().cpu().numpy() + if dtype is None: + return arr + return arr.astype(dtype) + + +class NativeLayer(NativeLayerDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + NativeLayerDP.__init__(self, *args, **kwargs) + for name in ("w", "b", "idt"): + if name in self._parameters or name in self._buffers: + continue + val = _to_torch_array(getattr(self, name)) + if val is None: + continue + if self.trainable: + if hasattr(self, name) and name not in self._parameters: + delattr(self, name) + self.register_parameter(name, TorchArrayParam(val, requires_grad=True)) + else: + if hasattr(self, name) and name not in self._buffers: + delattr(self, name) + self.register_buffer(name, val) + + def __setattr__(self, name: str, value: Any) -> None: + if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: + val = _to_torch_array(value) + if val is None: + return super().__setattr__(name, None) + if getattr(self, "trainable", False): + param = ( + value + if isinstance(value, TorchArrayParam) + else TorchArrayParam(val, requires_grad=True) + ) + if name in self._parameters: + self._parameters[name] = param + return + return super().__setattr__(name, param) + if name in self._buffers: + self._buffers[name] = val + return + return super().__setattr__(name, val) + return super().__setattr__(name, value) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +class NativeNet(make_multilayer_network(NativeLayer, NativeOP), torch.nn.Module): + def __init__(self, layers: list[dict] | None = None) -> None: + torch.nn.Module.__init__(self) + super().__init__(layers) + self.layers = torch.nn.ModuleList(self.layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): + pass + + +class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): + pass + + +class NetworkCollection(NetworkCollectionDP, torch.nn.Module): + NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = { + "network": NativeNet, + "embedding_network": EmbeddingNet, + "fitting_network": FittingNet, + } + + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + super().__init__(*args, **kwargs) + self._module_networks = torch.nn.ModuleDict() + for idx, net in enumerate(self._networks): + if isinstance(net, torch.nn.Module): + self._module_networks[str(idx)] = net + + def __setitem__(self, key: int | tuple, value: Any) -> None: + super().__setitem__(key, value) + if isinstance(value, torch.nn.Module): + self._module_networks[str(self._convert_key(key))] = value + + +class LayerNorm(LayerNormDP, NativeLayer): + pass diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 88fad4e10b..3d60f6def0 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -41,6 +41,11 @@ INSTALLED_TF = Backend.get_backend("tensorflow")().is_available() INSTALLED_PT = Backend.get_backend("pytorch")().is_available() +try: + _PT_EXPT_BACKEND = Backend.get_backend("pytorch-exportable") +except (KeyError, RuntimeError): + _PT_EXPT_BACKEND = None +INSTALLED_PT_EXPT = _PT_EXPT_BACKEND is not None and _PT_EXPT_BACKEND().is_available() INSTALLED_JAX = Backend.get_backend("jax")().is_available() INSTALLED_PD = Backend.get_backend("paddle")().is_available() INSTALLED_ARRAY_API_STRICT = find_spec("array_api_strict") is not None @@ -67,6 +72,7 @@ "INSTALLED_JAX", "INSTALLED_PD", "INSTALLED_PT", + "INSTALLED_PT_EXPT", "INSTALLED_TF", "CommonTest", "CommonTest", @@ -86,6 +92,8 @@ class CommonTest(ABC): """Native DP model class.""" pt_class: ClassVar[type | None] """PyTorch model class.""" + pt_expt_class: ClassVar[type | None] + """PyTorch exportable model class.""" jax_class: ClassVar[type | None] """JAX model class.""" pd_class: ClassVar[type | None] @@ -99,6 +107,8 @@ class CommonTest(ABC): """Whether to skip the TensorFlow model.""" skip_pt: ClassVar[bool] = not INSTALLED_PT """Whether to skip the PyTorch model.""" + skip_pt_expt: ClassVar[bool] = not INSTALLED_PT_EXPT + """Whether to skip the PyTorch exportable model.""" # we may usually skip jax before jax is fully supported skip_jax: ClassVar[bool] = True """Whether to skip the JAX model.""" @@ -176,6 +186,16 @@ def eval_pt(self, pt_obj: Any) -> Any: The object of PT """ + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + """Evaluate the return value of PT exportable. + + Parameters + ---------- + pt_expt_obj : Any + The object of PT exportable + """ + raise NotImplementedError("Not implemented") + def eval_jax(self, jax_obj: Any) -> Any: """Evaluate the return value of JAX. @@ -212,9 +232,10 @@ class RefBackend(Enum): TF = 1 DP = 2 PT = 3 - PD = 4 - JAX = 5 - ARRAY_API_STRICT = 6 + PT_EXPT = 4 + PD = 5 + JAX = 6 + ARRAY_API_STRICT = 7 @abstractmethod def extract_ret(self, ret: Any, backend: RefBackend) -> tuple[np.ndarray, ...]: @@ -275,6 +296,11 @@ def get_dp_ret_serialization_from_cls(self, obj): data = obj.serialize() return ret, data + def get_pt_expt_ret_serialization_from_cls(self, obj): + ret = self.eval_pt_expt(obj) + data = obj.serialize() + return ret, data + def get_jax_ret_serialization_from_cls(self, obj): ret = self.eval_jax(obj) data = obj.serialize() @@ -301,6 +327,8 @@ def get_reference_backend(self): return self.RefBackend.TF if not self.skip_pt: return self.RefBackend.PT + if not self.skip_pt_expt and self.pt_expt_class is not None: + return self.RefBackend.PT_EXPT if not self.skip_jax: return self.RefBackend.JAX if not self.skip_pd: @@ -320,6 +348,11 @@ def get_reference_ret_serialization(self, ref: RefBackend): if ref == self.RefBackend.PT: obj = self.init_backend_cls(self.pt_class) return self.get_pt_ret_serialization_from_cls(obj) + if ref == self.RefBackend.PT_EXPT: + if self.pt_expt_class is None: + raise ValueError("PT exportable class is not set") + obj = self.init_backend_cls(self.pt_expt_class) + return self.get_pt_expt_ret_serialization_from_cls(obj) if ref == self.RefBackend.JAX: obj = self.init_backend_cls(self.jax_class) return self.get_jax_ret_serialization_from_cls(obj) @@ -456,6 +489,47 @@ def test_pt_self_consistent(self) -> None: else: self.assertEqual(rr1, rr2) + def test_pt_expt_consistent_with_ref(self) -> None: + """Test whether PT exportable and reference are consistent.""" + if self.skip_pt_expt or self.pt_expt_class is None: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.PT_EXPT: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + obj = self.pt_expt_class.deserialize(data1) + ret2 = self.eval_pt_expt(obj) + ret2 = self.extract_ret(ret2, self.RefBackend.PT_EXPT) + data2 = obj.serialize() + if obj.__class__.__name__.startswith(("Polar", "Dipole", "DOS")): + common_keys = set(data1.keys()) & set(data2.keys()) + data1 = {k: data1[k] for k in common_keys} + data2 = {k: data2[k] for k in common_keys} + # drop @variables since they are not equal + data1.pop("@variables", None) + data2.pop("@variables", None) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2, strict=True): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + + def test_pt_expt_self_consistent(self) -> None: + """Test whether PT exportable is self consistent.""" + if self.skip_pt_expt or self.pt_expt_class is None: + self.skipTest("Unsupported backend") + obj1 = self.init_backend_cls(self.pt_expt_class) + ret1, data1 = self.get_pt_expt_ret_serialization_from_cls(obj1) + obj2 = self.pt_expt_class.deserialize(data1) + ret2, data2 = self.get_pt_expt_ret_serialization_from_cls(obj2) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2, strict=True): + if isinstance(rr1, np.ndarray) and isinstance(rr2, np.ndarray): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + else: + self.assertEqual(rr1, rr2) + def test_jax_consistent_with_ref(self) -> None: """Test whether JAX and reference are consistent.""" if self.skip_jax: diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 8af1c7ea64..7c8cbce744 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -21,10 +21,11 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, ) -if INSTALLED_PT: +if INSTALLED_PT or INSTALLED_PT_EXPT: import torch from deepmd.pt.utils.env import DEVICE as PT_DEVICE @@ -143,6 +144,34 @@ def eval_pt_descriptor( for x in pt_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) ] + def eval_pt_expt_descriptor( + self, + pt_expt_obj: Any, + natoms: np.ndarray, + coords: np.ndarray, + atype: np.ndarray, + box: np.ndarray, + mixed_types: bool = False, + ) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_expt_obj.get_rcut(), + ) + nlist = build_neighbor_list( + ext_coords, + ext_atype, + natoms[0], + pt_expt_obj.get_rcut(), + pt_expt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + return [ + x.detach().cpu().numpy() if torch.is_tensor(x) else x + for x in pt_expt_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) + ] + def eval_jax_descriptor( self, jax_obj: Any, diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index b345a61ed3..68a0068965 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -16,6 +16,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -33,6 +34,10 @@ ) else: DescrptSeAPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_e2_a import DescrptSeA as DescrptSeAPTExpt +else: + DescrptSeAPTExpt = None if INSTALLED_TF: from deepmd.tf.descriptor.se_a import DescrptSeA as DescrptSeATF else: @@ -107,6 +112,17 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + env_protection, + ) = self.param + return (not type_one_side) or CommonTest.skip_pt_expt + @property def skip_dp(self) -> bool: ( @@ -165,6 +181,7 @@ def skip_array_api_strict(self) -> bool: tf_class = DescrptSeATF dp_class = DescrptSeADP pt_class = DescrptSeAPT + pt_expt_class = DescrptSeAPTExpt jax_class = DescrptSeAJAX pd_class = DescrptSeAPD array_api_strict_class = DescrptSeAArrayAPIStrict @@ -244,6 +261,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, @@ -351,6 +377,17 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + env_protection, + ) = self.param + return (not type_one_side) or CommonTest.skip_pt_expt + @property def skip_dp(self) -> bool: ( @@ -402,6 +439,7 @@ def skip_array_api_strict(self) -> bool: tf_class = DescrptSeATF dp_class = DescrptSeADP pt_class = DescrptSeAPT + pt_expt_class = DescrptSeAPTExpt jax_class = DescrptSeAJAX pd_class = DescrptSeAPD array_api_strict_class = DescrptSeAArrayAPIStrict @@ -505,6 +543,28 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + pt_expt_obj.compute_input_stats( + [ + { + "r0": None, + "coord": torch.from_numpy(self.coords) + .reshape(-1, self.natoms[0], 3) + .to(env.DEVICE), + "atype": torch.from_numpy(self.atype.reshape(1, -1)).to(env.DEVICE), + "box": torch.from_numpy(self.box.reshape(1, 3, 3)).to(env.DEVICE), + "natoms": self.natoms[0], + } + ] + ) + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: jax_obj.compute_input_stats( [ diff --git a/source/tests/pt_expt/__init__.py b/source/tests/pt_expt/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/model/__init__.py b/source/tests/pt_expt/model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/model/test_se_e2_a.py b/source/tests/pt_expt/model/test_se_e2_a.py new file mode 100644 index 0000000000..b9b834849f --- /dev/null +++ b/source/tests/pt_expt/model/test_se_e2_a.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch # noqa: TID253 + +from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA +from deepmd.pt.utils import ( # noqa: TID253 + env, +) +from deepmd.pt.utils.env import ( # noqa: TID253 + PRECISION_DICT, +) +from deepmd.pt.utils.exclude_mask import ( # noqa: TID253 + PairExcludeMask, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, em in itertools.product( + [False, True], + ["float64", "float32"], + [[], [[0, 1]], [[1, 1]]], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + exclude_types=em, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeA.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeA.deserialize(dd0.serialize()) + rd2, gr2, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + for aa, bb in zip([rd1, gr1, sw1], [rd2, gr2, sw2], strict=True): + np.testing.assert_allclose( + aa.detach().cpu().numpy(), + bb, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + if em: + dd1.reinit_exclude([tuple(x) for x in em]) + self.assertIsInstance(dd1.emask, PairExcludeMask) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) From b8a48ffe6bdd97ff3dbee7ef13182aa4bcf03a87 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 07:53:13 +0800 Subject: [PATCH 02/60] better type for xp.zeros --- deepmd/dpmodel/descriptor/se_e2_a.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index a6b17bf69a..3ca28ba556 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -609,7 +609,7 @@ def call( ng = self.neuron[-1] gr = xp.zeros( [nf * nloc, ng, 4], - dtype=self.dstd.dtype, + dtype=input_dtype, device=array_api_compat.device(coord_ext), ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) From 1cc001f7f262e6d76f40a8c7d36d9e80a9f1dd19 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 09:51:22 +0800 Subject: [PATCH 03/60] implement env, base_descriptor and exclude_mask, remove the dependency on pt backend. --- deepmd/pt_expt/descriptor/__init__.py | 4 + deepmd/pt_expt/descriptor/base_descriptor.py | 10 ++ deepmd/pt_expt/descriptor/se_e2_a.py | 11 +- deepmd/pt_expt/utils/__init__.py | 10 ++ deepmd/pt_expt/utils/env.py | 117 +++++++++++++++++++ deepmd/pt_expt/utils/exclude_mask.py | 27 +++++ deepmd/pt_expt/utils/network.py | 6 +- source/tests/pt_expt/model/test_se_e2_a.py | 16 +-- 8 files changed, 187 insertions(+), 14 deletions(-) create mode 100644 deepmd/pt_expt/descriptor/base_descriptor.py create mode 100644 deepmd/pt_expt/utils/env.py create mode 100644 deepmd/pt_expt/utils/exclude_mask.py diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py index fdac48ed41..089e5619e0 100644 --- a/deepmd/pt_expt/descriptor/__init__.py +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .base_descriptor import ( + BaseDescriptor, +) from .se_e2_a import ( DescrptSeA, ) __all__ = [ + "BaseDescriptor", "DescrptSeA", ] diff --git a/deepmd/pt_expt/descriptor/base_descriptor.py b/deepmd/pt_expt/descriptor/base_descriptor.py new file mode 100644 index 0000000000..51e9325bba --- /dev/null +++ b/deepmd/pt_expt/descriptor/base_descriptor.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib + +from deepmd.dpmodel.descriptor import ( + make_base_descriptor, +) + +torch = importlib.import_module("torch") + +BaseDescriptor = make_base_descriptor(torch.Tensor, "forward") diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 4334011ec3..bb0c0cb2bd 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -1,24 +1,25 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import importlib from typing import ( Any, ) -import torch # noqa: TID253 - from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP -from deepmd.pt.model.descriptor.base_descriptor import ( # noqa: TID253 +from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.pt.utils import ( # noqa: TID253 +from deepmd.pt_expt.utils import ( env, ) -from deepmd.pt.utils.exclude_mask import ( # noqa: TID253 +from deepmd.pt_expt.utils.exclude_mask import ( PairExcludeMask, ) from deepmd.pt_expt.utils.network import ( NetworkCollection, ) +torch = importlib.import_module("torch") + @BaseDescriptor.register("se_e2_a_expt") @BaseDescriptor.register("se_a_expt") diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index 6ceb116d85..f90cf82249 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -1 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + +from .exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) + +__all__ = [ + "AtomExcludeMask", + "PairExcludeMask", +] diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py new file mode 100644 index 0000000000..bd644e7206 --- /dev/null +++ b/deepmd/pt_expt/utils/env.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import logging +import multiprocessing +import os +import sys + +import numpy as np + +from deepmd.common import ( + VALID_PRECISION, +) +from deepmd.env import ( + GLOBAL_ENER_FLOAT_PRECISION, + GLOBAL_NP_FLOAT_PRECISION, + get_default_nthreads, + set_default_nthreads, +) + +log = logging.getLogger(__name__) +torch = importlib.import_module("torch") + +if sys.platform != "win32": + try: + multiprocessing.set_start_method("fork", force=True) + log.debug("Successfully set multiprocessing start method to 'fork'.") + except (RuntimeError, ValueError) as err: + log.warning(f"Could not set multiprocessing start method: {err}") +else: + log.debug("Skipping fork start method on Windows (not supported).") + +SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) +DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1" +try: + # only linux + ncpus = len(os.sched_getaffinity(0)) +except AttributeError: + ncpus = os.cpu_count() +NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus))) +if multiprocessing.get_start_method() != "fork": + # spawn or forkserver does not support NUM_WORKERS > 0 for DataLoader + log.warning( + "NUM_WORKERS > 0 is not supported with spawn or forkserver start method. " + "Setting NUM_WORKERS to 0." + ) + NUM_WORKERS = 0 + +# Make sure DDP uses correct device if applicable +LOCAL_RANK = os.environ.get("LOCAL_RANK") +LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK) + +if os.environ.get("DEVICE") == "cpu" or torch.cuda.is_available() is False: + DEVICE = torch.device("cpu") +else: + DEVICE = torch.device(f"cuda:{LOCAL_RANK}") + +JIT = False +CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory +ENERGY_BIAS_TRAINABLE = True +CUSTOM_OP_USE_JIT = False + +PRECISION_DICT = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "half": torch.float16, + "single": torch.float32, + "double": torch.float64, + "int32": torch.int32, + "int64": torch.int64, + "bfloat16": torch.bfloat16, + "bool": torch.bool, +} +GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name] +GLOBAL_PT_ENER_FLOAT_PRECISION = PRECISION_DICT[ + np.dtype(GLOBAL_ENER_FLOAT_PRECISION).name +] +PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION +assert VALID_PRECISION.issubset(PRECISION_DICT.keys()) +# cannot automatically generated +RESERVED_PRECISION_DICT = { + torch.float16: "float16", + torch.float32: "float32", + torch.float64: "float64", + torch.int32: "int32", + torch.int64: "int64", + torch.bfloat16: "bfloat16", + torch.bool: "bool", +} +assert set(PRECISION_DICT.values()) == set(RESERVED_PRECISION_DICT.keys()) +DEFAULT_PRECISION = "float64" + +# throw warnings if threads not set +set_default_nthreads() +inter_nthreads, intra_nthreads = get_default_nthreads() +if inter_nthreads > 0: # the behavior of 0 is not documented + torch.set_num_interop_threads(inter_nthreads) +if intra_nthreads > 0: + torch.set_num_threads(intra_nthreads) + +__all__ = [ + "CACHE_PER_SYS", + "CUSTOM_OP_USE_JIT", + "DEFAULT_PRECISION", + "DEVICE", + "ENERGY_BIAS_TRAINABLE", + "GLOBAL_ENER_FLOAT_PRECISION", + "GLOBAL_NP_FLOAT_PRECISION", + "GLOBAL_PT_ENER_FLOAT_PRECISION", + "GLOBAL_PT_FLOAT_PRECISION", + "JIT", + "LOCAL_RANK", + "NUM_WORKERS", + "PRECISION_DICT", + "RESERVED_PRECISION_DICT", + "SAMPLER_RECORD", +] diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py new file mode 100644 index 0000000000..ed296c9f98 --- /dev/null +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +from typing import ( + Any, +) + +from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP +from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP +from deepmd.pt_expt.utils import ( + env, +) + +torch = importlib.import_module("torch") + + +class AtomExcludeMask(AtomExcludeMaskDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "type_mask": + value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + return super().__setattr__(name, value) + + +class PairExcludeMask(PairExcludeMaskDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "type_mask": + value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + return super().__setattr__(name, value) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index f29d8970b3..91a6999766 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import importlib from typing import ( Any, ClassVar, @@ -6,7 +7,6 @@ ) import numpy as np -import torch # noqa: TID253 from deepmd.dpmodel.common import ( NativeOP, @@ -19,10 +19,12 @@ make_fitting_network, make_multilayer_network, ) -from deepmd.pt.utils import ( # noqa: TID253 +from deepmd.pt_expt.utils import ( env, ) +torch = importlib.import_module("torch") + def _to_torch_array(value: Any) -> torch.Tensor | None: if value is None: diff --git a/source/tests/pt_expt/model/test_se_e2_a.py b/source/tests/pt_expt/model/test_se_e2_a.py index b9b834849f..57923b97a3 100644 --- a/source/tests/pt_expt/model/test_se_e2_a.py +++ b/source/tests/pt_expt/model/test_se_e2_a.py @@ -1,23 +1,23 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import importlib import itertools import unittest import numpy as np -import torch # noqa: TID253 from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA -from deepmd.pt.utils import ( # noqa: TID253 +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.utils import ( env, ) -from deepmd.pt.utils.env import ( # noqa: TID253 +from deepmd.pt_expt.utils.env import ( PRECISION_DICT, ) -from deepmd.pt.utils.exclude_mask import ( # noqa: TID253 +from deepmd.pt_expt.utils.exclude_mask import ( PairExcludeMask, ) -from deepmd.pt_expt.descriptor.se_e2_a import ( - DescrptSeA, -) from ...pt.model.test_env_mat import ( TestCaseSingleFrameWithNlist, @@ -29,6 +29,8 @@ GLOBAL_SEED, ) +torch = importlib.import_module("torch") + class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self) -> None: From f2fbe8884fdacf4369dbec847c06ebc54b453d02 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 10:08:29 +0800 Subject: [PATCH 04/60] mv to_torch_tensor to common --- deepmd/pt_expt/common.py | 35 +++++++++++++++++++++++++++++++++ deepmd/pt_expt/utils/network.py | 16 ++++----------- 2 files changed, 39 insertions(+), 12 deletions(-) create mode 100644 deepmd/pt_expt/common.py diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py new file mode 100644 index 0000000000..f065eeb76d --- /dev/null +++ b/deepmd/pt_expt/common.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +from typing import ( + Any, + overload, +) + +import numpy as np + +from deepmd.pt_expt.utils import ( + env, +) + +torch = importlib.import_module("torch") + + +@overload +def to_torch_array(array: np.ndarray) -> torch.Tensor: ... + + +@overload +def to_torch_array(array: None) -> None: ... + + +@overload +def to_torch_array(array: torch.Tensor) -> torch.Tensor: ... + + +def to_torch_array(array: Any) -> torch.Tensor | None: + """Convert input to a torch tensor on the pt-expt device.""" + if array is None: + return None + if torch.is_tensor(array): + return array + return torch.as_tensor(array, device=env.DEVICE) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 91a6999766..18840200be 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -19,21 +19,13 @@ make_fitting_network, make_multilayer_network, ) -from deepmd.pt_expt.utils import ( - env, +from deepmd.pt_expt.common import ( + to_torch_array, ) torch = importlib.import_module("torch") -def _to_torch_array(value: Any) -> torch.Tensor | None: - if value is None: - return None - if torch.is_tensor(value): - return value - return torch.as_tensor(value, device=env.DEVICE) - - class TorchArrayParam(torch.nn.Parameter): def __new__(cls, data: Any = None, requires_grad: bool = True) -> Self: return torch.nn.Parameter.__new__(cls, data, requires_grad) @@ -52,7 +44,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: for name in ("w", "b", "idt"): if name in self._parameters or name in self._buffers: continue - val = _to_torch_array(getattr(self, name)) + val = to_torch_array(getattr(self, name)) if val is None: continue if self.trainable: @@ -66,7 +58,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: - val = _to_torch_array(value) + val = to_torch_array(value) if val is None: return super().__setattr__(name, None) if getattr(self, "trainable", False): From e2afbe9c190ffef45315cac5089e067c7da800c5 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 10:26:15 +0800 Subject: [PATCH 05/60] simplify __init__ of the NaiveLayer --- deepmd/pt_expt/utils/network.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 18840200be..5708197c66 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -41,20 +41,6 @@ class NativeLayer(NativeLayerDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) NativeLayerDP.__init__(self, *args, **kwargs) - for name in ("w", "b", "idt"): - if name in self._parameters or name in self._buffers: - continue - val = to_torch_array(getattr(self, name)) - if val is None: - continue - if self.trainable: - if hasattr(self, name) and name not in self._parameters: - delattr(self, name) - self.register_parameter(name, TorchArrayParam(val, requires_grad=True)) - else: - if hasattr(self, name) and name not in self._buffers: - delattr(self, name) - self.register_buffer(name, val) def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: From 4ba511ac49d1093b3cabced160a953c90e8e0f81 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 10:32:09 +0800 Subject: [PATCH 06/60] fix bug --- deepmd/pt_expt/utils/network.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 5708197c66..f2230383de 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -3,7 +3,6 @@ from typing import ( Any, ClassVar, - Self, ) import numpy as np @@ -27,7 +26,9 @@ class TorchArrayParam(torch.nn.Parameter): - def __new__(cls, data: Any = None, requires_grad: bool = True) -> Self: + def __new__( # noqa: PYI034 + cls, data: Any = None, requires_grad: bool = True + ) -> "TorchArrayParam": return torch.nn.Parameter.__new__(cls, data, requires_grad) def __array__(self, dtype: Any | None = None) -> np.ndarray: From fb9598a68d13b1712b11b9ea481fd3e2ca85b502 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 10:43:19 +0800 Subject: [PATCH 07/60] fix bug --- deepmd/pt_expt/utils/network.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index f2230383de..7a85634dca 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -95,16 +95,18 @@ class NetworkCollection(NetworkCollectionDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) - super().__init__(*args, **kwargs) self._module_networks = torch.nn.ModuleDict() - for idx, net in enumerate(self._networks): - if isinstance(net, torch.nn.Module): - self._module_networks[str(idx)] = net + super().__init__(*args, **kwargs) def __setitem__(self, key: int | tuple, value: Any) -> None: + idx = self._convert_key(key) super().__setitem__(key, value) - if isinstance(value, torch.nn.Module): - self._module_networks[str(self._convert_key(key))] = value + net = self._networks[idx] + key_str = str(idx) + if isinstance(net, torch.nn.Module): + self._module_networks[key_str] = net + elif key_str in self._module_networks: + del self._module_networks[key_str] class LayerNorm(LayerNormDP, NativeLayer): From fa03351be77fe9b1f1e5173d0855d6e6d912ab9e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 11:25:14 +0800 Subject: [PATCH 08/60] simplify init method of se_e2_a descriptor. fig bug in consistent UT --- deepmd/pt_expt/descriptor/se_e2_a.py | 25 ------------------------- source/tests/consistent/common.py | 2 +- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index bb0c0cb2bd..19a0d56734 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -27,7 +27,6 @@ class DescrptSeA(DescrptSeADP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) DescrptSeADP.__init__(self, *args, **kwargs) - self._convert_state() def __setattr__(self, name: str, value: Any) -> None: if name in {"davg", "dstd"} and "_buffers" in self.__dict__: @@ -53,30 +52,6 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) return super().__setattr__(name, value) - def _convert_state(self) -> None: - if self.davg is not None: - davg = torch.as_tensor(self.davg, device=env.DEVICE) - if "davg" in self._buffers: - self._buffers["davg"] = davg - else: - if hasattr(self, "davg"): - delattr(self, "davg") - self.register_buffer("davg", davg) - if self.dstd is not None: - dstd = torch.as_tensor(self.dstd, device=env.DEVICE) - if "dstd" in self._buffers: - self._buffers["dstd"] = dstd - else: - if hasattr(self, "dstd"): - delattr(self, "dstd") - self.register_buffer("dstd", dstd) - if self.embeddings is not None: - self.embeddings = NetworkCollection.deserialize(self.embeddings.serialize()) - if self.emask is not None: - self.emask = PairExcludeMask( - self.ntypes, exclude_types=list(self.emask.get_exclude_types()) - ) - def forward( self, nlist: torch.Tensor, diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 3d60f6def0..76b7e9cb53 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -92,7 +92,7 @@ class CommonTest(ABC): """Native DP model class.""" pt_class: ClassVar[type | None] """PyTorch model class.""" - pt_expt_class: ClassVar[type | None] + pt_expt_class: ClassVar[type | None] = None """PyTorch exportable model class.""" jax_class: ClassVar[type | None] """JAX model class.""" From 09b33f19daef30dad5dcd27be4811cde994b8bad Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 11:34:44 +0800 Subject: [PATCH 09/60] restructure the test folders. add test_common. --- deepmd/pt_expt/common.py | 2 +- source/tests/pt_expt/descriptor/__init__.py | 1 + .../{model => descriptor}/test_se_e2_a.py | 0 source/tests/pt_expt/utils/__init__.py | 1 + source/tests/pt_expt/utils/test_common.py | 25 +++++++++++++++++++ 5 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 source/tests/pt_expt/descriptor/__init__.py rename source/tests/pt_expt/{model => descriptor}/test_se_e2_a.py (100%) create mode 100644 source/tests/pt_expt/utils/__init__.py create mode 100644 source/tests/pt_expt/utils/test_common.py diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index f065eeb76d..b66c0ff66d 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -31,5 +31,5 @@ def to_torch_array(array: Any) -> torch.Tensor | None: if array is None: return None if torch.is_tensor(array): - return array + return array.to(device=env.DEVICE) return torch.as_tensor(array, device=env.DEVICE) diff --git a/source/tests/pt_expt/descriptor/__init__.py b/source/tests/pt_expt/descriptor/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/descriptor/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/model/test_se_e2_a.py b/source/tests/pt_expt/descriptor/test_se_e2_a.py similarity index 100% rename from source/tests/pt_expt/model/test_se_e2_a.py rename to source/tests/pt_expt/descriptor/test_se_e2_a.py diff --git a/source/tests/pt_expt/utils/__init__.py b/source/tests/pt_expt/utils/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/utils/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py new file mode 100644 index 0000000000..63c4983f23 --- /dev/null +++ b/source/tests/pt_expt/utils/test_common.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib + +import numpy as np + +from deepmd.pt_expt.common import ( + to_torch_array, +) +from deepmd.pt_expt.utils import ( + env, +) + +torch = importlib.import_module("torch") + + +def test_to_torch_array_moves_device() -> None: + arr = np.arange(6, dtype=np.float32).reshape(2, 3) + tensor = to_torch_array(arr) + assert torch.is_tensor(tensor) + assert tensor.device == env.DEVICE + + input_tensor = torch.as_tensor(arr, device=torch.device("cpu")) + output_tensor = to_torch_array(input_tensor) + assert torch.is_tensor(output_tensor) + assert output_tensor.device == env.DEVICE From 67f2e544a6d228652ba9ea311a52a5c8defec210 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 11:54:41 +0800 Subject: [PATCH 10/60] add test_exclusion_mask.py --- .../pt_expt/utils/test_exclusion_mask.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 source/tests/pt_expt/utils/test_exclusion_mask.py diff --git a/source/tests/pt_expt/utils/test_exclusion_mask.py b/source/tests/pt_expt/utils/test_exclusion_mask.py new file mode 100644 index 0000000000..7168579052 --- /dev/null +++ b/source/tests/pt_expt/utils/test_exclusion_mask.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import unittest + +import numpy as np + +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +torch = importlib.import_module("torch") + + +class TestAtomExcludeMask(unittest.TestCase): + def test_build_type_exclude_mask(self) -> None: + nf = 2 + nt = 3 + exclude_types = [0, 2] + atype = np.array( + [ + [0, 2, 1, 2, 0, 1, 0], + [1, 2, 0, 0, 2, 2, 1], + ], + dtype=np.int32, + ).reshape([nf, -1]) + expected_mask = np.array( + [ + [0, 0, 1, 0, 0, 1, 0], + [1, 0, 0, 0, 0, 0, 1], + ] + ).reshape([nf, -1]) + des = AtomExcludeMask(nt, exclude_types=exclude_types) + mask = des.build_type_exclude_mask(torch.as_tensor(atype, device=env.DEVICE)) + np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask) + + +class TestPairExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_build_type_exclude_mask(self) -> None: + exclude_types = [[0, 1]] + expected_mask = np.array( + [ + [1, 1, 1, 1, 1, 0, 1], + [1, 1, 1, 1, 1, 0, 1], + [0, 0, 1, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 0, 1], + [1, 1, 1, 1, 1, 0, 1], + ] + ).reshape(self.nf, self.nloc, sum(self.sel)) + des = PairExcludeMask(self.nt, exclude_types=exclude_types) + mask = des.build_type_exclude_mask( + torch.as_tensor(self.nlist, device=env.DEVICE), + torch.as_tensor(self.atype_ext, device=env.DEVICE), + ) + np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask) From f7d83ddfae60920d05b76f8a13fd4a35bb350a79 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 11:58:38 +0800 Subject: [PATCH 11/60] fix poitential import issue in test. --- source/tests/pt_expt/conftest.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 source/tests/pt_expt/conftest.py diff --git a/source/tests/pt_expt/conftest.py b/source/tests/pt_expt/conftest.py new file mode 100644 index 0000000000..ec025c2202 --- /dev/null +++ b/source/tests/pt_expt/conftest.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import pytest + +pytest.importorskip("torch") From 0c96bb6fecf1564433d0f98d00d05866dcb9fbd5 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 12:11:12 +0800 Subject: [PATCH 12/60] correct __call__(). fix bug --- deepmd/pt_expt/descriptor/se_e2_a.py | 6 +++++- deepmd/pt_expt/utils/network.py | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 19a0d56734..7a4d4a71d9 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -28,6 +28,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) DescrptSeADP.__init__(self, *args, **kwargs) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + def __setattr__(self, name: str, value: Any) -> None: if name in {"davg", "dstd"} and "_buffers" in self.__dict__: tensor = ( @@ -54,9 +58,9 @@ def __setattr__(self, name: str, value: Any) -> None: def forward( self, - nlist: torch.Tensor, extended_coord: torch.Tensor, extended_atype: torch.Tensor, + nlist: torch.Tensor, extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, type_embedding: torch.Tensor | None = None, diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 7a85634dca..fffb98b1ef 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -43,6 +43,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) NativeLayerDP.__init__(self, *args, **kwargs) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return torch.nn.Module.__call__(self, *args, **kwargs) + def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: val = to_torch_array(value) @@ -74,6 +77,9 @@ def __init__(self, layers: list[dict] | None = None) -> None: super().__init__(layers) self.layers = torch.nn.ModuleList(self.layers) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return torch.nn.Module.__call__(self, *args, **kwargs) + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) From 9dca9128b50ecf7a92f902e6a6dca0b475ac1e9c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 12:54:28 +0800 Subject: [PATCH 13/60] fix registration issue --- deepmd/pt_expt/descriptor/se_e2_a.py | 4 +++- deepmd/pt_expt/utils/network.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 7a4d4a71d9..7df1148e38 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -40,7 +40,9 @@ def __setattr__(self, name: str, value: Any) -> None: if name in self._buffers: self._buffers[name] = tensor return - return super().__setattr__(name, tensor) + # Register on first assignment so buffers are in state_dict and moved by .to(). + self.register_buffer(name, tensor) + return if name == "embeddings" and "_modules" in self.__dict__: if value is not None and not isinstance(value, torch.nn.Module): if hasattr(value, "serialize"): diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index fffb98b1ef..5f66959d16 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -64,7 +64,9 @@ def __setattr__(self, name: str, value: Any) -> None: if name in self._buffers: self._buffers[name] = val return - return super().__setattr__(name, val) + # Register on first assignment so tensors are in state_dict and moved by .to(). + self.register_buffer(name, val) + return return super().__setattr__(name, value) def forward(self, x: torch.Tensor) -> torch.Tensor: From 17f0a5d1ae06eecfa25b4f2e3e2e09a4570fc0ef Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 13:00:54 +0800 Subject: [PATCH 14/60] fix pt-expt file extension --- deepmd/backend/pt_expt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py index 38745c690c..e651332e2b 100644 --- a/deepmd/backend/pt_expt.py +++ b/deepmd/backend/pt_expt.py @@ -41,7 +41,7 @@ class PyTorchExportableBackend(Backend): | Backend.Feature.IO ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".pth", ".pt"] + suffixes: ClassVar[list[str]] = [".pte"] """The suffixes of the backend.""" def is_available(self) -> bool: From 8ce93baafa42c0c51046152084765f44101809b9 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 13:03:21 +0800 Subject: [PATCH 15/60] fix(pt): expansion of get_default_nthreads() --- deepmd/pt/utils/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 90d0d536c1..aa384b31b5 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -91,7 +91,7 @@ # throw warnings if threads not set set_default_nthreads() -inter_nthreads, intra_nthreads = get_default_nthreads() +intra_nthreads, inter_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented torch.set_num_interop_threads(inter_nthreads) if intra_nthreads > 0: From 309198894d37ad0f0417fc3475919fb02e3fc63c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 13:08:58 +0800 Subject: [PATCH 16/60] fix bug of intra-inter --- deepmd/pt_expt/utils/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py index bd644e7206..b5042f6f2a 100644 --- a/deepmd/pt_expt/utils/env.py +++ b/deepmd/pt_expt/utils/env.py @@ -92,7 +92,7 @@ # throw warnings if threads not set set_default_nthreads() -inter_nthreads, intra_nthreads = get_default_nthreads() +intra_nthreads, inter_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented torch.set_num_interop_threads(inter_nthreads) if intra_nthreads > 0: From 85f05833353a8bb388ce0fd16cdf0059d579b5e4 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 13:13:42 +0800 Subject: [PATCH 17/60] fix bug of default dp inter value --- deepmd/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/env.py b/deepmd/env.py index 7b29a338f1..c9d0fb241f 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -138,7 +138,7 @@ def get_default_nthreads() -> tuple[int, int]: ), int( os.environ.get( "DP_INTER_OP_PARALLELISM_THREADS", - os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0"), + os.environ.get("TF_INTER_OP_PARALLELISM_THREADS", "0"), ) ) From d33324de397d728d15fa773bdb828783085e4156 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 15:15:07 +0800 Subject: [PATCH 18/60] fix cicd --- source/tests/consistent/descriptor/common.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 7c8cbce744..50efe32a08 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -153,13 +153,17 @@ def eval_pt_expt_descriptor( box: np.ndarray, mixed_types: bool = False, ) -> Any: - ext_coords, ext_atype, mapping = extend_coord_with_ghosts( + # Use the torch-native neighbor list utilities to avoid array_api_compat + # allocations on CUDA. The array_api path can hit torch empty/ones/eye/etc + # on CUDA, which all rely on aten::empty_strided and fail in CI builds + # where that CUDA kernel is not available. + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), pt_expt_obj.get_rcut(), ) - nlist = build_neighbor_list( + nlist = build_neighbor_list_pt( ext_coords, ext_atype, natoms[0], From 4de9a565c01c6193a94eaecbc18475c4be04dc08 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 15:57:27 +0800 Subject: [PATCH 19/60] feat: add support for se_r --- deepmd/dpmodel/descriptor/se_r.py | 4 +- deepmd/pt_expt/descriptor/__init__.py | 4 + deepmd/pt_expt/descriptor/se_r.py | 83 +++++++++++ .../tests/consistent/descriptor/test_se_r.py | 25 ++++ source/tests/pt_expt/descriptor/test_se_r.py | 132 ++++++++++++++++++ 5 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 deepmd/pt_expt/descriptor/se_r.py create mode 100644 source/tests/pt_expt/descriptor/test_se_r.py diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 6decd91a23..b38d561e95 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -391,7 +391,9 @@ def call( ng = self.neuron[-1] xyz_scatter = xp.zeros( - [nf, nloc, ng], dtype=get_xp_precision(xp, self.precision) + [nf, nloc, ng], + dtype=get_xp_precision(xp, self.precision), + device=array_api_compat.device(coord_ext), ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) rr = xp.astype(rr, xyz_scatter.dtype) diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py index 089e5619e0..4d9469a93a 100644 --- a/deepmd/pt_expt/descriptor/__init__.py +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -5,8 +5,12 @@ from .se_e2_a import ( DescrptSeA, ) +from .se_r import ( + DescrptSeR, +) __all__ = [ "BaseDescriptor", "DescrptSeA", + "DescrptSeR", ] diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py new file mode 100644 index 0000000000..f4969ce927 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.pt_expt.utils.network import ( + NetworkCollection, +) + +torch = importlib.import_module("torch") + + +@BaseDescriptor.register("se_e2_r_expt") +@BaseDescriptor.register("se_r_expt") +class DescrptSeR(DescrptSeRDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptSeRDP.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + if name in {"davg", "dstd"} and "_buffers" in self.__dict__: + tensor = ( + None if value is None else torch.as_tensor(value, device=env.DEVICE) + ) + if name in self._buffers: + self._buffers[name] = tensor + return + # Register on first assignment so buffers are in state_dict and moved by .to(). + self.register_buffer(name, tensor) + return + if name == "embeddings" and "_modules" in self.__dict__: + if value is not None and not isinstance(value, torch.nn.Module): + if hasattr(value, "serialize"): + value = NetworkCollection.deserialize(value.serialize()) + elif isinstance(value, dict): + value = NetworkCollection.deserialize(value) + return super().__setattr__(name, value) + if name == "emask" and "_modules" in self.__dict__: + if value is not None and not isinstance(value, torch.nn.Module): + value = PairExcludeMask( + self.ntypes, exclude_types=list(value.get_exclude_types()) + ) + return super().__setattr__(name, value) + return super().__setattr__(name, value) + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + extended_atype_embd: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + type_embedding: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + del extended_atype_embd, type_embedding + descrpt, rot_mat, g2, h2, sw = self.call( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + return descrpt, rot_mat, g2, h2, sw diff --git a/source/tests/consistent/descriptor/test_se_r.py b/source/tests/consistent/descriptor/test_se_r.py index 3420c5592f..9aafea1578 100644 --- a/source/tests/consistent/descriptor/test_se_r.py +++ b/source/tests/consistent/descriptor/test_se_r.py @@ -15,6 +15,7 @@ INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -27,6 +28,10 @@ from deepmd.pt.model.descriptor.se_r import DescrptSeR as DescrptSeRPT else: DescrptSeAPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_r import DescrptSeR as DescrptSeRPTExpt +else: + DescrptSeRPTExpt = None if INSTALLED_TF: from deepmd.tf.descriptor.se_r import DescrptSeR as DescrptSeRTF else: @@ -84,6 +89,16 @@ def skip_pt(self) -> bool: ) = self.param return not type_one_side or CommonTest.skip_pt + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + ) = self.param + return not type_one_side or CommonTest.skip_pt_expt + @property def skip_dp(self) -> bool: ( @@ -117,6 +132,7 @@ def skip_array_api_strict(self) -> bool: tf_class = DescrptSeRTF dp_class = DescrptSeRDP pt_class = DescrptSeRPT + pt_expt_class = DescrptSeRPTExpt jax_class = DescrptSeRJAX array_api_strict_class = DescrptSeRArrayAPIStrict args = descrpt_se_r_args() @@ -183,6 +199,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, diff --git a/source/tests/pt_expt/descriptor/test_se_r.py b/source/tests/pt_expt/descriptor/test_se_r.py new file mode 100644 index 0000000000..6e7339801c --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_se_r.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import itertools +import unittest + +import numpy as np + +from deepmd.dpmodel.descriptor import DescrptSeR as DPDescrptSeR +from deepmd.pt_expt.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + +torch = importlib.import_module("torch") + + +class TestDescrptSeR(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, em in itertools.product( + [False, True], + ["float64", "float32"], + [[], [[0, 1]], [[1, 1]]], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + exclude_types=em, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeR.deserialize(dd0.serialize()) + rd1, _, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeR.deserialize(dd0.serialize()) + rd2, _, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + for aa, bb in zip([rd1, sw1], [rd2, sw2], strict=True): + np.testing.assert_allclose( + aa.detach().cpu().numpy(), + bb, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) From f4dc0afec4909dd4c052e9dcd565b4be01b6ed92 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 17:50:35 +0800 Subject: [PATCH 20/60] fix device of xp array --- deepmd/dpmodel/descriptor/se_r.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index b38d561e95..4fdf50beba 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -309,9 +309,12 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.dstd) + device = array_api_compat.device(self.dstd) if not self.set_davg_zero: - self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True) - self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True) + self.davg = xp.asarray( + mean, dtype=self.davg.dtype, copy=True, device=device + ) + self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device) def set_stat_mean_and_stddev( self, From 238483531cd8455219ad5d21065f516f0486c97b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 17:51:50 +0800 Subject: [PATCH 21/60] fix device of xp array --- deepmd/dpmodel/descriptor/dpa1.py | 9 +++++++-- deepmd/dpmodel/descriptor/repflows.py | 9 +++++++-- deepmd/dpmodel/descriptor/repformers.py | 9 +++++++-- deepmd/dpmodel/descriptor/se_e2_a.py | 7 +++++-- deepmd/dpmodel/descriptor/se_t.py | 7 +++++-- deepmd/dpmodel/descriptor/se_t_tebd.py | 9 +++++++-- 6 files changed, 38 insertions(+), 12 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 5228ba55b2..f09ab24dfe 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -909,9 +909,14 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) + device = array_api_compat.device(self.stddev) if not self.set_davg_zero: - self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) - self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) + self.mean = xp.asarray( + mean, dtype=self.mean.dtype, copy=True, device=device + ) + self.stddev = xp.asarray( + stddev, dtype=self.stddev.dtype, copy=True, device=device + ) def get_stats(self) -> dict[str, StatItem]: """Get the statistics of the descriptor.""" diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 706fc690e4..7ba4f92662 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -453,9 +453,14 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) + device = array_api_compat.device(self.stddev) if not self.set_davg_zero: - self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) - self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) + self.mean = xp.asarray( + mean, dtype=self.mean.dtype, copy=True, device=device + ) + self.stddev = xp.asarray( + stddev, dtype=self.stddev.dtype, copy=True, device=device + ) def get_stats(self) -> dict[str, StatItem]: """Get the statistics of the descriptor.""" diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 79d4f9228f..06f5c1c943 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -417,9 +417,14 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) + device = array_api_compat.device(self.stddev) if not self.set_davg_zero: - self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) - self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) + self.mean = xp.asarray( + mean, dtype=self.mean.dtype, copy=True, device=device + ) + self.stddev = xp.asarray( + stddev, dtype=self.stddev.dtype, copy=True, device=device + ) def get_stats(self) -> dict[str, StatItem]: """Get the statistics of the descriptor.""" diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 3ca28ba556..77afb110e9 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -350,9 +350,12 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.dstd) + device = array_api_compat.device(self.dstd) if not self.set_davg_zero: - self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True) - self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True) + self.davg = xp.asarray( + mean, dtype=self.davg.dtype, copy=True, device=device + ) + self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device) def set_stat_mean_and_stddev( self, diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 863187dd4c..749a5da188 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -290,9 +290,12 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.dstd) + device = array_api_compat.device(self.dstd) if not self.set_davg_zero: - self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True) - self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True) + self.davg = xp.asarray( + mean, dtype=self.davg.dtype, copy=True, device=device + ) + self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device) def set_stat_mean_and_stddev( self, diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index e118d5abd4..0a2d46c015 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -694,9 +694,14 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) + device = array_api_compat.device(self.stddev) if not self.set_davg_zero: - self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) - self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) + self.mean = xp.asarray( + mean, dtype=self.mean.dtype, copy=True, device=device + ) + self.stddev = xp.asarray( + stddev, dtype=self.stddev.dtype, copy=True, device=device + ) def get_stats(self) -> dict[str, StatItem]: """Get the statistics of the descriptor.""" From 9646d71b6d05d7c30fb70ec7e0be708e122bdbde Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 17:52:16 +0800 Subject: [PATCH 22/60] revert extend_coord_with_ghosts --- source/tests/consistent/descriptor/common.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 50efe32a08..7c8cbce744 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -153,17 +153,13 @@ def eval_pt_expt_descriptor( box: np.ndarray, mixed_types: bool = False, ) -> Any: - # Use the torch-native neighbor list utilities to avoid array_api_compat - # allocations on CUDA. The array_api path can hit torch empty/ones/eye/etc - # on CUDA, which all rely on aten::empty_strided and fail in CI builds - # where that CUDA kernel is not available. - ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + ext_coords, ext_atype, mapping = extend_coord_with_ghosts( torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), pt_expt_obj.get_rcut(), ) - nlist = build_neighbor_list_pt( + nlist = build_neighbor_list( ext_coords, ext_atype, natoms[0], From f270069dd07c10bb9d908bd203de0e6e7ba72412 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 18:11:20 +0800 Subject: [PATCH 23/60] raise error for non-implemented methods --- deepmd/backend/pt_expt.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py index e651332e2b..ade9eb51f3 100644 --- a/deepmd/backend/pt_expt.py +++ b/deepmd/backend/pt_expt.py @@ -76,9 +76,7 @@ def deep_eval(self) -> type["DeepEvalBackend"]: type[DeepEvalBackend] The Deep Eval backend of the backend. """ - from deepmd.pt.infer.deep_eval import DeepEval as DeepEvalPT - - return DeepEvalPT + raise NotImplementedError @property def neighbor_stat(self) -> type["NeighborStat"]: @@ -89,11 +87,7 @@ def neighbor_stat(self) -> type["NeighborStat"]: type[NeighborStat] The neighbor statistics of the backend. """ - from deepmd.pt.utils.neighbor_stat import ( - NeighborStat, - ) - - return NeighborStat + raise NotImplementedError @property def serialize_hook(self) -> Callable[[str], dict]: @@ -104,11 +98,7 @@ def serialize_hook(self) -> Callable[[str], dict]: Callable[[str], dict] The serialize hook of the backend. """ - from deepmd.pt.utils.serialization import ( - serialize_from_file, - ) - - return serialize_from_file + raise NotImplementedError @property def deserialize_hook(self) -> Callable[[str, dict], None]: @@ -119,8 +109,4 @@ def deserialize_hook(self) -> Callable[[str, dict], None]: Callable[[str, dict], None] The deserialize hook of the backend. """ - from deepmd.pt.utils.serialization import ( - deserialize_to_file, - ) - - return deserialize_to_file + raise NotImplementedError From 57433d3e1e82b00006a9062c4a3610bbf3b52c45 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 22:17:46 +0800 Subject: [PATCH 24/60] restore import torch --- deepmd/pt_expt/common.py | 4 +--- deepmd/pt_expt/descriptor/base_descriptor.py | 5 ++--- deepmd/pt_expt/descriptor/se_e2_a.py | 5 ++--- deepmd/pt_expt/descriptor/se_r.py | 5 ++--- deepmd/pt_expt/utils/env.py | 3 +-- deepmd/pt_expt/utils/exclude_mask.py | 5 ++--- deepmd/pt_expt/utils/network.py | 4 +--- pyproject.toml | 3 +++ source/tests/pt_expt/descriptor/test_se_e2_a.py | 4 +--- source/tests/pt_expt/descriptor/test_se_r.py | 4 +--- source/tests/pt_expt/utils/test_common.py | 4 +--- source/tests/pt_expt/utils/test_exclusion_mask.py | 4 +--- 12 files changed, 18 insertions(+), 32 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index b66c0ff66d..db8b94989b 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -1,18 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, overload, ) import numpy as np +import torch from deepmd.pt_expt.utils import ( env, ) -torch = importlib.import_module("torch") - @overload def to_torch_array(array: np.ndarray) -> torch.Tensor: ... diff --git a/deepmd/pt_expt/descriptor/base_descriptor.py b/deepmd/pt_expt/descriptor/base_descriptor.py index 51e9325bba..986435205a 100644 --- a/deepmd/pt_expt/descriptor/base_descriptor.py +++ b/deepmd/pt_expt/descriptor/base_descriptor.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib + +import torch from deepmd.dpmodel.descriptor import ( make_base_descriptor, ) -torch = importlib.import_module("torch") - BaseDescriptor = make_base_descriptor(torch.Tensor, "forward") diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 7df1148e38..21c0a4eeb7 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, ) +import torch + from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, @@ -18,8 +19,6 @@ NetworkCollection, ) -torch = importlib.import_module("torch") - @BaseDescriptor.register("se_e2_a_expt") @BaseDescriptor.register("se_a_expt") diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index f4969ce927..508785949c 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, ) +import torch + from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, @@ -18,8 +19,6 @@ NetworkCollection, ) -torch = importlib.import_module("torch") - @BaseDescriptor.register("se_e2_r_expt") @BaseDescriptor.register("se_r_expt") diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py index b5042f6f2a..ce13e4ef42 100644 --- a/deepmd/pt_expt/utils/env.py +++ b/deepmd/pt_expt/utils/env.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import logging import multiprocessing import os @@ -18,7 +17,7 @@ ) log = logging.getLogger(__name__) -torch = importlib.import_module("torch") +import torch if sys.platform != "win32": try: diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py index ed296c9f98..15fbbc8e34 100644 --- a/deepmd/pt_expt/utils/exclude_mask.py +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -1,17 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, ) +import torch + from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from deepmd.pt_expt.utils import ( env, ) -torch = importlib.import_module("torch") - class AtomExcludeMask(AtomExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 5f66959d16..3effcfc488 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, ClassVar, ) import numpy as np +import torch from deepmd.dpmodel.common import ( NativeOP, @@ -22,8 +22,6 @@ to_torch_array, ) -torch = importlib.import_module("torch") - class TorchArrayParam(torch.nn.Parameter): def __new__( # noqa: PYI034 diff --git a/pyproject.toml b/pyproject.toml index bd403dfaf2..15eb0b2ae5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -411,6 +411,7 @@ convention = "numpy" banned-module-level-imports = [ "deepmd.tf", "deepmd.pt", + "deepmd.pt_expt", "deepmd.pd", "deepmd.jax", "tensorflow", @@ -432,12 +433,14 @@ runtime-evaluated-base-classes = ["torch.nn.Module"] "data/**" = ["ANN"] "deepmd/tf/**" = ["TID253", "ANN"] "deepmd/pt/**" = ["TID253", "B905"] +"deepmd/pt_expt/**" = ["TID253", "B905"] "deepmd/jax/**" = ["TID253"] "deepmd/pd/**" = ["TID253", "B905"] "source/**" = ["ANN"] "source/tests/tf/**" = ["TID253", "ANN"] "source/tests/pt/**" = ["TID253", "ANN"] +"source/tests/pt_expt/**" = ["TID253", "ANN"] "source/tests/jax/**" = ["TID253", "ANN"] "source/tests/pd/**" = ["TID253", "ANN"] "source/tests/universal/pt/**" = ["TID253", "ANN"] diff --git a/source/tests/pt_expt/descriptor/test_se_e2_a.py b/source/tests/pt_expt/descriptor/test_se_e2_a.py index 57923b97a3..e63138e43b 100644 --- a/source/tests/pt_expt/descriptor/test_se_e2_a.py +++ b/source/tests/pt_expt/descriptor/test_se_e2_a.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import itertools import unittest import numpy as np +import torch from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA from deepmd.pt_expt.descriptor.se_e2_a import ( @@ -29,8 +29,6 @@ GLOBAL_SEED, ) -torch = importlib.import_module("torch") - class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self) -> None: diff --git a/source/tests/pt_expt/descriptor/test_se_r.py b/source/tests/pt_expt/descriptor/test_se_r.py index 6e7339801c..c789b13652 100644 --- a/source/tests/pt_expt/descriptor/test_se_r.py +++ b/source/tests/pt_expt/descriptor/test_se_r.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import itertools import unittest import numpy as np +import torch from deepmd.dpmodel.descriptor import DescrptSeR as DPDescrptSeR from deepmd.pt_expt.descriptor.se_r import ( @@ -26,8 +26,6 @@ GLOBAL_SEED, ) -torch = importlib.import_module("torch") - class TestDescrptSeR(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self) -> None: diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py index 63c4983f23..ee8a7ca324 100644 --- a/source/tests/pt_expt/utils/test_common.py +++ b/source/tests/pt_expt/utils/test_common.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import numpy as np +import torch from deepmd.pt_expt.common import ( to_torch_array, @@ -10,8 +10,6 @@ env, ) -torch = importlib.import_module("torch") - def test_to_torch_array_moves_device() -> None: arr = np.arange(6, dtype=np.float32).reshape(2, 3) diff --git a/source/tests/pt_expt/utils/test_exclusion_mask.py b/source/tests/pt_expt/utils/test_exclusion_mask.py index 7168579052..b3707ef69d 100644 --- a/source/tests/pt_expt/utils/test_exclusion_mask.py +++ b/source/tests/pt_expt/utils/test_exclusion_mask.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import unittest import numpy as np +import torch from deepmd.pt_expt.utils import ( env, @@ -16,8 +16,6 @@ TestCaseSingleFrameWithNlist, ) -torch = importlib.import_module("torch") - class TestAtomExcludeMask(unittest.TestCase): def test_build_type_exclude_mask(self) -> None: From eedcbaf4f67ff6a9dface303450c4110b6c139b2 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 22:31:03 +0800 Subject: [PATCH 25/60] fix(pt,pt-expt): guard thread setters --- deepmd/pt/utils/env.py | 15 ++++++++++++-- deepmd/pt_expt/utils/env.py | 15 ++++++++++++-- source/tests/pt/test_env_threads.py | 28 ++++++++++++++++++++++++++ source/tests/pt_expt/utils/test_env.py | 28 ++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 4 deletions(-) create mode 100644 source/tests/pt/test_env_threads.py create mode 100644 source/tests/pt_expt/utils/test_env.py diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index aa384b31b5..9f453c895c 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -93,9 +93,20 @@ set_default_nthreads() intra_nthreads, inter_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented - torch.set_num_interop_threads(inter_nthreads) + # torch.set_num_interop_threads can only be called once per process. + # Guard to avoid RuntimeError when multiple backends are imported. + try: + if torch.get_num_interop_threads() != inter_nthreads: + torch.set_num_interop_threads(inter_nthreads) + except RuntimeError as err: + log.warning(f"Could not set torch interop threads: {err}") if intra_nthreads > 0: - torch.set_num_threads(intra_nthreads) + # torch.set_num_threads can also fail if called after threads are created. + try: + if torch.get_num_threads() != intra_nthreads: + torch.set_num_threads(intra_nthreads) + except RuntimeError as err: + log.warning(f"Could not set torch intra threads: {err}") __all__ = [ "CACHE_PER_SYS", diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py index ce13e4ef42..56cec25d49 100644 --- a/deepmd/pt_expt/utils/env.py +++ b/deepmd/pt_expt/utils/env.py @@ -93,9 +93,20 @@ set_default_nthreads() intra_nthreads, inter_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented - torch.set_num_interop_threads(inter_nthreads) + # torch.set_num_interop_threads can only be called once per process. + # Guard to avoid RuntimeError when both pt and pt_expt env modules are imported. + try: + if torch.get_num_interop_threads() != inter_nthreads: + torch.set_num_interop_threads(inter_nthreads) + except RuntimeError as err: + log.warning(f"Could not set torch interop threads: {err}") if intra_nthreads > 0: - torch.set_num_threads(intra_nthreads) + # torch.set_num_threads can also fail if called after threads are created. + try: + if torch.get_num_threads() != intra_nthreads: + torch.set_num_threads(intra_nthreads) + except RuntimeError as err: + log.warning(f"Could not set torch intra threads: {err}") __all__ = [ "CACHE_PER_SYS", diff --git a/source/tests/pt/test_env_threads.py b/source/tests/pt/test_env_threads.py new file mode 100644 index 0000000000..eb6604ceb8 --- /dev/null +++ b/source/tests/pt/test_env_threads.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import logging + +import torch + +import deepmd.env as common_env + + +def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None: + def raise_err(*_args, **_kwargs) -> None: + raise RuntimeError("boom") + + monkeypatch.setattr(common_env, "set_default_nthreads", lambda: None) + monkeypatch.setattr(common_env, "get_default_nthreads", lambda: (1, 1)) + monkeypatch.setattr(torch, "get_num_interop_threads", lambda: 2) + monkeypatch.setattr(torch, "set_num_interop_threads", raise_err) + monkeypatch.setattr(torch, "get_num_threads", lambda: 2) + monkeypatch.setattr(torch, "set_num_threads", raise_err) + + caplog.set_level(logging.WARNING, logger="deepmd.pt.utils.env") + import deepmd.pt.utils.env as env + + importlib.reload(env) + + messages = [record.getMessage() for record in caplog.records] + assert any("Could not set torch interop threads" in msg for msg in messages) + assert any("Could not set torch intra threads" in msg for msg in messages) diff --git a/source/tests/pt_expt/utils/test_env.py b/source/tests/pt_expt/utils/test_env.py new file mode 100644 index 0000000000..bbdc696aea --- /dev/null +++ b/source/tests/pt_expt/utils/test_env.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import logging + +import torch + +import deepmd.env as common_env + + +def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None: + def raise_err(*_args, **_kwargs) -> None: + raise RuntimeError("boom") + + monkeypatch.setattr(common_env, "set_default_nthreads", lambda: None) + monkeypatch.setattr(common_env, "get_default_nthreads", lambda: (1, 1)) + monkeypatch.setattr(torch, "get_num_interop_threads", lambda: 2) + monkeypatch.setattr(torch, "set_num_interop_threads", raise_err) + monkeypatch.setattr(torch, "get_num_threads", lambda: 2) + monkeypatch.setattr(torch, "set_num_threads", raise_err) + + caplog.set_level(logging.WARNING, logger="deepmd.pt_expt.utils.env") + import deepmd.pt_expt.utils.env as env + + importlib.reload(env) + + messages = [record.getMessage() for record in caplog.records] + assert any("Could not set torch interop threads" in msg for msg in messages) + assert any("Could not set torch intra threads" in msg for msg in messages) From d8b2cf43faa618a15e1e590688bd002bf75abe28 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 22:54:34 +0800 Subject: [PATCH 26/60] make exclusion mask modules --- deepmd/pt_expt/utils/exclude_mask.py | 26 ++++++++++++++++--- .../pt_expt/utils/test_exclusion_mask.py | 8 ++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py index 15fbbc8e34..e757283e1c 100644 --- a/deepmd/pt_expt/utils/exclude_mask.py +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -12,15 +12,33 @@ ) -class AtomExcludeMask(AtomExcludeMaskDP): +class AtomExcludeMask(AtomExcludeMaskDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + AtomExcludeMaskDP.__init__(self, *args, **kwargs) + def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask": + if name == "type_mask" and "_buffers" in self.__dict__: value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + if name in self._buffers: + self._buffers[name] = value + return + self.register_buffer(name, value) + return return super().__setattr__(name, value) -class PairExcludeMask(PairExcludeMaskDP): +class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + PairExcludeMaskDP.__init__(self, *args, **kwargs) + def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask": + if name == "type_mask" and "_buffers" in self.__dict__: value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + if name in self._buffers: + self._buffers[name] = value + return + self.register_buffer(name, value) + return return super().__setattr__(name, value) diff --git a/source/tests/pt_expt/utils/test_exclusion_mask.py b/source/tests/pt_expt/utils/test_exclusion_mask.py index b3707ef69d..6f836913af 100644 --- a/source/tests/pt_expt/utils/test_exclusion_mask.py +++ b/source/tests/pt_expt/utils/test_exclusion_mask.py @@ -39,6 +39,10 @@ def test_build_type_exclude_mask(self) -> None: mask = des.build_type_exclude_mask(torch.as_tensor(atype, device=env.DEVICE)) np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask) + def test_type_mask_is_buffer(self) -> None: + des = AtomExcludeMask(3, exclude_types=[0]) + assert "type_mask" in des.state_dict() + class TestPairExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self) -> None: @@ -62,3 +66,7 @@ def test_build_type_exclude_mask(self) -> None: torch.as_tensor(self.atype_ext, device=env.DEVICE), ) np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask) + + def test_type_mask_is_buffer(self) -> None: + des = PairExcludeMask(self.nt, exclude_types=[[0, 1]]) + assert "type_mask" in des.state_dict() From aeef15a99d6e9ccf8f9cbb93a149451adf7b615e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 23:46:24 +0800 Subject: [PATCH 27/60] fix(pt-expt): clear params on None --- deepmd/pt_expt/utils/network.py | 6 ++++++ source/tests/pt_expt/utils/test_network.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 source/tests/pt_expt/utils/test_network.py diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 3effcfc488..721a511f5f 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -48,6 +48,12 @@ def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: val = to_torch_array(value) if val is None: + if name in self._parameters: + self._parameters[name] = None + return + if name in self._buffers: + self._buffers[name] = None + return return super().__setattr__(name, None) if getattr(self, "trainable", False): param = ( diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py new file mode 100644 index 0000000000..ad7c2a7e3d --- /dev/null +++ b/source/tests/pt_expt/utils/test_network.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.pt_expt.utils.network import ( + NativeLayer, +) + + +def test_native_layer_clears_parameter_on_none() -> None: + layer = NativeLayer(2, 3, trainable=True) + assert layer.w is not None + layer.w = None + assert layer.w is None + assert layer._parameters.get("w") is None + + +def test_native_layer_clears_buffer_on_none() -> None: + layer = NativeLayer(2, 3, trainable=False) + assert layer.w is not None + layer.w = None + assert layer.w is None + assert layer._buffers.get("w") is None From 8bdb1f89eb509efc8d6133812ec5e9a2c678ae7b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 7 Feb 2026 18:51:50 +0800 Subject: [PATCH 28/60] fix bug --- source/tests/pt/test_env_threads.py | 12 +++++++++--- source/tests/pt_expt/utils/test_env.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/source/tests/pt/test_env_threads.py b/source/tests/pt/test_env_threads.py index eb6604ceb8..50de1996d8 100644 --- a/source/tests/pt/test_env_threads.py +++ b/source/tests/pt/test_env_threads.py @@ -7,7 +7,7 @@ import deepmd.env as common_env -def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None: +def test_env_threads_guard_handles_runtimeerror(monkeypatch) -> None: def raise_err(*_args, **_kwargs) -> None: raise RuntimeError("boom") @@ -18,11 +18,17 @@ def raise_err(*_args, **_kwargs) -> None: monkeypatch.setattr(torch, "get_num_threads", lambda: 2) monkeypatch.setattr(torch, "set_num_threads", raise_err) - caplog.set_level(logging.WARNING, logger="deepmd.pt.utils.env") + messages: list[str] = [] + original_warning = logging.Logger.warning + + def capture_warning(self, msg, *args, **kwargs): # type: ignore[no-untyped-def] + messages.append(str(msg)) + return original_warning(self, msg, *args, **kwargs) + + monkeypatch.setattr(logging.Logger, "warning", capture_warning) import deepmd.pt.utils.env as env importlib.reload(env) - messages = [record.getMessage() for record in caplog.records] assert any("Could not set torch interop threads" in msg for msg in messages) assert any("Could not set torch intra threads" in msg for msg in messages) diff --git a/source/tests/pt_expt/utils/test_env.py b/source/tests/pt_expt/utils/test_env.py index bbdc696aea..a589c80ae1 100644 --- a/source/tests/pt_expt/utils/test_env.py +++ b/source/tests/pt_expt/utils/test_env.py @@ -7,7 +7,7 @@ import deepmd.env as common_env -def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None: +def test_env_threads_guard_handles_runtimeerror(monkeypatch) -> None: def raise_err(*_args, **_kwargs) -> None: raise RuntimeError("boom") @@ -18,11 +18,17 @@ def raise_err(*_args, **_kwargs) -> None: monkeypatch.setattr(torch, "get_num_threads", lambda: 2) monkeypatch.setattr(torch, "set_num_threads", raise_err) - caplog.set_level(logging.WARNING, logger="deepmd.pt_expt.utils.env") + messages: list[str] = [] + original_warning = logging.Logger.warning + + def capture_warning(self, msg, *args, **kwargs): # type: ignore[no-untyped-def] + messages.append(str(msg)) + return original_warning(self, msg, *args, **kwargs) + + monkeypatch.setattr(logging.Logger, "warning", capture_warning) import deepmd.pt_expt.utils.env as env importlib.reload(env) - messages = [record.getMessage() for record in caplog.records] assert any("Could not set torch interop threads" in msg for msg in messages) assert any("Could not set torch intra threads" in msg for msg in messages) From d3b01da5075898e807aa08a34e2fcb49d3b47749 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 16:36:56 +0800 Subject: [PATCH 29/60] utility to handel dpmodel -> pt_expt conversion --- deepmd/pt_expt/common.py | 253 ++++++++++++++++++++++++++- deepmd/pt_expt/descriptor/se_e2_a.py | 39 +---- deepmd/pt_expt/descriptor/se_r.py | 39 +---- deepmd/pt_expt/utils/__init__.py | 4 + deepmd/pt_expt/utils/exclude_mask.py | 39 +++-- deepmd/pt_expt/utils/network.py | 7 + 6 files changed, 293 insertions(+), 88 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index db8b94989b..e687fa8e48 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -1,4 +1,22 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +"""Common utilities for the pt_expt backend. + +This module provides the core infrastructure for automatically wrapping dpmodel +classes (array_api_compat-based) as PyTorch modules. The key insight is to +detect attributes by their **value type** rather than by hard-coded names: + +- numpy arrays → torch buffers (persistent state like statistics, masks) +- dpmodel objects → pt_expt torch.nn.Module wrappers (via registry lookup) +- None values → clear existing buffers + +This eliminates the need to manually enumerate attribute names in each wrapper's +__setattr__ method, making the codebase more maintainable when dpmodel adds +new attributes. +""" + +from collections.abc import ( + Callable, +) from typing import ( Any, overload, @@ -7,11 +25,203 @@ import numpy as np import torch -from deepmd.pt_expt.utils import ( - env, -) +# --------------------------------------------------------------------------- +# dpmodel → pt_expt converter registry +# --------------------------------------------------------------------------- +_DPMODEL_TO_PT_EXPT: dict[type, Callable[[Any], torch.nn.Module]] = {} +"""Registry mapping dpmodel classes to their pt_expt converter functions. + +This registry is populated at module import time via `register_dpmodel_mapping` +calls in each pt_expt wrapper module (e.g., exclude_mask.py, network.py). When +dpmodel_setattr encounters a dpmodel object, it looks up the object's type in +this registry to find the appropriate converter. + +Examples of registered mappings: +- AtomExcludeMaskDP → lambda v: AtomExcludeMask(v.ntypes, exclude_types=...) +- NetworkCollectionDP → lambda v: NetworkCollection.deserialize(v.serialize()) +""" + + +def register_dpmodel_mapping( + dpmodel_cls: type, converter: Callable[[Any], torch.nn.Module] +) -> None: + """Register a converter that turns a dpmodel instance into a pt_expt Module. + + This function is called at module import time by each pt_expt wrapper to + register how dpmodel objects should be converted when they're assigned as + attributes. The converter is a callable that takes a dpmodel instance and + returns the corresponding pt_expt torch.nn.Module wrapper. + + Parameters + ---------- + dpmodel_cls : type + The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP). + This is the key used for lookup in dpmodel_setattr. + converter : Callable[[Any], torch.nn.Module] + A callable that converts a dpmodel instance to a pt_expt module. + Common patterns: + - Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...) + - Round-trip via serialization: lambda v: PtExptClass.deserialize(v.serialize()) + + Notes + ----- + This function must be called AFTER the pt_expt wrapper class is defined but + BEFORE dpmodel_setattr might encounter instances of dpmodel_cls. In practice, + this means calling it immediately after the wrapper class definition at module + import time. + + Examples + -------- + >>> register_dpmodel_mapping( + ... AtomExcludeMaskDP, + ... lambda v: AtomExcludeMask( + ... v.ntypes, exclude_types=list(v.get_exclude_types()) + ... ), + ... ) + """ + _DPMODEL_TO_PT_EXPT[dpmodel_cls] = converter + + +def try_convert_module(value: Any) -> torch.nn.Module | None: + """Convert a dpmodel object to its pt_expt wrapper if a converter is registered. + + This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT + registry. If a converter is found, it invokes it to produce a torch.nn.Module + wrapper; otherwise it returns None. + + Parameters + ---------- + value : Any + The value to potentially convert. Typically a dpmodel object like + AtomExcludeMaskDP or NetworkCollectionDP. + + Returns + ------- + torch.nn.Module or None + The converted pt_expt module if a converter is registered for value's + type, otherwise None. + + Notes + ----- + This function uses exact type matching (not isinstance checks) to ensure + predictable behavior. Each dpmodel class must be explicitly registered via + register_dpmodel_mapping. + + The function is called by dpmodel_setattr when it encounters an object that + might be a dpmodel instance. If conversion succeeds, the caller should use + the converted module instead of the original value. + """ + converter = _DPMODEL_TO_PT_EXPT.get(type(value)) + if converter is not None: + return converter(value) + return None + + +def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool, Any]: + """Common __setattr__ logic for pt_expt wrappers around dpmodel classes. + This function implements automatic attribute detection by value type, eliminating + the need to hard-code attribute names in each wrapper's __setattr__ method. It + handles three cases: + 1. **numpy arrays → torch buffers**: Persistent state like statistics (davg, dstd) + or masks that should be saved in state_dict and moved with .to(device). + 2. **None values → clear buffers**: Setting an existing buffer to None. + 3. **dpmodel objects → pt_expt modules**: Nested dpmodel objects like + AtomExcludeMaskDP or NetworkCollectionDP are converted to their pt_expt + wrappers via the registry. + + Parameters + ---------- + obj : torch.nn.Module + The pt_expt wrapper object whose attribute is being set. Must be a + torch.nn.Module (caller should verify this). + name : str + The attribute name being set. + value : Any + The value being assigned. This function inspects the type to determine + how to handle it. + + Returns + ------- + handled : bool + True if the attribute has been fully set (caller should NOT call + super().__setattr__). False if the caller should forward the (possibly + converted) value to super().__setattr__(name, value). + value : Any + The value to use. May be converted (e.g., dpmodel object → pt_expt module) + or unchanged (e.g., scalar, list, or unregistered object). + + Notes + ----- + **Why this design is safe:** + + - In dpmodel, all persistent arrays use `self.xxx = np.array(...)`. Scalars + use `.item()`, lists use `.tolist()`. So `isinstance(value, np.ndarray)` + reliably identifies buffer-worthy attributes. + - torch.Tensor values assigned to existing buffers fall through to + torch.nn.Module.__setattr__, which correctly updates them. + - dpmodel objects are identified by registry lookup (exact type match), so + only explicitly registered types are converted. + - The function checks `"_buffers" in obj.__dict__` to ensure the object has + been initialized as a torch.nn.Module before attempting buffer operations. + + **Circular import resolution:** + + The function uses a deferred import `from deepmd.pt_expt.utils import env` + inside the function body. This breaks the circular dependency chain: + common.py → utils/__init__.py → exclude_mask.py → common.py. The import is + cached by Python after the first call, so there's no performance penalty. + + **Usage pattern:** + + Typical wrapper classes use this three-line pattern: + + >>> class MyWrapper(MyDPModel, torch.nn.Module): + ... def __setattr__(self, name, value): + ... handled, value = dpmodel_setattr(self, name, value) + ... if not handled: + ... super().__setattr__(name, value) + + Examples + -------- + >>> # Case 1: numpy array → buffer + >>> obj.davg = np.array([1.0, 2.0]) # becomes torch.Tensor buffer + >>> + >>> # Case 2: clear buffer + >>> obj.davg = None # sets buffer to None + >>> + >>> # Case 3: dpmodel object → pt_expt module + >>> obj.emask = AtomExcludeMaskDP(...) # becomes AtomExcludeMask module + """ + from deepmd.pt_expt.utils import env # deferred - avoids circular import + + # numpy array → torch buffer + if isinstance(value, np.ndarray) and "_buffers" in obj.__dict__: + tensor = torch.as_tensor(value, device=env.DEVICE) + if name in obj._buffers: + obj._buffers[name] = tensor + return True, tensor + obj.register_buffer(name, tensor) + return True, tensor + + # clear an existing buffer to None + if value is None and "_buffers" in obj.__dict__ and name in obj._buffers: + obj._buffers[name] = None + return True, None + + # dpmodel object → pt_expt module + if "_modules" in obj.__dict__: + converted = try_convert_module(value) + if converted is not None: + return False, converted + + return False, value + + +# --------------------------------------------------------------------------- +# Utility +# --------------------------------------------------------------------------- @overload def to_torch_array(array: np.ndarray) -> torch.Tensor: ... @@ -25,7 +235,42 @@ def to_torch_array(array: torch.Tensor) -> torch.Tensor: ... def to_torch_array(array: Any) -> torch.Tensor | None: - """Convert input to a torch tensor on the pt-expt device.""" + """Convert input to a torch tensor on the pt_expt device. + + This utility function handles conversion from various array-like types (numpy + arrays, torch tensors on different devices, etc.) to torch tensors on the + pt_expt backend's configured device. + + Parameters + ---------- + array : Any + The input to convert. Can be: + - None (returns None) + - torch.Tensor (moves to pt_expt device) + - numpy array or array-like (converts to torch.Tensor on pt_expt device) + + Returns + ------- + torch.Tensor or None + The input as a torch tensor on the pt_expt device (env.DEVICE), or None + if the input was None. + + Notes + ----- + This function uses the same deferred import pattern as dpmodel_setattr to + avoid circular dependencies. The env module determines the target device + (typically CPU for pt_expt). + + Examples + -------- + >>> import numpy as np + >>> arr = np.array([1.0, 2.0, 3.0]) + >>> tensor = to_torch_array(arr) + >>> tensor.device + device(type='cpu') # or whatever env.DEVICE is set to + """ + from deepmd.pt_expt.utils import env # deferred - avoids circular import + if array is None: return None if torch.is_tensor(array): diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 21c0a4eeb7..1ccb4d2dda 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -6,18 +6,12 @@ import torch from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.pt_expt.utils import ( - env, -) -from deepmd.pt_expt.utils.exclude_mask import ( - PairExcludeMask, -) -from deepmd.pt_expt.utils.network import ( - NetworkCollection, -) @BaseDescriptor.register("se_e2_a_expt") @@ -32,30 +26,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return torch.nn.Module.__call__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name in {"davg", "dstd"} and "_buffers" in self.__dict__: - tensor = ( - None if value is None else torch.as_tensor(value, device=env.DEVICE) - ) - if name in self._buffers: - self._buffers[name] = tensor - return - # Register on first assignment so buffers are in state_dict and moved by .to(). - self.register_buffer(name, tensor) - return - if name == "embeddings" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - if hasattr(value, "serialize"): - value = NetworkCollection.deserialize(value.serialize()) - elif isinstance(value, dict): - value = NetworkCollection.deserialize(value) - return super().__setattr__(name, value) - if name == "emask" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - value = PairExcludeMask( - self.ntypes, exclude_types=list(value.get_exclude_types()) - ) - return super().__setattr__(name, value) - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) def forward( self, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index 508785949c..7a406fb499 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -6,18 +6,12 @@ import torch from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.pt_expt.utils import ( - env, -) -from deepmd.pt_expt.utils.exclude_mask import ( - PairExcludeMask, -) -from deepmd.pt_expt.utils.network import ( - NetworkCollection, -) @BaseDescriptor.register("se_e2_r_expt") @@ -32,30 +26,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return torch.nn.Module.__call__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name in {"davg", "dstd"} and "_buffers" in self.__dict__: - tensor = ( - None if value is None else torch.as_tensor(value, device=env.DEVICE) - ) - if name in self._buffers: - self._buffers[name] = tensor - return - # Register on first assignment so buffers are in state_dict and moved by .to(). - self.register_buffer(name, tensor) - return - if name == "embeddings" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - if hasattr(value, "serialize"): - value = NetworkCollection.deserialize(value.serialize()) - elif isinstance(value, dict): - value = NetworkCollection.deserialize(value) - return super().__setattr__(name, value) - if name == "emask" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - value = PairExcludeMask( - self.ntypes, exclude_types=list(value.get_exclude_types()) - ) - return super().__setattr__(name, value) - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) def forward( self, diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index f90cf82249..93f765a27c 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -4,8 +4,12 @@ AtomExcludeMask, PairExcludeMask, ) +from .network import ( + NetworkCollection, +) __all__ = [ "AtomExcludeMask", + "NetworkCollection", "PairExcludeMask", ] diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py index e757283e1c..4060b8c446 100644 --- a/deepmd/pt_expt/utils/exclude_mask.py +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -7,8 +7,9 @@ from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP -from deepmd.pt_expt.utils import ( - env, +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, ) @@ -18,14 +19,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: AtomExcludeMaskDP.__init__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask" and "_buffers" in self.__dict__: - value = None if value is None else torch.as_tensor(value, device=env.DEVICE) - if name in self._buffers: - self._buffers[name] = value - return - self.register_buffer(name, value) - return - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + +register_dpmodel_mapping( + AtomExcludeMaskDP, + lambda v: AtomExcludeMask(v.ntypes, exclude_types=list(v.get_exclude_types())), +) class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module): @@ -34,11 +36,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: PairExcludeMaskDP.__init__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask" and "_buffers" in self.__dict__: - value = None if value is None else torch.as_tensor(value, device=env.DEVICE) - if name in self._buffers: - self._buffers[name] = value - return - self.register_buffer(name, value) - return - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + +register_dpmodel_mapping( + PairExcludeMaskDP, + lambda v: PairExcludeMask(v.ntypes, exclude_types=list(v.get_exclude_types())), +) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 721a511f5f..84d0024a85 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -19,6 +19,7 @@ make_multilayer_network, ) from deepmd.pt_expt.common import ( + register_dpmodel_mapping, to_torch_array, ) @@ -121,5 +122,11 @@ def __setitem__(self, key: int | tuple, value: Any) -> None: del self._module_networks[key_str] +register_dpmodel_mapping( + NetworkCollectionDP, + lambda v: NetworkCollection.deserialize(v.serialize()), +) + + class LayerNorm(LayerNormDP, NativeLayer): pass From 3452a2a8c0ef161d4bd827066c77e2b1bda7be77 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 17:06:19 +0800 Subject: [PATCH 30/60] fix to_numpy_array device --- deepmd/dpmodel/common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index bd6f7dac49..dabbc34e01 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -121,10 +121,11 @@ def to_numpy_array(x: Optional["Array"]) -> np.ndarray | None: try: # asarray is not within Array API standard, so may fail return np.asarray(x) - except (ValueError, AttributeError): + except (ValueError, AttributeError, TypeError): xp = array_api_compat.array_namespace(x) # to fix BufferError: Cannot export readonly array since signalling readonly is unsupported by DLPack. - x = xp.asarray(x, copy=True) + # Move to CPU device to ensure numpy compatibility + x = xp.asarray(x, device="cpu", copy=True) return np.from_dlpack(x) From ba8e7abfaa6a364e14a3f881b9eadeef4b1ca3b6 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 21:17:53 +0800 Subject: [PATCH 31/60] chore(dpmodel,pt_expt): refactorize the implementation of embedding net --- deepmd/dpmodel/utils/network.py | 107 +++++++- deepmd/pt_expt/utils/network.py | 25 +- source/tests/common/dpmodel/test_network.py | 70 ++++++ source/tests/pt_expt/utils/test_network.py | 259 ++++++++++++++++++++ 4 files changed, 457 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index e712adfdd8..a5502b94cd 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -788,7 +788,112 @@ def deserialize(cls, data: dict) -> "EmbeddingNet": return EN -EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) +class EmbeddingNet(NativeNet): + """The embedding network. + + Parameters + ---------- + in_dim + Input dimension. + neuron + The number of neurons in each layer. The output dimension + is the same as the dimension of the last layer. + activation_function + The activation function. + resnet_dt + Use time step at the resnet architecture. + precision + Floating point precision for the model parameters. + seed : int, optional + Random seed. + bias : bool, Optional + Whether to use bias in the embedding layer. + trainable : bool or list[bool], Optional + Whether the weights are trainable. If a list, each element + corresponds to a layer. + """ + + def __init__( + self, + in_dim: int, + neuron: list[int] = [24, 48, 96], + activation_function: str = "tanh", + resnet_dt: bool = False, + precision: str = DEFAULT_PRECISION, + seed: int | list[int] | None = None, + bias: bool = True, + trainable: bool | list[bool] = True, + ) -> None: + layers = [] + i_in = in_dim + if isinstance(trainable, bool): + trainable = [trainable] * len(neuron) + for idx, ii in enumerate(neuron): + i_ot = ii + layers.append( + NativeLayer( + i_in, + i_ot, + bias=bias, + use_timestep=resnet_dt, + activation_function=activation_function, + resnet=True, + precision=precision, + seed=child_seed(seed, idx), + trainable=trainable[idx], + ).serialize() + ) + i_in = i_ot + super().__init__(layers) + self.in_dim = in_dim + self.neuron = neuron + self.activation_function = activation_function + self.resnet_dt = resnet_dt + self.precision = precision + self.bias = bias + + def serialize(self) -> dict: + """Serialize the network to a dict. + + Returns + ------- + dict + The serialized network. + """ + return { + "@class": "EmbeddingNetwork", + "@version": 2, + "in_dim": self.in_dim, + "neuron": self.neuron.copy(), + "activation_function": self.activation_function, + "resnet_dt": self.resnet_dt, + "bias": self.bias, + # make deterministic + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "layers": [layer.serialize() for layer in self.layers], + } + + @classmethod + def deserialize(cls, data: dict) -> "EmbeddingNet": + """Deserialize the network from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 2, 1) + data.pop("@class", None) + layers = data.pop("layers") + obj = cls(**data) + # Reinitialize layers from serialized data, using the same layer type + # that __init__ created (respects subclass overrides via MRO). + layer_type = type(obj.layers[0]) + obj.layers = type(obj.layers)( + [layer_type.deserialize(layer) for layer in layers] + ) + return obj def make_fitting_network( diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 84d0024a85..b115214056 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -10,11 +10,11 @@ from deepmd.dpmodel.common import ( NativeOP, ) +from deepmd.dpmodel.utils.network import EmbeddingNet as EmbeddingNetDP from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP from deepmd.dpmodel.utils.network import ( - make_embedding_network, make_fitting_network, make_multilayer_network, ) @@ -91,8 +91,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) -class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): - pass +class EmbeddingNet(EmbeddingNetDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + EmbeddingNetDP.__init__(self, *args, **kwargs) + # EmbeddingNetDP.__init__ creates dpmodel NativeLayer instances. + # Convert to pt_expt NativeLayer and wrap in ModuleList. + self.layers = torch.nn.ModuleList( + [NativeLayer.deserialize(layer.serialize()) for layer in self.layers] + ) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return torch.nn.Module.__call__(self, *args, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +register_dpmodel_mapping( + EmbeddingNetDP, + lambda v: EmbeddingNet.deserialize(v.serialize()), +) class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): diff --git a/source/tests/common/dpmodel/test_network.py b/source/tests/common/dpmodel/test_network.py index 1ea5b1fdf9..3a95dd7af0 100644 --- a/source/tests/common/dpmodel/test_network.py +++ b/source/tests/common/dpmodel/test_network.py @@ -180,6 +180,76 @@ def test_embedding_net(self) -> None: inp = np.ones([ni], dtype=get_xp_precision(np, prec)) np.testing.assert_allclose(en0.call(inp), en1.call(inp)) + def test_is_concrete_class(self) -> None: + """Verify EmbeddingNet is a concrete class, not factory-generated.""" + in_dim = 4 + neuron = [8, 16, 32] + net = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + ) + # Check it's the actual EmbeddingNet class, not a dynamic class + self.assertEqual(net.__class__.__name__, "EmbeddingNet") + self.assertEqual(net.__class__.__module__, "deepmd.dpmodel.utils.network") + # Verify it has the expected attributes + self.assertEqual(net.in_dim, in_dim) + self.assertEqual(net.neuron, neuron) + self.assertEqual(net.activation_function, "tanh") + self.assertEqual(net.resnet_dt, True) + self.assertEqual(len(net.layers), len(neuron)) + + def test_forward_pass(self) -> None: + """Test EmbeddingNet forward pass produces correct shapes.""" + in_dim = 4 + neuron = [8, 16, 32] + net = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + ) + rng = np.random.default_rng() + x = rng.standard_normal((5, in_dim)) + out = net.call(x) + self.assertEqual(out.shape, (5, neuron[-1])) + self.assertEqual(out.dtype, np.float64) + + def test_trainable_parameter_variants(self) -> None: + """Test EmbeddingNet with different trainable configurations.""" + in_dim = 4 + neuron = [8, 16] + + # All trainable + net_trainable = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + trainable=True, + ) + for layer in net_trainable.layers: + self.assertTrue(layer.trainable) + + # All frozen + net_frozen = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + trainable=False, + ) + for layer in net_frozen.layers: + self.assertFalse(layer.trainable) + + # Mixed trainable + net_mixed = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + trainable=[True, False], + ) + self.assertTrue(net_mixed.layers[0].trainable) + self.assertFalse(net_mixed.layers[1].trainable) + class TestFittingNet(unittest.TestCase): def test_fitting_net(self) -> None: diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py index ad7c2a7e3d..1510f28bd3 100644 --- a/source/tests/pt_expt/utils/test_network.py +++ b/source/tests/pt_expt/utils/test_network.py @@ -1,9 +1,19 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +import numpy as np +import torch + +from deepmd.dpmodel.utils.network import EmbeddingNet as DPEmbeddingNet from deepmd.pt_expt.utils.network import ( + EmbeddingNet, NativeLayer, ) +from ...seed import ( + GLOBAL_SEED, +) + def test_native_layer_clears_parameter_on_none() -> None: layer = NativeLayer(2, 3, trainable=True) @@ -19,3 +29,252 @@ def test_native_layer_clears_buffer_on_none() -> None: layer.w = None assert layer.w is None assert layer._buffers.get("w") is None + + +class TestEmbeddingNetRefactor(unittest.TestCase): + """Tests for the refactored EmbeddingNet pt_expt wrapper and integration.""" + + def setUp(self) -> None: + self.in_dim = 4 + self.neuron = [8, 16, 32] + self.activation = "tanh" + self.resnet_dt = True + self.precision = "float64" + + def test_pt_expt_embedding_net_wraps_dpmodel(self) -> None: + """Verify pt_expt EmbeddingNet correctly wraps dpmodel.""" + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + # Check it's a torch.nn.Module + self.assertIsInstance(net, torch.nn.Module) + # Check it's also a DPEmbeddingNet + self.assertIsInstance(net, DPEmbeddingNet) + # Check layers are converted to pt_expt NativeLayer (torch modules) + self.assertIsInstance(net.layers, torch.nn.ModuleList) + for layer in net.layers: + self.assertIsInstance(layer, NativeLayer) + self.assertIsInstance(layer, torch.nn.Module) + + def test_pt_expt_embedding_net_forward(self) -> None: + """Test pt_expt EmbeddingNet forward pass returns torch.Tensor.""" + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + x = torch.randn(5, self.in_dim, dtype=torch.float64) + out = net(x) + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(out.shape, (5, self.neuron[-1])) + self.assertEqual(out.dtype, torch.float64) + + def test_serialization_round_trip_pt_expt(self) -> None: + """Test pt_expt EmbeddingNet serialization/deserialization.""" + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + x = torch.randn(5, self.in_dim, dtype=torch.float64) + out1 = net(x) + + # Serialize and deserialize + serialized = net.serialize() + net2 = EmbeddingNet.deserialize(serialized) + + # Verify layers are still pt_expt NativeLayer modules + self.assertIsInstance(net2.layers, torch.nn.ModuleList) + for layer in net2.layers: + self.assertIsInstance(layer, NativeLayer) + + out2 = net2(x) + np.testing.assert_allclose( + out1.detach().cpu().numpy(), + out2.detach().cpu().numpy(), + ) + + def test_deserialize_preserves_layer_type(self) -> None: + """Test that deserialize uses type(obj.layers[0]) to preserve subclass layers. + + This is the key fix: dpmodel's deserialize no longer hardcodes + super(EmbeddingNet, obj).__init__(layers), which would overwrite + pt_expt's converted layers. Instead it uses type(obj.layers[0]) + to respect the subclass's layer type. + """ + # Create pt_expt EmbeddingNet + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Verify layers are pt_expt NativeLayer (torch modules) + for layer in net.layers: + self.assertIsInstance(layer, torch.nn.Module) + self.assertTrue(hasattr(layer, "_parameters")) + + # Deserialize + serialized = net.serialize() + net2 = EmbeddingNet.deserialize(serialized) + + # Verify deserialized layers are STILL pt_expt NativeLayer, not dpmodel + for layer in net2.layers: + self.assertIsInstance(layer, torch.nn.Module) + self.assertTrue(hasattr(layer, "_parameters")) + # This would fail if deserialize used hardcoded dpmodel layers + self.assertIsInstance(layer, NativeLayer) + + def test_cross_backend_consistency(self) -> None: + """Test numerical consistency between dpmodel and pt_expt EmbeddingNet.""" + # Create both with same seed + dp_net = DPEmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + pt_net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Test forward pass + rng = np.random.default_rng() + x_np = rng.standard_normal((5, self.in_dim)) + x_torch = torch.from_numpy(x_np) + + out_dp = dp_net.call(x_np) + out_pt = pt_net(x_torch).detach().cpu().numpy() + + np.testing.assert_allclose(out_dp, out_pt, rtol=1e-10, atol=1e-10) + + def test_registry_converts_dpmodel_to_pt_expt(self) -> None: + """Test that the registry auto-converts dpmodel EmbeddingNet to pt_expt.""" + from deepmd.pt_expt.common import ( + try_convert_module, + ) + + # Create dpmodel EmbeddingNet + dp_net = DPEmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Try to convert via registry + converted = try_convert_module(dp_net) + + # Should return pt_expt EmbeddingNet + self.assertIsNotNone(converted) + self.assertIsInstance(converted, torch.nn.Module) + self.assertIsInstance(converted, EmbeddingNet) + + # Verify layers are pt_expt NativeLayer + for layer in converted.layers: + self.assertIsInstance(layer, NativeLayer) + self.assertIsInstance(layer, torch.nn.Module) + + def test_auto_conversion_in_setattr(self) -> None: + """Test that dpmodel_setattr auto-converts EmbeddingNet attributes.""" + from deepmd.pt_expt.common import ( + dpmodel_setattr, + ) + + # Create a simple torch module + class TestModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.dummy = None + + obj = TestModule() + + # Create dpmodel EmbeddingNet + dp_net = DPEmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Use dpmodel_setattr to set it + handled, value = dpmodel_setattr(obj, "embedding_net", dp_net) + + # Should not be handled (returns converted value for caller to set) + self.assertFalse(handled) + # Value should be converted to pt_expt EmbeddingNet + self.assertIsInstance(value, torch.nn.Module) + self.assertIsInstance(value, EmbeddingNet) + + def test_trainable_parameter_handling(self) -> None: + """Test that trainable parameters work correctly in pt_expt.""" + # Test with trainable=True + net_trainable = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + trainable=True, + seed=GLOBAL_SEED, + ) + + # Count trainable parameters + param_count = sum( + p.numel() for p in net_trainable.parameters() if p.requires_grad + ) + self.assertGreater(param_count, 0) + + # Check all layer parameters are trainable + for layer in net_trainable.layers: + if layer.w is not None: + self.assertTrue(layer.w.requires_grad) + if layer.b is not None: + self.assertTrue(layer.b.requires_grad) + + # Test with trainable=False + net_frozen = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + trainable=False, + seed=GLOBAL_SEED, + ) + + # Count trainable parameters (should be 0) + param_count_frozen = sum( + p.numel() for p in net_frozen.parameters() if p.requires_grad + ) + self.assertEqual(param_count_frozen, 0) + + # Check all layer weights are buffers, not parameters + for layer in net_frozen.layers: + if layer.w is not None: + self.assertFalse(layer.w.requires_grad) From 621c7ccbbcc31c064b4f186af9bd27fb60211787 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 22:15:11 +0800 Subject: [PATCH 32/60] feat: se_t and se_t_tebd descriptors for the pytroch exportable backend. --- deepmd/dpmodel/descriptor/descriptor.py | 21 ++- deepmd/dpmodel/descriptor/dpa1.py | 2 + deepmd/dpmodel/descriptor/se_t.py | 6 +- deepmd/dpmodel/descriptor/se_t_tebd.py | 6 +- deepmd/dpmodel/utils/type_embed.py | 51 ++++-- deepmd/pt_expt/common.py | 7 +- deepmd/pt_expt/descriptor/__init__.py | 10 ++ deepmd/pt_expt/descriptor/se_t.py | 56 ++++++ deepmd/pt_expt/descriptor/se_t_tebd.py | 54 ++++++ deepmd/pt_expt/descriptor/se_t_tebd_block.py | 31 ++++ deepmd/pt_expt/utils/__init__.py | 4 + deepmd/pt_expt/utils/type_embed.py | 41 +++++ .../tests/consistent/descriptor/test_se_t.py | 25 +++ .../consistent/descriptor/test_se_t_tebd.py | 35 ++++ source/tests/pt_expt/descriptor/test_se_t.py | 134 +++++++++++++++ .../pt_expt/descriptor/test_se_t_tebd.py | 159 ++++++++++++++++++ 16 files changed, 620 insertions(+), 22 deletions(-) create mode 100644 deepmd/pt_expt/descriptor/se_t.py create mode 100644 deepmd/pt_expt/descriptor/se_t_tebd.py create mode 100644 deepmd/pt_expt/descriptor/se_t_tebd_block.py create mode 100644 deepmd/pt_expt/utils/type_embed.py create mode 100644 source/tests/pt_expt/descriptor/test_se_t.py create mode 100644 source/tests/pt_expt/descriptor/test_se_t_tebd.py diff --git a/deepmd/dpmodel/descriptor/descriptor.py b/deepmd/dpmodel/descriptor/descriptor.py index 9b0e067972..ad49a7cb8d 100644 --- a/deepmd/dpmodel/descriptor/descriptor.py +++ b/deepmd/dpmodel/descriptor/descriptor.py @@ -12,7 +12,7 @@ NoReturn, ) -import numpy as np +import array_api_compat from deepmd.dpmodel.array_api import ( Array, @@ -173,7 +173,18 @@ def extend_descrpt_stat( extend_dstd = des_with_stat["dstd"] else: extend_shape = [len(type_map), *list(des["davg"].shape[1:])] - extend_davg = np.zeros(extend_shape, dtype=des["davg"].dtype) - extend_dstd = np.ones(extend_shape, dtype=des["dstd"].dtype) - des["davg"] = np.concatenate([des["davg"], extend_davg], axis=0) - des["dstd"] = np.concatenate([des["dstd"], extend_dstd], axis=0) + # Use array_api_compat to infer device and dtype from context + xp = array_api_compat.array_namespace(des["davg"]) + extend_davg = xp.zeros( + extend_shape, + dtype=des["davg"].dtype, + device=array_api_compat.device(des["davg"]), + ) + extend_dstd = xp.ones( + extend_shape, + dtype=des["dstd"].dtype, + device=array_api_compat.device(des["dstd"]), + ) + xp = array_api_compat.array_namespace(des["davg"]) + des["davg"] = xp.concat([des["davg"], extend_davg], axis=0) + des["dstd"] = xp.concat([des["dstd"], extend_dstd], axis=0) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index f09ab24dfe..2f9aa69b62 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -1049,6 +1049,8 @@ def call( idx_j = xp.reshape(nei_type, (-1,)) # (nf x nl x nnei) x ng idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng)) + # Cast to int64 for PyTorch backend (take_along_dim requires Long indices) + idx = xp.astype(idx, xp.int64) # (ntypes) * ntypes * nt type_embedding_nei = xp.tile( xp.reshape(type_embedding, (1, ntypes_with_padding, nt)), diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 749a5da188..95b66759de 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -369,7 +369,11 @@ def call( sec = self.sel_cumsum ng = self.neuron[-1] - result = xp.zeros([nf * nloc, ng], dtype=get_xp_precision(xp, self.precision)) + result = xp.zeros( + [nf * nloc, ng], + dtype=get_xp_precision(xp, self.precision), + device=array_api_compat.device(coord_ext), + ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # merge nf and nloc axis, so for type_one_side == False, # we don't require atype is the same in all frames diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 0a2d46c015..05c9dc77af 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -769,7 +769,9 @@ def call( sw = xp.where( nlist_mask[:, :, None], xp.reshape(sw, (nf * nloc, nnei, 1)), - xp.zeros((nf * nloc, nnei, 1), dtype=sw.dtype), + xp.zeros( + (nf * nloc, nnei, 1), dtype=sw.dtype, device=array_api_compat.device(sw) + ), ) # nfnl x nnei x 4 @@ -832,6 +834,8 @@ def call( # (nf x nl x nt_i x nt_j) x ng idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng)) + # Cast to int64 for PyTorch backend (take_along_dim requires Long indices) + idx = xp.astype(idx, xp.int64) # ntypes * (ntypes) * nt type_embedding_i = xp.tile( diff --git a/deepmd/dpmodel/utils/type_embed.py b/deepmd/dpmodel/utils/type_embed.py index a1b698b698..0d01ccc9d8 100644 --- a/deepmd/dpmodel/utils/type_embed.py +++ b/deepmd/dpmodel/utils/type_embed.py @@ -102,11 +102,21 @@ def call(self) -> Array: sample_array = self.embedding_net[0]["w"] xp = array_api_compat.array_namespace(sample_array) if not self.use_econf_tebd: - embed = self.embedding_net(xp.eye(self.ntypes, dtype=sample_array.dtype)) + embed = self.embedding_net( + xp.eye( + self.ntypes, + dtype=sample_array.dtype, + device=array_api_compat.device(sample_array), + ) + ) else: embed = self.embedding_net(self.econf_tebd) if self.padding: - embed_pad = xp.zeros((1, embed.shape[-1]), dtype=embed.dtype) + embed_pad = xp.zeros( + (1, embed.shape[-1]), + dtype=embed.dtype, + device=array_api_compat.device(embed), + ) embed = xp.concat([embed, embed_pad], axis=0) return embed @@ -182,32 +192,51 @@ def change_type_map( "'activation_function' must be 'Linear' when performing type changing on resnet structure!" ) first_layer_matrix = self.embedding_net.layers[0].w - eye_vector = np.eye(self.ntypes, dtype=PRECISION_DICT[self.precision]) + # Use array_api_compat to handle both numpy and torch + xp = array_api_compat.array_namespace(first_layer_matrix) + eye_vector = xp.eye( + self.ntypes, + dtype=first_layer_matrix.dtype, + device=array_api_compat.device(first_layer_matrix), + ) # preprocess for resnet connection if self.neuron[0] == self.ntypes: - first_layer_matrix += eye_vector + first_layer_matrix = first_layer_matrix + eye_vector elif self.neuron[0] == self.ntypes * 2: - first_layer_matrix += np.concatenate([eye_vector, eye_vector], axis=-1) + first_layer_matrix = first_layer_matrix + xp.concat( + [eye_vector, eye_vector], axis=-1 + ) # randomly initialize params for the unseen types - rng = np.random.default_rng() if has_new_type: - extend_type_params = rng.random( + # Create random params with same dtype and device as first_layer_matrix + extend_type_params = np.random.default_rng().random( [len(type_map), first_layer_matrix.shape[-1]], + dtype=PRECISION_DICT[self.precision], + ) + extend_type_params = xp.asarray( + extend_type_params, dtype=first_layer_matrix.dtype, + device=array_api_compat.device(first_layer_matrix), ) - first_layer_matrix = np.concatenate( + first_layer_matrix = xp.concat( [first_layer_matrix, extend_type_params], axis=0 ) first_layer_matrix = first_layer_matrix[remap_index] new_ntypes = len(type_map) - eye_vector = np.eye(new_ntypes, dtype=PRECISION_DICT[self.precision]) + eye_vector = xp.eye( + new_ntypes, + dtype=first_layer_matrix.dtype, + device=array_api_compat.device(first_layer_matrix), + ) if self.neuron[0] == new_ntypes: - first_layer_matrix -= eye_vector + first_layer_matrix = first_layer_matrix - eye_vector elif self.neuron[0] == new_ntypes * 2: - first_layer_matrix -= np.concatenate([eye_vector, eye_vector], axis=-1) + first_layer_matrix = first_layer_matrix - xp.concat( + [eye_vector, eye_vector], axis=-1 + ) self.embedding_net.layers[0].num_in = new_ntypes self.embedding_net.layers[0].w = first_layer_matrix diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index e687fa8e48..c005804d74 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -85,7 +85,7 @@ def register_dpmodel_mapping( def try_convert_module(value: Any) -> torch.nn.Module | None: """Convert a dpmodel object to its pt_expt wrapper if a converter is registered. - This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT + This function looks up the type of *value* in the _DPMODEL_TO_PT_EXPT registry. If a converter is found, it invokes it to produce a torch.nn.Module wrapper; otherwise it returns None. @@ -103,9 +103,8 @@ def try_convert_module(value: Any) -> torch.nn.Module | None: Notes ----- - This function uses exact type matching (not isinstance checks) to ensure - predictable behavior. Each dpmodel class must be explicitly registered via - register_dpmodel_mapping. + This function uses exact type matching. Each dpmodel class must be explicitly + registered via register_dpmodel_mapping. The function is called by dpmodel_setattr when it encounters an object that might be a dpmodel instance. If conversion succeeds, the caller should use diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py index 4d9469a93a..7feda7d703 100644 --- a/deepmd/pt_expt/descriptor/__init__.py +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +# Import to register converters +from . import se_t_tebd_block # noqa: F401 from .base_descriptor import ( BaseDescriptor, ) @@ -8,9 +10,17 @@ from .se_r import ( DescrptSeR, ) +from .se_t import ( + DescrptSeT, +) +from .se_t_tebd import ( + DescrptSeTTebd, +) __all__ = [ "BaseDescriptor", "DescrptSeA", "DescrptSeR", + "DescrptSeT", + "DescrptSeTTebd", ] diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py new file mode 100644 index 0000000000..604dd6a5c0 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("se_e3_expt") +@BaseDescriptor.register("se_at_expt") +@BaseDescriptor.register("se_a_3be_expt") +class DescrptSeT(DescrptSeTDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptSeTDP.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + extended_atype_embd: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + type_embedding: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + del extended_atype_embd, type_embedding + descrpt, rot_mat, g2, h2, sw = self.call( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + return descrpt, rot_mat, g2, h2, sw diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py new file mode 100644 index 0000000000..235ba1bfe9 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("se_e3_tebd_expt") +class DescrptSeTTebd(DescrptSeTTebdDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptSeTTebdDP.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + extended_atype_embd: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + type_embedding: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + del extended_atype_embd, mapping, type_embedding + descrpt, rot_mat, g2, h2, sw = self.call( + extended_coord, + extended_atype, + nlist, + mapping=None, + ) + return descrpt, rot_mat, g2, h2, sw diff --git a/deepmd/pt_expt/descriptor/se_t_tebd_block.py b/deepmd/pt_expt/descriptor/se_t_tebd_block.py new file mode 100644 index 0000000000..7a0faaf170 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_t_tebd_block.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.descriptor.se_t_tebd import ( + DescrptBlockSeTTebd as DescrptBlockSeTTebdDP, +) +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) + + +class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptBlockSeTTebdDP.__init__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + +register_dpmodel_mapping( + DescrptBlockSeTTebdDP, + lambda v: DescrptBlockSeTTebd.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index 93f765a27c..99ae559e30 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -7,9 +7,13 @@ from .network import ( NetworkCollection, ) +from .type_embed import ( + TypeEmbedNet, +) __all__ = [ "AtomExcludeMask", "NetworkCollection", "PairExcludeMask", + "TypeEmbedNet", ] diff --git a/deepmd/pt_expt/utils/type_embed.py b/deepmd/pt_expt/utils/type_embed.py new file mode 100644 index 0000000000..da4cf09028 --- /dev/null +++ b/deepmd/pt_expt/utils/type_embed.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) + +# Import network to ensure EmbeddingNet is registered before TypeEmbedNet is used +from deepmd.pt_expt.utils import network # noqa: F401 + + +class TypeEmbedNet(TypeEmbedNetDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + TypeEmbedNetDP.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + # Use common dpmodel_setattr which handles embedding_net conversion via registry + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward(self) -> torch.Tensor: + # Call dpmodel's implementation (now with proper device handling) + return self.call() + + +register_dpmodel_mapping( + TypeEmbedNetDP, + lambda v: TypeEmbedNet.deserialize(v.serialize()), +) diff --git a/source/tests/consistent/descriptor/test_se_t.py b/source/tests/consistent/descriptor/test_se_t.py index 56655cabe1..2ef9864f47 100644 --- a/source/tests/consistent/descriptor/test_se_t.py +++ b/source/tests/consistent/descriptor/test_se_t.py @@ -15,6 +15,7 @@ INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -27,6 +28,10 @@ from deepmd.pt.model.descriptor.se_t import DescrptSeT as DescrptSeTPT else: DescrptSeTPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_t import DescrptSeT as DescrptSeTPTExpt +else: + DescrptSeTPTExpt = None if INSTALLED_TF: from deepmd.tf.descriptor.se_t import DescrptSeT as DescrptSeTTF else: @@ -91,6 +96,16 @@ def skip_dp(self) -> bool: ) = self.param return CommonTest.skip_dp + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + excluded_types, + precision, + env_protection, + ) = self.param + return CommonTest.skip_pt_expt + @property def skip_tf(self) -> bool: ( @@ -107,6 +122,7 @@ def skip_tf(self) -> bool: tf_class = DescrptSeTTF dp_class = DescrptSeTDP pt_class = DescrptSeTPT + pt_expt_class = DescrptSeTPTExpt jax_class = DescrptSeTJAX array_api_strict_class = DescrptSeTStrict args = descrpt_se_t_args() @@ -183,6 +199,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index 9cdca9bde3..a60ff1fdad 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -19,6 +19,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, CommonTest, parameterized, ) @@ -30,6 +31,12 @@ from deepmd.pt.model.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdPT else: DescrptSeTTebdPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_t_tebd import ( + DescrptSeTTebd as DescrptSeTTebdPTExpt, + ) +else: + DescrptSeTTebdPTExpt = None DescrptSeTTebdTF = None if INSTALLED_JAX: from deepmd.jax.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdJAX @@ -117,6 +124,23 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_pt_expt(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + excluded_types, + env_protection, + set_davg_zero, + smooth, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return CommonTest.skip_pt_expt + @property def skip_dp(self) -> bool: ( @@ -158,6 +182,7 @@ def skip_tf(self) -> bool: tf_class = DescrptSeTTebdTF dp_class = DescrptSeTTebdDP pt_class = DescrptSeTTebdPT + pt_expt_class = DescrptSeTTebdPTExpt pd_class = DescrptSeTTebdPD jax_class = DescrptSeTTebdJAX array_api_strict_class = DescrptSeTTebdStrict @@ -240,6 +265,16 @@ def eval_pt(self, pt_obj: Any) -> Any: mixed_types=True, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, diff --git a/source/tests/pt_expt/descriptor/test_se_t.py b/source/tests/pt_expt/descriptor/test_se_t.py new file mode 100644 index 0000000000..921f10a54a --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_se_t.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import DescrptSeT as DPDescrptSeT +from deepmd.pt_expt.descriptor.se_t import ( + DescrptSeT, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptSeT(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeT( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeT.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeT.deserialize(dd0.serialize()) + rd2, gr2, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + # se_t returns None for gr/g2/h2, only compare rd and sw + np.testing.assert_allclose( + rd1.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + sw1.detach().cpu().numpy(), + sw2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeT( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) diff --git a/source/tests/pt_expt/descriptor/test_se_t_tebd.py b/source/tests/pt_expt/descriptor/test_se_t_tebd.py new file mode 100644 index 0000000000..e774cda6d5 --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_se_t_tebd.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import DescrptSeTTebd as DPDescrptSeTTebd +from deepmd.pt_expt.descriptor.se_t_tebd import ( + DescrptSeTTebd, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptSeTTebd(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [True], # SeTTebd typically uses resnet_dt=True + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeTTebd( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + + # Type embedding input + type_embedding = torch.randn( + [self.nt, dd0.tebd_dim], dtype=dtype, device=self.device + ) + + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + type_embedding=type_embedding, + ) + dd1 = DescrptSeTTebd.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + type_embedding=type_embedding, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeTTebd.deserialize(dd0.serialize()) + rd2, gr2, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + # se_t_tebd should return gr and sw, compare only descriptor and sw for now + # TODO: investigate why gr is None + np.testing.assert_allclose( + rd1.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + if gr1 is not None and gr2 is not None: + np.testing.assert_allclose( + gr1.detach().cpu().numpy(), + gr2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + sw1.detach().cpu().numpy(), + sw2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeTTebd( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + + type_embedding = torch.randn( + [self.nt, dd0.tebd_dim], dtype=dtype, device=self.device + ) + + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + type_embedding, + ) + torch.export.export(dd0, inputs) From faa4026da8ee7aea5430aa2c6b992f6c0b16d6bf Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 22:21:28 +0800 Subject: [PATCH 33/60] fix bug --- source/tests/pt_expt/utils/test_network.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py index 1510f28bd3..c80946a018 100644 --- a/source/tests/pt_expt/utils/test_network.py +++ b/source/tests/pt_expt/utils/test_network.py @@ -5,6 +5,9 @@ import torch from deepmd.dpmodel.utils.network import EmbeddingNet as DPEmbeddingNet +from deepmd.pt_expt.utils import ( + env, +) from deepmd.pt_expt.utils.network import ( EmbeddingNet, NativeLayer, @@ -87,7 +90,7 @@ def test_serialization_round_trip_pt_expt(self) -> None: precision=self.precision, seed=GLOBAL_SEED, ) - x = torch.randn(5, self.in_dim, dtype=torch.float64) + x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE) out1 = net(x) # Serialize and deserialize From e263270ca2fa02b144f59c36013ac3eb828365a2 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 22:40:48 +0800 Subject: [PATCH 34/60] refact: fitting net --- deepmd/dpmodel/utils/network.py | 113 +++++++++++++++++- deepmd/pt_expt/utils/network.py | 24 +++- source/tests/common/dpmodel/test_network.py | 98 ++++++++++++++++ source/tests/pt_expt/utils/test_network.py | 121 ++++++++++++++++++++ 4 files changed, 352 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index a5502b94cd..a7e194f644 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -1009,7 +1009,118 @@ def deserialize(cls, data: dict) -> "FittingNet": return FN -FittingNet = make_fitting_network(EmbeddingNet, NativeNet, NativeLayer) +class FittingNet(EmbeddingNet): + """The fitting network. It may be implemented as an embedding + net connected with a linear output layer. + + Parameters + ---------- + in_dim + Input dimension. + out_dim + Output dimension + neuron + The number of neurons in each hidden layer. + activation_function + The activation function. + resnet_dt + Use time step at the resnet architecture. + precision + Floating point precision for the model parameters. + bias_out + The last linear layer has bias. + seed : int, optional + Random seed. + trainable : bool or list[bool], optional + Whether the network is trainable. + """ + + def __init__( + self, + in_dim: int, + out_dim: int, + neuron: list[int] = [24, 48, 96], + activation_function: str = "tanh", + resnet_dt: bool = False, + precision: str = DEFAULT_PRECISION, + bias_out: bool = True, + seed: int | list[int] | None = None, + trainable: bool | list[bool] = True, + ) -> None: + if trainable is None: + trainable = [True] * (len(neuron) + 1) + elif isinstance(trainable, bool): + trainable = [trainable] * (len(neuron) + 1) + else: + pass + super().__init__( + in_dim, + neuron=neuron, + activation_function=activation_function, + resnet_dt=resnet_dt, + precision=precision, + seed=seed, + trainable=trainable[:-1], + ) + i_in = neuron[-1] if len(neuron) > 0 else in_dim + i_ot = out_dim + self.layers.append( + NativeLayer( + i_in, + i_ot, + bias=bias_out, + use_timestep=False, + activation_function=None, + resnet=False, + precision=precision, + seed=child_seed(seed, len(neuron)), + trainable=trainable[-1], + ) + ) + self.out_dim = out_dim + self.bias_out = bias_out + + def serialize(self) -> dict: + """Serialize the network to a dict. + + Returns + ------- + dict + The serialized network. + """ + return { + "@class": "FittingNetwork", + "@version": 1, + "in_dim": self.in_dim, + "out_dim": self.out_dim, + "neuron": self.neuron.copy(), + "activation_function": self.activation_function, + "resnet_dt": self.resnet_dt, + "precision": self.precision, + "bias_out": self.bias_out, + "layers": [layer.serialize() for layer in self.layers], + } + + @classmethod + def deserialize(cls, data: dict) -> "FittingNet": + """Deserialize the network from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 1, 1) + data.pop("@class", None) + layers = data.pop("layers") + obj = cls(**data) + # Use type(obj.layers[0]) to respect subclass layer types + layer_type = type(obj.layers[0]) + obj.layers = type(obj.layers)( + [layer_type.deserialize(layer) for layer in layers] + ) + return obj class NetworkCollection: diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index b115214056..ee957316c9 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -11,11 +11,11 @@ NativeOP, ) from deepmd.dpmodel.utils.network import EmbeddingNet as EmbeddingNetDP +from deepmd.dpmodel.utils.network import FittingNet as FittingNetDP from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP from deepmd.dpmodel.utils.network import ( - make_fitting_network, make_multilayer_network, ) from deepmd.pt_expt.common import ( @@ -114,8 +114,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) -class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): - pass +class FittingNet(FittingNetDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + FittingNetDP.__init__(self, *args, **kwargs) + # Convert dpmodel layers to pt_expt NativeLayer + self.layers = torch.nn.ModuleList( + [NativeLayer.deserialize(layer.serialize()) for layer in self.layers] + ) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return torch.nn.Module.__call__(self, *args, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +register_dpmodel_mapping( + FittingNetDP, + lambda v: FittingNet.deserialize(v.serialize()), +) class NetworkCollection(NetworkCollectionDP, torch.nn.Module): diff --git a/source/tests/common/dpmodel/test_network.py b/source/tests/common/dpmodel/test_network.py index 3a95dd7af0..b091495701 100644 --- a/source/tests/common/dpmodel/test_network.py +++ b/source/tests/common/dpmodel/test_network.py @@ -275,6 +275,104 @@ def test_fitting_net(self) -> None: en1.call(inp) np.testing.assert_allclose(en0.call(inp), en1.call(inp)) + def test_is_concrete_class(self) -> None: + """Verify FittingNet is a concrete class, not factory-generated.""" + in_dim = 4 + out_dim = 1 + neuron = [8, 16] + net = FittingNet( + in_dim=in_dim, + out_dim=out_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + bias_out=True, + ) + # Check it's the actual FittingNet class, not a dynamic class + self.assertEqual(net.__class__.__name__, "FittingNet") + self.assertEqual(net.__class__.__module__, "deepmd.dpmodel.utils.network") + # Verify it has the expected attributes + self.assertEqual(net.in_dim, in_dim) + self.assertEqual(net.out_dim, out_dim) + self.assertEqual(net.neuron, neuron) + self.assertEqual(net.activation_function, "tanh") + self.assertEqual(net.resnet_dt, True) + self.assertEqual(net.bias_out, True) + # FittingNet has len(neuron) embedding layers + 1 output layer + self.assertEqual(len(net.layers), len(neuron) + 1) + + def test_forward_pass(self) -> None: + """Test FittingNet forward pass produces correct output shape.""" + in_dim = 4 + out_dim = 3 + neuron = [8, 16, 32] + net = FittingNet( + in_dim=in_dim, + out_dim=out_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + ) + # Single sample + rng = np.random.default_rng() + x = rng.standard_normal(in_dim) + out = net.call(x) + self.assertEqual(out.shape, (out_dim,)) + + # Batch of samples + batch_size = 5 + x_batch = rng.standard_normal((batch_size, in_dim)) + out_batch = net.call(x_batch) + self.assertEqual(out_batch.shape, (batch_size, out_dim)) + + def test_trainable_parameter_variants(self) -> None: + """Test FittingNet with different trainable configurations.""" + in_dim = 4 + out_dim = 2 + neuron = [8, 16] + + # Test 1: All layers trainable (default) + net_all_trainable = FittingNet( + in_dim=in_dim, + out_dim=out_dim, + neuron=neuron, + trainable=True, + ) + for layer in net_all_trainable.layers: + self.assertTrue(layer.trainable) + + # Test 2: All layers frozen + net_all_frozen = FittingNet( + in_dim=in_dim, + out_dim=out_dim, + neuron=neuron, + trainable=False, + ) + for layer in net_all_frozen.layers: + self.assertFalse(layer.trainable) + + # Test 3: Mixed trainable (embedding layers frozen, output layer trainable) + trainable_list = [False, False, True] # 2 embedding layers + 1 output layer + net_mixed = FittingNet( + in_dim=in_dim, + out_dim=out_dim, + neuron=neuron, + trainable=trainable_list, + ) + self.assertFalse(net_mixed.layers[0].trainable) # First embedding layer + self.assertFalse(net_mixed.layers[1].trainable) # Second embedding layer + self.assertTrue(net_mixed.layers[2].trainable) # Output layer + + # Test 4: Serialize/deserialize preserves trainable + serialized = net_mixed.serialize() + net_restored = FittingNet.deserialize(serialized) + for orig_layer, restored_layer in zip( + net_mixed.layers, net_restored.layers, strict=True + ): + self.assertEqual(orig_layer.trainable, restored_layer.trainable) + class TestNetworkCollection(unittest.TestCase): def setUp(self) -> None: diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py index c80946a018..1385cb83d4 100644 --- a/source/tests/pt_expt/utils/test_network.py +++ b/source/tests/pt_expt/utils/test_network.py @@ -281,3 +281,124 @@ def test_trainable_parameter_handling(self) -> None: for layer in net_frozen.layers: if layer.w is not None: self.assertFalse(layer.w.requires_grad) + + +class TestFittingNetRefactor(unittest.TestCase): + """Tests for the refactored FittingNet pt_expt wrapper.""" + + def setUp(self) -> None: + self.in_dim = 4 + self.out_dim = 1 + self.neuron = [8, 16] + self.activation = "tanh" + self.resnet_dt = True + self.precision = "float64" + + def test_pt_expt_fitting_net_wraps_dpmodel(self) -> None: + """Verify pt_expt FittingNet correctly wraps dpmodel.""" + from deepmd.pt_expt.utils.network import ( + FittingNet, + ) + + net = FittingNet( + in_dim=self.in_dim, + out_dim=self.out_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + # Check it's a torch.nn.Module + self.assertIsInstance(net, torch.nn.Module) + # Check layers are converted to pt_expt NativeLayer (torch modules) + self.assertIsInstance(net.layers, torch.nn.ModuleList) + for layer in net.layers: + self.assertIsInstance(layer, torch.nn.Module) + + def test_pt_expt_fitting_net_forward(self) -> None: + """Test pt_expt FittingNet forward pass returns torch.Tensor.""" + from deepmd.pt_expt.utils.network import ( + FittingNet, + ) + + net = FittingNet( + in_dim=self.in_dim, + out_dim=self.out_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE) + out = net(x) + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(out.shape, (5, self.out_dim)) + self.assertEqual(out.dtype, torch.float64) + + def test_serialization_round_trip_pt_expt(self) -> None: + """Test pt_expt FittingNet serialization/deserialization.""" + from deepmd.pt_expt.utils.network import ( + FittingNet, + ) + + net = FittingNet( + in_dim=self.in_dim, + out_dim=self.out_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE) + out1 = net(x) + + # Serialize and deserialize + serialized = net.serialize() + net2 = FittingNet.deserialize(serialized) + + # Verify layers are still pt_expt NativeLayer modules + self.assertIsInstance(net2.layers, torch.nn.ModuleList) + for layer in net2.layers: + self.assertIsInstance(layer, torch.nn.Module) + + out2 = net2(x) + np.testing.assert_allclose( + out1.detach().cpu().numpy(), + out2.detach().cpu().numpy(), + ) + + def test_registry_converts_dpmodel_to_pt_expt(self) -> None: + """Test that dpmodel FittingNet can be converted to pt_expt via registry.""" + from deepmd.dpmodel.utils.network import FittingNet as DPFittingNet + from deepmd.pt_expt.common import ( + try_convert_module, + ) + from deepmd.pt_expt.utils.network import ( + FittingNet, + ) + + # Create dpmodel FittingNet + dp_net = DPFittingNet( + in_dim=self.in_dim, + out_dim=self.out_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Try to convert via registry + converted = try_convert_module(dp_net) + + # Should return pt_expt FittingNet + self.assertIsNotNone(converted) + self.assertIsInstance(converted, torch.nn.Module) + self.assertIsInstance(converted, FittingNet) + + # Verify layers are pt_expt modules + for layer in converted.layers: + self.assertIsInstance(layer, torch.nn.Module) From ea6114167e591d30427f0f304a2a62d2e7bc91db Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 23:09:19 +0800 Subject: [PATCH 35/60] fix bug --- source/tests/pt_expt/utils/test_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py index 1385cb83d4..e56a20e245 100644 --- a/source/tests/pt_expt/utils/test_network.py +++ b/source/tests/pt_expt/utils/test_network.py @@ -74,7 +74,7 @@ def test_pt_expt_embedding_net_forward(self) -> None: precision=self.precision, seed=GLOBAL_SEED, ) - x = torch.randn(5, self.in_dim, dtype=torch.float64) + x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE) out = net(x) self.assertIsInstance(out, torch.Tensor) self.assertEqual(out.shape, (5, self.neuron[-1])) From 9311ed567b6d225e8e781eec196eb13caede1ba1 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 9 Feb 2026 19:44:47 +0800 Subject: [PATCH 36/60] feat(pt_expt): add fitting --- deepmd/dpmodel/fitting/general_fitting.py | 32 +- deepmd/pt_expt/descriptor/se_e2_a.py | 4 +- deepmd/pt_expt/descriptor/se_r.py | 4 +- deepmd/pt_expt/descriptor/se_t.py | 6 +- deepmd/pt_expt/descriptor/se_t_tebd.py | 2 +- deepmd/pt_expt/fitting/__init__.py | 16 + deepmd/pt_expt/fitting/base_fitting.py | 9 + deepmd/pt_expt/fitting/ener_fitting.py | 68 +++ deepmd/pt_expt/fitting/invar_fitting.py | 62 +++ source/tests/consistent/fitting/test_ener.py | 425 ++++++++++++++++++ source/tests/pt_expt/fitting/__init__.py | 1 + .../fitting/test_fitting_ener_fitting.py | 175 ++++++++ .../fitting/test_fitting_invar_fitting.py | 311 +++++++++++++ .../pt_expt/fitting/test_fitting_stat.py | 125 ++++++ 14 files changed, 1227 insertions(+), 13 deletions(-) create mode 100644 deepmd/pt_expt/fitting/__init__.py create mode 100644 deepmd/pt_expt/fitting/base_fitting.py create mode 100644 deepmd/pt_expt/fitting/ener_fitting.py create mode 100644 deepmd/pt_expt/fitting/invar_fitting.py create mode 100644 source/tests/pt_expt/fitting/__init__.py create mode 100644 source/tests/pt_expt/fitting/test_fitting_ener_fitting.py create mode 100644 source/tests/pt_expt/fitting/test_fitting_invar_fitting.py create mode 100644 source/tests/pt_expt/fitting/test_fitting_stat.py diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index a4089468f3..fabc39ae96 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -261,8 +261,18 @@ def compute_input_stats( fparam_std, ) fparam_inv_std = 1.0 / fparam_std - self.fparam_avg = fparam_avg.astype(self.fparam_avg.dtype) - self.fparam_inv_std = fparam_inv_std.astype(self.fparam_inv_std.dtype) + # Use array_api_compat to handle both numpy and torch + xp = array_api_compat.array_namespace(self.fparam_avg) + self.fparam_avg = xp.asarray( + fparam_avg, + dtype=self.fparam_avg.dtype, + device=array_api_compat.device(self.fparam_avg), + ) + self.fparam_inv_std = xp.asarray( + fparam_inv_std, + dtype=self.fparam_inv_std.dtype, + device=array_api_compat.device(self.fparam_inv_std), + ) # stat aparam if self.numb_aparam > 0: sys_sumv = [] @@ -284,8 +294,18 @@ def compute_input_stats( aparam_std, ) aparam_inv_std = 1.0 / aparam_std - self.aparam_avg = aparam_avg.astype(self.aparam_avg.dtype) - self.aparam_inv_std = aparam_inv_std.astype(self.aparam_inv_std.dtype) + # Use array_api_compat to handle both numpy and torch + xp = array_api_compat.array_namespace(self.aparam_avg) + self.aparam_avg = xp.asarray( + aparam_avg, + dtype=self.aparam_avg.dtype, + device=array_api_compat.device(self.aparam_avg), + ) + self.aparam_inv_std = xp.asarray( + aparam_inv_std, + dtype=self.aparam_inv_std.dtype, + device=array_api_compat.device(self.aparam_inv_std), + ) @abstractmethod def _net_out_dim(self) -> int: @@ -566,7 +586,9 @@ def _call_common( # calculate the prediction if not self.mixed_types: outs = xp.zeros( - [nf, nloc, net_dim_out], dtype=get_xp_precision(xp, self.precision) + [nf, nloc, net_dim_out], + dtype=get_xp_precision(xp, self.precision), + device=array_api_compat.device(descriptor), ) for type_i in range(self.ntypes): mask = xp.tile( diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 1ccb4d2dda..b5d3a92ea8 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -14,8 +14,8 @@ ) -@BaseDescriptor.register("se_e2_a_expt") -@BaseDescriptor.register("se_a_expt") +@BaseDescriptor.register("se_e2_a") +@BaseDescriptor.register("se_a") class DescrptSeA(DescrptSeADP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index 7a406fb499..43ed4fb7e9 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -14,8 +14,8 @@ ) -@BaseDescriptor.register("se_e2_r_expt") -@BaseDescriptor.register("se_r_expt") +@BaseDescriptor.register("se_e2_r") +@BaseDescriptor.register("se_r") class DescrptSeR(DescrptSeRDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py index 604dd6a5c0..e3c7f6bc7b 100644 --- a/deepmd/pt_expt/descriptor/se_t.py +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -14,9 +14,9 @@ ) -@BaseDescriptor.register("se_e3_expt") -@BaseDescriptor.register("se_at_expt") -@BaseDescriptor.register("se_a_3be_expt") +@BaseDescriptor.register("se_e3") +@BaseDescriptor.register("se_at") +@BaseDescriptor.register("se_a_3be") class DescrptSeT(DescrptSeTDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py index 235ba1bfe9..081a798ae4 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -14,7 +14,7 @@ ) -@BaseDescriptor.register("se_e3_tebd_expt") +@BaseDescriptor.register("se_e3_tebd") class DescrptSeTTebd(DescrptSeTTebdDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) diff --git a/deepmd/pt_expt/fitting/__init__.py b/deepmd/pt_expt/fitting/__init__.py new file mode 100644 index 0000000000..4a7c8100de --- /dev/null +++ b/deepmd/pt_expt/fitting/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .base_fitting import ( + BaseFitting, +) +from .ener_fitting import ( + EnergyFittingNet, +) +from .invar_fitting import ( + InvarFitting, +) + +__all__ = [ + "BaseFitting", + "EnergyFittingNet", + "InvarFitting", +] diff --git a/deepmd/pt_expt/fitting/base_fitting.py b/deepmd/pt_expt/fitting/base_fitting.py new file mode 100644 index 0000000000..f42e572578 --- /dev/null +++ b/deepmd/pt_expt/fitting/base_fitting.py @@ -0,0 +1,9 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import torch + +from deepmd.dpmodel.fitting import ( + make_base_fitting, +) + +BaseFitting = make_base_fitting(torch.Tensor, "forward") diff --git a/deepmd/pt_expt/fitting/ener_fitting.py b/deepmd/pt_expt/fitting/ener_fitting.py new file mode 100644 index 0000000000..425040ae75 --- /dev/null +++ b/deepmd/pt_expt/fitting/ener_fitting.py @@ -0,0 +1,68 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) +from deepmd.pt_expt.utils.network import ( + NetworkCollection, +) + +from .base_fitting import ( + BaseFitting, +) + + +@BaseFitting.register("ener") +class EnergyFittingNet(EnergyFittingNetDP, torch.nn.Module): + """Energy fitting net for pt_expt backend. + + This inherits from dpmodel EnergyFittingNet to get the correct serialize() method. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + EnergyFittingNetDP.__init__(self, *args, **kwargs) + # Convert dpmodel NetworkCollection to pt_expt NetworkCollection + self.nets = NetworkCollection.deserialize(self.nets.serialize()) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + return self.call( + descriptor, + atype, + gr=gr, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + + +register_dpmodel_mapping( + EnergyFittingNetDP, + lambda v: EnergyFittingNet.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/fitting/invar_fitting.py b/deepmd/pt_expt/fitting/invar_fitting.py new file mode 100644 index 0000000000..aa37026284 --- /dev/null +++ b/deepmd/pt_expt/fitting/invar_fitting.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) +from deepmd.pt_expt.fitting.base_fitting import ( + BaseFitting, +) +from deepmd.pt_expt.utils.network import ( + NetworkCollection, +) + + +@BaseFitting.register("invar") +class InvarFitting(InvarFittingDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + InvarFittingDP.__init__(self, *args, **kwargs) + # Convert dpmodel NetworkCollection to pt_expt NetworkCollection + self.nets = NetworkCollection.deserialize(self.nets.serialize()) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + return self.call( + descriptor, + atype, + gr=gr, + g2=g2, + h2=h2, + fparam=fparam, + aparam=aparam, + ) + + +register_dpmodel_mapping( + InvarFittingDP, + lambda v: InvarFitting.deserialize(v.serialize()), +) diff --git a/source/tests/consistent/fitting/test_ener.py b/source/tests/consistent/fitting/test_ener.py index ad70bd0bfa..a227fa1aed 100644 --- a/source/tests/consistent/fitting/test_ener.py +++ b/source/tests/consistent/fitting/test_ener.py @@ -19,6 +19,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -34,6 +35,13 @@ from deepmd.pt.utils.env import DEVICE as PT_DEVICE else: EnerFittingPT = object +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.fitting.ener_fitting import ( + EnergyFittingNet as EnerFittingPTExpt, + ) + from deepmd.pt_expt.utils.env import DEVICE as PT_EXPT_DEVICE +else: + EnerFittingPTExpt = None if INSTALLED_TF: from deepmd.tf.fit.ener import EnerFitting as EnerFittingTF else: @@ -150,9 +158,23 @@ def skip_tf(self) -> bool: ) = self.param return not INSTALLED_TF or default_fparam is not None + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # PyTorch does not support bfloat16 for some operations + return CommonTest.skip_pt_expt or precision == "bfloat16" + tf_class = EnerFittingTF dp_class = EnerFittingDP pt_class = EnerFittingPT + pt_expt_class = EnerFittingPTExpt jax_class = EnerFittingJAX pd_class = EnerFittingPD array_api_strict_class = EnerFittingStrict @@ -236,6 +258,35 @@ def eval_pt(self, pt_obj: Any) -> Any: .numpy() ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + return ( + pt_expt_obj( + torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE), + fparam=( + torch.from_numpy(self.fparam).to(device=PT_EXPT_DEVICE) + if (numb_fparam and default_fparam is None) # test default_fparam + else None + ), + aparam=( + torch.from_numpy(self.aparam).to(device=PT_EXPT_DEVICE) + if numb_aparam + else None + ), + )["energy"] + .detach() + .cpu() + .numpy() + ) + def eval_dp(self, dp_obj: Any) -> Any: ( resnet_dt, @@ -366,3 +417,377 @@ def atol(self) -> float: return 1e-1 else: raise ValueError(f"Unknown precision: {precision}") + + +@parameterized( + (True,), # resnet_dt + ("float64",), # precision + (True,), # mixed_types + ((3, None),), # (numb_fparam, default_fparam) + ((3, False),), # (numb_aparam, use_aparam_as_mask) + ([],), # atom_ener +) +class TestEnerStat(CommonTest, FittingTest, unittest.TestCase): + @property + def data(self) -> dict: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + return { + "neuron": [5, 5, 5], + "resnet_dt": resnet_dt, + "precision": precision, + "numb_fparam": numb_fparam, + "numb_aparam": numb_aparam, + "default_fparam": default_fparam, + "seed": 20240217, + "atom_ener": atom_ener, + "use_aparam_as_mask": use_aparam_as_mask, + } + + @property + def skip_pt(self) -> bool: + return CommonTest.skip_pt + + @property + def skip_pt_expt(self) -> bool: + return CommonTest.skip_pt_expt + + @property + def skip_tf(self) -> bool: + return True + + skip_jax = not INSTALLED_JAX + + @property + def skip_array_api_strict(self) -> bool: + return not INSTALLED_ARRAY_API_STRICT + + @property + def skip_pd(self) -> bool: + return not INSTALLED_PD + + tf_class = EnerFittingTF + dp_class = EnerFittingDP + pt_class = EnerFittingPT + pt_expt_class = EnerFittingPTExpt + jax_class = EnerFittingJAX + pd_class = EnerFittingPD + array_api_strict_class = EnerFittingStrict + args = fitting_ener() + + def setUp(self) -> None: + CommonTest.setUp(self) + + self.ntypes = 2 + self.natoms = np.array([6, 6, 2, 4], dtype=np.int32) + self.inputs = np.ones((1, 6, 20), dtype=GLOBAL_NP_FLOAT_PRECISION) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32) + # inconsistent if not sorted + self.atype.sort() + + # Prepare data for compute_input_stats + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + + # Create fparam and aparam with correct dimensions + rng = np.random.default_rng(20240217) + self.fparam = ( + rng.normal(size=(1, numb_fparam)).astype(GLOBAL_NP_FLOAT_PRECISION) + if numb_fparam > 0 + else None + ) + self.aparam = ( + rng.normal(size=(1, 6, numb_aparam)).astype(GLOBAL_NP_FLOAT_PRECISION) + if numb_aparam > 0 + else None + ) + + self.stat_data = [ + { + "fparam": rng.normal(size=(2, numb_fparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ), + "aparam": rng.normal(size=(2, 6, numb_aparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ), + }, + { + "fparam": rng.normal(size=(3, numb_fparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ), + "aparam": rng.normal(size=(3, 6, numb_aparam)).astype( + GLOBAL_NP_FLOAT_PRECISION + ), + }, + ] + + @property + def additional_data(self) -> dict: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + return { + "ntypes": self.ntypes, + "dim_descrpt": self.inputs.shape[-1], + "mixed_types": mixed_types, + } + + def build_tf(self, obj: Any, suffix: str) -> tuple[list, dict]: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + return self.build_tf_fitting( + obj, + self.inputs.ravel(), + self.natoms, + self.atype, + self.fparam if numb_fparam else None, + self.aparam if numb_aparam else None, + suffix, + ) + + def eval_pt(self, pt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # Convert stat_data to torch tensors for pt backend + pt_stat_data = [ + { + "fparam": torch.from_numpy(d["fparam"]).to(PT_DEVICE), + "aparam": torch.from_numpy(d["aparam"]).to(PT_DEVICE), + } + for d in self.stat_data + ] + pt_obj.compute_input_stats(pt_stat_data, protection=1e-2) + return ( + pt_obj( + torch.from_numpy(self.inputs).to(device=PT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_DEVICE), + fparam=( + torch.from_numpy(self.fparam).to(device=PT_DEVICE) + if self.fparam is not None + else None + ), + aparam=( + torch.from_numpy(self.aparam).to(device=PT_DEVICE) + if self.aparam is not None + else None + ), + )["energy"] + .detach() + .cpu() + .numpy() + ) + + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # dpmodel's compute_input_stats accepts numpy arrays + pt_expt_obj.compute_input_stats(self.stat_data, protection=1e-2) + return ( + pt_expt_obj( + torch.from_numpy(self.inputs).to(device=PT_EXPT_DEVICE), + torch.from_numpy(self.atype.reshape(1, -1)).to(device=PT_EXPT_DEVICE), + fparam=( + torch.from_numpy(self.fparam).to(device=PT_EXPT_DEVICE) + if self.fparam is not None + else None + ), + aparam=( + torch.from_numpy(self.aparam).to(device=PT_EXPT_DEVICE) + if self.aparam is not None + else None + ), + )["energy"] + .detach() + .cpu() + .numpy() + ) + + def eval_dp(self, dp_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + dp_obj.compute_input_stats(self.stat_data, protection=1e-2) + return dp_obj( + self.inputs, + self.atype.reshape(1, -1), + fparam=self.fparam, + aparam=self.aparam, + )["energy"] + + def eval_jax(self, jax_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # Convert stat_data to jax arrays + jax_stat_data = [ + { + "fparam": jnp.asarray(d["fparam"]), + "aparam": jnp.asarray(d["aparam"]), + } + for d in self.stat_data + ] + jax_obj.compute_input_stats(jax_stat_data, protection=1e-2) + return np.asarray( + jax_obj( + jnp.asarray(self.inputs), + jnp.asarray(self.atype.reshape(1, -1)), + fparam=jnp.asarray(self.fparam) if self.fparam is not None else None, + aparam=jnp.asarray(self.aparam) if self.aparam is not None else None, + )["energy"] + ) + + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # Convert stat_data to array_api_strict arrays + strict_stat_data = [ + { + "fparam": array_api_strict.asarray(d["fparam"]), + "aparam": array_api_strict.asarray(d["aparam"]), + } + for d in self.stat_data + ] + array_api_strict_obj.compute_input_stats(strict_stat_data, protection=1e-2) + return to_numpy_array( + array_api_strict_obj( + array_api_strict.asarray(self.inputs), + array_api_strict.asarray(self.atype.reshape(1, -1)), + fparam=array_api_strict.asarray(self.fparam) + if self.fparam is not None + else None, + aparam=array_api_strict.asarray(self.aparam) + if self.aparam is not None + else None, + )["energy"] + ) + + def eval_pd(self, pd_obj: Any) -> Any: + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + # Convert stat_data to paddle tensors + pd_stat_data = [ + { + "fparam": paddle.to_tensor(d["fparam"]).to(PD_DEVICE), + "aparam": paddle.to_tensor(d["aparam"]).to(PD_DEVICE), + } + for d in self.stat_data + ] + pd_obj.compute_input_stats(pd_stat_data, protection=1e-2) + return ( + pd_obj( + paddle.to_tensor(self.inputs).to(device=PD_DEVICE), + paddle.to_tensor(self.atype.reshape([1, -1])).to(device=PD_DEVICE), + fparam=( + paddle.to_tensor(self.fparam).to(device=PD_DEVICE) + if self.fparam is not None + else None + ), + aparam=( + paddle.to_tensor(self.aparam).to(device=PD_DEVICE) + if self.aparam is not None + else None + ), + )["energy"] + .detach() + .cpu() + .numpy() + ) + + def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: + if backend == self.RefBackend.TF: + # shape is not same + ret = ret[0].reshape(-1, self.natoms[0], 1) + return (ret,) + + @property + def rtol(self) -> float: + """Relative tolerance for comparing the return value.""" + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") + + @property + def atol(self) -> float: + """Absolute tolerance for comparing the return value.""" + ( + resnet_dt, + precision, + mixed_types, + (numb_fparam, default_fparam), + (numb_aparam, use_aparam_as_mask), + atom_ener, + ) = self.param + if precision == "float64": + return 1e-10 + elif precision == "float32": + return 1e-4 + else: + raise ValueError(f"Unknown precision: {precision}") diff --git a/source/tests/pt_expt/fitting/__init__.py b/source/tests/pt_expt/fitting/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/fitting/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/fitting/test_fitting_ener_fitting.py b/source/tests/pt_expt/fitting/test_fitting_ener_fitting.py new file mode 100644 index 0000000000..63ae82ab9a --- /dev/null +++ b/source/tests/pt_expt/fitting/test_fitting_ener_fitting.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + EnergyFittingNet, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestEnergyFittingNet(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_self_consistency( + self, + ) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for nfp, nap in [(0, 0), (3, 0), (0, 4), (3, 4)]: + efn0 = EnergyFittingNet( + self.nt, + ds.dim_out, + numb_fparam=nfp, + numb_aparam=nap, + ).to(self.device) + efn1 = EnergyFittingNet.deserialize(efn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + ret0 = efn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = efn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + ) + + def test_serialize_has_correct_type(self) -> None: + """Test that EnergyFittingNet serializes with type='ener' not 'invar'.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + + efn = EnergyFittingNet( + self.nt, + ds.dim_out, + ).to(self.device) + serialized = efn.serialize() + + # Check that the type is 'ener' not 'invar' + self.assertEqual(serialized["type"], "ener") + + # Check that it can be deserialized + efn2 = EnergyFittingNet.deserialize(serialized).to(self.device) + self.assertIsInstance(efn2, EnergyFittingNet) + + def test_torch_export_simple(self) -> None: + """Test that EnergyFittingNet can be exported with torch.export.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + efn = EnergyFittingNet( + self.nt, + ds.dim_out, + numb_fparam=0, + numb_aparam=0, + ).to(self.device) + + # Prepare inputs + descriptor = torch.from_numpy( + rng.standard_normal((self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + + # Test forward pass works + ret = efn(descriptor, atype) + self.assertIn("energy", ret) + + # Test torch.export + exported = torch.export.export( + efn, + (descriptor, atype), + kwargs={}, + strict=False, + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret_exported = exported.module()(descriptor, atype) + np.testing.assert_allclose( + ret["energy"].detach().cpu().numpy(), + ret_exported["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + def test_torch_export_with_aparam(self) -> None: + """Test that EnergyFittingNet with aparam can be exported.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + efn = EnergyFittingNet( + self.nt, + ds.dim_out, + numb_fparam=0, + numb_aparam=4, + ).to(self.device) + + # Prepare inputs + descriptor = torch.from_numpy( + rng.normal(size=(self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + aparam = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, 4))).to( + self.device + ) + + # Test forward pass works + ret = efn(descriptor, atype, aparam=aparam) + self.assertIn("energy", ret) + + # Test torch.export + exported = torch.export.export( + efn, + (descriptor, atype), + kwargs={"aparam": aparam}, + strict=False, + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret_exported = exported.module()(descriptor, atype, aparam=aparam) + np.testing.assert_allclose( + ret["energy"].detach().cpu().numpy(), + ret_exported["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py b/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py new file mode 100644 index 0000000000..d682b37145 --- /dev/null +++ b/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py @@ -0,0 +1,311 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestInvarFitting(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_self_consistency( + self, + ) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for ( + mixed_types, + od, + nfp, + nap, + et, + ) in itertools.product( + [True, False], + [1, 2], + [0, 3], + [0, 4], + [[], [0], [1]], + ): + ifn0 = InvarFitting( + "energy", + self.nt, + ds.dim_out, + od, + numb_fparam=nfp, + numb_aparam=nap, + mixed_types=mixed_types, + exclude_types=et, + ).to(self.device) + ifn1 = InvarFitting.deserialize(ifn0.serialize()).to(self.device) + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + ret0 = ifn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + ret1 = ifn1( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + ) + sel_set = set(ifn0.get_sel_type()) + exclude_set = set(et) + self.assertEqual(sel_set | exclude_set, set(range(self.nt))) + self.assertEqual(sel_set & exclude_set, set()) + + def test_mask(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + od = 2 + mixed_types = True + # exclude type 1 + et = [1] + ifn0 = InvarFitting( + "energy", + self.nt, + ds.dim_out, + od, + mixed_types=mixed_types, + exclude_types=et, + ).to(self.device) + ret0 = ifn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + ) + # atom index 2 is of type 1 that is excluded + zero_idx = 2 + np.testing.assert_allclose( + ret0["energy"][0, zero_idx, :].detach().cpu().numpy(), + np.zeros_like(ret0["energy"][0, zero_idx, :].detach().cpu().numpy()), + ) + zero_idx = 0 + np.testing.assert_allclose( + ret0["energy"][1, zero_idx, :].detach().cpu().numpy(), + np.zeros_like(ret0["energy"][1, zero_idx, :].detach().cpu().numpy()), + ) + + def test_self_exception( + self, + ) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + dd = ds.call(self.coord_ext, self.atype_ext, self.nlist) + atype = self.atype_ext[:, :nloc] + + for ( + mixed_types, + od, + nfp, + nap, + ) in itertools.product( + [True, False], + [1, 2], + [0, 3], + [0, 4], + ): + ifn0 = InvarFitting( + "energy", + self.nt, + ds.dim_out, + od, + numb_fparam=nfp, + numb_aparam=nap, + mixed_types=mixed_types, + ).to(self.device) + + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to(self.device) + else: + ifp = None + if nap > 0: + iap = torch.from_numpy(rng.normal(size=(self.nf, self.nloc, nap))).to( + self.device + ) + else: + iap = None + with self.assertRaises(ValueError) as context: + ret0 = ifn0( + torch.from_numpy(dd[0][:, :, :-2]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + self.assertIn("input descriptor", str(context.exception)) + + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp - 1))).to( + self.device + ) + with self.assertRaises(ValueError) as context: + ret0 = ifn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + self.assertIn("input fparam", str(context.exception)) + + if nap > 0: + iap = torch.from_numpy( + rng.normal(size=(self.nf, self.nloc, nap - 1)) + ).to(self.device) + with self.assertRaises(ValueError) as context: + ifn0( + torch.from_numpy(dd[0]).to(self.device), + torch.from_numpy(atype).to(self.device), + fparam=ifp, + aparam=iap, + ) + self.assertIn("input aparam", str(context.exception)) + + def test_get_set(self) -> None: + ifn0 = InvarFitting( + "energy", + self.nt, + 3, + 1, + ).to(self.device) + rng = np.random.default_rng(GLOBAL_SEED) + foo = rng.normal([3, 4]) + for ii in [ + "bias_atom_e", + "fparam_avg", + "fparam_inv_std", + "aparam_avg", + "aparam_inv_std", + ]: + ifn0[ii] = torch.from_numpy(foo).to(self.device) + np.testing.assert_allclose( + foo, ifn0[ii].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + + def test_torch_export_simple(self) -> None: + """Test that InvarFitting can be exported with torch.export.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + ifn = InvarFitting( + "energy", + self.nt, + ds.dim_out, + 1, + numb_fparam=0, + numb_aparam=0, + mixed_types=True, + ).to(self.device) + + # Prepare inputs + descriptor = torch.from_numpy( + rng.standard_normal((self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + + # Test forward pass works + ret = ifn(descriptor, atype) + self.assertIn("energy", ret) + + # Test torch.export + exported = torch.export.export( + ifn, + (descriptor, atype), + kwargs={}, + strict=False, # Use strict=False for now to handle dynamic shapes + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret_exported = exported.module()(descriptor, atype) + np.testing.assert_allclose( + ret["energy"].detach().cpu().numpy(), + ret_exported["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + def test_torch_export_with_fparam(self) -> None: + """Test that InvarFitting with fparam can be exported.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA(self.rcut, self.rcut_smth, self.sel) + rng = np.random.default_rng(GLOBAL_SEED) + + ifn = InvarFitting( + "energy", + self.nt, + ds.dim_out, + 1, + numb_fparam=3, + numb_aparam=0, + mixed_types=True, + ).to(self.device) + + # Prepare inputs + descriptor = torch.from_numpy( + rng.normal(size=(self.nf, self.nloc, ds.dim_out)) + ).to(self.device) + atype = torch.from_numpy(self.atype_ext[:, :nloc]).to(self.device) + fparam = torch.from_numpy(rng.normal(size=(self.nf, 3))).to(self.device) + + # Test forward pass works + ret = ifn(descriptor, atype, fparam=fparam) + self.assertIn("energy", ret) + + # Test torch.export + exported = torch.export.export( + ifn, + (descriptor, atype), + kwargs={"fparam": fparam}, + strict=False, + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret_exported = exported.module()(descriptor, atype, fparam=fparam) + np.testing.assert_allclose( + ret["energy"].detach().cpu().numpy(), + ret_exported["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) diff --git a/source/tests/pt_expt/fitting/test_fitting_stat.py b/source/tests/pt_expt/fitting/test_fitting_stat.py new file mode 100644 index 0000000000..b473c9309c --- /dev/null +++ b/source/tests/pt_expt/fitting/test_fitting_stat.py @@ -0,0 +1,125 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + EnergyFittingNet, +) +from deepmd.pt_expt.utils import ( + env, +) + + +def _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds): + """Make fake data as numpy arrays for dpmodel compute_input_stats.""" + merged_output_stat = [] + nsys = len(sys_natoms) + ndof = len(avgs) + for ii in range(nsys): + sys_dict = {} + tmp_data_f = [] + tmp_data_a = [] + for jj in range(ndof): + rng = np.random.default_rng(2025 * ii + 220 * jj) + tmp_data_f.append( + rng.normal(loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], 1)) + ) + rng = np.random.default_rng(220 * ii + 1636 * jj) + tmp_data_a.append( + rng.normal( + loc=avgs[jj], scale=stds[jj], size=(sys_nframes[ii], sys_natoms[ii]) + ) + ) + tmp_data_f = np.transpose(tmp_data_f, (1, 2, 0)) + tmp_data_a = np.transpose(tmp_data_a, (1, 2, 0)) + # dpmodel's compute_input_stats expects numpy arrays + sys_dict["fparam"] = tmp_data_f + sys_dict["aparam"] = tmp_data_a + merged_output_stat.append(sys_dict) + return merged_output_stat + + +def _brute_fparam_pt(data, ndim): + adata = [ii["fparam"] for ii in data] + all_data = [] + for ii in adata: + tmp = np.reshape(ii, [-1, ndim]) + if len(all_data) == 0: + all_data = np.array(tmp) + else: + all_data = np.concatenate((all_data, tmp), axis=0) + avg = np.average(all_data, axis=0) + std = np.std(all_data, axis=0) + return avg, std + + +def _brute_aparam_pt(data, ndim): + adata = [ii["aparam"] for ii in data] + all_data = [] + for ii in adata: + tmp = np.reshape(ii, [-1, ndim]) + if len(all_data) == 0: + all_data = np.array(tmp) + else: + all_data = np.concatenate((all_data, tmp), axis=0) + avg = np.average(all_data, axis=0) + std = np.std(all_data, axis=0) + return avg, std + + +class TestEnerFittingStat(unittest.TestCase): + def setUp(self) -> None: + self.device = env.DEVICE + + def test(self) -> None: + descrpt = DescrptSeA(6.0, 5.8, [46, 92], neuron=[25, 50, 100], axis_neuron=16) + fitting = EnergyFittingNet( + descrpt.get_ntypes(), + descrpt.get_dim_out(), + neuron=[240, 240, 240], + resnet_dt=True, + numb_fparam=3, + numb_aparam=3, + ).to(self.device) + avgs = [0, 10, 100] + stds = [2, 0.4, 0.00001] + sys_natoms = [10, 100] + sys_nframes = [5, 2] + all_data = _make_fake_data_pt(sys_natoms, sys_nframes, avgs, stds) + frefa, frefs = _brute_fparam_pt(all_data, len(avgs)) + arefa, arefs = _brute_aparam_pt(all_data, len(avgs)) + fitting.compute_input_stats(all_data, protection=1e-2) + frefs_inv = 1.0 / frefs + arefs_inv = 1.0 / arefs + frefs_inv[frefs_inv > 100] = 100 + arefs_inv[arefs_inv > 100] = 100 + # fparam_avg and fparam_inv_std are torch tensors on device + fparam_avg_np = ( + fitting.fparam_avg.detach().cpu().numpy() + if torch.is_tensor(fitting.fparam_avg) + else fitting.fparam_avg + ) + fparam_inv_std_np = ( + fitting.fparam_inv_std.detach().cpu().numpy() + if torch.is_tensor(fitting.fparam_inv_std) + else fitting.fparam_inv_std + ) + aparam_avg_np = ( + fitting.aparam_avg.detach().cpu().numpy() + if torch.is_tensor(fitting.aparam_avg) + else fitting.aparam_avg + ) + aparam_inv_std_np = ( + fitting.aparam_inv_std.detach().cpu().numpy() + if torch.is_tensor(fitting.aparam_inv_std) + else fitting.aparam_inv_std + ) + np.testing.assert_almost_equal(frefa, fparam_avg_np) + np.testing.assert_almost_equal(frefs_inv, fparam_inv_std_np) + np.testing.assert_almost_equal(arefa, aparam_avg_np) + np.testing.assert_almost_equal(arefs_inv, aparam_inv_std_np) From 165d1df5c1426b835e715f6b88a8f5c6482dcc20 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 11 Feb 2026 18:25:50 +0800 Subject: [PATCH 37/60] fix the API consistency issue in descriptors --- deepmd/pt_expt/descriptor/se_e2_a.py | 3 --- deepmd/pt_expt/descriptor/se_r.py | 3 --- deepmd/pt_expt/descriptor/se_t.py | 3 --- deepmd/pt_expt/descriptor/se_t_tebd.py | 5 +---- source/tests/pt_expt/descriptor/test_se_t_tebd.py | 12 ------------ 5 files changed, 1 insertion(+), 25 deletions(-) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index b5d3a92ea8..f8a98abd86 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -35,9 +35,7 @@ def forward( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, - type_embedding: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -45,7 +43,6 @@ def forward( torch.Tensor | None, torch.Tensor | None, ]: - del extended_atype_embd, type_embedding descrpt, rot_mat, g2, h2, sw = self.call( extended_coord, extended_atype, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index 43ed4fb7e9..0484c0dea4 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -35,9 +35,7 @@ def forward( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, - type_embedding: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -45,7 +43,6 @@ def forward( torch.Tensor | None, torch.Tensor | None, ]: - del extended_atype_embd, type_embedding descrpt, rot_mat, g2, h2, sw = self.call( extended_coord, extended_atype, diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py index e3c7f6bc7b..6d732790ca 100644 --- a/deepmd/pt_expt/descriptor/se_t.py +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -36,9 +36,7 @@ def forward( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, - type_embedding: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -46,7 +44,6 @@ def forward( torch.Tensor | None, torch.Tensor | None, ]: - del extended_atype_embd, type_embedding descrpt, rot_mat, g2, h2, sw = self.call( extended_coord, extended_atype, diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py index 081a798ae4..f28e1564cc 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -34,9 +34,7 @@ def forward( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, - type_embedding: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -44,11 +42,10 @@ def forward( torch.Tensor | None, torch.Tensor | None, ]: - del extended_atype_embd, mapping, type_embedding descrpt, rot_mat, g2, h2, sw = self.call( extended_coord, extended_atype, nlist, - mapping=None, + mapping=mapping, ) return descrpt, rot_mat, g2, h2, sw diff --git a/source/tests/pt_expt/descriptor/test_se_t_tebd.py b/source/tests/pt_expt/descriptor/test_se_t_tebd.py index e774cda6d5..e84080882a 100644 --- a/source/tests/pt_expt/descriptor/test_se_t_tebd.py +++ b/source/tests/pt_expt/descriptor/test_se_t_tebd.py @@ -58,23 +58,16 @@ def test_consistency(self) -> None: dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - # Type embedding input - type_embedding = torch.randn( - [self.nt, dd0.tebd_dim], dtype=dtype, device=self.device - ) - rd0, _, _, _, _ = dd0( torch.tensor(self.coord_ext, dtype=dtype, device=self.device), torch.tensor(self.atype_ext, dtype=int, device=self.device), torch.tensor(self.nlist, dtype=int, device=self.device), - type_embedding=type_embedding, ) dd1 = DescrptSeTTebd.deserialize(dd0.serialize()) rd1, gr1, _, _, sw1 = dd1( torch.tensor(self.coord_ext, dtype=dtype, device=self.device), torch.tensor(self.atype_ext, dtype=int, device=self.device), torch.tensor(self.nlist, dtype=int, device=self.device), - type_embedding=type_embedding, ) np.testing.assert_allclose( rd0.detach().cpu().numpy(), @@ -146,14 +139,9 @@ def test_exportable(self) -> None: dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) dd0 = dd0.eval() - type_embedding = torch.randn( - [self.nt, dd0.tebd_dim], dtype=dtype, device=self.device - ) - inputs = ( torch.tensor(self.coord_ext, dtype=dtype, device=self.device), torch.tensor(self.atype_ext, dtype=int, device=self.device), torch.tensor(self.nlist, dtype=int, device=self.device), - type_embedding, ) torch.export.export(dd0, inputs) From e76b7020d5b2f02233e951ae960c8d40fddc502f Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 11 Feb 2026 22:48:31 +0800 Subject: [PATCH 38/60] feat: add stat for dpmodel's atomic model. implement atomic model for pt_expt --- .../dpmodel/atomic_model/base_atomic_model.py | 196 +++++ deepmd/dpmodel/common.py | 3 +- deepmd/pt_expt/atomic_model/__init__.py | 12 + .../pt_expt/atomic_model/dp_atomic_model.py | 78 ++ .../atomic_model/energy_atomic_model.py | 27 + source/tests/pt_expt/atomic_model/__init__.py | 1 + .../test_atomic_model_atomic_stat.py | 471 +++++++++++ .../test_atomic_model_global_stat.py | 759 ++++++++++++++++++ .../atomic_model/test_dp_atomic_model.py | 287 +++++++ 9 files changed, 1833 insertions(+), 1 deletion(-) create mode 100644 deepmd/pt_expt/atomic_model/__init__.py create mode 100644 deepmd/pt_expt/atomic_model/dp_atomic_model.py create mode 100644 deepmd/pt_expt/atomic_model/energy_atomic_model.py create mode 100644 source/tests/pt_expt/atomic_model/__init__.py create mode 100644 source/tests/pt_expt/atomic_model/test_atomic_model_atomic_stat.py create mode 100644 source/tests/pt_expt/atomic_model/test_atomic_model_global_stat.py create mode 100644 source/tests/pt_expt/atomic_model/test_dp_atomic_model.py diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 2353e207a3..9866ddbc3a 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -1,5 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import math +from collections.abc import ( + Callable, +) from typing import ( Any, ) @@ -30,6 +33,9 @@ map_atom_exclude_types, map_pair_exclude_types, ) +from deepmd.utils.path import ( + DPPath, +) from .make_base_atomic_model import ( make_base_atomic_model, @@ -246,6 +252,196 @@ def call( aparam=aparam, ) + def get_intensive(self) -> bool: + """Whether the fitting property is intensive.""" + return False + + def get_compute_stats_distinguish_types(self) -> bool: + """Get whether the fitting net computes stats which are not distinguished between different types of atoms.""" + return True + + def compute_or_load_out_stat( + self, + merged: Callable[[], list[dict]] | list[dict], + stat_file_path: DPPath | None = None, + ) -> None: + """ + Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + stat_file_path : Optional[DPPath] + The path to the stat file. + + """ + self.change_out_bias( + merged, + stat_file_path=stat_file_path, + bias_adjust_mode="set-by-statistic", + ) + + def change_out_bias( + self, + sample_merged: Callable[[], list[dict]] | list[dict], + stat_file_path: DPPath | None = None, + bias_adjust_mode: str = "change-by-statistic", + ) -> None: + """Change the output bias according to the input data and the pretrained model. + + Parameters + ---------- + sample_merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + bias_adjust_mode : str + The mode for changing output bias : ['change-by-statistic', 'set-by-statistic'] + 'change-by-statistic' : perform predictions on labels of target dataset, + and do least square on the errors to obtain the target shift as bias. + 'set-by-statistic' : directly use the statistic output bias in the target dataset. + stat_file_path : Optional[DPPath] + The path to the stat file. + """ + from deepmd.dpmodel.utils.stat import ( + compute_output_stats, + ) + + if bias_adjust_mode == "change-by-statistic": + delta_bias, out_std = compute_output_stats( + sample_merged, + self.get_ntypes(), + keys=list(self.atomic_output_def().keys()), + stat_file_path=stat_file_path, + model_forward=self._get_forward_wrapper_func(), + rcond=self.rcond, + preset_bias=self.preset_out_bias, + stats_distinguish_types=self.get_compute_stats_distinguish_types(), + intensive=self.get_intensive(), + ) + self._store_out_stat(delta_bias, out_std, add=True) + elif bias_adjust_mode == "set-by-statistic": + bias_out, std_out = compute_output_stats( + sample_merged, + self.get_ntypes(), + keys=list(self.atomic_output_def().keys()), + stat_file_path=stat_file_path, + rcond=self.rcond, + preset_bias=self.preset_out_bias, + stats_distinguish_types=self.get_compute_stats_distinguish_types(), + intensive=self.get_intensive(), + ) + self._store_out_stat(bias_out, std_out) + else: + raise RuntimeError("Unknown bias_adjust_mode mode: " + bias_adjust_mode) + + def _store_out_stat( + self, + out_bias: dict[str, np.ndarray], + out_std: dict[str, np.ndarray], + add: bool = False, + ) -> None: + """Store output bias and std into the model.""" + ntypes = self.get_ntypes() + out_bias_data = np.copy(self.out_bias) + out_std_data = np.copy(self.out_std) + for kk in out_bias.keys(): + assert kk in out_std.keys() + idx = self._get_bias_index(kk) + size = self._varsize(self.atomic_output_def()[kk].shape) + if not add: + out_bias_data[idx, :, :size] = out_bias[kk].reshape(ntypes, size) + else: + out_bias_data[idx, :, :size] += out_bias[kk].reshape(ntypes, size) + out_std_data[idx, :, :size] = out_std[kk].reshape(ntypes, size) + self.out_bias = out_bias_data + self.out_std = out_std_data + + def _get_forward_wrapper_func(self) -> Callable[..., dict[str, np.ndarray]]: + """Get a forward wrapper of the atomic model for output bias calculation.""" + import array_api_compat + + from deepmd.dpmodel.utils.nlist import ( + extend_input_and_build_neighbor_list, + ) + + def model_forward( + coord: np.ndarray, + atype: np.ndarray, + box: np.ndarray | None, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + ) -> dict[str, np.ndarray]: + # Get reference array to determine the target array type and device + # Use out_bias as reference since it's always present + ref_array = self.out_bias + xp = array_api_compat.array_namespace(ref_array) + + # Convert numpy inputs to the model's array type with correct device + device = getattr(ref_array, "device", None) + if device is not None: + # For torch tensors + coord = xp.asarray(coord, device=device) + atype = xp.asarray(atype, device=device) + if box is not None: + # Check if box is all zeros before converting + if np.allclose(box, 0.0): + box = None + else: + box = xp.asarray(box, device=device) + if fparam is not None: + fparam = xp.asarray(fparam, device=device) + if aparam is not None: + aparam = xp.asarray(aparam, device=device) + else: + # For numpy arrays + coord = xp.asarray(coord) + atype = xp.asarray(atype) + if box is not None: + if np.allclose(box, 0.0): + box = None + else: + box = xp.asarray(box) + if fparam is not None: + fparam = xp.asarray(fparam) + if aparam is not None: + aparam = xp.asarray(aparam) + + ( + extended_coord, + extended_atype, + mapping, + nlist, + ) = extend_input_and_build_neighbor_list( + coord, + atype, + self.get_rcut(), + self.get_sel(), + mixed_types=self.mixed_types(), + box=box, + ) + atomic_ret = self.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + # Convert outputs back to numpy arrays + return {kk: to_numpy_array(vv) for kk, vv in atomic_ret.items()} + + return model_forward + def serialize(self) -> dict: return { "type_map": self.type_map, diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index dabbc34e01..cc730ddda6 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -121,7 +121,8 @@ def to_numpy_array(x: Optional["Array"]) -> np.ndarray | None: try: # asarray is not within Array API standard, so may fail return np.asarray(x) - except (ValueError, AttributeError, TypeError): + except (ValueError, AttributeError, TypeError, RuntimeError): + # RuntimeError: handles torch tensors with requires_grad=True xp = array_api_compat.array_namespace(x) # to fix BufferError: Cannot export readonly array since signalling readonly is unsupported by DLPack. # Move to CPU device to ensure numpy compatibility diff --git a/deepmd/pt_expt/atomic_model/__init__.py b/deepmd/pt_expt/atomic_model/__init__.py new file mode 100644 index 0000000000..51ee9f4186 --- /dev/null +++ b/deepmd/pt_expt/atomic_model/__init__.py @@ -0,0 +1,12 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .dp_atomic_model import ( + DPAtomicModel, +) +from .energy_atomic_model import ( + DPEnergyAtomicModel, +) + +__all__ = [ + "DPAtomicModel", + "DPEnergyAtomicModel", +] diff --git a/deepmd/pt_expt/atomic_model/dp_atomic_model.py b/deepmd/pt_expt/atomic_model/dp_atomic_model.py new file mode 100644 index 0000000000..5c00192661 --- /dev/null +++ b/deepmd/pt_expt/atomic_model/dp_atomic_model.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) + + +class DPAtomicModel(DPAtomicModelDP, torch.nn.Module): + # Import at class level to set base classes for deserialization + # These will be used by the dpmodel deserialize method to create pt_expt instances + from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, + ) + from deepmd.pt_expt.fitting.base_fitting import ( + BaseFitting, + ) + + base_descriptor_cls = BaseDescriptor + base_fitting_cls = BaseFitting + + def __init__( + self, descriptor: Any, fitting: Any, *args: Any, **kwargs: Any + ) -> None: + torch.nn.Module.__init__(self) + # Convert descriptor and fitting to pt_expt versions if they are dpmodel instances + # The dpmodel_setattr mechanism will handle this automatically via registry + from deepmd.pt_expt.common import ( + try_convert_module, + ) + + descriptor_pt = try_convert_module(descriptor) + fitting_pt = try_convert_module(fitting) + # If conversion failed (not registered), use original (assume already pt_expt) + if descriptor_pt is None: + descriptor_pt = descriptor + if fitting_pt is None: + fitting_pt = fitting + DPAtomicModelDP.__init__(self, descriptor_pt, fitting_pt, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + return self.forward_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + + +register_dpmodel_mapping( + DPAtomicModelDP, + lambda v: DPAtomicModel.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/atomic_model/energy_atomic_model.py b/deepmd/pt_expt/atomic_model/energy_atomic_model.py new file mode 100644 index 0000000000..5f34d215cf --- /dev/null +++ b/deepmd/pt_expt/atomic_model/energy_atomic_model.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.atomic_model.energy_atomic_model import ( + DPEnergyAtomicModel as DPEnergyAtomicModelDP, +) +from deepmd.pt_expt.common import ( + register_dpmodel_mapping, +) + +from .dp_atomic_model import ( + DPAtomicModel, +) + + +class DPEnergyAtomicModel(DPAtomicModel): + """Energy atomic model for pt_expt backend. + + This is a thin wrapper around DPAtomicModel that validates + the fitting is an EnergyFittingNet or InvarFitting. + """ + + pass + + +register_dpmodel_mapping( + DPEnergyAtomicModelDP, + lambda v: DPEnergyAtomicModel.deserialize(v.serialize()), +) diff --git a/source/tests/pt_expt/atomic_model/__init__.py b/source/tests/pt_expt/atomic_model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/atomic_model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/atomic_model/test_atomic_model_atomic_stat.py b/source/tests/pt_expt/atomic_model/test_atomic_model_atomic_stat.py new file mode 100644 index 0000000000..c393ad4b3b --- /dev/null +++ b/source/tests/pt_expt/atomic_model/test_atomic_model_atomic_stat.py @@ -0,0 +1,471 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tempfile +import unittest +from pathlib import ( + Path, +) +from typing import ( + NoReturn, +) + +import h5py +import numpy as np +import torch + +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt_expt.atomic_model import ( + DPAtomicModel, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting.base_fitting import ( + BaseFitting, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.utils.path import ( + DPPath, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + + +class FooFitting(BaseFitting, torch.nn.Module): + """Test fitting that returns fixed values for testing bias computation.""" + + def __init__(self): + torch.nn.Module.__init__(self) + BaseFitting.__init__(self) + + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "foo", + [1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + "bar", + [1, 2], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def serialize(self) -> dict: + return { + "@class": "Fitting", + "type": "foo", + "@version": 1, + } + + @classmethod + def deserialize(cls, data: dict): + return cls() + + def get_dim_fparam(self) -> int: + return 0 + + def get_dim_aparam(self) -> int: + return 0 + + def get_sel_type(self) -> list[int]: + return [] + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + pass + + def get_type_map(self) -> list[str]: + return [] + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["foo"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + .view([nf, nloc, *self.output_def()["foo"].shape]) + .to(dtype=torch.float64, device=env.DEVICE) + ) + ret["bar"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ) + .view([nf, nloc, *self.output_def()["bar"].shape]) + .to(dtype=torch.float64, device=env.DEVICE) + ) + return ret + + +class TestAtomicModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + def tearDown(self) -> None: + self.tempdir.cleanup() + + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + self.merged_output_stat = [ + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32), device=self.device + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # bias of foo: 5, 6 + "atom_foo": torch.tensor( + np.array([[5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]).reshape(2, 3, 1), + device=self.device, + ), + # bias of bar: [1, 5], [3, 2] + "bar": torch.tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2), device=self.device + ), + "find_atom_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32), device=self.device + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # bias of foo: 5, 6 from atomic label. + "foo": torch.tensor( + np.array([5.0, 7.0]).reshape(2, 1), device=self.device + ), + # bias of bar: [1, 5], [3, 2] + "bar": torch.tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2), device=self.device + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self) -> None: + """Test output statistics computation for pt_expt atomic model.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = FooFitting().to(self.device) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: vv.detach().cpu().numpy() for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + expected_std = np.ones( + (2, 2, 2), dtype=np.float64 + ) # 2 keys, 2 atypes, 2 max dims. + expected_std[0, :, :1] = np.array([0.0, 0.816496]).reshape( + 2, 1 + ) # updating std for foo based on [5.0, 5.0, 5.0], [5.0, 6.0, 7.0]] + np.testing.assert_almost_equal( + md0.out_std.detach().cpu().numpy(), expected_std, decimal=4 + ) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([5.0, 6.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error() -> NoReturn: + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + np.testing.assert_almost_equal( + md0.out_std.detach().cpu().numpy(), expected_std, decimal=4 + ) + + # 4. test change bias + md0.change_out_bias( + self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + # use atype_ext from merged_output_stat for inference (matching pt backend test) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + self.merged_output_stat[0]["atype_ext"].to( + dtype=torch.int64, device=self.device + ), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + expected_std[0, :, :1] = np.array([1.24722, 0.47140]).reshape( + 2, 1 + ) # updating std for foo based on [4.0, 3.0, 2.0], [1.0, 1.0, 1.0]] + expected_ret3 = {} + # new bias [2.666, 1.333] + expected_ret3["foo"] = np.array( + [[3.6667, 4.6667, 4.3333], [6.6667, 6.3333, 7.3333]] + ).reshape(2, 3, 1) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) + np.testing.assert_almost_equal( + md0.out_std.detach().cpu().numpy(), expected_std, decimal=4 + ) + + +class TestAtomicModelStatMergeGlobalAtomic( + unittest.TestCase, TestCaseSingleFrameWithNlist +): + """Test merging atomic and global stat when atomic label only covers some types.""" + + def tearDown(self) -> None: + self.tempdir.cleanup() + + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + self.merged_output_stat = [ + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int32), + device=self.device, + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # bias of foo: 5.5, nan (only type 0 atoms) + "atom_foo": torch.tensor( + np.array([[5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]).reshape(2, 3, 1), + device=self.device, + ), + # bias of bar: [1, 5], [3, 2] + "bar": torch.tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2), + device=self.device, + ), + "find_atom_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32), + device=self.device, + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # bias of foo: 5.5, 3 from global label. + "foo": torch.tensor( + np.array([5.0, 7.0]).reshape(2, 1), device=self.device + ), + # bias of bar: [1, 5], [3, 2] + "bar": torch.tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2), + device=self.device, + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self) -> None: + """Test merging atomic (type 0 only) and global stat for type 1.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = FooFitting().to(self.device) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: vv.detach().cpu().numpy() for kk, vv in x.items()} + + # 1. test run without bias + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + # foo: type 0 from atomic (mean=5.5), type 1 from global (lstsq=3.0) + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([5.5, 3.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error() -> NoReturn: + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + + # 4. test change bias + md0.change_out_bias( + self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + # use atype_ext from merged_output_stat for inference + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + self.merged_output_stat[0]["atype_ext"].to( + dtype=torch.int64, device=self.device + ), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + expected_ret3 = {} + # new bias [2, -5] + expected_ret3["foo"] = np.array([[3, 4, -2], [6, 0, 1]]).reshape(2, 3, 1) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) diff --git a/source/tests/pt_expt/atomic_model/test_atomic_model_global_stat.py b/source/tests/pt_expt/atomic_model/test_atomic_model_global_stat.py new file mode 100644 index 0000000000..e09e7c0c91 --- /dev/null +++ b/source/tests/pt_expt/atomic_model/test_atomic_model_global_stat.py @@ -0,0 +1,759 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import tempfile +import unittest +from pathlib import ( + Path, +) +from typing import ( + NoReturn, +) + +import h5py +import numpy as np +import torch + +from deepmd.dpmodel.atomic_model import DPAtomicModel as DPDPAtomicModel +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + OutputVariableDef, +) +from deepmd.pt_expt.atomic_model import ( + DPAtomicModel, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.fitting.base_fitting import ( + BaseFitting, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.utils.path import ( + DPPath, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class FooFitting(BaseFitting, torch.nn.Module): + """Test fitting with multiple outputs for testing global statistics.""" + + def __init__(self): + torch.nn.Module.__init__(self) + BaseFitting.__init__(self) + + def output_def(self): + return FittingOutputDef( + [ + OutputVariableDef( + "foo", + [1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + "pix", + [1], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + OutputVariableDef( + "bar", + [1, 2], + reducible=True, + r_differentiable=True, + c_differentiable=True, + ), + ] + ) + + def serialize(self) -> dict: + return { + "@class": "Fitting", + "type": "foo", + "@version": 1, + } + + @classmethod + def deserialize(cls, data: dict): + return cls() + + def get_dim_fparam(self) -> int: + return 0 + + def get_dim_aparam(self) -> int: + return 0 + + def get_sel_type(self) -> list[int]: + return [] + + def change_type_map( + self, type_map: list[str], model_with_new_type_stat=None + ) -> None: + pass + + def get_type_map(self) -> list[str]: + return [] + + def forward( + self, + descriptor: torch.Tensor, + atype: torch.Tensor, + gr: torch.Tensor | None = None, + g2: torch.Tensor | None = None, + h2: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + ): + nf, nloc, _ = descriptor.shape + ret = {} + ret["foo"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ) + .view([nf, nloc, *self.output_def()["foo"].shape]) + .to(dtype=torch.float64, device=env.DEVICE) + ) + ret["pix"] = ( + torch.Tensor( + [ + [3.0, 2.0, 1.0], + [6.0, 5.0, 4.0], + ] + ) + .view([nf, nloc, *self.output_def()["pix"].shape]) + .to(dtype=torch.float64, device=env.DEVICE) + ) + ret["bar"] = ( + torch.Tensor( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ) + .view([nf, nloc, *self.output_def()["bar"].shape]) + .to(dtype=torch.float64, device=env.DEVICE) + ) + return ret + + +def _to_numpy(x): + return x.detach().cpu().numpy() + + +class TestAtomicModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + def tearDown(self) -> None: + self.tempdir.cleanup() + + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + self.merged_output_stat = [ + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32), + device=self.device, + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # bias of foo: 1, 3 + "foo": torch.tensor( + np.array([5.0, 7.0]).reshape(2, 1), device=self.device + ), + # no bias of pix + # bias of bar: [1, 5], [3, 2] + "bar": torch.tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2), + device=self.device, + ), + "find_foo": np.float32(1.0), + "find_bar": np.float32(1.0), + } + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_output_stat(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = FooFitting().to(self.device) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: _to_numpy(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["pix"] = np.array( + [ + [3.0, 2.0, 1.0], + [6.0, 5.0, 4.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["pix"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + expected_std = np.ones((3, 2, 2)) # 3 keys, 2 atypes, 2 max dims. + # nt x odim + foo_bias = np.array([1.0, 3.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["pix"] = ret0["pix"] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + np.testing.assert_almost_equal(_to_numpy(md0.out_std), expected_std) + + # 3. test bias load from file + def raise_error() -> NoReturn: + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + np.testing.assert_almost_equal(_to_numpy(md0.out_std), expected_std) + + # 4. test change bias + md0.change_out_bias( + self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + # use atype_ext from merged_output_stat for inference (matching pt backend test) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + self.merged_output_stat[0]["atype_ext"].to( + dtype=torch.int64, device=self.device + ), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + ## model output on foo: [[2, 3, 6], [5, 8, 9]] given bias [1, 3] + ## foo sumed: [11, 22] compared with [5, 7], fit target is [-6, -15] + ## fit bias is [1, -8] + ## old bias + fit bias [2, -5] + ## new model output is [[3, 4, -2], [6, 0, 1]], which sumed to [5, 7] + expected_ret3 = {} + expected_ret3["foo"] = np.array([[3, 4, -2], [6, 0, 1]]).reshape(2, 3, 1) + expected_ret3["pix"] = ret0["pix"] + for kk in ["foo", "pix"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk]) + # bar is too complicated to be manually computed. + np.testing.assert_almost_equal(_to_numpy(md0.out_std), expected_std) + + def test_preset_bias(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = FooFitting().to(self.device) + type_map = ["foo", "bar"] + preset_out_bias = { + "foo": [None, 2], + "bar": np.array([7.0, 5.0, 13.0, 11.0]).reshape(2, 1, 2), + } + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + preset_out_bias=preset_out_bias, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: _to_numpy(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["pix"] = np.array( + [ + [3.0, 2.0, 1.0], + [6.0, 5.0, 4.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["pix"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # foo sums: [5, 7], + # given bias of type 1 being 2, the bias left for type 0 is [5-2*1, 7-2*2] = [3,3] + # the solution of type 0 is 1.8 + foo_bias = np.array([1.8, preset_out_bias["foo"][1]]).reshape(2, 1) + bar_bias = preset_out_bias["bar"] + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["pix"] = ret0["pix"] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + # 3. test bias load from file + def raise_error() -> NoReturn: + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + ret2 = cvt_ret(ret2) + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], ret2[kk]) + + # 4. test change bias + md0.change_out_bias( + self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + # use atype_ext from merged_output_stat for inference + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + self.merged_output_stat[0]["atype_ext"].to( + dtype=torch.int64, device=self.device + ), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + ## model output on foo: [[2.8, 3.8, 5], [5.8, 7., 8.]] given bias [1.8, 2] + ## foo sumed: [11.6, 20.8] compared with [5, 7], fit target is [-6.6, -13.8] + ## fit bias is [-7, 2] (2 is assigned. -7 is fit to [-8.6, -17.8]) + ## old bias[1.8,2] + fit bias[-7, 2] = [-5.2, 4] + ## new model output is [[-4.2, -3.2, 7], [-1.2, 9, 10]] + expected_ret3 = {} + expected_ret3["foo"] = np.array([[-4.2, -3.2, 7.0], [-1.2, 9.0, 10.0]]).reshape( + 2, 3, 1 + ) + expected_ret3["pix"] = ret0["pix"] + for kk in ["foo", "pix"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk]) + # bar is too complicated to be manually computed. + + def test_preset_bias_all_none(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = FooFitting().to(self.device) + type_map = ["foo", "bar"] + preset_out_bias = { + "foo": [None, None], + } + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + preset_out_bias=preset_out_bias, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + # nf x nloc + at = self.atype_ext[:, :nloc] + + def cvt_ret(x): + return {kk: _to_numpy(vv) for kk, vv in x.items()} + + # 1. test run without bias + # nf x na x odim + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + expected_ret0 = {} + expected_ret0["foo"] = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["foo"].shape]) + expected_ret0["pix"] = np.array( + [ + [3.0, 2.0, 1.0], + [6.0, 5.0, 4.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["pix"].shape]) + expected_ret0["bar"] = np.array( + [ + [1.0, 2.0, 3.0, 7.0, 8.0, 9.0], + [4.0, 5.0, 6.0, 10.0, 11.0, 12.0], + ] + ).reshape([nf, nloc, *md0.fitting_output_def()["bar"].shape]) + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret0[kk], expected_ret0[kk]) + + # 2. test bias is applied (all None preset = same as no preset) + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + # nt x odim + foo_bias = np.array([1.0, 3.0]).reshape(2, 1) + bar_bias = np.array([1.0, 5.0, 3.0, 2.0]).reshape(2, 1, 2) + expected_ret1 = {} + expected_ret1["foo"] = ret0["foo"] + foo_bias[at] + expected_ret1["pix"] = ret0["pix"] + expected_ret1["bar"] = ret0["bar"] + bar_bias[at] + for kk in ["foo", "pix", "bar"]: + np.testing.assert_almost_equal(ret1[kk], expected_ret1[kk]) + + def test_serialize(self) -> None: + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = InvarFitting( + "foo", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(self.device) + type_map = ["A", "B"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + + def cvt_ret(x): + return {kk: _to_numpy(vv) for kk, vv in x.items()} + + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + md1 = DPAtomicModel.deserialize(md0.serialize()) + ret1 = md1.forward_common_atomic(*args) + ret1 = cvt_ret(ret1) + + for kk in ["foo"]: + np.testing.assert_almost_equal(ret0[kk], ret1[kk]) + + md2 = DPDPAtomicModel.deserialize(md0.serialize()) + args_np = [self.coord_ext, self.atype_ext, self.nlist] + ret2 = md2.forward_common_atomic(*args_np) + for kk in ["foo"]: + np.testing.assert_almost_equal(ret0[kk], ret2[kk]) + + +class TestChangeByStatMixedLabels(unittest.TestCase, TestCaseSingleFrameWithNlist): + """Test change-by-statistic with mixed atomic and global labels.""" + + def tearDown(self) -> None: + self.tempdir.cleanup() + + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + self.merged_output_stat = [ + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32), + device=self.device, + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # foo: atomic label + "atom_foo": torch.tensor( + np.array([[5.0, 5.0, 5.0], [5.0, 6.0, 7.0]]).reshape(2, 3, 1), + device=self.device, + ), + # pix: global label + "pix": torch.tensor( + np.array([5.0, 12.0]).reshape(2, 1), device=self.device + ), + # bar: global label + "bar": torch.tensor( + np.array([5.0, 12.0, 7.0, 9.0]).reshape(2, 1, 2), + device=self.device, + ), + "find_atom_foo": np.float32(1.0), + "find_pix": np.float32(1.0), + "find_bar": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_change_by_statistic(self) -> None: + """Test change-by-statistic with atomic foo + global pix + global bar.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = FooFitting().to(self.device) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + + def cvt_ret(x): + return {kk: _to_numpy(vv) for kk, vv in x.items()} + + ret0 = md0.forward_common_atomic(*args) + ret0 = cvt_ret(ret0) + + # set initial bias + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + + # change bias + md0.change_out_bias( + self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + # use atype_ext from merged_output_stat for inference + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + self.merged_output_stat[0]["atype_ext"].to( + dtype=torch.int64, device=self.device + ), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret3 = md0.forward_common_atomic(*args) + ret3 = cvt_ret(ret3) + # foo: atomic label, bias after set-by-stat: [5, 6] + # model output with bias [5,6], atype [[0,0,1],[0,1,1]]: + # [[6, 7, 9], [9, 11, 12]] + # atom_foo labels: [[5, 5, 5], [5, 6, 7]] + # per-atom delta: [[-1, -2, -4], [-4, -5, -5]] + # delta bias (mean per type): type0=-7/3, type1=-14/3 + # new bias = [5-7/3, 6-14/3] = [8/3, 4/3] + # new output: [[11/3, 14/3, 13/3], [20/3, 19/3, 22/3]] + expected_ret3 = {} + expected_ret3["foo"] = np.array( + [[3.6667, 4.6667, 4.3333], [6.6667, 6.3333, 7.3333]] + ).reshape(2, 3, 1) + # pix: global label, bias after set-by-stat: [-2/3, 19/3] + # model pix with bias, atype [[0,0,1],[0,1,1]]: + # [[7/3, 4/3, 22/3], [16/3, 34/3, 31/3]], sums [11, 27] + # labels [5, 12], delta [-6, -15] + # lstsq: delta bias [1, -8], new bias [1/3, -5/3] + # new output: [[10/3, 7/3, -2/3], [19/3, 10/3, 7/3]] + expected_ret3["pix"] = np.array( + [[3.3333, 2.3333, -0.6667], [6.3333, 3.3333, 2.3333]] + ).reshape(2, 3, 1) + for kk in ["foo", "pix"]: + np.testing.assert_almost_equal(ret3[kk], expected_ret3[kk], decimal=4) + # bar is too complicated to be manually computed. + + +class TestEnergyModelStat(unittest.TestCase, TestCaseSingleFrameWithNlist): + """Test statistics computation with real energy fitting net.""" + + def tearDown(self) -> None: + self.tempdir.cleanup() + + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + self.merged_output_stat = [ + { + "coord": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "atype": torch.tensor( + np.array([[0, 0, 1], [0, 1, 1]], dtype=np.int32), + device=self.device, + ), + "atype_ext": torch.tensor( + np.array([[0, 0, 1, 0], [0, 1, 1, 0]], dtype=np.int32), + device=self.device, + ), + "box": torch.tensor(np.zeros([2, 3, 3]), device=self.device), + "natoms": torch.tensor( + np.array([[3, 3, 2, 1], [3, 3, 1, 2]], dtype=np.int32), + device=self.device, + ), + # energy data + "energy": torch.tensor( + np.array([10.0, 20.0]).reshape(2, 1), device=self.device + ), + "find_energy": np.float32(1.0), + }, + ] + self.tempdir = tempfile.TemporaryDirectory() + h5file = str((Path(self.tempdir.name) / "testcase.h5").resolve()) + with h5py.File(h5file, "w") as f: + pass + self.stat_file_path = DPPath(h5file, "a") + + def test_energy_stat(self) -> None: + """Test energy statistics computation with real energy fitting net.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(self.device) + type_map = ["foo", "bar"] + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + + # test run without bias + ret0 = md0.forward_common_atomic(*args) + self.assertIn("energy", ret0) + + # compute statistics + md0.compute_or_load_out_stat( + self.merged_output_stat, stat_file_path=self.stat_file_path + ) + ret1 = md0.forward_common_atomic(*args) + self.assertIn("energy", ret1) + + # Check that bias was computed (out_bias should be non-zero) + self.assertFalse(torch.all(md0.out_bias == 0)) + + # test bias load from file + def raise_error() -> NoReturn: + raise RuntimeError + + md0.compute_or_load_out_stat(raise_error, stat_file_path=self.stat_file_path) + ret2 = md0.forward_common_atomic(*args) + np.testing.assert_allclose( + ret1["energy"].detach().cpu().numpy(), + ret2["energy"].detach().cpu().numpy(), + ) + + # test change bias + md0.change_out_bias( + self.merged_output_stat, bias_adjust_mode="change-by-statistic" + ) + ret3 = md0.forward_common_atomic(*args) + self.assertIn("energy", ret3) diff --git a/source/tests/pt_expt/atomic_model/test_dp_atomic_model.py b/source/tests/pt_expt/atomic_model/test_dp_atomic_model.py new file mode 100644 index 0000000000..49e60373d4 --- /dev/null +++ b/source/tests/pt_expt/atomic_model/test_dp_atomic_model.py @@ -0,0 +1,287 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.atomic_model import DPAtomicModel as DPDPAtomicModel +from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA +from deepmd.dpmodel.fitting import InvarFitting as DPInvarFitting +from deepmd.pt_expt.atomic_model import ( + DPAtomicModel, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, + TestCaseSingleFrameWithNlistWithVirtual, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDPAtomicModel(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_self_consistency(self) -> None: + """Test that pt_expt atomic model serialize/deserialize preserves behavior.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(self.device) + type_map = ["foo", "bar"] + + # test the case of exclusion + for atom_excl, pair_excl in itertools.product([[], [1]], [[], [[0, 1]]]): + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + md0.reinit_atom_exclude(atom_excl) + md0.reinit_pair_exclude(pair_excl) + md1 = DPAtomicModel.deserialize(md0.serialize()).to(self.device) + + # Test forward pass + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret0 = md0.forward_common_atomic(*args) + ret1 = md1.forward_common_atomic(*args) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + ) + + def test_dp_consistency(self) -> None: + """Test numerical consistency between dpmodel and pt_expt atomic models.""" + nf, nloc, nnei = self.nlist.shape + ds = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft = DPInvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ) + type_map = ["foo", "bar"] + md0 = DPDPAtomicModel(ds, ft, type_map=type_map) + md1 = DPAtomicModel.deserialize(md0.serialize()).to(self.device) + + # dpmodel uses numpy arrays + args0 = [self.coord_ext, self.atype_ext, self.nlist] + # pt_expt uses torch tensors + args1 = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret0 = md0.forward_common_atomic(*args0) + ret1 = md1.forward_common_atomic(*args1) + np.testing.assert_allclose( + ret0["energy"], + ret1["energy"].detach().cpu().numpy(), + ) + + def test_exportable(self) -> None: + """Test that pt_expt atomic model can be exported with torch.export.""" + nf, nloc, nnei = self.nlist.shape + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(self.device) + type_map = ["foo", "bar"] + md0 = DPAtomicModel(ds, ft, type_map=type_map).to(self.device) + md0 = md0.eval() + + # Prepare inputs for export + coord = torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device) + atype = torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device) + nlist = torch.tensor(self.nlist, dtype=torch.int64, device=self.device) + + # Test forward pass + ret0 = md0(coord, atype, nlist) + self.assertIn("energy", ret0) + + # Test torch.export + exported = torch.export.export( + md0, + (coord, atype, nlist), + strict=False, + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + ret1 = exported.module()(coord, atype, nlist) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + def test_excl_consistency(self) -> None: + """Test that exclusion masks work correctly after serialize/deserialize.""" + type_map = ["foo", "bar"] + + # test the case of exclusion + for atom_excl, pair_excl in itertools.product([[], [1]], [[], [[0, 1]]]): + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(self.device) + md0 = DPAtomicModel( + ds, + ft, + type_map=type_map, + ).to(self.device) + md1 = DPAtomicModel.deserialize(md0.serialize()).to(self.device) + + md0.reinit_atom_exclude(atom_excl) + md0.reinit_pair_exclude(pair_excl) + # hacking! + md1.descriptor.reinit_exclude(pair_excl) + md1.fitting.reinit_exclude(atom_excl) + + # check energy consistency + args = [ + torch.tensor(self.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.nlist, dtype=torch.int64, device=self.device), + ] + ret0 = md0.forward_common_atomic(*args) + ret1 = md1.forward_common_atomic(*args) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + ) + + # check output def + out_names = [vv.name for vv in md0.atomic_output_def().get_data().values()] + self.assertEqual(out_names, ["energy", "mask"]) + if atom_excl != []: + for ii in md0.atomic_output_def().get_data().values(): + if ii.name == "mask": + self.assertEqual(ii.shape, [1]) + self.assertFalse(ii.reducible) + self.assertFalse(ii.r_differentiable) + self.assertFalse(ii.c_differentiable) + + # check mask + if atom_excl == []: + pass + elif atom_excl == [1]: + self.assertIn("mask", ret0.keys()) + expected = np.array([1, 1, 0], dtype=int) + expected = np.concatenate( + [expected, expected[self.perm[: self.nloc]]] + ).reshape(2, 3) + np.testing.assert_array_equal( + ret0["mask"].detach().cpu().numpy(), expected + ) + else: + raise ValueError(f"not expected atom_excl {atom_excl}") + + +class TestDPAtomicModelVirtualConsistency(unittest.TestCase): + def setUp(self) -> None: + self.case0 = TestCaseSingleFrameWithNlist() + self.case1 = TestCaseSingleFrameWithNlistWithVirtual() + self.case0.setUp() + self.case1.setUp() + self.device = env.DEVICE + + def test_virtual_consistency(self) -> None: + nf, _, _ = self.case0.nlist.shape + ds = DescrptSeA( + self.case0.rcut, + self.case0.rcut_smth, + self.case0.sel, + ) + ft = InvarFitting( + "energy", + self.case0.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ) + type_map = ["foo", "bar"] + md1 = DPAtomicModel(ds, ft, type_map=type_map).to(self.device) + + args0 = [ + torch.tensor(self.case0.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.case0.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.case0.nlist, dtype=torch.int64, device=self.device), + ] + args1 = [ + torch.tensor(self.case1.coord_ext, dtype=torch.float64, device=self.device), + torch.tensor(self.case1.atype_ext, dtype=torch.int64, device=self.device), + torch.tensor(self.case1.nlist, dtype=torch.int64, device=self.device), + ] + + ret0 = md1.forward_common_atomic(*args0) + ret1 = md1.forward_common_atomic(*args1) + + for dd in range(self.case0.nf): + np.testing.assert_allclose( + ret0["energy"][dd].detach().cpu().numpy(), + ret1["energy"][dd, self.case1.get_real_mapping[dd], :] + .detach() + .cpu() + .numpy(), + ) + expected_mask = np.array( + [ + [1, 0, 1, 1], + [1, 1, 0, 1], + ] + ) + np.testing.assert_equal(ret1["mask"].detach().cpu().numpy(), expected_mask) From 4ae2726cbb3cd216a34f1e087001bc7984182702 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 12 Feb 2026 19:13:39 +0800 Subject: [PATCH 39/60] feat: full energy model (but not exportable) --- deepmd/dpmodel/model/make_model.py | 62 ++++-- deepmd/dpmodel/model/transform_output.py | 34 +-- deepmd/dpmodel/utils/network.py | 6 +- deepmd/pt_expt/model/__init__.py | 8 + deepmd/pt_expt/model/ener_model.py | 61 ++++++ deepmd/pt_expt/model/make_model.py | 92 ++++++++ deepmd/pt_expt/model/transform_output.py | 198 ++++++++++++++++++ source/tests/consistent/model/common.py | 16 ++ source/tests/consistent/model/test_ener.py | 57 +++++ source/tests/pt_expt/model/test_autodiff.py | 185 ++++++++++++++++ source/tests/pt_expt/model/test_ener_model.py | 183 ++++++++++++++++ 11 files changed, 862 insertions(+), 40 deletions(-) create mode 100644 deepmd/pt_expt/model/__init__.py create mode 100644 deepmd/pt_expt/model/ener_model.py create mode 100644 deepmd/pt_expt/model/make_model.py create mode 100644 deepmd/pt_expt/model/transform_output.py create mode 100644 source/tests/pt_expt/model/test_autodiff.py create mode 100644 source/tests/pt_expt/model/test_ener_model.py diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index cc9dd12fc5..0a77549ca4 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -21,6 +21,7 @@ PRECISION_DICT, RESERVED_PRECISION_DICT, NativeOP, + get_xp_precision, ) from deepmd.dpmodel.model.base_model import ( BaseModel, @@ -103,7 +104,8 @@ def model_call_from_call_lower( bb.reshape(nframes, 3, 3), ) else: - coord_normalized = cc.copy() + xp = array_api_compat.array_namespace(cc) + coord_normalized = xp.reshape(cc, (nframes, nloc, 3)) extended_coord, extended_atype, mapping = extend_coord_with_ghosts( coord_normalized, atype, bb, rcut ) @@ -371,39 +373,57 @@ def input_type_cast( box: Array | None = None, fparam: Array | None = None, aparam: Array | None = None, - ) -> tuple[Array, Array, np.ndarray | None, np.ndarray | None, str]: + ) -> tuple[Array, Array | None, Array | None, Array | None, Any]: """Cast the input data to global float type.""" - input_prec = RESERVED_PRECISION_DICT[self.precision_dict[coord.dtype.name]] + xp = array_api_compat.array_namespace(coord) + input_dtype = coord.dtype + global_dtype = get_xp_precision( + xp, RESERVED_PRECISION_DICT[self.global_np_float_precision] + ) ### ### type checking would not pass jit, convert to coord prec anyway ### - _lst: list[np.ndarray | None] = [ - vv.astype(coord.dtype) if vv is not None else None + _lst: list[Array | None] = [ + xp.astype(vv, input_dtype) if vv is not None else None for vv in [box, fparam, aparam] ] box, fparam, aparam = _lst - if input_prec == RESERVED_PRECISION_DICT[self.global_np_float_precision]: - return coord, box, fparam, aparam, input_prec + if input_dtype == global_dtype: + return coord, box, fparam, aparam, input_dtype else: - pp = self.global_np_float_precision return ( - coord.astype(pp), - box.astype(pp) if box is not None else None, - fparam.astype(pp) if fparam is not None else None, - aparam.astype(pp) if aparam is not None else None, - input_prec, + xp.astype(coord, global_dtype), + xp.astype(box, global_dtype) if box is not None else None, + xp.astype(fparam, global_dtype) if fparam is not None else None, + xp.astype(aparam, global_dtype) if aparam is not None else None, + input_dtype, ) def output_type_cast( self, model_ret: dict[str, Array], - input_prec: str, + input_prec: Any, ) -> dict[str, Array]: - """Convert the model output to the input prec.""" - do_cast = ( - input_prec != RESERVED_PRECISION_DICT[self.global_np_float_precision] + """Convert the model output to the input prec. + + Parameters + ---------- + model_ret + The model output. + input_prec + The input dtype returned by ``input_type_cast``. + """ + model_ret_not_none = [vv for vv in model_ret.values() if vv is not None] + if not model_ret_not_none: + return model_ret + xp = array_api_compat.array_namespace(model_ret_not_none[0]) + global_dtype = get_xp_precision( + xp, RESERVED_PRECISION_DICT[self.global_np_float_precision] + ) + ener_dtype = get_xp_precision( + xp, RESERVED_PRECISION_DICT[self.global_ener_float_precision] ) - pp = self.precision_dict[input_prec] + do_cast = input_prec != global_dtype odef = self.model_output_def() for kk in odef.keys(): if kk not in model_ret.keys(): @@ -411,13 +431,15 @@ def output_type_cast( continue if check_operation_applied(odef[kk], OutputVariableOperation.REDU): model_ret[kk] = ( - model_ret[kk].astype(self.global_ener_float_precision) + xp.astype(model_ret[kk], ener_dtype) if model_ret[kk] is not None else None ) elif do_cast: model_ret[kk] = ( - model_ret[kk].astype(pp) if model_ret[kk] is not None else None + xp.astype(model_ret[kk], input_prec) + if model_ret[kk] is not None + else None ) return model_ret diff --git a/deepmd/dpmodel/model/transform_output.py b/deepmd/dpmodel/model/transform_output.py index d3315eda55..b697898896 100644 --- a/deepmd/dpmodel/model/transform_output.py +++ b/deepmd/dpmodel/model/transform_output.py @@ -98,6 +98,7 @@ def communicate_extended_output( """ xp = array_api_compat.get_namespace(mapping) + device = array_api_compat.device(mapping) mapping_ = mapping new_ret = {} for kk in model_output_def.keys_outp(): @@ -117,7 +118,9 @@ def communicate_extended_output( mapping, tuple(mldims + [1] * len(derv_r_ext_dims)) ) mapping = xp.tile(mapping, [1] * len(mldims) + derv_r_ext_dims) - force = xp.zeros(vldims + derv_r_ext_dims, dtype=vv.dtype) + force = xp.zeros( + vldims + derv_r_ext_dims, dtype=vv.dtype, device=device + ) force = xp_scatter_sum( force, 1, @@ -149,7 +152,9 @@ def communicate_extended_output( nall = hess_1.shape[1] # (1) -> [nf, nloc1, nall2, *def, 3(1), 3(2)] hessian1 = xp.zeros( - [*vldims, nall, *vdef.shape, 3, 3], dtype=vv.dtype + [*vldims, nall, *vdef.shape, 3, 3], + dtype=vv.dtype, + device=device, ) mapping_hess = xp.reshape( mapping_, (mldims + [1] * (len(vdef.shape) + 3)) @@ -172,7 +177,9 @@ def communicate_extended_output( nloc = hessian1.shape[2] # (2) -> [nf, nloc2, nloc1, *def, 3(1), 3(2)] hessian = xp.zeros( - [*vldims, nloc, *vdef.shape, 3, 3], dtype=vv.dtype + [*vldims, nloc, *vdef.shape, 3, 3], + dtype=vv.dtype, + device=device, ) mapping_hess = xp.reshape( mapping_, (mldims + [1] * (len(vdef.shape) + 3)) @@ -218,21 +225,14 @@ def communicate_extended_output( virial = xp.zeros( vldims + derv_c_ext_dims, dtype=vv.dtype, + device=device, + ) + virial = xp_scatter_sum( + virial, + 1, + mapping, + model_ret[kk_derv_c], ) - # jax only - if array_api_compat.is_jax_array(virial): - from deepmd.jax.common import ( - scatter_sum, - ) - - virial = scatter_sum( - virial, - 1, - mapping, - model_ret[kk_derv_c], - ) - else: - raise NotImplementedError("Only JAX arrays are supported.") new_ret[kk_derv_c] = virial new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1) else: diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 4679412d4b..da48583ac9 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -286,11 +286,11 @@ def call(self, x): # noqa: ANN001, ANN201 y = xp.astype(y, x.dtype) y = fn(y) if self.idt is not None: - y *= self.idt + y = y * self.idt if self.resnet and self.w.shape[1] == self.w.shape[0]: - y += x + y = y + x elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]: - y += xp.concat([x, x], axis=-1) + y = y + xp.concat([x, x], axis=-1) return y diff --git a/deepmd/pt_expt/model/__init__.py b/deepmd/pt_expt/model/__init__.py new file mode 100644 index 0000000000..5d1c5ffb5d --- /dev/null +++ b/deepmd/pt_expt/model/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .ener_model import ( + EnergyModel, +) + +__all__ = [ + "EnergyModel", +] diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py new file mode 100644 index 0000000000..8a68e57551 --- /dev/null +++ b/deepmd/pt_expt/model/ener_model.py @@ -0,0 +1,61 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.model.dp_model import ( + DPModelCommon, +) +from deepmd.pt_expt.atomic_model import ( + DPEnergyAtomicModel, +) + +from .make_model import ( + make_model, +) + +DPEnergyModel_ = make_model(DPEnergyAtomicModel) + + +class EnergyModel(DPModelCommon, DPEnergyModel_): + model_type = "ener" + + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + DPModelCommon.__init__(self) + DPEnergyModel_.__init__(self, *args, **kwargs) + + def forward( + self, + coord: torch.Tensor, + atype: torch.Tensor, + box: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + model_ret = self.call( + coord, + atype, + box, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py new file mode 100644 index 0000000000..05a98982d3 --- /dev/null +++ b/deepmd/pt_expt/model/make_model.py @@ -0,0 +1,92 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.atomic_model.base_atomic_model import ( + BaseAtomicModel, +) +from deepmd.dpmodel.model.make_model import make_model as make_model_dp +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) + +from .transform_output import ( + fit_output_to_model_output, +) + + +def make_model(T_AtomicModel: type[BaseAtomicModel]) -> type: + """Make a model as a derived class of an atomic model. + + Wraps dpmodel's make_model with torch.nn.Module and overrides + forward_common_atomic to use autograd-based derivatives. + + Parameters + ---------- + T_AtomicModel + The atomic model. + + Returns + ------- + CM + The model. + + """ + DPModel = make_model_dp(T_AtomicModel) + + class CM(DPModel, torch.nn.Module): + def __init__( + self, + *args: Any, + **kwargs: Any, + ) -> None: + torch.nn.Module.__init__(self) + DPModel.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward(self, *args: Any, **kwargs: Any) -> dict[str, torch.Tensor]: + """Default forward delegates to call(). + + Subclasses (e.g. EnergyModel) override this with output translation. + """ + return self.call(*args, **kwargs) + + def forward_common_atomic( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + atomic_ret = self.atomic_model.forward_common_atomic( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + fparam=fparam, + aparam=aparam, + ) + return fit_output_to_model_output( + atomic_ret, + self.atomic_output_def(), + extended_coord, + do_atomic_virial=do_atomic_virial, + create_graph=self.training, + mask=atomic_ret.get("mask"), + ) + + return CM diff --git a/deepmd/pt_expt/model/transform_output.py b/deepmd/pt_expt/model/transform_output.py new file mode 100644 index 0000000000..829591983a --- /dev/null +++ b/deepmd/pt_expt/model/transform_output.py @@ -0,0 +1,198 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import torch + +from deepmd.dpmodel import ( + FittingOutputDef, + OutputVariableDef, + get_deriv_name, + get_reduce_name, +) +from deepmd.pt_expt.utils import ( + env, +) + + +def atomic_virial_corr( + extended_coord: torch.Tensor, + atom_energy: torch.Tensor, +) -> torch.Tensor: + nall = extended_coord.shape[1] + nloc = atom_energy.shape[1] + coord, _ = torch.split(extended_coord, [nloc, nall - nloc], dim=1) + # no derivative with respect to the loc coord. + coord = coord.detach() + ce = coord * atom_energy + sumce0, sumce1, sumce2 = torch.split(torch.sum(ce, dim=1), [1, 1, 1], dim=-1) + faked_grad = torch.ones_like(sumce0) + lst: list[torch.Tensor | None] = [faked_grad] + extended_virial_corr0 = torch.autograd.grad( + [sumce0], + [extended_coord], + grad_outputs=lst, + create_graph=False, + retain_graph=True, + )[0] + assert extended_virial_corr0 is not None + extended_virial_corr1 = torch.autograd.grad( + [sumce1], + [extended_coord], + grad_outputs=lst, + create_graph=False, + retain_graph=True, + )[0] + assert extended_virial_corr1 is not None + extended_virial_corr2 = torch.autograd.grad( + [sumce2], + [extended_coord], + grad_outputs=lst, + create_graph=False, + retain_graph=True, + )[0] + assert extended_virial_corr2 is not None + extended_virial_corr = torch.concat( + [ + extended_virial_corr0.unsqueeze(-1), + extended_virial_corr1.unsqueeze(-1), + extended_virial_corr2.unsqueeze(-1), + ], + dim=-1, + ) + return extended_virial_corr + + +def task_deriv_one( + atom_energy: torch.Tensor, + energy: torch.Tensor, + extended_coord: torch.Tensor, + do_virial: bool = True, + do_atomic_virial: bool = False, + create_graph: bool = True, +) -> tuple[torch.Tensor, torch.Tensor | None]: + faked_grad = torch.ones_like(energy) + lst: list[torch.Tensor | None] = [faked_grad] + extended_force = torch.autograd.grad( + [energy], + [extended_coord], + grad_outputs=lst, + create_graph=create_graph, + retain_graph=True, + )[0] + assert extended_force is not None + extended_force = -extended_force + if do_virial: + extended_virial = torch.einsum( + "...ik,...ij->...ikj", extended_force, extended_coord + ) + # the correction sums to zero, which does not contribute to global virial + if do_atomic_virial: + extended_virial_corr = atomic_virial_corr(extended_coord, atom_energy) + extended_virial = extended_virial + extended_virial_corr + # to [...,3,3] -> [...,9] + extended_virial = extended_virial.view(list(extended_virial.shape[:-2]) + [9]) # noqa:RUF005 + else: + extended_virial = None + return extended_force, extended_virial + + +def get_leading_dims( + vv: torch.Tensor, + vdef: OutputVariableDef, +) -> list[int]: + """Get the dimensions of nf x nloc.""" + vshape = vv.shape + return list(vshape[: (len(vshape) - len(vdef.shape))]) + + +def take_deriv( + vv: torch.Tensor, + svv: torch.Tensor, + vdef: OutputVariableDef, + coord_ext: torch.Tensor, + do_virial: bool = False, + do_atomic_virial: bool = False, + create_graph: bool = True, +) -> tuple[torch.Tensor, torch.Tensor | None]: + size = 1 + for ii in vdef.shape: + size *= ii + vv1 = vv.view(list(get_leading_dims(vv, vdef)) + [size]) # noqa: RUF005 + svv1 = svv.view(list(get_leading_dims(svv, vdef)) + [size]) # noqa: RUF005 + split_vv1 = torch.split(vv1, [1] * size, dim=-1) + split_svv1 = torch.split(svv1, [1] * size, dim=-1) + split_ff, split_avir = [], [] + for vvi, svvi in zip(split_vv1, split_svv1): + # nf x nloc x 3, nf x nloc x 9 + ffi, aviri = task_deriv_one( + vvi, + svvi, + coord_ext, + do_virial=do_virial, + do_atomic_virial=do_atomic_virial, + create_graph=create_graph, + ) + # nf x nloc x 1 x 3, nf x nloc x 1 x 9 + ffi = ffi.unsqueeze(-2) + split_ff.append(ffi) + if do_virial: + assert aviri is not None + aviri = aviri.unsqueeze(-2) + split_avir.append(aviri) + # nf x nall x v_dim x 3, nf x nall x v_dim x 9 + out_lead_shape = list(coord_ext.shape[:-1]) + vdef.shape + ff = torch.concat(split_ff, dim=-2).view(out_lead_shape + [3]) # noqa: RUF005 + if do_virial: + avir = torch.concat(split_avir, dim=-2).view(out_lead_shape + [9]) # noqa: RUF005 + else: + avir = None + return ff, avir + + +def fit_output_to_model_output( + fit_ret: dict[str, torch.Tensor], + fit_output_def: FittingOutputDef, + coord_ext: torch.Tensor, + do_atomic_virial: bool = False, + create_graph: bool = True, + mask: torch.Tensor | None = None, +) -> dict[str, torch.Tensor]: + """Transform the output of the fitting network to + the model output. + + """ + redu_prec = env.GLOBAL_PT_ENER_FLOAT_PRECISION + model_ret = dict(fit_ret.items()) + for kk, vv in fit_ret.items(): + vdef = fit_output_def[kk] + shap = vdef.shape + atom_axis = -(len(shap) + 1) + if vdef.reducible: + kk_redu = get_reduce_name(kk) + if vdef.intensive: + if mask is not None: + model_ret[kk_redu] = torch.sum( + vv.to(redu_prec), dim=atom_axis + ) / torch.sum(mask, dim=-1, keepdim=True) + else: + model_ret[kk_redu] = torch.mean(vv.to(redu_prec), dim=atom_axis) + else: + model_ret[kk_redu] = torch.sum(vv.to(redu_prec), dim=atom_axis) + if vdef.r_differentiable: + kk_derv_r, kk_derv_c = get_deriv_name(kk) + dr, dc = take_deriv( + vv, + model_ret[kk_redu], + vdef, + coord_ext, + do_virial=vdef.c_differentiable, + do_atomic_virial=do_atomic_virial, + create_graph=create_graph, + ) + model_ret[kk_derv_r] = dr + if vdef.c_differentiable: + assert dc is not None + model_ret[kk_derv_c] = dc + model_ret[kk_derv_c + "_redu"] = torch.sum( + model_ret[kk_derv_c].to(redu_prec), dim=1 + ) + return model_ret diff --git a/source/tests/consistent/model/common.py b/source/tests/consistent/model/common.py index 778ae519c6..04966c02a1 100644 --- a/source/tests/consistent/model/common.py +++ b/source/tests/consistent/model/common.py @@ -14,6 +14,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, ) @@ -30,6 +31,8 @@ from deepmd.jax.env import ( jnp, ) +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.common import to_torch_array as pt_expt_numpy_to_torch if INSTALLED_PD: from deepmd.pd.utils.utils import to_numpy_array as paddle_to_numpy from deepmd.pd.utils.utils import to_paddle_tensor as numpy_to_paddle @@ -104,6 +107,19 @@ def eval_pt_model(self, pt_obj: Any, natoms, coords, atype, box) -> Any: ).items() } + def eval_pt_expt_model(self, pt_expt_obj: Any, natoms, coords, atype, box) -> Any: + coord_tensor = pt_expt_numpy_to_torch(coords) + coord_tensor.requires_grad_(True) + return { + kk: vv.detach().cpu().numpy() + for kk, vv in pt_expt_obj( + coord_tensor, + pt_expt_numpy_to_torch(atype), + box=pt_expt_numpy_to_torch(box), + do_atomic_virial=True, + ).items() + } + def eval_jax_model(self, jax_obj: Any, natoms, coords, atype, box) -> Any: def assert_jax_array(arr): assert isinstance(arr, jnp.ndarray) or arr is None diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index d56b9a257b..c1ee630516 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -26,6 +26,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, SKIP_FLAG, CommonTest, @@ -53,6 +54,11 @@ from deepmd.pd.utils.utils import to_paddle_tensor as numpy_to_paddle else: EnergyModelPD = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.common import to_torch_array as pt_expt_numpy_to_torch + from deepmd.pt_expt.model import EnergyModel as EnergyModelPTExpt +else: + EnergyModelPTExpt = None from deepmd.utils.argcheck import ( model_args, ) @@ -115,6 +121,7 @@ def data(self) -> dict: dp_class = EnergyModelDP pt_class = EnergyModelPT pd_class = EnergyModelPD + pt_expt_class = EnergyModelPTExpt jax_class = EnergyModelJAX pd_class = EnergyModelPD args = model_args() @@ -128,6 +135,8 @@ def get_reference_backend(self): return self.RefBackend.PT if not self.skip_tf: return self.RefBackend.TF + if not self.skip_pt_expt and self.pt_expt_class is not None: + return self.RefBackend.PT_EXPT if not self.skip_jax: return self.RefBackend.JAX if not self.skip_pd: @@ -156,6 +165,9 @@ def pass_data_to_cls(self, cls, data) -> Any: model = get_model_pt(data) model.atomic_model.out_bias.uniform_() return model + elif cls is EnergyModelPTExpt: + dp_model = get_model_dp(data) + return EnergyModelPTExpt.deserialize(dp_model.serialize()) elif cls is EnergyModelJAX: return get_model_jax(data) elif cls is EnergyModelPD: @@ -229,6 +241,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_model( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_model( jax_obj, @@ -265,6 +286,14 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["virial"].ravel(), ret["atom_virial"].ravel(), ) + elif backend is self.RefBackend.PT_EXPT: + return ( + ret["energy"].ravel(), + ret["atom_energy"].ravel(), + ret["force"].ravel(), + ret["virial"].ravel(), + ret["atom_virial"].ravel(), + ) elif backend is self.RefBackend.TF: return ( ret[0].ravel(), @@ -339,6 +368,7 @@ def data(self) -> dict: tf_class = EnergyModelTF dp_class = EnergyModelDP pt_class = EnergyModelPT + pt_expt_class = EnergyModelPTExpt jax_class = EnergyModelJAX pd_class = EnergyModelPD args = model_args() @@ -350,6 +380,8 @@ def get_reference_backend(self): """ if not self.skip_pt: return self.RefBackend.PT + if not self.skip_pt_expt and self.pt_expt_class is not None: + return self.RefBackend.PT_EXPT if not self.skip_jax: return self.RefBackend.JAX if not self.skip_dp: @@ -374,6 +406,9 @@ def pass_data_to_cls(self, cls, data) -> Any: return get_model_dp(data) elif cls is EnergyModelPT: return get_model_pt(data) + elif cls is EnergyModelPTExpt: + dp_model = get_model_dp(data) + return EnergyModelPTExpt.deserialize(dp_model.serialize()) elif cls is EnergyModelJAX: return get_model_jax(data) elif cls is EnergyModelPD: @@ -460,6 +495,20 @@ def eval_pt(self, pt_obj: Any) -> Any: ).items() } + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + coord_tensor = pt_expt_numpy_to_torch(self.extended_coord) + coord_tensor.requires_grad_(True) + return { + kk: vv.detach().cpu().numpy() if vv is not None else None + for kk, vv in pt_expt_obj.call_lower( + coord_tensor, + pt_expt_numpy_to_torch(self.extended_atype), + pt_expt_numpy_to_torch(self.nlist), + pt_expt_numpy_to_torch(self.mapping), + do_atomic_virial=True, + ).items() + } + def eval_jax(self, jax_obj: Any) -> Any: return { kk: to_numpy_array(vv) @@ -502,6 +551,14 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["virial"].ravel(), ret["extended_virial"].ravel(), ) + elif backend is self.RefBackend.PT_EXPT: + return ( + ret["energy_redu"].ravel(), + ret["energy"].ravel(), + ret["energy_derv_r"].ravel(), + ret["energy_derv_c_redu"].ravel(), + ret["energy_derv_c"].ravel(), + ) elif backend is self.RefBackend.JAX: return ( ret["energy_redu"].ravel(), diff --git a/source/tests/pt_expt/model/test_autodiff.py b/source/tests/pt_expt/model/test_autodiff.py new file mode 100644 index 0000000000..de404b5b95 --- /dev/null +++ b/source/tests/pt_expt/model/test_autodiff.py @@ -0,0 +1,185 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.model import ( + EnergyModel, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...seed import ( + GLOBAL_SEED, +) + +dtype = torch.float64 + + +def finite_difference(f, x, delta=1e-6): + in_shape = x.shape + y0 = f(x) + out_shape = y0.shape + res = np.empty(out_shape + in_shape) + for idx in np.ndindex(*in_shape): + diff = np.zeros(in_shape) + diff[idx] += delta + y1p = f(x + diff) + y1n = f(x - diff) + res[(Ellipsis, *idx)] = (y1p - y1n) / (2 * delta) + return res + + +def stretch_box(old_coord, old_box, new_box): + ocoord = old_coord.reshape(-1, 3) + obox = old_box.reshape(3, 3) + nbox = new_box.reshape(3, 3) + ncoord = ocoord @ np.linalg.inv(obox) @ nbox + return ncoord.reshape(old_coord.shape) + + +def eval_model(model, coord, cell, atype): + """Evaluate the pt_expt EnergyModel. + + Parameters + ---------- + model : EnergyModel + The model to evaluate. + coord : torch.Tensor + Coordinates, shape [nf, natoms, 3]. + cell : torch.Tensor + Cell, shape [nf, 3, 3]. + atype : torch.Tensor + Atom types, shape [natoms]. + + Returns + ------- + dict + Model predictions with keys: energy, force, virial. + """ + nframes = coord.shape[0] + if len(atype.shape) == 1: + atype = atype.unsqueeze(0).expand(nframes, -1) + coord_input = coord.to(dtype=dtype, device=env.DEVICE) + cell_input = cell.reshape(nframes, 9).to(dtype=dtype, device=env.DEVICE) + atype_input = atype.to(dtype=torch.long, device=env.DEVICE) + coord_input.requires_grad_(True) + result = model(coord_input, atype_input, cell_input) + return result + + +class ForceTest: + def test(self) -> None: + places = 5 + delta = 1e-5 + natoms = 5 + generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED) + cell = torch.rand([3, 3], dtype=dtype, device="cpu", generator=generator) + cell = (cell + cell.T) + 5.0 * torch.eye(3, device="cpu") + coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]) + coord = coord.numpy() + + def np_infer_coord(coord): + result = eval_model( + self.model, + torch.tensor(coord, device=env.DEVICE).unsqueeze(0), + cell.unsqueeze(0), + atype, + ) + ret = { + key: result[key].squeeze(0).detach().cpu().numpy() + for key in ["energy", "force", "virial"] + } + return ret + + def ff_coord(_coord): + return np_infer_coord(_coord)["energy"] + + fdf = -finite_difference(ff_coord, coord, delta=delta).squeeze() + rff = np_infer_coord(coord)["force"] + np.testing.assert_almost_equal(fdf, rff, decimal=places) + + +class VirialTest: + def test(self) -> None: + places = 5 + delta = 1e-4 + natoms = 5 + generator = torch.Generator(device="cpu").manual_seed(GLOBAL_SEED) + cell = torch.rand([3, 3], dtype=dtype, device="cpu", generator=generator) + cell = (cell) + 5.0 * torch.eye(3, device="cpu") + coord = torch.rand([natoms, 3], dtype=dtype, device="cpu", generator=generator) + coord = torch.matmul(coord, cell) + atype = torch.IntTensor([0, 0, 0, 1, 1]) + coord = coord.numpy() + cell = cell.numpy() + + def np_infer(new_cell): + result = eval_model( + self.model, + torch.tensor( + stretch_box(coord, cell, new_cell), device="cpu" + ).unsqueeze(0), + torch.tensor(new_cell, device="cpu").unsqueeze(0), + atype, + ) + ret = { + key: result[key].squeeze(0).detach().cpu().numpy() + for key in ["energy", "force", "virial"] + } + return ret + + def ff(bb): + return np_infer(bb)["energy"] + + fdv = ( + -(finite_difference(ff, cell, delta=delta).transpose(0, 2, 1) @ cell) + .squeeze() + .reshape(9) + ) + rfv = np_infer(cell)["virial"] + np.testing.assert_almost_equal(fdv, rfv, decimal=places) + + +class TestEnergyModelSeAForce(unittest.TestCase, ForceTest): + def setUp(self) -> None: + ds = DescrptSeA(4.0, 0.5, [8, 6]).to(env.DEVICE) + ft = InvarFitting( + "energy", + 2, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(env.DEVICE) + self.model = EnergyModel(ds, ft, type_map=["foo", "bar"]).to(env.DEVICE) + self.model.eval() + + +class TestEnergyModelSeAVirial(unittest.TestCase, VirialTest): + def setUp(self) -> None: + ds = DescrptSeA(4.0, 0.5, [8, 6]).to(env.DEVICE) + ft = InvarFitting( + "energy", + 2, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(env.DEVICE) + self.model = EnergyModel(ds, ft, type_map=["foo", "bar"]).to(env.DEVICE) + self.model.eval() + + +if __name__ == "__main__": + unittest.main() diff --git a/source/tests/pt_expt/model/test_ener_model.py b/source/tests/pt_expt/model/test_ener_model.py new file mode 100644 index 0000000000..c243fb792d --- /dev/null +++ b/source/tests/pt_expt/model/test_ener_model.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA +from deepmd.dpmodel.fitting import InvarFitting as DPInvarFitting +from deepmd.dpmodel.model.ener_model import EnergyModel as DPEnergyModel +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + InvarFitting, +) +from deepmd.pt_expt.model import ( + EnergyModel, +) +from deepmd.pt_expt.utils import ( + env, +) + +from ...seed import ( + GLOBAL_SEED, +) + + +class TestEnergyModel(unittest.TestCase): + def setUp(self) -> None: + self.device = env.DEVICE + self.natoms = 5 + self.rcut = 4.0 + self.rcut_smth = 0.5 + self.sel = [8, 6] + self.nt = 2 + self.type_map = ["foo", "bar"] + + generator = torch.Generator(device=self.device).manual_seed(GLOBAL_SEED) + cell = torch.rand( + [3, 3], dtype=torch.float64, device=self.device, generator=generator + ) + cell = (cell + cell.T) + 5.0 * torch.eye(3, device=self.device) + self.cell = cell.unsqueeze(0) # [1, 3, 3] + coord = torch.rand( + [self.natoms, 3], + dtype=torch.float64, + device=self.device, + generator=generator, + ) + coord = torch.matmul(coord, cell) + self.coord = coord.unsqueeze(0).to(self.device) # [1, natoms, 3] + self.atype = torch.tensor( + [[0, 0, 0, 1, 1]], dtype=torch.int64, device=self.device + ) + + def _make_model(self) -> EnergyModel: + ds = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ).to(self.device) + ft = InvarFitting( + "energy", + self.nt, + ds.get_dim_out(), + 1, + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ).to(self.device) + return EnergyModel(ds, ft, type_map=self.type_map).to(self.device) + + def test_output_keys(self) -> None: + """Test that EnergyModel produces expected output keys.""" + md = self._make_model() + md.eval() + coord = self.coord.clone().requires_grad_(True) + ret = md(coord, self.atype, self.cell.reshape(1, 9)) + self.assertIn("energy", ret) + self.assertIn("atom_energy", ret) + self.assertIn("force", ret) + self.assertIn("virial", ret) + + def test_output_shapes(self) -> None: + """Test that output shapes are correct.""" + md = self._make_model() + md.eval() + coord = self.coord.clone().requires_grad_(True) + ret = md(coord, self.atype, self.cell.reshape(1, 9)) + self.assertEqual(ret["energy"].shape, (1, 1)) + self.assertEqual(ret["atom_energy"].shape, (1, self.natoms, 1)) + self.assertEqual(ret["force"].shape, (1, self.natoms, 3)) + self.assertEqual(ret["virial"].shape, (1, 9)) + + @unittest.expectedFailure + def test_exportable(self) -> None: + """Test that EnergyModel can be exported with torch.export. + + Currently expected to fail because the full model's call() path includes + extend_coord_with_ghosts and neighbor list building, which involve + data-dependent shapes (item() calls) that torch.export cannot trace. + Individual components (descriptor, fitting, atomic model) are exportable. + """ + md = self._make_model() + md.eval() + coord = self.coord.clone().requires_grad_(True) + cell = self.cell.reshape(1, 9) + + # Test forward pass + ret0 = md(coord, self.atype, cell) + self.assertIn("energy", ret0) + + # Test torch.export + exported = torch.export.export( + md, + (coord, self.atype, cell), + strict=False, + ) + self.assertIsNotNone(exported) + + # Test exported model produces same output + coord2 = self.coord.clone().requires_grad_(True) + ret1 = exported.module()(coord2, self.atype, cell) + np.testing.assert_allclose( + ret0["energy"].detach().cpu().numpy(), + ret1["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + ret0["force"].detach().cpu().numpy(), + ret1["force"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + def test_dp_consistency(self) -> None: + """Test numerical consistency with dpmodel (energy values).""" + # Build dpmodel version + ds_dp = DPDescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + ) + ft_dp = DPInvarFitting( + "energy", + self.nt, + ds_dp.get_dim_out(), + 1, + mixed_types=ds_dp.mixed_types(), + seed=GLOBAL_SEED, + ) + md_dp = DPEnergyModel(ds_dp, ft_dp, type_map=self.type_map) + + # Build pt_expt version from serialized dpmodel + md_pt = EnergyModel.deserialize(md_dp.serialize()).to(self.device) + md_pt.eval() + + # dpmodel inference + coord_np = self.coord.detach().cpu().numpy() + atype_np = self.atype.detach().cpu().numpy() + cell_np = self.cell.reshape(1, 9).detach().cpu().numpy() + ret_dp = md_dp(coord_np.reshape(1, -1), atype_np, cell_np) + + # pt_expt inference + coord = self.coord.clone().requires_grad_(True) + ret_pt = md_pt(coord, self.atype, self.cell.reshape(1, 9)) + + np.testing.assert_allclose( + ret_dp["energy_redu"], + ret_pt["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + ret_dp["energy"], + ret_pt["atom_energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + + +if __name__ == "__main__": + unittest.main() From fb08ffca5bc081e6942d5f0cc34907f6842fddd6 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 12 Feb 2026 19:15:35 +0800 Subject: [PATCH 40/60] add missing file --- deepmd/dpmodel/utils/stat.py | 544 +++++++++++++++++++++++++++++++++++ 1 file changed, 544 insertions(+) create mode 100644 deepmd/dpmodel/utils/stat.py diff --git a/deepmd/dpmodel/utils/stat.py b/deepmd/dpmodel/utils/stat.py new file mode 100644 index 0000000000..1cbaad0275 --- /dev/null +++ b/deepmd/dpmodel/utils/stat.py @@ -0,0 +1,544 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Output statistics computation for dpmodel backend.""" + +import logging +from collections import ( + defaultdict, +) +from collections.abc import ( + Callable, +) + +import numpy as np + +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.utils.out_stat import ( + compute_stats_do_not_distinguish_types, + compute_stats_from_atomic, + compute_stats_from_redu, +) +from deepmd.utils.path import ( + DPPath, +) + +log = logging.getLogger(__name__) + + +def _restore_from_file( + stat_file_path: DPPath, + keys: list[str], +) -> tuple[dict | None, dict | None]: + """Restore bias and std from stat file.""" + if stat_file_path is None: + return None, None + stat_files = [stat_file_path / f"bias_atom_{kk}" for kk in keys] + if all(not (ii.is_file()) for ii in stat_files): + return None, None + stat_files = [stat_file_path / f"std_atom_{kk}" for kk in keys] + if all(not (ii.is_file()) for ii in stat_files): + return None, None + + ret_bias = {} + ret_std = {} + for kk in keys: + fp = stat_file_path / f"bias_atom_{kk}" + if fp.is_file(): + ret_bias[kk] = fp.load_numpy() + for kk in keys: + fp = stat_file_path / f"std_atom_{kk}" + if fp.is_file(): + ret_std[kk] = fp.load_numpy() + return ret_bias, ret_std + + +def _save_to_file( + stat_file_path: DPPath, + bias_out: dict, + std_out: dict, +) -> None: + """Save bias and std to stat file.""" + assert stat_file_path is not None + stat_file_path.mkdir(exist_ok=True, parents=True) + for kk, vv in bias_out.items(): + fp = stat_file_path / f"bias_atom_{kk}" + fp.save_numpy(vv) + for kk, vv in std_out.items(): + fp = stat_file_path / f"std_atom_{kk}" + fp.save_numpy(vv) + + +def _post_process_stat( + out_bias: dict, + out_std: dict, +) -> tuple[dict, dict]: + """Post process the statistics. + + For global statistics, we do not have the std for each type of atoms, + thus fake the output std by ones for all the types. + If the shape of out_std is already the same as out_bias, + we do not need to do anything. + """ + new_std = {} + for kk, vv in out_bias.items(): + if vv.shape == out_std[kk].shape: + new_std[kk] = out_std[kk] + else: + new_std[kk] = np.ones_like(vv) + return out_bias, new_std + + +def _make_preset_out_bias( + ntypes: int, + ibias: list[np.ndarray | None], +) -> np.ndarray | None: + """Make preset out bias. + + output: + a np array of shape [ntypes, *(odim0, odim1, ...)] is any item is not None + None if all items are None. + """ + if len(ibias) != ntypes: + raise ValueError("the length of preset bias list should be ntypes") + if all(ii is None for ii in ibias): + return None + for refb in ibias: + if refb is not None: + break + refb = np.array(refb) + nbias = [ + np.full_like(refb, np.nan, dtype=np.float64) if ii is None else ii + for ii in ibias + ] + return np.array(nbias) + + +def _fill_stat_with_global( + atomic_stat: np.ndarray | None, + global_stat: np.ndarray, +) -> np.ndarray | None: + """This function is used to fill atomic stat with global stat. + + Parameters + ---------- + atomic_stat : Union[np.ndarray, None] + The atomic stat. + global_stat : np.ndarray + The global stat. + if the atomic stat is None, use global stat. + if the atomic stat is not None, but has nan values (missing atypes), fill with global stat. + """ + if atomic_stat is None: + return global_stat + else: + atomic_stat = atomic_stat.reshape(*global_stat.shape) + return np.nan_to_num( + np.where( + np.isnan(atomic_stat) & ~np.isnan(global_stat), global_stat, atomic_stat + ) + ) + + +def _compute_model_predict( + sampled: list[dict], + keys: list[str], + model_forward: Callable, +) -> dict[str, list[np.ndarray]]: + """Compute model predictions for all samples.""" + model_predict = {kk: [] for kk in keys} + for system in sampled: + # Convert inputs to numpy to avoid gradient issues + coord = to_numpy_array(system["coord"]) + atype = to_numpy_array(system["atype"]) + box = to_numpy_array(system["box"]) + fparam = to_numpy_array(system.get("fparam", None)) + aparam = to_numpy_array(system.get("aparam", None)) + + sample_predict = model_forward(coord, atype, box, fparam=fparam, aparam=aparam) + for kk in keys: + model_predict[kk].append( + sample_predict[kk] # already numpy from model_forward + ) + return model_predict + + +def compute_output_stats( + merged: Callable[[], list[dict]] | list[dict], + ntypes: int, + keys: str | list[str], + stat_file_path: DPPath | None = None, + rcond: float | None = None, + preset_bias: dict[str, list[np.ndarray | None]] | None = None, + model_forward: Callable | None = None, + stats_distinguish_types: bool = True, + intensive: bool = False, +) -> tuple[dict, dict]: + """ + Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + ntypes : int + The number of atom types. + keys : Union[str, list[str]] + The keys of the output properties to compute statistics for. + stat_file_path : DPPath, optional + The path to the stat file. + rcond : float, optional + The condition number for the regression of atomic energy. + preset_bias : dict[str, list[Optional[np.ndarray]]], optional + Specifying atomic energy contribution in vacuum. Given by key:value pairs. + The value is a list specifying the bias. the elements can be None or np.ndarray of output shape. + For example: [None, [2.]] means type 0 is not set, type 1 is set to [2.] + The `set_davg_zero` key in the descriptor should be set. + model_forward : Callable, optional + The wrapped forward function of atomic model. + If not None, the model will be utilized to generate the original energy prediction, + which will be subtracted from the energy label of the data. + The difference will then be used to calculate the delta complement energy bias for each type. + stats_distinguish_types : bool, optional + Whether to distinguish different element types in the statistics. + intensive : bool, optional + Whether the fitting target is intensive. + """ + # normalize keys to list + keys = [keys] if isinstance(keys, str) else keys + assert isinstance(keys, list) + + # try to restore the bias from stat file + bias_atom_e, std_atom_e = _restore_from_file(stat_file_path, keys) + + # failed to restore the bias from stat file. compute + if bias_atom_e is None: + # only get data once, sampled is a list of dict[str, np.ndarray] + sampled = merged() if callable(merged) else merged + + # remove the keys that are not in the sample + new_keys = [ + ii + for ii in keys + if (ii in sampled[0].keys()) or ("atom_" + ii in sampled[0].keys()) + ] + keys = new_keys + + # compute model predictions if model_forward is provided + if model_forward is not None: + model_pred = _compute_model_predict(sampled, keys, model_forward) + else: + model_pred = None + + # split system based on label + atomic_sampled_idx = defaultdict(list) + global_sampled_idx = defaultdict(list) + + for kk in keys: + for idx, system in enumerate(sampled): + if (("find_atom_" + kk) in system) and ( + system["find_atom_" + kk] > 0.0 + ): + atomic_sampled_idx[kk].append(idx) + elif (("find_" + kk) in system) and (system["find_" + kk] > 0.0): + global_sampled_idx[kk].append(idx) + else: + continue + + # use index to gather model predictions for the corresponding systems. + model_pred_g = ( + { + kk: [ + np.sum(vv[idx], axis=1) for idx in global_sampled_idx[kk] + ] # sum atomic dim + for kk, vv in model_pred.items() + } + if model_pred + else None + ) + model_pred_a = ( + { + kk: [vv[idx] for idx in atomic_sampled_idx[kk]] + for kk, vv in model_pred.items() + } + if model_pred + else None + ) + + # concat all frames within those systems + model_pred_g = ( + { + kk: np.concatenate(model_pred_g[kk]) + for kk in model_pred_g.keys() + if len(model_pred_g[kk]) > 0 + } + if model_pred + else None + ) + model_pred_a = ( + { + kk: np.concatenate(model_pred_a[kk]) + for kk in model_pred_a.keys() + if len(model_pred_a[kk]) > 0 + } + if model_pred + else None + ) + + # compute stat + bias_atom_g, std_atom_g = compute_output_stats_global( + sampled, + ntypes, + keys, + rcond, + preset_bias, + global_sampled_idx, + stats_distinguish_types, + intensive, + model_pred_g, + ) + bias_atom_a, std_atom_a = compute_output_stats_atomic( + sampled, + ntypes, + keys, + atomic_sampled_idx, + model_pred_a, + ) + + # merge global/atomic bias + bias_atom_e, std_atom_e = {}, {} + for kk in keys: + # use atomic bias whenever available + if kk in bias_atom_a: + bias_atom_e[kk] = bias_atom_a[kk] + std_atom_e[kk] = std_atom_a[kk] + else: + bias_atom_e[kk] = None + std_atom_e[kk] = None + # use global bias to fill missing atomic bias + if kk in bias_atom_g: + bias_atom_e[kk] = _fill_stat_with_global( + bias_atom_e[kk], bias_atom_g[kk] + ) + std_atom_e[kk] = _fill_stat_with_global(std_atom_e[kk], std_atom_g[kk]) + if (bias_atom_e[kk] is None) or (std_atom_e[kk] is None): + raise RuntimeError("Fail to compute stat.") + + if stat_file_path is not None: + _save_to_file(stat_file_path, bias_atom_e, std_atom_e) + + return bias_atom_e, std_atom_e + + +def compute_output_stats_global( + sampled: list[dict], + ntypes: int, + keys: list[str], + rcond: float | None = None, + preset_bias: dict[str, list[np.ndarray | None]] | None = None, + global_sampled_idx: dict | None = None, + stats_distinguish_types: bool = True, + intensive: bool = False, + model_pred: dict[str, np.ndarray] | None = None, +) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: + """This function only handle stat computation from reduced global labels.""" + # return directly if no global samples + if global_sampled_idx is None or all( + len(v) == 0 for v in global_sampled_idx.values() + ): + return {}, {} + + # get label dict from sample; for each key, only picking the system with global labels. + outputs = { + kk: [to_numpy_array(sampled[idx][kk]) for idx in global_sampled_idx.get(kk, [])] + for kk in keys + } + + natoms_key = "natoms" + input_natoms = { + kk: [ + to_numpy_array(sampled[idx][natoms_key]) + for idx in global_sampled_idx.get(kk, []) + ] + for kk in keys + } + + # shape: (nframes, ndim) + merged_output = { + kk: np.concatenate(outputs[kk]) for kk in keys if len(outputs[kk]) > 0 + } + # shape: (nframes, ntypes) + merged_natoms = { + kk: np.concatenate(input_natoms[kk])[:, 2:] + for kk in keys + if len(input_natoms[kk]) > 0 + } + nf = {kk: merged_natoms[kk].shape[0] for kk in keys if kk in merged_natoms} + + if preset_bias is not None: + assigned_atom_ener = { + kk: _make_preset_out_bias(ntypes, preset_bias[kk]) + if kk in preset_bias.keys() + else None + for kk in keys + } + else: + assigned_atom_ener = dict.fromkeys(keys) + + if model_pred is None: + stats_input = merged_output + else: + # subtract the model bias and output the delta bias + stats_input = { + kk: merged_output[kk] - model_pred[kk].reshape(merged_output[kk].shape) + for kk in keys + if kk in merged_output + } + + bias_atom_e = {} + std_atom_e = {} + for kk in keys: + if kk in stats_input: + if not stats_distinguish_types: + bias_atom_e[kk], std_atom_e[kk] = ( + compute_stats_do_not_distinguish_types( + stats_input[kk], + merged_natoms[kk], + assigned_bias=assigned_atom_ener[kk], + intensive=intensive, + ) + ) + else: + bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_redu( + stats_input[kk], + merged_natoms[kk], + assigned_bias=assigned_atom_ener[kk], + rcond=rcond, + ) + else: + # this key does not have global labels, skip it. + continue + bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e) + + # compute and log rmse + def rmse(x: np.ndarray) -> float: + return np.sqrt(np.mean(np.square(x))) + + if model_pred is None: + unbias_e = { + kk: merged_natoms[kk] @ bias_atom_e[kk].reshape(ntypes, -1) + for kk in bias_atom_e.keys() + } + else: + unbias_e = { + kk: model_pred[kk].reshape(nf[kk], -1) + + merged_natoms[kk] @ bias_atom_e[kk].reshape(ntypes, -1) + for kk in bias_atom_e.keys() + } + atom_numbs = {kk: merged_natoms[kk].sum(-1) for kk in bias_atom_e.keys()} + + for kk in bias_atom_e.keys(): + rmse_ae = rmse( + (unbias_e[kk].reshape(nf[kk], -1) - merged_output[kk].reshape(nf[kk], -1)) + / atom_numbs[kk][:, None] + ) + log.info( + f"RMSE of {kk} per atom after linear regression is: {rmse_ae} in the unit of {kk}." + ) + return bias_atom_e, std_atom_e + + +def compute_output_stats_atomic( + sampled: list[dict], + ntypes: int, + keys: list[str], + atomic_sampled_idx: dict | None = None, + model_pred: dict[str, np.ndarray] | None = None, +) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: + """Compute output statistics from atomic labels.""" + # return directly if no atomic samples + if atomic_sampled_idx is None or all( + len(v) == 0 for v in atomic_sampled_idx.values() + ): + return {}, {} + + # get label dict from sample; for each key, only picking the system with atomic labels. + outputs = { + kk: [ + to_numpy_array(sampled[idx]["atom_" + kk]) + for idx in atomic_sampled_idx.get(kk, []) + ] + for kk in keys + } + natoms = { + kk: [ + to_numpy_array(sampled[idx]["atype"]) + for idx in atomic_sampled_idx.get(kk, []) + ] + for kk in keys + } + + # reshape outputs [nframes, nloc * ndim] --> reshape to [nframes * nloc, 1, ndim] for concatenation + # reshape natoms [nframes, nloc] --> reshape to [nframes * nolc, 1] for concatenation + natoms = {k: [sys_v.reshape(-1, 1) for sys_v in v] for k, v in natoms.items()} + outputs = { + k: [ + sys.reshape(natoms[k][sys_idx].shape[0], 1, -1) + for sys_idx, sys in enumerate(v) + ] + for k, v in outputs.items() + } + + merged_output = { + kk: np.concatenate(outputs[kk]) for kk in keys if len(outputs[kk]) > 0 + } + merged_natoms = { + kk: np.concatenate(natoms[kk]) for kk in keys if len(natoms[kk]) > 0 + } + # reshape merged data to [nf, nloc, ndim] + merged_output = { + kk: merged_output[kk].reshape((*merged_natoms[kk].shape, -1)) + for kk in merged_output + } + + if model_pred is None: + stats_input = merged_output + else: + # subtract the model bias and output the delta bias + stats_input = { + kk: merged_output[kk] - model_pred[kk].reshape(*merged_output[kk].shape) + for kk in keys + if kk in merged_output + } + + bias_atom_e = {} + std_atom_e = {} + + for kk in keys: + if kk in stats_input: + bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_atomic( + stats_input[kk], + merged_natoms[kk], + ) + # correction for missing types + missing_types = ntypes - merged_natoms[kk].max() - 1 + if missing_types > 0: + assert bias_atom_e[kk].dtype is std_atom_e[kk].dtype, ( + "bias and std should be of the same dtypes" + ) + nan_padding = np.empty( + (missing_types, bias_atom_e[kk].shape[1]), + dtype=bias_atom_e[kk].dtype, + ) + nan_padding.fill(np.nan) + bias_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0) + std_atom_e[kk] = np.concatenate([std_atom_e[kk], nan_padding], axis=0) + else: + # this key does not have atomic labels, skip it. + continue + return bias_atom_e, std_atom_e From ed460a59dee7c23873ba4ba84dced9914f4df39f Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 12 Feb 2026 19:15:35 +0800 Subject: [PATCH 41/60] add missing file --- deepmd/dpmodel/utils/stat.py | 544 +++++++++++++++++++++++++++++++++++ 1 file changed, 544 insertions(+) create mode 100644 deepmd/dpmodel/utils/stat.py diff --git a/deepmd/dpmodel/utils/stat.py b/deepmd/dpmodel/utils/stat.py new file mode 100644 index 0000000000..1cbaad0275 --- /dev/null +++ b/deepmd/dpmodel/utils/stat.py @@ -0,0 +1,544 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Output statistics computation for dpmodel backend.""" + +import logging +from collections import ( + defaultdict, +) +from collections.abc import ( + Callable, +) + +import numpy as np + +from deepmd.dpmodel.common import ( + to_numpy_array, +) +from deepmd.utils.out_stat import ( + compute_stats_do_not_distinguish_types, + compute_stats_from_atomic, + compute_stats_from_redu, +) +from deepmd.utils.path import ( + DPPath, +) + +log = logging.getLogger(__name__) + + +def _restore_from_file( + stat_file_path: DPPath, + keys: list[str], +) -> tuple[dict | None, dict | None]: + """Restore bias and std from stat file.""" + if stat_file_path is None: + return None, None + stat_files = [stat_file_path / f"bias_atom_{kk}" for kk in keys] + if all(not (ii.is_file()) for ii in stat_files): + return None, None + stat_files = [stat_file_path / f"std_atom_{kk}" for kk in keys] + if all(not (ii.is_file()) for ii in stat_files): + return None, None + + ret_bias = {} + ret_std = {} + for kk in keys: + fp = stat_file_path / f"bias_atom_{kk}" + if fp.is_file(): + ret_bias[kk] = fp.load_numpy() + for kk in keys: + fp = stat_file_path / f"std_atom_{kk}" + if fp.is_file(): + ret_std[kk] = fp.load_numpy() + return ret_bias, ret_std + + +def _save_to_file( + stat_file_path: DPPath, + bias_out: dict, + std_out: dict, +) -> None: + """Save bias and std to stat file.""" + assert stat_file_path is not None + stat_file_path.mkdir(exist_ok=True, parents=True) + for kk, vv in bias_out.items(): + fp = stat_file_path / f"bias_atom_{kk}" + fp.save_numpy(vv) + for kk, vv in std_out.items(): + fp = stat_file_path / f"std_atom_{kk}" + fp.save_numpy(vv) + + +def _post_process_stat( + out_bias: dict, + out_std: dict, +) -> tuple[dict, dict]: + """Post process the statistics. + + For global statistics, we do not have the std for each type of atoms, + thus fake the output std by ones for all the types. + If the shape of out_std is already the same as out_bias, + we do not need to do anything. + """ + new_std = {} + for kk, vv in out_bias.items(): + if vv.shape == out_std[kk].shape: + new_std[kk] = out_std[kk] + else: + new_std[kk] = np.ones_like(vv) + return out_bias, new_std + + +def _make_preset_out_bias( + ntypes: int, + ibias: list[np.ndarray | None], +) -> np.ndarray | None: + """Make preset out bias. + + output: + a np array of shape [ntypes, *(odim0, odim1, ...)] is any item is not None + None if all items are None. + """ + if len(ibias) != ntypes: + raise ValueError("the length of preset bias list should be ntypes") + if all(ii is None for ii in ibias): + return None + for refb in ibias: + if refb is not None: + break + refb = np.array(refb) + nbias = [ + np.full_like(refb, np.nan, dtype=np.float64) if ii is None else ii + for ii in ibias + ] + return np.array(nbias) + + +def _fill_stat_with_global( + atomic_stat: np.ndarray | None, + global_stat: np.ndarray, +) -> np.ndarray | None: + """This function is used to fill atomic stat with global stat. + + Parameters + ---------- + atomic_stat : Union[np.ndarray, None] + The atomic stat. + global_stat : np.ndarray + The global stat. + if the atomic stat is None, use global stat. + if the atomic stat is not None, but has nan values (missing atypes), fill with global stat. + """ + if atomic_stat is None: + return global_stat + else: + atomic_stat = atomic_stat.reshape(*global_stat.shape) + return np.nan_to_num( + np.where( + np.isnan(atomic_stat) & ~np.isnan(global_stat), global_stat, atomic_stat + ) + ) + + +def _compute_model_predict( + sampled: list[dict], + keys: list[str], + model_forward: Callable, +) -> dict[str, list[np.ndarray]]: + """Compute model predictions for all samples.""" + model_predict = {kk: [] for kk in keys} + for system in sampled: + # Convert inputs to numpy to avoid gradient issues + coord = to_numpy_array(system["coord"]) + atype = to_numpy_array(system["atype"]) + box = to_numpy_array(system["box"]) + fparam = to_numpy_array(system.get("fparam", None)) + aparam = to_numpy_array(system.get("aparam", None)) + + sample_predict = model_forward(coord, atype, box, fparam=fparam, aparam=aparam) + for kk in keys: + model_predict[kk].append( + sample_predict[kk] # already numpy from model_forward + ) + return model_predict + + +def compute_output_stats( + merged: Callable[[], list[dict]] | list[dict], + ntypes: int, + keys: str | list[str], + stat_file_path: DPPath | None = None, + rcond: float | None = None, + preset_bias: dict[str, list[np.ndarray | None]] | None = None, + model_forward: Callable | None = None, + stats_distinguish_types: bool = True, + intensive: bool = False, +) -> tuple[dict, dict]: + """ + Compute the output statistics (e.g. energy bias) for the fitting net from packed data. + + Parameters + ---------- + merged : Union[Callable[[], list[dict]], list[dict]] + - list[dict]: A list of data samples from various data systems. + Each element, `merged[i]`, is a data dictionary containing `keys`: `np.ndarray` + originating from the `i`-th data system. + - Callable[[], list[dict]]: A lazy function that returns data samples in the above format + only when needed. Since the sampling process can be slow and memory-intensive, + the lazy function helps by only sampling once. + ntypes : int + The number of atom types. + keys : Union[str, list[str]] + The keys of the output properties to compute statistics for. + stat_file_path : DPPath, optional + The path to the stat file. + rcond : float, optional + The condition number for the regression of atomic energy. + preset_bias : dict[str, list[Optional[np.ndarray]]], optional + Specifying atomic energy contribution in vacuum. Given by key:value pairs. + The value is a list specifying the bias. the elements can be None or np.ndarray of output shape. + For example: [None, [2.]] means type 0 is not set, type 1 is set to [2.] + The `set_davg_zero` key in the descriptor should be set. + model_forward : Callable, optional + The wrapped forward function of atomic model. + If not None, the model will be utilized to generate the original energy prediction, + which will be subtracted from the energy label of the data. + The difference will then be used to calculate the delta complement energy bias for each type. + stats_distinguish_types : bool, optional + Whether to distinguish different element types in the statistics. + intensive : bool, optional + Whether the fitting target is intensive. + """ + # normalize keys to list + keys = [keys] if isinstance(keys, str) else keys + assert isinstance(keys, list) + + # try to restore the bias from stat file + bias_atom_e, std_atom_e = _restore_from_file(stat_file_path, keys) + + # failed to restore the bias from stat file. compute + if bias_atom_e is None: + # only get data once, sampled is a list of dict[str, np.ndarray] + sampled = merged() if callable(merged) else merged + + # remove the keys that are not in the sample + new_keys = [ + ii + for ii in keys + if (ii in sampled[0].keys()) or ("atom_" + ii in sampled[0].keys()) + ] + keys = new_keys + + # compute model predictions if model_forward is provided + if model_forward is not None: + model_pred = _compute_model_predict(sampled, keys, model_forward) + else: + model_pred = None + + # split system based on label + atomic_sampled_idx = defaultdict(list) + global_sampled_idx = defaultdict(list) + + for kk in keys: + for idx, system in enumerate(sampled): + if (("find_atom_" + kk) in system) and ( + system["find_atom_" + kk] > 0.0 + ): + atomic_sampled_idx[kk].append(idx) + elif (("find_" + kk) in system) and (system["find_" + kk] > 0.0): + global_sampled_idx[kk].append(idx) + else: + continue + + # use index to gather model predictions for the corresponding systems. + model_pred_g = ( + { + kk: [ + np.sum(vv[idx], axis=1) for idx in global_sampled_idx[kk] + ] # sum atomic dim + for kk, vv in model_pred.items() + } + if model_pred + else None + ) + model_pred_a = ( + { + kk: [vv[idx] for idx in atomic_sampled_idx[kk]] + for kk, vv in model_pred.items() + } + if model_pred + else None + ) + + # concat all frames within those systems + model_pred_g = ( + { + kk: np.concatenate(model_pred_g[kk]) + for kk in model_pred_g.keys() + if len(model_pred_g[kk]) > 0 + } + if model_pred + else None + ) + model_pred_a = ( + { + kk: np.concatenate(model_pred_a[kk]) + for kk in model_pred_a.keys() + if len(model_pred_a[kk]) > 0 + } + if model_pred + else None + ) + + # compute stat + bias_atom_g, std_atom_g = compute_output_stats_global( + sampled, + ntypes, + keys, + rcond, + preset_bias, + global_sampled_idx, + stats_distinguish_types, + intensive, + model_pred_g, + ) + bias_atom_a, std_atom_a = compute_output_stats_atomic( + sampled, + ntypes, + keys, + atomic_sampled_idx, + model_pred_a, + ) + + # merge global/atomic bias + bias_atom_e, std_atom_e = {}, {} + for kk in keys: + # use atomic bias whenever available + if kk in bias_atom_a: + bias_atom_e[kk] = bias_atom_a[kk] + std_atom_e[kk] = std_atom_a[kk] + else: + bias_atom_e[kk] = None + std_atom_e[kk] = None + # use global bias to fill missing atomic bias + if kk in bias_atom_g: + bias_atom_e[kk] = _fill_stat_with_global( + bias_atom_e[kk], bias_atom_g[kk] + ) + std_atom_e[kk] = _fill_stat_with_global(std_atom_e[kk], std_atom_g[kk]) + if (bias_atom_e[kk] is None) or (std_atom_e[kk] is None): + raise RuntimeError("Fail to compute stat.") + + if stat_file_path is not None: + _save_to_file(stat_file_path, bias_atom_e, std_atom_e) + + return bias_atom_e, std_atom_e + + +def compute_output_stats_global( + sampled: list[dict], + ntypes: int, + keys: list[str], + rcond: float | None = None, + preset_bias: dict[str, list[np.ndarray | None]] | None = None, + global_sampled_idx: dict | None = None, + stats_distinguish_types: bool = True, + intensive: bool = False, + model_pred: dict[str, np.ndarray] | None = None, +) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: + """This function only handle stat computation from reduced global labels.""" + # return directly if no global samples + if global_sampled_idx is None or all( + len(v) == 0 for v in global_sampled_idx.values() + ): + return {}, {} + + # get label dict from sample; for each key, only picking the system with global labels. + outputs = { + kk: [to_numpy_array(sampled[idx][kk]) for idx in global_sampled_idx.get(kk, [])] + for kk in keys + } + + natoms_key = "natoms" + input_natoms = { + kk: [ + to_numpy_array(sampled[idx][natoms_key]) + for idx in global_sampled_idx.get(kk, []) + ] + for kk in keys + } + + # shape: (nframes, ndim) + merged_output = { + kk: np.concatenate(outputs[kk]) for kk in keys if len(outputs[kk]) > 0 + } + # shape: (nframes, ntypes) + merged_natoms = { + kk: np.concatenate(input_natoms[kk])[:, 2:] + for kk in keys + if len(input_natoms[kk]) > 0 + } + nf = {kk: merged_natoms[kk].shape[0] for kk in keys if kk in merged_natoms} + + if preset_bias is not None: + assigned_atom_ener = { + kk: _make_preset_out_bias(ntypes, preset_bias[kk]) + if kk in preset_bias.keys() + else None + for kk in keys + } + else: + assigned_atom_ener = dict.fromkeys(keys) + + if model_pred is None: + stats_input = merged_output + else: + # subtract the model bias and output the delta bias + stats_input = { + kk: merged_output[kk] - model_pred[kk].reshape(merged_output[kk].shape) + for kk in keys + if kk in merged_output + } + + bias_atom_e = {} + std_atom_e = {} + for kk in keys: + if kk in stats_input: + if not stats_distinguish_types: + bias_atom_e[kk], std_atom_e[kk] = ( + compute_stats_do_not_distinguish_types( + stats_input[kk], + merged_natoms[kk], + assigned_bias=assigned_atom_ener[kk], + intensive=intensive, + ) + ) + else: + bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_redu( + stats_input[kk], + merged_natoms[kk], + assigned_bias=assigned_atom_ener[kk], + rcond=rcond, + ) + else: + # this key does not have global labels, skip it. + continue + bias_atom_e, std_atom_e = _post_process_stat(bias_atom_e, std_atom_e) + + # compute and log rmse + def rmse(x: np.ndarray) -> float: + return np.sqrt(np.mean(np.square(x))) + + if model_pred is None: + unbias_e = { + kk: merged_natoms[kk] @ bias_atom_e[kk].reshape(ntypes, -1) + for kk in bias_atom_e.keys() + } + else: + unbias_e = { + kk: model_pred[kk].reshape(nf[kk], -1) + + merged_natoms[kk] @ bias_atom_e[kk].reshape(ntypes, -1) + for kk in bias_atom_e.keys() + } + atom_numbs = {kk: merged_natoms[kk].sum(-1) for kk in bias_atom_e.keys()} + + for kk in bias_atom_e.keys(): + rmse_ae = rmse( + (unbias_e[kk].reshape(nf[kk], -1) - merged_output[kk].reshape(nf[kk], -1)) + / atom_numbs[kk][:, None] + ) + log.info( + f"RMSE of {kk} per atom after linear regression is: {rmse_ae} in the unit of {kk}." + ) + return bias_atom_e, std_atom_e + + +def compute_output_stats_atomic( + sampled: list[dict], + ntypes: int, + keys: list[str], + atomic_sampled_idx: dict | None = None, + model_pred: dict[str, np.ndarray] | None = None, +) -> tuple[dict[str, np.ndarray], dict[str, np.ndarray]]: + """Compute output statistics from atomic labels.""" + # return directly if no atomic samples + if atomic_sampled_idx is None or all( + len(v) == 0 for v in atomic_sampled_idx.values() + ): + return {}, {} + + # get label dict from sample; for each key, only picking the system with atomic labels. + outputs = { + kk: [ + to_numpy_array(sampled[idx]["atom_" + kk]) + for idx in atomic_sampled_idx.get(kk, []) + ] + for kk in keys + } + natoms = { + kk: [ + to_numpy_array(sampled[idx]["atype"]) + for idx in atomic_sampled_idx.get(kk, []) + ] + for kk in keys + } + + # reshape outputs [nframes, nloc * ndim] --> reshape to [nframes * nloc, 1, ndim] for concatenation + # reshape natoms [nframes, nloc] --> reshape to [nframes * nolc, 1] for concatenation + natoms = {k: [sys_v.reshape(-1, 1) for sys_v in v] for k, v in natoms.items()} + outputs = { + k: [ + sys.reshape(natoms[k][sys_idx].shape[0], 1, -1) + for sys_idx, sys in enumerate(v) + ] + for k, v in outputs.items() + } + + merged_output = { + kk: np.concatenate(outputs[kk]) for kk in keys if len(outputs[kk]) > 0 + } + merged_natoms = { + kk: np.concatenate(natoms[kk]) for kk in keys if len(natoms[kk]) > 0 + } + # reshape merged data to [nf, nloc, ndim] + merged_output = { + kk: merged_output[kk].reshape((*merged_natoms[kk].shape, -1)) + for kk in merged_output + } + + if model_pred is None: + stats_input = merged_output + else: + # subtract the model bias and output the delta bias + stats_input = { + kk: merged_output[kk] - model_pred[kk].reshape(*merged_output[kk].shape) + for kk in keys + if kk in merged_output + } + + bias_atom_e = {} + std_atom_e = {} + + for kk in keys: + if kk in stats_input: + bias_atom_e[kk], std_atom_e[kk] = compute_stats_from_atomic( + stats_input[kk], + merged_natoms[kk], + ) + # correction for missing types + missing_types = ntypes - merged_natoms[kk].max() - 1 + if missing_types > 0: + assert bias_atom_e[kk].dtype is std_atom_e[kk].dtype, ( + "bias and std should be of the same dtypes" + ) + nan_padding = np.empty( + (missing_types, bias_atom_e[kk].shape[1]), + dtype=bias_atom_e[kk].dtype, + ) + nan_padding.fill(np.nan) + bias_atom_e[kk] = np.concatenate([bias_atom_e[kk], nan_padding], axis=0) + std_atom_e[kk] = np.concatenate([std_atom_e[kk], nan_padding], axis=0) + else: + # this key does not have atomic labels, skip it. + continue + return bias_atom_e, std_atom_e From a59c18db8304dffd2a64c4bc569ebc19a58cc345 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 12 Feb 2026 19:27:37 +0800 Subject: [PATCH 42/60] fix test --- .../common/dpmodel/test_fitting_invar_fitting.py | 9 ++++++--- .../pt_expt/fitting/test_fitting_invar_fitting.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/source/tests/common/dpmodel/test_fitting_invar_fitting.py b/source/tests/common/dpmodel/test_fitting_invar_fitting.py index d82d3f43d8..c93e46f66b 100644 --- a/source/tests/common/dpmodel/test_fitting_invar_fitting.py +++ b/source/tests/common/dpmodel/test_fitting_invar_fitting.py @@ -142,19 +142,22 @@ def test_self_exception( iap = None with self.assertRaises(ValueError) as context: ret0 = ifn0(dd[0][:, :, :-2], atype, fparam=ifp, aparam=iap) - self.assertIn("input descriptor", context.exception) + self.assertIn("input descriptor", str(context.exception)) if nfp > 0: ifp = rng.normal(size=(self.nf, nfp - 1)) with self.assertRaises(ValueError) as context: ret0 = ifn0(dd[0], atype, fparam=ifp, aparam=iap) - self.assertIn("input fparam", context.exception) + self.assertIn("input fparam", str(context.exception)) if nap > 0: + # restore correct ifp before testing aparam + if nfp > 0: + ifp = rng.normal(size=(self.nf, nfp)) iap = rng.normal(size=(self.nf, self.nloc, nap - 1)) with self.assertRaises(ValueError) as context: ifn0(dd[0], atype, fparam=ifp, aparam=iap) - self.assertIn("input aparam", context.exception) + self.assertIn("input aparam", str(context.exception)) def test_get_set(self) -> None: ifn0 = InvarFitting( diff --git a/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py b/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py index d682b37145..30cbe84401 100644 --- a/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py +++ b/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py @@ -172,7 +172,7 @@ def test_self_exception( fparam=ifp, aparam=iap, ) - self.assertIn("input descriptor", str(context.exception)) + self.assertIn("input descriptor", str(context.exception)) if nfp > 0: ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp - 1))).to( @@ -185,9 +185,14 @@ def test_self_exception( fparam=ifp, aparam=iap, ) - self.assertIn("input fparam", str(context.exception)) + self.assertIn("input fparam", str(context.exception)) if nap > 0: + # restore correct ifp before testing aparam + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to( + self.device + ) iap = torch.from_numpy( rng.normal(size=(self.nf, self.nloc, nap - 1)) ).to(self.device) @@ -198,7 +203,7 @@ def test_self_exception( fparam=ifp, aparam=iap, ) - self.assertIn("input aparam", str(context.exception)) + self.assertIn("input aparam", str(context.exception)) def test_get_set(self) -> None: ifn0 = InvarFitting( From d057ca1d7ffd0972cbe0f63dc979fe64d91ceee6 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 12 Feb 2026 19:27:37 +0800 Subject: [PATCH 43/60] fix test --- .../common/dpmodel/test_fitting_invar_fitting.py | 9 ++++++--- .../pt_expt/fitting/test_fitting_invar_fitting.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/source/tests/common/dpmodel/test_fitting_invar_fitting.py b/source/tests/common/dpmodel/test_fitting_invar_fitting.py index d82d3f43d8..c93e46f66b 100644 --- a/source/tests/common/dpmodel/test_fitting_invar_fitting.py +++ b/source/tests/common/dpmodel/test_fitting_invar_fitting.py @@ -142,19 +142,22 @@ def test_self_exception( iap = None with self.assertRaises(ValueError) as context: ret0 = ifn0(dd[0][:, :, :-2], atype, fparam=ifp, aparam=iap) - self.assertIn("input descriptor", context.exception) + self.assertIn("input descriptor", str(context.exception)) if nfp > 0: ifp = rng.normal(size=(self.nf, nfp - 1)) with self.assertRaises(ValueError) as context: ret0 = ifn0(dd[0], atype, fparam=ifp, aparam=iap) - self.assertIn("input fparam", context.exception) + self.assertIn("input fparam", str(context.exception)) if nap > 0: + # restore correct ifp before testing aparam + if nfp > 0: + ifp = rng.normal(size=(self.nf, nfp)) iap = rng.normal(size=(self.nf, self.nloc, nap - 1)) with self.assertRaises(ValueError) as context: ifn0(dd[0], atype, fparam=ifp, aparam=iap) - self.assertIn("input aparam", context.exception) + self.assertIn("input aparam", str(context.exception)) def test_get_set(self) -> None: ifn0 = InvarFitting( diff --git a/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py b/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py index d682b37145..30cbe84401 100644 --- a/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py +++ b/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py @@ -172,7 +172,7 @@ def test_self_exception( fparam=ifp, aparam=iap, ) - self.assertIn("input descriptor", str(context.exception)) + self.assertIn("input descriptor", str(context.exception)) if nfp > 0: ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp - 1))).to( @@ -185,9 +185,14 @@ def test_self_exception( fparam=ifp, aparam=iap, ) - self.assertIn("input fparam", str(context.exception)) + self.assertIn("input fparam", str(context.exception)) if nap > 0: + # restore correct ifp before testing aparam + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to( + self.device + ) iap = torch.from_numpy( rng.normal(size=(self.nf, self.nloc, nap - 1)) ).to(self.device) @@ -198,7 +203,7 @@ def test_self_exception( fparam=ifp, aparam=iap, ) - self.assertIn("input aparam", str(context.exception)) + self.assertIn("input aparam", str(context.exception)) def test_get_set(self) -> None: ifn0 = InvarFitting( From b3d22dadcc948466e54a23d7b61d5356c7c67c53 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 12 Feb 2026 19:27:37 +0800 Subject: [PATCH 44/60] fix test --- .../common/dpmodel/test_fitting_invar_fitting.py | 9 ++++++--- .../pt_expt/fitting/test_fitting_invar_fitting.py | 11 ++++++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/source/tests/common/dpmodel/test_fitting_invar_fitting.py b/source/tests/common/dpmodel/test_fitting_invar_fitting.py index d82d3f43d8..c93e46f66b 100644 --- a/source/tests/common/dpmodel/test_fitting_invar_fitting.py +++ b/source/tests/common/dpmodel/test_fitting_invar_fitting.py @@ -142,19 +142,22 @@ def test_self_exception( iap = None with self.assertRaises(ValueError) as context: ret0 = ifn0(dd[0][:, :, :-2], atype, fparam=ifp, aparam=iap) - self.assertIn("input descriptor", context.exception) + self.assertIn("input descriptor", str(context.exception)) if nfp > 0: ifp = rng.normal(size=(self.nf, nfp - 1)) with self.assertRaises(ValueError) as context: ret0 = ifn0(dd[0], atype, fparam=ifp, aparam=iap) - self.assertIn("input fparam", context.exception) + self.assertIn("input fparam", str(context.exception)) if nap > 0: + # restore correct ifp before testing aparam + if nfp > 0: + ifp = rng.normal(size=(self.nf, nfp)) iap = rng.normal(size=(self.nf, self.nloc, nap - 1)) with self.assertRaises(ValueError) as context: ifn0(dd[0], atype, fparam=ifp, aparam=iap) - self.assertIn("input aparam", context.exception) + self.assertIn("input aparam", str(context.exception)) def test_get_set(self) -> None: ifn0 = InvarFitting( diff --git a/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py b/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py index d682b37145..30cbe84401 100644 --- a/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py +++ b/source/tests/pt_expt/fitting/test_fitting_invar_fitting.py @@ -172,7 +172,7 @@ def test_self_exception( fparam=ifp, aparam=iap, ) - self.assertIn("input descriptor", str(context.exception)) + self.assertIn("input descriptor", str(context.exception)) if nfp > 0: ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp - 1))).to( @@ -185,9 +185,14 @@ def test_self_exception( fparam=ifp, aparam=iap, ) - self.assertIn("input fparam", str(context.exception)) + self.assertIn("input fparam", str(context.exception)) if nap > 0: + # restore correct ifp before testing aparam + if nfp > 0: + ifp = torch.from_numpy(rng.normal(size=(self.nf, nfp))).to( + self.device + ) iap = torch.from_numpy( rng.normal(size=(self.nf, self.nloc, nap - 1)) ).to(self.device) @@ -198,7 +203,7 @@ def test_self_exception( fparam=ifp, aparam=iap, ) - self.assertIn("input aparam", str(context.exception)) + self.assertIn("input aparam", str(context.exception)) def test_get_set(self) -> None: ifn0 = InvarFitting( From a920ef6b40d9dbb2c5ca50b3e51d72fb5d94cde6 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 12 Feb 2026 19:42:04 +0800 Subject: [PATCH 45/60] use torch_module to simplify the def of modules. --- .../pt_expt/atomic_model/dp_atomic_model.py | 36 ++----------------- deepmd/pt_expt/fitting/ener_fitting.py | 26 ++------------ deepmd/pt_expt/fitting/invar_fitting.py | 26 ++------------ deepmd/pt_expt/model/make_model.py | 22 ++---------- 4 files changed, 12 insertions(+), 98 deletions(-) diff --git a/deepmd/pt_expt/atomic_model/dp_atomic_model.py b/deepmd/pt_expt/atomic_model/dp_atomic_model.py index 5c00192661..75604c2d97 100644 --- a/deepmd/pt_expt/atomic_model/dp_atomic_model.py +++ b/deepmd/pt_expt/atomic_model/dp_atomic_model.py @@ -1,18 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) import torch from deepmd.dpmodel.atomic_model.dp_atomic_model import DPAtomicModel as DPAtomicModelDP from deepmd.pt_expt.common import ( - dpmodel_setattr, register_dpmodel_mapping, + torch_module, ) -class DPAtomicModel(DPAtomicModelDP, torch.nn.Module): +@torch_module +class DPAtomicModel(DPAtomicModelDP): # Import at class level to set base classes for deserialization # These will be used by the dpmodel deserialize method to create pt_expt instances from deepmd.pt_expt.descriptor.base_descriptor import ( @@ -25,34 +23,6 @@ class DPAtomicModel(DPAtomicModelDP, torch.nn.Module): base_descriptor_cls = BaseDescriptor base_fitting_cls = BaseFitting - def __init__( - self, descriptor: Any, fitting: Any, *args: Any, **kwargs: Any - ) -> None: - torch.nn.Module.__init__(self) - # Convert descriptor and fitting to pt_expt versions if they are dpmodel instances - # The dpmodel_setattr mechanism will handle this automatically via registry - from deepmd.pt_expt.common import ( - try_convert_module, - ) - - descriptor_pt = try_convert_module(descriptor) - fitting_pt = try_convert_module(fitting) - # If conversion failed (not registered), use original (assume already pt_expt) - if descriptor_pt is None: - descriptor_pt = descriptor - if fitting_pt is None: - fitting_pt = fitting - DPAtomicModelDP.__init__(self, descriptor_pt, fitting_pt, *args, **kwargs) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. - return torch.nn.Module.__call__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) - def forward( self, extended_coord: torch.Tensor, diff --git a/deepmd/pt_expt/fitting/ener_fitting.py b/deepmd/pt_expt/fitting/ener_fitting.py index 425040ae75..1c91f09526 100644 --- a/deepmd/pt_expt/fitting/ener_fitting.py +++ b/deepmd/pt_expt/fitting/ener_fitting.py @@ -1,17 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) import torch from deepmd.dpmodel.fitting.ener_fitting import EnergyFittingNet as EnergyFittingNetDP from deepmd.pt_expt.common import ( - dpmodel_setattr, register_dpmodel_mapping, -) -from deepmd.pt_expt.utils.network import ( - NetworkCollection, + torch_module, ) from .base_fitting import ( @@ -20,27 +14,13 @@ @BaseFitting.register("ener") -class EnergyFittingNet(EnergyFittingNetDP, torch.nn.Module): +@torch_module +class EnergyFittingNet(EnergyFittingNetDP): """Energy fitting net for pt_expt backend. This inherits from dpmodel EnergyFittingNet to get the correct serialize() method. """ - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - EnergyFittingNetDP.__init__(self, *args, **kwargs) - # Convert dpmodel NetworkCollection to pt_expt NetworkCollection - self.nets = NetworkCollection.deserialize(self.nets.serialize()) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. - return torch.nn.Module.__call__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) - def forward( self, descriptor: torch.Tensor, diff --git a/deepmd/pt_expt/fitting/invar_fitting.py b/deepmd/pt_expt/fitting/invar_fitting.py index aa37026284..640afe232e 100644 --- a/deepmd/pt_expt/fitting/invar_fitting.py +++ b/deepmd/pt_expt/fitting/invar_fitting.py @@ -1,40 +1,20 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from typing import ( - Any, -) import torch from deepmd.dpmodel.fitting.invar_fitting import InvarFitting as InvarFittingDP from deepmd.pt_expt.common import ( - dpmodel_setattr, register_dpmodel_mapping, + torch_module, ) from deepmd.pt_expt.fitting.base_fitting import ( BaseFitting, ) -from deepmd.pt_expt.utils.network import ( - NetworkCollection, -) @BaseFitting.register("invar") -class InvarFitting(InvarFittingDP, torch.nn.Module): - def __init__(self, *args: Any, **kwargs: Any) -> None: - torch.nn.Module.__init__(self) - InvarFittingDP.__init__(self, *args, **kwargs) - # Convert dpmodel NetworkCollection to pt_expt NetworkCollection - self.nets = NetworkCollection.deserialize(self.nets.serialize()) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. - return torch.nn.Module.__call__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) - +@torch_module +class InvarFitting(InvarFittingDP): def forward( self, descriptor: torch.Tensor, diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index 05a98982d3..d26733696d 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -10,7 +10,7 @@ ) from deepmd.dpmodel.model.make_model import make_model as make_model_dp from deepmd.pt_expt.common import ( - dpmodel_setattr, + torch_module, ) from .transform_output import ( @@ -37,24 +37,8 @@ def make_model(T_AtomicModel: type[BaseAtomicModel]) -> type: """ DPModel = make_model_dp(T_AtomicModel) - class CM(DPModel, torch.nn.Module): - def __init__( - self, - *args: Any, - **kwargs: Any, - ) -> None: - torch.nn.Module.__init__(self) - DPModel.__init__(self, *args, **kwargs) - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. - return torch.nn.Module.__call__(self, *args, **kwargs) - - def __setattr__(self, name: str, value: Any) -> None: - handled, value = dpmodel_setattr(self, name, value) - if not handled: - super().__setattr__(name, value) - + @torch_module + class CM(DPModel): def forward(self, *args: Any, **kwargs: Any) -> dict[str, torch.Tensor]: """Default forward delegates to call(). From 56cbe2dc363a5cfdabf0386e8c1727e71b570f89 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 12 Feb 2026 20:01:05 +0800 Subject: [PATCH 46/60] simplify three autograd to one by vmap, which was made inpossible by jit --- deepmd/pt_expt/model/transform_output.py | 58 +++++++++--------------- 1 file changed, 22 insertions(+), 36 deletions(-) diff --git a/deepmd/pt_expt/model/transform_output.py b/deepmd/pt_expt/model/transform_output.py index 829591983a..5fb1ac4e46 100644 --- a/deepmd/pt_expt/model/transform_output.py +++ b/deepmd/pt_expt/model/transform_output.py @@ -18,47 +18,33 @@ def atomic_virial_corr( atom_energy: torch.Tensor, ) -> torch.Tensor: nall = extended_coord.shape[1] + nf = extended_coord.shape[0] nloc = atom_energy.shape[1] coord, _ = torch.split(extended_coord, [nloc, nall - nloc], dim=1) # no derivative with respect to the loc coord. coord = coord.detach() ce = coord * atom_energy - sumce0, sumce1, sumce2 = torch.split(torch.sum(ce, dim=1), [1, 1, 1], dim=-1) - faked_grad = torch.ones_like(sumce0) - lst: list[torch.Tensor | None] = [faked_grad] - extended_virial_corr0 = torch.autograd.grad( - [sumce0], - [extended_coord], - grad_outputs=lst, - create_graph=False, - retain_graph=True, - )[0] - assert extended_virial_corr0 is not None - extended_virial_corr1 = torch.autograd.grad( - [sumce1], - [extended_coord], - grad_outputs=lst, - create_graph=False, - retain_graph=True, - )[0] - assert extended_virial_corr1 is not None - extended_virial_corr2 = torch.autograd.grad( - [sumce2], - [extended_coord], - grad_outputs=lst, - create_graph=False, - retain_graph=True, - )[0] - assert extended_virial_corr2 is not None - extended_virial_corr = torch.concat( - [ - extended_virial_corr0.unsqueeze(-1), - extended_virial_corr1.unsqueeze(-1), - extended_virial_corr2.unsqueeze(-1), - ], - dim=-1, - ) - return extended_virial_corr + sumce = torch.sum(ce, dim=1) # [nf, 3] + + # Use vmap to batch the 3 backward passes (one per spatial component) + basis = torch.eye(3, dtype=sumce.dtype, device=sumce.device) # [3, 3] + basis = basis.unsqueeze(1).expand(3, nf, 3) # [3, nf, 3] + + def grad_fn(grad_output: torch.Tensor) -> torch.Tensor: + result = torch.autograd.grad( + [sumce], + [extended_coord], + grad_outputs=[grad_output], + create_graph=False, + retain_graph=True, + )[0] + assert result is not None + return result + + # [3, nf, nall, 3] — batched over the 3 spatial components + extended_virial_corr = torch.vmap(grad_fn)(basis) + # [3, nf, nall, 3] -> [nf, nall, 3, 3] + return extended_virial_corr.permute(1, 2, 3, 0) def task_deriv_one( From bcb4008af507c95e322259b1fed8e98a1aac6d3b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 12 Feb 2026 20:55:39 +0800 Subject: [PATCH 47/60] export forward_lower, but not successful --- deepmd/pt_expt/model/ener_model.py | 34 +++++ source/tests/pt_expt/model/test_ener_model.py | 125 +++++++++++++++--- 2 files changed, 142 insertions(+), 17 deletions(-) diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index 8a68e57551..a9b9792a62 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -59,3 +59,37 @@ def forward( if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] return model_predict + + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + model_ret = self.call_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + model_predict = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if self.do_grad_r("energy"): + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if self.do_grad_c("energy"): + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( + -3 + ) + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict diff --git a/source/tests/pt_expt/model/test_ener_model.py b/source/tests/pt_expt/model/test_ener_model.py index c243fb792d..df7ee5b069 100644 --- a/source/tests/pt_expt/model/test_ener_model.py +++ b/source/tests/pt_expt/model/test_ener_model.py @@ -7,6 +7,13 @@ from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA from deepmd.dpmodel.fitting import InvarFitting as DPInvarFitting from deepmd.dpmodel.model.ener_model import EnergyModel as DPEnergyModel +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) from deepmd.pt_expt.descriptor.se_e2_a import ( DescrptSeA, ) @@ -92,34 +99,106 @@ def test_output_shapes(self) -> None: self.assertEqual(ret["virial"].shape, (1, 9)) @unittest.expectedFailure - def test_exportable(self) -> None: - """Test that EnergyModel can be exported with torch.export. + def test_forward_lower_exportable(self) -> None: + """Test that EnergyModel.forward_lower can be exported with torch.export. - Currently expected to fail because the full model's call() path includes - extend_coord_with_ghosts and neighbor list building, which involve - data-dependent shapes (item() calls) that torch.export cannot trace. - Individual components (descriptor, fitting, atomic model) are exportable. + The full model's forward() is not exportable because extend_coord_with_ghosts + involves data-dependent shapes. forward_lower(), which takes pre-computed + extended coords and neighbor lists, bypasses this. + + However, force/virial computation via torch.autograd.grad is currently + incompatible with torch.export (PyTorch 2.10). During export tracing, + the backward pass captures forward-pass intermediates (e.g. tanh outputs) + as FakeTensor constants instead of reconnecting them to the forward + computation graph. This causes the exported model to fail at runtime. + See https://github.com/pytorch/pytorch/issues/153251 for context. + When PyTorch fixes autograd.grad tracing in torch.export, this test + should be updated to remove @unittest.expectedFailure. """ md = self._make_model() md.eval() - coord = self.coord.clone().requires_grad_(True) - cell = self.cell.reshape(1, 9) - # Test forward pass - ret0 = md(coord, self.atype, cell) + # Prepare extended coords and neighbor list using dpmodel utilities + coord_np = self.coord.detach().cpu().numpy() + atype_np = self.atype.detach().cpu().numpy() + cell_np = self.cell.reshape(1, 9).detach().cpu().numpy() + coord_normalized = normalize_coord( + coord_np.reshape(1, self.natoms, 3), + cell_np.reshape(1, 3, 3), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, + atype_np, + cell_np, + self.rcut, + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + self.natoms, + self.rcut, + self.sel, + distinguish_types=True, + ) + extended_coord = extended_coord.reshape(1, -1, 3) + + # Convert to torch tensors + ext_coord = torch.tensor( + extended_coord, + dtype=torch.float64, + device=self.device, + ).requires_grad_(True) + ext_atype = torch.tensor( + extended_atype, + dtype=torch.int64, + device=self.device, + ) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=self.device) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=self.device) + + # Test forward_lower pass with atomic virial + ret0 = md.forward_lower( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + do_atomic_virial=True, + ) self.assertIn("energy", ret0) + self.assertIn("extended_force", ret0) + self.assertIn("virial", ret0) + self.assertIn("extended_virial", ret0) - # Test torch.export + # Export forward_lower via a wrapper module + class ForwardLowerWrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, extended_coord, extended_atype, nlist, mapping): + return self.model.forward_lower( + extended_coord, + extended_atype, + nlist, + mapping, + do_atomic_virial=True, + ) + + wrapper = ForwardLowerWrapper(md) exported = torch.export.export( - md, - (coord, self.atype, cell), + wrapper, + (ext_coord, ext_atype, nlist_t, mapping_t), strict=False, ) self.assertIsNotNone(exported) # Test exported model produces same output - coord2 = self.coord.clone().requires_grad_(True) - ret1 = exported.module()(coord2, self.atype, cell) + ext_coord2 = torch.tensor( + extended_coord, + dtype=torch.float64, + device=self.device, + ).requires_grad_(True) + ret1 = exported.module()(ext_coord2, ext_atype, nlist_t, mapping_t) np.testing.assert_allclose( ret0["energy"].detach().cpu().numpy(), ret1["energy"].detach().cpu().numpy(), @@ -127,8 +206,20 @@ def test_exportable(self) -> None: atol=1e-10, ) np.testing.assert_allclose( - ret0["force"].detach().cpu().numpy(), - ret1["force"].detach().cpu().numpy(), + ret0["extended_force"].detach().cpu().numpy(), + ret1["extended_force"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + ret0["virial"].detach().cpu().numpy(), + ret1["virial"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + ret0["extended_virial"].detach().cpu().numpy(), + ret1["extended_virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10, ) From d0e22d2d81c6778322d754a3c57189bd2a49ee1b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 13 Feb 2026 15:57:00 +0800 Subject: [PATCH 48/60] make forward_lower exportable --- deepmd/pt_expt/model/ener_model.py | 54 ++++++++- deepmd/pt_expt/utils/network.py | 112 +++++++++++++++++- source/tests/pt_expt/model/test_ener_model.py | 62 ++++------ 3 files changed, 184 insertions(+), 44 deletions(-) diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index a9b9792a62..a5a774a6da 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -4,6 +4,9 @@ ) import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.model.dp_model import ( DPModelCommon, @@ -60,7 +63,7 @@ def forward( model_predict["mask"] = model_ret["mask"] return model_predict - def forward_lower( + def _forward_lower( self, extended_coord: torch.Tensor, extended_atype: torch.Tensor, @@ -93,3 +96,52 @@ def forward_lower( if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] return model_predict + + def forward_lower( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> torch.nn.Module: + """Trace ``_forward_lower`` into an exportable module. + + Uses ``make_fx`` to trace through ``torch.autograd.grad``, + decomposing the backward pass into primitive ops. The returned + module can be passed directly to ``torch.export.export``. + + Parameters + ---------- + extended_coord, extended_atype, nlist, mapping, fparam, aparam, do_atomic_virial + Sample inputs with representative shapes (used for tracing). + + Returns + ------- + torch.nn.Module + A traced module whose ``forward`` accepts + ``(extended_coord, extended_atype, nlist, mapping)`` and + returns a dict with the same keys as ``_forward_lower``. + """ + model = self + + def fn( + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + extended_coord = extended_coord.detach().requires_grad_(True) + return model._forward_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + return make_fx(fn)(extended_coord, extended_atype, nlist, mapping) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 1611ab53d2..929907c2f3 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -26,6 +26,21 @@ class TorchArrayParam(torch.nn.Parameter): + """Parameter subclass that supports ``np.array(param)`` conversion. + + Note: this class is intentionally NOT used for model parameters. + ``make_fx`` (``torch.fx.experimental.proxy_tensor``) uses + ``ProxyTorchDispatchMode`` to intercept tensor operations. When an + operand is a *subclass* of ``torch.Tensor`` (including subclasses of + ``torch.nn.Parameter``), PyTorch invokes the ``__torch_function__`` + protocol which the proxy dispatch mode does not handle, causing + ``aten.mm`` and other ops to fail with "Multiple dispatch failed … + returned NotImplemented". Using plain ``torch.nn.Parameter`` avoids + this because the proxy mode is designed to work with the base + ``Parameter`` type. ``TorchArrayParam`` is kept only for backward + compatibility and should not be used for new code. + """ + def __new__( # noqa: PYI034 cls, data: Any = None, requires_grad: bool = True ) -> "TorchArrayParam": @@ -40,6 +55,31 @@ def __array__(self, dtype: Any | None = None) -> np.ndarray: # do not apply torch_module until its setattr working to register parameters class NativeLayer(NativeLayerDP, torch.nn.Module): + """PyTorch layer wrapping dpmodel's ``NativeLayer``. + + Two aspects of the inherited dpmodel ``call()`` are incompatible with + ``make_fx`` tracing (used to export ``forward_lower`` with + ``autograd.grad``-based force/virial computation): + + 1. **Ellipsis indexing** (``self.w[...]``): On a ``torch.Tensor`` + this triggers ``aten.alias``, an op that ``ProxyTorchDispatchMode`` + does not support, resulting in "Multiple dispatch failed for + ``aten.alias.default``". + 2. **``array_api_compat`` wrappers** (``xp = array_api_compat + .array_namespace(x); xp.matmul(…)``): The wrappers re-enter + ``torch.matmul`` through Python, which goes through the + ``__torch_function__`` protocol. Under the proxy dispatch mode + this path also fails with "Multiple dispatch failed". + + This class therefore overrides ``call()`` with an implementation that + uses plain ``torch`` ops exclusively (``torch.matmul``, ``torch.tanh``, + etc.), avoiding both issues. + + Trainable weights are stored as plain ``torch.nn.Parameter`` (not + ``TorchArrayParam``) for the same ``make_fx`` compatibility reason — + see the ``TorchArrayParam`` docstring. + """ + def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) NativeLayerDP.__init__(self, *args, **kwargs) @@ -61,8 +101,8 @@ def __setattr__(self, name: str, value: Any) -> None: if getattr(self, "trainable", False): param = ( value - if isinstance(value, TorchArrayParam) - else TorchArrayParam(val, requires_grad=True) + if isinstance(value, torch.nn.Parameter) + else torch.nn.Parameter(val, requires_grad=True) ) if name in self._parameters: self._parameters[name] = param @@ -76,10 +116,78 @@ def __setattr__(self, name: str, value: Any) -> None: return return super().__setattr__(name, value) + def call(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass using pure torch ops. + + Overrides dpmodel's ``call()`` to ensure compatibility with + ``make_fx`` (``torch.fx.experimental.proxy_tensor``). + + The dpmodel implementation uses ``self.w[...]`` and + ``array_api_compat.array_namespace(x).matmul(…)`` for + backend-agnostic array operations. Both patterns break under + ``make_fx``'s ``ProxyTorchDispatchMode``: + + - ``self.w[...]`` emits ``aten.alias`` which the proxy mode + cannot dispatch. + - ``array_api_compat`` re-enters ``torch.matmul`` via Python, + hitting ``__torch_function__`` which the proxy mode returns + ``NotImplemented`` for. + + This override uses ``torch.matmul``, ``torch.cat``, and + ``_torch_activation`` directly, sidestepping both issues. + """ + if self.w is None or self.activation_function is None: + raise ValueError("w, b, and activation_function must be set") + y = ( + torch.matmul(x, self.w) + self.b + if self.b is not None + else torch.matmul(x, self.w) + ) + if y.dtype != x.dtype: + y = y.to(x.dtype) + y = _torch_activation(y, self.activation_function) + if self.idt is not None: + y = y * self.idt + if self.resnet and self.w.shape[1] == self.w.shape[0]: + y = y + x + elif self.resnet and self.w.shape[1] == 2 * self.w.shape[0]: + y = y + torch.cat([x, x], dim=-1) + return y + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) +def _torch_activation(x: torch.Tensor, name: str) -> torch.Tensor: + """Apply activation function using native torch ops. + + The dpmodel ``get_activation_fn`` returns closures that call + ``array_api_compat.array_namespace(x).tanh(x)`` etc. Under + ``make_fx`` proxy tracing, the ``array_api_compat`` indirection + triggers ``__torch_function__`` dispatch failures. This function + calls ``torch.tanh`` and friends directly to avoid the issue. + """ + name = name.lower() + if name == "tanh": + return torch.tanh(x) + elif name == "relu": + return torch.relu(x) + elif name in ("gelu", "gelu_tf"): + return torch.nn.functional.gelu(x, approximate="tanh") + elif name == "relu6": + return torch.clamp(x, min=0.0, max=6.0) + elif name == "softplus": + return torch.nn.functional.softplus(x) + elif name == "sigmoid": + return torch.sigmoid(x) + elif name == "silu": + return torch.nn.functional.silu(x) + elif name in ("none", "linear"): + return x + else: + raise NotImplementedError(name) + + @torch_module class NativeNet(make_multilayer_network(NativeLayer, NativeOP)): def __init__(self, layers: list[dict] | None = None) -> None: diff --git a/source/tests/pt_expt/model/test_ener_model.py b/source/tests/pt_expt/model/test_ener_model.py index df7ee5b069..7ea8f96bfc 100644 --- a/source/tests/pt_expt/model/test_ener_model.py +++ b/source/tests/pt_expt/model/test_ener_model.py @@ -98,22 +98,12 @@ def test_output_shapes(self) -> None: self.assertEqual(ret["force"].shape, (1, self.natoms, 3)) self.assertEqual(ret["virial"].shape, (1, 9)) - @unittest.expectedFailure def test_forward_lower_exportable(self) -> None: - """Test that EnergyModel.forward_lower can be exported with torch.export. + """Test that EnergyModel.forward_lower returns an exportable module. - The full model's forward() is not exportable because extend_coord_with_ghosts - involves data-dependent shapes. forward_lower(), which takes pre-computed - extended coords and neighbor lists, bypasses this. - - However, force/virial computation via torch.autograd.grad is currently - incompatible with torch.export (PyTorch 2.10). During export tracing, - the backward pass captures forward-pass intermediates (e.g. tanh outputs) - as FakeTensor constants instead of reconnecting them to the forward - computation graph. This causes the exported model to fail at runtime. - See https://github.com/pytorch/pytorch/issues/153251 for context. - When PyTorch fixes autograd.grad tracing in torch.export, this test - should be updated to remove @unittest.expectedFailure. + forward_lower() uses make_fx to trace through torch.autograd.grad, + decomposing the backward pass into primitive ops. The returned module + can be passed directly to torch.export.export. """ md = self._make_model() md.eval() @@ -147,7 +137,7 @@ def test_forward_lower_exportable(self) -> None: extended_coord, dtype=torch.float64, device=self.device, - ).requires_grad_(True) + ) ext_atype = torch.tensor( extended_atype, dtype=torch.int64, @@ -156,9 +146,9 @@ def test_forward_lower_exportable(self) -> None: nlist_t = torch.tensor(nlist, dtype=torch.int64, device=self.device) mapping_t = torch.tensor(mapping, dtype=torch.int64, device=self.device) - # Test forward_lower pass with atomic virial - ret0 = md.forward_lower( - ext_coord, + # Eager reference via _forward_lower + ret0 = md._forward_lower( + ext_coord.requires_grad_(True), ext_atype, nlist_t, mapping_t, @@ -169,36 +159,26 @@ def test_forward_lower_exportable(self) -> None: self.assertIn("virial", ret0) self.assertIn("extended_virial", ret0) - # Export forward_lower via a wrapper module - class ForwardLowerWrapper(torch.nn.Module): - def __init__(self, model): - super().__init__() - self.model = model - - def forward(self, extended_coord, extended_atype, nlist, mapping): - return self.model.forward_lower( - extended_coord, - extended_atype, - nlist, - mapping, - do_atomic_virial=True, - ) + # forward_lower returns a traced module + traced = md.forward_lower( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + do_atomic_virial=True, + ) + self.assertIsInstance(traced, torch.nn.Module) - wrapper = ForwardLowerWrapper(md) + # The traced module should be directly exportable exported = torch.export.export( - wrapper, + traced, (ext_coord, ext_atype, nlist_t, mapping_t), strict=False, ) self.assertIsNotNone(exported) - # Test exported model produces same output - ext_coord2 = torch.tensor( - extended_coord, - dtype=torch.float64, - device=self.device, - ).requires_grad_(True) - ret1 = exported.module()(ext_coord2, ext_atype, nlist_t, mapping_t) + # Verify exported model produces same output + ret1 = exported.module()(ext_coord, ext_atype, nlist_t, mapping_t) np.testing.assert_allclose( ret0["energy"].detach().cpu().numpy(), ret1["energy"].detach().cpu().numpy(), From 1f7bb6c1ebd49ff95d4b49e37dd36977fb37e5db Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 13 Feb 2026 20:48:30 +0800 Subject: [PATCH 49/60] implement all EnergyModel APIs in pt but not in dpmodel. add a comprehensive UT to make sure the consistency between pt and dpmodel --- .../dpmodel/atomic_model/dp_atomic_model.py | 34 + deepmd/dpmodel/fitting/general_fitting.py | 31 +- deepmd/dpmodel/model/dp_model.py | 23 + deepmd/dpmodel/model/ener_model.py | 25 + deepmd/dpmodel/model/make_model.py | 41 +- deepmd/pd/model/model/make_model.py | 12 +- deepmd/pt/model/model/make_model.py | 12 +- source/tests/consistent/model/test_ener.py | 639 ++++++++++++++++++ 8 files changed, 797 insertions(+), 20 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 0f5b12bc9c..86e13e14df 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -3,6 +3,8 @@ Any, ) +import array_api_compat + from deepmd.dpmodel.array_api import ( Array, ) @@ -54,6 +56,10 @@ def __init__( if hasattr(self.fitting, "reinit_exclude"): self.fitting.reinit_exclude(self.atom_exclude_types) self.type_map = type_map + self.enable_eval_descriptor_hook = False + self.enable_eval_fitting_last_layer_hook = False + self.eval_descriptor_list: list[Array] = [] + self.eval_fitting_last_layer_list: list[Array] = [] super().init_out_stat() def fitting_output_def(self) -> FittingOutputDef: @@ -126,6 +132,27 @@ def enable_compression( check_frequency, ) + def set_eval_descriptor_hook(self, enable: bool) -> None: + """Set the hook for evaluating descriptor and clear the cache.""" + self.enable_eval_descriptor_hook = enable + self.eval_descriptor_list.clear() + + def eval_descriptor(self) -> Array: + """Evaluate the descriptor by concatenating cached results.""" + xp = array_api_compat.array_namespace(self.eval_descriptor_list[0]) + return xp.concat(self.eval_descriptor_list, axis=0) + + def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: + """Set the hook for evaluating fitting last layer output and clear the cache.""" + self.enable_eval_fitting_last_layer_hook = enable + self.fitting.set_return_middle_output(enable) + self.eval_fitting_last_layer_list.clear() + + def eval_fitting_last_layer(self) -> Array: + """Evaluate the fitting last layer output by concatenating cached results.""" + xp = array_api_compat.array_namespace(self.eval_fitting_last_layer_list[0]) + return xp.concat(self.eval_fitting_last_layer_list, axis=0) + def forward_atomic( self, extended_coord: Array, @@ -166,6 +193,8 @@ def forward_atomic( nlist, mapping=mapping, ) + if self.enable_eval_descriptor_hook: + self.eval_descriptor_list.append(descriptor) ret = self.fitting( descriptor, atype, @@ -175,6 +204,11 @@ def forward_atomic( fparam=fparam, aparam=aparam, ) + if self.enable_eval_fitting_last_layer_hook: + assert "middle_output" in ret, ( + "eval_fitting_last_layer not supported for this fitting net!" + ) + self.eval_fitting_last_layer_list.append(ret.pop("middle_output")) return ret def change_type_map( diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index fabc39ae96..cedc0eb916 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -195,6 +195,7 @@ def __init__( self.default_fparam_tensor = np.array(self.default_fparam, dtype=self.prec) else: self.default_fparam_tensor = None + self.eval_return_middle_output = False # init networks in_dim = ( self.dim_descrpt @@ -410,6 +411,10 @@ def __getitem__(self, key: str) -> Any: else: raise KeyError(key) + def set_return_middle_output(self, return_middle_output: bool = True) -> None: + """Set whether to return the output of the last hidden layer.""" + self.eval_return_middle_output = return_middle_output + def reinit_exclude( self, exclude_types: list[int] = [], @@ -584,12 +589,19 @@ def _call_common( ) # calculate the prediction + results: dict[str, Array] = {} if not self.mixed_types: outs = xp.zeros( [nf, nloc, net_dim_out], dtype=get_xp_precision(xp, self.precision), device=array_api_compat.device(descriptor), ) + if self.eval_return_middle_output: + outs_middle = xp.zeros( + [nf, nloc, self.neuron[-1]], + dtype=get_xp_precision(xp, self.precision), + device=array_api_compat.device(descriptor), + ) for type_i in range(self.ntypes): mask = xp.tile( xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out) @@ -605,10 +617,26 @@ def _call_common( mask, atom_property, xp.zeros_like(atom_property) ) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] + if self.eval_return_middle_output: + middle_output_type = self.nets[(type_i,)].call_until_last(xx) + middle_mask = xp.tile( + xp.reshape((atype == type_i), (nf, nloc, 1)), + (1, 1, self.neuron[-1]), + ) + middle_output_type = xp.where( + middle_mask, + middle_output_type, + xp.zeros_like(middle_output_type), + ) + outs_middle = outs_middle + middle_output_type + if self.eval_return_middle_output: + results["middle_output"] = outs_middle else: outs = self.nets[()](xx) if xx_zeros is not None: outs -= self.nets[()](xx_zeros) + if self.eval_return_middle_output: + results["middle_output"] = self.nets[()].call_until_last(xx) outs += xp.reshape( xp.take( xp.astype(self.bias_atom_e[...], outs.dtype), @@ -622,4 +650,5 @@ def _call_common( exclude_mask = xp.astype(exclude_mask, xp.bool) # nf x nloc x nod outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs)) - return {self.var_name: outs} + results[self.var_name] = outs + return results diff --git a/deepmd/dpmodel/model/dp_model.py b/deepmd/dpmodel/model/dp_model.py index 063533f2a7..fcf7c88b86 100644 --- a/deepmd/dpmodel/model/dp_model.py +++ b/deepmd/dpmodel/model/dp_model.py @@ -1,6 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from deepmd.dpmodel.array_api import ( + Array, +) from deepmd.dpmodel.descriptor.base_descriptor import ( BaseDescriptor, ) @@ -48,3 +51,23 @@ def update_sel( def get_fitting_net(self) -> BaseFitting: """Get the fitting network.""" return self.atomic_model.fitting + + def get_descriptor(self) -> BaseDescriptor: + """Get the descriptor.""" + return self.atomic_model.descriptor + + def set_eval_descriptor_hook(self, enable: bool) -> None: + """Set the hook for evaluating descriptor.""" + self.atomic_model.set_eval_descriptor_hook(enable) + + def eval_descriptor(self) -> Array: + """Evaluate the descriptor.""" + return self.atomic_model.eval_descriptor() + + def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: + """Set the hook for evaluating fitting last layer output.""" + self.atomic_model.set_eval_fitting_last_layer_hook(enable) + + def eval_fitting_last_layer(self) -> Array: + """Evaluate the fitting last layer output.""" + return self.atomic_model.eval_fitting_last_layer() diff --git a/deepmd/dpmodel/model/ener_model.py b/deepmd/dpmodel/model/ener_model.py index 9d38a17513..ef1e986f26 100644 --- a/deepmd/dpmodel/model/ener_model.py +++ b/deepmd/dpmodel/model/ener_model.py @@ -47,3 +47,28 @@ def atomic_output_def(self) -> FittingOutputDef: if self._enable_hessian: return self.hess_fitting_def return super().atomic_output_def() + + def translated_output_def(self) -> dict[str, Any]: + """Get the translated output definition. + + Maps internal output names to user-facing names, e.g. + ``energy_redu`` -> ``energy``, ``energy_derv_r`` -> ``force``. + """ + out_def_data = self.model_output_def().get_data() + output_def = { + "atom_energy": out_def_data["energy"], + "energy": out_def_data["energy_redu"], + } + if self.do_grad_r("energy"): + output_def["force"] = out_def_data["energy_derv_r"] + output_def["force"].squeeze(-2) + if self.do_grad_c("energy"): + output_def["virial"] = out_def_data["energy_derv_c_redu"] + output_def["virial"].squeeze(-2) + output_def["atom_virial"] = out_def_data["energy_derv_c"] + output_def["atom_virial"].squeeze(-3) + if "mask" in out_def_data: + output_def["mask"] = out_def_data["mask"] + if self._enable_hessian: + output_def["hessian"] = out_def_data["energy_derv_r_derv_r"] + return output_def diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index 0a77549ca4..e115478df5 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -257,7 +257,7 @@ def call( The keys are defined by the `ModelOutputDef`. """ - cc, bb, fp, ap, input_prec = self.input_type_cast( + cc, bb, fp, ap, input_prec = self._input_type_cast( coord, box=box, fparam=fparam, aparam=aparam ) del coord, box, fparam, aparam @@ -274,7 +274,7 @@ def call( aparam=ap, do_atomic_virial=do_atomic_virial, ) - model_predict = self.output_type_cast(model_predict, input_prec) + model_predict = self._output_type_cast(model_predict, input_prec) return model_predict def call_lower( @@ -323,7 +323,7 @@ def call_lower( nlist, extra_nlist_sort=self.need_sorted_nlist_for_lower(), ) - cc_ext, _, fp, ap, input_prec = self.input_type_cast( + cc_ext, _, fp, ap, input_prec = self._input_type_cast( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam @@ -336,7 +336,7 @@ def call_lower( aparam=ap, do_atomic_virial=do_atomic_virial, ) - model_predict = self.output_type_cast(model_predict, input_prec) + model_predict = self._output_type_cast(model_predict, input_prec) return model_predict def forward_common_atomic( @@ -366,8 +366,35 @@ def forward_common_atomic( ) forward_lower = call_lower + forward_common = call + forward_common_lower = call_lower - def input_type_cast( + def get_out_bias(self) -> Array: + """Get the output bias.""" + return self.atomic_model.out_bias + + def set_out_bias(self, out_bias: Array) -> None: + """Set the output bias.""" + self.atomic_model.out_bias = out_bias + + def change_out_bias( + self, + merged: Any, + bias_adjust_mode: str = "change-by-statistic", + ) -> None: + """Change the output bias according to the input data and the pretrained model. + + Parameters + ---------- + merged + The merged data samples. + bias_adjust_mode : str + The mode for changing output bias: + 'change-by-statistic' or 'set-by-statistic'. + """ + self.atomic_model.change_out_bias(merged, bias_adjust_mode=bias_adjust_mode) + + def _input_type_cast( self, coord: Array, box: Array | None = None, @@ -399,7 +426,7 @@ def input_type_cast( input_dtype, ) - def output_type_cast( + def _output_type_cast( self, model_ret: dict[str, Array], input_prec: Any, @@ -411,7 +438,7 @@ def output_type_cast( model_ret The model output. input_prec - The input dtype returned by ``input_type_cast``. + The input dtype returned by ``_input_type_cast``. """ model_ret_not_none = [vv for vv in model_ret.values() if vv is not None] if not model_ret_not_none: diff --git a/deepmd/pd/model/model/make_model.py b/deepmd/pd/model/model/make_model.py index 72811c9e1c..321c939061 100644 --- a/deepmd/pd/model/model/make_model.py +++ b/deepmd/pd/model/model/make_model.py @@ -162,7 +162,7 @@ def forward_common( The keys are defined by the `ModelOutputDef`. """ - cc, bb, fp, ap, input_prec = self.input_type_cast( + cc, bb, fp, ap, input_prec = self._input_type_cast( coord, box=box, fparam=fparam, aparam=aparam ) del coord, box, fparam, aparam @@ -196,7 +196,7 @@ def forward_common( mapping, do_atomic_virial=do_atomic_virial, ) - model_predict = self.output_type_cast(model_predict, input_prec) + model_predict = self._output_type_cast(model_predict, input_prec) return model_predict def get_out_bias(self) -> paddle.Tensor: @@ -283,7 +283,7 @@ def forward_common_lower( nlist = self.format_nlist( extended_coord, extended_atype, nlist, extra_nlist_sort=extra_nlist_sort ) - cc_ext, _, fp, ap, input_prec = self.input_type_cast( + cc_ext, _, fp, ap, input_prec = self._input_type_cast( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam @@ -303,10 +303,10 @@ def forward_common_lower( do_atomic_virial=do_atomic_virial, create_graph=self.training, ) - model_predict = self.output_type_cast(model_predict, input_prec) + model_predict = self._output_type_cast(model_predict, input_prec) return model_predict - def input_type_cast( + def _input_type_cast( self, coord: paddle.Tensor, box: paddle.Tensor | None = None, @@ -351,7 +351,7 @@ def input_type_cast( input_prec, ) - def output_type_cast( + def _output_type_cast( self, model_ret: dict[str, paddle.Tensor], input_prec: str, diff --git a/deepmd/pt/model/model/make_model.py b/deepmd/pt/model/model/make_model.py index c958a62bf6..87a1d6b9c5 100644 --- a/deepmd/pt/model/model/make_model.py +++ b/deepmd/pt/model/model/make_model.py @@ -164,7 +164,7 @@ def forward_common( The keys are defined by the `ModelOutputDef`. """ - cc, bb, fp, ap, input_prec = self.input_type_cast( + cc, bb, fp, ap, input_prec = self._input_type_cast( coord, box=box, fparam=fparam, aparam=aparam ) del coord, box, fparam, aparam @@ -198,7 +198,7 @@ def forward_common( mapping, do_atomic_virial=do_atomic_virial, ) - model_predict = self.output_type_cast(model_predict, input_prec) + model_predict = self._output_type_cast(model_predict, input_prec) return model_predict def get_out_bias(self) -> torch.Tensor: @@ -285,7 +285,7 @@ def forward_common_lower( nlist = self.format_nlist( extended_coord, extended_atype, nlist, extra_nlist_sort=extra_nlist_sort ) - cc_ext, _, fp, ap, input_prec = self.input_type_cast( + cc_ext, _, fp, ap, input_prec = self._input_type_cast( extended_coord, fparam=fparam, aparam=aparam ) del extended_coord, fparam, aparam @@ -306,10 +306,10 @@ def forward_common_lower( create_graph=self.training, mask=atomic_ret["mask"] if "mask" in atomic_ret else None, ) - model_predict = self.output_type_cast(model_predict, input_prec) + model_predict = self._output_type_cast(model_predict, input_prec) return model_predict - def input_type_cast( + def _input_type_cast( self, coord: torch.Tensor, box: torch.Tensor | None = None, @@ -354,7 +354,7 @@ def input_type_cast( input_prec, ) - def output_type_cast( + def _output_type_cast( self, model_ret: dict[str, torch.Tensor], input_prec: str, diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index c1ee630516..dd739b11d5 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -576,3 +576,642 @@ def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: ret["extended_virial"].flatten(), ) raise ValueError(f"Unknown backend: {backend}") + + +@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed") +class TestEnerModelAPIs(unittest.TestCase): + """Test consistency of model-level APIs between pt and dpmodel backends. + + Both models are constructed from the same serialized weights + (dpmodel -> serialize -> pt deserialize) so that numerical outputs + can be compared directly. + """ + + def setUp(self) -> None: + from deepmd.utils.argcheck import ( + model_args, + ) + + data = model_args().normalize_value( + { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": [20, 20], + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [3, 6], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "type_one_side": True, + "seed": 1, + }, + "fitting_net": { + "neuron": [5, 5], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + }, + trim_pattern="_*", + ) + # Build dpmodel first, then deserialize into pt to share weights + self.dp_model = get_model_dp(data) + serialized = self.dp_model.serialize() + self.pt_model = EnergyModelPT.deserialize(serialized) + + # Coords / atype / box + self.coords = np.array( + [ + 12.83, + 2.56, + 2.18, + 12.09, + 2.87, + 2.74, + 00.25, + 3.32, + 1.68, + 3.36, + 3.00, + 1.81, + 3.51, + 2.51, + 2.60, + 4.27, + 3.22, + 1.56, + ], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, -1, 3) + self.atype = np.array([0, 1, 1, 0, 1, 1], dtype=np.int32).reshape(1, -1) + self.box = np.array( + [13.0, 0.0, 0.0, 0.0, 13.0, 0.0, 0.0, 0.0, 13.0], + dtype=GLOBAL_NP_FLOAT_PRECISION, + ).reshape(1, 9) + + # Build extended coords + nlist for lower-level calls + rcut = 6.0 + nframes, nloc = self.atype.shape[:2] + coord_normalized = normalize_coord( + self.coords.reshape(nframes, nloc, 3), + self.box.reshape(nframes, 3, 3), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, self.atype, self.box, rcut + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + [20, 20], + distinguish_types=True, + ) + self.extended_coord = extended_coord.reshape(nframes, -1, 3) + self.extended_atype = extended_atype + self.mapping = mapping + self.nlist = nlist + + def test_translated_output_def(self) -> None: + """translated_output_def should return the same keys on dp and pt.""" + dp_def = self.dp_model.translated_output_def() + pt_def = self.pt_model.translated_output_def() + self.assertEqual(set(dp_def.keys()), set(pt_def.keys())) + for key in dp_def: + self.assertEqual(dp_def[key].shape, pt_def[key].shape) + + def test_get_descriptor(self) -> None: + """get_descriptor should return a non-None object on both backends.""" + self.assertIsNotNone(self.dp_model.get_descriptor()) + self.assertIsNotNone(self.pt_model.get_descriptor()) + + def test_get_fitting_net(self) -> None: + """get_fitting_net should return a non-None object on both backends.""" + self.assertIsNotNone(self.dp_model.get_fitting_net()) + self.assertIsNotNone(self.pt_model.get_fitting_net()) + + def test_get_out_bias(self) -> None: + """get_out_bias should return numerically equal values on dp and pt. + + Freshly constructed models have zero bias; the shape (n_output x ntypes x odim) + is verified. Non-zero bias round-trip is covered by test_set_out_bias. + """ + dp_bias = to_numpy_array(self.dp_model.get_out_bias()) + pt_bias = torch_to_numpy(self.pt_model.get_out_bias()) + np.testing.assert_allclose(dp_bias, pt_bias, rtol=1e-10, atol=1e-10) + # Verify shape is sensible (n_output_keys x ntypes x odim) + self.assertEqual(dp_bias.shape[1], 2) # ntypes + self.assertGreater(dp_bias.shape[0], 0) # at least one output key + + def test_set_out_bias(self) -> None: + """set_out_bias should update the bias on both backends.""" + dp_bias = to_numpy_array(self.dp_model.get_out_bias()) + new_bias = dp_bias + 1.0 + # dp + self.dp_model.set_out_bias(new_bias) + np.testing.assert_allclose( + to_numpy_array(self.dp_model.get_out_bias()), + new_bias, + rtol=1e-10, + atol=1e-10, + ) + # pt + self.pt_model.set_out_bias(numpy_to_torch(new_bias)) + np.testing.assert_allclose( + torch_to_numpy(self.pt_model.get_out_bias()), + new_bias, + rtol=1e-10, + atol=1e-10, + ) + + def test_forward_common_alias(self) -> None: + """forward_common should be the same as call on dpmodel.""" + ret_call = self.dp_model.call( + self.coords, + self.atype, + box=self.box, + ) + ret_fc = self.dp_model.forward_common( + self.coords, + self.atype, + box=self.box, + ) + for key in ret_call: + np.testing.assert_equal(ret_call[key], ret_fc[key]) + + def test_forward_common_lower_alias(self) -> None: + """forward_common_lower should be the same as call_lower on dpmodel.""" + ret_call = self.dp_model.call_lower( + self.extended_coord, + self.extended_atype, + self.nlist, + self.mapping, + ) + ret_fc = self.dp_model.forward_common_lower( + self.extended_coord, + self.extended_atype, + self.nlist, + self.mapping, + ) + for key in ret_call: + np.testing.assert_equal(ret_call[key], ret_fc[key]) + + def test_eval_descriptor(self) -> None: + """eval_descriptor should produce consistent results across dp and pt.""" + # dpmodel + self.dp_model.set_eval_descriptor_hook(True) + self.dp_model.call_lower( + self.extended_coord, + self.extended_atype, + self.nlist, + self.mapping, + ) + dp_desc = self.dp_model.eval_descriptor() + + # pt + self.pt_model.set_eval_descriptor_hook(True) + self.pt_model.forward_common_lower( + numpy_to_torch(self.extended_coord), + numpy_to_torch(self.extended_atype), + numpy_to_torch(self.nlist), + numpy_to_torch(self.mapping), + ) + pt_desc = torch_to_numpy(self.pt_model.eval_descriptor()) + + np.testing.assert_allclose(dp_desc, pt_desc, rtol=1e-10, atol=1e-10) + + def test_eval_fitting_last_layer(self) -> None: + """eval_fitting_last_layer should produce consistent results across dp and pt.""" + # dpmodel + self.dp_model.set_eval_fitting_last_layer_hook(True) + self.dp_model.call_lower( + self.extended_coord, + self.extended_atype, + self.nlist, + self.mapping, + ) + dp_fl = self.dp_model.eval_fitting_last_layer() + + # pt + self.pt_model.set_eval_fitting_last_layer_hook(True) + self.pt_model.forward_common_lower( + numpy_to_torch(self.extended_coord), + numpy_to_torch(self.extended_atype), + numpy_to_torch(self.nlist), + numpy_to_torch(self.mapping), + ) + pt_fl = torch_to_numpy(self.pt_model.eval_fitting_last_layer()) + + np.testing.assert_allclose(dp_fl, pt_fl, rtol=1e-10, atol=1e-10) + + def test_model_output_def(self) -> None: + """model_output_def should return the same keys and shapes on dp and pt.""" + dp_def = self.dp_model.model_output_def().get_data() + pt_def = self.pt_model.model_output_def().get_data() + self.assertEqual(set(dp_def.keys()), set(pt_def.keys())) + for key in dp_def: + self.assertEqual(dp_def[key].shape, pt_def[key].shape) + + def test_model_output_type(self) -> None: + """model_output_type should return the same list on dp and pt.""" + self.assertEqual( + self.dp_model.model_output_type(), + self.pt_model.model_output_type(), + ) + + def test_do_grad_r(self) -> None: + """do_grad_r should return the same value on dp and pt.""" + self.assertEqual( + self.dp_model.do_grad_r("energy"), + self.pt_model.do_grad_r("energy"), + ) + self.assertTrue(self.dp_model.do_grad_r("energy")) + + def test_do_grad_c(self) -> None: + """do_grad_c should return the same value on dp and pt.""" + self.assertEqual( + self.dp_model.do_grad_c("energy"), + self.pt_model.do_grad_c("energy"), + ) + self.assertTrue(self.dp_model.do_grad_c("energy")) + + def test_get_rcut(self) -> None: + """get_rcut should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_rcut(), self.pt_model.get_rcut()) + self.assertAlmostEqual(self.dp_model.get_rcut(), 6.0) + + def test_get_type_map(self) -> None: + """get_type_map should return the same list on dp and pt.""" + self.assertEqual(self.dp_model.get_type_map(), self.pt_model.get_type_map()) + self.assertEqual(self.dp_model.get_type_map(), ["O", "H"]) + + def test_get_sel(self) -> None: + """get_sel should return the same list on dp and pt.""" + self.assertEqual(self.dp_model.get_sel(), self.pt_model.get_sel()) + self.assertEqual(self.dp_model.get_sel(), [20, 20]) + + def test_get_nsel(self) -> None: + """get_nsel should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_nsel(), self.pt_model.get_nsel()) + self.assertEqual(self.dp_model.get_nsel(), 40) + + def test_get_nnei(self) -> None: + """get_nnei should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_nnei(), self.pt_model.get_nnei()) + self.assertEqual(self.dp_model.get_nnei(), 40) + + def test_mixed_types(self) -> None: + """mixed_types should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.mixed_types(), self.pt_model.mixed_types()) + # se_e2_a is not mixed-types + self.assertFalse(self.dp_model.mixed_types()) + + def test_has_message_passing(self) -> None: + """has_message_passing should return the same value on dp and pt.""" + self.assertEqual( + self.dp_model.has_message_passing(), + self.pt_model.has_message_passing(), + ) + self.assertFalse(self.dp_model.has_message_passing()) + + def test_need_sorted_nlist_for_lower(self) -> None: + """need_sorted_nlist_for_lower should return the same value on dp and pt.""" + self.assertEqual( + self.dp_model.need_sorted_nlist_for_lower(), + self.pt_model.need_sorted_nlist_for_lower(), + ) + self.assertFalse(self.dp_model.need_sorted_nlist_for_lower()) + + def test_get_dim_fparam(self) -> None: + """get_dim_fparam should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_dim_fparam(), self.pt_model.get_dim_fparam()) + self.assertEqual(self.dp_model.get_dim_fparam(), 0) + + def test_get_dim_aparam(self) -> None: + """get_dim_aparam should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_dim_aparam(), self.pt_model.get_dim_aparam()) + self.assertEqual(self.dp_model.get_dim_aparam(), 0) + + def test_get_sel_type(self) -> None: + """get_sel_type should return the same list on dp and pt.""" + self.assertEqual(self.dp_model.get_sel_type(), self.pt_model.get_sel_type()) + # For this model config, all types are selected (empty list) + self.assertEqual(self.dp_model.get_sel_type(), [0, 1]) + + def test_is_aparam_nall(self) -> None: + """is_aparam_nall should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.is_aparam_nall(), self.pt_model.is_aparam_nall()) + self.assertFalse(self.dp_model.is_aparam_nall()) + + def test_atomic_output_def(self) -> None: + """atomic_output_def should return the same keys and shapes on dp and pt.""" + dp_def = self.dp_model.atomic_output_def() + pt_def = self.pt_model.atomic_output_def() + self.assertEqual(set(dp_def.keys()), set(pt_def.keys())) + for key in dp_def.keys(): + self.assertEqual(dp_def[key].shape, pt_def[key].shape) + + def test_format_nlist(self) -> None: + """format_nlist should produce the same result on dp and pt.""" + dp_nlist = self.dp_model.format_nlist( + self.extended_coord, + self.extended_atype, + self.nlist, + ) + pt_nlist = torch_to_numpy( + self.pt_model.format_nlist( + numpy_to_torch(self.extended_coord), + numpy_to_torch(self.extended_atype), + numpy_to_torch(self.nlist), + ) + ) + np.testing.assert_equal(dp_nlist, pt_nlist) + + def test_forward_common_atomic(self) -> None: + """forward_common_atomic should produce consistent results on dp and pt. + + Compares at the atomic_model level, where both backends define this method. + """ + dp_ret = self.dp_model.atomic_model.forward_common_atomic( + self.extended_coord, + self.extended_atype, + self.nlist, + mapping=self.mapping, + ) + pt_ret = self.pt_model.atomic_model.forward_common_atomic( + numpy_to_torch(self.extended_coord), + numpy_to_torch(self.extended_atype), + numpy_to_torch(self.nlist), + mapping=numpy_to_torch(self.mapping), + ) + # Compare the common keys + common_keys = set(dp_ret.keys()) & set(pt_ret.keys()) + self.assertTrue(len(common_keys) > 0) + for key in common_keys: + if dp_ret[key] is not None and pt_ret[key] is not None: + np.testing.assert_allclose( + dp_ret[key], + torch_to_numpy(pt_ret[key]), + rtol=1e-10, + atol=1e-10, + err_msg=f"Mismatch in forward_common_atomic key '{key}'", + ) + + def test_has_default_fparam(self) -> None: + """has_default_fparam should return the same value on dp and pt.""" + self.assertEqual( + self.dp_model.has_default_fparam(), + self.pt_model.has_default_fparam(), + ) + self.assertFalse(self.dp_model.has_default_fparam()) + + def test_get_default_fparam(self) -> None: + """get_default_fparam should return None on both dp and pt (no fparam configured).""" + dp_val = self.dp_model.get_default_fparam() + pt_val = self.pt_model.get_default_fparam() + self.assertIsNone(dp_val) + self.assertIsNone(pt_val) + # Note: both return None because no default_fparam is configured. + # A non-trivial return requires configuring default_fparam in the fitting net. + + def test_change_out_bias(self) -> None: + """change_out_bias should produce consistent bias on dp and pt.""" + nframes = 2 + nloc = 6 + # Use realistic coords (from setUp, tiled for 2 frames) + coords_2f = np.tile(self.coords, (nframes, 1, 1)) # (2, 6, 3) + atype_2f = np.array([[0, 0, 1, 1, 1, 1], [0, 1, 1, 0, 1, 1]], dtype=np.int32) + box_2f = np.tile(self.box.reshape(1, 3, 3), (nframes, 1, 1)) + natoms_data = np.array([[6, 6, 2, 4], [6, 6, 2, 4]], dtype=np.int32) + energy_data = np.array([10.0, 20.0]).reshape(nframes, 1) + + # dpmodel stat data (numpy) + dp_merged = [ + { + "coord": coords_2f, + "atype": atype_2f, + "atype_ext": atype_2f, + "box": box_2f, + "natoms": natoms_data, + "energy": energy_data, + "find_energy": np.float32(1.0), + } + ] + # pt stat data (torch tensors) + pt_merged = [ + { + "coord": numpy_to_torch(coords_2f), + "atype": numpy_to_torch(atype_2f), + "atype_ext": numpy_to_torch(atype_2f), + "box": numpy_to_torch(box_2f), + "natoms": numpy_to_torch(natoms_data), + "energy": numpy_to_torch(energy_data), + "find_energy": np.float32(1.0), + } + ] + + # Save initial (zero) bias + dp_bias_init = to_numpy_array(self.dp_model.get_out_bias()).copy() + + # Test "set-by-statistic" mode + self.dp_model.change_out_bias(dp_merged, bias_adjust_mode="set-by-statistic") + self.pt_model.change_out_bias(pt_merged, bias_adjust_mode="set-by-statistic") + dp_bias = to_numpy_array(self.dp_model.get_out_bias()) + pt_bias = torch_to_numpy(self.pt_model.get_out_bias()) + np.testing.assert_allclose(dp_bias, pt_bias, rtol=1e-10, atol=1e-10) + # Verify bias actually changed from initial zeros + self.assertFalse( + np.allclose(dp_bias, dp_bias_init), + "set-by-statistic did not change the bias from initial values", + ) + + # Test "change-by-statistic" mode (adjusts bias based on model predictions) + dp_bias_before = dp_bias.copy() + self.dp_model.change_out_bias(dp_merged, bias_adjust_mode="change-by-statistic") + self.pt_model.change_out_bias(pt_merged, bias_adjust_mode="change-by-statistic") + dp_bias2 = to_numpy_array(self.dp_model.get_out_bias()) + pt_bias2 = torch_to_numpy(self.pt_model.get_out_bias()) + np.testing.assert_allclose(dp_bias2, pt_bias2, rtol=1e-10, atol=1e-10) + # Verify change-by-statistic further modified the bias + self.assertFalse( + np.allclose(dp_bias2, dp_bias_before), + "change-by-statistic did not further change the bias", + ) + + def test_change_type_map(self) -> None: + """change_type_map should produce consistent results on dp and pt. + + Uses a DPA1 (se_atten) descriptor since se_e2_a does not support + change_type_map (non-mixed-types descriptors raise NotImplementedError). + """ + from deepmd.utils.argcheck import model_args as model_args_fn + + data = model_args_fn().normalize_value( + { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_atten", + "sel": 20, + "rcut_smth": 0.50, + "rcut": 6.00, + "neuron": [3, 6], + "resnet_dt": False, + "axis_neuron": 2, + "precision": "float64", + "seed": 1, + "attn": 6, + "attn_layer": 0, + }, + "fitting_net": { + "neuron": [5, 5], + "resnet_dt": True, + "precision": "float64", + "seed": 1, + }, + }, + trim_pattern="_*", + ) + dp_model = get_model_dp(data) + pt_model = EnergyModelPT.deserialize(dp_model.serialize()) + + # Set non-zero out_bias so the swap is non-trivial + dp_bias_orig = to_numpy_array(dp_model.get_out_bias()).copy() + new_bias = dp_bias_orig.copy() + new_bias[:, 0, :] = 1.5 # type 0 ("O") + new_bias[:, 1, :] = -3.7 # type 1 ("H") + dp_model.set_out_bias(new_bias) + pt_model.set_out_bias(numpy_to_torch(new_bias)) + + new_type_map = ["H", "O"] + dp_model.change_type_map(new_type_map) + pt_model.change_type_map(new_type_map) + + # Both should have the new type_map + self.assertEqual(dp_model.get_type_map(), new_type_map) + self.assertEqual(pt_model.get_type_map(), new_type_map) + + # Out_bias should be reordered consistently between backends + dp_bias_new = to_numpy_array(dp_model.get_out_bias()) + pt_bias_new = torch_to_numpy(pt_model.get_out_bias()) + np.testing.assert_allclose(dp_bias_new, pt_bias_new, rtol=1e-10, atol=1e-10) + + # Verify the reorder is correct: old type 0 -> new type 1, old type 1 -> new type 0 + np.testing.assert_allclose( + dp_bias_new[:, 0, :], + new_bias[:, 1, :], + rtol=1e-10, + atol=1e-10, + ) + np.testing.assert_allclose( + dp_bias_new[:, 1, :], + new_bias[:, 0, :], + rtol=1e-10, + atol=1e-10, + ) + + def test_update_sel(self) -> None: + """update_sel should return the same result on dp and pt.""" + from unittest.mock import ( + patch, + ) + + from deepmd.dpmodel.model.dp_model import DPModelCommon as DPModelCommonDP + from deepmd.pt.model.model.dp_model import DPModelCommon as DPModelCommonPT + + mock_min_nbor_dist = 0.5 + mock_sel = [10, 20] + local_jdata = { + "type_map": ["O", "H"], + "descriptor": { + "type": "se_e2_a", + "sel": "auto", + "rcut_smth": 0.50, + "rcut": 6.00, + }, + "fitting_net": { + "neuron": [5, 5], + }, + } + type_map = ["O", "H"] + + with patch( + "deepmd.dpmodel.utils.update_sel.UpdateSel.get_nbor_stat", + return_value=(mock_min_nbor_dist, mock_sel), + ): + dp_result, dp_min_dist = DPModelCommonDP.update_sel( + None, type_map, local_jdata + ) + + with patch( + "deepmd.pt.utils.update_sel.UpdateSel.get_nbor_stat", + return_value=(mock_min_nbor_dist, mock_sel), + ): + pt_result, pt_min_dist = DPModelCommonPT.update_sel( + None, type_map, local_jdata + ) + + self.assertEqual(dp_result, pt_result) + self.assertEqual(dp_min_dist, pt_min_dist) + # Verify sel was actually updated (not still "auto") + self.assertIsInstance(dp_result["descriptor"]["sel"], list) + self.assertNotEqual(dp_result["descriptor"]["sel"], "auto") + + def test_get_ntypes(self) -> None: + """get_ntypes should return the same value on dp and pt.""" + self.assertEqual(self.dp_model.get_ntypes(), self.pt_model.get_ntypes()) + self.assertEqual(self.dp_model.get_ntypes(), 2) + + def test_compute_or_load_out_stat(self) -> None: + """compute_or_load_out_stat should produce consistent bias on dp and pt.""" + nframes = 2 + nloc = 6 + coords_2f = np.tile(self.coords, (nframes, 1, 1)) + atype_2f = np.array([[0, 0, 1, 1, 1, 1], [0, 1, 1, 0, 1, 1]], dtype=np.int32) + box_2f = np.tile(self.box.reshape(1, 3, 3), (nframes, 1, 1)) + natoms_data = np.array([[6, 6, 2, 4], [6, 6, 2, 4]], dtype=np.int32) + energy_data = np.array([10.0, 20.0]).reshape(nframes, 1) + + dp_merged = [ + { + "coord": coords_2f, + "atype": atype_2f, + "atype_ext": atype_2f, + "box": box_2f, + "natoms": natoms_data, + "energy": energy_data, + "find_energy": np.float32(1.0), + } + ] + pt_merged = [ + { + "coord": numpy_to_torch(coords_2f), + "atype": numpy_to_torch(atype_2f), + "atype_ext": numpy_to_torch(atype_2f), + "box": numpy_to_torch(box_2f), + "natoms": numpy_to_torch(natoms_data), + "energy": numpy_to_torch(energy_data), + "find_energy": np.float32(1.0), + } + ] + + # Verify bias is initially zero (or at least identical) + dp_bias_before = to_numpy_array(self.dp_model.get_out_bias()).copy() + pt_bias_before = torch_to_numpy(self.pt_model.get_out_bias()).copy() + np.testing.assert_allclose( + dp_bias_before, pt_bias_before, rtol=1e-10, atol=1e-10 + ) + + self.dp_model.atomic_model.compute_or_load_out_stat(dp_merged) + self.pt_model.atomic_model.compute_or_load_out_stat(pt_merged) + + dp_bias_after = to_numpy_array(self.dp_model.get_out_bias()) + pt_bias_after = torch_to_numpy(self.pt_model.get_out_bias()) + np.testing.assert_allclose(dp_bias_after, pt_bias_after, rtol=1e-10, atol=1e-10) + + # Verify bias actually changed (not still all zeros) + self.assertFalse( + np.allclose(dp_bias_after, dp_bias_before), + "compute_or_load_out_stat did not change the bias", + ) From 2ea6b748518a4aa62539cbc427918d8e6d2c7447 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 14 Feb 2026 11:13:37 +0800 Subject: [PATCH 50/60] simplify the code --- .../dpmodel/atomic_model/base_atomic_model.py | 40 ++++++------------- 1 file changed, 12 insertions(+), 28 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 9866ddbc3a..7b02700258 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -387,34 +387,18 @@ def model_forward( xp = array_api_compat.array_namespace(ref_array) # Convert numpy inputs to the model's array type with correct device - device = getattr(ref_array, "device", None) - if device is not None: - # For torch tensors - coord = xp.asarray(coord, device=device) - atype = xp.asarray(atype, device=device) - if box is not None: - # Check if box is all zeros before converting - if np.allclose(box, 0.0): - box = None - else: - box = xp.asarray(box, device=device) - if fparam is not None: - fparam = xp.asarray(fparam, device=device) - if aparam is not None: - aparam = xp.asarray(aparam, device=device) - else: - # For numpy arrays - coord = xp.asarray(coord) - atype = xp.asarray(atype) - if box is not None: - if np.allclose(box, 0.0): - box = None - else: - box = xp.asarray(box) - if fparam is not None: - fparam = xp.asarray(fparam) - if aparam is not None: - aparam = xp.asarray(aparam) + device = array_api_compat.device(ref_array) + coord = xp.asarray(coord, device=device) + atype = xp.asarray(atype, device=device) + if box is not None: + if np.allclose(box, 0.0): + box = None + else: + box = xp.asarray(box, device=device) + if fparam is not None: + fparam = xp.asarray(fparam, device=device) + if aparam is not None: + aparam = xp.asarray(aparam, device=device) ( extended_coord, From 21077bc61ef788d9f667636428818e8a8c3c24a8 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 14 Feb 2026 11:26:56 +0800 Subject: [PATCH 51/60] fix bug --- deepmd/dpmodel/atomic_model/base_atomic_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index 7b02700258..3c9c595221 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -352,8 +352,8 @@ def _store_out_stat( ) -> None: """Store output bias and std into the model.""" ntypes = self.get_ntypes() - out_bias_data = np.copy(self.out_bias) - out_std_data = np.copy(self.out_std) + out_bias_data = np.array(to_numpy_array(self.out_bias)) + out_std_data = np.array(to_numpy_array(self.out_std)) for kk in out_bias.keys(): assert kk in out_std.keys() idx = self._get_bias_index(kk) From bdd015c1d1d0d316e24231e703e1c31a35baee2a Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 14 Feb 2026 11:36:32 +0800 Subject: [PATCH 52/60] fix issues --- deepmd/dpmodel/utils/type_embed.py | 1 - source/tests/pt_expt/atomic_model/test_dp_atomic_model.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/utils/type_embed.py b/deepmd/dpmodel/utils/type_embed.py index bc1146203d..d5173c08d1 100644 --- a/deepmd/dpmodel/utils/type_embed.py +++ b/deepmd/dpmodel/utils/type_embed.py @@ -210,7 +210,6 @@ def change_type_map( # Create random params with same dtype and device as first_layer_matrix extend_type_params = np.random.default_rng().random( [len(type_map), first_layer_matrix.shape[-1]], - dtype=PRECISION_DICT[self.precision], ) extend_type_params = xp.asarray( extend_type_params, diff --git a/source/tests/pt_expt/atomic_model/test_dp_atomic_model.py b/source/tests/pt_expt/atomic_model/test_dp_atomic_model.py index 49e60373d4..0196170cd0 100644 --- a/source/tests/pt_expt/atomic_model/test_dp_atomic_model.py +++ b/source/tests/pt_expt/atomic_model/test_dp_atomic_model.py @@ -142,6 +142,7 @@ def test_exportable(self) -> None: self.assertIn("energy", ret0) # Test torch.export + # Use strict=False for now to handle dynamic shapes exported = torch.export.export( md0, (coord, atype, nlist), From 77609c131231f1cb552351ad785c89f6c490016a Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 14 Feb 2026 13:33:30 +0800 Subject: [PATCH 53/60] more careful check on the compute_or_load_stat --- source/tests/consistent/model/test_ener.py | 98 +++++++++++++++++++--- 1 file changed, 87 insertions(+), 11 deletions(-) diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index dd739b11d5..3f79ebdc64 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -1164,7 +1164,23 @@ def test_get_ntypes(self) -> None: self.assertEqual(self.dp_model.get_ntypes(), 2) def test_compute_or_load_out_stat(self) -> None: - """compute_or_load_out_stat should produce consistent bias on dp and pt.""" + """compute_or_load_out_stat should produce consistent bias on dp and pt. + + Tests both the compute path (from data) and the load path (from file). + Both backends should save the same stat file content and load identical + biases from file. + """ + import tempfile + from pathlib import ( + Path, + ) + + import h5py + + from deepmd.utils.path import ( + DPPath, + ) + nframes = 2 nloc = 6 coords_2f = np.tile(self.coords, (nframes, 1, 1)) @@ -1203,15 +1219,75 @@ def test_compute_or_load_out_stat(self) -> None: dp_bias_before, pt_bias_before, rtol=1e-10, atol=1e-10 ) - self.dp_model.atomic_model.compute_or_load_out_stat(dp_merged) - self.pt_model.atomic_model.compute_or_load_out_stat(pt_merged) + with tempfile.TemporaryDirectory() as tmpdir: + # Create separate h5 files for dp and pt + dp_h5 = str((Path(tmpdir) / "dp_stat.h5").resolve()) + pt_h5 = str((Path(tmpdir) / "pt_stat.h5").resolve()) + with h5py.File(dp_h5, "w"): + pass + with h5py.File(pt_h5, "w"): + pass + dp_stat_path = DPPath(dp_h5, "a") + pt_stat_path = DPPath(pt_h5, "a") + + # 1. Compute stats and save to file + self.dp_model.atomic_model.compute_or_load_out_stat( + dp_merged, stat_file_path=dp_stat_path + ) + self.pt_model.atomic_model.compute_or_load_out_stat( + pt_merged, stat_file_path=pt_stat_path + ) + + dp_bias_after = to_numpy_array(self.dp_model.get_out_bias()) + pt_bias_after = torch_to_numpy(self.pt_model.get_out_bias()) + np.testing.assert_allclose( + dp_bias_after, pt_bias_after, rtol=1e-10, atol=1e-10 + ) - dp_bias_after = to_numpy_array(self.dp_model.get_out_bias()) - pt_bias_after = torch_to_numpy(self.pt_model.get_out_bias()) - np.testing.assert_allclose(dp_bias_after, pt_bias_after, rtol=1e-10, atol=1e-10) + # Verify bias actually changed (not still all zeros) + self.assertFalse( + np.allclose(dp_bias_after, dp_bias_before), + "compute_or_load_out_stat did not change the bias", + ) - # Verify bias actually changed (not still all zeros) - self.assertFalse( - np.allclose(dp_bias_after, dp_bias_before), - "compute_or_load_out_stat did not change the bias", - ) + # 2. Verify both backends saved the same file content + with h5py.File(dp_h5, "r") as dp_f, h5py.File(pt_h5, "r") as pt_f: + dp_keys = sorted(dp_f.keys()) + pt_keys = sorted(pt_f.keys()) + self.assertEqual(dp_keys, pt_keys) + for key in dp_keys: + np.testing.assert_allclose( + np.array(dp_f[key]), + np.array(pt_f[key]), + rtol=1e-10, + atol=1e-10, + err_msg=f"Stat file content mismatch for key {key}", + ) + + # 3. Reset biases to zero, then load from file + zero_bias = np.zeros_like(dp_bias_after) + self.dp_model.set_out_bias(zero_bias) + self.pt_model.set_out_bias(numpy_to_torch(zero_bias)) + + # Use a callable that raises to ensure it loads from file, not recomputes + def raise_error(): + raise RuntimeError("Should not recompute — should load from file") + + self.dp_model.atomic_model.compute_or_load_out_stat( + raise_error, stat_file_path=dp_stat_path + ) + self.pt_model.atomic_model.compute_or_load_out_stat( + raise_error, stat_file_path=pt_stat_path + ) + + dp_bias_loaded = to_numpy_array(self.dp_model.get_out_bias()) + pt_bias_loaded = torch_to_numpy(self.pt_model.get_out_bias()) + + # Loaded biases should match between backends + np.testing.assert_allclose( + dp_bias_loaded, pt_bias_loaded, rtol=1e-10, atol=1e-10 + ) + # Loaded biases should match the originally computed biases + np.testing.assert_allclose( + dp_bias_loaded, dp_bias_after, rtol=1e-10, atol=1e-10 + ) From b67accc42183f439e124764e16790c9ec808a716 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 15 Feb 2026 13:25:53 +0800 Subject: [PATCH 54/60] add guard for eval_descriptor and eval_fitting_last_layer --- deepmd/dpmodel/atomic_model/dp_atomic_model.py | 10 ++++++++++ deepmd/pt/model/atomic_model/dp_atomic_model.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 86e13e14df..1b5498e347 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -139,6 +139,11 @@ def set_eval_descriptor_hook(self, enable: bool) -> None: def eval_descriptor(self) -> Array: """Evaluate the descriptor by concatenating cached results.""" + if not self.eval_descriptor_list: + raise RuntimeError( + "eval_descriptor_list is empty. " + "Call set_eval_descriptor_hook(True) and perform a forward pass first." + ) xp = array_api_compat.array_namespace(self.eval_descriptor_list[0]) return xp.concat(self.eval_descriptor_list, axis=0) @@ -150,6 +155,11 @@ def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: def eval_fitting_last_layer(self) -> Array: """Evaluate the fitting last layer output by concatenating cached results.""" + if not self.eval_fitting_last_layer_list: + raise RuntimeError( + "eval_fitting_last_layer_list is empty. " + "Call set_eval_fitting_last_layer_hook(True) and perform a forward pass first." + ) xp = array_api_compat.array_namespace(self.eval_fitting_last_layer_list[0]) return xp.concat(self.eval_fitting_last_layer_list, axis=0) diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index af2e8954df..a71427d5e9 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -83,6 +83,11 @@ def set_eval_descriptor_hook(self, enable: bool) -> None: def eval_descriptor(self) -> torch.Tensor: """Evaluate the descriptor.""" + if not self.eval_descriptor_list: + raise RuntimeError( + "eval_descriptor_list is empty. " + "Call set_eval_descriptor_hook(True) and perform a forward pass first." + ) return torch.concat(self.eval_descriptor_list) def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: @@ -94,6 +99,11 @@ def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: def eval_fitting_last_layer(self) -> torch.Tensor: """Evaluate the fitting last layer output.""" + if not self.eval_fitting_last_layer_list: + raise RuntimeError( + "eval_fitting_last_layer_list is empty. " + "Call set_eval_fitting_last_layer_hook(True) and perform a forward pass first." + ) return torch.concat(self.eval_fitting_last_layer_list) @torch.jit.export From 19df985e3f651ed3c64c66809332320dfddc4cc6 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 15 Feb 2026 15:43:25 +0800 Subject: [PATCH 55/60] fix issues --- deepmd/pt_expt/model/ener_model.py | 10 +- source/tests/consistent/model/test_ener.py | 3 - source/tests/pt_expt/model/test_ener_model.py | 217 ++++++++++++++---- 3 files changed, 179 insertions(+), 51 deletions(-) diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index a5a774a6da..38e3976303 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -122,8 +122,8 @@ def forward_lower( ------- torch.nn.Module A traced module whose ``forward`` accepts - ``(extended_coord, extended_atype, nlist, mapping)`` and - returns a dict with the same keys as ``_forward_lower``. + ``(extended_coord, extended_atype, nlist, mapping, fparam, aparam)`` + and returns a dict with the same keys as ``_forward_lower``. """ model = self @@ -132,6 +132,8 @@ def fn( extended_atype: torch.Tensor, nlist: torch.Tensor, mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: extended_coord = extended_coord.detach().requires_grad_(True) return model._forward_lower( @@ -144,4 +146,6 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn)(extended_coord, extended_atype, nlist, mapping) + return make_fx(fn)( + extended_coord, extended_atype, nlist, mapping, fparam, aparam + ) diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 3f79ebdc64..5c72a4f553 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -123,7 +123,6 @@ def data(self) -> dict: pd_class = EnergyModelPD pt_expt_class = EnergyModelPTExpt jax_class = EnergyModelJAX - pd_class = EnergyModelPD args = model_args() def get_reference_backend(self): @@ -979,7 +978,6 @@ def test_get_default_fparam(self) -> None: def test_change_out_bias(self) -> None: """change_out_bias should produce consistent bias on dp and pt.""" nframes = 2 - nloc = 6 # Use realistic coords (from setUp, tiled for 2 frames) coords_2f = np.tile(self.coords, (nframes, 1, 1)) # (2, 6, 3) atype_2f = np.array([[0, 0, 1, 1, 1, 1], [0, 1, 1, 0, 1, 1]], dtype=np.int32) @@ -1182,7 +1180,6 @@ def test_compute_or_load_out_stat(self) -> None: ) nframes = 2 - nloc = 6 coords_2f = np.tile(self.coords, (nframes, 1, 1)) atype_2f = np.array([[0, 0, 1, 1, 1, 1], [0, 1, 1, 0, 1, 1]], dtype=np.int32) box_2f = np.tile(self.box.reshape(1, 3, 3), (nframes, 1, 1)) diff --git a/source/tests/pt_expt/model/test_ener_model.py b/source/tests/pt_expt/model/test_ener_model.py index 7ea8f96bfc..d91177ab51 100644 --- a/source/tests/pt_expt/model/test_ener_model.py +++ b/source/tests/pt_expt/model/test_ener_model.py @@ -60,7 +60,11 @@ def setUp(self) -> None: [[0, 0, 0, 1, 1]], dtype=torch.int64, device=self.device ) - def _make_model(self) -> EnergyModel: + def _make_model( + self, + numb_fparam: int = 0, + numb_aparam: int = 0, + ) -> EnergyModel: ds = DescrptSeA( self.rcut, self.rcut_smth, @@ -71,6 +75,8 @@ def _make_model(self) -> EnergyModel: self.nt, ds.get_dim_out(), 1, + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, mixed_types=ds.mixed_types(), seed=GLOBAL_SEED, ).to(self.device) @@ -98,17 +104,8 @@ def test_output_shapes(self) -> None: self.assertEqual(ret["force"].shape, (1, self.natoms, 3)) self.assertEqual(ret["virial"].shape, (1, 9)) - def test_forward_lower_exportable(self) -> None: - """Test that EnergyModel.forward_lower returns an exportable module. - - forward_lower() uses make_fx to trace through torch.autograd.grad, - decomposing the backward pass into primitive ops. The returned module - can be passed directly to torch.export.export. - """ - md = self._make_model() - md.eval() - - # Prepare extended coords and neighbor list using dpmodel utilities + def _prepare_lower_inputs(self): + """Build extended coords, atype, nlist, mapping as torch tensors.""" coord_np = self.coord.detach().cpu().numpy() atype_np = self.atype.detach().cpu().numpy() cell_np = self.cell.reshape(1, 9).detach().cpu().numpy() @@ -131,8 +128,6 @@ def test_forward_lower_exportable(self) -> None: distinguish_types=True, ) extended_coord = extended_coord.reshape(1, -1, 3) - - # Convert to torch tensors ext_coord = torch.tensor( extended_coord, dtype=torch.float64, @@ -145,63 +140,195 @@ def test_forward_lower_exportable(self) -> None: ) nlist_t = torch.tensor(nlist, dtype=torch.int64, device=self.device) mapping_t = torch.tensor(mapping, dtype=torch.int64, device=self.device) + return ext_coord, ext_atype, nlist_t, mapping_t + + def test_forward_lower_exportable(self) -> None: + """Test that EnergyModel.forward_lower returns an exportable module. + + forward_lower() uses make_fx to trace through torch.autograd.grad, + decomposing the backward pass into primitive ops. The returned module + can be passed directly to torch.export.export. + + The test builds a model with numb_fparam > 0 and numb_aparam > 0 and + verifies that: + 1. The traced / exported module reproduces eager results (zero params). + 2. The traced / exported module reproduces eager results with non-zero + fparam and aparam (ruling out baked-in constants). + 3. Changing fparam or aparam at runtime actually changes the output. + """ + numb_fparam = 2 + numb_aparam = 3 + md = self._make_model( + numb_fparam=numb_fparam, + numb_aparam=numb_aparam, + ) + md.eval() - # Eager reference via _forward_lower - ret0 = md._forward_lower( + ext_coord, ext_atype, nlist_t, mapping_t = self._prepare_lower_inputs() + nframes = ext_coord.shape[0] + nloc = self.natoms + output_keys = ("energy", "extended_force", "virial", "extended_virial") + + fparam_zero = torch.zeros( + nframes, + numb_fparam, + dtype=torch.float64, + device=self.device, + ) + aparam_zero = torch.zeros( + nframes, + nloc, + numb_aparam, + dtype=torch.float64, + device=self.device, + ) + + # --- eager reference with zero params --- + ret_eager_zero = md._forward_lower( ext_coord.requires_grad_(True), ext_atype, nlist_t, mapping_t, + fparam=fparam_zero, + aparam=aparam_zero, do_atomic_virial=True, ) - self.assertIn("energy", ret0) - self.assertIn("extended_force", ret0) - self.assertIn("virial", ret0) - self.assertIn("extended_virial", ret0) + for key in output_keys: + self.assertIn(key, ret_eager_zero) - # forward_lower returns a traced module + # --- trace and export --- traced = md.forward_lower( ext_coord, ext_atype, nlist_t, mapping_t, + fparam=fparam_zero, + aparam=aparam_zero, do_atomic_virial=True, ) self.assertIsInstance(traced, torch.nn.Module) - # The traced module should be directly exportable exported = torch.export.export( traced, - (ext_coord, ext_atype, nlist_t, mapping_t), + (ext_coord, ext_atype, nlist_t, mapping_t, fparam_zero, aparam_zero), strict=False, ) self.assertIsNotNone(exported) - # Verify exported model produces same output - ret1 = exported.module()(ext_coord, ext_atype, nlist_t, mapping_t) - np.testing.assert_allclose( - ret0["energy"].detach().cpu().numpy(), - ret1["energy"].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, + # --- verify traced/exported match eager (zero params) --- + ret_traced_zero = traced( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam_zero, + aparam_zero, ) - np.testing.assert_allclose( - ret0["extended_force"].detach().cpu().numpy(), - ret1["extended_force"].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, + ret_exported_zero = exported.module()( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam_zero, + aparam_zero, ) - np.testing.assert_allclose( - ret0["virial"].detach().cpu().numpy(), - ret1["virial"].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, + for key in output_keys: + np.testing.assert_allclose( + ret_eager_zero[key].detach().cpu().numpy(), + ret_traced_zero[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"traced vs eager (zero params): {key}", + ) + np.testing.assert_allclose( + ret_eager_zero[key].detach().cpu().numpy(), + ret_exported_zero[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"exported vs eager (zero params): {key}", + ) + + # --- verify traced/exported match eager (non-zero params) --- + fparam_nz = torch.ones( + nframes, + numb_fparam, + dtype=torch.float64, + device=self.device, ) - np.testing.assert_allclose( - ret0["extended_virial"].detach().cpu().numpy(), - ret1["extended_virial"].detach().cpu().numpy(), - rtol=1e-10, - atol=1e-10, + aparam_nz = torch.ones( + nframes, + nloc, + numb_aparam, + dtype=torch.float64, + device=self.device, + ) + ret_eager_nz = md._forward_lower( + ext_coord.requires_grad_(True), + ext_atype, + nlist_t, + mapping_t, + fparam=fparam_nz, + aparam=aparam_nz, + do_atomic_virial=True, + ) + ret_traced_nz = traced( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam_nz, + aparam_nz, + ) + ret_exported_nz = exported.module()( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam_nz, + aparam_nz, + ) + for key in output_keys: + np.testing.assert_allclose( + ret_eager_nz[key].detach().cpu().numpy(), + ret_traced_nz[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"traced vs eager (non-zero params): {key}", + ) + np.testing.assert_allclose( + ret_eager_nz[key].detach().cpu().numpy(), + ret_exported_nz[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"exported vs eager (non-zero params): {key}", + ) + + # --- verify fparam is dynamic (changing it changes the output) --- + self.assertFalse( + np.allclose( + ret_traced_zero["energy"].detach().cpu().numpy(), + ret_traced_nz["energy"].detach().cpu().numpy(), + ), + "Changing fparam did not change output — " + "fparam may be baked in as a constant", + ) + + # --- verify aparam is dynamic (changing it changes the output) --- + ret_traced_ap = traced( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam_zero, + aparam_nz, + ) + self.assertFalse( + np.allclose( + ret_traced_zero["energy"].detach().cpu().numpy(), + ret_traced_ap["energy"].detach().cpu().numpy(), + ), + "Changing aparam did not change output — " + "aparam may be baked in as a constant", ) def test_dp_consistency(self) -> None: From fc0be626d7014b46c0db6768550e210b336325ad Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 15 Feb 2026 16:10:34 +0800 Subject: [PATCH 56/60] remove eval_ hooks --- .../dpmodel/atomic_model/dp_atomic_model.py | 44 ----------------- deepmd/dpmodel/model/dp_model.py | 19 -------- source/tests/consistent/model/test_ener.py | 48 ------------------- 3 files changed, 111 deletions(-) diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 1b5498e347..0f5b12bc9c 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -3,8 +3,6 @@ Any, ) -import array_api_compat - from deepmd.dpmodel.array_api import ( Array, ) @@ -56,10 +54,6 @@ def __init__( if hasattr(self.fitting, "reinit_exclude"): self.fitting.reinit_exclude(self.atom_exclude_types) self.type_map = type_map - self.enable_eval_descriptor_hook = False - self.enable_eval_fitting_last_layer_hook = False - self.eval_descriptor_list: list[Array] = [] - self.eval_fitting_last_layer_list: list[Array] = [] super().init_out_stat() def fitting_output_def(self) -> FittingOutputDef: @@ -132,37 +126,6 @@ def enable_compression( check_frequency, ) - def set_eval_descriptor_hook(self, enable: bool) -> None: - """Set the hook for evaluating descriptor and clear the cache.""" - self.enable_eval_descriptor_hook = enable - self.eval_descriptor_list.clear() - - def eval_descriptor(self) -> Array: - """Evaluate the descriptor by concatenating cached results.""" - if not self.eval_descriptor_list: - raise RuntimeError( - "eval_descriptor_list is empty. " - "Call set_eval_descriptor_hook(True) and perform a forward pass first." - ) - xp = array_api_compat.array_namespace(self.eval_descriptor_list[0]) - return xp.concat(self.eval_descriptor_list, axis=0) - - def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: - """Set the hook for evaluating fitting last layer output and clear the cache.""" - self.enable_eval_fitting_last_layer_hook = enable - self.fitting.set_return_middle_output(enable) - self.eval_fitting_last_layer_list.clear() - - def eval_fitting_last_layer(self) -> Array: - """Evaluate the fitting last layer output by concatenating cached results.""" - if not self.eval_fitting_last_layer_list: - raise RuntimeError( - "eval_fitting_last_layer_list is empty. " - "Call set_eval_fitting_last_layer_hook(True) and perform a forward pass first." - ) - xp = array_api_compat.array_namespace(self.eval_fitting_last_layer_list[0]) - return xp.concat(self.eval_fitting_last_layer_list, axis=0) - def forward_atomic( self, extended_coord: Array, @@ -203,8 +166,6 @@ def forward_atomic( nlist, mapping=mapping, ) - if self.enable_eval_descriptor_hook: - self.eval_descriptor_list.append(descriptor) ret = self.fitting( descriptor, atype, @@ -214,11 +175,6 @@ def forward_atomic( fparam=fparam, aparam=aparam, ) - if self.enable_eval_fitting_last_layer_hook: - assert "middle_output" in ret, ( - "eval_fitting_last_layer not supported for this fitting net!" - ) - self.eval_fitting_last_layer_list.append(ret.pop("middle_output")) return ret def change_type_map( diff --git a/deepmd/dpmodel/model/dp_model.py b/deepmd/dpmodel/model/dp_model.py index fcf7c88b86..0dcf6358f9 100644 --- a/deepmd/dpmodel/model/dp_model.py +++ b/deepmd/dpmodel/model/dp_model.py @@ -1,9 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -from deepmd.dpmodel.array_api import ( - Array, -) from deepmd.dpmodel.descriptor.base_descriptor import ( BaseDescriptor, ) @@ -55,19 +52,3 @@ def get_fitting_net(self) -> BaseFitting: def get_descriptor(self) -> BaseDescriptor: """Get the descriptor.""" return self.atomic_model.descriptor - - def set_eval_descriptor_hook(self, enable: bool) -> None: - """Set the hook for evaluating descriptor.""" - self.atomic_model.set_eval_descriptor_hook(enable) - - def eval_descriptor(self) -> Array: - """Evaluate the descriptor.""" - return self.atomic_model.eval_descriptor() - - def set_eval_fitting_last_layer_hook(self, enable: bool) -> None: - """Set the hook for evaluating fitting last layer output.""" - self.atomic_model.set_eval_fitting_last_layer_hook(enable) - - def eval_fitting_last_layer(self) -> Array: - """Evaluate the fitting last layer output.""" - return self.atomic_model.eval_fitting_last_layer() diff --git a/source/tests/consistent/model/test_ener.py b/source/tests/consistent/model/test_ener.py index 5c72a4f553..6cceb1c640 100644 --- a/source/tests/consistent/model/test_ener.py +++ b/source/tests/consistent/model/test_ener.py @@ -757,54 +757,6 @@ def test_forward_common_lower_alias(self) -> None: for key in ret_call: np.testing.assert_equal(ret_call[key], ret_fc[key]) - def test_eval_descriptor(self) -> None: - """eval_descriptor should produce consistent results across dp and pt.""" - # dpmodel - self.dp_model.set_eval_descriptor_hook(True) - self.dp_model.call_lower( - self.extended_coord, - self.extended_atype, - self.nlist, - self.mapping, - ) - dp_desc = self.dp_model.eval_descriptor() - - # pt - self.pt_model.set_eval_descriptor_hook(True) - self.pt_model.forward_common_lower( - numpy_to_torch(self.extended_coord), - numpy_to_torch(self.extended_atype), - numpy_to_torch(self.nlist), - numpy_to_torch(self.mapping), - ) - pt_desc = torch_to_numpy(self.pt_model.eval_descriptor()) - - np.testing.assert_allclose(dp_desc, pt_desc, rtol=1e-10, atol=1e-10) - - def test_eval_fitting_last_layer(self) -> None: - """eval_fitting_last_layer should produce consistent results across dp and pt.""" - # dpmodel - self.dp_model.set_eval_fitting_last_layer_hook(True) - self.dp_model.call_lower( - self.extended_coord, - self.extended_atype, - self.nlist, - self.mapping, - ) - dp_fl = self.dp_model.eval_fitting_last_layer() - - # pt - self.pt_model.set_eval_fitting_last_layer_hook(True) - self.pt_model.forward_common_lower( - numpy_to_torch(self.extended_coord), - numpy_to_torch(self.extended_atype), - numpy_to_torch(self.nlist), - numpy_to_torch(self.mapping), - ) - pt_fl = torch_to_numpy(self.pt_model.eval_fitting_last_layer()) - - np.testing.assert_allclose(dp_fl, pt_fl, rtol=1e-10, atol=1e-10) - def test_model_output_def(self) -> None: """model_output_def should return the same keys and shapes on dp and pt.""" dp_def = self.dp_model.model_output_def().get_data() From c15212d7ac6d947f43f92af6ffa2b2e3b7f8339f Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 15 Feb 2026 16:39:40 +0800 Subject: [PATCH 57/60] rm eval_return_middle_output --- deepmd/dpmodel/fitting/general_fitting.py | 27 ----------------------- 1 file changed, 27 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index cedc0eb916..260be619fd 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -195,7 +195,6 @@ def __init__( self.default_fparam_tensor = np.array(self.default_fparam, dtype=self.prec) else: self.default_fparam_tensor = None - self.eval_return_middle_output = False # init networks in_dim = ( self.dim_descrpt @@ -411,10 +410,6 @@ def __getitem__(self, key: str) -> Any: else: raise KeyError(key) - def set_return_middle_output(self, return_middle_output: bool = True) -> None: - """Set whether to return the output of the last hidden layer.""" - self.eval_return_middle_output = return_middle_output - def reinit_exclude( self, exclude_types: list[int] = [], @@ -596,12 +591,6 @@ def _call_common( dtype=get_xp_precision(xp, self.precision), device=array_api_compat.device(descriptor), ) - if self.eval_return_middle_output: - outs_middle = xp.zeros( - [nf, nloc, self.neuron[-1]], - dtype=get_xp_precision(xp, self.precision), - device=array_api_compat.device(descriptor), - ) for type_i in range(self.ntypes): mask = xp.tile( xp.reshape((atype == type_i), (nf, nloc, 1)), (1, 1, net_dim_out) @@ -617,26 +606,10 @@ def _call_common( mask, atom_property, xp.zeros_like(atom_property) ) outs = outs + atom_property # Shape is [nframes, natoms[0], 1] - if self.eval_return_middle_output: - middle_output_type = self.nets[(type_i,)].call_until_last(xx) - middle_mask = xp.tile( - xp.reshape((atype == type_i), (nf, nloc, 1)), - (1, 1, self.neuron[-1]), - ) - middle_output_type = xp.where( - middle_mask, - middle_output_type, - xp.zeros_like(middle_output_type), - ) - outs_middle = outs_middle + middle_output_type - if self.eval_return_middle_output: - results["middle_output"] = outs_middle else: outs = self.nets[()](xx) if xx_zeros is not None: outs -= self.nets[()](xx_zeros) - if self.eval_return_middle_output: - results["middle_output"] = self.nets[()].call_until_last(xx) outs += xp.reshape( xp.take( xp.astype(self.bias_atom_e[...], outs.dtype), From 7e51e9d4fb3a8fd6fafd8bf892926f767f50a341 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 16 Feb 2026 20:01:58 +0800 Subject: [PATCH 58/60] change forward_lower to forward_lower_exportable --- deepmd/pt_expt/model/ener_model.py | 22 ++++++++++++++++++- source/tests/pt_expt/model/test_ener_model.py | 2 +- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index 38e3976303..e38365f8f1 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -106,6 +106,26 @@ def forward_lower( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, + ) -> dict[str, torch.Tensor]: + return self._forward_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + def forward_lower_exportable( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, ) -> torch.nn.Module: """Trace ``_forward_lower`` into an exportable module. @@ -123,7 +143,7 @@ def forward_lower( torch.nn.Module A traced module whose ``forward`` accepts ``(extended_coord, extended_atype, nlist, mapping, fparam, aparam)`` - and returns a dict with the same keys as ``_forward_lower``. + and returns a dict with the same keys as ``forward_lower``. """ model = self diff --git a/source/tests/pt_expt/model/test_ener_model.py b/source/tests/pt_expt/model/test_ener_model.py index d91177ab51..6e5006661e 100644 --- a/source/tests/pt_expt/model/test_ener_model.py +++ b/source/tests/pt_expt/model/test_ener_model.py @@ -197,7 +197,7 @@ def test_forward_lower_exportable(self) -> None: self.assertIn(key, ret_eager_zero) # --- trace and export --- - traced = md.forward_lower( + traced = md.forward_lower_exportable( ext_coord, ext_atype, nlist_t, From 7add23856506e8c6084a00a72cbe57c4c0943f5c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 16 Feb 2026 20:26:37 +0800 Subject: [PATCH 59/60] fix the squeeze issue for atomic virial, pt_expt backend only --- deepmd/pt_expt/model/ener_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index e38365f8f1..5547543d27 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -58,7 +58,7 @@ def forward( if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: - model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] return model_predict @@ -91,7 +91,7 @@ def _forward_lower( model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["extended_virial"] = model_ret["energy_derv_c"].squeeze( - -3 + -2 ) if "mask" in model_ret: model_predict["mask"] = model_ret["mask"] From 3600076e0ad1151eedd97f19328795cf675bdfef Mon Sep 17 00:00:00 2001 From: Han Wang Date: Mon, 16 Feb 2026 20:43:25 +0800 Subject: [PATCH 60/60] fix squeeze bug --- deepmd/dpmodel/model/ener_model.py | 2 +- deepmd/pd/model/model/ener_model.py | 4 ++-- deepmd/pt/model/model/dipole_model.py | 4 ++-- deepmd/pt/model/model/dp_linear_model.py | 4 ++-- deepmd/pt/model/model/dp_zbl_model.py | 4 ++-- deepmd/pt/model/model/ener_model.py | 6 +++--- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/deepmd/dpmodel/model/ener_model.py b/deepmd/dpmodel/model/ener_model.py index ef1e986f26..27d6db811e 100644 --- a/deepmd/dpmodel/model/ener_model.py +++ b/deepmd/dpmodel/model/ener_model.py @@ -66,7 +66,7 @@ def translated_output_def(self) -> dict[str, Any]: output_def["virial"] = out_def_data["energy_derv_c_redu"] output_def["virial"].squeeze(-2) output_def["atom_virial"] = out_def_data["energy_derv_c"] - output_def["atom_virial"].squeeze(-3) + output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] if self._enable_hessian: diff --git a/deepmd/pd/model/model/ener_model.py b/deepmd/pd/model/model/ener_model.py index 3a57e79d3a..072c60793e 100644 --- a/deepmd/pd/model/model/ener_model.py +++ b/deepmd/pd/model/model/ener_model.py @@ -60,7 +60,7 @@ def translated_output_def(self) -> dict: output_def["virial"] = out_def_data["energy_derv_c_redu"] output_def["virial"].squeeze(-2) output_def["atom_virial"] = out_def_data["energy_derv_c"] - output_def["atom_virial"].squeeze(-3) + output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] return output_def @@ -140,7 +140,7 @@ def forward_lower( if do_atomic_virial: model_predict["extended_virial"] = model_ret[ "energy_derv_c" - ].squeeze(-3) + ].squeeze(-2) else: model_predict["extended_virial"] = paddle.zeros( [model_predict["energy"].shape[0], 1, 9], dtype=paddle.float64 diff --git a/deepmd/pt/model/model/dipole_model.py b/deepmd/pt/model/model/dipole_model.py index c6813ce079..5cfebb4b03 100644 --- a/deepmd/pt/model/model/dipole_model.py +++ b/deepmd/pt/model/model/dipole_model.py @@ -47,7 +47,7 @@ def translated_output_def(self) -> dict[str, Any]: output_def["virial"] = out_def_data["dipole_derv_c_redu"] output_def["virial"].squeeze(-2) output_def["atom_virial"] = out_def_data["dipole_derv_c"] - output_def["atom_virial"].squeeze(-3) + output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] return output_def @@ -122,7 +122,7 @@ def forward_lower( if do_atomic_virial: model_predict["extended_virial"] = model_ret[ "dipole_derv_c" - ].squeeze(-3) + ].squeeze(-2) else: model_predict = model_ret return model_predict diff --git a/deepmd/pt/model/model/dp_linear_model.py b/deepmd/pt/model/model/dp_linear_model.py index b43f849258..b95f568cb1 100644 --- a/deepmd/pt/model/model/dp_linear_model.py +++ b/deepmd/pt/model/model/dp_linear_model.py @@ -52,7 +52,7 @@ def translated_output_def(self) -> dict[str, OutputVariableDef]: output_def["virial"] = out_def_data["energy_derv_c_redu"] output_def["virial"].squeeze(-2) output_def["atom_virial"] = out_def_data["energy_derv_c"] - output_def["atom_virial"].squeeze(-3) + output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] return output_def @@ -83,7 +83,7 @@ def forward( if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: - model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) else: model_predict["force"] = model_ret["dforce"] if "mask" in model_ret: diff --git a/deepmd/pt/model/model/dp_zbl_model.py b/deepmd/pt/model/model/dp_zbl_model.py index d533cbe125..07f0732687 100644 --- a/deepmd/pt/model/model/dp_zbl_model.py +++ b/deepmd/pt/model/model/dp_zbl_model.py @@ -49,7 +49,7 @@ def translated_output_def(self) -> dict[str, Any]: output_def["virial"] = out_def_data["energy_derv_c_redu"] output_def["virial"].squeeze(-2) output_def["atom_virial"] = out_def_data["energy_derv_c"] - output_def["atom_virial"].squeeze(-3) + output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] return output_def @@ -80,7 +80,7 @@ def forward( if self.do_grad_c("energy"): model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: - model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-3) + model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze(-2) else: model_predict["force"] = model_ret["dforce"] if "mask" in model_ret: diff --git a/deepmd/pt/model/model/ener_model.py b/deepmd/pt/model/model/ener_model.py index 8f8a3cbad7..36beb33ff6 100644 --- a/deepmd/pt/model/model/ener_model.py +++ b/deepmd/pt/model/model/ener_model.py @@ -83,7 +83,7 @@ def translated_output_def(self) -> dict[str, Any]: output_def["virial"] = out_def_data["energy_derv_c_redu"] output_def["virial"].squeeze(-2) output_def["atom_virial"] = out_def_data["energy_derv_c"] - output_def["atom_virial"].squeeze(-3) + output_def["atom_virial"].squeeze(-2) if "mask" in out_def_data: output_def["mask"] = out_def_data["mask"] if self._hessian_enabled: @@ -117,7 +117,7 @@ def forward( model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) if do_atomic_virial: model_predict["atom_virial"] = model_ret["energy_derv_c"].squeeze( - -3 + -2 ) else: model_predict["force"] = model_ret["dforce"] @@ -164,7 +164,7 @@ def forward_lower( if do_atomic_virial: model_predict["extended_virial"] = model_ret[ "energy_derv_c" - ].squeeze(-3) + ].squeeze(-2) else: assert model_ret["dforce"] is not None model_predict["dforce"] = model_ret["dforce"]