diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py index 61a7151208..2380ab0271 100644 --- a/deepmd/backend/pt_expt.py +++ b/deepmd/backend/pt_expt.py @@ -76,7 +76,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]: type[DeepEvalBackend] The Deep Eval backend of the backend. """ - raise NotImplementedError + from deepmd.pt_expt.infer.deep_eval import ( + DeepEval, + ) + + return DeepEval @property def neighbor_stat(self) -> type["NeighborStat"]: @@ -87,7 +91,11 @@ def neighbor_stat(self) -> type["NeighborStat"]: type[NeighborStat] The neighbor statistics of the backend. """ - raise NotImplementedError + from deepmd.pt_expt.utils.neighbor_stat import ( + NeighborStat, + ) + + return NeighborStat @property def serialize_hook(self) -> Callable[[str], dict]: @@ -98,7 +106,11 @@ def serialize_hook(self) -> Callable[[str], dict]: Callable[[str], dict] The serialize hook of the backend. """ - raise NotImplementedError + from deepmd.pt_expt.utils.serialization import ( + serialize_from_file, + ) + + return serialize_from_file @property def deserialize_hook(self) -> Callable[[str, dict], None]: @@ -109,4 +121,8 @@ def deserialize_hook(self) -> Callable[[str, dict], None]: Callable[[str, dict], None] The deserialize hook of the backend. """ - raise NotImplementedError + from deepmd.pt_expt.utils.serialization import ( + deserialize_to_file, + ) + + return deserialize_to_file diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index e745b28f94..0fed9813dd 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -32,6 +32,15 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array: # torch.take_along_dim requires int64 indices if array_api_compat.is_torch_array(indices): indices = xp.astype(indices, xp.int64) + if array_api_compat.is_torch_array(arr): + # Use torch.gather directly for torch.export dynamic shape compatibility. + # array_api_compat's take_along_axis / torch.take_along_dim specializes + # the source dimension size to a constant during torch.export tracing, + # breaking dynamic shape export. torch.gather is the underlying + # primitive and handles symbolic shapes correctly. + import torch + + return torch.gather(arr, axis, indices) if Version(xp.__array_api_version__) >= Version("2024.12"): # see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39 return xp.take_along_axis(arr, indices, axis=axis) @@ -62,6 +71,24 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array: return xp_swapaxes(out, axis, -1) +def xp_take_first_n(arr: Array, dim: int, n: int) -> Array: + """Take the first *n* elements along *dim*. + + For torch tensors, uses ``torch.index_select`` so that + ``torch.export`` does not emit a contiguity guard that would + prevent the ``nall == nloc`` (no-PBC) case from working. + For numpy / jax, uses regular slicing. + """ + if array_api_compat.is_torch_array(arr): + import torch + + indices = torch.arange(n, dtype=torch.int64, device=arr.device) + return torch.index_select(arr, dim, indices) + slices = [slice(None)] * arr.ndim + slices[dim] = slice(0, n) + return arr[tuple(slices)] + + def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array: """Reduces all values from the src tensor to the indices specified in the index tensor. diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index b87b9a3a6d..1d3ff5aa4a 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -13,6 +13,7 @@ from deepmd.dpmodel.array_api import ( Array, + xp_take_first_n, ) from deepmd.dpmodel.common import ( NativeOP, @@ -250,7 +251,7 @@ def forward_common_atomic( """ xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist) _, nloc, _ = nlist.shape - atype = extended_atype[:, :nloc] + atype = xp_take_first_n(extended_atype, 1, nloc) if self.pair_excl is not None: pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype) # exclude neighbors in the nlist @@ -268,7 +269,7 @@ def forward_common_atomic( ret_dict = self.apply_out_stat(ret_dict, atype) # nf x nloc - atom_mask = ext_atom_mask[:, :nloc] + atom_mask = xp_take_first_n(ext_atom_mask, 1, nloc) if self.atom_excl is not None: atom_mask = xp.logical_and( atom_mask, self.atom_excl.build_type_exclude_mask(atype) diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 5d81ed0538..21245abaa8 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -8,6 +8,7 @@ from deepmd.dpmodel.array_api import ( Array, + xp_take_first_n, ) from deepmd.dpmodel.descriptor.base_descriptor import ( BaseDescriptor, @@ -178,7 +179,7 @@ def forward_atomic( """ nframes, nloc, nnei = nlist.shape - atype = extended_atype[:, :nloc] + atype = xp_take_first_n(extended_atype, 1, nloc) descriptor, rot_mat, g2, h2, sw = self.descriptor( extended_coord, extended_atype, diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index e9407a435b..9856741317 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -11,6 +11,7 @@ from deepmd.dpmodel.array_api import ( Array, xp_take_along_axis, + xp_take_first_n, ) from deepmd.dpmodel.utils.safe_gradient import ( safe_for_vector_norm, @@ -76,7 +77,7 @@ def _make_env_mat( # nf x nloc x nnei x 3 coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3)) # nf x nloc x 1 x 3 - coord_l = xp.reshape(coord[:, :nloc, ...], (nf, -1, 1, 3)) + coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, -1, 1, 3)) # nf x nloc x nnei x 3 diff = coord_r - coord_l # nf x nloc x nnei @@ -153,7 +154,7 @@ def call( xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) em, diff, sw = self._call(nlist, coord_ext, radial_only) nf, nloc, nnei = nlist.shape - atype = atype_ext[:, :nloc] + atype = xp_take_first_n(atype_ext, 1, nloc) if davg is not None: em -= xp.reshape(xp.take(davg, xp.reshape(atype, (-1,)), axis=0), em.shape) if dstd is not None: diff --git a/deepmd/pt_expt/infer/__init__.py b/deepmd/pt_expt/infer/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt_expt/infer/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py new file mode 100644 index 0000000000..a6e1e1e540 --- /dev/null +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -0,0 +1,676 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +from collections.abc import ( + Callable, +) +from typing import ( + TYPE_CHECKING, + Any, + Optional, +) + +import numpy as np +import torch + +from deepmd.dpmodel.model.transform_output import ( + communicate_extended_output, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableCategory, + OutputVariableDef, +) +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, + nlist_distinguish_types, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.infer.deep_dipole import ( + DeepDipole, +) +from deepmd.infer.deep_dos import ( + DeepDOS, +) +from deepmd.infer.deep_eval import DeepEval as DeepEvalWrapper +from deepmd.infer.deep_eval import ( + DeepEvalBackend, +) +from deepmd.infer.deep_polar import ( + DeepPolar, +) +from deepmd.infer.deep_pot import ( + DeepPot, +) +from deepmd.infer.deep_wfc import ( + DeepWFC, +) +from deepmd.pt.utils.auto_batch_size import ( + AutoBatchSize, +) + +if TYPE_CHECKING: + import ase.neighborlist + + +def _reconstruct_model_output_def(metadata: dict) -> ModelOutputDef: + """Reconstruct ModelOutputDef from stored fitting_output_defs metadata.""" + var_defs = [] + for vd in metadata["fitting_output_defs"]: + var_defs.append( + OutputVariableDef( + name=vd["name"], + shape=vd["shape"], + reducible=vd["reducible"], + r_differentiable=vd["r_differentiable"], + c_differentiable=vd["c_differentiable"], + atomic=vd["atomic"], + category=vd["category"], + r_hessian=vd["r_hessian"], + magnetic=vd["magnetic"], + intensive=vd["intensive"], + ) + ) + fitting_output_def = FittingOutputDef(var_defs) + return ModelOutputDef(fitting_output_def) + + +class DeepEval(DeepEvalBackend): + """PyTorch Exportable backend implementation of DeepEval. + + Loads a .pte file containing a torch.export-ed model and evaluates + it using pre-built neighbor lists. + + Parameters + ---------- + model_file : Path + The name of the .pte model file. + output_def : ModelOutputDef + The output definition of the model. + *args : list + Positional arguments. + auto_batch_size : bool or int or AutoBatchSize, default: True + If True, automatic batch size will be used. If int, it will be used + as the initial batch size. + neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional + The ASE neighbor list class to produce the neighbor list. If None, the + neighbor list will be built natively in the model. + **kwargs : dict + Keyword arguments. + """ + + def __init__( + self, + model_file: str, + output_def: ModelOutputDef, + *args: Any, + auto_batch_size: bool | int | AutoBatchSize = True, + neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None, + **kwargs: Any, + ) -> None: + self.output_def = output_def + self.model_path = model_file + self.neighbor_list = neighbor_list + + # Load the exported program with metadata + extra_files = {"model_def_script.json": ""} + exported = torch.export.load(model_file, extra_files=extra_files) + self.exported_module = exported.module() + + # Parse metadata + self.metadata = json.loads(extra_files["model_def_script.json"]) + self.rcut = self.metadata["rcut"] + self.type_map = self.metadata["type_map"] + + # Reconstruct the model output def from stored fitting output defs + self._model_output_def = _reconstruct_model_output_def(self.metadata) + + if isinstance(auto_batch_size, bool): + if auto_batch_size: + self.auto_batch_size = AutoBatchSize() + else: + self.auto_batch_size = None + elif isinstance(auto_batch_size, int): + self.auto_batch_size = AutoBatchSize(auto_batch_size) + elif isinstance(auto_batch_size, AutoBatchSize): + self.auto_batch_size = auto_batch_size + else: + raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize") + + def get_rcut(self) -> float: + """Get the cutoff radius of this model.""" + return self.rcut + + def get_ntypes(self) -> int: + """Get the number of atom types of this model.""" + return len(self.type_map) + + def get_type_map(self) -> list[str]: + """Get the type map (element name of the atom types) of this model.""" + return self.type_map + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this DP.""" + return self.metadata["dim_fparam"] + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this DP.""" + return self.metadata["dim_aparam"] + + @property + def model_type(self) -> type["DeepEvalWrapper"]: + """The the evaluator of the model type.""" + model_output_type = self.metadata["model_output_type"] + if "energy" in model_output_type: + return DeepPot + elif "dos" in model_output_type: + return DeepDOS + elif "dipole" in model_output_type: + return DeepDipole + elif "polar" in model_output_type or "polarizability" in model_output_type: + return DeepPolar + elif "wfc" in model_output_type: + return DeepWFC + else: + raise RuntimeError("Unknown model type") + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.metadata["sel_type"] + + def get_numb_dos(self) -> int: + """Get the number of DOS.""" + return 0 + + def get_has_efield(self) -> bool: + """Check if the model has efield.""" + return False + + def get_ntypes_spin(self) -> int: + """Get the number of spin atom types of this model.""" + return 0 + + def eval( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + atomic: bool = False, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + **kwargs: Any, + ) -> dict[str, np.ndarray]: + """Evaluate the energy, force and virial by using this DP. + + Parameters + ---------- + coords + The coordinates of atoms. + The array should be of size nframes x natoms x 3 + cells + The cell of the region. + If None then non-PBC is assumed, otherwise using PBC. + The array should be of size nframes x 9 + atom_types + The atom types + The list should contain natoms ints + atomic + Calculate the atomic energy and virial + fparam + The frame parameter. + The array should be of size nframes x dim_fparam. + aparam + The atomic parameter. + The array should be of size nframes x natoms x dim_aparam. + **kwargs + Other parameters + + Returns + ------- + output_dict : dict + The output of the evaluation. The keys are the names of the output + variables, and the values are the corresponding output arrays. + """ + atom_types = np.array(atom_types, dtype=np.int32) + coords = np.array(coords) + if cells is not None: + cells = np.array(cells) + natoms, numb_test = self._get_natoms_and_nframes( + coords, atom_types, len(atom_types.shape) > 1 + ) + request_defs = self._get_request_defs(atomic) + out = self._eval_func(self._eval_model, numb_test, natoms)( + coords, cells, atom_types, fparam, aparam, request_defs + ) + return dict( + zip( + [x.name for x in request_defs], + out, + strict=True, + ) + ) + + def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]: + """Get the requested output definitions.""" + if atomic: + return list(self.output_def.var_defs.values()) + else: + return [ + x + for x in self.output_def.var_defs.values() + if x.category + in ( + OutputVariableCategory.REDU, + OutputVariableCategory.DERV_R, + OutputVariableCategory.DERV_C_REDU, + ) + ] + + def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Callable: + """Wrapper method with auto batch size.""" + if self.auto_batch_size is not None: + + def eval_func(*args: Any, **kwargs: Any) -> Any: + return self.auto_batch_size.execute_all( + inner_func, numb_test, natoms, *args, **kwargs + ) + + else: + eval_func = inner_func + return eval_func + + def _get_natoms_and_nframes( + self, + coords: np.ndarray, + atom_types: np.ndarray, + mixed_type: bool = False, + ) -> tuple[int, int]: + if mixed_type: + natoms = len(atom_types[0]) + else: + natoms = len(atom_types) + if natoms == 0: + assert coords.size == 0 + else: + coords = np.reshape(np.array(coords), [-1, natoms * 3]) + nframes = coords.shape[0] + return natoms, nframes + + def _build_nlist_native( + self, + coords: torch.Tensor, + cells: torch.Tensor | None, + atom_types: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Build extended coords, atype, nlist, mapping using native nlist. + + Parameters + ---------- + coords : torch.Tensor + Coordinates, shape (nframes, natoms, 3). + cells : torch.Tensor or None + Cell vectors, shape (nframes, 9). None for non-PBC. + atom_types : torch.Tensor + Atom types, shape (nframes, natoms). + + Returns + ------- + extended_coord, extended_atype, nlist, mapping + All as torch.Tensor on the same device as inputs. + """ + nframes = coords.shape[0] + natoms = coords.shape[1] + rcut = self.rcut + sel = self.metadata["sel"] + mixed_types = self.metadata["mixed_types"] + + if cells is not None: + box_input = cells.reshape(nframes, 3, 3) + coord_normalized = normalize_coord(coords, box_input) + else: + coord_normalized = coords + + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, + atom_types, + cells, + rcut, + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + natoms, + rcut, + sel, + distinguish_types=not mixed_types, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + return extended_coord, extended_atype, nlist, mapping + + def _build_nlist_ase( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Build extended coords, atype, nlist, mapping using ASE neighbor list. + + Handles multiple frames by building per frame and padding to + a common nall. + + Parameters + ---------- + coords : np.ndarray + Coordinates, shape (nframes, natoms, 3). + cells : np.ndarray or None + Cell vectors, shape (nframes, 9). None for non-PBC. + atom_types : np.ndarray + Atom types, shape (nframes, natoms). + + Returns + ------- + extended_coord, extended_atype, nlist, mapping + """ + nframes = coords.shape[0] + frame_results = [] + for ff in range(nframes): + ec, ea, nl, mp = self._build_nlist_ase_single( + coords[ff], + cells[ff] if cells is not None else None, + atom_types[ff], + ) + frame_results.append((ec, ea, nl, mp)) + # Pad to max nall across frames + max_nall = max(ec.shape[0] for ec, _, _, _ in frame_results) + ext_coords, ext_atypes, nlists, mappings = [], [], [], [] + for ec, ea, nl, mp in frame_results: + pad = max_nall - ec.shape[0] + if pad > 0: + ec = np.concatenate( + [ec, np.zeros((pad, 3), dtype=ec.dtype)], + axis=0, + ) + ea = np.concatenate( + [ea, np.full(pad, -1, dtype=ea.dtype)], + axis=0, + ) + mp = np.concatenate( + [mp, np.zeros(pad, dtype=mp.dtype)], + axis=0, + ) + ext_coords.append(ec) + ext_atypes.append(ea) + nlists.append(nl) + mappings.append(mp) + return ( + np.stack(ext_coords, axis=0), + np.stack(ext_atypes, axis=0), + np.stack(nlists, axis=0), + np.stack(mappings, axis=0), + ) + + def _build_nlist_ase_single( + self, + positions: np.ndarray, + cell: np.ndarray | None, + atype: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Build extended coords, atype, nlist, mapping for a single frame. + + Parameters + ---------- + positions : np.ndarray + Atom positions, shape (natoms, 3). + cell : np.ndarray or None + Cell vector, shape (9,). None for non-PBC. + atype : np.ndarray + Atom types, shape (natoms,). + + Returns + ------- + extended_coord : np.ndarray, shape (nall, 3) + extended_atype : np.ndarray, shape (nall,) + nlist : np.ndarray, shape (nloc, nsel) + mapping : np.ndarray, shape (nall,) + """ + sel = self.metadata["sel"] + mixed_types = self.metadata["mixed_types"] + nsel = sum(sel) + + natoms = positions.shape[0] + cell_3x3 = ( + cell.reshape(3, 3) + if cell is not None + else np.zeros((3, 3), dtype=np.float64) + ) + pbc = np.repeat(cell is not None, 3) + + nl = self.neighbor_list + nl.bothways = True + nl.self_interaction = False + if nl.update(pbc, cell_3x3, positions): + nl.build(pbc, cell_3x3, positions) + + first_neigh = nl.first_neigh.copy() + pair_second = nl.pair_second.copy() + offset_vec = nl.offset_vec.copy() + + # Identify ghost atoms (out-of-box neighbors) + out_mask = np.any(offset_vec != 0, axis=1) + out_idx = pair_second[out_mask] + out_offset = offset_vec[out_mask] + out_coords = positions[out_idx] + out_offset.dot(cell_3x3) + out_atype = atype[out_idx] + + nloc = natoms + nghost = out_idx.size + + # Extended arrays (no leading frame dimension) + extended_coord = np.concatenate((positions, out_coords), axis=0) + extended_atype = np.concatenate((atype, out_atype)) + mapping = np.concatenate( + (np.arange(nloc, dtype=np.int32), out_idx.astype(np.int32)) + ) + + # Remap neighbor indices: ghost atoms get new indices [nloc, nloc+nghost) + ghost_remap = pair_second.copy() + ghost_remap[out_mask] = np.arange(nloc, nloc + nghost, dtype=np.int64) + + # Build nlist: vectorized CSR-to-dense conversion + rcut = self.rcut + counts = np.diff(first_neigh) + max_nn = int(counts.max()) if counts.size > 0 else 0 + + # CSR to dense: (nloc, max_nn) neighbor index array, padded with -1 + col_idx = np.arange(len(ghost_remap), dtype=np.int64) - np.repeat( + first_neigh[:-1], counts + ) + row_idx = np.repeat(np.arange(nloc, dtype=np.int64), counts) + dense_idx = np.full((nloc, max_nn), -1, dtype=np.int64) + dense_idx[row_idx, col_idx] = ghost_remap + + # Compute all distances at once + valid = dense_idx >= 0 + lookup = np.where(valid, dense_idx, 0) + neigh_coords = extended_coord[lookup] # (nloc, max_nn, 3) + dists = np.linalg.norm( + neigh_coords - positions[:, None, :], axis=-1 + ) # (nloc, max_nn) + + # Mask invalid and out-of-range, sort by distance + valid &= dists <= rcut + dists = np.where(valid, dists, np.inf) + order = np.argsort(dists, axis=-1) + sorted_idx = np.take_along_axis(dense_idx, order, axis=-1) + sorted_valid = np.take_along_axis(valid, order, axis=-1) + + # Take first nsel neighbors, pad if fewer than nsel + if max_nn >= nsel: + nlist = sorted_idx[:, :nsel] + nlist = np.where(sorted_valid[:, :nsel], nlist, -1) + else: + nlist = np.full((nloc, nsel), -1, dtype=np.int64) + nlist[:, :max_nn] = np.where(sorted_valid, sorted_idx, -1) + + if not mixed_types: + # nlist_distinguish_types expects (nframes, nloc, nsel) + nlist = nlist_distinguish_types( + nlist[None], + extended_atype[None], + sel, + )[0] + + return extended_coord, extended_atype, nlist, mapping + + def _eval_model( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + request_defs: list[OutputVariableDef], + ) -> tuple[np.ndarray, ...]: + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + + coord_input = coords.reshape(nframes, natoms, 3) + if self.neighbor_list is not None: + # ASE path: builds nlist in numpy, then convert to tensors + extended_coord, extended_atype, nlist, mapping = self._build_nlist_ase( + coord_input, + cells, + atom_types, + ) + ext_coord_t = torch.tensor( + extended_coord, dtype=torch.float64, device=DEVICE + ) + ext_atype_t = torch.tensor(extended_atype, dtype=torch.int64, device=DEVICE) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=DEVICE) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=DEVICE) + else: + # Native path: convert to tensors first so array-API functions + # use the torch backend (runs on DEVICE). + coord_t = torch.tensor(coord_input, dtype=torch.float64, device=DEVICE) + atype_t = torch.tensor(atom_types, dtype=torch.int64, device=DEVICE) + cells_t = ( + torch.tensor(cells, dtype=torch.float64, device=DEVICE) + if cells is not None + else None + ) + ext_coord_t, ext_atype_t, nlist_t, mapping_t = self._build_nlist_native( + coord_t, + cells_t, + atype_t, + ) + + if fparam is not None: + fparam_t = torch.tensor( + fparam.reshape(nframes, self.get_dim_fparam()), + dtype=torch.float64, + device=DEVICE, + ) + else: + fparam_t = None + + if aparam is not None: + aparam_t = torch.tensor( + aparam.reshape(nframes, natoms, self.get_dim_aparam()), + dtype=torch.float64, + device=DEVICE, + ) + else: + aparam_t = None + + # Call the exported module (forward_common_lower interface, internal keys) + model_ret = self.exported_module( + ext_coord_t, ext_atype_t, nlist_t, mapping_t, fparam_t, aparam_t + ) + + # Apply communicate_extended_output to map extended atoms → local atoms + do_atomic_virial = any( + x.category == OutputVariableCategory.DERV_C for x in request_defs + ) + model_predict = communicate_extended_output( + model_ret, + self._model_output_def, + mapping_t, + do_atomic_virial=do_atomic_virial, + ) + + # Translate internal keys to backend names and collect results + results = [] + for odef in request_defs: + # odef.name is the internal key (e.g. "energy_derv_r") + # _OUTDEF_DP2BACKEND maps it to backend name (e.g. "force") + # but model_predict uses internal keys from communicate_extended_output + if odef.name in model_predict: + shape = self._get_output_shape(odef, nframes, natoms) + if model_predict[odef.name] is not None: + out = model_predict[odef.name].detach().cpu().numpy().reshape(shape) + else: + out = np.full(shape, np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION) + results.append(out) + else: + shape = self._get_output_shape(odef, nframes, natoms) + results.append( + np.full(np.abs(shape), np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION) + ) + return tuple(results) + + def _get_output_shape( + self, odef: OutputVariableDef, nframes: int, natoms: int + ) -> list[int]: + if odef.category == OutputVariableCategory.DERV_C_REDU: + # virial + return [nframes, *odef.shape[:-1], 9] + elif odef.category == OutputVariableCategory.REDU: + # energy + return [nframes, *odef.shape, 1] + elif odef.category == OutputVariableCategory.DERV_C: + # atom_virial + return [nframes, *odef.shape[:-1], natoms, 9] + elif odef.category == OutputVariableCategory.DERV_R: + # force + return [nframes, *odef.shape[:-1], natoms, 3] + elif odef.category == OutputVariableCategory.OUT: + # atom_energy, atom_tensor + return [nframes, natoms, *odef.shape, 1] + elif odef.category == OutputVariableCategory.DERV_R_DERV_R: + # hessian + return [nframes, 3 * natoms, 3 * natoms] + else: + raise RuntimeError("unknown category") + + def get_model_def_script(self) -> dict: + """Get model definition script.""" + return self.metadata + + def get_model(self) -> torch.nn.Module: + """Get the exported model module. + + Returns + ------- + torch.nn.Module + The exported model module. + """ + return self.exported_module diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index 271028d2ff..53b22ab705 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -126,17 +126,21 @@ def forward_lower_exportable( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, + **make_fx_kwargs: Any, ) -> torch.nn.Module: """Trace ``forward_lower`` into an exportable module. - Uses ``make_fx`` to trace through ``torch.autograd.grad``, - decomposing the backward pass into primitive ops. The returned - module can be passed directly to ``torch.export.export``. + Delegates to ``forward_common_lower_exportable`` for tracing, + then translates the internal keys to the ``forward_lower`` + convention. Parameters ---------- extended_coord, extended_atype, nlist, mapping, fparam, aparam, do_atomic_virial Sample inputs with representative shapes (used for tracing). + **make_fx_kwargs + Extra keyword arguments forwarded to ``make_fx`` + (e.g. ``tracing_mode="symbolic"``). Returns ------- @@ -145,7 +149,21 @@ def forward_lower_exportable( ``(extended_coord, extended_atype, nlist, mapping, fparam, aparam)`` and returns a dict with the same keys as ``forward_lower``. """ - model = self + traced = self.forward_common_lower_exportable( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + **make_fx_kwargs, + ) + + # Translate internal keys to forward_lower convention. + # Capture model config at trace time via closures. + do_grad_r = self.do_grad_r("energy") + do_grad_c = self.do_grad_c("energy") def fn( extended_coord: torch.Tensor, @@ -155,17 +173,24 @@ def fn( fparam: torch.Tensor | None, aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: - extended_coord = extended_coord.detach().requires_grad_(True) - return model.forward_lower( - extended_coord, - extended_atype, - nlist, - mapping, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, + model_ret = traced( + extended_coord, extended_atype, nlist, mapping, fparam, aparam ) + model_predict: dict[str, torch.Tensor] = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if do_grad_r: + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if do_grad_c: + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret[ + "energy_derv_c" + ].squeeze(-2) + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict - return make_fx(fn)( + return make_fx(fn, **make_fx_kwargs)( extended_coord, extended_atype, nlist, mapping, fparam, aparam ) diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index 56cabafe81..4baf3f5c7a 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -4,6 +4,9 @@ ) import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.atomic_model.base_atomic_model import ( BaseAtomicModel, @@ -90,4 +93,72 @@ def forward_common_atomic( mask=atomic_ret.get("mask"), ) + def forward_common_lower_exportable( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + **make_fx_kwargs: Any, + ) -> torch.nn.Module: + """Trace ``forward_common_lower`` into an exportable module. + + Uses ``make_fx`` to trace through ``torch.autograd.grad``, + decomposing the backward pass into primitive ops. The returned + module can be passed directly to ``torch.export.export``. + + The output uses internal key names (e.g. ``energy``, + ``energy_redu``, ``energy_derv_r``) so that + ``communicate_extended_output`` can be applied at inference + time. + + Parameters + ---------- + extended_coord, extended_atype, nlist, mapping, fparam, aparam, do_atomic_virial + Sample inputs with representative shapes (used for tracing). + **make_fx_kwargs + Extra keyword arguments forwarded to ``make_fx`` + (e.g. ``tracing_mode="symbolic"``). + + Returns + ------- + torch.nn.Module + A traced module whose ``forward`` accepts + ``(extended_coord, extended_atype, nlist, mapping, + fparam, aparam)`` and returns a dict with the same keys + as ``call_common_lower``. + """ + model = self + + def fn( + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + extended_coord = extended_coord.detach().requires_grad_(True) + return model.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + return make_fx(fn, **make_fx_kwargs)( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + return CM diff --git a/deepmd/pt_expt/model/transform_output.py b/deepmd/pt_expt/model/transform_output.py index 5fb1ac4e46..265a5a268d 100644 --- a/deepmd/pt_expt/model/transform_output.py +++ b/deepmd/pt_expt/model/transform_output.py @@ -17,20 +17,20 @@ def atomic_virial_corr( extended_coord: torch.Tensor, atom_energy: torch.Tensor, ) -> torch.Tensor: - nall = extended_coord.shape[1] - nf = extended_coord.shape[0] nloc = atom_energy.shape[1] - coord, _ = torch.split(extended_coord, [nloc, nall - nloc], dim=1) + indices = torch.arange(nloc, dtype=torch.int64, device=extended_coord.device) + coord = torch.index_select(extended_coord, 1, indices) # no derivative with respect to the loc coord. coord = coord.detach() ce = coord * atom_energy sumce = torch.sum(ce, dim=1) # [nf, 3] - # Use vmap to batch the 3 backward passes (one per spatial component) - basis = torch.eye(3, dtype=sumce.dtype, device=sumce.device) # [3, 3] - basis = basis.unsqueeze(1).expand(3, nf, 3) # [3, nf, 3] - - def grad_fn(grad_output: torch.Tensor) -> torch.Tensor: + # Explicitly loop over the 3 spatial components instead of vmap, + # so that make_fx(symbolic) and torch.export can trace through. + results = [] + for i in range(3): + grad_output = torch.zeros_like(sumce) + grad_output[:, i] = 1.0 result = torch.autograd.grad( [sumce], [extended_coord], @@ -39,11 +39,10 @@ def grad_fn(grad_output: torch.Tensor) -> torch.Tensor: retain_graph=True, )[0] assert result is not None - return result + results.append(result) - # [3, nf, nall, 3] — batched over the 3 spatial components - extended_virial_corr = torch.vmap(grad_fn)(basis) # [3, nf, nall, 3] -> [nf, nall, 3, 3] + extended_virial_corr = torch.stack(results, dim=0) return extended_virial_corr.permute(1, 2, 3, 0) diff --git a/deepmd/pt_expt/utils/neighbor_stat.py b/deepmd/pt_expt/utils/neighbor_stat.py index cf9d9f3c18..850bf86def 100644 --- a/deepmd/pt_expt/utils/neighbor_stat.py +++ b/deepmd/pt_expt/utils/neighbor_stat.py @@ -12,6 +12,7 @@ ) from deepmd.pt_expt.utils.env import ( DEVICE, + GLOBAL_PT_FLOAT_PRECISION, ) from deepmd.utils.data_system import ( DeepmdDataSystem, @@ -79,9 +80,15 @@ def _execute( ) -> tuple[np.ndarray, np.ndarray]: """Execute the operation on DEVICE.""" minrr2, max_nnei = self.op( - torch.from_numpy(coord).to(DEVICE), - torch.from_numpy(atype).to(DEVICE), - torch.from_numpy(cell).to(DEVICE) if cell is not None else None, + torch.from_numpy(coord).to(device=DEVICE, dtype=GLOBAL_PT_FLOAT_PRECISION), + torch.from_numpy(atype).to(device=DEVICE, dtype=torch.long), + ( + torch.from_numpy(cell).to( + device=DEVICE, dtype=GLOBAL_PT_FLOAT_PRECISION + ) + if cell is not None + else None + ), ) minrr2 = minrr2.detach().cpu().numpy() max_nnei = max_nnei.detach().cpu().numpy() diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py new file mode 100644 index 0000000000..516eac62cd --- /dev/null +++ b/deepmd/pt_expt/utils/serialization.py @@ -0,0 +1,295 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json + +import numpy as np +import torch + +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) +from deepmd.dpmodel.utils.serialization import ( + traverse_model_dict, +) + + +def _numpy_to_json_serializable(model_obj: dict) -> dict: + """Convert numpy arrays in a model dict to JSON-serializable lists.""" + return traverse_model_dict( + model_obj, + lambda x: ( + { + "@class": "np.ndarray", + "@is_variable": True, + "dtype": x.dtype.name, + "value": x.tolist(), + } + if isinstance(x, np.ndarray) + else x + ), + ) + + +def _json_to_numpy(model_obj: dict) -> dict: + """Convert JSON-serialized numpy arrays back to np.ndarray.""" + return traverse_model_dict( + model_obj, + lambda x: ( + np.asarray(x["value"], dtype=np.dtype(x["dtype"])) + if isinstance(x, dict) and x.get("@class") == "np.ndarray" + else x + ), + ) + + +def _make_sample_inputs( + model: torch.nn.Module, + nframes: int = 1, + nloc: int = 2, +) -> tuple[torch.Tensor, ...]: + """Create sample inputs for tracing forward_lower. + + Parameters + ---------- + model : torch.nn.Module + The pt_expt model (must have get_rcut, get_sel, get_type_map, etc.). + nframes : int + Number of frames. + nloc : int + Number of local atoms. + + Returns + ------- + tuple + (ext_coord, ext_atype, nlist, mapping, fparam, aparam) + """ + rcut = model.get_rcut() + sel = model.get_sel() + ntypes = len(model.get_type_map()) + dim_fparam = model.get_dim_fparam() + dim_aparam = model.get_dim_aparam() + mixed_types = model.mixed_types() + + # Create a simple box large enough to avoid PBC issues + box_size = rcut * 3.0 + box = np.eye(3, dtype=np.float64) * box_size + box_np = box.reshape(1, 9) + + # Random coords inside the box + rng = np.random.default_rng(42) + coord_np = rng.random((nframes, nloc, 3), dtype=np.float64) * box_size * 0.5 + coord_np += box_size * 0.25 # center in box + + # Assign atom types: distribute across types + atype_np = np.zeros((nframes, nloc), dtype=np.int32) + for i in range(nloc): + atype_np[:, i] = i % ntypes + + # Normalize and extend + coord_normalized = normalize_coord( + coord_np.reshape(nframes, nloc, 3), + np.tile(box.reshape(1, 3, 3), (nframes, 1, 1)), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, + atype_np, + np.tile(box_np, (nframes, 1)), + rcut, + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=not mixed_types, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + + # Convert to torch tensors + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + + ext_coord = torch.tensor(extended_coord, dtype=torch.float64, device=DEVICE) + ext_atype = torch.tensor(extended_atype, dtype=torch.int64, device=DEVICE) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=DEVICE) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=DEVICE) + + if dim_fparam > 0: + fparam = torch.zeros(nframes, dim_fparam, dtype=torch.float64, device=DEVICE) + else: + fparam = None + + if dim_aparam > 0: + aparam = torch.zeros( + nframes, nloc, dim_aparam, dtype=torch.float64, device=DEVICE + ) + else: + aparam = None + + return ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam + + +def _build_dynamic_shapes( + _ext_coord: torch.Tensor, + _ext_atype: torch.Tensor, + _nlist: torch.Tensor, + _mapping: torch.Tensor, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, +) -> tuple: + """Build dynamic shape specifications for torch.export. + + Marks nframes, nloc and nall as dynamic dimensions so the exported + program handles arbitrary frame and atom counts. + + Returns a tuple (not dict) to match positional args of the make_fx + traced module, whose arg names may have suffixes like ``_1``. + """ + nframes_dim = torch.export.Dim("nframes", min=1) + nall_dim = torch.export.Dim("nall", min=1) + nloc_dim = torch.export.Dim("nloc", min=1) + + return ( + {0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3) + {0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall) + {0: nframes_dim, 1: nloc_dim}, # nlist: (nframes, nloc, nnei) + {0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall) + {0: nframes_dim} if fparam is not None else None, # fparam + {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, # aparam + ) + + +def _collect_metadata(model: torch.nn.Module) -> dict: + """Collect metadata from the model for storage in .pte extra_files.""" + # Serialize the fitting output definitions so that ModelOutputDef + # can be reconstructed at inference time without loading the full model. + fitting_output_def = model.atomic_output_def() + fitting_output_defs = [] + for vdef in fitting_output_def.get_data().values(): + fitting_output_defs.append( + { + "name": vdef.name, + "shape": list(vdef.shape), + "reducible": vdef.reducible, + "r_differentiable": vdef.r_differentiable, + "c_differentiable": vdef.c_differentiable, + "atomic": vdef.atomic, + "category": vdef.category, + "r_hessian": vdef.r_hessian, + "magnetic": vdef.magnetic, + "intensive": vdef.intensive, + } + ) + return { + "type_map": model.get_type_map(), + "rcut": model.get_rcut(), + "sel": model.get_sel(), + "model_output_type": model.model_output_type(), + "dim_fparam": model.get_dim_fparam(), + "dim_aparam": model.get_dim_aparam(), + "mixed_types": model.mixed_types(), + "sel_type": model.get_sel_type(), + "fitting_output_defs": fitting_output_defs, + } + + +def serialize_from_file(model_file: str) -> dict: + """Serialize a .pte model file to a dictionary. + + Reads the model dict stored in the extra_files of the .pte archive. + + Parameters + ---------- + model_file : str + The .pte model file to be serialized. + + Returns + ------- + dict + The serialized model data. + """ + extra_files = {"model.json": ""} + torch.export.load(model_file, extra_files=extra_files) + model_dict = json.loads(extra_files["model.json"]) + model_dict = _json_to_numpy(model_dict) + return model_dict + + +def deserialize_to_file(model_file: str, data: dict) -> None: + """Deserialize a dictionary to a .pte model file. + + Builds a pt_expt model from the dict, traces it via make_fx, + exports with dynamic shapes, and saves using torch.export.save. + + Parameters + ---------- + model_file : str + The .pte model file to be saved. + data : dict + The dictionary to be deserialized (same format as dpmodel's + serialize output, with "model" and optionally "model_def_script" keys). + """ + from deepmd.pt_expt.model.model import ( + BaseModel, + ) + + # 1. Deserialize into a pt_expt model + model = BaseModel.deserialize(data["model"]) + model.eval() + + # 2. Collect metadata + metadata = _collect_metadata(model) + + # 3. Create sample inputs for tracing + # Use nframes=2 so make_fx doesn't specialize on nframes=1 + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = _make_sample_inputs( + model, nframes=2 + ) + + # 4. Trace via forward_common_lower_exportable (make_fx) + # Uses internal keys (energy, energy_redu, energy_derv_r, etc.) + # so that communicate_extended_output can be applied at inference time. + traced = model.forward_common_lower_exportable( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, + ) + + # 5. Build dynamic shapes and export + dynamic_shapes = _build_dynamic_shapes( + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam + ) + exported = torch.export.export( + traced, + (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam), + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + # 6. Prepare extra files + # Serialize the full model dict for cross-backend conversion + from copy import ( + deepcopy, + ) + + data_for_json = deepcopy(data) + data_for_json = _numpy_to_json_serializable(data_for_json) + + extra_files = { + "model_def_script.json": json.dumps(metadata), + "model.json": json.dumps(data_for_json, separators=(",", ":")), + } + + # 7. Save + torch.export.save(exported, model_file, extra_files=extra_files) diff --git a/source/tests/pt_expt/infer/__init__.py b/source/tests/pt_expt/infer/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/infer/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py new file mode 100644 index 0000000000..ef38e1d36f --- /dev/null +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -0,0 +1,490 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for pt_expt inference via the DeepPot / DeepEval interface. + +Verifies the full pipeline: + model.serialize() → deserialize_to_file(.pte) → DeepPot(.pte) → eval() +""" + +import importlib +import tempfile +import unittest + +import numpy as np +import torch + +from deepmd.infer import ( + DeepPot, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + EnergyFittingNet, +) +from deepmd.pt_expt.model import ( + EnergyModel, +) +from deepmd.pt_expt.utils.env import ( + DEVICE, +) +from deepmd.pt_expt.utils.serialization import ( + _make_sample_inputs, + deserialize_to_file, + serialize_from_file, +) + +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDeepEvalEner(unittest.TestCase): + """Test pt_expt inference for energy models.""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + + # Build pt_expt model + ds = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft = EnergyFittingNet( + cls.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ) + cls.model = EnergyModel(ds, ft, type_map=cls.type_map) + cls.model = cls.model.to(torch.float64) + cls.model.eval() + + # Serialize and save to .pte + cls.model_data = {"model": cls.model.serialize()} + cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) + cls.tmpfile.close() + deserialize_to_file(cls.tmpfile.name, cls.model_data) + + # Create DeepPot for testing + cls.dp = DeepPot(cls.tmpfile.name) + + @classmethod + def tearDownClass(cls) -> None: + import os + + os.unlink(cls.tmpfile.name) + + def test_get_rcut(self) -> None: + self.assertAlmostEqual(self.dp.deep_eval.get_rcut(), self.rcut) + + def test_get_ntypes(self) -> None: + self.assertEqual(self.dp.deep_eval.get_ntypes(), self.nt) + + def test_get_type_map(self) -> None: + self.assertEqual(self.dp.deep_eval.get_type_map(), self.type_map) + + def test_get_dim_fparam(self) -> None: + self.assertEqual(self.dp.deep_eval.get_dim_fparam(), 0) + + def test_get_dim_aparam(self) -> None: + self.assertEqual(self.dp.deep_eval.get_dim_aparam(), 0) + + def test_get_sel_type(self) -> None: + sel_type = self.dp.deep_eval.get_sel_type() + self.assertEqual(sel_type, self.model.get_sel_type()) + + def test_model_type(self) -> None: + self.assertIs(self.dp.deep_eval.model_type, DeepPot) + + def test_get_model(self) -> None: + mod = self.dp.deep_eval.get_model() + self.assertIsInstance(mod, torch.nn.Module) + + def test_get_model_def_script(self) -> None: + mds = self.dp.deep_eval.get_model_def_script() + self.assertIsInstance(mds, dict) + self.assertEqual(mds["type_map"], self.type_map) + self.assertAlmostEqual(mds["rcut"], self.rcut) + self.assertEqual(mds["sel"], list(self.sel)) + + def test_eval_consistency(self) -> None: + """Test that DeepPot.eval gives same results as direct model forward.""" + rng = np.random.default_rng(GLOBAL_SEED) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + # .pte inference + e, f, v, ae, av = self.dp.eval(coords, cells, atom_types, atomic=True) + + # Direct model forward + coord_t = torch.tensor( + coords, dtype=torch.float64, device=DEVICE + ).requires_grad_(True) + atype_t = torch.tensor( + atom_types.reshape(1, -1), dtype=torch.int64, device=DEVICE + ) + cell_t = torch.tensor(cells, dtype=torch.float64, device=DEVICE) + ref = self.model.forward(coord_t, atype_t, cell_t, do_atomic_virial=True) + + np.testing.assert_allclose( + e, ref["energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + f, ref["force"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + v, ref["virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + ae, ref["atom_energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + av, ref["atom_virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + + def test_multiple_frames(self) -> None: + """Test evaluation with multiple frames.""" + rng = np.random.default_rng(GLOBAL_SEED + 7) + natoms = 4 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + for nframes in [2, 5]: + coords = rng.random((nframes, natoms, 3)) * 8.0 + cells = np.tile(np.eye(3).reshape(1, 9) * 10.0, (nframes, 1)) + + e, f, v, ae, av = self.dp.eval(coords, cells, atom_types, atomic=True) + + coord_t = torch.tensor( + coords, dtype=torch.float64, device=DEVICE + ).requires_grad_(True) + atype_t = torch.tensor( + np.tile(atom_types, (nframes, 1)), dtype=torch.int64, device=DEVICE + ) + cell_t = torch.tensor(cells, dtype=torch.float64, device=DEVICE) + ref = self.model.forward(coord_t, atype_t, cell_t, do_atomic_virial=True) + + np.testing.assert_allclose( + e, + ref["energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"nframes={nframes}, energy", + ) + np.testing.assert_allclose( + f, + ref["force"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"nframes={nframes}, force", + ) + np.testing.assert_allclose( + v, + ref["virial"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"nframes={nframes}, virial", + ) + np.testing.assert_allclose( + ae, + ref["atom_energy"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"nframes={nframes}, atom_energy", + ) + np.testing.assert_allclose( + av, + ref["atom_virial"].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"nframes={nframes}, atom_virial", + ) + + def test_dynamic_shapes(self) -> None: + """Test that the exported model handles different atom counts. + + Compares exported module output against direct forward_common_lower + for multiple nloc values. + """ + extra_files = {"model_def_script.json": ""} + exported = torch.export.load(self.tmpfile.name, extra_files=extra_files) + exported_mod = exported.module() + + for nloc in [2, 5, 10]: + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = ( + _make_sample_inputs(self.model, nloc=nloc) + ) + + pte_ret = exported_mod( + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam + ) + + ec = ext_coord.detach().requires_grad_(True) + ref_ret = self.model.forward_common_lower( + ec, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + + for key in ("energy", "energy_redu", "energy_derv_r", "energy_derv_c"): + if ref_ret[key] is not None and key in pte_ret: + np.testing.assert_allclose( + ref_ret[key].detach().cpu().numpy(), + pte_ret[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"nloc={nloc}, key={key}", + ) + + def test_serialize_round_trip(self) -> None: + """Test .pte → serialize_from_file → deserialize → model gives same outputs.""" + loaded_data = serialize_from_file(self.tmpfile.name) + + model2 = EnergyModel.deserialize(loaded_data["model"]) + model2 = model2.to(torch.float64) + model2.eval() + + for nloc in [3, 7]: + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = ( + _make_sample_inputs(self.model, nloc=nloc) + ) + ec1 = ext_coord.detach().requires_grad_(True) + ec2 = ext_coord.detach().requires_grad_(True) + + ret1 = self.model.forward_common_lower( + ec1, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + ret2 = model2.forward_common_lower( + ec2, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + + for key in ("energy", "energy_redu", "energy_derv_r", "energy_derv_c"): + if ret1[key] is not None: + np.testing.assert_allclose( + ret1[key].detach().cpu().numpy(), + ret2[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"round-trip nloc={nloc}, key={key}", + ) + + def test_no_pbc(self) -> None: + """Test evaluation without periodic boundary conditions.""" + rng = np.random.default_rng(GLOBAL_SEED + 3) + natoms = 3 + coords = rng.random((1, natoms, 3)) * 5.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + e, f, v = self.dp.eval(coords, None, atom_types) + + coord_t = torch.tensor( + coords, dtype=torch.float64, device=DEVICE + ).requires_grad_(True) + atype_t = torch.tensor( + atom_types.reshape(1, -1), dtype=torch.int64, device=DEVICE + ) + ref = self.model.forward(coord_t, atype_t, box=None) + + np.testing.assert_allclose( + e, ref["energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + f, ref["force"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + v, ref["virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 + ) + + @unittest.skipUnless( + importlib.util.find_spec("ase") is not None, "ase not installed" + ) + def test_ase_neighbor_list_consistency(self) -> None: + """Test that ASE neighbor list gives same results as native nlist.""" + import ase.neighborlist + + rng = np.random.default_rng(GLOBAL_SEED + 11) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + # Eval without ASE neighbor list (native) + e1, f1, v1, ae1, av1 = self.dp.eval( + coords, + cells, + atom_types, + atomic=True, + ) + + # Eval with ASE neighbor list + dp_ase = DeepPot( + self.tmpfile.name, + neighbor_list=ase.neighborlist.NewPrimitiveNeighborList( + cutoffs=self.rcut, + bothways=True, + ), + ) + e2, f2, v2, ae2, av2 = dp_ase.eval( + coords, + cells, + atom_types, + atomic=True, + ) + + np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10, err_msg="energy") + np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10, err_msg="force") + np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10, err_msg="virial") + np.testing.assert_allclose( + ae1, + ae2, + rtol=1e-10, + atol=1e-10, + err_msg="atom_energy", + ) + np.testing.assert_allclose( + av1, + av2, + rtol=1e-10, + atol=1e-10, + err_msg="atom_virial", + ) + + @unittest.skipUnless( + importlib.util.find_spec("ase") is not None, "ase not installed" + ) + def test_build_nlist_ase(self) -> None: + """Test _build_nlist_ase produces the same neighbor sets as native.""" + import ase.neighborlist + + from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, + ) + from deepmd.dpmodel.utils.region import ( + normalize_coord, + ) + + rng = np.random.default_rng(GLOBAL_SEED + 13) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + atom_types_2d = atom_types.reshape(1, -1) + + dp_ase = DeepPot( + self.tmpfile.name, + neighbor_list=ase.neighborlist.NewPrimitiveNeighborList( + cutoffs=self.rcut, + bothways=True, + ), + ) + deep_eval = dp_ase.deep_eval + + # ASE path + ext_coord_ase, _ext_atype_ase, nlist_ase, _mapping_ase = ( + deep_eval._build_nlist_ase(coords, cells, atom_types_2d) + ) + + # Native path + box_input = cells.reshape(1, 3, 3) + coord_normalized = normalize_coord(coords, box_input) + ext_coord_nat, ext_atype_nat, _mapping_nat = extend_coord_with_ghosts( + coord_normalized, + atom_types_2d, + cells, + self.rcut, + ) + sel = self.sel + nlist_nat = build_neighbor_list( + ext_coord_nat, + ext_atype_nat, + natoms, + self.rcut, + sel, + distinguish_types=not self.model.mixed_types(), + ) + ext_coord_nat = ext_coord_nat.reshape(1, -1, 3) + + # Compare: for each local atom, the set of neighbor relative + # coordinates should match (ghost ordering may differ). + for ii in range(natoms): + # ASE neighbors + nn_ase = nlist_ase[0, ii] + mask_ase = nn_ase >= 0 + rel_ase = ext_coord_ase[0, nn_ase[mask_ase]] - coords[0, ii] + + # Native neighbors + nn_nat = nlist_nat[0, ii] + mask_nat = nn_nat >= 0 + rel_nat = ext_coord_nat[0, nn_nat[mask_nat]] - coords[0, ii] + + # Sort by distance then by coordinates for deterministic order + def _sort_key(rel: np.ndarray) -> np.ndarray: + dist = np.linalg.norm(rel, axis=-1, keepdims=True) + return np.concatenate([dist, rel], axis=-1) + + order_ase = np.lexsort(_sort_key(rel_ase).T) + order_nat = np.lexsort(_sort_key(rel_nat).T) + + np.testing.assert_allclose( + rel_ase[order_ase], + rel_nat[order_nat], + rtol=1e-10, + atol=1e-10, + err_msg=f"atom {ii}: neighbor relative coords differ", + ) + + @unittest.skipUnless( + importlib.util.find_spec("ase") is not None, "ase not installed" + ) + def test_ase_nlist_multiple_frames(self) -> None: + """Test ASE neighbor list with multiple frames and auto_batch_size=False.""" + import ase.neighborlist + + rng = np.random.default_rng(GLOBAL_SEED + 17) + natoms = 4 + nframes = 3 + coords = rng.random((nframes, natoms, 3)) * 8.0 + cells = np.tile(np.eye(3).reshape(1, 9) * 10.0, (nframes, 1)) + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + # Native eval (no ASE nlist) + e1, f1, v1 = self.dp.eval(coords, cells, atom_types) + + # ASE nlist with auto_batch_size=False to exercise multi-frame path + dp_ase = DeepPot( + self.tmpfile.name, + neighbor_list=ase.neighborlist.NewPrimitiveNeighborList( + cutoffs=self.rcut, + bothways=True, + ), + auto_batch_size=False, + ) + e2, f2, v2 = dp_ase.eval(coords, cells, atom_types) + + np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10, err_msg="energy") + np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10, err_msg="force") + np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10, err_msg="virial") + + +if __name__ == "__main__": + unittest.main()