From ec2e031e4384a66e96c2d64d45f4dec897bb769d Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 07:48:20 +0800 Subject: [PATCH 01/35] implement pytorch-exportable for se_e2_a descriptor --- deepmd/backend/pt_expt.py | 126 ++++++++++++++++ deepmd/dpmodel/descriptor/se_e2_a.py | 6 +- deepmd/pt_expt/__init__.py | 1 + deepmd/pt_expt/descriptor/__init__.py | 8 ++ deepmd/pt_expt/descriptor/se_e2_a.py | 101 +++++++++++++ deepmd/pt_expt/utils/__init__.py | 1 + deepmd/pt_expt/utils/network.py | 130 +++++++++++++++++ source/tests/consistent/common.py | 80 ++++++++++- source/tests/consistent/descriptor/common.py | 31 +++- .../consistent/descriptor/test_se_e2_a.py | 60 ++++++++ source/tests/pt_expt/__init__.py | 1 + source/tests/pt_expt/model/__init__.py | 1 + source/tests/pt_expt/model/test_se_e2_a.py | 135 ++++++++++++++++++ 13 files changed, 676 insertions(+), 5 deletions(-) create mode 100644 deepmd/backend/pt_expt.py create mode 100644 deepmd/pt_expt/__init__.py create mode 100644 deepmd/pt_expt/descriptor/__init__.py create mode 100644 deepmd/pt_expt/descriptor/se_e2_a.py create mode 100644 deepmd/pt_expt/utils/__init__.py create mode 100644 deepmd/pt_expt/utils/network.py create mode 100644 source/tests/pt_expt/__init__.py create mode 100644 source/tests/pt_expt/model/__init__.py create mode 100644 source/tests/pt_expt/model/test_se_e2_a.py diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py new file mode 100644 index 0000000000..38745c690c --- /dev/null +++ b/deepmd/backend/pt_expt.py @@ -0,0 +1,126 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from collections.abc import ( + Callable, +) +from importlib.util import ( + find_spec, +) +from typing import ( + TYPE_CHECKING, + ClassVar, +) + +from deepmd.backend.backend import ( + Backend, +) + +if TYPE_CHECKING: + from argparse import ( + Namespace, + ) + + from deepmd.infer.deep_eval import ( + DeepEvalBackend, + ) + from deepmd.utils.neighbor_stat import ( + NeighborStat, + ) + + +@Backend.register("pt-expt") +@Backend.register("pytorch-exportable") +class PyTorchExportableBackend(Backend): + """PyTorch exportable backend.""" + + name = "PyTorch Exportable" + """The formal name of the backend.""" + features: ClassVar[Backend.Feature] = ( + Backend.Feature.ENTRY_POINT + | Backend.Feature.DEEP_EVAL + | Backend.Feature.NEIGHBOR_STAT + | Backend.Feature.IO + ) + """The features of the backend.""" + suffixes: ClassVar[list[str]] = [".pth", ".pt"] + """The suffixes of the backend.""" + + def is_available(self) -> bool: + """Check if the backend is available. + + Returns + ------- + bool + Whether the backend is available. + """ + return find_spec("torch") is not None + + @property + def entry_point_hook(self) -> Callable[["Namespace"], None]: + """The entry point hook of the backend. + + Returns + ------- + Callable[[Namespace], None] + The entry point hook of the backend. + """ + from deepmd.pt.entrypoints.main import main as deepmd_main + + return deepmd_main + + @property + def deep_eval(self) -> type["DeepEvalBackend"]: + """The Deep Eval backend of the backend. + + Returns + ------- + type[DeepEvalBackend] + The Deep Eval backend of the backend. + """ + from deepmd.pt.infer.deep_eval import DeepEval as DeepEvalPT + + return DeepEvalPT + + @property + def neighbor_stat(self) -> type["NeighborStat"]: + """The neighbor statistics of the backend. + + Returns + ------- + type[NeighborStat] + The neighbor statistics of the backend. + """ + from deepmd.pt.utils.neighbor_stat import ( + NeighborStat, + ) + + return NeighborStat + + @property + def serialize_hook(self) -> Callable[[str], dict]: + """The serialize hook to convert the model file to a dictionary. + + Returns + ------- + Callable[[str], dict] + The serialize hook of the backend. + """ + from deepmd.pt.utils.serialization import ( + serialize_from_file, + ) + + return serialize_from_file + + @property + def deserialize_hook(self) -> Callable[[str, dict], None]: + """The deserialize hook to convert the dictionary to a model file. + + Returns + ------- + Callable[[str, dict], None] + The deserialize hook of the backend. + """ + from deepmd.pt.utils.serialization import ( + deserialize_to_file, + ) + + return deserialize_to_file diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index c09a6cbdc3..a6b17bf69a 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -607,7 +607,11 @@ def call( sec = self.sel_cumsum ng = self.neuron[-1] - gr = xp.zeros([nf * nloc, ng, 4], dtype=self.dstd.dtype) + gr = xp.zeros( + [nf * nloc, ng, 4], + dtype=self.dstd.dtype, + device=array_api_compat.device(coord_ext), + ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # merge nf and nloc axis, so for type_one_side == False, # we don't require atype is the same in all frames diff --git a/deepmd/pt_expt/__init__.py b/deepmd/pt_expt/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt_expt/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py new file mode 100644 index 0000000000..fdac48ed41 --- /dev/null +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -0,0 +1,8 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from .se_e2_a import ( + DescrptSeA, +) + +__all__ = [ + "DescrptSeA", +] diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py new file mode 100644 index 0000000000..4334011ec3 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -0,0 +1,101 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch # noqa: TID253 + +from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP +from deepmd.pt.model.descriptor.base_descriptor import ( # noqa: TID253 + BaseDescriptor, +) +from deepmd.pt.utils import ( # noqa: TID253 + env, +) +from deepmd.pt.utils.exclude_mask import ( # noqa: TID253 + PairExcludeMask, +) +from deepmd.pt_expt.utils.network import ( + NetworkCollection, +) + + +@BaseDescriptor.register("se_e2_a_expt") +@BaseDescriptor.register("se_a_expt") +class DescrptSeA(DescrptSeADP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptSeADP.__init__(self, *args, **kwargs) + self._convert_state() + + def __setattr__(self, name: str, value: Any) -> None: + if name in {"davg", "dstd"} and "_buffers" in self.__dict__: + tensor = ( + None if value is None else torch.as_tensor(value, device=env.DEVICE) + ) + if name in self._buffers: + self._buffers[name] = tensor + return + return super().__setattr__(name, tensor) + if name == "embeddings" and "_modules" in self.__dict__: + if value is not None and not isinstance(value, torch.nn.Module): + if hasattr(value, "serialize"): + value = NetworkCollection.deserialize(value.serialize()) + elif isinstance(value, dict): + value = NetworkCollection.deserialize(value) + return super().__setattr__(name, value) + if name == "emask" and "_modules" in self.__dict__: + if value is not None and not isinstance(value, torch.nn.Module): + value = PairExcludeMask( + self.ntypes, exclude_types=list(value.get_exclude_types()) + ) + return super().__setattr__(name, value) + return super().__setattr__(name, value) + + def _convert_state(self) -> None: + if self.davg is not None: + davg = torch.as_tensor(self.davg, device=env.DEVICE) + if "davg" in self._buffers: + self._buffers["davg"] = davg + else: + if hasattr(self, "davg"): + delattr(self, "davg") + self.register_buffer("davg", davg) + if self.dstd is not None: + dstd = torch.as_tensor(self.dstd, device=env.DEVICE) + if "dstd" in self._buffers: + self._buffers["dstd"] = dstd + else: + if hasattr(self, "dstd"): + delattr(self, "dstd") + self.register_buffer("dstd", dstd) + if self.embeddings is not None: + self.embeddings = NetworkCollection.deserialize(self.embeddings.serialize()) + if self.emask is not None: + self.emask = PairExcludeMask( + self.ntypes, exclude_types=list(self.emask.get_exclude_types()) + ) + + def forward( + self, + nlist: torch.Tensor, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + extended_atype_embd: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + type_embedding: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + del extended_atype_embd, type_embedding + descrpt, rot_mat, g2, h2, sw = self.call( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + return descrpt, rot_mat, g2, h2, sw diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt_expt/utils/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py new file mode 100644 index 0000000000..f29d8970b3 --- /dev/null +++ b/deepmd/pt_expt/utils/network.py @@ -0,0 +1,130 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, + ClassVar, + Self, +) + +import numpy as np +import torch # noqa: TID253 + +from deepmd.dpmodel.common import ( + NativeOP, +) +from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP +from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP +from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP +from deepmd.dpmodel.utils.network import ( + make_embedding_network, + make_fitting_network, + make_multilayer_network, +) +from deepmd.pt.utils import ( # noqa: TID253 + env, +) + + +def _to_torch_array(value: Any) -> torch.Tensor | None: + if value is None: + return None + if torch.is_tensor(value): + return value + return torch.as_tensor(value, device=env.DEVICE) + + +class TorchArrayParam(torch.nn.Parameter): + def __new__(cls, data: Any = None, requires_grad: bool = True) -> Self: + return torch.nn.Parameter.__new__(cls, data, requires_grad) + + def __array__(self, dtype: Any | None = None) -> np.ndarray: + arr = self.detach().cpu().numpy() + if dtype is None: + return arr + return arr.astype(dtype) + + +class NativeLayer(NativeLayerDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + NativeLayerDP.__init__(self, *args, **kwargs) + for name in ("w", "b", "idt"): + if name in self._parameters or name in self._buffers: + continue + val = _to_torch_array(getattr(self, name)) + if val is None: + continue + if self.trainable: + if hasattr(self, name) and name not in self._parameters: + delattr(self, name) + self.register_parameter(name, TorchArrayParam(val, requires_grad=True)) + else: + if hasattr(self, name) and name not in self._buffers: + delattr(self, name) + self.register_buffer(name, val) + + def __setattr__(self, name: str, value: Any) -> None: + if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: + val = _to_torch_array(value) + if val is None: + return super().__setattr__(name, None) + if getattr(self, "trainable", False): + param = ( + value + if isinstance(value, TorchArrayParam) + else TorchArrayParam(val, requires_grad=True) + ) + if name in self._parameters: + self._parameters[name] = param + return + return super().__setattr__(name, param) + if name in self._buffers: + self._buffers[name] = val + return + return super().__setattr__(name, val) + return super().__setattr__(name, value) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +class NativeNet(make_multilayer_network(NativeLayer, NativeOP), torch.nn.Module): + def __init__(self, layers: list[dict] | None = None) -> None: + torch.nn.Module.__init__(self) + super().__init__(layers) + self.layers = torch.nn.ModuleList(self.layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): + pass + + +class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): + pass + + +class NetworkCollection(NetworkCollectionDP, torch.nn.Module): + NETWORK_TYPE_MAP: ClassVar[dict[str, type]] = { + "network": NativeNet, + "embedding_network": EmbeddingNet, + "fitting_network": FittingNet, + } + + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + super().__init__(*args, **kwargs) + self._module_networks = torch.nn.ModuleDict() + for idx, net in enumerate(self._networks): + if isinstance(net, torch.nn.Module): + self._module_networks[str(idx)] = net + + def __setitem__(self, key: int | tuple, value: Any) -> None: + super().__setitem__(key, value) + if isinstance(value, torch.nn.Module): + self._module_networks[str(self._convert_key(key))] = value + + +class LayerNorm(LayerNormDP, NativeLayer): + pass diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 88fad4e10b..3d60f6def0 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -41,6 +41,11 @@ INSTALLED_TF = Backend.get_backend("tensorflow")().is_available() INSTALLED_PT = Backend.get_backend("pytorch")().is_available() +try: + _PT_EXPT_BACKEND = Backend.get_backend("pytorch-exportable") +except (KeyError, RuntimeError): + _PT_EXPT_BACKEND = None +INSTALLED_PT_EXPT = _PT_EXPT_BACKEND is not None and _PT_EXPT_BACKEND().is_available() INSTALLED_JAX = Backend.get_backend("jax")().is_available() INSTALLED_PD = Backend.get_backend("paddle")().is_available() INSTALLED_ARRAY_API_STRICT = find_spec("array_api_strict") is not None @@ -67,6 +72,7 @@ "INSTALLED_JAX", "INSTALLED_PD", "INSTALLED_PT", + "INSTALLED_PT_EXPT", "INSTALLED_TF", "CommonTest", "CommonTest", @@ -86,6 +92,8 @@ class CommonTest(ABC): """Native DP model class.""" pt_class: ClassVar[type | None] """PyTorch model class.""" + pt_expt_class: ClassVar[type | None] + """PyTorch exportable model class.""" jax_class: ClassVar[type | None] """JAX model class.""" pd_class: ClassVar[type | None] @@ -99,6 +107,8 @@ class CommonTest(ABC): """Whether to skip the TensorFlow model.""" skip_pt: ClassVar[bool] = not INSTALLED_PT """Whether to skip the PyTorch model.""" + skip_pt_expt: ClassVar[bool] = not INSTALLED_PT_EXPT + """Whether to skip the PyTorch exportable model.""" # we may usually skip jax before jax is fully supported skip_jax: ClassVar[bool] = True """Whether to skip the JAX model.""" @@ -176,6 +186,16 @@ def eval_pt(self, pt_obj: Any) -> Any: The object of PT """ + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + """Evaluate the return value of PT exportable. + + Parameters + ---------- + pt_expt_obj : Any + The object of PT exportable + """ + raise NotImplementedError("Not implemented") + def eval_jax(self, jax_obj: Any) -> Any: """Evaluate the return value of JAX. @@ -212,9 +232,10 @@ class RefBackend(Enum): TF = 1 DP = 2 PT = 3 - PD = 4 - JAX = 5 - ARRAY_API_STRICT = 6 + PT_EXPT = 4 + PD = 5 + JAX = 6 + ARRAY_API_STRICT = 7 @abstractmethod def extract_ret(self, ret: Any, backend: RefBackend) -> tuple[np.ndarray, ...]: @@ -275,6 +296,11 @@ def get_dp_ret_serialization_from_cls(self, obj): data = obj.serialize() return ret, data + def get_pt_expt_ret_serialization_from_cls(self, obj): + ret = self.eval_pt_expt(obj) + data = obj.serialize() + return ret, data + def get_jax_ret_serialization_from_cls(self, obj): ret = self.eval_jax(obj) data = obj.serialize() @@ -301,6 +327,8 @@ def get_reference_backend(self): return self.RefBackend.TF if not self.skip_pt: return self.RefBackend.PT + if not self.skip_pt_expt and self.pt_expt_class is not None: + return self.RefBackend.PT_EXPT if not self.skip_jax: return self.RefBackend.JAX if not self.skip_pd: @@ -320,6 +348,11 @@ def get_reference_ret_serialization(self, ref: RefBackend): if ref == self.RefBackend.PT: obj = self.init_backend_cls(self.pt_class) return self.get_pt_ret_serialization_from_cls(obj) + if ref == self.RefBackend.PT_EXPT: + if self.pt_expt_class is None: + raise ValueError("PT exportable class is not set") + obj = self.init_backend_cls(self.pt_expt_class) + return self.get_pt_expt_ret_serialization_from_cls(obj) if ref == self.RefBackend.JAX: obj = self.init_backend_cls(self.jax_class) return self.get_jax_ret_serialization_from_cls(obj) @@ -456,6 +489,47 @@ def test_pt_self_consistent(self) -> None: else: self.assertEqual(rr1, rr2) + def test_pt_expt_consistent_with_ref(self) -> None: + """Test whether PT exportable and reference are consistent.""" + if self.skip_pt_expt or self.pt_expt_class is None: + self.skipTest("Unsupported backend") + ref_backend = self.get_reference_backend() + if ref_backend == self.RefBackend.PT_EXPT: + self.skipTest("Reference is self") + ret1, data1 = self.get_reference_ret_serialization(ref_backend) + ret1 = self.extract_ret(ret1, ref_backend) + obj = self.pt_expt_class.deserialize(data1) + ret2 = self.eval_pt_expt(obj) + ret2 = self.extract_ret(ret2, self.RefBackend.PT_EXPT) + data2 = obj.serialize() + if obj.__class__.__name__.startswith(("Polar", "Dipole", "DOS")): + common_keys = set(data1.keys()) & set(data2.keys()) + data1 = {k: data1[k] for k in common_keys} + data2 = {k: data2[k] for k in common_keys} + # drop @variables since they are not equal + data1.pop("@variables", None) + data2.pop("@variables", None) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2, strict=True): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + + def test_pt_expt_self_consistent(self) -> None: + """Test whether PT exportable is self consistent.""" + if self.skip_pt_expt or self.pt_expt_class is None: + self.skipTest("Unsupported backend") + obj1 = self.init_backend_cls(self.pt_expt_class) + ret1, data1 = self.get_pt_expt_ret_serialization_from_cls(obj1) + obj2 = self.pt_expt_class.deserialize(data1) + ret2, data2 = self.get_pt_expt_ret_serialization_from_cls(obj2) + np.testing.assert_equal(data1, data2) + for rr1, rr2 in zip(ret1, ret2, strict=True): + if isinstance(rr1, np.ndarray) and isinstance(rr2, np.ndarray): + np.testing.assert_allclose(rr1, rr2, rtol=self.rtol, atol=self.atol) + assert rr1.dtype == rr2.dtype, f"{rr1.dtype} != {rr2.dtype}" + else: + self.assertEqual(rr1, rr2) + def test_jax_consistent_with_ref(self) -> None: """Test whether JAX and reference are consistent.""" if self.skip_jax: diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 8af1c7ea64..7c8cbce744 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -21,10 +21,11 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, ) -if INSTALLED_PT: +if INSTALLED_PT or INSTALLED_PT_EXPT: import torch from deepmd.pt.utils.env import DEVICE as PT_DEVICE @@ -143,6 +144,34 @@ def eval_pt_descriptor( for x in pt_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) ] + def eval_pt_expt_descriptor( + self, + pt_expt_obj: Any, + natoms: np.ndarray, + coords: np.ndarray, + atype: np.ndarray, + box: np.ndarray, + mixed_types: bool = False, + ) -> Any: + ext_coords, ext_atype, mapping = extend_coord_with_ghosts( + torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), + torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), + torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), + pt_expt_obj.get_rcut(), + ) + nlist = build_neighbor_list( + ext_coords, + ext_atype, + natoms[0], + pt_expt_obj.get_rcut(), + pt_expt_obj.get_sel(), + distinguish_types=(not mixed_types), + ) + return [ + x.detach().cpu().numpy() if torch.is_tensor(x) else x + for x in pt_expt_obj(ext_coords, ext_atype, nlist=nlist, mapping=mapping) + ] + def eval_jax_descriptor( self, jax_obj: Any, diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index b345a61ed3..68a0068965 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -16,6 +16,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -33,6 +34,10 @@ ) else: DescrptSeAPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_e2_a import DescrptSeA as DescrptSeAPTExpt +else: + DescrptSeAPTExpt = None if INSTALLED_TF: from deepmd.tf.descriptor.se_a import DescrptSeA as DescrptSeATF else: @@ -107,6 +112,17 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + env_protection, + ) = self.param + return (not type_one_side) or CommonTest.skip_pt_expt + @property def skip_dp(self) -> bool: ( @@ -165,6 +181,7 @@ def skip_array_api_strict(self) -> bool: tf_class = DescrptSeATF dp_class = DescrptSeADP pt_class = DescrptSeAPT + pt_expt_class = DescrptSeAPTExpt jax_class = DescrptSeAJAX pd_class = DescrptSeAPD array_api_strict_class = DescrptSeAArrayAPIStrict @@ -244,6 +261,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, @@ -351,6 +377,17 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + env_protection, + ) = self.param + return (not type_one_side) or CommonTest.skip_pt_expt + @property def skip_dp(self) -> bool: ( @@ -402,6 +439,7 @@ def skip_array_api_strict(self) -> bool: tf_class = DescrptSeATF dp_class = DescrptSeADP pt_class = DescrptSeAPT + pt_expt_class = DescrptSeAPTExpt jax_class = DescrptSeAJAX pd_class = DescrptSeAPD array_api_strict_class = DescrptSeAArrayAPIStrict @@ -505,6 +543,28 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + pt_expt_obj.compute_input_stats( + [ + { + "r0": None, + "coord": torch.from_numpy(self.coords) + .reshape(-1, self.natoms[0], 3) + .to(env.DEVICE), + "atype": torch.from_numpy(self.atype.reshape(1, -1)).to(env.DEVICE), + "box": torch.from_numpy(self.box.reshape(1, 3, 3)).to(env.DEVICE), + "natoms": self.natoms[0], + } + ] + ) + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: jax_obj.compute_input_stats( [ diff --git a/source/tests/pt_expt/__init__.py b/source/tests/pt_expt/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/model/__init__.py b/source/tests/pt_expt/model/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/model/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/model/test_se_e2_a.py b/source/tests/pt_expt/model/test_se_e2_a.py new file mode 100644 index 0000000000..b9b834849f --- /dev/null +++ b/source/tests/pt_expt/model/test_se_e2_a.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch # noqa: TID253 + +from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA +from deepmd.pt.utils import ( # noqa: TID253 + env, +) +from deepmd.pt.utils.env import ( # noqa: TID253 + PRECISION_DICT, +) +from deepmd.pt.utils.exclude_mask import ( # noqa: TID253 + PairExcludeMask, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, em in itertools.product( + [False, True], + ["float64", "float32"], + [[], [[0, 1]], [[1, 1]]], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + exclude_types=em, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeA.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeA.deserialize(dd0.serialize()) + rd2, gr2, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + for aa, bb in zip([rd1, gr1, sw1], [rd2, gr2, sw2], strict=True): + np.testing.assert_allclose( + aa.detach().cpu().numpy(), + bb, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + if em: + dd1.reinit_exclude([tuple(x) for x in em]) + self.assertIsInstance(dd1.emask, PairExcludeMask) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) From b8a48ffe6bdd97ff3dbee7ef13182aa4bcf03a87 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 07:53:13 +0800 Subject: [PATCH 02/35] better type for xp.zeros --- deepmd/dpmodel/descriptor/se_e2_a.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index a6b17bf69a..3ca28ba556 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -609,7 +609,7 @@ def call( ng = self.neuron[-1] gr = xp.zeros( [nf * nloc, ng, 4], - dtype=self.dstd.dtype, + dtype=input_dtype, device=array_api_compat.device(coord_ext), ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) From 1cc001f7f262e6d76f40a8c7d36d9e80a9f1dd19 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 09:51:22 +0800 Subject: [PATCH 03/35] implement env, base_descriptor and exclude_mask, remove the dependency on pt backend. --- deepmd/pt_expt/descriptor/__init__.py | 4 + deepmd/pt_expt/descriptor/base_descriptor.py | 10 ++ deepmd/pt_expt/descriptor/se_e2_a.py | 11 +- deepmd/pt_expt/utils/__init__.py | 10 ++ deepmd/pt_expt/utils/env.py | 117 +++++++++++++++++++ deepmd/pt_expt/utils/exclude_mask.py | 27 +++++ deepmd/pt_expt/utils/network.py | 6 +- source/tests/pt_expt/model/test_se_e2_a.py | 16 +-- 8 files changed, 187 insertions(+), 14 deletions(-) create mode 100644 deepmd/pt_expt/descriptor/base_descriptor.py create mode 100644 deepmd/pt_expt/utils/env.py create mode 100644 deepmd/pt_expt/utils/exclude_mask.py diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py index fdac48ed41..089e5619e0 100644 --- a/deepmd/pt_expt/descriptor/__init__.py +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from .base_descriptor import ( + BaseDescriptor, +) from .se_e2_a import ( DescrptSeA, ) __all__ = [ + "BaseDescriptor", "DescrptSeA", ] diff --git a/deepmd/pt_expt/descriptor/base_descriptor.py b/deepmd/pt_expt/descriptor/base_descriptor.py new file mode 100644 index 0000000000..51e9325bba --- /dev/null +++ b/deepmd/pt_expt/descriptor/base_descriptor.py @@ -0,0 +1,10 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib + +from deepmd.dpmodel.descriptor import ( + make_base_descriptor, +) + +torch = importlib.import_module("torch") + +BaseDescriptor = make_base_descriptor(torch.Tensor, "forward") diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 4334011ec3..bb0c0cb2bd 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -1,24 +1,25 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import importlib from typing import ( Any, ) -import torch # noqa: TID253 - from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP -from deepmd.pt.model.descriptor.base_descriptor import ( # noqa: TID253 +from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.pt.utils import ( # noqa: TID253 +from deepmd.pt_expt.utils import ( env, ) -from deepmd.pt.utils.exclude_mask import ( # noqa: TID253 +from deepmd.pt_expt.utils.exclude_mask import ( PairExcludeMask, ) from deepmd.pt_expt.utils.network import ( NetworkCollection, ) +torch = importlib.import_module("torch") + @BaseDescriptor.register("se_e2_a_expt") @BaseDescriptor.register("se_a_expt") diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index 6ceb116d85..f90cf82249 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -1 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + +from .exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) + +__all__ = [ + "AtomExcludeMask", + "PairExcludeMask", +] diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py new file mode 100644 index 0000000000..bd644e7206 --- /dev/null +++ b/deepmd/pt_expt/utils/env.py @@ -0,0 +1,117 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import logging +import multiprocessing +import os +import sys + +import numpy as np + +from deepmd.common import ( + VALID_PRECISION, +) +from deepmd.env import ( + GLOBAL_ENER_FLOAT_PRECISION, + GLOBAL_NP_FLOAT_PRECISION, + get_default_nthreads, + set_default_nthreads, +) + +log = logging.getLogger(__name__) +torch = importlib.import_module("torch") + +if sys.platform != "win32": + try: + multiprocessing.set_start_method("fork", force=True) + log.debug("Successfully set multiprocessing start method to 'fork'.") + except (RuntimeError, ValueError) as err: + log.warning(f"Could not set multiprocessing start method: {err}") +else: + log.debug("Skipping fork start method on Windows (not supported).") + +SAMPLER_RECORD = os.environ.get("SAMPLER_RECORD", False) +DP_DTYPE_PROMOTION_STRICT = os.environ.get("DP_DTYPE_PROMOTION_STRICT", "0") == "1" +try: + # only linux + ncpus = len(os.sched_getaffinity(0)) +except AttributeError: + ncpus = os.cpu_count() +NUM_WORKERS = int(os.environ.get("NUM_WORKERS", min(4, ncpus))) +if multiprocessing.get_start_method() != "fork": + # spawn or forkserver does not support NUM_WORKERS > 0 for DataLoader + log.warning( + "NUM_WORKERS > 0 is not supported with spawn or forkserver start method. " + "Setting NUM_WORKERS to 0." + ) + NUM_WORKERS = 0 + +# Make sure DDP uses correct device if applicable +LOCAL_RANK = os.environ.get("LOCAL_RANK") +LOCAL_RANK = int(0 if LOCAL_RANK is None else LOCAL_RANK) + +if os.environ.get("DEVICE") == "cpu" or torch.cuda.is_available() is False: + DEVICE = torch.device("cpu") +else: + DEVICE = torch.device(f"cuda:{LOCAL_RANK}") + +JIT = False +CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory +ENERGY_BIAS_TRAINABLE = True +CUSTOM_OP_USE_JIT = False + +PRECISION_DICT = { + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + "half": torch.float16, + "single": torch.float32, + "double": torch.float64, + "int32": torch.int32, + "int64": torch.int64, + "bfloat16": torch.bfloat16, + "bool": torch.bool, +} +GLOBAL_PT_FLOAT_PRECISION = PRECISION_DICT[np.dtype(GLOBAL_NP_FLOAT_PRECISION).name] +GLOBAL_PT_ENER_FLOAT_PRECISION = PRECISION_DICT[ + np.dtype(GLOBAL_ENER_FLOAT_PRECISION).name +] +PRECISION_DICT["default"] = GLOBAL_PT_FLOAT_PRECISION +assert VALID_PRECISION.issubset(PRECISION_DICT.keys()) +# cannot automatically generated +RESERVED_PRECISION_DICT = { + torch.float16: "float16", + torch.float32: "float32", + torch.float64: "float64", + torch.int32: "int32", + torch.int64: "int64", + torch.bfloat16: "bfloat16", + torch.bool: "bool", +} +assert set(PRECISION_DICT.values()) == set(RESERVED_PRECISION_DICT.keys()) +DEFAULT_PRECISION = "float64" + +# throw warnings if threads not set +set_default_nthreads() +inter_nthreads, intra_nthreads = get_default_nthreads() +if inter_nthreads > 0: # the behavior of 0 is not documented + torch.set_num_interop_threads(inter_nthreads) +if intra_nthreads > 0: + torch.set_num_threads(intra_nthreads) + +__all__ = [ + "CACHE_PER_SYS", + "CUSTOM_OP_USE_JIT", + "DEFAULT_PRECISION", + "DEVICE", + "ENERGY_BIAS_TRAINABLE", + "GLOBAL_ENER_FLOAT_PRECISION", + "GLOBAL_NP_FLOAT_PRECISION", + "GLOBAL_PT_ENER_FLOAT_PRECISION", + "GLOBAL_PT_FLOAT_PRECISION", + "JIT", + "LOCAL_RANK", + "NUM_WORKERS", + "PRECISION_DICT", + "RESERVED_PRECISION_DICT", + "SAMPLER_RECORD", +] diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py new file mode 100644 index 0000000000..ed296c9f98 --- /dev/null +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -0,0 +1,27 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +from typing import ( + Any, +) + +from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP +from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP +from deepmd.pt_expt.utils import ( + env, +) + +torch = importlib.import_module("torch") + + +class AtomExcludeMask(AtomExcludeMaskDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "type_mask": + value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + return super().__setattr__(name, value) + + +class PairExcludeMask(PairExcludeMaskDP): + def __setattr__(self, name: str, value: Any) -> None: + if name == "type_mask": + value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + return super().__setattr__(name, value) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index f29d8970b3..91a6999766 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import importlib from typing import ( Any, ClassVar, @@ -6,7 +7,6 @@ ) import numpy as np -import torch # noqa: TID253 from deepmd.dpmodel.common import ( NativeOP, @@ -19,10 +19,12 @@ make_fitting_network, make_multilayer_network, ) -from deepmd.pt.utils import ( # noqa: TID253 +from deepmd.pt_expt.utils import ( env, ) +torch = importlib.import_module("torch") + def _to_torch_array(value: Any) -> torch.Tensor | None: if value is None: diff --git a/source/tests/pt_expt/model/test_se_e2_a.py b/source/tests/pt_expt/model/test_se_e2_a.py index b9b834849f..57923b97a3 100644 --- a/source/tests/pt_expt/model/test_se_e2_a.py +++ b/source/tests/pt_expt/model/test_se_e2_a.py @@ -1,23 +1,23 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import importlib import itertools import unittest import numpy as np -import torch # noqa: TID253 from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA -from deepmd.pt.utils import ( # noqa: TID253 +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.utils import ( env, ) -from deepmd.pt.utils.env import ( # noqa: TID253 +from deepmd.pt_expt.utils.env import ( PRECISION_DICT, ) -from deepmd.pt.utils.exclude_mask import ( # noqa: TID253 +from deepmd.pt_expt.utils.exclude_mask import ( PairExcludeMask, ) -from deepmd.pt_expt.descriptor.se_e2_a import ( - DescrptSeA, -) from ...pt.model.test_env_mat import ( TestCaseSingleFrameWithNlist, @@ -29,6 +29,8 @@ GLOBAL_SEED, ) +torch = importlib.import_module("torch") + class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self) -> None: From f2fbe8884fdacf4369dbec847c06ebc54b453d02 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 10:08:29 +0800 Subject: [PATCH 04/35] mv to_torch_tensor to common --- deepmd/pt_expt/common.py | 35 +++++++++++++++++++++++++++++++++ deepmd/pt_expt/utils/network.py | 16 ++++----------- 2 files changed, 39 insertions(+), 12 deletions(-) create mode 100644 deepmd/pt_expt/common.py diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py new file mode 100644 index 0000000000..f065eeb76d --- /dev/null +++ b/deepmd/pt_expt/common.py @@ -0,0 +1,35 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +from typing import ( + Any, + overload, +) + +import numpy as np + +from deepmd.pt_expt.utils import ( + env, +) + +torch = importlib.import_module("torch") + + +@overload +def to_torch_array(array: np.ndarray) -> torch.Tensor: ... + + +@overload +def to_torch_array(array: None) -> None: ... + + +@overload +def to_torch_array(array: torch.Tensor) -> torch.Tensor: ... + + +def to_torch_array(array: Any) -> torch.Tensor | None: + """Convert input to a torch tensor on the pt-expt device.""" + if array is None: + return None + if torch.is_tensor(array): + return array + return torch.as_tensor(array, device=env.DEVICE) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 91a6999766..18840200be 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -19,21 +19,13 @@ make_fitting_network, make_multilayer_network, ) -from deepmd.pt_expt.utils import ( - env, +from deepmd.pt_expt.common import ( + to_torch_array, ) torch = importlib.import_module("torch") -def _to_torch_array(value: Any) -> torch.Tensor | None: - if value is None: - return None - if torch.is_tensor(value): - return value - return torch.as_tensor(value, device=env.DEVICE) - - class TorchArrayParam(torch.nn.Parameter): def __new__(cls, data: Any = None, requires_grad: bool = True) -> Self: return torch.nn.Parameter.__new__(cls, data, requires_grad) @@ -52,7 +44,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: for name in ("w", "b", "idt"): if name in self._parameters or name in self._buffers: continue - val = _to_torch_array(getattr(self, name)) + val = to_torch_array(getattr(self, name)) if val is None: continue if self.trainable: @@ -66,7 +58,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: - val = _to_torch_array(value) + val = to_torch_array(value) if val is None: return super().__setattr__(name, None) if getattr(self, "trainable", False): From e2afbe9c190ffef45315cac5089e067c7da800c5 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 10:26:15 +0800 Subject: [PATCH 05/35] simplify __init__ of the NaiveLayer --- deepmd/pt_expt/utils/network.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 18840200be..5708197c66 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -41,20 +41,6 @@ class NativeLayer(NativeLayerDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) NativeLayerDP.__init__(self, *args, **kwargs) - for name in ("w", "b", "idt"): - if name in self._parameters or name in self._buffers: - continue - val = to_torch_array(getattr(self, name)) - if val is None: - continue - if self.trainable: - if hasattr(self, name) and name not in self._parameters: - delattr(self, name) - self.register_parameter(name, TorchArrayParam(val, requires_grad=True)) - else: - if hasattr(self, name) and name not in self._buffers: - delattr(self, name) - self.register_buffer(name, val) def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: From 4ba511ac49d1093b3cabced160a953c90e8e0f81 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 10:32:09 +0800 Subject: [PATCH 06/35] fix bug --- deepmd/pt_expt/utils/network.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 5708197c66..f2230383de 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -3,7 +3,6 @@ from typing import ( Any, ClassVar, - Self, ) import numpy as np @@ -27,7 +26,9 @@ class TorchArrayParam(torch.nn.Parameter): - def __new__(cls, data: Any = None, requires_grad: bool = True) -> Self: + def __new__( # noqa: PYI034 + cls, data: Any = None, requires_grad: bool = True + ) -> "TorchArrayParam": return torch.nn.Parameter.__new__(cls, data, requires_grad) def __array__(self, dtype: Any | None = None) -> np.ndarray: From fb9598a68d13b1712b11b9ea481fd3e2ca85b502 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 10:43:19 +0800 Subject: [PATCH 07/35] fix bug --- deepmd/pt_expt/utils/network.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index f2230383de..7a85634dca 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -95,16 +95,18 @@ class NetworkCollection(NetworkCollectionDP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) - super().__init__(*args, **kwargs) self._module_networks = torch.nn.ModuleDict() - for idx, net in enumerate(self._networks): - if isinstance(net, torch.nn.Module): - self._module_networks[str(idx)] = net + super().__init__(*args, **kwargs) def __setitem__(self, key: int | tuple, value: Any) -> None: + idx = self._convert_key(key) super().__setitem__(key, value) - if isinstance(value, torch.nn.Module): - self._module_networks[str(self._convert_key(key))] = value + net = self._networks[idx] + key_str = str(idx) + if isinstance(net, torch.nn.Module): + self._module_networks[key_str] = net + elif key_str in self._module_networks: + del self._module_networks[key_str] class LayerNorm(LayerNormDP, NativeLayer): From fa03351be77fe9b1f1e5173d0855d6e6d912ab9e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 11:25:14 +0800 Subject: [PATCH 08/35] simplify init method of se_e2_a descriptor. fig bug in consistent UT --- deepmd/pt_expt/descriptor/se_e2_a.py | 25 ------------------------- source/tests/consistent/common.py | 2 +- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index bb0c0cb2bd..19a0d56734 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -27,7 +27,6 @@ class DescrptSeA(DescrptSeADP, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) DescrptSeADP.__init__(self, *args, **kwargs) - self._convert_state() def __setattr__(self, name: str, value: Any) -> None: if name in {"davg", "dstd"} and "_buffers" in self.__dict__: @@ -53,30 +52,6 @@ def __setattr__(self, name: str, value: Any) -> None: return super().__setattr__(name, value) return super().__setattr__(name, value) - def _convert_state(self) -> None: - if self.davg is not None: - davg = torch.as_tensor(self.davg, device=env.DEVICE) - if "davg" in self._buffers: - self._buffers["davg"] = davg - else: - if hasattr(self, "davg"): - delattr(self, "davg") - self.register_buffer("davg", davg) - if self.dstd is not None: - dstd = torch.as_tensor(self.dstd, device=env.DEVICE) - if "dstd" in self._buffers: - self._buffers["dstd"] = dstd - else: - if hasattr(self, "dstd"): - delattr(self, "dstd") - self.register_buffer("dstd", dstd) - if self.embeddings is not None: - self.embeddings = NetworkCollection.deserialize(self.embeddings.serialize()) - if self.emask is not None: - self.emask = PairExcludeMask( - self.ntypes, exclude_types=list(self.emask.get_exclude_types()) - ) - def forward( self, nlist: torch.Tensor, diff --git a/source/tests/consistent/common.py b/source/tests/consistent/common.py index 3d60f6def0..76b7e9cb53 100644 --- a/source/tests/consistent/common.py +++ b/source/tests/consistent/common.py @@ -92,7 +92,7 @@ class CommonTest(ABC): """Native DP model class.""" pt_class: ClassVar[type | None] """PyTorch model class.""" - pt_expt_class: ClassVar[type | None] + pt_expt_class: ClassVar[type | None] = None """PyTorch exportable model class.""" jax_class: ClassVar[type | None] """JAX model class.""" From 09b33f19daef30dad5dcd27be4811cde994b8bad Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 11:34:44 +0800 Subject: [PATCH 09/35] restructure the test folders. add test_common. --- deepmd/pt_expt/common.py | 2 +- source/tests/pt_expt/descriptor/__init__.py | 1 + .../{model => descriptor}/test_se_e2_a.py | 0 source/tests/pt_expt/utils/__init__.py | 1 + source/tests/pt_expt/utils/test_common.py | 25 +++++++++++++++++++ 5 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 source/tests/pt_expt/descriptor/__init__.py rename source/tests/pt_expt/{model => descriptor}/test_se_e2_a.py (100%) create mode 100644 source/tests/pt_expt/utils/__init__.py create mode 100644 source/tests/pt_expt/utils/test_common.py diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index f065eeb76d..b66c0ff66d 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -31,5 +31,5 @@ def to_torch_array(array: Any) -> torch.Tensor | None: if array is None: return None if torch.is_tensor(array): - return array + return array.to(device=env.DEVICE) return torch.as_tensor(array, device=env.DEVICE) diff --git a/source/tests/pt_expt/descriptor/__init__.py b/source/tests/pt_expt/descriptor/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/descriptor/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/model/test_se_e2_a.py b/source/tests/pt_expt/descriptor/test_se_e2_a.py similarity index 100% rename from source/tests/pt_expt/model/test_se_e2_a.py rename to source/tests/pt_expt/descriptor/test_se_e2_a.py diff --git a/source/tests/pt_expt/utils/__init__.py b/source/tests/pt_expt/utils/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/utils/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py new file mode 100644 index 0000000000..63c4983f23 --- /dev/null +++ b/source/tests/pt_expt/utils/test_common.py @@ -0,0 +1,25 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib + +import numpy as np + +from deepmd.pt_expt.common import ( + to_torch_array, +) +from deepmd.pt_expt.utils import ( + env, +) + +torch = importlib.import_module("torch") + + +def test_to_torch_array_moves_device() -> None: + arr = np.arange(6, dtype=np.float32).reshape(2, 3) + tensor = to_torch_array(arr) + assert torch.is_tensor(tensor) + assert tensor.device == env.DEVICE + + input_tensor = torch.as_tensor(arr, device=torch.device("cpu")) + output_tensor = to_torch_array(input_tensor) + assert torch.is_tensor(output_tensor) + assert output_tensor.device == env.DEVICE From 67f2e544a6d228652ba9ea311a52a5c8defec210 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 11:54:41 +0800 Subject: [PATCH 10/35] add test_exclusion_mask.py --- .../pt_expt/utils/test_exclusion_mask.py | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 source/tests/pt_expt/utils/test_exclusion_mask.py diff --git a/source/tests/pt_expt/utils/test_exclusion_mask.py b/source/tests/pt_expt/utils/test_exclusion_mask.py new file mode 100644 index 0000000000..7168579052 --- /dev/null +++ b/source/tests/pt_expt/utils/test_exclusion_mask.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import unittest + +import numpy as np + +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.exclude_mask import ( + AtomExcludeMask, + PairExcludeMask, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) + +torch = importlib.import_module("torch") + + +class TestAtomExcludeMask(unittest.TestCase): + def test_build_type_exclude_mask(self) -> None: + nf = 2 + nt = 3 + exclude_types = [0, 2] + atype = np.array( + [ + [0, 2, 1, 2, 0, 1, 0], + [1, 2, 0, 0, 2, 2, 1], + ], + dtype=np.int32, + ).reshape([nf, -1]) + expected_mask = np.array( + [ + [0, 0, 1, 0, 0, 1, 0], + [1, 0, 0, 0, 0, 0, 1], + ] + ).reshape([nf, -1]) + des = AtomExcludeMask(nt, exclude_types=exclude_types) + mask = des.build_type_exclude_mask(torch.as_tensor(atype, device=env.DEVICE)) + np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask) + + +class TestPairExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + + def test_build_type_exclude_mask(self) -> None: + exclude_types = [[0, 1]] + expected_mask = np.array( + [ + [1, 1, 1, 1, 1, 0, 1], + [1, 1, 1, 1, 1, 0, 1], + [0, 0, 1, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 0, 1], + [1, 1, 1, 1, 1, 0, 1], + ] + ).reshape(self.nf, self.nloc, sum(self.sel)) + des = PairExcludeMask(self.nt, exclude_types=exclude_types) + mask = des.build_type_exclude_mask( + torch.as_tensor(self.nlist, device=env.DEVICE), + torch.as_tensor(self.atype_ext, device=env.DEVICE), + ) + np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask) From f7d83ddfae60920d05b76f8a13fd4a35bb350a79 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 11:58:38 +0800 Subject: [PATCH 11/35] fix poitential import issue in test. --- source/tests/pt_expt/conftest.py | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 source/tests/pt_expt/conftest.py diff --git a/source/tests/pt_expt/conftest.py b/source/tests/pt_expt/conftest.py new file mode 100644 index 0000000000..ec025c2202 --- /dev/null +++ b/source/tests/pt_expt/conftest.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import pytest + +pytest.importorskip("torch") From 0c96bb6fecf1564433d0f98d00d05866dcb9fbd5 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 12:11:12 +0800 Subject: [PATCH 12/35] correct __call__(). fix bug --- deepmd/pt_expt/descriptor/se_e2_a.py | 6 +++++- deepmd/pt_expt/utils/network.py | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 19a0d56734..7a4d4a71d9 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -28,6 +28,10 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) DescrptSeADP.__init__(self, *args, **kwargs) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + def __setattr__(self, name: str, value: Any) -> None: if name in {"davg", "dstd"} and "_buffers" in self.__dict__: tensor = ( @@ -54,9 +58,9 @@ def __setattr__(self, name: str, value: Any) -> None: def forward( self, - nlist: torch.Tensor, extended_coord: torch.Tensor, extended_atype: torch.Tensor, + nlist: torch.Tensor, extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, type_embedding: torch.Tensor | None = None, diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 7a85634dca..fffb98b1ef 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -43,6 +43,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) NativeLayerDP.__init__(self, *args, **kwargs) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return torch.nn.Module.__call__(self, *args, **kwargs) + def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: val = to_torch_array(value) @@ -74,6 +77,9 @@ def __init__(self, layers: list[dict] | None = None) -> None: super().__init__(layers) self.layers = torch.nn.ModuleList(self.layers) + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return torch.nn.Module.__call__(self, *args, **kwargs) + def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) From 9dca9128b50ecf7a92f902e6a6dca0b475ac1e9c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 12:54:28 +0800 Subject: [PATCH 13/35] fix registration issue --- deepmd/pt_expt/descriptor/se_e2_a.py | 4 +++- deepmd/pt_expt/utils/network.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 7a4d4a71d9..7df1148e38 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -40,7 +40,9 @@ def __setattr__(self, name: str, value: Any) -> None: if name in self._buffers: self._buffers[name] = tensor return - return super().__setattr__(name, tensor) + # Register on first assignment so buffers are in state_dict and moved by .to(). + self.register_buffer(name, tensor) + return if name == "embeddings" and "_modules" in self.__dict__: if value is not None and not isinstance(value, torch.nn.Module): if hasattr(value, "serialize"): diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index fffb98b1ef..5f66959d16 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -64,7 +64,9 @@ def __setattr__(self, name: str, value: Any) -> None: if name in self._buffers: self._buffers[name] = val return - return super().__setattr__(name, val) + # Register on first assignment so tensors are in state_dict and moved by .to(). + self.register_buffer(name, val) + return return super().__setattr__(name, value) def forward(self, x: torch.Tensor) -> torch.Tensor: From 17f0a5d1ae06eecfa25b4f2e3e2e09a4570fc0ef Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 13:00:54 +0800 Subject: [PATCH 14/35] fix pt-expt file extension --- deepmd/backend/pt_expt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py index 38745c690c..e651332e2b 100644 --- a/deepmd/backend/pt_expt.py +++ b/deepmd/backend/pt_expt.py @@ -41,7 +41,7 @@ class PyTorchExportableBackend(Backend): | Backend.Feature.IO ) """The features of the backend.""" - suffixes: ClassVar[list[str]] = [".pth", ".pt"] + suffixes: ClassVar[list[str]] = [".pte"] """The suffixes of the backend.""" def is_available(self) -> bool: From 8ce93baafa42c0c51046152084765f44101809b9 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 13:03:21 +0800 Subject: [PATCH 15/35] fix(pt): expansion of get_default_nthreads() --- deepmd/pt/utils/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index 90d0d536c1..aa384b31b5 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -91,7 +91,7 @@ # throw warnings if threads not set set_default_nthreads() -inter_nthreads, intra_nthreads = get_default_nthreads() +intra_nthreads, inter_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented torch.set_num_interop_threads(inter_nthreads) if intra_nthreads > 0: From 309198894d37ad0f0417fc3475919fb02e3fc63c Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 13:08:58 +0800 Subject: [PATCH 16/35] fix bug of intra-inter --- deepmd/pt_expt/utils/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py index bd644e7206..b5042f6f2a 100644 --- a/deepmd/pt_expt/utils/env.py +++ b/deepmd/pt_expt/utils/env.py @@ -92,7 +92,7 @@ # throw warnings if threads not set set_default_nthreads() -inter_nthreads, intra_nthreads = get_default_nthreads() +intra_nthreads, inter_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented torch.set_num_interop_threads(inter_nthreads) if intra_nthreads > 0: From 85f05833353a8bb388ce0fd16cdf0059d579b5e4 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 13:13:42 +0800 Subject: [PATCH 17/35] fix bug of default dp inter value --- deepmd/env.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/env.py b/deepmd/env.py index 7b29a338f1..c9d0fb241f 100644 --- a/deepmd/env.py +++ b/deepmd/env.py @@ -138,7 +138,7 @@ def get_default_nthreads() -> tuple[int, int]: ), int( os.environ.get( "DP_INTER_OP_PARALLELISM_THREADS", - os.environ.get("TF_INTRA_OP_PARALLELISM_THREADS", "0"), + os.environ.get("TF_INTER_OP_PARALLELISM_THREADS", "0"), ) ) From d33324de397d728d15fa773bdb828783085e4156 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 15:15:07 +0800 Subject: [PATCH 18/35] fix cicd --- source/tests/consistent/descriptor/common.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 7c8cbce744..50efe32a08 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -153,13 +153,17 @@ def eval_pt_expt_descriptor( box: np.ndarray, mixed_types: bool = False, ) -> Any: - ext_coords, ext_atype, mapping = extend_coord_with_ghosts( + # Use the torch-native neighbor list utilities to avoid array_api_compat + # allocations on CUDA. The array_api path can hit torch empty/ones/eye/etc + # on CUDA, which all rely on aten::empty_strided and fail in CI builds + # where that CUDA kernel is not available. + ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), pt_expt_obj.get_rcut(), ) - nlist = build_neighbor_list( + nlist = build_neighbor_list_pt( ext_coords, ext_atype, natoms[0], From 4de9a565c01c6193a94eaecbc18475c4be04dc08 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 15:57:27 +0800 Subject: [PATCH 19/35] feat: add support for se_r --- deepmd/dpmodel/descriptor/se_r.py | 4 +- deepmd/pt_expt/descriptor/__init__.py | 4 + deepmd/pt_expt/descriptor/se_r.py | 83 +++++++++++ .../tests/consistent/descriptor/test_se_r.py | 25 ++++ source/tests/pt_expt/descriptor/test_se_r.py | 132 ++++++++++++++++++ 5 files changed, 247 insertions(+), 1 deletion(-) create mode 100644 deepmd/pt_expt/descriptor/se_r.py create mode 100644 source/tests/pt_expt/descriptor/test_se_r.py diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index 6decd91a23..b38d561e95 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -391,7 +391,9 @@ def call( ng = self.neuron[-1] xyz_scatter = xp.zeros( - [nf, nloc, ng], dtype=get_xp_precision(xp, self.precision) + [nf, nloc, ng], + dtype=get_xp_precision(xp, self.precision), + device=array_api_compat.device(coord_ext), ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) rr = xp.astype(rr, xyz_scatter.dtype) diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py index 089e5619e0..4d9469a93a 100644 --- a/deepmd/pt_expt/descriptor/__init__.py +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -5,8 +5,12 @@ from .se_e2_a import ( DescrptSeA, ) +from .se_r import ( + DescrptSeR, +) __all__ = [ "BaseDescriptor", "DescrptSeA", + "DescrptSeR", ] diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py new file mode 100644 index 0000000000..f4969ce927 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -0,0 +1,83 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.exclude_mask import ( + PairExcludeMask, +) +from deepmd.pt_expt.utils.network import ( + NetworkCollection, +) + +torch = importlib.import_module("torch") + + +@BaseDescriptor.register("se_e2_r_expt") +@BaseDescriptor.register("se_r_expt") +class DescrptSeR(DescrptSeRDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptSeRDP.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + if name in {"davg", "dstd"} and "_buffers" in self.__dict__: + tensor = ( + None if value is None else torch.as_tensor(value, device=env.DEVICE) + ) + if name in self._buffers: + self._buffers[name] = tensor + return + # Register on first assignment so buffers are in state_dict and moved by .to(). + self.register_buffer(name, tensor) + return + if name == "embeddings" and "_modules" in self.__dict__: + if value is not None and not isinstance(value, torch.nn.Module): + if hasattr(value, "serialize"): + value = NetworkCollection.deserialize(value.serialize()) + elif isinstance(value, dict): + value = NetworkCollection.deserialize(value) + return super().__setattr__(name, value) + if name == "emask" and "_modules" in self.__dict__: + if value is not None and not isinstance(value, torch.nn.Module): + value = PairExcludeMask( + self.ntypes, exclude_types=list(value.get_exclude_types()) + ) + return super().__setattr__(name, value) + return super().__setattr__(name, value) + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + extended_atype_embd: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + type_embedding: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + del extended_atype_embd, type_embedding + descrpt, rot_mat, g2, h2, sw = self.call( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + return descrpt, rot_mat, g2, h2, sw diff --git a/source/tests/consistent/descriptor/test_se_r.py b/source/tests/consistent/descriptor/test_se_r.py index 3420c5592f..9aafea1578 100644 --- a/source/tests/consistent/descriptor/test_se_r.py +++ b/source/tests/consistent/descriptor/test_se_r.py @@ -15,6 +15,7 @@ INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -27,6 +28,10 @@ from deepmd.pt.model.descriptor.se_r import DescrptSeR as DescrptSeRPT else: DescrptSeAPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_r import DescrptSeR as DescrptSeRPTExpt +else: + DescrptSeRPTExpt = None if INSTALLED_TF: from deepmd.tf.descriptor.se_r import DescrptSeR as DescrptSeRTF else: @@ -84,6 +89,16 @@ def skip_pt(self) -> bool: ) = self.param return not type_one_side or CommonTest.skip_pt + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + ) = self.param + return not type_one_side or CommonTest.skip_pt_expt + @property def skip_dp(self) -> bool: ( @@ -117,6 +132,7 @@ def skip_array_api_strict(self) -> bool: tf_class = DescrptSeRTF dp_class = DescrptSeRDP pt_class = DescrptSeRPT + pt_expt_class = DescrptSeRPTExpt jax_class = DescrptSeRJAX array_api_strict_class = DescrptSeRArrayAPIStrict args = descrpt_se_r_args() @@ -183,6 +199,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, diff --git a/source/tests/pt_expt/descriptor/test_se_r.py b/source/tests/pt_expt/descriptor/test_se_r.py new file mode 100644 index 0000000000..6e7339801c --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_se_r.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import itertools +import unittest + +import numpy as np + +from deepmd.dpmodel.descriptor import DescrptSeR as DPDescrptSeR +from deepmd.pt_expt.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + +torch = importlib.import_module("torch") + + +class TestDescrptSeR(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec, em in itertools.product( + [False, True], + ["float64", "float32"], + [[], [[0, 1]], [[1, 1]]], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + exclude_types=em, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeR.deserialize(dd0.serialize()) + rd1, _, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeR.deserialize(dd0.serialize()) + rd2, _, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + for aa, bb in zip([rd1, sw1], [rd2, sw2], strict=True): + np.testing.assert_allclose( + aa.detach().cpu().numpy(), + bb, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) From f4dc0afec4909dd4c052e9dcd565b4be01b6ed92 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 17:50:35 +0800 Subject: [PATCH 20/35] fix device of xp array --- deepmd/dpmodel/descriptor/se_r.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/descriptor/se_r.py b/deepmd/dpmodel/descriptor/se_r.py index b38d561e95..4fdf50beba 100644 --- a/deepmd/dpmodel/descriptor/se_r.py +++ b/deepmd/dpmodel/descriptor/se_r.py @@ -309,9 +309,12 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.dstd) + device = array_api_compat.device(self.dstd) if not self.set_davg_zero: - self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True) - self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True) + self.davg = xp.asarray( + mean, dtype=self.davg.dtype, copy=True, device=device + ) + self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device) def set_stat_mean_and_stddev( self, From 238483531cd8455219ad5d21065f516f0486c97b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 17:51:50 +0800 Subject: [PATCH 21/35] fix device of xp array --- deepmd/dpmodel/descriptor/dpa1.py | 9 +++++++-- deepmd/dpmodel/descriptor/repflows.py | 9 +++++++-- deepmd/dpmodel/descriptor/repformers.py | 9 +++++++-- deepmd/dpmodel/descriptor/se_e2_a.py | 7 +++++-- deepmd/dpmodel/descriptor/se_t.py | 7 +++++-- deepmd/dpmodel/descriptor/se_t_tebd.py | 9 +++++++-- 6 files changed, 38 insertions(+), 12 deletions(-) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index 5228ba55b2..f09ab24dfe 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -909,9 +909,14 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) + device = array_api_compat.device(self.stddev) if not self.set_davg_zero: - self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) - self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) + self.mean = xp.asarray( + mean, dtype=self.mean.dtype, copy=True, device=device + ) + self.stddev = xp.asarray( + stddev, dtype=self.stddev.dtype, copy=True, device=device + ) def get_stats(self) -> dict[str, StatItem]: """Get the statistics of the descriptor.""" diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 706fc690e4..7ba4f92662 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -453,9 +453,14 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) + device = array_api_compat.device(self.stddev) if not self.set_davg_zero: - self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) - self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) + self.mean = xp.asarray( + mean, dtype=self.mean.dtype, copy=True, device=device + ) + self.stddev = xp.asarray( + stddev, dtype=self.stddev.dtype, copy=True, device=device + ) def get_stats(self) -> dict[str, StatItem]: """Get the statistics of the descriptor.""" diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 79d4f9228f..06f5c1c943 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -417,9 +417,14 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) + device = array_api_compat.device(self.stddev) if not self.set_davg_zero: - self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) - self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) + self.mean = xp.asarray( + mean, dtype=self.mean.dtype, copy=True, device=device + ) + self.stddev = xp.asarray( + stddev, dtype=self.stddev.dtype, copy=True, device=device + ) def get_stats(self) -> dict[str, StatItem]: """Get the statistics of the descriptor.""" diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 3ca28ba556..77afb110e9 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -350,9 +350,12 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.dstd) + device = array_api_compat.device(self.dstd) if not self.set_davg_zero: - self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True) - self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True) + self.davg = xp.asarray( + mean, dtype=self.davg.dtype, copy=True, device=device + ) + self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device) def set_stat_mean_and_stddev( self, diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 863187dd4c..749a5da188 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -290,9 +290,12 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.dstd) + device = array_api_compat.device(self.dstd) if not self.set_davg_zero: - self.davg = xp.asarray(mean, dtype=self.davg.dtype, copy=True) - self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True) + self.davg = xp.asarray( + mean, dtype=self.davg.dtype, copy=True, device=device + ) + self.dstd = xp.asarray(stddev, dtype=self.dstd.dtype, copy=True, device=device) def set_stat_mean_and_stddev( self, diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index e118d5abd4..0a2d46c015 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -694,9 +694,14 @@ def compute_input_stats( self.stats = env_mat_stat.stats mean, stddev = env_mat_stat() xp = array_api_compat.array_namespace(self.stddev) + device = array_api_compat.device(self.stddev) if not self.set_davg_zero: - self.mean = xp.asarray(mean, dtype=self.mean.dtype, copy=True) - self.stddev = xp.asarray(stddev, dtype=self.stddev.dtype, copy=True) + self.mean = xp.asarray( + mean, dtype=self.mean.dtype, copy=True, device=device + ) + self.stddev = xp.asarray( + stddev, dtype=self.stddev.dtype, copy=True, device=device + ) def get_stats(self) -> dict[str, StatItem]: """Get the statistics of the descriptor.""" From 9646d71b6d05d7c30fb70ec7e0be708e122bdbde Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 17:52:16 +0800 Subject: [PATCH 22/35] revert extend_coord_with_ghosts --- source/tests/consistent/descriptor/common.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 50efe32a08..7c8cbce744 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -153,17 +153,13 @@ def eval_pt_expt_descriptor( box: np.ndarray, mixed_types: bool = False, ) -> Any: - # Use the torch-native neighbor list utilities to avoid array_api_compat - # allocations on CUDA. The array_api path can hit torch empty/ones/eye/etc - # on CUDA, which all rely on aten::empty_strided and fail in CI builds - # where that CUDA kernel is not available. - ext_coords, ext_atype, mapping = extend_coord_with_ghosts_pt( + ext_coords, ext_atype, mapping = extend_coord_with_ghosts( torch.from_numpy(coords).to(PT_DEVICE).reshape(1, -1, 3), torch.from_numpy(atype).to(PT_DEVICE).reshape(1, -1), torch.from_numpy(box).to(PT_DEVICE).reshape(1, 3, 3), pt_expt_obj.get_rcut(), ) - nlist = build_neighbor_list_pt( + nlist = build_neighbor_list( ext_coords, ext_atype, natoms[0], From f270069dd07c10bb9d908bd203de0e6e7ba72412 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 18:11:20 +0800 Subject: [PATCH 23/35] raise error for non-implemented methods --- deepmd/backend/pt_expt.py | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py index e651332e2b..ade9eb51f3 100644 --- a/deepmd/backend/pt_expt.py +++ b/deepmd/backend/pt_expt.py @@ -76,9 +76,7 @@ def deep_eval(self) -> type["DeepEvalBackend"]: type[DeepEvalBackend] The Deep Eval backend of the backend. """ - from deepmd.pt.infer.deep_eval import DeepEval as DeepEvalPT - - return DeepEvalPT + raise NotImplementedError @property def neighbor_stat(self) -> type["NeighborStat"]: @@ -89,11 +87,7 @@ def neighbor_stat(self) -> type["NeighborStat"]: type[NeighborStat] The neighbor statistics of the backend. """ - from deepmd.pt.utils.neighbor_stat import ( - NeighborStat, - ) - - return NeighborStat + raise NotImplementedError @property def serialize_hook(self) -> Callable[[str], dict]: @@ -104,11 +98,7 @@ def serialize_hook(self) -> Callable[[str], dict]: Callable[[str], dict] The serialize hook of the backend. """ - from deepmd.pt.utils.serialization import ( - serialize_from_file, - ) - - return serialize_from_file + raise NotImplementedError @property def deserialize_hook(self) -> Callable[[str, dict], None]: @@ -119,8 +109,4 @@ def deserialize_hook(self) -> Callable[[str, dict], None]: Callable[[str, dict], None] The deserialize hook of the backend. """ - from deepmd.pt.utils.serialization import ( - deserialize_to_file, - ) - - return deserialize_to_file + raise NotImplementedError From 57433d3e1e82b00006a9062c4a3610bbf3b52c45 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 22:17:46 +0800 Subject: [PATCH 24/35] restore import torch --- deepmd/pt_expt/common.py | 4 +--- deepmd/pt_expt/descriptor/base_descriptor.py | 5 ++--- deepmd/pt_expt/descriptor/se_e2_a.py | 5 ++--- deepmd/pt_expt/descriptor/se_r.py | 5 ++--- deepmd/pt_expt/utils/env.py | 3 +-- deepmd/pt_expt/utils/exclude_mask.py | 5 ++--- deepmd/pt_expt/utils/network.py | 4 +--- pyproject.toml | 3 +++ source/tests/pt_expt/descriptor/test_se_e2_a.py | 4 +--- source/tests/pt_expt/descriptor/test_se_r.py | 4 +--- source/tests/pt_expt/utils/test_common.py | 4 +--- source/tests/pt_expt/utils/test_exclusion_mask.py | 4 +--- 12 files changed, 18 insertions(+), 32 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index b66c0ff66d..db8b94989b 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -1,18 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, overload, ) import numpy as np +import torch from deepmd.pt_expt.utils import ( env, ) -torch = importlib.import_module("torch") - @overload def to_torch_array(array: np.ndarray) -> torch.Tensor: ... diff --git a/deepmd/pt_expt/descriptor/base_descriptor.py b/deepmd/pt_expt/descriptor/base_descriptor.py index 51e9325bba..986435205a 100644 --- a/deepmd/pt_expt/descriptor/base_descriptor.py +++ b/deepmd/pt_expt/descriptor/base_descriptor.py @@ -1,10 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib + +import torch from deepmd.dpmodel.descriptor import ( make_base_descriptor, ) -torch = importlib.import_module("torch") - BaseDescriptor = make_base_descriptor(torch.Tensor, "forward") diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 7df1148e38..21c0a4eeb7 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, ) +import torch + from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, @@ -18,8 +19,6 @@ NetworkCollection, ) -torch = importlib.import_module("torch") - @BaseDescriptor.register("se_e2_a_expt") @BaseDescriptor.register("se_a_expt") diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index f4969ce927..508785949c 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, ) +import torch + from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, @@ -18,8 +19,6 @@ NetworkCollection, ) -torch = importlib.import_module("torch") - @BaseDescriptor.register("se_e2_r_expt") @BaseDescriptor.register("se_r_expt") diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py index b5042f6f2a..ce13e4ef42 100644 --- a/deepmd/pt_expt/utils/env.py +++ b/deepmd/pt_expt/utils/env.py @@ -1,5 +1,4 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import logging import multiprocessing import os @@ -18,7 +17,7 @@ ) log = logging.getLogger(__name__) -torch = importlib.import_module("torch") +import torch if sys.platform != "win32": try: diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py index ed296c9f98..15fbbc8e34 100644 --- a/deepmd/pt_expt/utils/exclude_mask.py +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -1,17 +1,16 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, ) +import torch + from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP from deepmd.pt_expt.utils import ( env, ) -torch = importlib.import_module("torch") - class AtomExcludeMask(AtomExcludeMaskDP): def __setattr__(self, name: str, value: Any) -> None: diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 5f66959d16..3effcfc488 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib from typing import ( Any, ClassVar, ) import numpy as np +import torch from deepmd.dpmodel.common import ( NativeOP, @@ -22,8 +22,6 @@ to_torch_array, ) -torch = importlib.import_module("torch") - class TorchArrayParam(torch.nn.Parameter): def __new__( # noqa: PYI034 diff --git a/pyproject.toml b/pyproject.toml index bd403dfaf2..15eb0b2ae5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -411,6 +411,7 @@ convention = "numpy" banned-module-level-imports = [ "deepmd.tf", "deepmd.pt", + "deepmd.pt_expt", "deepmd.pd", "deepmd.jax", "tensorflow", @@ -432,12 +433,14 @@ runtime-evaluated-base-classes = ["torch.nn.Module"] "data/**" = ["ANN"] "deepmd/tf/**" = ["TID253", "ANN"] "deepmd/pt/**" = ["TID253", "B905"] +"deepmd/pt_expt/**" = ["TID253", "B905"] "deepmd/jax/**" = ["TID253"] "deepmd/pd/**" = ["TID253", "B905"] "source/**" = ["ANN"] "source/tests/tf/**" = ["TID253", "ANN"] "source/tests/pt/**" = ["TID253", "ANN"] +"source/tests/pt_expt/**" = ["TID253", "ANN"] "source/tests/jax/**" = ["TID253", "ANN"] "source/tests/pd/**" = ["TID253", "ANN"] "source/tests/universal/pt/**" = ["TID253", "ANN"] diff --git a/source/tests/pt_expt/descriptor/test_se_e2_a.py b/source/tests/pt_expt/descriptor/test_se_e2_a.py index 57923b97a3..e63138e43b 100644 --- a/source/tests/pt_expt/descriptor/test_se_e2_a.py +++ b/source/tests/pt_expt/descriptor/test_se_e2_a.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import itertools import unittest import numpy as np +import torch from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA from deepmd.pt_expt.descriptor.se_e2_a import ( @@ -29,8 +29,6 @@ GLOBAL_SEED, ) -torch = importlib.import_module("torch") - class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self) -> None: diff --git a/source/tests/pt_expt/descriptor/test_se_r.py b/source/tests/pt_expt/descriptor/test_se_r.py index 6e7339801c..c789b13652 100644 --- a/source/tests/pt_expt/descriptor/test_se_r.py +++ b/source/tests/pt_expt/descriptor/test_se_r.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import itertools import unittest import numpy as np +import torch from deepmd.dpmodel.descriptor import DescrptSeR as DPDescrptSeR from deepmd.pt_expt.descriptor.se_r import ( @@ -26,8 +26,6 @@ GLOBAL_SEED, ) -torch = importlib.import_module("torch") - class TestDescrptSeR(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self) -> None: diff --git a/source/tests/pt_expt/utils/test_common.py b/source/tests/pt_expt/utils/test_common.py index 63c4983f23..ee8a7ca324 100644 --- a/source/tests/pt_expt/utils/test_common.py +++ b/source/tests/pt_expt/utils/test_common.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import numpy as np +import torch from deepmd.pt_expt.common import ( to_torch_array, @@ -10,8 +10,6 @@ env, ) -torch = importlib.import_module("torch") - def test_to_torch_array_moves_device() -> None: arr = np.arange(6, dtype=np.float32).reshape(2, 3) diff --git a/source/tests/pt_expt/utils/test_exclusion_mask.py b/source/tests/pt_expt/utils/test_exclusion_mask.py index 7168579052..b3707ef69d 100644 --- a/source/tests/pt_expt/utils/test_exclusion_mask.py +++ b/source/tests/pt_expt/utils/test_exclusion_mask.py @@ -1,8 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import importlib import unittest import numpy as np +import torch from deepmd.pt_expt.utils import ( env, @@ -16,8 +16,6 @@ TestCaseSingleFrameWithNlist, ) -torch = importlib.import_module("torch") - class TestAtomExcludeMask(unittest.TestCase): def test_build_type_exclude_mask(self) -> None: From eedcbaf4f67ff6a9dface303450c4110b6c139b2 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 22:31:03 +0800 Subject: [PATCH 25/35] fix(pt,pt-expt): guard thread setters --- deepmd/pt/utils/env.py | 15 ++++++++++++-- deepmd/pt_expt/utils/env.py | 15 ++++++++++++-- source/tests/pt/test_env_threads.py | 28 ++++++++++++++++++++++++++ source/tests/pt_expt/utils/test_env.py | 28 ++++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 4 deletions(-) create mode 100644 source/tests/pt/test_env_threads.py create mode 100644 source/tests/pt_expt/utils/test_env.py diff --git a/deepmd/pt/utils/env.py b/deepmd/pt/utils/env.py index aa384b31b5..9f453c895c 100644 --- a/deepmd/pt/utils/env.py +++ b/deepmd/pt/utils/env.py @@ -93,9 +93,20 @@ set_default_nthreads() intra_nthreads, inter_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented - torch.set_num_interop_threads(inter_nthreads) + # torch.set_num_interop_threads can only be called once per process. + # Guard to avoid RuntimeError when multiple backends are imported. + try: + if torch.get_num_interop_threads() != inter_nthreads: + torch.set_num_interop_threads(inter_nthreads) + except RuntimeError as err: + log.warning(f"Could not set torch interop threads: {err}") if intra_nthreads > 0: - torch.set_num_threads(intra_nthreads) + # torch.set_num_threads can also fail if called after threads are created. + try: + if torch.get_num_threads() != intra_nthreads: + torch.set_num_threads(intra_nthreads) + except RuntimeError as err: + log.warning(f"Could not set torch intra threads: {err}") __all__ = [ "CACHE_PER_SYS", diff --git a/deepmd/pt_expt/utils/env.py b/deepmd/pt_expt/utils/env.py index ce13e4ef42..56cec25d49 100644 --- a/deepmd/pt_expt/utils/env.py +++ b/deepmd/pt_expt/utils/env.py @@ -93,9 +93,20 @@ set_default_nthreads() intra_nthreads, inter_nthreads = get_default_nthreads() if inter_nthreads > 0: # the behavior of 0 is not documented - torch.set_num_interop_threads(inter_nthreads) + # torch.set_num_interop_threads can only be called once per process. + # Guard to avoid RuntimeError when both pt and pt_expt env modules are imported. + try: + if torch.get_num_interop_threads() != inter_nthreads: + torch.set_num_interop_threads(inter_nthreads) + except RuntimeError as err: + log.warning(f"Could not set torch interop threads: {err}") if intra_nthreads > 0: - torch.set_num_threads(intra_nthreads) + # torch.set_num_threads can also fail if called after threads are created. + try: + if torch.get_num_threads() != intra_nthreads: + torch.set_num_threads(intra_nthreads) + except RuntimeError as err: + log.warning(f"Could not set torch intra threads: {err}") __all__ = [ "CACHE_PER_SYS", diff --git a/source/tests/pt/test_env_threads.py b/source/tests/pt/test_env_threads.py new file mode 100644 index 0000000000..eb6604ceb8 --- /dev/null +++ b/source/tests/pt/test_env_threads.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import logging + +import torch + +import deepmd.env as common_env + + +def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None: + def raise_err(*_args, **_kwargs) -> None: + raise RuntimeError("boom") + + monkeypatch.setattr(common_env, "set_default_nthreads", lambda: None) + monkeypatch.setattr(common_env, "get_default_nthreads", lambda: (1, 1)) + monkeypatch.setattr(torch, "get_num_interop_threads", lambda: 2) + monkeypatch.setattr(torch, "set_num_interop_threads", raise_err) + monkeypatch.setattr(torch, "get_num_threads", lambda: 2) + monkeypatch.setattr(torch, "set_num_threads", raise_err) + + caplog.set_level(logging.WARNING, logger="deepmd.pt.utils.env") + import deepmd.pt.utils.env as env + + importlib.reload(env) + + messages = [record.getMessage() for record in caplog.records] + assert any("Could not set torch interop threads" in msg for msg in messages) + assert any("Could not set torch intra threads" in msg for msg in messages) diff --git a/source/tests/pt_expt/utils/test_env.py b/source/tests/pt_expt/utils/test_env.py new file mode 100644 index 0000000000..bbdc696aea --- /dev/null +++ b/source/tests/pt_expt/utils/test_env.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import importlib +import logging + +import torch + +import deepmd.env as common_env + + +def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None: + def raise_err(*_args, **_kwargs) -> None: + raise RuntimeError("boom") + + monkeypatch.setattr(common_env, "set_default_nthreads", lambda: None) + monkeypatch.setattr(common_env, "get_default_nthreads", lambda: (1, 1)) + monkeypatch.setattr(torch, "get_num_interop_threads", lambda: 2) + monkeypatch.setattr(torch, "set_num_interop_threads", raise_err) + monkeypatch.setattr(torch, "get_num_threads", lambda: 2) + monkeypatch.setattr(torch, "set_num_threads", raise_err) + + caplog.set_level(logging.WARNING, logger="deepmd.pt_expt.utils.env") + import deepmd.pt_expt.utils.env as env + + importlib.reload(env) + + messages = [record.getMessage() for record in caplog.records] + assert any("Could not set torch interop threads" in msg for msg in messages) + assert any("Could not set torch intra threads" in msg for msg in messages) From d8b2cf43faa618a15e1e590688bd002bf75abe28 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 22:54:34 +0800 Subject: [PATCH 26/35] make exclusion mask modules --- deepmd/pt_expt/utils/exclude_mask.py | 26 ++++++++++++++++--- .../pt_expt/utils/test_exclusion_mask.py | 8 ++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py index 15fbbc8e34..e757283e1c 100644 --- a/deepmd/pt_expt/utils/exclude_mask.py +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -12,15 +12,33 @@ ) -class AtomExcludeMask(AtomExcludeMaskDP): +class AtomExcludeMask(AtomExcludeMaskDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + AtomExcludeMaskDP.__init__(self, *args, **kwargs) + def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask": + if name == "type_mask" and "_buffers" in self.__dict__: value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + if name in self._buffers: + self._buffers[name] = value + return + self.register_buffer(name, value) + return return super().__setattr__(name, value) -class PairExcludeMask(PairExcludeMaskDP): +class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + PairExcludeMaskDP.__init__(self, *args, **kwargs) + def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask": + if name == "type_mask" and "_buffers" in self.__dict__: value = None if value is None else torch.as_tensor(value, device=env.DEVICE) + if name in self._buffers: + self._buffers[name] = value + return + self.register_buffer(name, value) + return return super().__setattr__(name, value) diff --git a/source/tests/pt_expt/utils/test_exclusion_mask.py b/source/tests/pt_expt/utils/test_exclusion_mask.py index b3707ef69d..6f836913af 100644 --- a/source/tests/pt_expt/utils/test_exclusion_mask.py +++ b/source/tests/pt_expt/utils/test_exclusion_mask.py @@ -39,6 +39,10 @@ def test_build_type_exclude_mask(self) -> None: mask = des.build_type_exclude_mask(torch.as_tensor(atype, device=env.DEVICE)) np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask) + def test_type_mask_is_buffer(self) -> None: + des = AtomExcludeMask(3, exclude_types=[0]) + assert "type_mask" in des.state_dict() + class TestPairExcludeMask(unittest.TestCase, TestCaseSingleFrameWithNlist): def setUp(self) -> None: @@ -62,3 +66,7 @@ def test_build_type_exclude_mask(self) -> None: torch.as_tensor(self.atype_ext, device=env.DEVICE), ) np.testing.assert_equal(mask.detach().cpu().numpy(), expected_mask) + + def test_type_mask_is_buffer(self) -> None: + des = PairExcludeMask(self.nt, exclude_types=[[0, 1]]) + assert "type_mask" in des.state_dict() From aeef15a99d6e9ccf8f9cbb93a149451adf7b615e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Fri, 6 Feb 2026 23:46:24 +0800 Subject: [PATCH 27/35] fix(pt-expt): clear params on None --- deepmd/pt_expt/utils/network.py | 6 ++++++ source/tests/pt_expt/utils/test_network.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 source/tests/pt_expt/utils/test_network.py diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 3effcfc488..721a511f5f 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -48,6 +48,12 @@ def __setattr__(self, name: str, value: Any) -> None: if name in {"w", "b", "idt"} and "_parameters" in self.__dict__: val = to_torch_array(value) if val is None: + if name in self._parameters: + self._parameters[name] = None + return + if name in self._buffers: + self._buffers[name] = None + return return super().__setattr__(name, None) if getattr(self, "trainable", False): param = ( diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py new file mode 100644 index 0000000000..ad7c2a7e3d --- /dev/null +++ b/source/tests/pt_expt/utils/test_network.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +from deepmd.pt_expt.utils.network import ( + NativeLayer, +) + + +def test_native_layer_clears_parameter_on_none() -> None: + layer = NativeLayer(2, 3, trainable=True) + assert layer.w is not None + layer.w = None + assert layer.w is None + assert layer._parameters.get("w") is None + + +def test_native_layer_clears_buffer_on_none() -> None: + layer = NativeLayer(2, 3, trainable=False) + assert layer.w is not None + layer.w = None + assert layer.w is None + assert layer._buffers.get("w") is None From 8bdb1f89eb509efc8d6133812ec5e9a2c678ae7b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sat, 7 Feb 2026 18:51:50 +0800 Subject: [PATCH 28/35] fix bug --- source/tests/pt/test_env_threads.py | 12 +++++++++--- source/tests/pt_expt/utils/test_env.py | 12 +++++++++--- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/source/tests/pt/test_env_threads.py b/source/tests/pt/test_env_threads.py index eb6604ceb8..50de1996d8 100644 --- a/source/tests/pt/test_env_threads.py +++ b/source/tests/pt/test_env_threads.py @@ -7,7 +7,7 @@ import deepmd.env as common_env -def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None: +def test_env_threads_guard_handles_runtimeerror(monkeypatch) -> None: def raise_err(*_args, **_kwargs) -> None: raise RuntimeError("boom") @@ -18,11 +18,17 @@ def raise_err(*_args, **_kwargs) -> None: monkeypatch.setattr(torch, "get_num_threads", lambda: 2) monkeypatch.setattr(torch, "set_num_threads", raise_err) - caplog.set_level(logging.WARNING, logger="deepmd.pt.utils.env") + messages: list[str] = [] + original_warning = logging.Logger.warning + + def capture_warning(self, msg, *args, **kwargs): # type: ignore[no-untyped-def] + messages.append(str(msg)) + return original_warning(self, msg, *args, **kwargs) + + monkeypatch.setattr(logging.Logger, "warning", capture_warning) import deepmd.pt.utils.env as env importlib.reload(env) - messages = [record.getMessage() for record in caplog.records] assert any("Could not set torch interop threads" in msg for msg in messages) assert any("Could not set torch intra threads" in msg for msg in messages) diff --git a/source/tests/pt_expt/utils/test_env.py b/source/tests/pt_expt/utils/test_env.py index bbdc696aea..a589c80ae1 100644 --- a/source/tests/pt_expt/utils/test_env.py +++ b/source/tests/pt_expt/utils/test_env.py @@ -7,7 +7,7 @@ import deepmd.env as common_env -def test_env_threads_guard_handles_runtimeerror(monkeypatch, caplog) -> None: +def test_env_threads_guard_handles_runtimeerror(monkeypatch) -> None: def raise_err(*_args, **_kwargs) -> None: raise RuntimeError("boom") @@ -18,11 +18,17 @@ def raise_err(*_args, **_kwargs) -> None: monkeypatch.setattr(torch, "get_num_threads", lambda: 2) monkeypatch.setattr(torch, "set_num_threads", raise_err) - caplog.set_level(logging.WARNING, logger="deepmd.pt_expt.utils.env") + messages: list[str] = [] + original_warning = logging.Logger.warning + + def capture_warning(self, msg, *args, **kwargs): # type: ignore[no-untyped-def] + messages.append(str(msg)) + return original_warning(self, msg, *args, **kwargs) + + monkeypatch.setattr(logging.Logger, "warning", capture_warning) import deepmd.pt_expt.utils.env as env importlib.reload(env) - messages = [record.getMessage() for record in caplog.records] assert any("Could not set torch interop threads" in msg for msg in messages) assert any("Could not set torch intra threads" in msg for msg in messages) From d3b01da5075898e807aa08a34e2fcb49d3b47749 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 16:36:56 +0800 Subject: [PATCH 29/35] utility to handel dpmodel -> pt_expt conversion --- deepmd/pt_expt/common.py | 253 ++++++++++++++++++++++++++- deepmd/pt_expt/descriptor/se_e2_a.py | 39 +---- deepmd/pt_expt/descriptor/se_r.py | 39 +---- deepmd/pt_expt/utils/__init__.py | 4 + deepmd/pt_expt/utils/exclude_mask.py | 39 +++-- deepmd/pt_expt/utils/network.py | 7 + 6 files changed, 293 insertions(+), 88 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index db8b94989b..e687fa8e48 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -1,4 +1,22 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +"""Common utilities for the pt_expt backend. + +This module provides the core infrastructure for automatically wrapping dpmodel +classes (array_api_compat-based) as PyTorch modules. The key insight is to +detect attributes by their **value type** rather than by hard-coded names: + +- numpy arrays → torch buffers (persistent state like statistics, masks) +- dpmodel objects → pt_expt torch.nn.Module wrappers (via registry lookup) +- None values → clear existing buffers + +This eliminates the need to manually enumerate attribute names in each wrapper's +__setattr__ method, making the codebase more maintainable when dpmodel adds +new attributes. +""" + +from collections.abc import ( + Callable, +) from typing import ( Any, overload, @@ -7,11 +25,203 @@ import numpy as np import torch -from deepmd.pt_expt.utils import ( - env, -) +# --------------------------------------------------------------------------- +# dpmodel → pt_expt converter registry +# --------------------------------------------------------------------------- +_DPMODEL_TO_PT_EXPT: dict[type, Callable[[Any], torch.nn.Module]] = {} +"""Registry mapping dpmodel classes to their pt_expt converter functions. + +This registry is populated at module import time via `register_dpmodel_mapping` +calls in each pt_expt wrapper module (e.g., exclude_mask.py, network.py). When +dpmodel_setattr encounters a dpmodel object, it looks up the object's type in +this registry to find the appropriate converter. + +Examples of registered mappings: +- AtomExcludeMaskDP → lambda v: AtomExcludeMask(v.ntypes, exclude_types=...) +- NetworkCollectionDP → lambda v: NetworkCollection.deserialize(v.serialize()) +""" + + +def register_dpmodel_mapping( + dpmodel_cls: type, converter: Callable[[Any], torch.nn.Module] +) -> None: + """Register a converter that turns a dpmodel instance into a pt_expt Module. + + This function is called at module import time by each pt_expt wrapper to + register how dpmodel objects should be converted when they're assigned as + attributes. The converter is a callable that takes a dpmodel instance and + returns the corresponding pt_expt torch.nn.Module wrapper. + + Parameters + ---------- + dpmodel_cls : type + The dpmodel class to register (e.g., AtomExcludeMaskDP, NetworkCollectionDP). + This is the key used for lookup in dpmodel_setattr. + converter : Callable[[Any], torch.nn.Module] + A callable that converts a dpmodel instance to a pt_expt module. + Common patterns: + - Reconstruct from constructor args: lambda v: PtExptClass(v.ntypes, ...) + - Round-trip via serialization: lambda v: PtExptClass.deserialize(v.serialize()) + + Notes + ----- + This function must be called AFTER the pt_expt wrapper class is defined but + BEFORE dpmodel_setattr might encounter instances of dpmodel_cls. In practice, + this means calling it immediately after the wrapper class definition at module + import time. + + Examples + -------- + >>> register_dpmodel_mapping( + ... AtomExcludeMaskDP, + ... lambda v: AtomExcludeMask( + ... v.ntypes, exclude_types=list(v.get_exclude_types()) + ... ), + ... ) + """ + _DPMODEL_TO_PT_EXPT[dpmodel_cls] = converter + + +def try_convert_module(value: Any) -> torch.nn.Module | None: + """Convert a dpmodel object to its pt_expt wrapper if a converter is registered. + + This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT + registry. If a converter is found, it invokes it to produce a torch.nn.Module + wrapper; otherwise it returns None. + + Parameters + ---------- + value : Any + The value to potentially convert. Typically a dpmodel object like + AtomExcludeMaskDP or NetworkCollectionDP. + + Returns + ------- + torch.nn.Module or None + The converted pt_expt module if a converter is registered for value's + type, otherwise None. + + Notes + ----- + This function uses exact type matching (not isinstance checks) to ensure + predictable behavior. Each dpmodel class must be explicitly registered via + register_dpmodel_mapping. + + The function is called by dpmodel_setattr when it encounters an object that + might be a dpmodel instance. If conversion succeeds, the caller should use + the converted module instead of the original value. + """ + converter = _DPMODEL_TO_PT_EXPT.get(type(value)) + if converter is not None: + return converter(value) + return None + + +def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool, Any]: + """Common __setattr__ logic for pt_expt wrappers around dpmodel classes. + This function implements automatic attribute detection by value type, eliminating + the need to hard-code attribute names in each wrapper's __setattr__ method. It + handles three cases: + 1. **numpy arrays → torch buffers**: Persistent state like statistics (davg, dstd) + or masks that should be saved in state_dict and moved with .to(device). + 2. **None values → clear buffers**: Setting an existing buffer to None. + 3. **dpmodel objects → pt_expt modules**: Nested dpmodel objects like + AtomExcludeMaskDP or NetworkCollectionDP are converted to their pt_expt + wrappers via the registry. + + Parameters + ---------- + obj : torch.nn.Module + The pt_expt wrapper object whose attribute is being set. Must be a + torch.nn.Module (caller should verify this). + name : str + The attribute name being set. + value : Any + The value being assigned. This function inspects the type to determine + how to handle it. + + Returns + ------- + handled : bool + True if the attribute has been fully set (caller should NOT call + super().__setattr__). False if the caller should forward the (possibly + converted) value to super().__setattr__(name, value). + value : Any + The value to use. May be converted (e.g., dpmodel object → pt_expt module) + or unchanged (e.g., scalar, list, or unregistered object). + + Notes + ----- + **Why this design is safe:** + + - In dpmodel, all persistent arrays use `self.xxx = np.array(...)`. Scalars + use `.item()`, lists use `.tolist()`. So `isinstance(value, np.ndarray)` + reliably identifies buffer-worthy attributes. + - torch.Tensor values assigned to existing buffers fall through to + torch.nn.Module.__setattr__, which correctly updates them. + - dpmodel objects are identified by registry lookup (exact type match), so + only explicitly registered types are converted. + - The function checks `"_buffers" in obj.__dict__` to ensure the object has + been initialized as a torch.nn.Module before attempting buffer operations. + + **Circular import resolution:** + + The function uses a deferred import `from deepmd.pt_expt.utils import env` + inside the function body. This breaks the circular dependency chain: + common.py → utils/__init__.py → exclude_mask.py → common.py. The import is + cached by Python after the first call, so there's no performance penalty. + + **Usage pattern:** + + Typical wrapper classes use this three-line pattern: + + >>> class MyWrapper(MyDPModel, torch.nn.Module): + ... def __setattr__(self, name, value): + ... handled, value = dpmodel_setattr(self, name, value) + ... if not handled: + ... super().__setattr__(name, value) + + Examples + -------- + >>> # Case 1: numpy array → buffer + >>> obj.davg = np.array([1.0, 2.0]) # becomes torch.Tensor buffer + >>> + >>> # Case 2: clear buffer + >>> obj.davg = None # sets buffer to None + >>> + >>> # Case 3: dpmodel object → pt_expt module + >>> obj.emask = AtomExcludeMaskDP(...) # becomes AtomExcludeMask module + """ + from deepmd.pt_expt.utils import env # deferred - avoids circular import + + # numpy array → torch buffer + if isinstance(value, np.ndarray) and "_buffers" in obj.__dict__: + tensor = torch.as_tensor(value, device=env.DEVICE) + if name in obj._buffers: + obj._buffers[name] = tensor + return True, tensor + obj.register_buffer(name, tensor) + return True, tensor + + # clear an existing buffer to None + if value is None and "_buffers" in obj.__dict__ and name in obj._buffers: + obj._buffers[name] = None + return True, None + + # dpmodel object → pt_expt module + if "_modules" in obj.__dict__: + converted = try_convert_module(value) + if converted is not None: + return False, converted + + return False, value + + +# --------------------------------------------------------------------------- +# Utility +# --------------------------------------------------------------------------- @overload def to_torch_array(array: np.ndarray) -> torch.Tensor: ... @@ -25,7 +235,42 @@ def to_torch_array(array: torch.Tensor) -> torch.Tensor: ... def to_torch_array(array: Any) -> torch.Tensor | None: - """Convert input to a torch tensor on the pt-expt device.""" + """Convert input to a torch tensor on the pt_expt device. + + This utility function handles conversion from various array-like types (numpy + arrays, torch tensors on different devices, etc.) to torch tensors on the + pt_expt backend's configured device. + + Parameters + ---------- + array : Any + The input to convert. Can be: + - None (returns None) + - torch.Tensor (moves to pt_expt device) + - numpy array or array-like (converts to torch.Tensor on pt_expt device) + + Returns + ------- + torch.Tensor or None + The input as a torch tensor on the pt_expt device (env.DEVICE), or None + if the input was None. + + Notes + ----- + This function uses the same deferred import pattern as dpmodel_setattr to + avoid circular dependencies. The env module determines the target device + (typically CPU for pt_expt). + + Examples + -------- + >>> import numpy as np + >>> arr = np.array([1.0, 2.0, 3.0]) + >>> tensor = to_torch_array(arr) + >>> tensor.device + device(type='cpu') # or whatever env.DEVICE is set to + """ + from deepmd.pt_expt.utils import env # deferred - avoids circular import + if array is None: return None if torch.is_tensor(array): diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 21c0a4eeb7..1ccb4d2dda 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -6,18 +6,12 @@ import torch from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.pt_expt.utils import ( - env, -) -from deepmd.pt_expt.utils.exclude_mask import ( - PairExcludeMask, -) -from deepmd.pt_expt.utils.network import ( - NetworkCollection, -) @BaseDescriptor.register("se_e2_a_expt") @@ -32,30 +26,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return torch.nn.Module.__call__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name in {"davg", "dstd"} and "_buffers" in self.__dict__: - tensor = ( - None if value is None else torch.as_tensor(value, device=env.DEVICE) - ) - if name in self._buffers: - self._buffers[name] = tensor - return - # Register on first assignment so buffers are in state_dict and moved by .to(). - self.register_buffer(name, tensor) - return - if name == "embeddings" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - if hasattr(value, "serialize"): - value = NetworkCollection.deserialize(value.serialize()) - elif isinstance(value, dict): - value = NetworkCollection.deserialize(value) - return super().__setattr__(name, value) - if name == "emask" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - value = PairExcludeMask( - self.ntypes, exclude_types=list(value.get_exclude_types()) - ) - return super().__setattr__(name, value) - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) def forward( self, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index 508785949c..7a406fb499 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -6,18 +6,12 @@ import torch from deepmd.dpmodel.descriptor.se_r import DescrptSeR as DescrptSeRDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) from deepmd.pt_expt.descriptor.base_descriptor import ( BaseDescriptor, ) -from deepmd.pt_expt.utils import ( - env, -) -from deepmd.pt_expt.utils.exclude_mask import ( - PairExcludeMask, -) -from deepmd.pt_expt.utils.network import ( - NetworkCollection, -) @BaseDescriptor.register("se_e2_r_expt") @@ -32,30 +26,9 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: return torch.nn.Module.__call__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name in {"davg", "dstd"} and "_buffers" in self.__dict__: - tensor = ( - None if value is None else torch.as_tensor(value, device=env.DEVICE) - ) - if name in self._buffers: - self._buffers[name] = tensor - return - # Register on first assignment so buffers are in state_dict and moved by .to(). - self.register_buffer(name, tensor) - return - if name == "embeddings" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - if hasattr(value, "serialize"): - value = NetworkCollection.deserialize(value.serialize()) - elif isinstance(value, dict): - value = NetworkCollection.deserialize(value) - return super().__setattr__(name, value) - if name == "emask" and "_modules" in self.__dict__: - if value is not None and not isinstance(value, torch.nn.Module): - value = PairExcludeMask( - self.ntypes, exclude_types=list(value.get_exclude_types()) - ) - return super().__setattr__(name, value) - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) def forward( self, diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index f90cf82249..93f765a27c 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -4,8 +4,12 @@ AtomExcludeMask, PairExcludeMask, ) +from .network import ( + NetworkCollection, +) __all__ = [ "AtomExcludeMask", + "NetworkCollection", "PairExcludeMask", ] diff --git a/deepmd/pt_expt/utils/exclude_mask.py b/deepmd/pt_expt/utils/exclude_mask.py index e757283e1c..4060b8c446 100644 --- a/deepmd/pt_expt/utils/exclude_mask.py +++ b/deepmd/pt_expt/utils/exclude_mask.py @@ -7,8 +7,9 @@ from deepmd.dpmodel.utils.exclude_mask import AtomExcludeMask as AtomExcludeMaskDP from deepmd.dpmodel.utils.exclude_mask import PairExcludeMask as PairExcludeMaskDP -from deepmd.pt_expt.utils import ( - env, +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, ) @@ -18,14 +19,15 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: AtomExcludeMaskDP.__init__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask" and "_buffers" in self.__dict__: - value = None if value is None else torch.as_tensor(value, device=env.DEVICE) - if name in self._buffers: - self._buffers[name] = value - return - self.register_buffer(name, value) - return - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + +register_dpmodel_mapping( + AtomExcludeMaskDP, + lambda v: AtomExcludeMask(v.ntypes, exclude_types=list(v.get_exclude_types())), +) class PairExcludeMask(PairExcludeMaskDP, torch.nn.Module): @@ -34,11 +36,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: PairExcludeMaskDP.__init__(self, *args, **kwargs) def __setattr__(self, name: str, value: Any) -> None: - if name == "type_mask" and "_buffers" in self.__dict__: - value = None if value is None else torch.as_tensor(value, device=env.DEVICE) - if name in self._buffers: - self._buffers[name] = value - return - self.register_buffer(name, value) - return - return super().__setattr__(name, value) + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + +register_dpmodel_mapping( + PairExcludeMaskDP, + lambda v: PairExcludeMask(v.ntypes, exclude_types=list(v.get_exclude_types())), +) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 721a511f5f..84d0024a85 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -19,6 +19,7 @@ make_multilayer_network, ) from deepmd.pt_expt.common import ( + register_dpmodel_mapping, to_torch_array, ) @@ -121,5 +122,11 @@ def __setitem__(self, key: int | tuple, value: Any) -> None: del self._module_networks[key_str] +register_dpmodel_mapping( + NetworkCollectionDP, + lambda v: NetworkCollection.deserialize(v.serialize()), +) + + class LayerNorm(LayerNormDP, NativeLayer): pass From 3452a2a8c0ef161d4bd827066c77e2b1bda7be77 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 17:06:19 +0800 Subject: [PATCH 30/35] fix to_numpy_array device --- deepmd/dpmodel/common.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/common.py b/deepmd/dpmodel/common.py index bd6f7dac49..dabbc34e01 100644 --- a/deepmd/dpmodel/common.py +++ b/deepmd/dpmodel/common.py @@ -121,10 +121,11 @@ def to_numpy_array(x: Optional["Array"]) -> np.ndarray | None: try: # asarray is not within Array API standard, so may fail return np.asarray(x) - except (ValueError, AttributeError): + except (ValueError, AttributeError, TypeError): xp = array_api_compat.array_namespace(x) # to fix BufferError: Cannot export readonly array since signalling readonly is unsupported by DLPack. - x = xp.asarray(x, copy=True) + # Move to CPU device to ensure numpy compatibility + x = xp.asarray(x, device="cpu", copy=True) return np.from_dlpack(x) From ba8e7abfaa6a364e14a3f881b9eadeef4b1ca3b6 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 21:17:53 +0800 Subject: [PATCH 31/35] chore(dpmodel,pt_expt): refactorize the implementation of embedding net --- deepmd/dpmodel/utils/network.py | 107 +++++++- deepmd/pt_expt/utils/network.py | 25 +- source/tests/common/dpmodel/test_network.py | 70 ++++++ source/tests/pt_expt/utils/test_network.py | 259 ++++++++++++++++++++ 4 files changed, 457 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index e712adfdd8..a5502b94cd 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -788,7 +788,112 @@ def deserialize(cls, data: dict) -> "EmbeddingNet": return EN -EmbeddingNet = make_embedding_network(NativeNet, NativeLayer) +class EmbeddingNet(NativeNet): + """The embedding network. + + Parameters + ---------- + in_dim + Input dimension. + neuron + The number of neurons in each layer. The output dimension + is the same as the dimension of the last layer. + activation_function + The activation function. + resnet_dt + Use time step at the resnet architecture. + precision + Floating point precision for the model parameters. + seed : int, optional + Random seed. + bias : bool, Optional + Whether to use bias in the embedding layer. + trainable : bool or list[bool], Optional + Whether the weights are trainable. If a list, each element + corresponds to a layer. + """ + + def __init__( + self, + in_dim: int, + neuron: list[int] = [24, 48, 96], + activation_function: str = "tanh", + resnet_dt: bool = False, + precision: str = DEFAULT_PRECISION, + seed: int | list[int] | None = None, + bias: bool = True, + trainable: bool | list[bool] = True, + ) -> None: + layers = [] + i_in = in_dim + if isinstance(trainable, bool): + trainable = [trainable] * len(neuron) + for idx, ii in enumerate(neuron): + i_ot = ii + layers.append( + NativeLayer( + i_in, + i_ot, + bias=bias, + use_timestep=resnet_dt, + activation_function=activation_function, + resnet=True, + precision=precision, + seed=child_seed(seed, idx), + trainable=trainable[idx], + ).serialize() + ) + i_in = i_ot + super().__init__(layers) + self.in_dim = in_dim + self.neuron = neuron + self.activation_function = activation_function + self.resnet_dt = resnet_dt + self.precision = precision + self.bias = bias + + def serialize(self) -> dict: + """Serialize the network to a dict. + + Returns + ------- + dict + The serialized network. + """ + return { + "@class": "EmbeddingNetwork", + "@version": 2, + "in_dim": self.in_dim, + "neuron": self.neuron.copy(), + "activation_function": self.activation_function, + "resnet_dt": self.resnet_dt, + "bias": self.bias, + # make deterministic + "precision": np.dtype(PRECISION_DICT[self.precision]).name, + "layers": [layer.serialize() for layer in self.layers], + } + + @classmethod + def deserialize(cls, data: dict) -> "EmbeddingNet": + """Deserialize the network from a dict. + + Parameters + ---------- + data : dict + The dict to deserialize from. + """ + data = data.copy() + check_version_compatibility(data.pop("@version", 1), 2, 1) + data.pop("@class", None) + layers = data.pop("layers") + obj = cls(**data) + # Reinitialize layers from serialized data, using the same layer type + # that __init__ created (respects subclass overrides via MRO). + layer_type = type(obj.layers[0]) + obj.layers = type(obj.layers)( + [layer_type.deserialize(layer) for layer in layers] + ) + return obj def make_fitting_network( diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 84d0024a85..b115214056 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -10,11 +10,11 @@ from deepmd.dpmodel.common import ( NativeOP, ) +from deepmd.dpmodel.utils.network import EmbeddingNet as EmbeddingNetDP from deepmd.dpmodel.utils.network import LayerNorm as LayerNormDP from deepmd.dpmodel.utils.network import NativeLayer as NativeLayerDP from deepmd.dpmodel.utils.network import NetworkCollection as NetworkCollectionDP from deepmd.dpmodel.utils.network import ( - make_embedding_network, make_fitting_network, make_multilayer_network, ) @@ -91,8 +91,27 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.call(x) -class EmbeddingNet(make_embedding_network(NativeNet, NativeLayer)): - pass +class EmbeddingNet(EmbeddingNetDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + EmbeddingNetDP.__init__(self, *args, **kwargs) + # EmbeddingNetDP.__init__ creates dpmodel NativeLayer instances. + # Convert to pt_expt NativeLayer and wrap in ModuleList. + self.layers = torch.nn.ModuleList( + [NativeLayer.deserialize(layer.serialize()) for layer in self.layers] + ) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return torch.nn.Module.__call__(self, *args, **kwargs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +register_dpmodel_mapping( + EmbeddingNetDP, + lambda v: EmbeddingNet.deserialize(v.serialize()), +) class FittingNet(make_fitting_network(EmbeddingNet, NativeNet, NativeLayer)): diff --git a/source/tests/common/dpmodel/test_network.py b/source/tests/common/dpmodel/test_network.py index 1ea5b1fdf9..3a95dd7af0 100644 --- a/source/tests/common/dpmodel/test_network.py +++ b/source/tests/common/dpmodel/test_network.py @@ -180,6 +180,76 @@ def test_embedding_net(self) -> None: inp = np.ones([ni], dtype=get_xp_precision(np, prec)) np.testing.assert_allclose(en0.call(inp), en1.call(inp)) + def test_is_concrete_class(self) -> None: + """Verify EmbeddingNet is a concrete class, not factory-generated.""" + in_dim = 4 + neuron = [8, 16, 32] + net = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + ) + # Check it's the actual EmbeddingNet class, not a dynamic class + self.assertEqual(net.__class__.__name__, "EmbeddingNet") + self.assertEqual(net.__class__.__module__, "deepmd.dpmodel.utils.network") + # Verify it has the expected attributes + self.assertEqual(net.in_dim, in_dim) + self.assertEqual(net.neuron, neuron) + self.assertEqual(net.activation_function, "tanh") + self.assertEqual(net.resnet_dt, True) + self.assertEqual(len(net.layers), len(neuron)) + + def test_forward_pass(self) -> None: + """Test EmbeddingNet forward pass produces correct shapes.""" + in_dim = 4 + neuron = [8, 16, 32] + net = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + activation_function="tanh", + resnet_dt=True, + precision="float64", + ) + rng = np.random.default_rng() + x = rng.standard_normal((5, in_dim)) + out = net.call(x) + self.assertEqual(out.shape, (5, neuron[-1])) + self.assertEqual(out.dtype, np.float64) + + def test_trainable_parameter_variants(self) -> None: + """Test EmbeddingNet with different trainable configurations.""" + in_dim = 4 + neuron = [8, 16] + + # All trainable + net_trainable = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + trainable=True, + ) + for layer in net_trainable.layers: + self.assertTrue(layer.trainable) + + # All frozen + net_frozen = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + trainable=False, + ) + for layer in net_frozen.layers: + self.assertFalse(layer.trainable) + + # Mixed trainable + net_mixed = EmbeddingNet( + in_dim=in_dim, + neuron=neuron, + trainable=[True, False], + ) + self.assertTrue(net_mixed.layers[0].trainable) + self.assertFalse(net_mixed.layers[1].trainable) + class TestFittingNet(unittest.TestCase): def test_fitting_net(self) -> None: diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py index ad7c2a7e3d..1510f28bd3 100644 --- a/source/tests/pt_expt/utils/test_network.py +++ b/source/tests/pt_expt/utils/test_network.py @@ -1,9 +1,19 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +import unittest +import numpy as np +import torch + +from deepmd.dpmodel.utils.network import EmbeddingNet as DPEmbeddingNet from deepmd.pt_expt.utils.network import ( + EmbeddingNet, NativeLayer, ) +from ...seed import ( + GLOBAL_SEED, +) + def test_native_layer_clears_parameter_on_none() -> None: layer = NativeLayer(2, 3, trainable=True) @@ -19,3 +29,252 @@ def test_native_layer_clears_buffer_on_none() -> None: layer.w = None assert layer.w is None assert layer._buffers.get("w") is None + + +class TestEmbeddingNetRefactor(unittest.TestCase): + """Tests for the refactored EmbeddingNet pt_expt wrapper and integration.""" + + def setUp(self) -> None: + self.in_dim = 4 + self.neuron = [8, 16, 32] + self.activation = "tanh" + self.resnet_dt = True + self.precision = "float64" + + def test_pt_expt_embedding_net_wraps_dpmodel(self) -> None: + """Verify pt_expt EmbeddingNet correctly wraps dpmodel.""" + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + # Check it's a torch.nn.Module + self.assertIsInstance(net, torch.nn.Module) + # Check it's also a DPEmbeddingNet + self.assertIsInstance(net, DPEmbeddingNet) + # Check layers are converted to pt_expt NativeLayer (torch modules) + self.assertIsInstance(net.layers, torch.nn.ModuleList) + for layer in net.layers: + self.assertIsInstance(layer, NativeLayer) + self.assertIsInstance(layer, torch.nn.Module) + + def test_pt_expt_embedding_net_forward(self) -> None: + """Test pt_expt EmbeddingNet forward pass returns torch.Tensor.""" + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + x = torch.randn(5, self.in_dim, dtype=torch.float64) + out = net(x) + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(out.shape, (5, self.neuron[-1])) + self.assertEqual(out.dtype, torch.float64) + + def test_serialization_round_trip_pt_expt(self) -> None: + """Test pt_expt EmbeddingNet serialization/deserialization.""" + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + x = torch.randn(5, self.in_dim, dtype=torch.float64) + out1 = net(x) + + # Serialize and deserialize + serialized = net.serialize() + net2 = EmbeddingNet.deserialize(serialized) + + # Verify layers are still pt_expt NativeLayer modules + self.assertIsInstance(net2.layers, torch.nn.ModuleList) + for layer in net2.layers: + self.assertIsInstance(layer, NativeLayer) + + out2 = net2(x) + np.testing.assert_allclose( + out1.detach().cpu().numpy(), + out2.detach().cpu().numpy(), + ) + + def test_deserialize_preserves_layer_type(self) -> None: + """Test that deserialize uses type(obj.layers[0]) to preserve subclass layers. + + This is the key fix: dpmodel's deserialize no longer hardcodes + super(EmbeddingNet, obj).__init__(layers), which would overwrite + pt_expt's converted layers. Instead it uses type(obj.layers[0]) + to respect the subclass's layer type. + """ + # Create pt_expt EmbeddingNet + net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Verify layers are pt_expt NativeLayer (torch modules) + for layer in net.layers: + self.assertIsInstance(layer, torch.nn.Module) + self.assertTrue(hasattr(layer, "_parameters")) + + # Deserialize + serialized = net.serialize() + net2 = EmbeddingNet.deserialize(serialized) + + # Verify deserialized layers are STILL pt_expt NativeLayer, not dpmodel + for layer in net2.layers: + self.assertIsInstance(layer, torch.nn.Module) + self.assertTrue(hasattr(layer, "_parameters")) + # This would fail if deserialize used hardcoded dpmodel layers + self.assertIsInstance(layer, NativeLayer) + + def test_cross_backend_consistency(self) -> None: + """Test numerical consistency between dpmodel and pt_expt EmbeddingNet.""" + # Create both with same seed + dp_net = DPEmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + pt_net = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Test forward pass + rng = np.random.default_rng() + x_np = rng.standard_normal((5, self.in_dim)) + x_torch = torch.from_numpy(x_np) + + out_dp = dp_net.call(x_np) + out_pt = pt_net(x_torch).detach().cpu().numpy() + + np.testing.assert_allclose(out_dp, out_pt, rtol=1e-10, atol=1e-10) + + def test_registry_converts_dpmodel_to_pt_expt(self) -> None: + """Test that the registry auto-converts dpmodel EmbeddingNet to pt_expt.""" + from deepmd.pt_expt.common import ( + try_convert_module, + ) + + # Create dpmodel EmbeddingNet + dp_net = DPEmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Try to convert via registry + converted = try_convert_module(dp_net) + + # Should return pt_expt EmbeddingNet + self.assertIsNotNone(converted) + self.assertIsInstance(converted, torch.nn.Module) + self.assertIsInstance(converted, EmbeddingNet) + + # Verify layers are pt_expt NativeLayer + for layer in converted.layers: + self.assertIsInstance(layer, NativeLayer) + self.assertIsInstance(layer, torch.nn.Module) + + def test_auto_conversion_in_setattr(self) -> None: + """Test that dpmodel_setattr auto-converts EmbeddingNet attributes.""" + from deepmd.pt_expt.common import ( + dpmodel_setattr, + ) + + # Create a simple torch module + class TestModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.dummy = None + + obj = TestModule() + + # Create dpmodel EmbeddingNet + dp_net = DPEmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + seed=GLOBAL_SEED, + ) + + # Use dpmodel_setattr to set it + handled, value = dpmodel_setattr(obj, "embedding_net", dp_net) + + # Should not be handled (returns converted value for caller to set) + self.assertFalse(handled) + # Value should be converted to pt_expt EmbeddingNet + self.assertIsInstance(value, torch.nn.Module) + self.assertIsInstance(value, EmbeddingNet) + + def test_trainable_parameter_handling(self) -> None: + """Test that trainable parameters work correctly in pt_expt.""" + # Test with trainable=True + net_trainable = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + trainable=True, + seed=GLOBAL_SEED, + ) + + # Count trainable parameters + param_count = sum( + p.numel() for p in net_trainable.parameters() if p.requires_grad + ) + self.assertGreater(param_count, 0) + + # Check all layer parameters are trainable + for layer in net_trainable.layers: + if layer.w is not None: + self.assertTrue(layer.w.requires_grad) + if layer.b is not None: + self.assertTrue(layer.b.requires_grad) + + # Test with trainable=False + net_frozen = EmbeddingNet( + in_dim=self.in_dim, + neuron=self.neuron, + activation_function=self.activation, + resnet_dt=self.resnet_dt, + precision=self.precision, + trainable=False, + seed=GLOBAL_SEED, + ) + + # Count trainable parameters (should be 0) + param_count_frozen = sum( + p.numel() for p in net_frozen.parameters() if p.requires_grad + ) + self.assertEqual(param_count_frozen, 0) + + # Check all layer weights are buffers, not parameters + for layer in net_frozen.layers: + if layer.w is not None: + self.assertFalse(layer.w.requires_grad) From 621c7ccbbcc31c064b4f186af9bd27fb60211787 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 22:15:11 +0800 Subject: [PATCH 32/35] feat: se_t and se_t_tebd descriptors for the pytroch exportable backend. --- deepmd/dpmodel/descriptor/descriptor.py | 21 ++- deepmd/dpmodel/descriptor/dpa1.py | 2 + deepmd/dpmodel/descriptor/se_t.py | 6 +- deepmd/dpmodel/descriptor/se_t_tebd.py | 6 +- deepmd/dpmodel/utils/type_embed.py | 51 ++++-- deepmd/pt_expt/common.py | 7 +- deepmd/pt_expt/descriptor/__init__.py | 10 ++ deepmd/pt_expt/descriptor/se_t.py | 56 ++++++ deepmd/pt_expt/descriptor/se_t_tebd.py | 54 ++++++ deepmd/pt_expt/descriptor/se_t_tebd_block.py | 31 ++++ deepmd/pt_expt/utils/__init__.py | 4 + deepmd/pt_expt/utils/type_embed.py | 41 +++++ .../tests/consistent/descriptor/test_se_t.py | 25 +++ .../consistent/descriptor/test_se_t_tebd.py | 35 ++++ source/tests/pt_expt/descriptor/test_se_t.py | 134 +++++++++++++++ .../pt_expt/descriptor/test_se_t_tebd.py | 159 ++++++++++++++++++ 16 files changed, 620 insertions(+), 22 deletions(-) create mode 100644 deepmd/pt_expt/descriptor/se_t.py create mode 100644 deepmd/pt_expt/descriptor/se_t_tebd.py create mode 100644 deepmd/pt_expt/descriptor/se_t_tebd_block.py create mode 100644 deepmd/pt_expt/utils/type_embed.py create mode 100644 source/tests/pt_expt/descriptor/test_se_t.py create mode 100644 source/tests/pt_expt/descriptor/test_se_t_tebd.py diff --git a/deepmd/dpmodel/descriptor/descriptor.py b/deepmd/dpmodel/descriptor/descriptor.py index 9b0e067972..ad49a7cb8d 100644 --- a/deepmd/dpmodel/descriptor/descriptor.py +++ b/deepmd/dpmodel/descriptor/descriptor.py @@ -12,7 +12,7 @@ NoReturn, ) -import numpy as np +import array_api_compat from deepmd.dpmodel.array_api import ( Array, @@ -173,7 +173,18 @@ def extend_descrpt_stat( extend_dstd = des_with_stat["dstd"] else: extend_shape = [len(type_map), *list(des["davg"].shape[1:])] - extend_davg = np.zeros(extend_shape, dtype=des["davg"].dtype) - extend_dstd = np.ones(extend_shape, dtype=des["dstd"].dtype) - des["davg"] = np.concatenate([des["davg"], extend_davg], axis=0) - des["dstd"] = np.concatenate([des["dstd"], extend_dstd], axis=0) + # Use array_api_compat to infer device and dtype from context + xp = array_api_compat.array_namespace(des["davg"]) + extend_davg = xp.zeros( + extend_shape, + dtype=des["davg"].dtype, + device=array_api_compat.device(des["davg"]), + ) + extend_dstd = xp.ones( + extend_shape, + dtype=des["dstd"].dtype, + device=array_api_compat.device(des["dstd"]), + ) + xp = array_api_compat.array_namespace(des["davg"]) + des["davg"] = xp.concat([des["davg"], extend_davg], axis=0) + des["dstd"] = xp.concat([des["dstd"], extend_dstd], axis=0) diff --git a/deepmd/dpmodel/descriptor/dpa1.py b/deepmd/dpmodel/descriptor/dpa1.py index f09ab24dfe..2f9aa69b62 100644 --- a/deepmd/dpmodel/descriptor/dpa1.py +++ b/deepmd/dpmodel/descriptor/dpa1.py @@ -1049,6 +1049,8 @@ def call( idx_j = xp.reshape(nei_type, (-1,)) # (nf x nl x nnei) x ng idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng)) + # Cast to int64 for PyTorch backend (take_along_dim requires Long indices) + idx = xp.astype(idx, xp.int64) # (ntypes) * ntypes * nt type_embedding_nei = xp.tile( xp.reshape(type_embedding, (1, ntypes_with_padding, nt)), diff --git a/deepmd/dpmodel/descriptor/se_t.py b/deepmd/dpmodel/descriptor/se_t.py index 749a5da188..95b66759de 100644 --- a/deepmd/dpmodel/descriptor/se_t.py +++ b/deepmd/dpmodel/descriptor/se_t.py @@ -369,7 +369,11 @@ def call( sec = self.sel_cumsum ng = self.neuron[-1] - result = xp.zeros([nf * nloc, ng], dtype=get_xp_precision(xp, self.precision)) + result = xp.zeros( + [nf * nloc, ng], + dtype=get_xp_precision(xp, self.precision), + device=array_api_compat.device(coord_ext), + ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # merge nf and nloc axis, so for type_one_side == False, # we don't require atype is the same in all frames diff --git a/deepmd/dpmodel/descriptor/se_t_tebd.py b/deepmd/dpmodel/descriptor/se_t_tebd.py index 0a2d46c015..05c9dc77af 100644 --- a/deepmd/dpmodel/descriptor/se_t_tebd.py +++ b/deepmd/dpmodel/descriptor/se_t_tebd.py @@ -769,7 +769,9 @@ def call( sw = xp.where( nlist_mask[:, :, None], xp.reshape(sw, (nf * nloc, nnei, 1)), - xp.zeros((nf * nloc, nnei, 1), dtype=sw.dtype), + xp.zeros( + (nf * nloc, nnei, 1), dtype=sw.dtype, device=array_api_compat.device(sw) + ), ) # nfnl x nnei x 4 @@ -832,6 +834,8 @@ def call( # (nf x nl x nt_i x nt_j) x ng idx = xp.tile(xp.reshape((idx_i + idx_j), (-1, 1)), (1, ng)) + # Cast to int64 for PyTorch backend (take_along_dim requires Long indices) + idx = xp.astype(idx, xp.int64) # ntypes * (ntypes) * nt type_embedding_i = xp.tile( diff --git a/deepmd/dpmodel/utils/type_embed.py b/deepmd/dpmodel/utils/type_embed.py index a1b698b698..0d01ccc9d8 100644 --- a/deepmd/dpmodel/utils/type_embed.py +++ b/deepmd/dpmodel/utils/type_embed.py @@ -102,11 +102,21 @@ def call(self) -> Array: sample_array = self.embedding_net[0]["w"] xp = array_api_compat.array_namespace(sample_array) if not self.use_econf_tebd: - embed = self.embedding_net(xp.eye(self.ntypes, dtype=sample_array.dtype)) + embed = self.embedding_net( + xp.eye( + self.ntypes, + dtype=sample_array.dtype, + device=array_api_compat.device(sample_array), + ) + ) else: embed = self.embedding_net(self.econf_tebd) if self.padding: - embed_pad = xp.zeros((1, embed.shape[-1]), dtype=embed.dtype) + embed_pad = xp.zeros( + (1, embed.shape[-1]), + dtype=embed.dtype, + device=array_api_compat.device(embed), + ) embed = xp.concat([embed, embed_pad], axis=0) return embed @@ -182,32 +192,51 @@ def change_type_map( "'activation_function' must be 'Linear' when performing type changing on resnet structure!" ) first_layer_matrix = self.embedding_net.layers[0].w - eye_vector = np.eye(self.ntypes, dtype=PRECISION_DICT[self.precision]) + # Use array_api_compat to handle both numpy and torch + xp = array_api_compat.array_namespace(first_layer_matrix) + eye_vector = xp.eye( + self.ntypes, + dtype=first_layer_matrix.dtype, + device=array_api_compat.device(first_layer_matrix), + ) # preprocess for resnet connection if self.neuron[0] == self.ntypes: - first_layer_matrix += eye_vector + first_layer_matrix = first_layer_matrix + eye_vector elif self.neuron[0] == self.ntypes * 2: - first_layer_matrix += np.concatenate([eye_vector, eye_vector], axis=-1) + first_layer_matrix = first_layer_matrix + xp.concat( + [eye_vector, eye_vector], axis=-1 + ) # randomly initialize params for the unseen types - rng = np.random.default_rng() if has_new_type: - extend_type_params = rng.random( + # Create random params with same dtype and device as first_layer_matrix + extend_type_params = np.random.default_rng().random( [len(type_map), first_layer_matrix.shape[-1]], + dtype=PRECISION_DICT[self.precision], + ) + extend_type_params = xp.asarray( + extend_type_params, dtype=first_layer_matrix.dtype, + device=array_api_compat.device(first_layer_matrix), ) - first_layer_matrix = np.concatenate( + first_layer_matrix = xp.concat( [first_layer_matrix, extend_type_params], axis=0 ) first_layer_matrix = first_layer_matrix[remap_index] new_ntypes = len(type_map) - eye_vector = np.eye(new_ntypes, dtype=PRECISION_DICT[self.precision]) + eye_vector = xp.eye( + new_ntypes, + dtype=first_layer_matrix.dtype, + device=array_api_compat.device(first_layer_matrix), + ) if self.neuron[0] == new_ntypes: - first_layer_matrix -= eye_vector + first_layer_matrix = first_layer_matrix - eye_vector elif self.neuron[0] == new_ntypes * 2: - first_layer_matrix -= np.concatenate([eye_vector, eye_vector], axis=-1) + first_layer_matrix = first_layer_matrix - xp.concat( + [eye_vector, eye_vector], axis=-1 + ) self.embedding_net.layers[0].num_in = new_ntypes self.embedding_net.layers[0].w = first_layer_matrix diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index e687fa8e48..c005804d74 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -85,7 +85,7 @@ def register_dpmodel_mapping( def try_convert_module(value: Any) -> torch.nn.Module | None: """Convert a dpmodel object to its pt_expt wrapper if a converter is registered. - This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT + This function looks up the type of *value* in the _DPMODEL_TO_PT_EXPT registry. If a converter is found, it invokes it to produce a torch.nn.Module wrapper; otherwise it returns None. @@ -103,9 +103,8 @@ def try_convert_module(value: Any) -> torch.nn.Module | None: Notes ----- - This function uses exact type matching (not isinstance checks) to ensure - predictable behavior. Each dpmodel class must be explicitly registered via - register_dpmodel_mapping. + This function uses exact type matching. Each dpmodel class must be explicitly + registered via register_dpmodel_mapping. The function is called by dpmodel_setattr when it encounters an object that might be a dpmodel instance. If conversion succeeds, the caller should use diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py index 4d9469a93a..7feda7d703 100644 --- a/deepmd/pt_expt/descriptor/__init__.py +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -1,4 +1,6 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +# Import to register converters +from . import se_t_tebd_block # noqa: F401 from .base_descriptor import ( BaseDescriptor, ) @@ -8,9 +10,17 @@ from .se_r import ( DescrptSeR, ) +from .se_t import ( + DescrptSeT, +) +from .se_t_tebd import ( + DescrptSeTTebd, +) __all__ = [ "BaseDescriptor", "DescrptSeA", "DescrptSeR", + "DescrptSeT", + "DescrptSeTTebd", ] diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py new file mode 100644 index 0000000000..604dd6a5c0 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -0,0 +1,56 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.descriptor.se_t import DescrptSeT as DescrptSeTDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("se_e3_expt") +@BaseDescriptor.register("se_at_expt") +@BaseDescriptor.register("se_a_3be_expt") +class DescrptSeT(DescrptSeTDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptSeTDP.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + extended_atype_embd: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + type_embedding: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + del extended_atype_embd, type_embedding + descrpt, rot_mat, g2, h2, sw = self.call( + extended_coord, + extended_atype, + nlist, + mapping=mapping, + ) + return descrpt, rot_mat, g2, h2, sw diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py new file mode 100644 index 0000000000..235ba1bfe9 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, +) +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("se_e3_tebd_expt") +class DescrptSeTTebd(DescrptSeTTebdDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptSeTTebdDP.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + extended_atype_embd: torch.Tensor | None = None, + mapping: torch.Tensor | None = None, + type_embedding: torch.Tensor | None = None, + ) -> tuple[ + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + torch.Tensor | None, + ]: + del extended_atype_embd, mapping, type_embedding + descrpt, rot_mat, g2, h2, sw = self.call( + extended_coord, + extended_atype, + nlist, + mapping=None, + ) + return descrpt, rot_mat, g2, h2, sw diff --git a/deepmd/pt_expt/descriptor/se_t_tebd_block.py b/deepmd/pt_expt/descriptor/se_t_tebd_block.py new file mode 100644 index 0000000000..7a0faaf170 --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_t_tebd_block.py @@ -0,0 +1,31 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.descriptor.se_t_tebd import ( + DescrptBlockSeTTebd as DescrptBlockSeTTebdDP, +) +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) + + +class DescrptBlockSeTTebd(DescrptBlockSeTTebdDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + DescrptBlockSeTTebdDP.__init__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + +register_dpmodel_mapping( + DescrptBlockSeTTebdDP, + lambda v: DescrptBlockSeTTebd.deserialize(v.serialize()), +) diff --git a/deepmd/pt_expt/utils/__init__.py b/deepmd/pt_expt/utils/__init__.py index 93f765a27c..99ae559e30 100644 --- a/deepmd/pt_expt/utils/__init__.py +++ b/deepmd/pt_expt/utils/__init__.py @@ -7,9 +7,13 @@ from .network import ( NetworkCollection, ) +from .type_embed import ( + TypeEmbedNet, +) __all__ = [ "AtomExcludeMask", "NetworkCollection", "PairExcludeMask", + "TypeEmbedNet", ] diff --git a/deepmd/pt_expt/utils/type_embed.py b/deepmd/pt_expt/utils/type_embed.py new file mode 100644 index 0000000000..da4cf09028 --- /dev/null +++ b/deepmd/pt_expt/utils/type_embed.py @@ -0,0 +1,41 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +import torch + +from deepmd.dpmodel.utils.type_embed import TypeEmbedNet as TypeEmbedNetDP +from deepmd.pt_expt.common import ( + dpmodel_setattr, + register_dpmodel_mapping, +) + +# Import network to ensure EmbeddingNet is registered before TypeEmbedNet is used +from deepmd.pt_expt.utils import network # noqa: F401 + + +class TypeEmbedNet(TypeEmbedNetDP, torch.nn.Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + torch.nn.Module.__init__(self) + TypeEmbedNetDP.__init__(self, *args, **kwargs) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. + return torch.nn.Module.__call__(self, *args, **kwargs) + + def __setattr__(self, name: str, value: Any) -> None: + # Use common dpmodel_setattr which handles embedding_net conversion via registry + handled, value = dpmodel_setattr(self, name, value) + if not handled: + super().__setattr__(name, value) + + def forward(self) -> torch.Tensor: + # Call dpmodel's implementation (now with proper device handling) + return self.call() + + +register_dpmodel_mapping( + TypeEmbedNetDP, + lambda v: TypeEmbedNet.deserialize(v.serialize()), +) diff --git a/source/tests/consistent/descriptor/test_se_t.py b/source/tests/consistent/descriptor/test_se_t.py index 56655cabe1..2ef9864f47 100644 --- a/source/tests/consistent/descriptor/test_se_t.py +++ b/source/tests/consistent/descriptor/test_se_t.py @@ -15,6 +15,7 @@ INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, @@ -27,6 +28,10 @@ from deepmd.pt.model.descriptor.se_t import DescrptSeT as DescrptSeTPT else: DescrptSeTPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_t import DescrptSeT as DescrptSeTPTExpt +else: + DescrptSeTPTExpt = None if INSTALLED_TF: from deepmd.tf.descriptor.se_t import DescrptSeT as DescrptSeTTF else: @@ -91,6 +96,16 @@ def skip_dp(self) -> bool: ) = self.param return CommonTest.skip_dp + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + excluded_types, + precision, + env_protection, + ) = self.param + return CommonTest.skip_pt_expt + @property def skip_tf(self) -> bool: ( @@ -107,6 +122,7 @@ def skip_tf(self) -> bool: tf_class = DescrptSeTTF dp_class = DescrptSeTDP pt_class = DescrptSeTPT + pt_expt_class = DescrptSeTPTExpt jax_class = DescrptSeTJAX array_api_strict_class = DescrptSeTStrict args = descrpt_se_t_args() @@ -183,6 +199,15 @@ def eval_pt(self, pt_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index 9cdca9bde3..a60ff1fdad 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -19,6 +19,7 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, CommonTest, parameterized, ) @@ -30,6 +31,12 @@ from deepmd.pt.model.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdPT else: DescrptSeTTebdPT = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_t_tebd import ( + DescrptSeTTebd as DescrptSeTTebdPTExpt, + ) +else: + DescrptSeTTebdPTExpt = None DescrptSeTTebdTF = None if INSTALLED_JAX: from deepmd.jax.descriptor.se_t_tebd import DescrptSeTTebd as DescrptSeTTebdJAX @@ -117,6 +124,23 @@ def skip_pt(self) -> bool: ) = self.param return CommonTest.skip_pt + @property + def skip_pt_expt(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + excluded_types, + env_protection, + set_davg_zero, + smooth, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return CommonTest.skip_pt_expt + @property def skip_dp(self) -> bool: ( @@ -158,6 +182,7 @@ def skip_tf(self) -> bool: tf_class = DescrptSeTTebdTF dp_class = DescrptSeTTebdDP pt_class = DescrptSeTTebdPT + pt_expt_class = DescrptSeTTebdPTExpt pd_class = DescrptSeTTebdPD jax_class = DescrptSeTTebdJAX array_api_strict_class = DescrptSeTTebdStrict @@ -240,6 +265,16 @@ def eval_pt(self, pt_obj: Any) -> Any: mixed_types=True, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, diff --git a/source/tests/pt_expt/descriptor/test_se_t.py b/source/tests/pt_expt/descriptor/test_se_t.py new file mode 100644 index 0000000000..921f10a54a --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_se_t.py @@ -0,0 +1,134 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import DescrptSeT as DPDescrptSeT +from deepmd.pt_expt.descriptor.se_t import ( + DescrptSeT, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptSeT(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeT( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeT.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeT.deserialize(dd0.serialize()) + rd2, gr2, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + # se_t returns None for gr/g2/h2, only compare rd and sw + np.testing.assert_allclose( + rd1.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + sw1.detach().cpu().numpy(), + sw2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [False, True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeT( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) diff --git a/source/tests/pt_expt/descriptor/test_se_t_tebd.py b/source/tests/pt_expt/descriptor/test_se_t_tebd.py new file mode 100644 index 0000000000..e774cda6d5 --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_se_t_tebd.py @@ -0,0 +1,159 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import itertools +import unittest + +import numpy as np +import torch + +from deepmd.dpmodel.descriptor import DescrptSeTTebd as DPDescrptSeTTebd +from deepmd.pt_expt.descriptor.se_t_tebd import ( + DescrptSeTTebd, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptSeTTebd(unittest.TestCase, TestCaseSingleFrameWithNlist): + def setUp(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + def test_consistency(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [True], # SeTTebd typically uses resnet_dt=True + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeTTebd( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + + # Type embedding input + type_embedding = torch.randn( + [self.nt, dd0.tebd_dim], dtype=dtype, device=self.device + ) + + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + type_embedding=type_embedding, + ) + dd1 = DescrptSeTTebd.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + type_embedding=type_embedding, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeTTebd.deserialize(dd0.serialize()) + rd2, gr2, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + # se_t_tebd should return gr and sw, compare only descriptor and sw for now + # TODO: investigate why gr is None + np.testing.assert_allclose( + rd1.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + if gr1 is not None and gr2 is not None: + np.testing.assert_allclose( + gr1.detach().cpu().numpy(), + gr2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + sw1.detach().cpu().numpy(), + sw2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + def test_exportable(self) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + for idt, prec in itertools.product( + [True], + ["float64", "float32"], + ): + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeTTebd( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + + type_embedding = torch.randn( + [self.nt, dd0.tebd_dim], dtype=dtype, device=self.device + ) + + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + type_embedding, + ) + torch.export.export(dd0, inputs) From faa4026da8ee7aea5430aa2c6b992f6c0b16d6bf Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 22:21:28 +0800 Subject: [PATCH 33/35] fix bug --- source/tests/pt_expt/utils/test_network.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py index 1510f28bd3..c80946a018 100644 --- a/source/tests/pt_expt/utils/test_network.py +++ b/source/tests/pt_expt/utils/test_network.py @@ -5,6 +5,9 @@ import torch from deepmd.dpmodel.utils.network import EmbeddingNet as DPEmbeddingNet +from deepmd.pt_expt.utils import ( + env, +) from deepmd.pt_expt.utils.network import ( EmbeddingNet, NativeLayer, @@ -87,7 +90,7 @@ def test_serialization_round_trip_pt_expt(self) -> None: precision=self.precision, seed=GLOBAL_SEED, ) - x = torch.randn(5, self.in_dim, dtype=torch.float64) + x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE) out1 = net(x) # Serialize and deserialize From b2220734371c30290ddb7ac0c9fe7e53dd1878b7 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Sun, 8 Feb 2026 23:09:19 +0800 Subject: [PATCH 34/35] fix bug --- source/tests/pt_expt/utils/test_network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/source/tests/pt_expt/utils/test_network.py b/source/tests/pt_expt/utils/test_network.py index c80946a018..083ea2fbd9 100644 --- a/source/tests/pt_expt/utils/test_network.py +++ b/source/tests/pt_expt/utils/test_network.py @@ -74,7 +74,7 @@ def test_pt_expt_embedding_net_forward(self) -> None: precision=self.precision, seed=GLOBAL_SEED, ) - x = torch.randn(5, self.in_dim, dtype=torch.float64) + x = torch.randn(5, self.in_dim, dtype=torch.float64, device=env.DEVICE) out = net(x) self.assertIsInstance(out, torch.Tensor) self.assertEqual(out.shape, (5, self.neuron[-1])) From baa1be4ba9146febcf2832c43397d6c67582c49a Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 11 Feb 2026 18:25:50 +0800 Subject: [PATCH 35/35] fix the API consistency issue in descriptors --- deepmd/pt_expt/descriptor/se_e2_a.py | 3 --- deepmd/pt_expt/descriptor/se_r.py | 3 --- deepmd/pt_expt/descriptor/se_t.py | 3 --- deepmd/pt_expt/descriptor/se_t_tebd.py | 5 +---- source/tests/pt_expt/descriptor/test_se_t_tebd.py | 12 ------------ 5 files changed, 1 insertion(+), 25 deletions(-) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 1ccb4d2dda..c9503ec413 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -35,9 +35,7 @@ def forward( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, - type_embedding: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -45,7 +43,6 @@ def forward( torch.Tensor | None, torch.Tensor | None, ]: - del extended_atype_embd, type_embedding descrpt, rot_mat, g2, h2, sw = self.call( extended_coord, extended_atype, diff --git a/deepmd/pt_expt/descriptor/se_r.py b/deepmd/pt_expt/descriptor/se_r.py index 7a406fb499..f3006d38a5 100644 --- a/deepmd/pt_expt/descriptor/se_r.py +++ b/deepmd/pt_expt/descriptor/se_r.py @@ -35,9 +35,7 @@ def forward( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, - type_embedding: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -45,7 +43,6 @@ def forward( torch.Tensor | None, torch.Tensor | None, ]: - del extended_atype_embd, type_embedding descrpt, rot_mat, g2, h2, sw = self.call( extended_coord, extended_atype, diff --git a/deepmd/pt_expt/descriptor/se_t.py b/deepmd/pt_expt/descriptor/se_t.py index 604dd6a5c0..9a438c0140 100644 --- a/deepmd/pt_expt/descriptor/se_t.py +++ b/deepmd/pt_expt/descriptor/se_t.py @@ -36,9 +36,7 @@ def forward( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, - type_embedding: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -46,7 +44,6 @@ def forward( torch.Tensor | None, torch.Tensor | None, ]: - del extended_atype_embd, type_embedding descrpt, rot_mat, g2, h2, sw = self.call( extended_coord, extended_atype, diff --git a/deepmd/pt_expt/descriptor/se_t_tebd.py b/deepmd/pt_expt/descriptor/se_t_tebd.py index 235ba1bfe9..7545f8c6fe 100644 --- a/deepmd/pt_expt/descriptor/se_t_tebd.py +++ b/deepmd/pt_expt/descriptor/se_t_tebd.py @@ -34,9 +34,7 @@ def forward( extended_coord: torch.Tensor, extended_atype: torch.Tensor, nlist: torch.Tensor, - extended_atype_embd: torch.Tensor | None = None, mapping: torch.Tensor | None = None, - type_embedding: torch.Tensor | None = None, ) -> tuple[ torch.Tensor, torch.Tensor | None, @@ -44,11 +42,10 @@ def forward( torch.Tensor | None, torch.Tensor | None, ]: - del extended_atype_embd, mapping, type_embedding descrpt, rot_mat, g2, h2, sw = self.call( extended_coord, extended_atype, nlist, - mapping=None, + mapping=mapping, ) return descrpt, rot_mat, g2, h2, sw diff --git a/source/tests/pt_expt/descriptor/test_se_t_tebd.py b/source/tests/pt_expt/descriptor/test_se_t_tebd.py index e774cda6d5..e84080882a 100644 --- a/source/tests/pt_expt/descriptor/test_se_t_tebd.py +++ b/source/tests/pt_expt/descriptor/test_se_t_tebd.py @@ -58,23 +58,16 @@ def test_consistency(self) -> None: dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - # Type embedding input - type_embedding = torch.randn( - [self.nt, dd0.tebd_dim], dtype=dtype, device=self.device - ) - rd0, _, _, _, _ = dd0( torch.tensor(self.coord_ext, dtype=dtype, device=self.device), torch.tensor(self.atype_ext, dtype=int, device=self.device), torch.tensor(self.nlist, dtype=int, device=self.device), - type_embedding=type_embedding, ) dd1 = DescrptSeTTebd.deserialize(dd0.serialize()) rd1, gr1, _, _, sw1 = dd1( torch.tensor(self.coord_ext, dtype=dtype, device=self.device), torch.tensor(self.atype_ext, dtype=int, device=self.device), torch.tensor(self.nlist, dtype=int, device=self.device), - type_embedding=type_embedding, ) np.testing.assert_allclose( rd0.detach().cpu().numpy(), @@ -146,14 +139,9 @@ def test_exportable(self) -> None: dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) dd0 = dd0.eval() - type_embedding = torch.randn( - [self.nt, dd0.tebd_dim], dtype=dtype, device=self.device - ) - inputs = ( torch.tensor(self.coord_ext, dtype=dtype, device=self.device), torch.tensor(self.atype_ext, dtype=int, device=self.device), torch.tensor(self.nlist, dtype=int, device=self.device), - type_embedding, ) torch.export.export(dd0, inputs)