Skip to content

Commit 8f021aa

Browse files
refactor(pt): reuse dpmodel NeighborStatOP (#5137)
This class does not need torchscript, so we can reuse dpmodel codes to reduce redundancy. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Ensure tensors and masks are allocated on the same device as input data, fixing device-consistency issues across neighbor-list and neighbor-stat computations. * **Refactor** * Consolidated neighbor-stat computation into a shared implementation and simplified local orchestration while preserving public interfaces and behaviors. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6012b4d commit 8f021aa

File tree

3 files changed

+66
-107
lines changed

3 files changed

+66
-107
lines changed

deepmd/dpmodel/utils/neighbor_stat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def call(
8787
)
8888
assert list(diff.shape) == [nframes, nloc, nall, 3]
8989
# remove the diagonal elements
90-
mask = xp.eye(nloc, nall, dtype=xp.bool)
90+
mask = xp.eye(nloc, nall, dtype=xp.bool, device=array_api_compat.device(diff))
9191
mask = xp.tile(mask[None, :, :, None], (nframes, 1, 1, 3))
9292
diff = xp.where(mask, xp.full_like(diff, xp.inf), diff)
9393
rr2 = xp.sum(xp.square(diff), axis=-1)

deepmd/dpmodel/utils/nlist.py

Lines changed: 61 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,9 @@ def build_neighbor_list(
117117
assert list(diff.shape) == [batch_size, nloc, nall, 3]
118118
rr = xp.linalg.vector_norm(diff, axis=-1)
119119
# if central atom has two zero distances, sorting sometimes can not exclude itself
120-
rr -= xp.eye(nloc, nall, dtype=diff.dtype)[xp.newaxis, :, :]
120+
rr -= xp.eye(nloc, nall, dtype=diff.dtype, device=array_api_compat.device(diff))[
121+
xp.newaxis, :, :
122+
]
121123
nlist = xp.argsort(rr, axis=-1)
122124
rr = xp.sort(rr, axis=-1)
123125
rr = rr[:, :, 1:]
@@ -128,11 +130,26 @@ def build_neighbor_list(
128130
nlist = nlist[:, :, :nsel]
129131
else:
130132
rr = xp.concatenate(
131-
[rr, xp.ones([batch_size, nloc, nsel - nnei], dtype=rr.dtype) + rcut],
133+
[
134+
rr,
135+
xp.ones(
136+
[batch_size, nloc, nsel - nnei],
137+
dtype=rr.dtype,
138+
device=array_api_compat.device(rr),
139+
)
140+
+ rcut,
141+
],
132142
axis=-1,
133143
)
134144
nlist = xp.concatenate(
135-
[nlist, xp.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)],
145+
[
146+
nlist,
147+
xp.ones(
148+
[batch_size, nloc, nsel - nnei],
149+
dtype=nlist.dtype,
150+
device=array_api_compat.device(nlist),
151+
),
152+
],
136153
axis=-1,
137154
)
138155
assert list(nlist.shape) == [batch_size, nloc, nsel]
@@ -218,7 +235,11 @@ def build_multiple_neighbor_list(
218235
return {}
219236
nb, nloc, nsel = nlist.shape
220237
if nsel < nsels[-1]:
221-
pad = -1 * xp.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype)
238+
pad = -1 * xp.ones(
239+
(nb, nloc, nsels[-1] - nsel),
240+
dtype=nlist.dtype,
241+
device=array_api_compat.device(nlist),
242+
)
222243
nlist = xp.concat([nlist, pad], axis=-1)
223244
nsel = nsels[-1]
224245
coord1 = xp.reshape(coord, (nb, -1, 3))
@@ -276,7 +297,12 @@ def extend_coord_with_ghosts(
276297
xp = array_api_compat.array_namespace(coord, atype)
277298
nf, nloc = atype.shape
278299
# int64 for index
279-
aidx = xp.tile(xp.arange(nloc, dtype=xp.int64)[xp.newaxis, :], (nf, 1))
300+
aidx = xp.tile(
301+
xp.arange(nloc, dtype=xp.int64, device=array_api_compat.device(atype))[
302+
xp.newaxis, :
303+
],
304+
(nf, 1),
305+
)
280306
if cell is None:
281307
nall = nloc
282308
extend_coord = coord
@@ -288,17 +314,41 @@ def extend_coord_with_ghosts(
288314
to_face = to_face_distance(cell)
289315
nbuff = xp.astype(xp.ceil(rcut / to_face), xp.int64)
290316
nbuff = xp.max(nbuff, axis=0)
291-
xi = xp.arange(-int(nbuff[0]), int(nbuff[0]) + 1, 1, dtype=xp.int64)
292-
yi = xp.arange(-int(nbuff[1]), int(nbuff[1]) + 1, 1, dtype=xp.int64)
293-
zi = xp.arange(-int(nbuff[2]), int(nbuff[2]) + 1, 1, dtype=xp.int64)
294-
xyz = xp.linalg.outer(xi, xp.asarray([1, 0, 0]))[:, xp.newaxis, xp.newaxis, :]
317+
xi = xp.arange(
318+
-int(nbuff[0]),
319+
int(nbuff[0]) + 1,
320+
1,
321+
dtype=xp.int64,
322+
device=array_api_compat.device(coord),
323+
)
324+
yi = xp.arange(
325+
-int(nbuff[1]),
326+
int(nbuff[1]) + 1,
327+
1,
328+
dtype=xp.int64,
329+
device=array_api_compat.device(coord),
330+
)
331+
zi = xp.arange(
332+
-int(nbuff[2]),
333+
int(nbuff[2]) + 1,
334+
1,
335+
dtype=xp.int64,
336+
device=array_api_compat.device(coord),
337+
)
338+
xyz = xp.linalg.outer(
339+
xi, xp.asarray([1, 0, 0], device=array_api_compat.device(xi))
340+
)[:, xp.newaxis, xp.newaxis, :]
295341
xyz = (
296342
xyz
297-
+ xp.linalg.outer(yi, xp.asarray([0, 1, 0]))[xp.newaxis, :, xp.newaxis, :]
343+
+ xp.linalg.outer(
344+
yi, xp.asarray([0, 1, 0], device=array_api_compat.device(yi))
345+
)[xp.newaxis, :, xp.newaxis, :]
298346
)
299347
xyz = (
300348
xyz
301-
+ xp.linalg.outer(zi, xp.asarray([0, 0, 1]))[xp.newaxis, xp.newaxis, :, :]
349+
+ xp.linalg.outer(
350+
zi, xp.asarray([0, 0, 1], device=array_api_compat.device(zi))
351+
)[xp.newaxis, xp.newaxis, :, :]
302352
)
303353
xyz = xp.reshape(xyz, (-1, 3))
304354
xyz = xp.astype(xyz, coord.dtype)

deepmd/pt/utils/neighbor_stat.py

Lines changed: 4 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -6,111 +6,21 @@
66
import numpy as np
77
import torch
88

9+
from deepmd.dpmodel.utils.neighbor_stat import (
10+
NeighborStatOP,
11+
)
912
from deepmd.pt.utils.auto_batch_size import (
1013
AutoBatchSize,
1114
)
1215
from deepmd.pt.utils.env import (
1316
DEVICE,
1417
)
15-
from deepmd.pt.utils.nlist import (
16-
extend_coord_with_ghosts,
17-
)
1818
from deepmd.utils.data_system import (
1919
DeepmdDataSystem,
2020
)
2121
from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat
2222

2323

24-
class NeighborStatOP(torch.nn.Module):
25-
"""Class for getting neighbor statistics data information.
26-
27-
Parameters
28-
----------
29-
ntypes
30-
The num of atom types
31-
rcut
32-
The cut-off radius
33-
mixed_types : bool, optional
34-
If True, treat neighbors of all types as a single type.
35-
"""
36-
37-
def __init__(
38-
self,
39-
ntypes: int,
40-
rcut: float,
41-
mixed_types: bool,
42-
) -> None:
43-
super().__init__()
44-
self.rcut = float(rcut)
45-
self.ntypes = ntypes
46-
self.mixed_types = mixed_types
47-
48-
def forward(
49-
self,
50-
coord: torch.Tensor,
51-
atype: torch.Tensor,
52-
cell: torch.Tensor | None,
53-
) -> tuple[torch.Tensor, torch.Tensor]:
54-
"""Calculate the neareest neighbor distance between atoms, maximum nbor size of
55-
atoms and the output data range of the environment matrix.
56-
57-
Parameters
58-
----------
59-
coord
60-
The coordinates of atoms.
61-
atype
62-
The atom types.
63-
cell
64-
The cell.
65-
66-
Returns
67-
-------
68-
torch.Tensor
69-
The minimal squared distance between two atoms, in the shape of (nframes,)
70-
torch.Tensor
71-
The maximal number of neighbors
72-
"""
73-
nframes = coord.shape[0]
74-
coord = coord.view(nframes, -1, 3)
75-
nloc = coord.shape[1]
76-
coord = coord.view(nframes, nloc * 3)
77-
extend_coord, extend_atype, _ = extend_coord_with_ghosts(
78-
coord, atype, cell, self.rcut
79-
)
80-
81-
coord1 = extend_coord.reshape(nframes, -1)
82-
nall = coord1.shape[1] // 3
83-
coord0 = coord1[:, : nloc * 3]
84-
diff = (
85-
coord1.reshape([nframes, -1, 3])[:, None, :, :]
86-
- coord0.reshape([nframes, -1, 3])[:, :, None, :]
87-
)
88-
assert list(diff.shape) == [nframes, nloc, nall, 3]
89-
# remove the diagonal elements
90-
mask = torch.eye(nloc, nall, dtype=torch.bool, device=diff.device)
91-
diff[:, mask] = torch.inf
92-
rr2 = torch.sum(torch.square(diff), dim=-1)
93-
min_rr2, _ = torch.min(rr2, dim=-1)
94-
# count the number of neighbors
95-
if not self.mixed_types:
96-
mask = rr2 < self.rcut**2
97-
nnei = torch.zeros(
98-
(nframes, nloc, self.ntypes), dtype=torch.int32, device=mask.device
99-
)
100-
for ii in range(self.ntypes):
101-
nnei[:, :, ii] = torch.sum(
102-
mask & extend_atype.eq(ii)[:, None, :], dim=-1
103-
)
104-
else:
105-
mask = rr2 < self.rcut**2
106-
# virtual types (<0) are not counted
107-
nnei = torch.sum(mask & extend_atype.ge(0)[:, None, :], dim=-1).view(
108-
nframes, nloc, 1
109-
)
110-
max_nnei, _ = torch.max(nnei, dim=1)
111-
return min_rr2, max_nnei
112-
113-
11424
class NeighborStat(BaseNeighborStat):
11525
"""Neighbor statistics using pure NumPy.
11626
@@ -131,8 +41,7 @@ def __init__(
13141
mixed_type: bool = False,
13242
) -> None:
13343
super().__init__(ntypes, rcut, mixed_type)
134-
op = NeighborStatOP(ntypes, rcut, mixed_type)
135-
self.op = torch.jit.script(op)
44+
self.op = NeighborStatOP(ntypes, rcut, mixed_type)
13645
self.auto_batch_size = AutoBatchSize()
13746

13847
def iterator(

0 commit comments

Comments
 (0)