Skip to content

Commit e065c22

Browse files
author
Han Wang
committed
Updated the backend to use the GPU-aware NeighborStat from deepmd.pt_expt.utils.neighbor_stat (ported from PR deepmodeling#5270)
1 parent 8fd1153 commit e065c22

2 files changed

Lines changed: 89 additions & 1 deletion

File tree

deepmd/backend/pt_expt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def neighbor_stat(self) -> type["NeighborStat"]:
9191
type[NeighborStat]
9292
The neighbor statistics of the backend.
9393
"""
94-
from deepmd.dpmodel.utils.neighbor_stat import (
94+
from deepmd.pt_expt.utils.neighbor_stat import (
9595
NeighborStat,
9696
)
9797

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
from collections.abc import (
3+
Iterator,
4+
)
5+
6+
import numpy as np
7+
import torch
8+
9+
from deepmd.dpmodel.utils.neighbor_stat import NeighborStatOP as NeighborStatOPDP
10+
from deepmd.pt_expt.common import (
11+
torch_module,
12+
)
13+
from deepmd.pt_expt.utils.env import (
14+
DEVICE,
15+
)
16+
from deepmd.utils.data_system import (
17+
DeepmdDataSystem,
18+
)
19+
from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat
20+
21+
22+
@torch_module
23+
class NeighborStatOP(NeighborStatOPDP):
24+
pass
25+
26+
27+
class NeighborStat(BaseNeighborStat):
28+
"""Neighbor statistics using torch on DEVICE.
29+
30+
Parameters
31+
----------
32+
ntypes : int
33+
The num of atom types
34+
rcut : float
35+
The cut-off radius
36+
mixed_type : bool, optional, default=False
37+
Treat all types as a single type.
38+
"""
39+
40+
def __init__(
41+
self,
42+
ntypes: int,
43+
rcut: float,
44+
mixed_type: bool = False,
45+
) -> None:
46+
super().__init__(ntypes, rcut, mixed_type)
47+
self.op = NeighborStatOP(ntypes, rcut, mixed_type)
48+
49+
def iterator(
50+
self, data: DeepmdDataSystem
51+
) -> Iterator[tuple[np.ndarray, float, str]]:
52+
"""Produce neighbor statistics for each data set.
53+
54+
Yields
55+
------
56+
np.ndarray
57+
The maximal number of neighbors
58+
float
59+
The squared minimal distance between two atoms
60+
str
61+
The directory of the data system
62+
"""
63+
for ii in range(len(data.system_dirs)):
64+
for jj in data.data_systems[ii].dirs:
65+
data_set = data.data_systems[ii]
66+
data_set_data = data_set._load_set(jj)
67+
minrr2, max_nnei = self._execute(
68+
data_set_data["coord"],
69+
data_set_data["type"],
70+
data_set_data["box"] if data_set.pbc else None,
71+
)
72+
yield np.max(max_nnei, axis=0), np.min(minrr2), jj
73+
74+
def _execute(
75+
self,
76+
coord: np.ndarray,
77+
atype: np.ndarray,
78+
cell: np.ndarray | None,
79+
) -> tuple[np.ndarray, np.ndarray]:
80+
"""Execute the operation on DEVICE."""
81+
minrr2, max_nnei = self.op(
82+
torch.from_numpy(coord).to(DEVICE),
83+
torch.from_numpy(atype).to(DEVICE),
84+
torch.from_numpy(cell).to(DEVICE) if cell is not None else None,
85+
)
86+
minrr2 = minrr2.detach().cpu().numpy()
87+
max_nnei = max_nnei.detach().cpu().numpy()
88+
return minrr2, max_nnei

0 commit comments

Comments
 (0)