From 3a286e55f46db7e5d5241ba6f4b2ccac83a77413 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 18 Feb 2026 21:18:18 +0800 Subject: [PATCH 1/9] feat(pt_expt): add descriptors dpa1 dpa2 dpa3 and hybrid --- deepmd/dpmodel/array_api.py | 3 + deepmd/dpmodel/descriptor/dpa2.py | 1 + deepmd/dpmodel/descriptor/repflows.py | 10 +- deepmd/dpmodel/descriptor/repformers.py | 12 +- deepmd/dpmodel/descriptor/se_e2_a.py | 141 ++------ deepmd/pt_expt/common.py | 118 ++++++- deepmd/pt_expt/descriptor/__init__.py | 20 ++ deepmd/pt_expt/descriptor/dpa1.py | 20 ++ deepmd/pt_expt/descriptor/dpa2.py | 19 + deepmd/pt_expt/descriptor/dpa3.py | 19 + deepmd/pt_expt/descriptor/hybrid.py | 19 + deepmd/pt_expt/descriptor/se_atten_v2.py | 19 + deepmd/pt_expt/descriptor/se_e2_a.py | 2 +- deepmd/pt_expt/utils/network.py | 26 ++ source/tests/consistent/descriptor/common.py | 104 ++++++ .../tests/consistent/descriptor/test_dpa1.py | 197 +++++++++++ .../tests/consistent/descriptor/test_dpa2.py | 156 +++++++++ .../tests/consistent/descriptor/test_dpa3.py | 102 ++++++ .../consistent/descriptor/test_hybrid.py | 58 ++++ .../consistent/descriptor/test_se_atten_v2.py | 194 +++++++++++ .../consistent/descriptor/test_se_e2_a.py | 54 ++- .../tests/consistent/descriptor/test_se_r.py | 61 +++- .../tests/consistent/descriptor/test_se_t.py | 39 ++- .../consistent/descriptor/test_se_t_tebd.py | 60 ++++ source/tests/pt_expt/descriptor/test_dpa1.py | 181 ++++++++++ source/tests/pt_expt/descriptor/test_dpa2.py | 326 ++++++++++++++++++ source/tests/pt_expt/descriptor/test_dpa3.py | 248 +++++++++++++ .../tests/pt_expt/descriptor/test_hybrid.py | 219 ++++++++++++ .../pt_expt/descriptor/test_se_atten_v2.py | 177 ++++++++++ .../tests/pt_expt/descriptor/test_se_e2_a.py | 212 +++++++----- source/tests/pt_expt/descriptor/test_se_r.py | 206 ++++++----- source/tests/pt_expt/descriptor/test_se_t.py | 226 +++++++----- .../pt_expt/descriptor/test_se_t_tebd.py | 241 +++++++------ 33 files changed, 3004 insertions(+), 486 deletions(-) create mode 100644 deepmd/pt_expt/descriptor/dpa1.py create mode 100644 deepmd/pt_expt/descriptor/dpa2.py create mode 100644 deepmd/pt_expt/descriptor/dpa3.py create mode 100644 deepmd/pt_expt/descriptor/hybrid.py create mode 100644 deepmd/pt_expt/descriptor/se_atten_v2.py create mode 100644 source/tests/pt_expt/descriptor/test_dpa1.py create mode 100644 source/tests/pt_expt/descriptor/test_dpa2.py create mode 100644 source/tests/pt_expt/descriptor/test_dpa3.py create mode 100644 source/tests/pt_expt/descriptor/test_hybrid.py create mode 100644 source/tests/pt_expt/descriptor/test_se_atten_v2.py diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index a0da6193c9..4a7fa9a45c 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -29,6 +29,9 @@ def xp_swapaxes(a: Array, axis1: int, axis2: int) -> Array: def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array: xp = array_api_compat.array_namespace(arr) + # torch.take_along_dim requires int64 indices + if array_api_compat.is_torch_array(indices): + indices = xp.astype(indices, xp.int64) if Version(xp.__array_api_version__) >= Version("2024.12"): # see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39 return xp.take_along_axis(arr, indices, axis=axis) diff --git a/deepmd/dpmodel/descriptor/dpa2.py b/deepmd/dpmodel/descriptor/dpa2.py index e5e02d312c..9a3be982f1 100644 --- a/deepmd/dpmodel/descriptor/dpa2.py +++ b/deepmd/dpmodel/descriptor/dpa2.py @@ -573,6 +573,7 @@ def init_subclass_params(sub_data: dict | Any, sub_class: type) -> Any: self.smooth = smooth self.exclude_types = exclude_types self.env_protection = env_protection + self.rcut_smth = self.repinit.get_rcut_smth() self.trainable = trainable self.add_tebd_to_repinit_out = add_tebd_to_repinit_out diff --git a/deepmd/dpmodel/descriptor/repflows.py b/deepmd/dpmodel/descriptor/repflows.py index 7ba4f92662..3188bbfee5 100644 --- a/deepmd/dpmodel/descriptor/repflows.py +++ b/deepmd/dpmodel/descriptor/repflows.py @@ -595,8 +595,12 @@ def call( # n_angle x 1 a_sw = (a_sw[:, :, :, None] * a_sw[:, :, None, :])[a_nlist_mask] else: - edge_index = xp.zeros([2, 1], dtype=nlist.dtype) - angle_index = xp.zeros([3, 1], dtype=nlist.dtype) + edge_index = xp.zeros( + [2, 1], dtype=nlist.dtype, device=array_api_compat.device(nlist) + ) + angle_index = xp.zeros( + [3, 1], dtype=nlist.dtype, device=array_api_compat.device(nlist) + ) # get edge and angle embedding # nb x nloc x nnei x e_dim [OR] n_edge x e_dim @@ -1711,6 +1715,7 @@ def call( xp.zeros( (nb, nloc, self.nnei - self.a_sel, self.e_dim), dtype=edge_ebd.dtype, + device=array_api_compat.device(edge_ebd), ), ], axis=2, @@ -1741,6 +1746,7 @@ def call( xp.zeros( (nb, nloc, self.nnei - self.a_sel), dtype=a_nlist_mask.dtype, + device=array_api_compat.device(a_nlist_mask), ), ], axis=-1, diff --git a/deepmd/dpmodel/descriptor/repformers.py b/deepmd/dpmodel/descriptor/repformers.py index 06f5c1c943..65248ab88d 100644 --- a/deepmd/dpmodel/descriptor/repformers.py +++ b/deepmd/dpmodel/descriptor/repformers.py @@ -475,7 +475,7 @@ def call( g1 = self.act(atype_embd) # nf x nloc x nnei x 1, nf x nloc x nnei x 3 if not self.direct_dist: - g2, h2 = xp.split(dmatrix, [1], axis=-1) + g2, h2 = dmatrix[..., :1], dmatrix[..., 1:] else: g2, h2 = safe_for_vector_norm(diff, axis=-1, keepdims=True), diff g2 = g2 / self.rcut @@ -756,10 +756,12 @@ def _cal_hg( else: g = _apply_switch(g, sw) if not use_sqrt_nnei: - invnnei = (1.0 / float(nnei)) * xp.ones((nf, nloc, 1, 1), dtype=g.dtype) + invnnei = (1.0 / float(nnei)) * xp.ones( + (nf, nloc, 1, 1), dtype=g.dtype, device=array_api_compat.device(g) + ) else: invnnei = (1.0 / (float(nnei) ** 0.5)) * xp.ones( - (nf, nloc, 1, 1), dtype=g.dtype + (nf, nloc, 1, 1), dtype=g.dtype, device=array_api_compat.device(g) ) # nf x nloc x 3 x ng hg = xp.matmul(xp.matrix_transpose(h), g) * invnnei @@ -1655,7 +1657,9 @@ def _update_g1_conv( invnnei = invnnei[:, :, xp.newaxis] else: gg1 = _apply_switch(gg1, sw) - invnnei = (1.0 / float(nnei)) * xp.ones((nf, nloc, 1), dtype=gg1.dtype) + invnnei = (1.0 / float(nnei)) * xp.ones( + (nf, nloc, 1), dtype=gg1.dtype, device=array_api_compat.device(gg1) + ) if not self.g1_out_conv: # nf x nloc x ng2 g1_11 = xp.sum(g2 * gg1, axis=2) * invnnei diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 77afb110e9..e949a4946b 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -427,45 +427,64 @@ def call( The smooth switch function. """ del mapping + xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) + input_dtype = coord_ext.dtype # nf x nloc x nnei x 4 rr, diff, ww = self.env_mat.call( - coord_ext, atype_ext, nlist, self.davg, self.dstd + coord_ext, + atype_ext, + nlist, + self.davg[...], + self.dstd[...], ) nf, nloc, nnei, _ = rr.shape - sec = np.append([0], np.cumsum(self.sel)) + sec = self.sel_cumsum ng = self.neuron[-1] - gr = np.zeros([nf * nloc, ng, 4], dtype=PRECISION_DICT[self.precision]) + gr = xp.zeros( + [nf * nloc, ng, 4], + dtype=input_dtype, + device=array_api_compat.device(coord_ext), + ) exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) # merge nf and nloc axis, so for type_one_side == False, # we don't require atype is the same in all frames - exclude_mask = exclude_mask.reshape(nf * nloc, nnei) - rr = rr.reshape(nf * nloc, nnei, 4) + exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) + rr = xp.reshape(rr, (nf * nloc, nnei, 4)) + rr = xp.astype(rr, self.dstd.dtype) + + if not self.type_one_side: + # nf x nloc -> (nf * nloc) + atype_loc = xp.reshape(atype_ext[:, :nloc], (nf * nloc,)) for embedding_idx in itertools.product( range(self.ntypes), repeat=self.embeddings.ndim ): if self.type_one_side: (tt,) = embedding_idx - ti_mask = np.s_[:] else: ti, tt = embedding_idx - ti_mask = atype_ext[:, :nloc].ravel() == ti - mm = exclude_mask[ti_mask, sec[tt] : sec[tt + 1]] - tr = rr[ti_mask, sec[tt] : sec[tt + 1], :] - tr = tr * mm[:, :, None] + mm = exclude_mask[:, sec[tt] : sec[tt + 1]] + tr = rr[:, sec[tt] : sec[tt + 1], :] + tr = tr * xp.astype(mm[:, :, None], tr.dtype) ss = tr[..., 0:1] gg = self.cal_g(ss, embedding_idx) - gr_tmp = np.einsum("lni,lnj->lij", gg, tr) - gr[ti_mask] += gr_tmp - gr = gr.reshape(nf, nloc, ng, 4) + gr_tmp = xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1) + if not self.type_one_side: + # (nf * nloc) x 1 x 1 + ti_mask = xp.astype( + xp.reshape(atype_loc == ti, (nf * nloc, 1, 1)), gr_tmp.dtype + ) + gr_tmp = gr_tmp * ti_mask + gr += gr_tmp + gr = xp.reshape(gr, (nf, nloc, ng, 4)) # nf x nloc x ng x 4 gr /= self.nnei gr1 = gr[:, :, : self.axis_neuron, :] # nf x nloc x ng x ng1 - grrg = np.einsum("flid,fljd->flij", gr, gr1) + grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4) # nf x nloc x (ng x ng1) - grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron) + grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)) return grrg, gr[..., 1:], None, None, ww def serialize(self) -> dict: @@ -553,94 +572,4 @@ def update_sel( return local_jdata_cpy, min_nbor_dist -class DescrptSeAArrayAPI(DescrptSeA): - @cast_precision - def call( - self, - coord_ext: Array, - atype_ext: Array, - nlist: Array, - mapping: Array | None = None, - ) -> Array: - """Compute the descriptor. - - Parameters - ---------- - coord_ext - The extended coordinates of atoms. shape: nf x (nallx3) - atype_ext - The extended aotm types. shape: nf x nall - nlist - The neighbor list. shape: nf x nloc x nnei - mapping - The index mapping from extended to local region. not used by this descriptor. - - Returns - ------- - descriptor - The descriptor. shape: nf x nloc x (ng x axis_neuron) - gr - The rotationally equivariant and permutationally invariant single particle - representation. shape: nf x nloc x ng x 3 - g2 - The rotationally invariant pair-partical representation. - this descriptor returns None - h2 - The rotationally equivariant pair-partical representation. - this descriptor returns None - sw - The smooth switch function. - """ - if not self.type_one_side: - raise NotImplementedError( - "type_one_side == False is not supported in DescrptSeAArrayAPI" - ) - del mapping - xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) - input_dtype = coord_ext.dtype - # nf x nloc x nnei x 4 - rr, diff, ww = self.env_mat.call( - coord_ext, - atype_ext, - nlist, - self.davg[...], - self.dstd[...], - ) - nf, nloc, nnei, _ = rr.shape - sec = self.sel_cumsum - - ng = self.neuron[-1] - gr = xp.zeros( - [nf * nloc, ng, 4], - dtype=input_dtype, - device=array_api_compat.device(coord_ext), - ) - exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext) - # merge nf and nloc axis, so for type_one_side == False, - # we don't require atype is the same in all frames - exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei)) - rr = xp.reshape(rr, (nf * nloc, nnei, 4)) - rr = xp.astype(rr, self.dstd.dtype) - - for embedding_idx in itertools.product( - range(self.ntypes), repeat=self.embeddings.ndim - ): - (tt,) = embedding_idx - mm = exclude_mask[:, sec[tt] : sec[tt + 1]] - tr = rr[:, sec[tt] : sec[tt + 1], :] - tr = tr * xp.astype(mm[:, :, None], tr.dtype) - ss = tr[..., 0:1] - gg = self.cal_g(ss, embedding_idx) - # gr_tmp = xp.einsum("lni,lnj->lij", gg, tr) - gr_tmp = xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1) - gr += gr_tmp - gr = xp.reshape(gr, (nf, nloc, ng, 4)) - # nf x nloc x ng x 4 - gr /= self.nnei - gr1 = gr[:, :, : self.axis_neuron, :] - # nf x nloc x ng x ng1 - # grrg = xp.einsum("flid,fljd->flij", gr, gr1) - grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4) - # nf x nloc x (ng x ng1) - grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)) - return grrg, gr[..., 1:], None, None, ww +DescrptSeAArrayAPI = DescrptSeA diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index c7375119e2..a29f61f3f8 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -90,11 +90,16 @@ def register_dpmodel_mapping( def try_convert_module(value: Any) -> torch.nn.Module | None: - """Convert a dpmodel object to its pt_expt wrapper if a converter is registered. + """Convert a dpmodel object to its pt_expt wrapper. This function looks up the exact type of *value* in the _DPMODEL_TO_PT_EXPT registry. If a converter is found, it invokes it to produce a torch.nn.Module - wrapper; otherwise it returns None. + wrapper. Otherwise, if *value* is a ``NativeOP``, it is automatically + wrapped via ``_auto_wrap_native_op`` so that internal helper classes + (e.g. ``RepformerLayer``, ``DescrptBlockRepformers``) don't need explicit + registrations. + + Returns None only for non-NativeOP values. Parameters ---------- @@ -106,13 +111,13 @@ def try_convert_module(value: Any) -> torch.nn.Module | None: ------- torch.nn.Module or None The converted pt_expt module if a converter is registered for value's - type, otherwise None. + type or if the value is a NativeOP (auto-wrapped), otherwise None. Notes ----- - This function uses exact type matching (not isinstance checks) to ensure - predictable behavior. Each dpmodel class must be explicitly registered via - register_dpmodel_mapping. + For explicitly registered types, exact type matching is used (not isinstance + checks). For unregistered types, isinstance(value, NativeOP) triggers the + auto-wrap fallback. The function is called by dpmodel_setattr when it encounters an object that might be a dpmodel instance. If conversion succeeds, the caller should use @@ -121,6 +126,81 @@ def try_convert_module(value: Any) -> torch.nn.Module | None: converter = _DPMODEL_TO_PT_EXPT.get(type(value)) if converter is not None: return converter(value) + if isinstance(value, NativeOP): + return _auto_wrap_native_op(value) + return None + + +# Cache of auto-wrapped classes so each dpmodel class is wrapped at most once. +_AUTO_WRAPPED_CLASSES: dict[type, type] = {} + + +def _auto_wrap_native_op(value: NativeOP) -> torch.nn.Module: + """Auto-wrap any NativeOP as a torch.nn.Module via ``torch_module``. + + Creates a subclass with a generic ``forward`` that delegates to ``call``, + then applies ``torch_module`` to get full ``__setattr__`` / post-init + list conversion. The wrapped class is cached per dpmodel type. + + Parameters + ---------- + value : NativeOP + The dpmodel object to wrap. + + Returns + ------- + torch.nn.Module + The wrapped pt_expt module, deserialized from value's serialized state. + """ + cls = type(value) + if cls not in _AUTO_WRAPPED_CLASSES: + wrapped = type( + cls.__name__, + (cls,), + {"forward": lambda self, *args, **kwargs: self.call(*args, **kwargs)}, + ) + _AUTO_WRAPPED_CLASSES[cls] = torch_module(wrapped) + return _AUTO_WRAPPED_CLASSES[cls].deserialize(value.serialize()) + + +def _try_convert_list(name: str, value: list) -> torch.nn.Module | None: + """Try to convert a plain list to ModuleList or ParameterList. + + Returns the converted container, or None if no conversion is needed + (e.g., empty list, list of scalars/strings). + """ + if not value: + return None + # List of torch.nn.Module → ModuleList + if all(isinstance(v, torch.nn.Module) for v in value): + return torch.nn.ModuleList(value) + # List of NativeOP (not yet Module) → convert each + ModuleList + if all( + isinstance(v, NativeOP) and not isinstance(v, torch.nn.Module) for v in value + ): + converted = [] + for v in value: + c = try_convert_module(v) + if c is None: + raise TypeError( + f"Failed to convert {type(v).__name__} " + f"in list attribute '{name}'. Please call " + f"register_dpmodel_mapping for this type." + ) + converted.append(c) + return torch.nn.ModuleList(converted) + # List of numpy arrays → ParameterList (non-trainable) + if all(isinstance(v, np.ndarray) for v in value): + from deepmd.pt_expt.utils import env # deferred - avoids circular import + + return torch.nn.ParameterList( + [ + torch.nn.Parameter( + torch.as_tensor(v, device=env.DEVICE), requires_grad=False + ) + for v in value + ] + ) return None @@ -217,21 +297,18 @@ def dpmodel_setattr(obj: torch.nn.Module, name: str, value: Any) -> tuple[bool, obj._buffers[name] = None return True, None - # dpmodel object → pt_expt module + # list of modules / NativeOP / numpy arrays → ModuleList / ParameterList + if isinstance(value, list) and "_modules" in obj.__dict__: + converted_list = _try_convert_list(name, value) + if converted_list is not None: + return False, converted_list + + # dpmodel object → pt_expt module (uses auto-wrap for unregistered NativeOP) if "_modules" in obj.__dict__: - # Try to convert dpmodel objects that aren't already torch.nn.Modules if not isinstance(value, torch.nn.Module): converted = try_convert_module(value) if converted is not None: return False, converted - # If this is a NativeOP that should have been registered but wasn't, raise error - if isinstance(value, NativeOP): - raise TypeError( - f"Attempted to assign a dpmodel object of type {type(value).__name__} " - f"but no converter is registered. Please call register_dpmodel_mapping " - f"for this type. If this object doesn't need conversion, register it " - f"with an identity converter: lambda v: v" - ) return False, value @@ -322,6 +399,15 @@ class TorchModule(module, torch.nn.Module): def __init__(self, *args: Any, **kwargs: Any) -> None: torch.nn.Module.__init__(self) module.__init__(self, *args, **kwargs) + # Convert any plain lists built incrementally during __init__. + # (list.append() bypasses __setattr__, so dpmodel_setattr never + # sees the complete list; we scan __dict__ here to catch them.) + for name in list(self.__dict__): + value = self.__dict__[name] + if isinstance(value, list): + converted = _try_convert_list(name, value) + if converted is not None: + setattr(self, name, converted) def __call__(self, *args: Any, **kwargs: Any) -> Any: # Ensure torch.nn.Module.__call__ drives forward() for export/tracing. diff --git a/deepmd/pt_expt/descriptor/__init__.py b/deepmd/pt_expt/descriptor/__init__.py index 7feda7d703..1667182d84 100644 --- a/deepmd/pt_expt/descriptor/__init__.py +++ b/deepmd/pt_expt/descriptor/__init__.py @@ -4,6 +4,21 @@ from .base_descriptor import ( BaseDescriptor, ) +from .dpa1 import ( + DescrptDPA1, +) +from .dpa2 import ( + DescrptDPA2, +) +from .dpa3 import ( + DescrptDPA3, +) +from .hybrid import ( + DescrptHybrid, +) +from .se_atten_v2 import ( + DescrptSeAttenV2, +) from .se_e2_a import ( DescrptSeA, ) @@ -19,7 +34,12 @@ __all__ = [ "BaseDescriptor", + "DescrptDPA1", + "DescrptDPA2", + "DescrptDPA3", + "DescrptHybrid", "DescrptSeA", + "DescrptSeAttenV2", "DescrptSeR", "DescrptSeT", "DescrptSeTTebd", diff --git a/deepmd/pt_expt/descriptor/dpa1.py b/deepmd/pt_expt/descriptor/dpa1.py new file mode 100644 index 0000000000..8af613eaf3 --- /dev/null +++ b/deepmd/pt_expt/descriptor/dpa1.py @@ -0,0 +1,20 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1DP +from deepmd.pt_expt.common import ( + torch_module, +) +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("se_atten") +@BaseDescriptor.register("dpa1") +@torch_module +class DescrptDPA1(DescrptDPA1DP): + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.call(*args, **kwargs) diff --git a/deepmd/pt_expt/descriptor/dpa2.py b/deepmd/pt_expt/descriptor/dpa2.py new file mode 100644 index 0000000000..e35eb22336 --- /dev/null +++ b/deepmd/pt_expt/descriptor/dpa2.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2DP +from deepmd.pt_expt.common import ( + torch_module, +) +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("dpa2") +@torch_module +class DescrptDPA2(DescrptDPA2DP): + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.call(*args, **kwargs) diff --git a/deepmd/pt_expt/descriptor/dpa3.py b/deepmd/pt_expt/descriptor/dpa3.py new file mode 100644 index 0000000000..05d2c4277a --- /dev/null +++ b/deepmd/pt_expt/descriptor/dpa3.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3DP +from deepmd.pt_expt.common import ( + torch_module, +) +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("dpa3") +@torch_module +class DescrptDPA3(DescrptDPA3DP): + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.call(*args, **kwargs) diff --git a/deepmd/pt_expt/descriptor/hybrid.py b/deepmd/pt_expt/descriptor/hybrid.py new file mode 100644 index 0000000000..7da3c1da21 --- /dev/null +++ b/deepmd/pt_expt/descriptor/hybrid.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DescrptHybridDP +from deepmd.pt_expt.common import ( + torch_module, +) +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("hybrid") +@torch_module +class DescrptHybrid(DescrptHybridDP): + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.call(*args, **kwargs) diff --git a/deepmd/pt_expt/descriptor/se_atten_v2.py b/deepmd/pt_expt/descriptor/se_atten_v2.py new file mode 100644 index 0000000000..eb7c464f2c --- /dev/null +++ b/deepmd/pt_expt/descriptor/se_atten_v2.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from typing import ( + Any, +) + +from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DescrptSeAttenV2DP +from deepmd.pt_expt.common import ( + torch_module, +) +from deepmd.pt_expt.descriptor.base_descriptor import ( + BaseDescriptor, +) + + +@BaseDescriptor.register("se_atten_v2") +@torch_module +class DescrptSeAttenV2(DescrptSeAttenV2DP): + def forward(self, *args: Any, **kwargs: Any) -> Any: + return self.call(*args, **kwargs) diff --git a/deepmd/pt_expt/descriptor/se_e2_a.py b/deepmd/pt_expt/descriptor/se_e2_a.py index 894f175764..f177c8c07f 100644 --- a/deepmd/pt_expt/descriptor/se_e2_a.py +++ b/deepmd/pt_expt/descriptor/se_e2_a.py @@ -2,7 +2,7 @@ import torch -from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeAArrayAPI as DescrptSeADP +from deepmd.dpmodel.descriptor.se_e2_a import DescrptSeA as DescrptSeADP from deepmd.pt_expt.common import ( torch_module, ) diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 929907c2f3..79f51a5ca2 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -274,3 +274,29 @@ def __setitem__(self, key: int | tuple, value: Any) -> None: class LayerNorm(LayerNormDP, NativeLayer): pass + + +register_dpmodel_mapping( + NativeLayerDP, + lambda v: NativeLayer.deserialize(v.serialize()), +) + +register_dpmodel_mapping( + LayerNormDP, + lambda v: LayerNorm.deserialize(v.serialize()), +) + + +from deepmd.dpmodel.utils.network import Identity as IdentityDP + + +@torch_module +class Identity(IdentityDP): + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.call(x) + + +register_dpmodel_mapping( + IdentityDP, + lambda v: Identity.deserialize(v.serialize()), +) diff --git a/source/tests/consistent/descriptor/common.py b/source/tests/consistent/descriptor/common.py index 7c8cbce744..e82fb0dda8 100644 --- a/source/tests/consistent/descriptor/common.py +++ b/source/tests/consistent/descriptor/common.py @@ -4,6 +4,9 @@ ) import numpy as np +from dargs import ( + Argument, +) from deepmd.common import ( make_default_mesh, @@ -257,3 +260,104 @@ def eval_array_api_strict_descriptor( ext_coords, ext_atype, nlist=nlist, mapping=mapping ) ] + + +class DescriptorAPITest: + """Base mixin class for testing consistency of BaseDescriptor API methods across backends. + + Subclasses should set dp_class, pt_class, pt_expt_class, args, and + provide a ``data`` property returning the constructor kwargs dict. + """ + + dp_class = None + pt_class = None + pt_expt_class = None + args = None + ntypes = 2 + + @property + def data(self) -> dict: + raise NotImplementedError + + @property + def skip_pt(self) -> bool: + return not INSTALLED_PT + + @property + def skip_pt_expt(self) -> bool: + return not INSTALLED_PT_EXPT + + def _init_descriptor(self, cls): + """Initialize a descriptor from the test data.""" + if self.args is None: + data = self.data + else: + if isinstance(self.args, list): + base = Argument("arg", dict, sub_fields=self.args) + elif isinstance(self.args, Argument): + base = self.args + else: + raise ValueError("Invalid type for args") + data = base.normalize_value(self.data, trim_pattern="_*") + base.check_value(data, strict=True) + return cls(**data) + + def _assert_method_equal(self, ref_obj, obj, method_name, err_msg=""): + """Assert a method returns the same value on both objects. + + Handles the case where a method raises NotImplementedError or + AttributeError: both objects must raise for the check to pass. + """ + ref_raised = False + obj_raised = False + ref_val = obj_val = None + try: + ref_val = getattr(ref_obj, method_name)() + except (NotImplementedError, AttributeError): + ref_raised = True + try: + obj_val = getattr(obj, method_name)() + except (NotImplementedError, AttributeError): + obj_raised = True + self.assertEqual( + ref_raised, + obj_raised, + msg=f"{err_msg}: {method_name} raised/not-raised mismatch", + ) + if not ref_raised: + self.assertEqual(ref_val, obj_val, msg=f"{err_msg}: {method_name}") + + def _assert_descriptor_api_equal(self, ref_obj, obj, err_msg=""): + """Assert that all BaseDescriptor API methods return consistent values.""" + for method_name in [ + "get_rcut", + "get_rcut_smth", + "get_sel", + "get_ntypes", + "get_type_map", + "get_dim_out", + "get_dim_emb", + "mixed_types", + "has_message_passing", + "need_sorted_nlist_for_lower", + "get_env_protection", + "get_nsel", + "get_nnei", + ]: + self._assert_method_equal(ref_obj, obj, method_name, err_msg=err_msg) + + def test_dp_pt_api(self) -> None: + """Test whether DP and PT descriptor APIs are consistent.""" + if self.skip_pt: + self.skipTest("PT not installed or skipped") + dp_obj = self._init_descriptor(self.dp_class) + pt_obj = self._init_descriptor(self.pt_class) + self._assert_descriptor_api_equal(dp_obj, pt_obj, err_msg="DP vs PT") + + def test_dp_pt_expt_api(self) -> None: + """Test whether DP and PT exportable descriptor APIs are consistent.""" + if self.skip_pt_expt: + self.skipTest("PT_EXPT not installed or skipped") + dp_obj = self._init_descriptor(self.dp_class) + pt_expt_obj = self._init_descriptor(self.pt_expt_class) + self._assert_descriptor_api_equal(dp_obj, pt_expt_obj, err_msg="DP vs PT_EXPT") diff --git a/source/tests/consistent/descriptor/test_dpa1.py b/source/tests/consistent/descriptor/test_dpa1.py index 0f14288386..60748bd17b 100644 --- a/source/tests/consistent/descriptor/test_dpa1.py +++ b/source/tests/consistent/descriptor/test_dpa1.py @@ -19,11 +19,13 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, parameterized, ) from .common import ( + DescriptorAPITest, DescriptorTest, ) @@ -43,6 +45,10 @@ from deepmd.pd.model.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1PD else: DescrptDPA1PD = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.dpa1 import DescrptDPA1 as DescrptDPA1PTExpt +else: + DescrptDPA1PTExpt = None if INSTALLED_ARRAY_API_STRICT: from ...array_api_strict.descriptor.dpa1 import DescrptDPA1 as DescriptorDPA1Strict else: @@ -280,6 +286,37 @@ def skip_array_api_strict(self) -> bool: ) ) + @property + def skip_pt_expt(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return ( + CommonTest.skip_pt_expt + or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + temperature, + ) + ) + @property def skip_tf(self) -> bool: ( @@ -321,6 +358,7 @@ def skip_tf(self) -> bool: tf_class = DescrptDPA1TF dp_class = DescrptDPA1DP pt_class = DescrptDPA1PT + pt_expt_class = DescrptDPA1PTExpt pd_class = DescrptDPA1PD jax_class = DescriptorDPA1JAX array_api_strict_class = DescriptorDPA1Strict @@ -432,6 +470,16 @@ def eval_pd(self, pd_obj: Any) -> Any: mixed_types=True, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: return self.eval_array_api_strict_descriptor( array_api_strict_obj, @@ -506,3 +554,152 @@ def atol(self) -> float: return 1e-4 else: raise ValueError(f"Unknown precision: {precision}") + + +@parameterized( + (4,), # tebd_dim + ("concat", "strip"), # tebd_input_mode + (True,), # resnet_dt + (True,), # type_one_side + (20,), # attn + (0, 2), # attn_layer + (True,), # attn_dotr + ([], [[0, 1]]), # excluded_types + (0.0,), # env_protection + (True, False), # set_davg_zero + (1.0,), # scaling_factor + (True,), # normalize + (None, 1.0), # temperature + (1e-5,), # ln_eps + (True,), # smooth_type_embedding + (True,), # concat_output_tebd + ("float64",), # precision + (True, False), # use_econf_tebd + (False,), # use_tebd_bias +) +class TestDPA1DescriptorAPI(DescriptorAPITest, unittest.TestCase): + """Test consistency of BaseDescriptor API methods across backends.""" + + dp_class = DescrptDPA1DP + pt_class = DescrptDPA1PT + pt_expt_class = DescrptDPA1PTExpt + args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False)) + + def is_meaningless_zero_attention_layer_tests( + self, + attn_layer: int, + temperature: float | None, + ) -> bool: + return attn_layer == 0 and (temperature is not None) + + @property + def data(self) -> dict: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return { + "sel": [10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "ntypes": self.ntypes, + "axis_neuron": 3, + "tebd_dim": tebd_dim, + "tebd_input_mode": tebd_input_mode, + "attn": attn, + "attn_layer": attn_layer, + "attn_dotr": attn_dotr, + "attn_mask": False, + "scaling_factor": scaling_factor, + "normalize": normalize, + "temperature": temperature, + "ln_eps": ln_eps, + "concat_output_tebd": concat_output_tebd, + "resnet_dt": resnet_dt, + "type_one_side": type_one_side, + "exclude_types": excluded_types, + "env_protection": env_protection, + "precision": precision, + "set_davg_zero": set_davg_zero, + "smooth_type_embedding": smooth_type_embedding, + "use_econf_tebd": use_econf_tebd, + "use_tebd_bias": use_tebd_bias, + "type_map": ["O", "H"] if use_econf_tebd else None, + "seed": 1145141919810, + "trainable": False, + "activation_function": "relu", + } + + @property + def skip_pt(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return not INSTALLED_PT or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + temperature, + ) + + @property + def skip_pt_expt(self) -> bool: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + smooth_type_embedding, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return not INSTALLED_PT_EXPT or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + temperature, + ) diff --git a/source/tests/consistent/descriptor/test_dpa2.py b/source/tests/consistent/descriptor/test_dpa2.py index 25ab23d1a7..9a48b84d9d 100644 --- a/source/tests/consistent/descriptor/test_dpa2.py +++ b/source/tests/consistent/descriptor/test_dpa2.py @@ -19,10 +19,12 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, CommonTest, parameterized, ) from .common import ( + DescriptorAPITest, DescriptorTest, ) @@ -41,6 +43,10 @@ else: DescrptDPA2PD = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2PTExpt +else: + DescrptDPA2PTExpt = None if INSTALLED_ARRAY_API_STRICT: from ...array_api_strict.descriptor.dpa2 import DescrptDPA2 as DescrptDPA2Strict else: @@ -322,10 +328,12 @@ def skip_tf(self) -> bool: skip_jax = not INSTALLED_JAX skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + skip_pt_expt = not INSTALLED_PT_EXPT tf_class = DescrptDPA2TF dp_class = DescrptDPA2DP pt_class = DescrptDPA2PT + pt_expt_class = DescrptDPA2PTExpt pd_class = DescrptDPA2PD jax_class = DescrptDPA2JAX array_api_strict_class = DescrptDPA2Strict @@ -444,6 +452,16 @@ def eval_jax(self, jax_obj: Any) -> Any: mixed_types=True, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: return self.eval_array_api_strict_descriptor( array_api_strict_obj, @@ -534,3 +552,141 @@ def atol(self) -> float: return 1e-4 else: raise ValueError(f"Unknown precision: {precision}") + + +@parameterized( + ("concat", "strip"), # repinit_tebd_input_mode + (True,), # repinit_set_davg_zero + (False,), # repinit_type_one_side + (True, False), # repinit_use_three_body + (True, False), # repformer_direct_dist + (True,), # repformer_update_g1_has_conv + (True,), # repformer_update_g1_has_drrd + (True,), # repformer_update_g1_has_grrg + (True,), # repformer_update_g1_has_attn + (True,), # repformer_update_g2_has_g1g1 + (True,), # repformer_update_g2_has_attn + (False,), # repformer_update_h2 + (True,), # repformer_attn2_has_gate + ("res_avg", "res_residual"), # repformer_update_style + ("norm", "const"), # repformer_update_residual_init + (True,), # repformer_set_davg_zero + (True,), # repformer_trainable_ln + (1e-5,), # repformer_ln_eps + (True,), # repformer_use_sqrt_nnei + (True,), # repformer_g1_out_conv + (True,), # repformer_g1_out_mlp + (True, False), # smooth + ([], [[0, 1]]), # exclude_types + ("float64",), # precision + (True, False), # add_tebd_to_repinit_out + (True, False), # use_econf_tebd + (False,), # use_tebd_bias +) +class TestDPA2DescriptorAPI(DescriptorAPITest, unittest.TestCase): + """Test consistency of BaseDescriptor API methods across backends.""" + + dp_class = DescrptDPA2DP + pt_class = DescrptDPA2PT + pt_expt_class = DescrptDPA2PTExpt + args = descrpt_dpa2_args().append(Argument("ntypes", int, optional=False)) + + @property + def data(self) -> dict: + ( + repinit_tebd_input_mode, + repinit_set_davg_zero, + repinit_type_one_side, + repinit_use_three_body, + repformer_update_g1_has_conv, + repformer_direct_dist, + repformer_update_g1_has_drrd, + repformer_update_g1_has_grrg, + repformer_update_g1_has_attn, + repformer_update_g2_has_g1g1, + repformer_update_g2_has_attn, + repformer_update_h2, + repformer_attn2_has_gate, + repformer_update_style, + repformer_update_residual_init, + repformer_set_davg_zero, + repformer_trainable_ln, + repformer_ln_eps, + repformer_use_sqrt_nnei, + repformer_g1_out_conv, + repformer_g1_out_mlp, + smooth, + exclude_types, + precision, + add_tebd_to_repinit_out, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return { + "ntypes": self.ntypes, + # kwargs for repinit + "repinit": RepinitArgs( + **{ + "rcut": 6.00, + "rcut_smth": 5.80, + "nsel": 10, + "neuron": [6, 12, 24], + "axis_neuron": 3, + "tebd_dim": 4, + "tebd_input_mode": repinit_tebd_input_mode, + "set_davg_zero": repinit_set_davg_zero, + "activation_function": "relu", + "type_one_side": repinit_type_one_side, + "use_three_body": repinit_use_three_body, + "three_body_sel": 8, + "three_body_rcut": 4.0, + "three_body_rcut_smth": 3.5, + } + ), + # kwargs for repformer + "repformer": RepformerArgs( + **{ + "rcut": 4.00, + "rcut_smth": 3.50, + "nsel": 8, + "nlayers": 3, + "g1_dim": 20, + "g2_dim": 10, + "axis_neuron": 3, + "direct_dist": repformer_direct_dist, + "update_g1_has_conv": repformer_update_g1_has_conv, + "update_g1_has_drrd": repformer_update_g1_has_drrd, + "update_g1_has_grrg": repformer_update_g1_has_grrg, + "update_g1_has_attn": repformer_update_g1_has_attn, + "update_g2_has_g1g1": repformer_update_g2_has_g1g1, + "update_g2_has_attn": repformer_update_g2_has_attn, + "update_h2": repformer_update_h2, + "attn1_hidden": 12, + "attn1_nhead": 2, + "attn2_hidden": 10, + "attn2_nhead": 2, + "attn2_has_gate": repformer_attn2_has_gate, + "activation_function": "relu", + "update_style": repformer_update_style, + "update_residual": 0.001, + "update_residual_init": repformer_update_residual_init, + "set_davg_zero": True, + "trainable_ln": repformer_trainable_ln, + "ln_eps": repformer_ln_eps, + "use_sqrt_nnei": repformer_use_sqrt_nnei, + "g1_out_conv": repformer_g1_out_conv, + "g1_out_mlp": repformer_g1_out_mlp, + } + ), + # kwargs for descriptor + "concat_output_tebd": True, + "precision": precision, + "smooth": smooth, + "exclude_types": exclude_types, + "env_protection": 0.0, + "trainable": False, + "use_econf_tebd": use_econf_tebd, + "use_tebd_bias": use_tebd_bias, + "type_map": ["O", "H"] if use_econf_tebd else None, + "add_tebd_to_repinit_out": add_tebd_to_repinit_out, + } diff --git a/source/tests/consistent/descriptor/test_dpa3.py b/source/tests/consistent/descriptor/test_dpa3.py index 253d58d4a7..65d471fb99 100644 --- a/source/tests/consistent/descriptor/test_dpa3.py +++ b/source/tests/consistent/descriptor/test_dpa3.py @@ -19,10 +19,12 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, CommonTest, parameterized, ) from .common import ( + DescriptorAPITest, DescriptorTest, ) @@ -41,6 +43,10 @@ else: DescrptDPA3PD = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3PTExpt +else: + DescrptDPA3PTExpt = None if INSTALLED_ARRAY_API_STRICT: from ...array_api_strict.descriptor.dpa3 import DescrptDPA3 as DescrptDPA3Strict else: @@ -215,10 +221,12 @@ def skip_tf(self) -> bool: skip_jax = not INSTALLED_JAX skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + skip_pt_expt = not INSTALLED_PT_EXPT tf_class = DescrptDPA3TF dp_class = DescrptDPA3DP pt_class = DescrptDPA3PT + pt_expt_class = DescrptDPA3PTExpt pd_class = DescrptDPA3PD jax_class = DescrptDPA3JAX array_api_strict_class = DescrptDPA3Strict @@ -324,6 +332,16 @@ def eval_jax(self, jax_obj: Any) -> Any: mixed_types=True, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: return self.eval_array_api_strict_descriptor( array_api_strict_obj, @@ -388,3 +406,87 @@ def atol(self) -> float: return 1e-4 else: raise ValueError(f"Unknown precision: {precision}") + + +@parameterized( + ("const",), # update_residual_init + ([], [[0, 1]]), # exclude_types + (True,), # update_angle + (0, 1), # a_compress_rate + (1, 2), # a_compress_e_rate + (True,), # a_compress_use_split + (True, False), # optim_update + (True, False), # edge_init_use_dist + (True, False), # use_exp_switch + (True, False), # use_dynamic_sel + (True, False), # use_loc_mapping + (0.3, 0.0), # fix_stat_std + (1,), # n_multi_edge_message + ("float64",), # precision +) +class TestDPA3DescriptorAPI(DescriptorAPITest, unittest.TestCase): + """Test consistency of BaseDescriptor API methods across backends.""" + + dp_class = DescrptDPA3DP + pt_class = DescrptDPA3PT + pt_expt_class = DescrptDPA3PTExpt + args = descrpt_dpa3_args().append(Argument("ntypes", int, optional=False)) + + @property + def data(self) -> dict: + ( + update_residual_init, + exclude_types, + update_angle, + a_compress_rate, + a_compress_e_rate, + a_compress_use_split, + optim_update, + edge_init_use_dist, + use_exp_switch, + use_dynamic_sel, + use_loc_mapping, + fix_stat_std, + n_multi_edge_message, + precision, + ) = self.param + return { + "ntypes": self.ntypes, + # kwargs for repinit + "repflow": RepFlowArgs( + **{ + "n_dim": 20, + "e_dim": 10, + "a_dim": 8, + "nlayers": 3, + "e_rcut": 6.0, + "e_rcut_smth": 5.0, + "e_sel": 10, + "a_rcut": 4.0, + "a_rcut_smth": 3.5, + "a_sel": 8, + "a_compress_rate": a_compress_rate, + "a_compress_e_rate": a_compress_e_rate, + "a_compress_use_split": a_compress_use_split, + "optim_update": optim_update, + "edge_init_use_dist": edge_init_use_dist, + "use_exp_switch": use_exp_switch, + "use_dynamic_sel": use_dynamic_sel, + "smooth_edge_update": True, + "fix_stat_std": fix_stat_std, + "n_multi_edge_message": n_multi_edge_message, + "axis_neuron": 4, + "update_angle": update_angle, + "update_style": "res_residual", + "update_residual": 0.1, + "update_residual_init": update_residual_init, + } + ), + # kwargs for descriptor + "activation_function": "relu", + "precision": precision, + "exclude_types": exclude_types, + "env_protection": 0.0, + "use_loc_mapping": use_loc_mapping, + "trainable": False, + } diff --git a/source/tests/consistent/descriptor/test_hybrid.py b/source/tests/consistent/descriptor/test_hybrid.py index 2fcd606615..6557deb9a4 100644 --- a/source/tests/consistent/descriptor/test_hybrid.py +++ b/source/tests/consistent/descriptor/test_hybrid.py @@ -15,10 +15,12 @@ INSTALLED_ARRAY_API_STRICT, INSTALLED_JAX, INSTALLED_PT, + INSTALLED_PT_EXPT, INSTALLED_TF, CommonTest, ) from .common import ( + DescriptorAPITest, DescriptorTest, ) @@ -34,6 +36,10 @@ from deepmd.jax.descriptor.hybrid import DescrptHybrid as DescrptHybridJAX else: DescrptHybridJAX = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.hybrid import DescrptHybrid as DescrptHybridPTExpt +else: + DescrptHybridPTExpt = None if INSTALLED_ARRAY_API_STRICT: from ...array_api_strict.descriptor.hybrid import ( DescrptHybrid as DescrptHybridStrict, @@ -82,12 +88,14 @@ def data(self) -> dict: tf_class = DescrptHybridTF dp_class = DescrptHybridDP pt_class = DescrptHybridPT + pt_expt_class = DescrptHybridPTExpt jax_class = DescrptHybridJAX array_api_strict_class = DescrptHybridStrict args = descrpt_hybrid_args() skip_jax = not INSTALLED_JAX skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT + skip_pt_expt = not INSTALLED_PT_EXPT def setUp(self) -> None: CommonTest.setUp(self) @@ -160,6 +168,15 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: self.box, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + ) + def eval_jax(self, jax_obj: Any) -> Any: return self.eval_jax_descriptor( jax_obj, @@ -171,3 +188,44 @@ def eval_jax(self, jax_obj: Any) -> Any: def extract_ret(self, ret: Any, backend) -> tuple[np.ndarray, ...]: return (ret[0], ret[1]) + + +class TestHybridDescriptorAPI(DescriptorAPITest, unittest.TestCase): + """Test consistency of BaseDescriptor API methods across backends.""" + + dp_class = DescrptHybridDP + pt_class = DescrptHybridPT + pt_expt_class = DescrptHybridPTExpt + args = descrpt_hybrid_args() + + @property + def data(self) -> dict: + return { + "list": [ + { + "type": "se_e2_r", + "sel": [10, 10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "resnet_dt": False, + "type_one_side": True, + "precision": "float64", + "seed": 20240229, + "activation_function": "relu", + }, + { + "type": "se_e2_a", + "sel": [9, 11], + "rcut_smth": 2.80, + "rcut": 3.00, + "neuron": [6, 12, 24], + "axis_neuron": 3, + "resnet_dt": True, + "type_one_side": True, + "precision": "float64", + "seed": 20240229, + "activation_function": "relu", + }, + ] + } diff --git a/source/tests/consistent/descriptor/test_se_atten_v2.py b/source/tests/consistent/descriptor/test_se_atten_v2.py index 1ab71eafbf..1cc644c73c 100644 --- a/source/tests/consistent/descriptor/test_se_atten_v2.py +++ b/source/tests/consistent/descriptor/test_se_atten_v2.py @@ -19,10 +19,12 @@ INSTALLED_JAX, INSTALLED_PD, INSTALLED_PT, + INSTALLED_PT_EXPT, CommonTest, parameterized, ) from .common import ( + DescriptorAPITest, DescriptorTest, ) @@ -44,6 +46,12 @@ ) else: DescrptSeAttenV2Strict = None +if INSTALLED_PT_EXPT: + from deepmd.pt_expt.descriptor.se_atten_v2 import ( + DescrptSeAttenV2 as DescrptSeAttenV2PTExpt, + ) +else: + DescrptSeAttenV2PTExpt = None if INSTALLED_PD: from deepmd.pd.model.descriptor.se_atten_v2 import ( DescrptSeAttenV2 as DescrptSeAttenV2PD, @@ -255,6 +263,37 @@ def skip_array_api_strict(self) -> bool: ) ) + @property + def skip_pt_expt(self) -> bool: + ( + tebd_dim, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return ( + CommonTest.skip_pt_expt + or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) + ) + @property def skip_pd(self) -> bool: ( @@ -286,6 +325,7 @@ def skip_pd(self) -> bool: tf_class = DescrptSeAttenV2TF dp_class = DescrptSeAttenV2DP pt_class = DescrptSeAttenV2PT + pt_expt_class = DescrptSeAttenV2PTExpt jax_class = DescrptSeAttenV2JAX array_api_strict_class = DescrptSeAttenV2Strict pd_class = DescrptSeAttenV2PD @@ -375,6 +415,16 @@ def eval_array_api_strict(self, array_api_strict_obj: Any) -> Any: mixed_types=True, ) + def eval_pt_expt(self, pt_expt_obj: Any) -> Any: + return self.eval_pt_expt_descriptor( + pt_expt_obj, + self.natoms, + self.coords, + self.atype, + self.box, + mixed_types=True, + ) + def eval_pd(self, pd_obj: Any) -> Any: return self.eval_pd_descriptor( pd_obj, @@ -445,3 +495,147 @@ def atol(self) -> float: return 1e-4 else: raise ValueError(f"Unknown precision: {precision}") + + +@parameterized( + (4,), # tebd_dim + (True,), # resnet_dt + (True, False), # type_one_side + (20,), # attn + (0, 2), # attn_layer + (True, False), # attn_dotr + ([], [[0, 1]]), # excluded_types + (0.0,), # env_protection + (True, False), # set_davg_zero + (1.0,), # scaling_factor + (True, False), # normalize + (None, 1.0), # temperature + (1e-5,), # ln_eps + (True,), # concat_output_tebd + ("float64",), # precision + (True, False), # use_econf_tebd + (False,), # use_tebd_bias +) +class TestSeAttenV2DescriptorAPI(DescriptorAPITest, unittest.TestCase): + """Test consistency of BaseDescriptor API methods across backends.""" + + dp_class = DescrptSeAttenV2DP + pt_class = DescrptSeAttenV2PT + pt_expt_class = DescrptSeAttenV2PTExpt + args = descrpt_se_atten_args().append(Argument("ntypes", int, optional=False)) + + def is_meaningless_zero_attention_layer_tests( + self, + attn_layer: int, + attn_dotr: bool, + normalize: bool, + temperature: float | None, + ) -> bool: + return attn_layer == 0 and (attn_dotr or normalize or temperature is not None) + + @property + def data(self) -> dict: + ( + tebd_dim, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return { + "sel": [10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "ntypes": self.ntypes, + "axis_neuron": 3, + "tebd_dim": tebd_dim, + "attn": attn, + "attn_layer": attn_layer, + "attn_dotr": attn_dotr, + "attn_mask": False, + "scaling_factor": scaling_factor, + "normalize": normalize, + "temperature": temperature, + "ln_eps": ln_eps, + "concat_output_tebd": concat_output_tebd, + "resnet_dt": resnet_dt, + "type_one_side": type_one_side, + "exclude_types": excluded_types, + "env_protection": env_protection, + "precision": precision, + "set_davg_zero": set_davg_zero, + "use_econf_tebd": use_econf_tebd, + "use_tebd_bias": use_tebd_bias, + "type_map": ["O", "H"] if use_econf_tebd else None, + "seed": 1145141919810, + "activation_function": "relu", + } + + @property + def skip_pt(self) -> bool: + ( + tebd_dim, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return not INSTALLED_PT or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) + + @property + def skip_pt_expt(self) -> bool: + ( + tebd_dim, + resnet_dt, + type_one_side, + attn, + attn_layer, + attn_dotr, + excluded_types, + env_protection, + set_davg_zero, + scaling_factor, + normalize, + temperature, + ln_eps, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return not INSTALLED_PT_EXPT or self.is_meaningless_zero_attention_layer_tests( + attn_layer, + attn_dotr, + normalize, + temperature, + ) diff --git a/source/tests/consistent/descriptor/test_se_e2_a.py b/source/tests/consistent/descriptor/test_se_e2_a.py index f1736f7538..efa5c4bc6d 100644 --- a/source/tests/consistent/descriptor/test_se_e2_a.py +++ b/source/tests/consistent/descriptor/test_se_e2_a.py @@ -22,6 +22,7 @@ parameterized, ) from .common import ( + DescriptorAPITest, DescriptorTest, ) @@ -144,7 +145,7 @@ def skip_tf(self) -> bool: precision, env_protection, ) = self.param - return env_protection != 0.0 + return env_protection != 0.0 or CommonTest.skip_tf @property def skip_jax(self) -> bool: @@ -666,3 +667,54 @@ def atol(self) -> float: return 1e-4 else: raise ValueError(f"Unknown precision: {precision}") + + +@parameterized( + (True, False), # resnet_dt + (True, False), # type_one_side + ([], [[0, 1]]), # excluded_types + ("float64",), # precision + (0.0, 1e-8, 1e-2), # env_protection +) +class TestSeADescriptorAPI(DescriptorAPITest, unittest.TestCase): + """Test consistency of BaseDescriptor API methods across backends.""" + + dp_class = DescrptSeADP + pt_class = DescrptSeAPT + pt_expt_class = DescrptSeAPTExpt + args = descrpt_se_a_args() + + @property + def data(self) -> dict: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + env_protection, + ) = self.param + return { + "sel": [9, 10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "axis_neuron": 3, + "resnet_dt": resnet_dt, + "type_one_side": type_one_side, + "exclude_types": excluded_types, + "env_protection": env_protection, + "precision": precision, + "seed": 1145141919810, + "activation_function": "relu", + } + + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + env_protection, + ) = self.param + return (not type_one_side) or not INSTALLED_PT_EXPT diff --git a/source/tests/consistent/descriptor/test_se_r.py b/source/tests/consistent/descriptor/test_se_r.py index 8c8680755e..826eaf2145 100644 --- a/source/tests/consistent/descriptor/test_se_r.py +++ b/source/tests/consistent/descriptor/test_se_r.py @@ -21,13 +21,14 @@ parameterized, ) from .common import ( + DescriptorAPITest, DescriptorTest, ) if INSTALLED_PT: from deepmd.pt.model.descriptor.se_r import DescrptSeR as DescrptSeRPT else: - DescrptSeAPT = None + DescrptSeRPT = None if INSTALLED_PT_EXPT: from deepmd.pt_expt.descriptor.se_r import DescrptSeR as DescrptSeRPTExpt else: @@ -35,7 +36,7 @@ if INSTALLED_TF: from deepmd.tf.descriptor.se_r import DescrptSeR as DescrptSeRTF else: - DescrptSeATF = None + DescrptSeRTF = None from deepmd.utils.argcheck import ( descrpt_se_r_args, ) @@ -261,3 +262,59 @@ def atol(self) -> float: return 1e-4 else: raise ValueError(f"Unknown precision: {precision}") + + +@parameterized( + (True, False), # resnet_dt + (True, False), # type_one_side + ([], [[0, 1]]), # excluded_types + ("float64",), # precision +) +class TestSeRDescriptorAPI(DescriptorAPITest, unittest.TestCase): + """Test consistency of BaseDescriptor API methods across backends.""" + + dp_class = DescrptSeRDP + pt_class = DescrptSeRPT + pt_expt_class = DescrptSeRPTExpt + args = descrpt_se_r_args() + + @property + def data(self) -> dict: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + ) = self.param + return { + "sel": [9, 10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "resnet_dt": resnet_dt, + "type_one_side": type_one_side, + "exclude_types": excluded_types, + "precision": precision, + "seed": 1145141919810, + "activation_function": "relu", + } + + @property + def skip_pt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + ) = self.param + return not type_one_side or not INSTALLED_PT + + @property + def skip_pt_expt(self) -> bool: + ( + resnet_dt, + type_one_side, + excluded_types, + precision, + ) = self.param + return not type_one_side or not INSTALLED_PT_EXPT diff --git a/source/tests/consistent/descriptor/test_se_t.py b/source/tests/consistent/descriptor/test_se_t.py index df03f270f5..7d2a33aba9 100644 --- a/source/tests/consistent/descriptor/test_se_t.py +++ b/source/tests/consistent/descriptor/test_se_t.py @@ -21,6 +21,7 @@ parameterized, ) from .common import ( + DescriptorAPITest, DescriptorTest, ) @@ -115,7 +116,7 @@ def skip_tf(self) -> bool: precision, env_protection, ) = self.param - return env_protection != 0.0 or excluded_types + return env_protection != 0.0 or excluded_types or CommonTest.skip_tf skip_array_api_strict = not INSTALLED_ARRAY_API_STRICT skip_jax = not INSTALLED_JAX @@ -261,3 +262,39 @@ def atol(self) -> float: return 1e-4 else: raise ValueError(f"Unknown precision: {precision}") + + +@parameterized( + (True, False), # resnet_dt + ([], [[0, 1]]), # excluded_types + ("float64",), # precision + (0.0, 1e-8, 1e-2), # env_protection +) +class TestSeTDescriptorAPI(DescriptorAPITest, unittest.TestCase): + """Test consistency of BaseDescriptor API methods across backends.""" + + dp_class = DescrptSeTDP + pt_class = DescrptSeTPT + pt_expt_class = DescrptSeTPTExpt + args = descrpt_se_t_args() + + @property + def data(self) -> dict: + ( + resnet_dt, + excluded_types, + precision, + env_protection, + ) = self.param + return { + "sel": [9, 10], + "rcut_smth": 5.80, + "rcut": 6.00, + "neuron": [6, 12, 24], + "resnet_dt": resnet_dt, + "exclude_types": excluded_types, + "env_protection": env_protection, + "precision": precision, + "seed": 1145141919810, + "activation_function": "relu", + } diff --git a/source/tests/consistent/descriptor/test_se_t_tebd.py b/source/tests/consistent/descriptor/test_se_t_tebd.py index 7d33679e69..4017e059f5 100644 --- a/source/tests/consistent/descriptor/test_se_t_tebd.py +++ b/source/tests/consistent/descriptor/test_se_t_tebd.py @@ -24,6 +24,7 @@ parameterized, ) from .common import ( + DescriptorAPITest, DescriptorTest, ) @@ -354,3 +355,62 @@ def atol(self) -> float: return 1e-4 else: raise ValueError(f"Unknown precision: {precision}") + + +@parameterized( + (4,), # tebd_dim + ("strip",), # tebd_input_mode + (True,), # resnet_dt + ([], [[0, 1]]), # excluded_types + (0.0,), # env_protection + (True, False), # set_davg_zero + (True, False), # smooth + (True,), # concat_output_tebd + ("float64",), # precision + (True, False), # use_econf_tebd + (False, True), # use_tebd_bias +) +class TestSeTTebdDescriptorAPI(DescriptorAPITest, unittest.TestCase): + """Test consistency of BaseDescriptor API methods across backends.""" + + dp_class = DescrptSeTTebdDP + pt_class = DescrptSeTTebdPT + pt_expt_class = DescrptSeTTebdPTExpt + args = descrpt_se_e3_tebd_args().append(Argument("ntypes", int, optional=False)) + + @property + def data(self) -> dict: + ( + tebd_dim, + tebd_input_mode, + resnet_dt, + excluded_types, + env_protection, + set_davg_zero, + smooth, + concat_output_tebd, + precision, + use_econf_tebd, + use_tebd_bias, + ) = self.param + return { + "sel": [10], + "rcut_smth": 3.50, + "rcut": 4.00, + "neuron": [2, 4, 8], + "ntypes": self.ntypes, + "tebd_dim": tebd_dim, + "tebd_input_mode": tebd_input_mode, + "concat_output_tebd": concat_output_tebd, + "resnet_dt": resnet_dt, + "exclude_types": excluded_types, + "env_protection": env_protection, + "precision": precision, + "set_davg_zero": set_davg_zero, + "smooth": smooth, + "use_econf_tebd": use_econf_tebd, + "use_tebd_bias": use_tebd_bias, + "type_map": ["O", "H"] if use_econf_tebd else None, + "seed": 1145141919810, + "activation_function": "relu", + } diff --git a/source/tests/pt_expt/descriptor/test_dpa1.py b/source/tests/pt_expt/descriptor/test_dpa1.py new file mode 100644 index 0000000000..04f2e8a983 --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_dpa1.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import numpy as np +import pytest +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from deepmd.dpmodel.descriptor.dpa1 import DescrptDPA1 as DPDescrptDPA1 +from deepmd.pt_expt.descriptor.dpa1 import ( + DescrptDPA1, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptDPA1(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + @pytest.mark.parametrize("idt", [False, True]) # resnet_dt + @pytest.mark.parametrize("sm", [False, True]) # smooth_type_embedding + @pytest.mark.parametrize("to", [False, True]) # type_one_side + @pytest.mark.parametrize("tm", ["concat", "strip"]) # tebd_input_mode + @pytest.mark.parametrize("prec", ["float64"]) # precision + @pytest.mark.parametrize("ect", [False, True]) # use_econf_tebd + def test_consistency(self, idt, sm, to, tm, prec, ect) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} sm={sm} to={to} tm={tm} prec={prec} ect={ect}" + + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + precision=prec, + resnet_dt=idt, + smooth_type_embedding=sm, + type_one_side=to, + tebd_input_mode=tm, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + # serialization round-trip + dd1 = DescrptDPA1.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # dp impl + dd2 = DPDescrptDPA1.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + @pytest.mark.parametrize("idt", [False, True]) # resnet_dt + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_exportable(self, idt, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) + + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx(self, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + dd0 = DescrptDPA1( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist) + traced = make_fx(fn)(coord_ext, atype_ext, nlist) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_dpa2.py b/source/tests/pt_expt/descriptor/test_dpa2.py new file mode 100644 index 0000000000..6e2651641c --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_dpa2.py @@ -0,0 +1,326 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import numpy as np +import pytest +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from deepmd.dpmodel.descriptor.dpa2 import DescrptDPA2 as DPDescrptDPA2 +from deepmd.dpmodel.descriptor.dpa2 import ( + RepformerArgs, + RepinitArgs, +) +from deepmd.pt_expt.descriptor.dpa2 import ( + DescrptDPA2, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptDPA2(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + @pytest.mark.parametrize("riti", ["concat", "strip"]) # repinit_tebd_input_mode + @pytest.mark.parametrize("rp1c", [True, False]) # repformer_update_g1_has_conv + @pytest.mark.parametrize("rp1d", [True, False]) # repformer_update_g1_has_drrd + @pytest.mark.parametrize("rp1g", [True, False]) # repformer_update_g1_has_grrg + @pytest.mark.parametrize("rp2a", [True, False]) # repformer_update_g2_has_attn + @pytest.mark.parametrize( + "rus", ["res_avg", "res_residual"] + ) # repformer_update_style + @pytest.mark.parametrize("prec", ["float64"]) # precision + @pytest.mark.parametrize("ect", [False, True]) # use_econf_tebd + @pytest.mark.parametrize("ns", [False, True]) # new sub-structures + def test_consistency( + self, riti, rp1c, rp1d, rp1g, rp2a, rus, prec, ect, ns + ) -> None: + if ns and not rp1d and not rp1g: + pytest.skip("invalid parameter combination") + + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + dstd_2 = 0.1 + np.abs(dstd_2) + + # fixed parameters + riz = True # repinit_set_davg_zero + rp1a = False # repformer_update_g1_has_attn + rp2g = False # repformer_update_g2_has_g1g1 + rph = False # repformer_update_h2 + rp2gate = True # repformer_attn2_has_gate + rpz = True # repformer_set_davg_zero + sm = True # smooth + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 # marginal test cases + + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=self.sel_mix, + tebd_input_mode=riti, + set_davg_zero=riz, + ) + repformer = RepformerArgs( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + nsel=nnei // 2, + nlayers=3, + g1_dim=20, + g2_dim=10, + axis_neuron=4, + update_g1_has_conv=rp1c, + update_g1_has_drrd=rp1d, + update_g1_has_grrg=rp1g, + update_g1_has_attn=rp1a, + update_g2_has_g1g1=rp2g, + update_g2_has_attn=rp2a, + update_h2=rph, + attn1_hidden=20, + attn1_nhead=2, + attn2_hidden=10, + attn2_nhead=2, + attn2_has_gate=rp2gate, + update_style=rus, + set_davg_zero=rpz, + use_sqrt_nnei=ns, + g1_out_conv=ns, + g1_out_mlp=ns, + ) + + dd0 = DescrptDPA2( + self.nt, + repinit=repinit, + repformer=repformer, + smooth=sm, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(self.device) + + dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.repinit.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0.repformers.mean = torch.tensor(davg_2, dtype=dtype, device=self.device) + dd0.repformers.stddev = torch.tensor(dstd_2, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + torch.tensor(self.mapping, dtype=int, device=self.device), + ) + # serialization round-trip + dd1 = DescrptDPA2.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + torch.tensor(self.mapping, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + # dp impl + dd2 = DPDescrptDPA2.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, self.atype_ext, self.nlist, self.mapping + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + ) + + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_exportable(self, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + dstd_2 = 0.1 + np.abs(dstd_2) + + dtype = PRECISION_DICT[prec] + + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=self.sel_mix, + tebd_input_mode="concat", + set_davg_zero=True, + ) + repformer = RepformerArgs( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + nsel=nnei // 2, + nlayers=3, + g1_dim=20, + g2_dim=10, + axis_neuron=4, + update_g1_has_conv=True, + update_g1_has_drrd=True, + update_g1_has_grrg=True, + update_g1_has_attn=False, + update_g2_has_g1g1=False, + update_g2_has_attn=True, + update_h2=False, + attn1_hidden=20, + attn1_nhead=2, + attn2_hidden=10, + attn2_nhead=2, + attn2_has_gate=True, + update_style="res_avg", + set_davg_zero=True, + use_sqrt_nnei=True, + g1_out_conv=True, + g1_out_mlp=True, + ) + + dd0 = DescrptDPA2( + self.nt, + repinit=repinit, + repformer=repformer, + smooth=True, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + + dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.repinit.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0.repformers.mean = torch.tensor(davg_2, dtype=dtype, device=self.device) + dd0.repformers.stddev = torch.tensor(dstd_2, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + torch.tensor(self.mapping, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) + + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx(self, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + davg_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd_2 = rng.normal(size=(self.nt, nnei // 2, 4)) + dstd = 0.1 + np.abs(dstd) + dstd_2 = 0.1 + np.abs(dstd_2) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 + + repinit = RepinitArgs( + rcut=self.rcut, + rcut_smth=self.rcut_smth, + nsel=self.sel_mix, + tebd_input_mode="concat", + set_davg_zero=True, + ) + repformer = RepformerArgs( + rcut=self.rcut / 2, + rcut_smth=self.rcut_smth, + nsel=nnei // 2, + nlayers=3, + g1_dim=20, + g2_dim=10, + axis_neuron=4, + update_g1_has_conv=True, + update_g1_has_drrd=True, + update_g1_has_grrg=True, + update_g1_has_attn=False, + update_g2_has_g1g1=False, + update_g2_has_attn=True, + update_h2=False, + attn1_hidden=20, + attn1_nhead=2, + attn2_hidden=10, + attn2_nhead=2, + attn2_has_gate=True, + update_style="res_avg", + set_davg_zero=True, + use_sqrt_nnei=True, + g1_out_conv=True, + g1_out_mlp=True, + ) + + dd0 = DescrptDPA2( + self.nt, + repinit=repinit, + repformer=repformer, + smooth=True, + exclude_types=[], + add_tebd_to_repinit_out=False, + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + + dd0.repinit.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.repinit.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0.repformers.mean = torch.tensor(davg_2, dtype=dtype, device=self.device) + dd0.repformers.stddev = torch.tensor(dstd_2, dtype=dtype, device=self.device) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + mapping = torch.tensor(self.mapping, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist, mapping): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist, mapping)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist, mapping) + traced = make_fx(fn)(coord_ext, atype_ext, nlist, mapping) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist, mapping) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_dpa3.py b/source/tests/pt_expt/descriptor/test_dpa3.py new file mode 100644 index 0000000000..4aeec0dbad --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_dpa3.py @@ -0,0 +1,248 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import numpy as np +import pytest +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from deepmd.dpmodel.descriptor.dpa3 import DescrptDPA3 as DPDescrptDPA3 +from deepmd.dpmodel.descriptor.dpa3 import ( + RepFlowArgs, +) +from deepmd.pt_expt.descriptor.dpa3 import ( + DescrptDPA3, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptDPA3(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + @pytest.mark.parametrize("ua", [True, False]) # update_angle + @pytest.mark.parametrize("ruri", ["norm", "const"]) # update_residual_init + @pytest.mark.parametrize("acr", [0, 1]) # a_compress_rate + @pytest.mark.parametrize("acer", [1, 2]) # a_compress_e_rate + @pytest.mark.parametrize("acus", [True, False]) # a_compress_use_split + @pytest.mark.parametrize("nme", [1, 2]) # n_multi_edge_message + def test_consistency(self, ua, ruri, acr, acer, acus, nme) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + # fixed parameters + rus = "res_residual" # update_style + prec = "float64" # precision + ect = False # use_econf_tebd + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 # marginal test cases + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=8, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + a_compress_rate=acr, + a_compress_e_rate=acer, + a_compress_use_split=acus, + n_multi_edge_message=nme, + axis_neuron=4, + update_angle=ua, + update_style=rus, + update_residual_init=ruri, + smooth_edge_update=True, + ) + + dd0 = DescrptDPA3( + self.nt, + repflow=repflow, + exclude_types=[], + precision=prec, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(self.device) + + dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + torch.tensor(self.mapping, dtype=int, device=self.device), + ) + # serialization round-trip + dd1 = DescrptDPA3.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + torch.tensor(self.mapping, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + # dp impl + dd2 = DPDescrptDPA3.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, self.atype_ext, self.nlist, self.mapping + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + ) + + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_exportable(self, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=8, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + axis_neuron=4, + update_angle=True, + update_style="res_residual", + update_residual_init="const", + smooth_edge_update=True, + ) + + dd0 = DescrptDPA3( + self.nt, + repflow=repflow, + exclude_types=[], + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + + dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + torch.tensor(self.mapping, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) + + @pytest.mark.parametrize("ruri", ["norm", "const"]) # update_residual_init + @pytest.mark.parametrize("acus", [True, False]) # a_compress_use_split + @pytest.mark.parametrize("nme", [1, 2]) # n_multi_edge_message + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx(self, ruri, acus, nme, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + nf, nloc, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + if prec == "float64": + atol = 1e-8 + + repflow = RepFlowArgs( + n_dim=20, + e_dim=10, + a_dim=8, + nlayers=3, + e_rcut=self.rcut, + e_rcut_smth=self.rcut_smth, + e_sel=nnei, + a_rcut=self.rcut - 0.1, + a_rcut_smth=self.rcut_smth, + a_sel=nnei - 1, + a_compress_use_split=acus, + n_multi_edge_message=nme, + axis_neuron=4, + update_angle=True, + update_style="res_residual", + update_residual_init=ruri, + smooth_edge_update=True, + ) + + dd0 = DescrptDPA3( + self.nt, + repflow=repflow, + exclude_types=[], + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + + dd0.repflows.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.repflows.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + mapping = torch.tensor(self.mapping, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist, mapping): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist, mapping)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist, mapping) + traced = make_fx(fn)(coord_ext, atype_ext, nlist, mapping) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist, mapping) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_hybrid.py b/source/tests/pt_expt/descriptor/test_hybrid.py new file mode 100644 index 0000000000..41a87273d5 --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_hybrid.py @@ -0,0 +1,219 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import numpy as np +import pytest +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from deepmd.dpmodel.descriptor.hybrid import DescrptHybrid as DPDescrptHybrid +from deepmd.pt_expt.descriptor.hybrid import ( + DescrptHybrid, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.descriptor.se_r import ( + DescrptSeR, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptHybrid(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_consistency(self, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"prec={prec}" + + ddsub0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + seed=GLOBAL_SEED, + ) + ddsub1 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + seed=GLOBAL_SEED, + ) + dd0 = DescrptHybrid( + list=[ddsub0, ddsub1], + ).to(self.device) + # set davg/dstd on sub-descriptors + dd0.descrpt_list[0].davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.descrpt_list[0].dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0.descrpt_list[1].davg = torch.tensor( + davg[..., :1], dtype=dtype, device=self.device + ) + dd0.descrpt_list[1].dstd = torch.tensor( + dstd[..., :1], dtype=dtype, device=self.device + ) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + # serialization round-trip + dd1 = DescrptHybrid.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # dp impl + dd2 = DPDescrptHybrid.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_exportable(self, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + + ddsub0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + seed=GLOBAL_SEED, + ) + ddsub1 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + seed=GLOBAL_SEED, + ) + dd0 = DescrptHybrid( + list=[ddsub0, ddsub1], + ).to(self.device) + dd0.descrpt_list[0].davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.descrpt_list[0].dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0.descrpt_list[1].davg = torch.tensor( + davg[..., :1], dtype=dtype, device=self.device + ) + dd0.descrpt_list[1].dstd = torch.tensor( + dstd[..., :1], dtype=dtype, device=self.device + ) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) + + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx(self, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + + ddsub0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + seed=GLOBAL_SEED, + ) + ddsub1 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + seed=GLOBAL_SEED, + ) + dd0 = DescrptHybrid( + list=[ddsub0, ddsub1], + ).to(self.device) + dd0.descrpt_list[0].davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.descrpt_list[0].dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0.descrpt_list[1].davg = torch.tensor( + davg[..., :1], dtype=dtype, device=self.device + ) + dd0.descrpt_list[1].dstd = torch.tensor( + dstd[..., :1], dtype=dtype, device=self.device + ) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist) + traced = make_fx(fn)(coord_ext, atype_ext, nlist) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_se_atten_v2.py b/source/tests/pt_expt/descriptor/test_se_atten_v2.py new file mode 100644 index 0000000000..9d33e83e68 --- /dev/null +++ b/source/tests/pt_expt/descriptor/test_se_atten_v2.py @@ -0,0 +1,177 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later + +import numpy as np +import pytest +import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) + +from deepmd.dpmodel.descriptor.se_atten_v2 import DescrptSeAttenV2 as DPDescrptSeAttenV2 +from deepmd.pt_expt.descriptor.se_atten_v2 import ( + DescrptSeAttenV2, +) +from deepmd.pt_expt.utils import ( + env, +) +from deepmd.pt_expt.utils.env import ( + PRECISION_DICT, +) + +from ...pt.model.test_env_mat import ( + TestCaseSingleFrameWithNlist, +) +from ...pt.model.test_mlp import ( + get_tols, +) +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDescrptSeAttenV2(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: + TestCaseSingleFrameWithNlist.setUp(self) + self.device = env.DEVICE + + @pytest.mark.parametrize("idt", [False, True]) # resnet_dt + @pytest.mark.parametrize("to", [False, True]) # type_one_side + @pytest.mark.parametrize("prec", ["float64"]) # precision + @pytest.mark.parametrize("ect", [False, True]) # use_econf_tebd + def test_consistency(self, idt, to, prec, ect) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} to={to} prec={prec} ect={ect}" + + dd0 = DescrptSeAttenV2( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + precision=prec, + resnet_dt=idt, + type_one_side=to, + use_econf_tebd=ect, + type_map=["O", "H"] if ect else None, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + # serialization round-trip + dd1 = DescrptSeAttenV2.deserialize(dd0.serialize()) + rd1, _, _, _, _ = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + # dp impl + dd2 = DPDescrptSeAttenV2.deserialize(dd0.serialize()) + rd2, _, _, _, _ = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + @pytest.mark.parametrize("idt", [False, True]) # resnet_dt + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_exportable(self, idt, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeAttenV2( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) + + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx(self, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + dd0 = DescrptSeAttenV2( + self.rcut, + self.rcut_smth, + self.sel_mix, + self.nt, + attn_layer=2, + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.se_atten.mean = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.se_atten.stddev = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist) + traced = make_fx(fn)(coord_ext, atype_ext, nlist) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_se_e2_a.py b/source/tests/pt_expt/descriptor/test_se_e2_a.py index e63138e43b..2c1121cc8b 100644 --- a/source/tests/pt_expt/descriptor/test_se_e2_a.py +++ b/source/tests/pt_expt/descriptor/test_se_e2_a.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import itertools -import unittest import numpy as np +import pytest import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.descriptor import DescrptSeA as DPDescrptSeA from deepmd.pt_expt.descriptor.se_e2_a import ( @@ -30,106 +32,148 @@ ) -class TestDescrptSeA(unittest.TestCase, TestCaseSingleFrameWithNlist): - def setUp(self) -> None: +class TestDescrptSeA(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: TestCaseSingleFrameWithNlist.setUp(self) self.device = env.DEVICE - def test_consistency(self) -> None: + @pytest.mark.parametrize("idt", [False, True]) # resnet_dt + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + @pytest.mark.parametrize("em", [[], [[0, 1]], [[1, 1]]]) # exclude_types + def test_consistency(self, idt, prec, em) -> None: rng = np.random.default_rng(GLOBAL_SEED) _, _, nnei = self.nlist.shape davg = rng.normal(size=(self.nt, nnei, 4)) dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec, em in itertools.product( - [False, True], - ["float64", "float32"], - [[], [[0, 1]], [[1, 1]]], - ): - dtype = PRECISION_DICT[prec] - rtol, atol = get_tols(prec) - err_msg = f"idt={idt} prec={prec}" - dd0 = DescrptSeA( - self.rcut, - self.rcut_smth, - self.sel, - precision=prec, - resnet_dt=idt, - exclude_types=em, - seed=GLOBAL_SEED, - ).to(self.device) - dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) - dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - rd0, _, _, _, _ = dd0( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - dd1 = DescrptSeA.deserialize(dd0.serialize()) - rd1, gr1, _, _, sw1 = dd1( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - np.testing.assert_allclose( - rd0.detach().cpu().numpy(), - rd1.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + exclude_types=em, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeA.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeA.deserialize(dd0.serialize()) + rd2, gr2, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + for aa, bb in zip([rd1, gr1, sw1], [rd2, gr2, sw2], strict=True): np.testing.assert_allclose( - rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], - rd0.detach().cpu().numpy()[1], + aa.detach().cpu().numpy(), + bb, rtol=rtol, atol=atol, err_msg=err_msg, ) - dd2 = DPDescrptSeA.deserialize(dd0.serialize()) - rd2, gr2, _, _, sw2 = dd2.call( - self.coord_ext, - self.atype_ext, - self.nlist, - ) - for aa, bb in zip([rd1, gr1, sw1], [rd2, gr2, sw2], strict=True): - np.testing.assert_allclose( - aa.detach().cpu().numpy(), - bb, - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) - if em: - dd1.reinit_exclude([tuple(x) for x in em]) - self.assertIsInstance(dd1.emask, PairExcludeMask) + if em: + dd1.reinit_exclude([tuple(x) for x in em]) + assert isinstance(dd1.emask, PairExcludeMask) - def test_exportable(self) -> None: + @pytest.mark.parametrize("idt", [False, True]) # resnet_dt + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_exportable(self, idt, prec) -> None: rng = np.random.default_rng(GLOBAL_SEED) _, _, nnei = self.nlist.shape davg = rng.normal(size=(self.nt, nnei, 4)) dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec in itertools.product( - [False, True], - ["float64", "float32"], - ): - dtype = PRECISION_DICT[prec] - dd0 = DescrptSeA( - self.rcut, - self.rcut_smth, - self.sel, - precision=prec, - resnet_dt=idt, - seed=GLOBAL_SEED, - ).to(self.device) - dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) - dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - dd0 = dd0.eval() - inputs = ( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - torch.export.export(dd0, inputs) + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) + + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx(self, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + dd0 = DescrptSeA( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist) + traced = make_fx(fn)(coord_ext, atype_ext, nlist) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_se_r.py b/source/tests/pt_expt/descriptor/test_se_r.py index c789b13652..2c494d9a53 100644 --- a/source/tests/pt_expt/descriptor/test_se_r.py +++ b/source/tests/pt_expt/descriptor/test_se_r.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import itertools -import unittest import numpy as np +import pytest import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.descriptor import DescrptSeR as DPDescrptSeR from deepmd.pt_expt.descriptor.se_r import ( @@ -27,104 +29,146 @@ ) -class TestDescrptSeR(unittest.TestCase, TestCaseSingleFrameWithNlist): - def setUp(self) -> None: +class TestDescrptSeR(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: TestCaseSingleFrameWithNlist.setUp(self) self.device = env.DEVICE - def test_consistency(self) -> None: + @pytest.mark.parametrize("idt", [False, True]) # resnet_dt + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + @pytest.mark.parametrize("em", [[], [[0, 1]], [[1, 1]]]) # exclude_types + def test_consistency(self, idt, prec, em) -> None: rng = np.random.default_rng(GLOBAL_SEED) _, _, nnei = self.nlist.shape davg = rng.normal(size=(self.nt, nnei, 1)) dstd = rng.normal(size=(self.nt, nnei, 1)) dstd = 0.1 + np.abs(dstd) - for idt, prec, em in itertools.product( - [False, True], - ["float64", "float32"], - [[], [[0, 1]], [[1, 1]]], - ): - dtype = PRECISION_DICT[prec] - rtol, atol = get_tols(prec) - err_msg = f"idt={idt} prec={prec}" - dd0 = DescrptSeR( - self.rcut, - self.rcut_smth, - self.sel, - precision=prec, - resnet_dt=idt, - exclude_types=em, - seed=GLOBAL_SEED, - ).to(self.device) - dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) - dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + exclude_types=em, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - rd0, _, _, _, _ = dd0( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - dd1 = DescrptSeR.deserialize(dd0.serialize()) - rd1, _, _, _, sw1 = dd1( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - np.testing.assert_allclose( - rd0.detach().cpu().numpy(), - rd1.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeR.deserialize(dd0.serialize()) + rd1, _, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeR.deserialize(dd0.serialize()) + rd2, _, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + for aa, bb in zip([rd1, sw1], [rd2, sw2], strict=True): np.testing.assert_allclose( - rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], - rd0.detach().cpu().numpy()[1], + aa.detach().cpu().numpy(), + bb, rtol=rtol, atol=atol, err_msg=err_msg, ) - dd2 = DPDescrptSeR.deserialize(dd0.serialize()) - rd2, _, _, _, sw2 = dd2.call( - self.coord_ext, - self.atype_ext, - self.nlist, - ) - for aa, bb in zip([rd1, sw1], [rd2, sw2], strict=True): - np.testing.assert_allclose( - aa.detach().cpu().numpy(), - bb, - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) - def test_exportable(self) -> None: + @pytest.mark.parametrize("idt", [False, True]) # resnet_dt + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_exportable(self, idt, prec) -> None: rng = np.random.default_rng(GLOBAL_SEED) _, _, nnei = self.nlist.shape davg = rng.normal(size=(self.nt, nnei, 1)) dstd = rng.normal(size=(self.nt, nnei, 1)) dstd = 0.1 + np.abs(dstd) - for idt, prec in itertools.product( - [False, True], - ["float64", "float32"], - ): - dtype = PRECISION_DICT[prec] - dd0 = DescrptSeR( - self.rcut, - self.rcut_smth, - self.sel, - precision=prec, - resnet_dt=idt, - seed=GLOBAL_SEED, - ).to(self.device) - dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) - dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - dd0 = dd0.eval() - inputs = ( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - torch.export.export(dd0, inputs) + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) + + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx(self, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 1)) + dstd = rng.normal(size=(self.nt, nnei, 1)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + dd0 = DescrptSeR( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist) + traced = make_fx(fn)(coord_ext, atype_ext, nlist) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_se_t.py b/source/tests/pt_expt/descriptor/test_se_t.py index 921f10a54a..37ed787bbd 100644 --- a/source/tests/pt_expt/descriptor/test_se_t.py +++ b/source/tests/pt_expt/descriptor/test_se_t.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import itertools -import unittest import numpy as np +import pytest import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.descriptor import DescrptSeT as DPDescrptSeT from deepmd.pt_expt.descriptor.se_t import ( @@ -27,108 +29,150 @@ ) -class TestDescrptSeT(unittest.TestCase, TestCaseSingleFrameWithNlist): - def setUp(self) -> None: +class TestDescrptSeT(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: TestCaseSingleFrameWithNlist.setUp(self) self.device = env.DEVICE - def test_consistency(self) -> None: + @pytest.mark.parametrize("idt", [False, True]) # resnet_dt + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_consistency(self, idt, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeT( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeT.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeT.deserialize(dd0.serialize()) + rd2, gr2, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + # se_t returns None for gr/g2/h2, only compare rd and sw + np.testing.assert_allclose( + rd1.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + sw1.detach().cpu().numpy(), + sw2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + + @pytest.mark.parametrize("idt", [False, True]) # resnet_dt + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_exportable(self, idt, prec) -> None: rng = np.random.default_rng(GLOBAL_SEED) _, _, nnei = self.nlist.shape davg = rng.normal(size=(self.nt, nnei, 4)) dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec in itertools.product( - [False, True], - ["float64", "float32"], - ): - dtype = PRECISION_DICT[prec] - rtol, atol = get_tols(prec) - err_msg = f"idt={idt} prec={prec}" - dd0 = DescrptSeT( - self.rcut, - self.rcut_smth, - self.sel, - precision=prec, - resnet_dt=idt, - seed=GLOBAL_SEED, - ).to(self.device) - dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) - dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - rd0, _, _, _, _ = dd0( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - dd1 = DescrptSeT.deserialize(dd0.serialize()) - rd1, gr1, _, _, sw1 = dd1( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - np.testing.assert_allclose( - rd0.detach().cpu().numpy(), - rd1.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) - np.testing.assert_allclose( - rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], - rd0.detach().cpu().numpy()[1], - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) - dd2 = DPDescrptSeT.deserialize(dd0.serialize()) - rd2, gr2, _, _, sw2 = dd2.call( - self.coord_ext, - self.atype_ext, - self.nlist, - ) - # se_t returns None for gr/g2/h2, only compare rd and sw - np.testing.assert_allclose( - rd1.detach().cpu().numpy(), - rd2, - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) - np.testing.assert_allclose( - sw1.detach().cpu().numpy(), - sw2, - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeT( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) - def test_exportable(self) -> None: + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx(self, prec) -> None: rng = np.random.default_rng(GLOBAL_SEED) _, _, nnei = self.nlist.shape davg = rng.normal(size=(self.nt, nnei, 4)) dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec in itertools.product( - [False, True], - ["float64", "float32"], - ): - dtype = PRECISION_DICT[prec] - dd0 = DescrptSeT( - self.rcut, - self.rcut_smth, - self.sel, - precision=prec, - resnet_dt=idt, - seed=GLOBAL_SEED, - ).to(self.device) - dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) - dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - dd0 = dd0.eval() - inputs = ( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - torch.export.export(dd0, inputs) + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + dd0 = DescrptSeT( + self.rcut, + self.rcut_smth, + self.sel, + precision=prec, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist) + traced = make_fx(fn)(coord_ext, atype_ext, nlist) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) diff --git a/source/tests/pt_expt/descriptor/test_se_t_tebd.py b/source/tests/pt_expt/descriptor/test_se_t_tebd.py index e84080882a..09184325f5 100644 --- a/source/tests/pt_expt/descriptor/test_se_t_tebd.py +++ b/source/tests/pt_expt/descriptor/test_se_t_tebd.py @@ -1,9 +1,11 @@ # SPDX-License-Identifier: LGPL-3.0-or-later -import itertools -import unittest import numpy as np +import pytest import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.descriptor import DescrptSeTTebd as DPDescrptSeTTebd from deepmd.pt_expt.descriptor.se_t_tebd import ( @@ -27,121 +29,166 @@ ) -class TestDescrptSeTTebd(unittest.TestCase, TestCaseSingleFrameWithNlist): - def setUp(self) -> None: +class TestDescrptSeTTebd(TestCaseSingleFrameWithNlist): + def setup_method(self) -> None: TestCaseSingleFrameWithNlist.setUp(self) self.device = env.DEVICE - def test_consistency(self) -> None: + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_consistency(self, prec) -> None: rng = np.random.default_rng(GLOBAL_SEED) _, _, nnei = self.nlist.shape davg = rng.normal(size=(self.nt, nnei, 4)) dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec in itertools.product( - [True], # SeTTebd typically uses resnet_dt=True - ["float64", "float32"], - ): - dtype = PRECISION_DICT[prec] - rtol, atol = get_tols(prec) - err_msg = f"idt={idt} prec={prec}" - dd0 = DescrptSeTTebd( - self.rcut, - self.rcut_smth, - self.sel, - self.nt, - precision=prec, - resnet_dt=idt, - seed=GLOBAL_SEED, - ).to(self.device) - dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) - dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + idt = True # SeTTebd typically uses resnet_dt=True + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + err_msg = f"idt={idt} prec={prec}" + dd0 = DescrptSeTTebd( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - rd0, _, _, _, _ = dd0( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - dd1 = DescrptSeTTebd.deserialize(dd0.serialize()) - rd1, gr1, _, _, sw1 = dd1( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - np.testing.assert_allclose( - rd0.detach().cpu().numpy(), - rd1.detach().cpu().numpy(), - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) + rd0, _, _, _, _ = dd0( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + dd1 = DescrptSeTTebd.deserialize(dd0.serialize()) + rd1, gr1, _, _, sw1 = dd1( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy(), + rd1.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + np.testing.assert_allclose( + rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], + rd0.detach().cpu().numpy()[1], + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + dd2 = DPDescrptSeTTebd.deserialize(dd0.serialize()) + rd2, gr2, _, _, sw2 = dd2.call( + self.coord_ext, + self.atype_ext, + self.nlist, + ) + # se_t_tebd should return gr and sw, compare only descriptor and sw for now + # TODO: investigate why gr is None + np.testing.assert_allclose( + rd1.detach().cpu().numpy(), + rd2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) + if gr1 is not None and gr2 is not None: np.testing.assert_allclose( - rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]], - rd0.detach().cpu().numpy()[1], - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) - dd2 = DPDescrptSeTTebd.deserialize(dd0.serialize()) - rd2, gr2, _, _, sw2 = dd2.call( - self.coord_ext, - self.atype_ext, - self.nlist, - ) - # se_t_tebd should return gr and sw, compare only descriptor and sw for now - # TODO: investigate why gr is None - np.testing.assert_allclose( - rd1.detach().cpu().numpy(), - rd2, - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) - if gr1 is not None and gr2 is not None: - np.testing.assert_allclose( - gr1.detach().cpu().numpy(), - gr2, - rtol=rtol, - atol=atol, - err_msg=err_msg, - ) - np.testing.assert_allclose( - sw1.detach().cpu().numpy(), - sw2, + gr1.detach().cpu().numpy(), + gr2, rtol=rtol, atol=atol, err_msg=err_msg, ) + np.testing.assert_allclose( + sw1.detach().cpu().numpy(), + sw2, + rtol=rtol, + atol=atol, + err_msg=err_msg, + ) - def test_exportable(self) -> None: + @pytest.mark.parametrize("prec", ["float64", "float32"]) # precision + def test_exportable(self, prec) -> None: rng = np.random.default_rng(GLOBAL_SEED) _, _, nnei = self.nlist.shape davg = rng.normal(size=(self.nt, nnei, 4)) dstd = rng.normal(size=(self.nt, nnei, 4)) dstd = 0.1 + np.abs(dstd) - for idt, prec in itertools.product( - [True], - ["float64", "float32"], - ): - dtype = PRECISION_DICT[prec] - dd0 = DescrptSeTTebd( - self.rcut, - self.rcut_smth, - self.sel, - self.nt, - precision=prec, - resnet_dt=idt, - seed=GLOBAL_SEED, - ).to(self.device) - dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) - dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) - dd0 = dd0.eval() + idt = True + dtype = PRECISION_DICT[prec] + dd0 = DescrptSeTTebd( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() - inputs = ( - torch.tensor(self.coord_ext, dtype=dtype, device=self.device), - torch.tensor(self.atype_ext, dtype=int, device=self.device), - torch.tensor(self.nlist, dtype=int, device=self.device), - ) - torch.export.export(dd0, inputs) + inputs = ( + torch.tensor(self.coord_ext, dtype=dtype, device=self.device), + torch.tensor(self.atype_ext, dtype=int, device=self.device), + torch.tensor(self.nlist, dtype=int, device=self.device), + ) + torch.export.export(dd0, inputs) + + @pytest.mark.parametrize("prec", ["float64"]) # precision + def test_make_fx(self, prec) -> None: + rng = np.random.default_rng(GLOBAL_SEED) + _, _, nnei = self.nlist.shape + davg = rng.normal(size=(self.nt, nnei, 4)) + dstd = rng.normal(size=(self.nt, nnei, 4)) + dstd = 0.1 + np.abs(dstd) + + idt = True + dtype = PRECISION_DICT[prec] + rtol, atol = get_tols(prec) + dd0 = DescrptSeTTebd( + self.rcut, + self.rcut_smth, + self.sel, + self.nt, + precision=prec, + resnet_dt=idt, + seed=GLOBAL_SEED, + ).to(self.device) + dd0.davg = torch.tensor(davg, dtype=dtype, device=self.device) + dd0.dstd = torch.tensor(dstd, dtype=dtype, device=self.device) + dd0 = dd0.eval() + coord_ext = torch.tensor(self.coord_ext, dtype=dtype, device=self.device) + atype_ext = torch.tensor(self.atype_ext, dtype=int, device=self.device) + nlist = torch.tensor(self.nlist, dtype=int, device=self.device) + + def fn(coord_ext, atype_ext, nlist): + coord_ext = coord_ext.detach().requires_grad_(True) + rd = dd0(coord_ext, atype_ext, nlist)[0] + grad = torch.autograd.grad(rd.sum(), coord_ext, create_graph=False)[0] + return rd, grad + + rd_eager, grad_eager = fn(coord_ext, atype_ext, nlist) + traced = make_fx(fn)(coord_ext, atype_ext, nlist) + rd_traced, grad_traced = traced(coord_ext, atype_ext, nlist) + np.testing.assert_allclose( + rd_eager.detach().cpu().numpy(), + rd_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) + np.testing.assert_allclose( + grad_eager.detach().cpu().numpy(), + grad_traced.detach().cpu().numpy(), + rtol=rtol, + atol=atol, + ) From 115ec93c6c574553bddf85c82dc6a977232d7a2a Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 18 Feb 2026 21:56:50 +0800 Subject: [PATCH 2/9] fix issue of require grad --- deepmd/pt_expt/common.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index a29f61f3f8..bc58ca4a5f 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -189,18 +189,19 @@ def _try_convert_list(name: str, value: list) -> torch.nn.Module | None: ) converted.append(c) return torch.nn.ModuleList(converted) - # List of numpy arrays → ParameterList (non-trainable) + # List of numpy arrays → ParameterList if all(isinstance(v, np.ndarray) for v in value): from deepmd.pt_expt.utils import env # deferred - avoids circular import - return torch.nn.ParameterList( - [ + params = [] + for v in value: + t = torch.as_tensor(v, device=env.DEVICE) + params.append( torch.nn.Parameter( - torch.as_tensor(v, device=env.DEVICE), requires_grad=False + t, requires_grad=t.is_floating_point() or t.is_complex() ) - for v in value - ] - ) + ) + return torch.nn.ParameterList(params) return None From 6b4748d703084e0fe25f2d4f893e7c9182bd7385 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 18 Feb 2026 22:16:10 +0800 Subject: [PATCH 3/9] fix performance issue when type_one_side == False --- deepmd/dpmodel/descriptor/se_e2_a.py | 63 ++++++++++++++++++---------- 1 file changed, 40 insertions(+), 23 deletions(-) diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index e949a4946b..4b35cf0ac9 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -453,30 +453,47 @@ def call( rr = xp.reshape(rr, (nf * nloc, nnei, 4)) rr = xp.astype(rr, self.dstd.dtype) - if not self.type_one_side: - # nf x nloc -> (nf * nloc) + if self.type_one_side: + for tt in range(self.ntypes): + mm = exclude_mask[:, sec[tt] : sec[tt + 1]] + tr = rr[:, sec[tt] : sec[tt + 1], :] + tr = tr * xp.astype(mm[:, :, None], tr.dtype) + ss = tr[..., 0:1] + gg = self.cal_g(ss, (tt,)) + gr += xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1) + else: + # Sort atoms by center type so each type forms a contiguous block. + # Slice indexing (arr[s:e]) is array-api compatible and lets us + # run cal_g only on atoms of the matching center type, keeping the + # same O(nf*nloc) total embedding cost as the original numpy code. atype_loc = xp.reshape(atype_ext[:, :nloc], (nf * nloc,)) - - for embedding_idx in itertools.product( - range(self.ntypes), repeat=self.embeddings.ndim - ): - if self.type_one_side: - (tt,) = embedding_idx - else: - ti, tt = embedding_idx - mm = exclude_mask[:, sec[tt] : sec[tt + 1]] - tr = rr[:, sec[tt] : sec[tt + 1], :] - tr = tr * xp.astype(mm[:, :, None], tr.dtype) - ss = tr[..., 0:1] - gg = self.cal_g(ss, embedding_idx) - gr_tmp = xp.sum(gg[:, :, :, None] * tr[:, :, None, :], axis=1) - if not self.type_one_side: - # (nf * nloc) x 1 x 1 - ti_mask = xp.astype( - xp.reshape(atype_loc == ti, (nf * nloc, 1, 1)), gr_tmp.dtype - ) - gr_tmp = gr_tmp * ti_mask - gr += gr_tmp + sort_idx = xp.argsort(atype_loc) + unsort_idx = xp.argsort(sort_idx) + rr_s = xp.take(rr, sort_idx, axis=0) + mask_s = xp.take(exclude_mask, sort_idx, axis=0) + dev = array_api_compat.device(coord_ext) + gr_s = xp.zeros([nf * nloc, ng, 4], dtype=input_dtype, device=dev) + # Per-type boundaries in sorted order + type_ends = [] + offset = 0 + for ti in range(self.ntypes): + offset += int(xp.sum(xp.astype(atype_loc == ti, xp.int32))) + type_ends.append(offset) + type_starts = [0, *type_ends[:-1]] + for ti in range(self.ntypes): + s, e = type_starts[ti], type_ends[ti] + if s == e: + continue + for tt in range(self.ntypes): + mm = mask_s[s:e, sec[tt] : sec[tt + 1]] + tr = rr_s[s:e, sec[tt] : sec[tt + 1], :] + tr = tr * xp.astype(mm[:, :, None], tr.dtype) + ss = tr[..., 0:1] + gg = self.cal_g(ss, (ti, tt)) + gr_s[s:e] = gr_s[s:e] + xp.sum( + gg[:, :, :, None] * tr[:, :, None, :], axis=1 + ) + gr = xp.take(gr_s, unsort_idx, axis=0) gr = xp.reshape(gr, (nf, nloc, ng, 4)) # nf x nloc x ng x 4 gr /= self.nnei From a3054f2e9f09ccdfee44bd300948354efe401cea Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 18 Feb 2026 22:53:32 +0800 Subject: [PATCH 4/9] fix --- deepmd/pt_expt/common.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/deepmd/pt_expt/common.py b/deepmd/pt_expt/common.py index bc58ca4a5f..d6c2a69603 100644 --- a/deepmd/pt_expt/common.py +++ b/deepmd/pt_expt/common.py @@ -160,7 +160,14 @@ def _auto_wrap_native_op(value: NativeOP) -> torch.nn.Module: {"forward": lambda self, *args, **kwargs: self.call(*args, **kwargs)}, ) _AUTO_WRAPPED_CLASSES[cls] = torch_module(wrapped) - return _AUTO_WRAPPED_CLASSES[cls].deserialize(value.serialize()) + wrapped_cls = _AUTO_WRAPPED_CLASSES[cls] + if not (hasattr(value, "serialize") and hasattr(wrapped_cls, "deserialize")): + raise TypeError( + f"Cannot auto-wrap {cls.__name__}: " + "it must implement serialize()/deserialize() or be explicitly " + "registered via register_dpmodel_mapping()." + ) + return wrapped_cls.deserialize(value.serialize()) def _try_convert_list(name: str, value: list) -> torch.nn.Module | None: From d29cee83531e60cf3bd3e37b4a4f202147beb880 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 18 Feb 2026 23:21:29 +0800 Subject: [PATCH 5/9] fix docstr --- deepmd/dpmodel/descriptor/se_e2_a.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 4b35cf0ac9..4710987f54 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -138,9 +138,8 @@ class DescrptSeA(NativeOP, BaseDescriptor): ----------- The currently implementation does not support the following features - 1. type_one_side == False - 2. exclude_types != [] - 3. spin is not None + 1. exclude_types != [] + 2. spin is not None References ---------- From 0357f4e468bc2f34e8219fdd30f747a721b3dd4d Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 19 Feb 2026 14:25:38 +0800 Subject: [PATCH 6/9] fix device --- deepmd/dpmodel/utils/network.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index b385ce6005..28682bf3b8 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -1332,7 +1332,8 @@ def get_graph_index( # noqa: ANN201 # 1. atom graph # node(i) to edge(ij) index_select; edge(ij) to node aggregate - nlist_loc_index = xp.arange(nf * nloc, dtype=nlist.dtype) + dev = array_api_compat.device(nlist) + nlist_loc_index = xp.arange(nf * nloc, dtype=nlist.dtype, device=dev) # nf x nloc x nnei n2e_index = xp.broadcast_to( xp.reshape(nlist_loc_index, (nf, nloc, 1)), (nf, nloc, nnei) @@ -1341,7 +1342,7 @@ def get_graph_index( # noqa: ANN201 n2e_index = n2e_index[xp.astype(nlist_mask, xp.bool)] # node_ext(j) to edge(ij) index_select - frame_shift = xp.arange(nf, dtype=nlist.dtype) * ( + frame_shift = xp.arange(nf, dtype=nlist.dtype, device=dev) * ( nall if not use_loc_mapping else nloc ) shifted_nlist = nlist + frame_shift[:, xp.newaxis, xp.newaxis] @@ -1357,8 +1358,8 @@ def get_graph_index( # noqa: ANN201 n2a_index = n2a_index[a_nlist_mask_3d] # edge(ij) to angle(ijk) index_select; angle(ijk) to edge(ij) aggregate - edge_id = xp.arange(n_edge, dtype=nlist.dtype) - edge_index = xp.zeros((nf, nloc, nnei), dtype=nlist.dtype) + edge_id = xp.arange(n_edge, dtype=nlist.dtype, device=dev) + edge_index = xp.zeros((nf, nloc, nnei), dtype=nlist.dtype, device=dev) edge_index = xp_setitem_at(edge_index, xp.astype(nlist_mask, xp.bool), edge_id) # only cut a_nnei neighbors, to avoid nnei x nnei edge_index = edge_index[:, :, :a_nnei] From 850996a7c4dec8aafeca7656d806f2597a97314e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 19 Feb 2026 14:29:15 +0800 Subject: [PATCH 7/9] fix --- deepmd/dpmodel/utils/network.py | 2 +- deepmd/pt_expt/utils/network.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index 28682bf3b8..dc79fd21cb 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -60,7 +60,7 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "Identity": - return Identity() + return cls() class NativeLayer(NativeOP): diff --git a/deepmd/pt_expt/utils/network.py b/deepmd/pt_expt/utils/network.py index 04cfaad441..1629ecb83a 100644 --- a/deepmd/pt_expt/utils/network.py +++ b/deepmd/pt_expt/utils/network.py @@ -295,5 +295,5 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: register_dpmodel_mapping( IdentityDP, - lambda v: Identity.deserialize(v.serialize()), + lambda v: Identity(), ) From 14f3d16d21c09a05a689596acf88fbf54e2b26ae Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 19 Feb 2026 16:06:39 +0800 Subject: [PATCH 8/9] fix --- deepmd/dpmodel/utils/network.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/utils/network.py b/deepmd/dpmodel/utils/network.py index dc79fd21cb..da6a305b9b 100644 --- a/deepmd/dpmodel/utils/network.py +++ b/deepmd/dpmodel/utils/network.py @@ -1265,11 +1265,14 @@ def aggregate( # noqa: ANN201 bin_count = xp_bincount(owners) bin_count = xp.where(bin_count == 0, xp.ones_like(bin_count), bin_count) + dev = array_api_compat.device(data) if num_owner is not None and bin_count.shape[0] != num_owner: difference = num_owner - bin_count.shape[0] - bin_count = xp.concat([bin_count, xp.ones(difference, dtype=bin_count.dtype)]) + bin_count = xp.concat( + [bin_count, xp.ones(difference, dtype=bin_count.dtype, device=dev)] + ) - output = xp.zeros((bin_count.shape[0], data.shape[1]), dtype=data.dtype) + output = xp.zeros((bin_count.shape[0], data.shape[1]), dtype=data.dtype, device=dev) output = xp_add_at(output, owners, data) if average: From 8e433461303057fd85c139198ca9effa84dd6302 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 19 Feb 2026 17:04:29 +0800 Subject: [PATCH 9/9] fix similar issues in models --- deepmd/dpmodel/array_api.py | 5 ++++- deepmd/dpmodel/atomic_model/linear_atomic_model.py | 4 +++- deepmd/dpmodel/atomic_model/pairtab_atomic_model.py | 10 ++++++++-- deepmd/dpmodel/model/make_model.py | 7 ++++++- 4 files changed, 21 insertions(+), 5 deletions(-) diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index 4a7fa9a45c..cfdcfdca96 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -51,7 +51,10 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array: else: indices = xp.reshape(indices, (0, 0)) - offset = (xp.arange(indices.shape[0], dtype=indices.dtype) * m)[:, xp.newaxis] + dev = array_api_compat.device(indices) + offset = (xp.arange(indices.shape[0], dtype=indices.dtype, device=dev) * m)[ + :, xp.newaxis + ] indices = xp.reshape(offset + indices, (-1,)) out = xp.take(arr, indices) diff --git a/deepmd/dpmodel/atomic_model/linear_atomic_model.py b/deepmd/dpmodel/atomic_model/linear_atomic_model.py index 380c300216..b73dcb77fb 100644 --- a/deepmd/dpmodel/atomic_model/linear_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/linear_atomic_model.py @@ -337,9 +337,11 @@ def _compute_weight( xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlists_) nmodels = len(self.models) nframes, nloc, _ = nlists_[0].shape + dev = array_api_compat.device(extended_coord) # the dtype of weights is the interface data type. return [ - xp.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION) / nmodels + xp.ones((nframes, nloc, 1), dtype=GLOBAL_NP_FLOAT_PRECISION, device=dev) + / nmodels for _ in range(nmodels) ] diff --git a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py index 5385d4c56c..6212696ddc 100644 --- a/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/pairtab_atomic_model.py @@ -220,8 +220,11 @@ def forward_atomic( ) # (nframes, nloc, nnei) # (nframes, nloc, nnei), index type is int64. + dev = array_api_compat.device(extended_atype) j_type = extended_atype[ - xp.arange(extended_atype.shape[0], dtype=xp.int64)[:, None, None], + xp.arange(extended_atype.shape[0], dtype=xp.int64, device=dev)[ + :, None, None + ], masked_nlist, ] @@ -327,8 +330,11 @@ def _get_pairwise_dist(coords: Array, nlist: Array) -> Array: The pairwise distance between the atoms (nframes, nloc, nnei). """ xp = array_api_compat.array_namespace(coords, nlist) + dev = array_api_compat.device(nlist) # index type is int64 - batch_indices = xp.arange(nlist.shape[0], dtype=xp.int64)[:, None, None] + batch_indices = xp.arange(nlist.shape[0], dtype=xp.int64, device=dev)[ + :, None, None + ] neighbor_atoms = coords[batch_indices, nlist] loc_atoms = coords[:, : nlist.shape[1], :] pairwise_dr = loc_atoms[:, :, None, :] - neighbor_atoms diff --git a/deepmd/dpmodel/model/make_model.py b/deepmd/dpmodel/model/make_model.py index e115478df5..fea86f3557 100644 --- a/deepmd/dpmodel/model/make_model.py +++ b/deepmd/dpmodel/model/make_model.py @@ -540,7 +540,12 @@ def _format_nlist( ret = xp.concat( [ nlist, - -1 * xp.ones([n_nf, n_nloc, nnei - n_nnei], dtype=nlist.dtype), + -1 + * xp.ones( + [n_nf, n_nloc, nnei - n_nnei], + dtype=nlist.dtype, + device=array_api_compat.device(nlist), + ), ], axis=-1, )