|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +from collections.abc import ( |
| 3 | + Iterator, |
| 4 | +) |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +import torch |
| 8 | + |
| 9 | +from deepmd.dpmodel.utils.neighbor_stat import NeighborStatOP as NeighborStatOPDP |
| 10 | +from deepmd.pt_expt.common import ( |
| 11 | + torch_module, |
| 12 | +) |
| 13 | +from deepmd.pt_expt.utils.env import ( |
| 14 | + DEVICE, |
| 15 | +) |
| 16 | +from deepmd.utils.data_system import ( |
| 17 | + DeepmdDataSystem, |
| 18 | +) |
| 19 | +from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat |
| 20 | + |
| 21 | + |
| 22 | +@torch_module |
| 23 | +class NeighborStatOP(NeighborStatOPDP): |
| 24 | + pass |
| 25 | + |
| 26 | + |
| 27 | +class NeighborStat(BaseNeighborStat): |
| 28 | + """Neighbor statistics using torch on DEVICE. |
| 29 | +
|
| 30 | + Parameters |
| 31 | + ---------- |
| 32 | + ntypes : int |
| 33 | + The num of atom types |
| 34 | + rcut : float |
| 35 | + The cut-off radius |
| 36 | + mixed_type : bool, optional, default=False |
| 37 | + Treat all types as a single type. |
| 38 | + """ |
| 39 | + |
| 40 | + def __init__( |
| 41 | + self, |
| 42 | + ntypes: int, |
| 43 | + rcut: float, |
| 44 | + mixed_type: bool = False, |
| 45 | + ) -> None: |
| 46 | + super().__init__(ntypes, rcut, mixed_type) |
| 47 | + self.op = NeighborStatOP(ntypes, rcut, mixed_type) |
| 48 | + |
| 49 | + def iterator( |
| 50 | + self, data: DeepmdDataSystem |
| 51 | + ) -> Iterator[tuple[np.ndarray, float, str]]: |
| 52 | + """Produce neighbor statistics for each data set. |
| 53 | +
|
| 54 | + Yields |
| 55 | + ------ |
| 56 | + np.ndarray |
| 57 | + The maximal number of neighbors |
| 58 | + float |
| 59 | + The squared minimal distance between two atoms |
| 60 | + str |
| 61 | + The directory of the data system |
| 62 | + """ |
| 63 | + for ii in range(len(data.system_dirs)): |
| 64 | + for jj in data.data_systems[ii].dirs: |
| 65 | + data_set = data.data_systems[ii] |
| 66 | + data_set_data = data_set._load_set(jj) |
| 67 | + minrr2, max_nnei = self._execute( |
| 68 | + data_set_data["coord"], |
| 69 | + data_set_data["type"], |
| 70 | + data_set_data["box"] if data_set.pbc else None, |
| 71 | + ) |
| 72 | + yield np.max(max_nnei, axis=0), np.min(minrr2), jj |
| 73 | + |
| 74 | + def _execute( |
| 75 | + self, |
| 76 | + coord: np.ndarray, |
| 77 | + atype: np.ndarray, |
| 78 | + cell: np.ndarray | None, |
| 79 | + ) -> tuple[np.ndarray, np.ndarray]: |
| 80 | + """Execute the operation on DEVICE.""" |
| 81 | + minrr2, max_nnei = self.op( |
| 82 | + torch.from_numpy(coord).to(DEVICE), |
| 83 | + torch.from_numpy(atype).to(DEVICE), |
| 84 | + torch.from_numpy(cell).to(DEVICE) if cell is not None else None, |
| 85 | + ) |
| 86 | + minrr2 = minrr2.detach().cpu().numpy() |
| 87 | + max_nnei = max_nnei.detach().cpu().numpy() |
| 88 | + return minrr2, max_nnei |
0 commit comments