From a31fd18ca716ecf4154ef50a29e6bcd181711fde Mon Sep 17 00:00:00 2001 From: Han Wang Date: Tue, 3 Mar 2026 17:37:34 +0800 Subject: [PATCH 01/15] feat(pt_expt): implement .pte inference pipeline with dynamic shapes Implement the full pt_expt inference pipeline: serialize models to .pte files via torch.export, and load them for inference via DeepPot/DeepEval. Key changes: - Add DeepEval backend for .pte files (deepmd/pt_expt/infer/deep_eval.py) - Add serialize/deserialize hooks (deepmd/pt_expt/utils/serialization.py) - Wire up backend hooks in deepmd/backend/pt_expt.py - Add forward_common_lower_exportable using make_fx + torch.export - Support dynamic nframes, nloc, and nall dimensions - Fix atomic_virial_corr to use explicit loop instead of vmap Add xp_take_first_n helper to avoid torch.export contiguity guards on [:, :nloc] slices. When torch.export traces tensor[:, :nloc] on a tensor of size nall, it records a Ne(nall, nloc) guard from the view's contiguity check, which fails when nall == nloc (no PBC). Using torch.index_select instead creates a new tensor, avoiding the guard. --- deepmd/backend/pt_expt.py | 24 +- deepmd/dpmodel/array_api.py | 27 + .../dpmodel/atomic_model/base_atomic_model.py | 5 +- .../dpmodel/atomic_model/dp_atomic_model.py | 3 +- deepmd/dpmodel/utils/env_mat.py | 5 +- deepmd/pt_expt/infer/__init__.py | 1 + deepmd/pt_expt/infer/deep_eval.py | 460 ++++++++++++++++++ deepmd/pt_expt/model/ener_model.py | 48 +- deepmd/pt_expt/model/make_model.py | 69 +++ deepmd/pt_expt/model/transform_output.py | 21 +- deepmd/pt_expt/utils/serialization.py | 293 +++++++++++ source/tests/pt_expt/infer/__init__.py | 1 + source/tests/pt_expt/infer/test_deep_eval.py | 303 ++++++++++++ 13 files changed, 1226 insertions(+), 34 deletions(-) create mode 100644 deepmd/pt_expt/infer/__init__.py create mode 100644 deepmd/pt_expt/infer/deep_eval.py create mode 100644 deepmd/pt_expt/utils/serialization.py create mode 100644 source/tests/pt_expt/infer/__init__.py create mode 100644 source/tests/pt_expt/infer/test_deep_eval.py diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py index ade9eb51f3..4b92d7551a 100644 --- a/deepmd/backend/pt_expt.py +++ b/deepmd/backend/pt_expt.py @@ -76,7 +76,11 @@ def deep_eval(self) -> type["DeepEvalBackend"]: type[DeepEvalBackend] The Deep Eval backend of the backend. """ - raise NotImplementedError + from deepmd.pt_expt.infer.deep_eval import ( + DeepEval, + ) + + return DeepEval @property def neighbor_stat(self) -> type["NeighborStat"]: @@ -87,7 +91,11 @@ def neighbor_stat(self) -> type["NeighborStat"]: type[NeighborStat] The neighbor statistics of the backend. """ - raise NotImplementedError + from deepmd.dpmodel.utils.neighbor_stat import ( + NeighborStat, + ) + + return NeighborStat @property def serialize_hook(self) -> Callable[[str], dict]: @@ -98,7 +106,11 @@ def serialize_hook(self) -> Callable[[str], dict]: Callable[[str], dict] The serialize hook of the backend. """ - raise NotImplementedError + from deepmd.pt_expt.utils.serialization import ( + serialize_from_file, + ) + + return serialize_from_file @property def deserialize_hook(self) -> Callable[[str, dict], None]: @@ -109,4 +121,8 @@ def deserialize_hook(self) -> Callable[[str, dict], None]: Callable[[str, dict], None] The deserialize hook of the backend. """ - raise NotImplementedError + from deepmd.pt_expt.utils.serialization import ( + deserialize_to_file, + ) + + return deserialize_to_file diff --git a/deepmd/dpmodel/array_api.py b/deepmd/dpmodel/array_api.py index e745b28f94..0fed9813dd 100644 --- a/deepmd/dpmodel/array_api.py +++ b/deepmd/dpmodel/array_api.py @@ -32,6 +32,15 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array: # torch.take_along_dim requires int64 indices if array_api_compat.is_torch_array(indices): indices = xp.astype(indices, xp.int64) + if array_api_compat.is_torch_array(arr): + # Use torch.gather directly for torch.export dynamic shape compatibility. + # array_api_compat's take_along_axis / torch.take_along_dim specializes + # the source dimension size to a constant during torch.export tracing, + # breaking dynamic shape export. torch.gather is the underlying + # primitive and handles symbolic shapes correctly. + import torch + + return torch.gather(arr, axis, indices) if Version(xp.__array_api_version__) >= Version("2024.12"): # see: https://github.com/data-apis/array-api-strict/blob/d086c619a58f35c38240592ef994aa19ca7beebc/array_api_strict/_indexing_functions.py#L30-L39 return xp.take_along_axis(arr, indices, axis=axis) @@ -62,6 +71,24 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array: return xp_swapaxes(out, axis, -1) +def xp_take_first_n(arr: Array, dim: int, n: int) -> Array: + """Take the first *n* elements along *dim*. + + For torch tensors, uses ``torch.index_select`` so that + ``torch.export`` does not emit a contiguity guard that would + prevent the ``nall == nloc`` (no-PBC) case from working. + For numpy / jax, uses regular slicing. + """ + if array_api_compat.is_torch_array(arr): + import torch + + indices = torch.arange(n, dtype=torch.int64, device=arr.device) + return torch.index_select(arr, dim, indices) + slices = [slice(None)] * arr.ndim + slices[dim] = slice(0, n) + return arr[tuple(slices)] + + def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array: """Reduces all values from the src tensor to the indices specified in the index tensor. diff --git a/deepmd/dpmodel/atomic_model/base_atomic_model.py b/deepmd/dpmodel/atomic_model/base_atomic_model.py index ecfd08b61a..d8e68e5189 100644 --- a/deepmd/dpmodel/atomic_model/base_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/base_atomic_model.py @@ -13,6 +13,7 @@ from deepmd.dpmodel.array_api import ( Array, + xp_take_first_n, ) from deepmd.dpmodel.common import ( NativeOP, @@ -211,7 +212,7 @@ def forward_common_atomic( """ xp = array_api_compat.array_namespace(extended_coord, extended_atype, nlist) _, nloc, _ = nlist.shape - atype = extended_atype[:, :nloc] + atype = xp_take_first_n(extended_atype, 1, nloc) if self.pair_excl is not None: pair_mask = self.pair_excl.build_type_exclude_mask(nlist, extended_atype) # exclude neighbors in the nlist @@ -229,7 +230,7 @@ def forward_common_atomic( ret_dict = self.apply_out_stat(ret_dict, atype) # nf x nloc - atom_mask = ext_atom_mask[:, :nloc] + atom_mask = xp_take_first_n(ext_atom_mask, 1, nloc) if self.atom_excl is not None: atom_mask = xp.logical_and( atom_mask, self.atom_excl.build_type_exclude_mask(atype) diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 73447de955..329b9f69bc 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -8,6 +8,7 @@ from deepmd.dpmodel.array_api import ( Array, + xp_take_first_n, ) from deepmd.dpmodel.descriptor.base_descriptor import ( BaseDescriptor, @@ -178,7 +179,7 @@ def forward_atomic( """ nframes, nloc, nnei = nlist.shape - atype = extended_atype[:, :nloc] + atype = xp_take_first_n(extended_atype, 1, nloc) descriptor, rot_mat, g2, h2, sw = self.descriptor( extended_coord, extended_atype, diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 5af7a9fc3c..69b39b862e 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -11,6 +11,7 @@ from deepmd.dpmodel.array_api import ( Array, xp_take_along_axis, + xp_take_first_n, ) from deepmd.dpmodel.utils.safe_gradient import ( safe_for_vector_norm, @@ -72,7 +73,7 @@ def _make_env_mat( # nf x nloc x nnei x 3 coord_r = xp.reshape(coord_r, (nf, nloc, nnei, 3)) # nf x nloc x 1 x 3 - coord_l = xp.reshape(coord[:, :nloc, ...], (nf, -1, 1, 3)) + coord_l = xp.reshape(xp_take_first_n(coord, 1, nloc), (nf, -1, 1, 3)) # nf x nloc x nnei x 3 diff = coord_r - coord_l # nf x nloc x nnei @@ -149,7 +150,7 @@ def call( xp = array_api_compat.array_namespace(coord_ext, atype_ext, nlist) em, diff, sw = self._call(nlist, coord_ext, radial_only) nf, nloc, nnei = nlist.shape - atype = atype_ext[:, :nloc] + atype = xp_take_first_n(atype_ext, 1, nloc) if davg is not None: em -= xp.reshape(xp.take(davg, xp.reshape(atype, (-1,)), axis=0), em.shape) if dstd is not None: diff --git a/deepmd/pt_expt/infer/__init__.py b/deepmd/pt_expt/infer/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/deepmd/pt_expt/infer/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py new file mode 100644 index 0000000000..18807c3471 --- /dev/null +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -0,0 +1,460 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json +from collections.abc import ( + Callable, +) +from typing import ( + TYPE_CHECKING, + Any, + Optional, +) + +import numpy as np +import torch + +from deepmd.dpmodel.model.transform_output import ( + communicate_extended_output, +) +from deepmd.dpmodel.output_def import ( + FittingOutputDef, + ModelOutputDef, + OutputVariableCategory, + OutputVariableDef, +) +from deepmd.dpmodel.utils.batch_size import ( + AutoBatchSize, +) +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) +from deepmd.env import ( + GLOBAL_NP_FLOAT_PRECISION, +) +from deepmd.infer.deep_dipole import ( + DeepDipole, +) +from deepmd.infer.deep_dos import ( + DeepDOS, +) +from deepmd.infer.deep_eval import DeepEval as DeepEvalWrapper +from deepmd.infer.deep_eval import ( + DeepEvalBackend, +) +from deepmd.infer.deep_polar import ( + DeepPolar, +) +from deepmd.infer.deep_pot import ( + DeepPot, +) +from deepmd.infer.deep_wfc import ( + DeepWFC, +) + +if TYPE_CHECKING: + import ase.neighborlist + + +def _reconstruct_model_output_def(metadata: dict) -> ModelOutputDef: + """Reconstruct ModelOutputDef from stored fitting_output_defs metadata.""" + var_defs = [] + for vd in metadata["fitting_output_defs"]: + var_defs.append( + OutputVariableDef( + name=vd["name"], + shape=vd["shape"], + reducible=vd["reducible"], + r_differentiable=vd["r_differentiable"], + c_differentiable=vd["c_differentiable"], + atomic=vd["atomic"], + category=vd["category"], + r_hessian=vd["r_hessian"], + magnetic=vd["magnetic"], + intensive=vd["intensive"], + ) + ) + fitting_output_def = FittingOutputDef(var_defs) + return ModelOutputDef(fitting_output_def) + + +class DeepEval(DeepEvalBackend): + """PyTorch Exportable backend implementation of DeepEval. + + Loads a .pte file containing a torch.export-ed model and evaluates + it using pre-built neighbor lists. + + Parameters + ---------- + model_file : Path + The name of the .pte model file. + output_def : ModelOutputDef + The output definition of the model. + *args : list + Positional arguments. + auto_batch_size : bool or int or AutoBatchSize, default: True + If True, automatic batch size will be used. If int, it will be used + as the initial batch size. + neighbor_list : ase.neighborlist.NewPrimitiveNeighborList, optional + The ASE neighbor list class to produce the neighbor list. If None, the + neighbor list will be built natively in the model. + **kwargs : dict + Keyword arguments. + """ + + def __init__( + self, + model_file: str, + output_def: ModelOutputDef, + *args: Any, + auto_batch_size: bool | int | AutoBatchSize = True, + neighbor_list: Optional["ase.neighborlist.NewPrimitiveNeighborList"] = None, + **kwargs: Any, + ) -> None: + self.output_def = output_def + self.model_path = model_file + + # Load the exported program with metadata + extra_files = {"model_def_script.json": ""} + exported = torch.export.load(model_file, extra_files=extra_files) + self.exported_module = exported.module() + + # Parse metadata + self.metadata = json.loads(extra_files["model_def_script.json"]) + self.rcut = self.metadata["rcut"] + self.type_map = self.metadata["type_map"] + + # Reconstruct the model output def from stored fitting output defs + self._model_output_def = _reconstruct_model_output_def(self.metadata) + + if isinstance(auto_batch_size, bool): + if auto_batch_size: + self.auto_batch_size = AutoBatchSize() + else: + self.auto_batch_size = None + elif isinstance(auto_batch_size, int): + self.auto_batch_size = AutoBatchSize(auto_batch_size) + elif isinstance(auto_batch_size, AutoBatchSize): + self.auto_batch_size = auto_batch_size + else: + raise TypeError("auto_batch_size should be bool, int, or AutoBatchSize") + + def get_rcut(self) -> float: + """Get the cutoff radius of this model.""" + return self.rcut + + def get_ntypes(self) -> int: + """Get the number of atom types of this model.""" + return len(self.type_map) + + def get_type_map(self) -> list[str]: + """Get the type map (element name of the atom types) of this model.""" + return self.type_map + + def get_dim_fparam(self) -> int: + """Get the number (dimension) of frame parameters of this DP.""" + return self.metadata["dim_fparam"] + + def get_dim_aparam(self) -> int: + """Get the number (dimension) of atomic parameters of this DP.""" + return self.metadata["dim_aparam"] + + @property + def model_type(self) -> type["DeepEvalWrapper"]: + """The the evaluator of the model type.""" + model_output_type = self.metadata["model_output_type"] + if "energy" in model_output_type: + return DeepPot + elif "dos" in model_output_type: + return DeepDOS + elif "dipole" in model_output_type: + return DeepDipole + elif "polar" in model_output_type or "polarizability" in model_output_type: + return DeepPolar + elif "wfc" in model_output_type: + return DeepWFC + else: + raise RuntimeError("Unknown model type") + + def get_sel_type(self) -> list[int]: + """Get the selected atom types of this model. + + Only atoms with selected atom types have atomic contribution + to the result of the model. + If returning an empty list, all atom types are selected. + """ + return self.metadata["sel_type"] + + def get_numb_dos(self) -> int: + """Get the number of DOS.""" + return 0 + + def get_has_efield(self) -> bool: + """Check if the model has efield.""" + return False + + def get_ntypes_spin(self) -> int: + """Get the number of spin atom types of this model.""" + return 0 + + def eval( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + atomic: bool = False, + fparam: np.ndarray | None = None, + aparam: np.ndarray | None = None, + **kwargs: Any, + ) -> dict[str, np.ndarray]: + """Evaluate the energy, force and virial by using this DP. + + Parameters + ---------- + coords + The coordinates of atoms. + The array should be of size nframes x natoms x 3 + cells + The cell of the region. + If None then non-PBC is assumed, otherwise using PBC. + The array should be of size nframes x 9 + atom_types + The atom types + The list should contain natoms ints + atomic + Calculate the atomic energy and virial + fparam + The frame parameter. + The array can be of size : + - nframes x dim_fparam. + - dim_fparam. Then all frames are assumed to be provided with the same fparam. + aparam + The atomic parameter + The array can be of size : + - nframes x natoms x dim_aparam. + - natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam. + - dim_aparam. Then all frames and atoms are provided with the same aparam. + **kwargs + Other parameters + + Returns + ------- + output_dict : dict + The output of the evaluation. The keys are the names of the output + variables, and the values are the corresponding output arrays. + """ + atom_types = np.array(atom_types, dtype=np.int32) + coords = np.array(coords) + if cells is not None: + cells = np.array(cells) + natoms, numb_test = self._get_natoms_and_nframes( + coords, atom_types, len(atom_types.shape) > 1 + ) + request_defs = self._get_request_defs(atomic) + out = self._eval_func(self._eval_model, numb_test, natoms)( + coords, cells, atom_types, fparam, aparam, request_defs + ) + return dict( + zip( + [x.name for x in request_defs], + out, + strict=True, + ) + ) + + def _get_request_defs(self, atomic: bool) -> list[OutputVariableDef]: + """Get the requested output definitions.""" + if atomic: + return list(self.output_def.var_defs.values()) + else: + return [ + x + for x in self.output_def.var_defs.values() + if x.category + in ( + OutputVariableCategory.REDU, + OutputVariableCategory.DERV_R, + OutputVariableCategory.DERV_C_REDU, + ) + ] + + def _eval_func(self, inner_func: Callable, numb_test: int, natoms: int) -> Callable: + """Wrapper method with auto batch size.""" + if self.auto_batch_size is not None: + + def eval_func(*args: Any, **kwargs: Any) -> Any: + return self.auto_batch_size.execute_all( + inner_func, numb_test, natoms, *args, **kwargs + ) + + else: + eval_func = inner_func + return eval_func + + def _get_natoms_and_nframes( + self, + coords: np.ndarray, + atom_types: np.ndarray, + mixed_type: bool = False, + ) -> tuple[int, int]: + if mixed_type: + natoms = len(atom_types[0]) + else: + natoms = len(atom_types) + if natoms == 0: + assert coords.size == 0 + else: + coords = np.reshape(np.array(coords), [-1, natoms * 3]) + nframes = coords.shape[0] + return natoms, nframes + + def _eval_model( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + request_defs: list[OutputVariableDef], + ) -> tuple[np.ndarray, ...]: + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + rcut = self.rcut + sel = self.metadata["sel"] + mixed_types = self.metadata["mixed_types"] + + coord_input = coords.reshape(nframes, natoms, 3) + if cells is not None: + box_input = cells.reshape(nframes, 3, 3) + else: + box_input = None + + if box_input is not None: + coord_normalized = normalize_coord(coord_input, box_input) + else: + coord_normalized = coord_input + + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, + atom_types, + cells, + rcut, + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + natoms, + rcut, + sel, + distinguish_types=not mixed_types, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + + # Convert to torch tensors + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + + ext_coord_t = torch.tensor(extended_coord, dtype=torch.float64, device=DEVICE) + ext_atype_t = torch.tensor(extended_atype, dtype=torch.int64, device=DEVICE) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=DEVICE) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=DEVICE) + + if fparam is not None: + fparam_t = torch.tensor( + fparam.reshape(nframes, self.get_dim_fparam()), + dtype=torch.float64, + device=DEVICE, + ) + else: + fparam_t = None + + if aparam is not None: + aparam_t = torch.tensor( + aparam.reshape(nframes, natoms, self.get_dim_aparam()), + dtype=torch.float64, + device=DEVICE, + ) + else: + aparam_t = None + + # Call the exported module (forward_common_lower interface, internal keys) + model_ret = self.exported_module( + ext_coord_t, ext_atype_t, nlist_t, mapping_t, fparam_t, aparam_t + ) + + # Apply communicate_extended_output to map extended atoms → local atoms + do_atomic_virial = any( + x.category == OutputVariableCategory.DERV_C for x in request_defs + ) + model_predict = communicate_extended_output( + model_ret, + self._model_output_def, + mapping_t, + do_atomic_virial=do_atomic_virial, + ) + + # Translate internal keys to backend names and collect results + results = [] + for odef in request_defs: + # odef.name is the internal key (e.g. "energy_derv_r") + # _OUTDEF_DP2BACKEND maps it to backend name (e.g. "force") + # but model_predict uses internal keys from communicate_extended_output + if odef.name in model_predict: + shape = self._get_output_shape(odef, nframes, natoms) + if model_predict[odef.name] is not None: + out = model_predict[odef.name].detach().numpy().reshape(shape) + else: + out = np.full(shape, np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION) + results.append(out) + else: + shape = self._get_output_shape(odef, nframes, natoms) + results.append( + np.full(np.abs(shape), np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION) + ) + return tuple(results) + + def _get_output_shape( + self, odef: OutputVariableDef, nframes: int, natoms: int + ) -> list[int]: + if odef.category == OutputVariableCategory.DERV_C_REDU: + # virial + return [nframes, *odef.shape[:-1], 9] + elif odef.category == OutputVariableCategory.REDU: + # energy + return [nframes, *odef.shape, 1] + elif odef.category == OutputVariableCategory.DERV_C: + # atom_virial + return [nframes, *odef.shape[:-1], natoms, 9] + elif odef.category == OutputVariableCategory.DERV_R: + # force + return [nframes, *odef.shape[:-1], natoms, 3] + elif odef.category == OutputVariableCategory.OUT: + # atom_energy, atom_tensor + return [nframes, natoms, *odef.shape, 1] + elif odef.category == OutputVariableCategory.DERV_R_DERV_R: + # hessian + return [nframes, 3 * natoms, 3 * natoms] + else: + raise RuntimeError("unknown category") + + def get_model_def_script(self) -> dict: + """Get model definition script.""" + return self.metadata + + def get_model(self) -> torch.nn.Module: + """Get the exported model module. + + Returns + ------- + torch.nn.Module + The exported model module. + """ + return self.exported_module diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index 271028d2ff..7a15c2d91d 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -129,9 +129,9 @@ def forward_lower_exportable( ) -> torch.nn.Module: """Trace ``forward_lower`` into an exportable module. - Uses ``make_fx`` to trace through ``torch.autograd.grad``, - decomposing the backward pass into primitive ops. The returned - module can be passed directly to ``torch.export.export``. + Delegates to ``forward_common_lower_exportable`` for tracing, + then translates the internal keys to the ``forward_lower`` + convention. Parameters ---------- @@ -145,7 +145,20 @@ def forward_lower_exportable( ``(extended_coord, extended_atype, nlist, mapping, fparam, aparam)`` and returns a dict with the same keys as ``forward_lower``. """ - model = self + traced = self.forward_common_lower_exportable( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + # Translate internal keys to forward_lower convention. + # Capture model config at trace time via closures. + do_grad_r = self.do_grad_r("energy") + do_grad_c = self.do_grad_c("energy") def fn( extended_coord: torch.Tensor, @@ -155,17 +168,24 @@ def fn( fparam: torch.Tensor | None, aparam: torch.Tensor | None, ) -> dict[str, torch.Tensor]: - extended_coord = extended_coord.detach().requires_grad_(True) - return model.forward_lower( - extended_coord, - extended_atype, - nlist, - mapping, - fparam=fparam, - aparam=aparam, - do_atomic_virial=do_atomic_virial, + model_ret = traced( + extended_coord, extended_atype, nlist, mapping, fparam, aparam ) + model_predict: dict[str, torch.Tensor] = {} + model_predict["atom_energy"] = model_ret["energy"] + model_predict["energy"] = model_ret["energy_redu"] + if do_grad_r: + model_predict["extended_force"] = model_ret["energy_derv_r"].squeeze(-2) + if do_grad_c: + model_predict["virial"] = model_ret["energy_derv_c_redu"].squeeze(-2) + if do_atomic_virial: + model_predict["extended_virial"] = model_ret[ + "energy_derv_c" + ].squeeze(-2) + if "mask" in model_ret: + model_predict["mask"] = model_ret["mask"] + return model_predict - return make_fx(fn)( + return make_fx(fn, tracing_mode="symbolic", _allow_non_fake_inputs=True)( extended_coord, extended_atype, nlist, mapping, fparam, aparam ) diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index 56cabafe81..f569b0021e 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -4,6 +4,9 @@ ) import torch +from torch.fx.experimental.proxy_tensor import ( + make_fx, +) from deepmd.dpmodel.atomic_model.base_atomic_model import ( BaseAtomicModel, @@ -90,4 +93,70 @@ def forward_common_atomic( mask=atomic_ret.get("mask"), ) + def forward_common_lower_exportable( + self, + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None = None, + fparam: torch.Tensor | None = None, + aparam: torch.Tensor | None = None, + do_atomic_virial: bool = False, + ) -> torch.nn.Module: + """Trace ``forward_common_lower`` into an exportable module. + + Uses ``make_fx`` with symbolic tracing to trace through + ``torch.autograd.grad``, decomposing the backward pass into + primitive ops while preserving dynamic shapes. The returned + module can be passed directly to ``torch.export.export`` with + dynamic shape specifications. + + The output uses internal key names (e.g. ``energy``, + ``energy_redu``, ``energy_derv_r``) so that + ``communicate_extended_output`` can be applied at inference + time. + + Parameters + ---------- + extended_coord, extended_atype, nlist, mapping, fparam, aparam, do_atomic_virial + Sample inputs with representative shapes (used for tracing). + + Returns + ------- + torch.nn.Module + A traced module whose ``forward`` accepts + ``(extended_coord, extended_atype, nlist, mapping, + fparam, aparam)`` and returns a dict with the same keys + as ``call_common_lower``. + """ + model = self + + def fn( + extended_coord: torch.Tensor, + extended_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor | None, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, + ) -> dict[str, torch.Tensor]: + extended_coord = extended_coord.detach().requires_grad_(True) + return model.forward_common_lower( + extended_coord, + extended_atype, + nlist, + mapping, + fparam=fparam, + aparam=aparam, + do_atomic_virial=do_atomic_virial, + ) + + return make_fx(fn, tracing_mode="symbolic", _allow_non_fake_inputs=True)( + extended_coord, + extended_atype, + nlist, + mapping, + fparam, + aparam, + ) + return CM diff --git a/deepmd/pt_expt/model/transform_output.py b/deepmd/pt_expt/model/transform_output.py index 5fb1ac4e46..265a5a268d 100644 --- a/deepmd/pt_expt/model/transform_output.py +++ b/deepmd/pt_expt/model/transform_output.py @@ -17,20 +17,20 @@ def atomic_virial_corr( extended_coord: torch.Tensor, atom_energy: torch.Tensor, ) -> torch.Tensor: - nall = extended_coord.shape[1] - nf = extended_coord.shape[0] nloc = atom_energy.shape[1] - coord, _ = torch.split(extended_coord, [nloc, nall - nloc], dim=1) + indices = torch.arange(nloc, dtype=torch.int64, device=extended_coord.device) + coord = torch.index_select(extended_coord, 1, indices) # no derivative with respect to the loc coord. coord = coord.detach() ce = coord * atom_energy sumce = torch.sum(ce, dim=1) # [nf, 3] - # Use vmap to batch the 3 backward passes (one per spatial component) - basis = torch.eye(3, dtype=sumce.dtype, device=sumce.device) # [3, 3] - basis = basis.unsqueeze(1).expand(3, nf, 3) # [3, nf, 3] - - def grad_fn(grad_output: torch.Tensor) -> torch.Tensor: + # Explicitly loop over the 3 spatial components instead of vmap, + # so that make_fx(symbolic) and torch.export can trace through. + results = [] + for i in range(3): + grad_output = torch.zeros_like(sumce) + grad_output[:, i] = 1.0 result = torch.autograd.grad( [sumce], [extended_coord], @@ -39,11 +39,10 @@ def grad_fn(grad_output: torch.Tensor) -> torch.Tensor: retain_graph=True, )[0] assert result is not None - return result + results.append(result) - # [3, nf, nall, 3] — batched over the 3 spatial components - extended_virial_corr = torch.vmap(grad_fn)(basis) # [3, nf, nall, 3] -> [nf, nall, 3, 3] + extended_virial_corr = torch.stack(results, dim=0) return extended_virial_corr.permute(1, 2, 3, 0) diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py new file mode 100644 index 0000000000..313d0666f6 --- /dev/null +++ b/deepmd/pt_expt/utils/serialization.py @@ -0,0 +1,293 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import json + +import numpy as np +import torch + +from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, +) +from deepmd.dpmodel.utils.region import ( + normalize_coord, +) +from deepmd.dpmodel.utils.serialization import ( + traverse_model_dict, +) + + +def _numpy_to_json_serializable(model_obj: dict) -> dict: + """Convert numpy arrays in a model dict to JSON-serializable lists.""" + return traverse_model_dict( + model_obj, + lambda x: ( + { + "@class": "np.ndarray", + "@is_variable": True, + "dtype": x.dtype.name, + "value": x.tolist(), + } + if isinstance(x, np.ndarray) + else x + ), + ) + + +def _json_to_numpy(model_obj: dict) -> dict: + """Convert JSON-serialized numpy arrays back to np.ndarray.""" + return traverse_model_dict( + model_obj, + lambda x: ( + np.asarray(x["value"], dtype=np.dtype(x["dtype"])) + if isinstance(x, dict) and x.get("@class") == "np.ndarray" + else x + ), + ) + + +def _make_sample_inputs( + model: torch.nn.Module, + nframes: int = 1, + nloc: int = 2, +) -> tuple[torch.Tensor, ...]: + """Create sample inputs for tracing forward_lower. + + Parameters + ---------- + model : torch.nn.Module + The pt_expt model (must have get_rcut, get_sel, get_type_map, etc.). + nframes : int + Number of frames. + nloc : int + Number of local atoms. + + Returns + ------- + tuple + (ext_coord, ext_atype, nlist, mapping, fparam, aparam) + """ + rcut = model.get_rcut() + sel = model.get_sel() + ntypes = len(model.get_type_map()) + dim_fparam = model.get_dim_fparam() + dim_aparam = model.get_dim_aparam() + mixed_types = model.mixed_types() + + # Create a simple box large enough to avoid PBC issues + box_size = rcut * 3.0 + box = np.eye(3, dtype=np.float64) * box_size + box_np = box.reshape(1, 9) + + # Random coords inside the box + rng = np.random.default_rng(42) + coord_np = rng.random((nframes, nloc, 3), dtype=np.float64) * box_size * 0.5 + coord_np += box_size * 0.25 # center in box + + # Assign atom types: distribute across types + atype_np = np.zeros((nframes, nloc), dtype=np.int32) + for i in range(nloc): + atype_np[:, i] = i % ntypes + + # Normalize and extend + coord_normalized = normalize_coord( + coord_np.reshape(nframes, nloc, 3), + np.tile(box.reshape(1, 3, 3), (nframes, 1, 1)), + ) + extended_coord, extended_atype, mapping = extend_coord_with_ghosts( + coord_normalized, + atype_np, + np.tile(box_np, (nframes, 1)), + rcut, + ) + nlist = build_neighbor_list( + extended_coord, + extended_atype, + nloc, + rcut, + sel, + distinguish_types=not mixed_types, + ) + extended_coord = extended_coord.reshape(nframes, -1, 3) + + # Convert to torch tensors + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + + ext_coord = torch.tensor(extended_coord, dtype=torch.float64, device=DEVICE) + ext_atype = torch.tensor(extended_atype, dtype=torch.int64, device=DEVICE) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=DEVICE) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=DEVICE) + + if dim_fparam > 0: + fparam = torch.zeros(nframes, dim_fparam, dtype=torch.float64, device=DEVICE) + else: + fparam = None + + if dim_aparam > 0: + aparam = torch.zeros( + nframes, nloc, dim_aparam, dtype=torch.float64, device=DEVICE + ) + else: + aparam = None + + return ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam + + +def _build_dynamic_shapes( + ext_coord: torch.Tensor, + ext_atype: torch.Tensor, + nlist: torch.Tensor, + mapping: torch.Tensor, + fparam: torch.Tensor | None, + aparam: torch.Tensor | None, +) -> tuple: + """Build dynamic shape specifications for torch.export. + + Marks nframes, nloc and nall as dynamic dimensions so the exported + program handles arbitrary frame and atom counts. + + Returns a tuple (not dict) to match positional args of the make_fx + traced module, whose arg names may have suffixes like ``_1``. + """ + nframes_dim = torch.export.Dim("nframes", min=1) + nall_dim = torch.export.Dim("nall", min=1) + nloc_dim = torch.export.Dim("nloc", min=1) + + return ( + {0: nframes_dim, 1: nall_dim}, # extended_coord: (nframes, nall, 3) + {0: nframes_dim, 1: nall_dim}, # extended_atype: (nframes, nall) + {0: nframes_dim, 1: nloc_dim}, # nlist: (nframes, nloc, nnei) + {0: nframes_dim, 1: nall_dim}, # mapping: (nframes, nall) + {0: nframes_dim} if fparam is not None else None, # fparam + {0: nframes_dim, 1: nloc_dim} if aparam is not None else None, # aparam + ) + + +def _collect_metadata(model: torch.nn.Module) -> dict: + """Collect metadata from the model for storage in .pte extra_files.""" + # Serialize the fitting output definitions so that ModelOutputDef + # can be reconstructed at inference time without loading the full model. + fitting_output_def = model.atomic_output_def() + fitting_output_defs = [] + for vdef in fitting_output_def.get_data().values(): + fitting_output_defs.append( + { + "name": vdef.name, + "shape": list(vdef.shape), + "reducible": vdef.reducible, + "r_differentiable": vdef.r_differentiable, + "c_differentiable": vdef.c_differentiable, + "atomic": vdef.atomic, + "category": vdef.category, + "r_hessian": vdef.r_hessian, + "magnetic": vdef.magnetic, + "intensive": vdef.intensive, + } + ) + return { + "type_map": model.get_type_map(), + "rcut": model.get_rcut(), + "sel": model.get_sel(), + "model_output_type": model.model_output_type(), + "dim_fparam": model.get_dim_fparam(), + "dim_aparam": model.get_dim_aparam(), + "mixed_types": model.mixed_types(), + "sel_type": model.get_sel_type(), + "fitting_output_defs": fitting_output_defs, + } + + +def serialize_from_file(model_file: str) -> dict: + """Serialize a .pte model file to a dictionary. + + Reads the model dict stored in the extra_files of the .pte archive. + + Parameters + ---------- + model_file : str + The .pte model file to be serialized. + + Returns + ------- + dict + The serialized model data. + """ + extra_files = {"model.json": ""} + torch.export.load(model_file, extra_files=extra_files) + model_dict = json.loads(extra_files["model.json"]) + model_dict = _json_to_numpy(model_dict) + return model_dict + + +def deserialize_to_file(model_file: str, data: dict) -> None: + """Deserialize a dictionary to a .pte model file. + + Builds a pt_expt model from the dict, traces it via make_fx, + exports with dynamic shapes, and saves using torch.export.save. + + Parameters + ---------- + model_file : str + The .pte model file to be saved. + data : dict + The dictionary to be deserialized (same format as dpmodel's + serialize output, with "model" and optionally "model_def_script" keys). + """ + from deepmd.pt_expt.model.model import ( + BaseModel, + ) + + # 1. Deserialize into a pt_expt model + model = BaseModel.deserialize(data["model"]) + model.eval() + + # 2. Collect metadata + metadata = _collect_metadata(model) + + # 3. Create sample inputs for tracing + # Use nframes=2 so make_fx doesn't specialize on nframes=1 + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = _make_sample_inputs( + model, nframes=2 + ) + + # 4. Trace via forward_common_lower_exportable (make_fx) + # Uses internal keys (energy, energy_redu, energy_derv_r, etc.) + # so that communicate_extended_output can be applied at inference time. + traced = model.forward_common_lower_exportable( + ext_coord, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + + # 5. Build dynamic shapes and export + dynamic_shapes = _build_dynamic_shapes( + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam + ) + exported = torch.export.export( + traced, + (ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam), + dynamic_shapes=dynamic_shapes, + strict=False, + ) + + # 6. Prepare extra files + # Serialize the full model dict for cross-backend conversion + from copy import ( + deepcopy, + ) + + data_for_json = deepcopy(data) + data_for_json = _numpy_to_json_serializable(data_for_json) + + extra_files = { + "model_def_script.json": json.dumps(metadata), + "model.json": json.dumps(data_for_json, separators=(",", ":")), + } + + # 7. Save + torch.export.save(exported, model_file, extra_files=extra_files) diff --git a/source/tests/pt_expt/infer/__init__.py b/source/tests/pt_expt/infer/__init__.py new file mode 100644 index 0000000000..6ceb116d85 --- /dev/null +++ b/source/tests/pt_expt/infer/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py new file mode 100644 index 0000000000..54e05a93b4 --- /dev/null +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -0,0 +1,303 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Tests for pt_expt inference via the DeepPot / DeepEval interface. + +Verifies the full pipeline: + model.serialize() → deserialize_to_file(.pte) → DeepPot(.pte) → eval() +""" + +import tempfile +import unittest + +import numpy as np +import torch + +from deepmd.infer import ( + DeepPot, +) +from deepmd.pt_expt.descriptor.se_e2_a import ( + DescrptSeA, +) +from deepmd.pt_expt.fitting import ( + EnergyFittingNet, +) +from deepmd.pt_expt.model import ( + EnergyModel, +) +from deepmd.pt_expt.utils.serialization import ( + _make_sample_inputs, + deserialize_to_file, + serialize_from_file, +) + +from ...seed import ( + GLOBAL_SEED, +) + + +class TestDeepEvalEner(unittest.TestCase): + """Test pt_expt inference for energy models.""" + + @classmethod + def setUpClass(cls) -> None: + cls.rcut = 4.0 + cls.rcut_smth = 0.5 + cls.sel = [8, 6] + cls.nt = 2 + cls.type_map = ["foo", "bar"] + + # Build pt_expt model + ds = DescrptSeA(cls.rcut, cls.rcut_smth, cls.sel) + ft = EnergyFittingNet( + cls.nt, + ds.get_dim_out(), + mixed_types=ds.mixed_types(), + seed=GLOBAL_SEED, + ) + cls.model = EnergyModel(ds, ft, type_map=cls.type_map) + cls.model = cls.model.to(torch.float64) + cls.model.eval() + + # Serialize and save to .pte + cls.model_data = {"model": cls.model.serialize()} + cls.tmpfile = tempfile.NamedTemporaryFile(suffix=".pte", delete=False) + cls.tmpfile.close() + deserialize_to_file(cls.tmpfile.name, cls.model_data) + + # Create DeepPot for testing + cls.dp = DeepPot(cls.tmpfile.name) + + @classmethod + def tearDownClass(cls) -> None: + import os + + os.unlink(cls.tmpfile.name) + + def test_get_rcut(self) -> None: + self.assertAlmostEqual(self.dp.deep_eval.get_rcut(), self.rcut) + + def test_get_ntypes(self) -> None: + self.assertEqual(self.dp.deep_eval.get_ntypes(), self.nt) + + def test_get_type_map(self) -> None: + self.assertEqual(self.dp.deep_eval.get_type_map(), self.type_map) + + def test_get_dim_fparam(self) -> None: + self.assertEqual(self.dp.deep_eval.get_dim_fparam(), 0) + + def test_get_dim_aparam(self) -> None: + self.assertEqual(self.dp.deep_eval.get_dim_aparam(), 0) + + def test_get_sel_type(self) -> None: + sel_type = self.dp.deep_eval.get_sel_type() + self.assertEqual(sel_type, self.model.get_sel_type()) + + def test_model_type(self) -> None: + self.assertIs(self.dp.deep_eval.model_type, DeepPot) + + def test_get_model(self) -> None: + mod = self.dp.deep_eval.get_model() + self.assertIsInstance(mod, torch.nn.Module) + + def test_get_model_def_script(self) -> None: + mds = self.dp.deep_eval.get_model_def_script() + self.assertIsInstance(mds, dict) + self.assertEqual(mds["type_map"], self.type_map) + self.assertAlmostEqual(mds["rcut"], self.rcut) + self.assertEqual(mds["sel"], list(self.sel)) + + def test_eval_consistency(self) -> None: + """Test that DeepPot.eval gives same results as direct model forward.""" + rng = np.random.default_rng(GLOBAL_SEED) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 15.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + # .pte inference + e, f, v, ae, av = self.dp.eval(coords, cells, atom_types, atomic=True) + + # Direct model forward + coord_t = torch.tensor(coords, dtype=torch.float64).requires_grad_(True) + atype_t = torch.tensor(atom_types.reshape(1, -1), dtype=torch.int64) + cell_t = torch.tensor(cells, dtype=torch.float64) + ref = self.model.forward(coord_t, atype_t, cell_t, do_atomic_virial=True) + + np.testing.assert_allclose( + e, ref["energy"].detach().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + f, ref["force"].detach().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + v, ref["virial"].detach().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + ae, ref["atom_energy"].detach().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + av, ref["atom_virial"].detach().numpy(), rtol=1e-10, atol=1e-10 + ) + + def test_multiple_frames(self) -> None: + """Test evaluation with multiple frames.""" + rng = np.random.default_rng(GLOBAL_SEED + 7) + natoms = 4 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + for nframes in [2, 5]: + coords = rng.random((nframes, natoms, 3)) * 8.0 + cells = np.tile(np.eye(3).reshape(1, 9) * 15.0, (nframes, 1)) + + e, f, v, ae, av = self.dp.eval(coords, cells, atom_types, atomic=True) + + coord_t = torch.tensor(coords, dtype=torch.float64).requires_grad_(True) + atype_t = torch.tensor(np.tile(atom_types, (nframes, 1)), dtype=torch.int64) + cell_t = torch.tensor(cells, dtype=torch.float64) + ref = self.model.forward(coord_t, atype_t, cell_t, do_atomic_virial=True) + + np.testing.assert_allclose( + e, + ref["energy"].detach().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"nframes={nframes}, energy", + ) + np.testing.assert_allclose( + f, + ref["force"].detach().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"nframes={nframes}, force", + ) + np.testing.assert_allclose( + v, + ref["virial"].detach().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"nframes={nframes}, virial", + ) + np.testing.assert_allclose( + ae, + ref["atom_energy"].detach().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"nframes={nframes}, atom_energy", + ) + np.testing.assert_allclose( + av, + ref["atom_virial"].detach().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"nframes={nframes}, atom_virial", + ) + + def test_dynamic_shapes(self) -> None: + """Test that the exported model handles different atom counts. + + Compares exported module output against direct forward_common_lower + for multiple nloc values. + """ + extra_files = {"model_def_script.json": ""} + exported = torch.export.load(self.tmpfile.name, extra_files=extra_files) + exported_mod = exported.module() + + for nloc in [2, 5, 10]: + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = ( + _make_sample_inputs(self.model, nloc=nloc) + ) + + pte_ret = exported_mod( + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam + ) + + ec = ext_coord.detach().requires_grad_(True) + ref_ret = self.model.forward_common_lower( + ec, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + + for key in ("energy", "energy_redu", "energy_derv_r", "energy_derv_c"): + if ref_ret[key] is not None and key in pte_ret: + np.testing.assert_allclose( + ref_ret[key].detach().cpu().numpy(), + pte_ret[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"nloc={nloc}, key={key}", + ) + + def test_serialize_round_trip(self) -> None: + """Test .pte → serialize_from_file → deserialize → model gives same outputs.""" + loaded_data = serialize_from_file(self.tmpfile.name) + + model2 = EnergyModel.deserialize(loaded_data["model"]) + model2 = model2.to(torch.float64) + model2.eval() + + for nloc in [3, 7]: + ext_coord, ext_atype, nlist_t, mapping_t, fparam, aparam = ( + _make_sample_inputs(self.model, nloc=nloc) + ) + ec1 = ext_coord.detach().requires_grad_(True) + ec2 = ext_coord.detach().requires_grad_(True) + + ret1 = self.model.forward_common_lower( + ec1, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + ret2 = model2.forward_common_lower( + ec2, + ext_atype, + nlist_t, + mapping_t, + fparam=fparam, + aparam=aparam, + do_atomic_virial=True, + ) + + for key in ("energy", "energy_redu", "energy_derv_r", "energy_derv_c"): + if ret1[key] is not None: + np.testing.assert_allclose( + ret1[key].detach().cpu().numpy(), + ret2[key].detach().cpu().numpy(), + rtol=1e-10, + atol=1e-10, + err_msg=f"round-trip nloc={nloc}, key={key}", + ) + + def test_no_pbc(self) -> None: + """Test evaluation without periodic boundary conditions.""" + rng = np.random.default_rng(GLOBAL_SEED + 3) + natoms = 3 + coords = rng.random((1, natoms, 3)) * 5.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + e, f, v = self.dp.eval(coords, None, atom_types) + + coord_t = torch.tensor(coords, dtype=torch.float64).requires_grad_(True) + atype_t = torch.tensor(atom_types.reshape(1, -1), dtype=torch.int64) + ref = self.model.forward(coord_t, atype_t, box=None) + + np.testing.assert_allclose( + e, ref["energy"].detach().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + f, ref["force"].detach().numpy(), rtol=1e-10, atol=1e-10 + ) + np.testing.assert_allclose( + v, ref["virial"].detach().numpy(), rtol=1e-10, atol=1e-10 + ) + + +if __name__ == "__main__": + unittest.main() From ebf90e353f99cf92ddc07a657c73e7d8e5dd5d63 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 4 Mar 2026 15:22:12 +0800 Subject: [PATCH 02/15] support ase nlist --- deepmd/pt_expt/infer/deep_eval.py | 246 +++++++++++++++++-- source/tests/pt_expt/infer/test_deep_eval.py | 176 ++++++++++++- 2 files changed, 402 insertions(+), 20 deletions(-) diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 18807c3471..628c9fed8e 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -27,6 +27,7 @@ from deepmd.dpmodel.utils.nlist import ( build_neighbor_list, extend_coord_with_ghosts, + nlist_distinguish_types, ) from deepmd.dpmodel.utils.region import ( normalize_coord, @@ -115,6 +116,7 @@ def __init__( ) -> None: self.output_def = output_def self.model_path = model_file + self.neighbor_list = neighbor_list # Load the exported program with metadata extra_files = {"model_def_script.json": ""} @@ -310,36 +312,38 @@ def _get_natoms_and_nframes( nframes = coords.shape[0] return natoms, nframes - def _eval_model( + def _build_nlist_native( self, coords: np.ndarray, cells: np.ndarray | None, atom_types: np.ndarray, - fparam: np.ndarray | None, - aparam: np.ndarray | None, - request_defs: list[OutputVariableDef], - ) -> tuple[np.ndarray, ...]: - nframes = coords.shape[0] - if len(atom_types.shape) == 1: - natoms = len(atom_types) - atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) - else: - natoms = len(atom_types[0]) + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Build extended coords, atype, nlist, mapping using native nlist. + Parameters + ---------- + coords : np.ndarray + Coordinates, shape (nframes, natoms, 3). + cells : np.ndarray or None + Cell vectors, shape (nframes, 9). None for non-PBC. + atom_types : np.ndarray + Atom types, shape (nframes, natoms). + + Returns + ------- + extended_coord, extended_atype, nlist, mapping + """ + nframes = coords.shape[0] + natoms = coords.shape[1] rcut = self.rcut sel = self.metadata["sel"] mixed_types = self.metadata["mixed_types"] - coord_input = coords.reshape(nframes, natoms, 3) if cells is not None: box_input = cells.reshape(nframes, 3, 3) + coord_normalized = normalize_coord(coords, box_input) else: - box_input = None - - if box_input is not None: - coord_normalized = normalize_coord(coord_input, box_input) - else: - coord_normalized = coord_input + coord_normalized = coords extended_coord, extended_atype, mapping = extend_coord_with_ghosts( coord_normalized, @@ -356,6 +360,212 @@ def _eval_model( distinguish_types=not mixed_types, ) extended_coord = extended_coord.reshape(nframes, -1, 3) + return extended_coord, extended_atype, nlist, mapping + + def _build_nlist_ase( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Build extended coords, atype, nlist, mapping using ASE neighbor list. + + Handles multiple frames by building per frame and padding to + a common nall. + + Parameters + ---------- + coords : np.ndarray + Coordinates, shape (nframes, natoms, 3). + cells : np.ndarray or None + Cell vectors, shape (nframes, 9). None for non-PBC. + atom_types : np.ndarray + Atom types, shape (nframes, natoms). + + Returns + ------- + extended_coord, extended_atype, nlist, mapping + """ + nframes = coords.shape[0] + frame_results = [] + for ff in range(nframes): + ec, ea, nl, mp = self._build_nlist_ase_single( + coords[ff], + cells[ff] if cells is not None else None, + atom_types[ff], + ) + frame_results.append((ec, ea, nl, mp)) + # Pad to max nall across frames + max_nall = max(ec.shape[0] for ec, _, _, _ in frame_results) + ext_coords, ext_atypes, nlists, mappings = [], [], [], [] + for ec, ea, nl, mp in frame_results: + pad = max_nall - ec.shape[0] + if pad > 0: + ec = np.concatenate( + [ec, np.zeros((pad, 3), dtype=ec.dtype)], + axis=0, + ) + ea = np.concatenate( + [ea, np.full(pad, -1, dtype=ea.dtype)], + axis=0, + ) + mp = np.concatenate( + [mp, np.zeros(pad, dtype=mp.dtype)], + axis=0, + ) + ext_coords.append(ec) + ext_atypes.append(ea) + nlists.append(nl) + mappings.append(mp) + return ( + np.stack(ext_coords, axis=0), + np.stack(ext_atypes, axis=0), + np.stack(nlists, axis=0), + np.stack(mappings, axis=0), + ) + + def _build_nlist_ase_single( + self, + positions: np.ndarray, + cell: np.ndarray | None, + atype: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """Build extended coords, atype, nlist, mapping for a single frame. + + Parameters + ---------- + positions : np.ndarray + Atom positions, shape (natoms, 3). + cell : np.ndarray or None + Cell vector, shape (9,). None for non-PBC. + atype : np.ndarray + Atom types, shape (natoms,). + + Returns + ------- + extended_coord : np.ndarray, shape (nall, 3) + extended_atype : np.ndarray, shape (nall,) + nlist : np.ndarray, shape (nloc, nsel) + mapping : np.ndarray, shape (nall,) + """ + sel = self.metadata["sel"] + mixed_types = self.metadata["mixed_types"] + nsel = sum(sel) + + natoms = positions.shape[0] + cell_3x3 = ( + cell.reshape(3, 3) + if cell is not None + else np.zeros((3, 3), dtype=np.float64) + ) + pbc = np.repeat(cell is not None, 3) + + nl = self.neighbor_list + nl.bothways = True + nl.self_interaction = False + if nl.update(pbc, cell_3x3, positions): + nl.build(pbc, cell_3x3, positions) + + first_neigh = nl.first_neigh.copy() + pair_second = nl.pair_second.copy() + offset_vec = nl.offset_vec.copy() + + # Identify ghost atoms (out-of-box neighbors) + out_mask = np.any(offset_vec != 0, axis=1) + out_idx = pair_second[out_mask] + out_offset = offset_vec[out_mask] + out_coords = positions[out_idx] + out_offset.dot(cell_3x3) + out_atype = atype[out_idx] + + nloc = natoms + nghost = out_idx.size + + # Extended arrays (no leading frame dimension) + extended_coord = np.concatenate((positions, out_coords), axis=0) + extended_atype = np.concatenate((atype, out_atype)) + mapping = np.concatenate( + (np.arange(nloc, dtype=np.int32), out_idx.astype(np.int32)) + ) + + # Remap neighbor indices: ghost atoms get new indices [nloc, nloc+nghost) + ghost_remap = pair_second.copy() + ghost_remap[out_mask] = np.arange(nloc, nloc + nghost, dtype=np.int64) + + # Build nlist: vectorized CSR-to-dense conversion + rcut = self.rcut + counts = np.diff(first_neigh) + max_nn = int(counts.max()) if counts.size > 0 else 0 + + # CSR to dense: (nloc, max_nn) neighbor index array, padded with -1 + col_idx = np.arange(len(ghost_remap), dtype=np.int64) - np.repeat( + first_neigh[:-1], counts + ) + row_idx = np.repeat(np.arange(nloc, dtype=np.int64), counts) + dense_idx = np.full((nloc, max_nn), -1, dtype=np.int64) + dense_idx[row_idx, col_idx] = ghost_remap + + # Compute all distances at once + valid = dense_idx >= 0 + lookup = np.where(valid, dense_idx, 0) + neigh_coords = extended_coord[lookup] # (nloc, max_nn, 3) + dists = np.linalg.norm( + neigh_coords - positions[:, None, :], axis=-1 + ) # (nloc, max_nn) + + # Mask invalid and out-of-range, sort by distance + valid &= dists <= rcut + dists = np.where(valid, dists, np.inf) + order = np.argsort(dists, axis=-1) + sorted_idx = np.take_along_axis(dense_idx, order, axis=-1) + sorted_valid = np.take_along_axis(valid, order, axis=-1) + + # Take first nsel neighbors, pad if fewer than nsel + if max_nn >= nsel: + nlist = sorted_idx[:, :nsel] + nlist = np.where(sorted_valid[:, :nsel], nlist, -1) + else: + nlist = np.full((nloc, nsel), -1, dtype=np.int64) + nlist[:, :max_nn] = np.where(sorted_valid, sorted_idx, -1) + + if not mixed_types: + # nlist_distinguish_types expects (nframes, nloc, nsel) + nlist = nlist_distinguish_types( + nlist[None], + extended_atype[None], + sel, + )[0] + + return extended_coord, extended_atype, nlist, mapping + + def _eval_model( + self, + coords: np.ndarray, + cells: np.ndarray | None, + atom_types: np.ndarray, + fparam: np.ndarray | None, + aparam: np.ndarray | None, + request_defs: list[OutputVariableDef], + ) -> tuple[np.ndarray, ...]: + nframes = coords.shape[0] + if len(atom_types.shape) == 1: + natoms = len(atom_types) + atom_types = np.tile(atom_types, nframes).reshape(nframes, -1) + else: + natoms = len(atom_types[0]) + + coord_input = coords.reshape(nframes, natoms, 3) + if self.neighbor_list is not None: + extended_coord, extended_atype, nlist, mapping = self._build_nlist_ase( + coord_input, + cells, + atom_types, + ) + else: + extended_coord, extended_atype, nlist, mapping = self._build_nlist_native( + coord_input, + cells, + atom_types, + ) # Convert to torch tensors from deepmd.pt_expt.utils.env import ( diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index 54e05a93b4..1fdb9d75ce 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -5,6 +5,7 @@ model.serialize() → deserialize_to_file(.pte) → DeepPot(.pte) → eval() """ +import importlib import tempfile import unittest @@ -110,7 +111,7 @@ def test_eval_consistency(self) -> None: rng = np.random.default_rng(GLOBAL_SEED) natoms = 5 coords = rng.random((1, natoms, 3)) * 8.0 - cells = np.eye(3).reshape(1, 9) * 15.0 + cells = np.eye(3).reshape(1, 9) * 10.0 atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) # .pte inference @@ -146,7 +147,7 @@ def test_multiple_frames(self) -> None: for nframes in [2, 5]: coords = rng.random((nframes, natoms, 3)) * 8.0 - cells = np.tile(np.eye(3).reshape(1, 9) * 15.0, (nframes, 1)) + cells = np.tile(np.eye(3).reshape(1, 9) * 10.0, (nframes, 1)) e, f, v, ae, av = self.dp.eval(coords, cells, atom_types, atomic=True) @@ -298,6 +299,177 @@ def test_no_pbc(self) -> None: v, ref["virial"].detach().numpy(), rtol=1e-10, atol=1e-10 ) + @unittest.skipUnless( + importlib.util.find_spec("ase") is not None, "ase not installed" + ) + def test_ase_neighbor_list_consistency(self) -> None: + """Test that ASE neighbor list gives same results as native nlist.""" + import ase.neighborlist + + rng = np.random.default_rng(GLOBAL_SEED + 11) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + # Eval without ASE neighbor list (native) + e1, f1, v1, ae1, av1 = self.dp.eval( + coords, + cells, + atom_types, + atomic=True, + ) + + # Eval with ASE neighbor list + dp_ase = DeepPot( + self.tmpfile.name, + neighbor_list=ase.neighborlist.NewPrimitiveNeighborList( + cutoffs=self.rcut, + bothways=True, + ), + ) + e2, f2, v2, ae2, av2 = dp_ase.eval( + coords, + cells, + atom_types, + atomic=True, + ) + + np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10, err_msg="energy") + np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10, err_msg="force") + np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10, err_msg="virial") + np.testing.assert_allclose( + ae1, + ae2, + rtol=1e-10, + atol=1e-10, + err_msg="atom_energy", + ) + np.testing.assert_allclose( + av1, + av2, + rtol=1e-10, + atol=1e-10, + err_msg="atom_virial", + ) + + @unittest.skipUnless( + importlib.util.find_spec("ase") is not None, "ase not installed" + ) + def test_build_nlist_ase(self) -> None: + """Test _build_nlist_ase produces the same neighbor sets as native.""" + import ase.neighborlist + + from deepmd.dpmodel.utils.nlist import ( + build_neighbor_list, + extend_coord_with_ghosts, + ) + from deepmd.dpmodel.utils.region import ( + normalize_coord, + ) + + rng = np.random.default_rng(GLOBAL_SEED + 13) + natoms = 5 + coords = rng.random((1, natoms, 3)) * 8.0 + cells = np.eye(3).reshape(1, 9) * 10.0 + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + atom_types_2d = atom_types.reshape(1, -1) + + dp_ase = DeepPot( + self.tmpfile.name, + neighbor_list=ase.neighborlist.NewPrimitiveNeighborList( + cutoffs=self.rcut, + bothways=True, + ), + ) + deep_eval = dp_ase.deep_eval + + # ASE path + ext_coord_ase, ext_atype_ase, nlist_ase, mapping_ase = ( + deep_eval._build_nlist_ase(coords, cells, atom_types_2d) + ) + + # Native path + box_input = cells.reshape(1, 3, 3) + coord_normalized = normalize_coord(coords, box_input) + ext_coord_nat, ext_atype_nat, mapping_nat = extend_coord_with_ghosts( + coord_normalized, + atom_types_2d, + cells, + self.rcut, + ) + sel = self.sel + nlist_nat = build_neighbor_list( + ext_coord_nat, + ext_atype_nat, + natoms, + self.rcut, + sel, + distinguish_types=not self.model.mixed_types(), + ) + ext_coord_nat = ext_coord_nat.reshape(1, -1, 3) + + # Compare: for each local atom, the set of neighbor relative + # coordinates should match (ghost ordering may differ). + for ii in range(natoms): + # ASE neighbors + nn_ase = nlist_ase[0, ii] + mask_ase = nn_ase >= 0 + rel_ase = ext_coord_ase[0, nn_ase[mask_ase]] - coords[0, ii] + + # Native neighbors + nn_nat = nlist_nat[0, ii] + mask_nat = nn_nat >= 0 + rel_nat = ext_coord_nat[0, nn_nat[mask_nat]] - coords[0, ii] + + # Sort by distance then by coordinates for deterministic order + def _sort_key(rel: np.ndarray) -> np.ndarray: + dist = np.linalg.norm(rel, axis=-1, keepdims=True) + return np.concatenate([dist, rel], axis=-1) + + order_ase = np.lexsort(_sort_key(rel_ase).T) + order_nat = np.lexsort(_sort_key(rel_nat).T) + + np.testing.assert_allclose( + rel_ase[order_ase], + rel_nat[order_nat], + rtol=1e-10, + atol=1e-10, + err_msg=f"atom {ii}: neighbor relative coords differ", + ) + + @unittest.skipUnless( + importlib.util.find_spec("ase") is not None, "ase not installed" + ) + def test_ase_nlist_multiple_frames(self) -> None: + """Test ASE neighbor list with multiple frames and auto_batch_size=False.""" + import ase.neighborlist + + rng = np.random.default_rng(GLOBAL_SEED + 17) + natoms = 4 + nframes = 3 + coords = rng.random((nframes, natoms, 3)) * 8.0 + cells = np.tile(np.eye(3).reshape(1, 9) * 10.0, (nframes, 1)) + atom_types = np.array([i % self.nt for i in range(natoms)], dtype=np.int32) + + # Native eval (no ASE nlist) + e1, f1, v1 = self.dp.eval(coords, cells, atom_types) + + # ASE nlist with auto_batch_size=False to exercise multi-frame path + dp_ase = DeepPot( + self.tmpfile.name, + neighbor_list=ase.neighborlist.NewPrimitiveNeighborList( + cutoffs=self.rcut, + bothways=True, + ), + auto_batch_size=False, + ) + e2, f2, v2 = dp_ase.eval(coords, cells, atom_types) + + np.testing.assert_allclose(e1, e2, rtol=1e-10, atol=1e-10, err_msg="energy") + np.testing.assert_allclose(f1, f2, rtol=1e-10, atol=1e-10, err_msg="force") + np.testing.assert_allclose(v1, v2, rtol=1e-10, atol=1e-10, err_msg="virial") + if __name__ == "__main__": unittest.main() From eed7c2545ee9956e3c8318d3bfd1e5e3625add6e Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 4 Mar 2026 15:25:04 +0800 Subject: [PATCH 03/15] update doc --- deepmd/pt_expt/infer/deep_eval.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 628c9fed8e..11eb88bb5d 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -229,15 +229,10 @@ def eval( Calculate the atomic energy and virial fparam The frame parameter. - The array can be of size : - - nframes x dim_fparam. - - dim_fparam. Then all frames are assumed to be provided with the same fparam. + The array should be of size nframes x dim_fparam. aparam - The atomic parameter - The array can be of size : - - nframes x natoms x dim_aparam. - - natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam. - - dim_aparam. Then all frames and atoms are provided with the same aparam. + The atomic parameter. + The array should be of size nframes x natoms x dim_aparam. **kwargs Other parameters From ae86b1127f40eb16866011ba924bed9266d09fe1 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 4 Mar 2026 15:26:37 +0800 Subject: [PATCH 04/15] fix device conversion --- deepmd/pt_expt/infer/deep_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index 11eb88bb5d..bcd5f4b653 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -615,7 +615,7 @@ def _eval_model( if odef.name in model_predict: shape = self._get_output_shape(odef, nframes, natoms) if model_predict[odef.name] is not None: - out = model_predict[odef.name].detach().numpy().reshape(shape) + out = model_predict[odef.name].detach().cpu().numpy().reshape(shape) else: out = np.full(shape, np.nan, dtype=GLOBAL_NP_FLOAT_PRECISION) results.append(out) From 8fd1153fe3d909299e9cd1151735d4cab31ff822 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 4 Mar 2026 15:30:12 +0800 Subject: [PATCH 05/15] Resolve Ruff ARG001 in _build_dynamic_shapes --- deepmd/pt_expt/utils/serialization.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index 313d0666f6..593e38ed61 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -135,10 +135,10 @@ def _make_sample_inputs( def _build_dynamic_shapes( - ext_coord: torch.Tensor, - ext_atype: torch.Tensor, - nlist: torch.Tensor, - mapping: torch.Tensor, + _ext_coord: torch.Tensor, + _ext_atype: torch.Tensor, + _nlist: torch.Tensor, + _mapping: torch.Tensor, fparam: torch.Tensor | None, aparam: torch.Tensor | None, ) -> tuple: From e065c229309c5cb8a0cbbdf1ae2974c19033fb39 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 4 Mar 2026 16:11:43 +0800 Subject: [PATCH 06/15] Updated the backend to use the GPU-aware NeighborStat from deepmd.pt_expt.utils.neighbor_stat (ported from PR #5270) --- deepmd/backend/pt_expt.py | 2 +- deepmd/pt_expt/utils/neighbor_stat.py | 88 +++++++++++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 deepmd/pt_expt/utils/neighbor_stat.py diff --git a/deepmd/backend/pt_expt.py b/deepmd/backend/pt_expt.py index 4b92d7551a..69f8cb0af8 100644 --- a/deepmd/backend/pt_expt.py +++ b/deepmd/backend/pt_expt.py @@ -91,7 +91,7 @@ def neighbor_stat(self) -> type["NeighborStat"]: type[NeighborStat] The neighbor statistics of the backend. """ - from deepmd.dpmodel.utils.neighbor_stat import ( + from deepmd.pt_expt.utils.neighbor_stat import ( NeighborStat, ) diff --git a/deepmd/pt_expt/utils/neighbor_stat.py b/deepmd/pt_expt/utils/neighbor_stat.py new file mode 100644 index 0000000000..cf9d9f3c18 --- /dev/null +++ b/deepmd/pt_expt/utils/neighbor_stat.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +from collections.abc import ( + Iterator, +) + +import numpy as np +import torch + +from deepmd.dpmodel.utils.neighbor_stat import NeighborStatOP as NeighborStatOPDP +from deepmd.pt_expt.common import ( + torch_module, +) +from deepmd.pt_expt.utils.env import ( + DEVICE, +) +from deepmd.utils.data_system import ( + DeepmdDataSystem, +) +from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat + + +@torch_module +class NeighborStatOP(NeighborStatOPDP): + pass + + +class NeighborStat(BaseNeighborStat): + """Neighbor statistics using torch on DEVICE. + + Parameters + ---------- + ntypes : int + The num of atom types + rcut : float + The cut-off radius + mixed_type : bool, optional, default=False + Treat all types as a single type. + """ + + def __init__( + self, + ntypes: int, + rcut: float, + mixed_type: bool = False, + ) -> None: + super().__init__(ntypes, rcut, mixed_type) + self.op = NeighborStatOP(ntypes, rcut, mixed_type) + + def iterator( + self, data: DeepmdDataSystem + ) -> Iterator[tuple[np.ndarray, float, str]]: + """Produce neighbor statistics for each data set. + + Yields + ------ + np.ndarray + The maximal number of neighbors + float + The squared minimal distance between two atoms + str + The directory of the data system + """ + for ii in range(len(data.system_dirs)): + for jj in data.data_systems[ii].dirs: + data_set = data.data_systems[ii] + data_set_data = data_set._load_set(jj) + minrr2, max_nnei = self._execute( + data_set_data["coord"], + data_set_data["type"], + data_set_data["box"] if data_set.pbc else None, + ) + yield np.max(max_nnei, axis=0), np.min(minrr2), jj + + def _execute( + self, + coord: np.ndarray, + atype: np.ndarray, + cell: np.ndarray | None, + ) -> tuple[np.ndarray, np.ndarray]: + """Execute the operation on DEVICE.""" + minrr2, max_nnei = self.op( + torch.from_numpy(coord).to(DEVICE), + torch.from_numpy(atype).to(DEVICE), + torch.from_numpy(cell).to(DEVICE) if cell is not None else None, + ) + minrr2 = minrr2.detach().cpu().numpy() + max_nnei = max_nnei.detach().cpu().numpy() + return minrr2, max_nnei From 8da4064e5e6ae89903a301838d2d8c551599d9db Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 4 Mar 2026 16:15:48 +0800 Subject: [PATCH 07/15] use PT's AutoBatchSize to catch OOM --- deepmd/pt_expt/infer/deep_eval.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index bcd5f4b653..c5d6767bd0 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -21,9 +21,6 @@ OutputVariableCategory, OutputVariableDef, ) -from deepmd.dpmodel.utils.batch_size import ( - AutoBatchSize, -) from deepmd.dpmodel.utils.nlist import ( build_neighbor_list, extend_coord_with_ghosts, @@ -54,6 +51,9 @@ from deepmd.infer.deep_wfc import ( DeepWFC, ) +from deepmd.pt.utils.auto_batch_size import ( + AutoBatchSize, +) if TYPE_CHECKING: import ase.neighborlist From 291a14bf7c71cc262e1ce8fadd5a0bc240e6129b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 4 Mar 2026 17:28:55 +0800 Subject: [PATCH 08/15] build and extend coords on cuda device --- deepmd/pt_expt/infer/deep_eval.py | 53 +++++++++++++++++++------------ 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/deepmd/pt_expt/infer/deep_eval.py b/deepmd/pt_expt/infer/deep_eval.py index c5d6767bd0..a6e1e1e540 100644 --- a/deepmd/pt_expt/infer/deep_eval.py +++ b/deepmd/pt_expt/infer/deep_eval.py @@ -309,24 +309,25 @@ def _get_natoms_and_nframes( def _build_nlist_native( self, - coords: np.ndarray, - cells: np.ndarray | None, - atom_types: np.ndarray, - ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + coords: torch.Tensor, + cells: torch.Tensor | None, + atom_types: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Build extended coords, atype, nlist, mapping using native nlist. Parameters ---------- - coords : np.ndarray + coords : torch.Tensor Coordinates, shape (nframes, natoms, 3). - cells : np.ndarray or None + cells : torch.Tensor or None Cell vectors, shape (nframes, 9). None for non-PBC. - atom_types : np.ndarray + atom_types : torch.Tensor Atom types, shape (nframes, natoms). Returns ------- extended_coord, extended_atype, nlist, mapping + All as torch.Tensor on the same device as inputs. """ nframes = coords.shape[0] natoms = coords.shape[1] @@ -548,29 +549,39 @@ def _eval_model( else: natoms = len(atom_types[0]) + from deepmd.pt_expt.utils.env import ( + DEVICE, + ) + coord_input = coords.reshape(nframes, natoms, 3) if self.neighbor_list is not None: + # ASE path: builds nlist in numpy, then convert to tensors extended_coord, extended_atype, nlist, mapping = self._build_nlist_ase( coord_input, cells, atom_types, ) + ext_coord_t = torch.tensor( + extended_coord, dtype=torch.float64, device=DEVICE + ) + ext_atype_t = torch.tensor(extended_atype, dtype=torch.int64, device=DEVICE) + nlist_t = torch.tensor(nlist, dtype=torch.int64, device=DEVICE) + mapping_t = torch.tensor(mapping, dtype=torch.int64, device=DEVICE) else: - extended_coord, extended_atype, nlist, mapping = self._build_nlist_native( - coord_input, - cells, - atom_types, + # Native path: convert to tensors first so array-API functions + # use the torch backend (runs on DEVICE). + coord_t = torch.tensor(coord_input, dtype=torch.float64, device=DEVICE) + atype_t = torch.tensor(atom_types, dtype=torch.int64, device=DEVICE) + cells_t = ( + torch.tensor(cells, dtype=torch.float64, device=DEVICE) + if cells is not None + else None + ) + ext_coord_t, ext_atype_t, nlist_t, mapping_t = self._build_nlist_native( + coord_t, + cells_t, + atype_t, ) - - # Convert to torch tensors - from deepmd.pt_expt.utils.env import ( - DEVICE, - ) - - ext_coord_t = torch.tensor(extended_coord, dtype=torch.float64, device=DEVICE) - ext_atype_t = torch.tensor(extended_atype, dtype=torch.int64, device=DEVICE) - nlist_t = torch.tensor(nlist, dtype=torch.int64, device=DEVICE) - mapping_t = torch.tensor(mapping, dtype=torch.int64, device=DEVICE) if fparam is not None: fparam_t = torch.tensor( From 3304fe8866e93bba59ed0f0c391166ee8433dae2 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 4 Mar 2026 17:51:07 +0800 Subject: [PATCH 09/15] fix --- source/tests/pt_expt/infer/test_deep_eval.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index 1fdb9d75ce..8080fae9df 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -385,14 +385,14 @@ def test_build_nlist_ase(self) -> None: deep_eval = dp_ase.deep_eval # ASE path - ext_coord_ase, ext_atype_ase, nlist_ase, mapping_ase = ( + ext_coord_ase, _ext_atype_ase, nlist_ase, _mapping_ase = ( deep_eval._build_nlist_ase(coords, cells, atom_types_2d) ) # Native path box_input = cells.reshape(1, 3, 3) coord_normalized = normalize_coord(coords, box_input) - ext_coord_nat, ext_atype_nat, mapping_nat = extend_coord_with_ghosts( + ext_coord_nat, ext_atype_nat, _mapping_nat = extend_coord_with_ghosts( coord_normalized, atom_types_2d, cells, From 57d2c8aa96c8e6f4b58611ebb2fa45c8cafca316 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 4 Mar 2026 17:59:02 +0800 Subject: [PATCH 10/15] assign dtype for torch tensors input to neighbor stat --- deepmd/pt_expt/utils/neighbor_stat.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/deepmd/pt_expt/utils/neighbor_stat.py b/deepmd/pt_expt/utils/neighbor_stat.py index cf9d9f3c18..850bf86def 100644 --- a/deepmd/pt_expt/utils/neighbor_stat.py +++ b/deepmd/pt_expt/utils/neighbor_stat.py @@ -12,6 +12,7 @@ ) from deepmd.pt_expt.utils.env import ( DEVICE, + GLOBAL_PT_FLOAT_PRECISION, ) from deepmd.utils.data_system import ( DeepmdDataSystem, @@ -79,9 +80,15 @@ def _execute( ) -> tuple[np.ndarray, np.ndarray]: """Execute the operation on DEVICE.""" minrr2, max_nnei = self.op( - torch.from_numpy(coord).to(DEVICE), - torch.from_numpy(atype).to(DEVICE), - torch.from_numpy(cell).to(DEVICE) if cell is not None else None, + torch.from_numpy(coord).to(device=DEVICE, dtype=GLOBAL_PT_FLOAT_PRECISION), + torch.from_numpy(atype).to(device=DEVICE, dtype=torch.long), + ( + torch.from_numpy(cell).to( + device=DEVICE, dtype=GLOBAL_PT_FLOAT_PRECISION + ) + if cell is not None + else None + ), ) minrr2 = minrr2.detach().cpu().numpy() max_nnei = max_nnei.detach().cpu().numpy() From 1a2624e79a28668eb8e2408ae8f2e7695592e9f4 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 4 Mar 2026 18:12:45 +0800 Subject: [PATCH 11/15] provide options for make_fx rather than hard coding --- deepmd/pt_expt/model/ener_model.py | 7 ++++++- deepmd/pt_expt/model/make_model.py | 14 ++++++++------ deepmd/pt_expt/utils/serialization.py | 2 ++ 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/deepmd/pt_expt/model/ener_model.py b/deepmd/pt_expt/model/ener_model.py index 7a15c2d91d..53b22ab705 100644 --- a/deepmd/pt_expt/model/ener_model.py +++ b/deepmd/pt_expt/model/ener_model.py @@ -126,6 +126,7 @@ def forward_lower_exportable( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, + **make_fx_kwargs: Any, ) -> torch.nn.Module: """Trace ``forward_lower`` into an exportable module. @@ -137,6 +138,9 @@ def forward_lower_exportable( ---------- extended_coord, extended_atype, nlist, mapping, fparam, aparam, do_atomic_virial Sample inputs with representative shapes (used for tracing). + **make_fx_kwargs + Extra keyword arguments forwarded to ``make_fx`` + (e.g. ``tracing_mode="symbolic"``). Returns ------- @@ -153,6 +157,7 @@ def forward_lower_exportable( fparam=fparam, aparam=aparam, do_atomic_virial=do_atomic_virial, + **make_fx_kwargs, ) # Translate internal keys to forward_lower convention. @@ -186,6 +191,6 @@ def fn( model_predict["mask"] = model_ret["mask"] return model_predict - return make_fx(fn, tracing_mode="symbolic", _allow_non_fake_inputs=True)( + return make_fx(fn, **make_fx_kwargs)( extended_coord, extended_atype, nlist, mapping, fparam, aparam ) diff --git a/deepmd/pt_expt/model/make_model.py b/deepmd/pt_expt/model/make_model.py index f569b0021e..4baf3f5c7a 100644 --- a/deepmd/pt_expt/model/make_model.py +++ b/deepmd/pt_expt/model/make_model.py @@ -102,14 +102,13 @@ def forward_common_lower_exportable( fparam: torch.Tensor | None = None, aparam: torch.Tensor | None = None, do_atomic_virial: bool = False, + **make_fx_kwargs: Any, ) -> torch.nn.Module: """Trace ``forward_common_lower`` into an exportable module. - Uses ``make_fx`` with symbolic tracing to trace through - ``torch.autograd.grad``, decomposing the backward pass into - primitive ops while preserving dynamic shapes. The returned - module can be passed directly to ``torch.export.export`` with - dynamic shape specifications. + Uses ``make_fx`` to trace through ``torch.autograd.grad``, + decomposing the backward pass into primitive ops. The returned + module can be passed directly to ``torch.export.export``. The output uses internal key names (e.g. ``energy``, ``energy_redu``, ``energy_derv_r``) so that @@ -120,6 +119,9 @@ def forward_common_lower_exportable( ---------- extended_coord, extended_atype, nlist, mapping, fparam, aparam, do_atomic_virial Sample inputs with representative shapes (used for tracing). + **make_fx_kwargs + Extra keyword arguments forwarded to ``make_fx`` + (e.g. ``tracing_mode="symbolic"``). Returns ------- @@ -150,7 +152,7 @@ def fn( do_atomic_virial=do_atomic_virial, ) - return make_fx(fn, tracing_mode="symbolic", _allow_non_fake_inputs=True)( + return make_fx(fn, **make_fx_kwargs)( extended_coord, extended_atype, nlist, diff --git a/deepmd/pt_expt/utils/serialization.py b/deepmd/pt_expt/utils/serialization.py index 593e38ed61..516eac62cd 100644 --- a/deepmd/pt_expt/utils/serialization.py +++ b/deepmd/pt_expt/utils/serialization.py @@ -262,6 +262,8 @@ def deserialize_to_file(model_file: str, data: dict) -> None: fparam=fparam, aparam=aparam, do_atomic_virial=True, + tracing_mode="symbolic", + _allow_non_fake_inputs=True, ) # 5. Build dynamic shapes and export From c415ec3f74354c3ba1a1e6631140956435115a0f Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 4 Mar 2026 18:20:36 +0800 Subject: [PATCH 12/15] align the implementation of env_mat with the training branch --- deepmd/dpmodel/utils/env_mat.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/utils/env_mat.py b/deepmd/dpmodel/utils/env_mat.py index 69b39b862e..9856741317 100644 --- a/deepmd/dpmodel/utils/env_mat.py +++ b/deepmd/dpmodel/utils/env_mat.py @@ -27,7 +27,10 @@ def compute_smooth_weight( if rmin >= rmax: raise ValueError("rmin should be less than rmax.") xp = array_api_compat.array_namespace(distance) - distance = xp.clip(distance, min=rmin, max=rmax) + # Use where instead of clip so that make_fx tracing does not + # decompose it into boolean-indexed ops with data-dependent sizes. + distance = xp.where(distance < rmin, xp.full_like(distance, rmin), distance) + distance = xp.where(distance > rmax, xp.full_like(distance, rmax), distance) uu = (distance - rmin) / (rmax - rmin) uu2 = uu * uu vv = uu2 * uu * (-6.0 * uu2 + 15.0 * uu - 10.0) + 1.0 @@ -43,7 +46,8 @@ def compute_exp_sw( if rmin >= rmax: raise ValueError("rmin should be less than rmax.") xp = array_api_compat.array_namespace(distance) - distance = xp.clip(distance, min=0.0, max=rmax) + distance = xp.where(distance < 0.0, xp.zeros_like(distance), distance) + distance = xp.where(distance > rmax, xp.full_like(distance, rmax), distance) C = 20 a = C / rmin b = rmin From 40de8cac641124af8ec5e2249a6d2e0d8e7b8e41 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Wed, 4 Mar 2026 18:28:18 +0800 Subject: [PATCH 13/15] borrow conftest.py --- source/tests/pt_expt/conftest.py | 50 ++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/source/tests/pt_expt/conftest.py b/source/tests/pt_expt/conftest.py index ec025c2202..21ca13e419 100644 --- a/source/tests/pt_expt/conftest.py +++ b/source/tests/pt_expt/conftest.py @@ -1,4 +1,54 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +"""Conftest for pt_expt tests. + +Clears any leaked ``torch.utils._device.DeviceContext`` modes that may +have been left on the torch function mode stack by ``make_fx`` or other +tracing utilities during test collection. A stale ``DeviceContext`` +silently reroutes ``torch.tensor(...)`` calls (without an explicit +``device=``) to a fake CUDA device, causing spurious "no NVIDIA driver" +errors on CPU-only machines. + +The leak is triggered when pytest collects descriptor test modules that +import ``make_fx``. A ``DeviceContext(cuda:127)`` ends up on the +``torch.overrides`` function mode stack and is never popped. + +Our own code (``display_if_exist`` in ``deepmd/dpmodel/loss/loss.py``) +is already fixed to pass ``device=`` explicitly. However, PyTorch's +``Adam._init_group`` (``torch/optim/adam.py``) contains:: + + torch.tensor(0.0, dtype=_get_scalar_dtype()) # no device= + +on the ``capturable=False, fused=False`` path (the default). This is +a PyTorch bug — the ``capturable=True`` branch correctly uses +``device=p.device`` but the default branch omits it. We cannot fix +PyTorch internals, so this fixture works around the issue by popping +leaked ``DeviceContext`` modes before each test. +""" + import pytest +import torch.utils._device as _device +from torch.overrides import ( + _get_current_function_mode_stack, +) pytest.importorskip("torch") + + +@pytest.fixture(autouse=True) +def _clear_leaked_device_context(): + """Pop any stale ``DeviceContext`` before each test, restore after.""" + popped = [] + while True: + modes = _get_current_function_mode_stack() + if not modes: + break + top = modes[-1] + if isinstance(top, _device.DeviceContext): + top.__exit__(None, None, None) + popped.append(top) + else: + break + yield + # Restore in reverse order so the stack is back to its original state. + for ctx in reversed(popped): + ctx.__enter__() From 497f26ae9a52a1534571ee4527612de861b0708b Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 5 Mar 2026 12:16:50 +0800 Subject: [PATCH 14/15] fix : test device --- source/tests/pt_expt/infer/test_deep_eval.py | 57 ++++++++++++-------- 1 file changed, 36 insertions(+), 21 deletions(-) diff --git a/source/tests/pt_expt/infer/test_deep_eval.py b/source/tests/pt_expt/infer/test_deep_eval.py index 8080fae9df..ef38e1d36f 100644 --- a/source/tests/pt_expt/infer/test_deep_eval.py +++ b/source/tests/pt_expt/infer/test_deep_eval.py @@ -24,6 +24,9 @@ from deepmd.pt_expt.model import ( EnergyModel, ) +from deepmd.pt_expt.utils.env import ( + DEVICE, +) from deepmd.pt_expt.utils.serialization import ( _make_sample_inputs, deserialize_to_file, @@ -118,25 +121,29 @@ def test_eval_consistency(self) -> None: e, f, v, ae, av = self.dp.eval(coords, cells, atom_types, atomic=True) # Direct model forward - coord_t = torch.tensor(coords, dtype=torch.float64).requires_grad_(True) - atype_t = torch.tensor(atom_types.reshape(1, -1), dtype=torch.int64) - cell_t = torch.tensor(cells, dtype=torch.float64) + coord_t = torch.tensor( + coords, dtype=torch.float64, device=DEVICE + ).requires_grad_(True) + atype_t = torch.tensor( + atom_types.reshape(1, -1), dtype=torch.int64, device=DEVICE + ) + cell_t = torch.tensor(cells, dtype=torch.float64, device=DEVICE) ref = self.model.forward(coord_t, atype_t, cell_t, do_atomic_virial=True) np.testing.assert_allclose( - e, ref["energy"].detach().numpy(), rtol=1e-10, atol=1e-10 + e, ref["energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 ) np.testing.assert_allclose( - f, ref["force"].detach().numpy(), rtol=1e-10, atol=1e-10 + f, ref["force"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 ) np.testing.assert_allclose( - v, ref["virial"].detach().numpy(), rtol=1e-10, atol=1e-10 + v, ref["virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 ) np.testing.assert_allclose( - ae, ref["atom_energy"].detach().numpy(), rtol=1e-10, atol=1e-10 + ae, ref["atom_energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 ) np.testing.assert_allclose( - av, ref["atom_virial"].detach().numpy(), rtol=1e-10, atol=1e-10 + av, ref["atom_virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 ) def test_multiple_frames(self) -> None: @@ -151,42 +158,46 @@ def test_multiple_frames(self) -> None: e, f, v, ae, av = self.dp.eval(coords, cells, atom_types, atomic=True) - coord_t = torch.tensor(coords, dtype=torch.float64).requires_grad_(True) - atype_t = torch.tensor(np.tile(atom_types, (nframes, 1)), dtype=torch.int64) - cell_t = torch.tensor(cells, dtype=torch.float64) + coord_t = torch.tensor( + coords, dtype=torch.float64, device=DEVICE + ).requires_grad_(True) + atype_t = torch.tensor( + np.tile(atom_types, (nframes, 1)), dtype=torch.int64, device=DEVICE + ) + cell_t = torch.tensor(cells, dtype=torch.float64, device=DEVICE) ref = self.model.forward(coord_t, atype_t, cell_t, do_atomic_virial=True) np.testing.assert_allclose( e, - ref["energy"].detach().numpy(), + ref["energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10, err_msg=f"nframes={nframes}, energy", ) np.testing.assert_allclose( f, - ref["force"].detach().numpy(), + ref["force"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10, err_msg=f"nframes={nframes}, force", ) np.testing.assert_allclose( v, - ref["virial"].detach().numpy(), + ref["virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10, err_msg=f"nframes={nframes}, virial", ) np.testing.assert_allclose( ae, - ref["atom_energy"].detach().numpy(), + ref["atom_energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10, err_msg=f"nframes={nframes}, atom_energy", ) np.testing.assert_allclose( av, - ref["atom_virial"].detach().numpy(), + ref["atom_virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10, err_msg=f"nframes={nframes}, atom_virial", @@ -285,18 +296,22 @@ def test_no_pbc(self) -> None: e, f, v = self.dp.eval(coords, None, atom_types) - coord_t = torch.tensor(coords, dtype=torch.float64).requires_grad_(True) - atype_t = torch.tensor(atom_types.reshape(1, -1), dtype=torch.int64) + coord_t = torch.tensor( + coords, dtype=torch.float64, device=DEVICE + ).requires_grad_(True) + atype_t = torch.tensor( + atom_types.reshape(1, -1), dtype=torch.int64, device=DEVICE + ) ref = self.model.forward(coord_t, atype_t, box=None) np.testing.assert_allclose( - e, ref["energy"].detach().numpy(), rtol=1e-10, atol=1e-10 + e, ref["energy"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 ) np.testing.assert_allclose( - f, ref["force"].detach().numpy(), rtol=1e-10, atol=1e-10 + f, ref["force"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 ) np.testing.assert_allclose( - v, ref["virial"].detach().numpy(), rtol=1e-10, atol=1e-10 + v, ref["virial"].detach().cpu().numpy(), rtol=1e-10, atol=1e-10 ) @unittest.skipUnless( From 143d5c4b4de96403375b78a73f7e768e904dc435 Mon Sep 17 00:00:00 2001 From: Han Wang Date: Thu, 5 Mar 2026 16:09:29 +0800 Subject: [PATCH 15/15] revert change on inputs --- examples/water/dpa3/input_torch.json | 6 ++++++ examples/water/dpa3/input_torch_dynamic.json | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/examples/water/dpa3/input_torch.json b/examples/water/dpa3/input_torch.json index 7b67e8cf55..6aaa6e2776 100644 --- a/examples/water/dpa3/input_torch.json +++ b/examples/water/dpa3/input_torch.json @@ -68,6 +68,12 @@ "limit_pref_v": 1, "_comment": " that's all" }, + "optimizer": { + "type": "AdamW", + "adam_beta1": 0.9, + "adam_beta2": 0.999, + "weight_decay": 0.001 + }, "training": { "stat_file": "./dpa3.hdf5", "training_data": { diff --git a/examples/water/dpa3/input_torch_dynamic.json b/examples/water/dpa3/input_torch_dynamic.json index edb7e53414..dbd93cd4e7 100644 --- a/examples/water/dpa3/input_torch_dynamic.json +++ b/examples/water/dpa3/input_torch_dynamic.json @@ -70,6 +70,12 @@ "limit_pref_v": 1, "_comment": " that's all" }, + "optimizer": { + "type": "AdamW", + "adam_beta1": 0.9, + "adam_beta2": 0.999, + "weight_decay": 0.001 + }, "training": { "stat_file": "./dpa3.hdf5", "training_data": {