Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepmd/dpmodel/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def call(
)
assert list(diff.shape) == [nframes, nloc, nall, 3]
# remove the diagonal elements
mask = xp.eye(nloc, nall, dtype=xp.bool)
mask = xp.eye(nloc, nall, dtype=xp.bool, device=array_api_compat.device(diff))
mask = xp.tile(mask[None, :, :, None], (nframes, 1, 1, 3))
diff = xp.where(mask, xp.full_like(diff, xp.inf), diff)
rr2 = xp.sum(xp.square(diff), axis=-1)
Expand Down
72 changes: 61 additions & 11 deletions deepmd/dpmodel/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,9 @@ def build_neighbor_list(
assert list(diff.shape) == [batch_size, nloc, nall, 3]
rr = xp.linalg.vector_norm(diff, axis=-1)
# if central atom has two zero distances, sorting sometimes can not exclude itself
rr -= xp.eye(nloc, nall, dtype=diff.dtype)[xp.newaxis, :, :]
rr -= xp.eye(nloc, nall, dtype=diff.dtype, device=array_api_compat.device(diff))[
xp.newaxis, :, :
]
nlist = xp.argsort(rr, axis=-1)
rr = xp.sort(rr, axis=-1)
rr = rr[:, :, 1:]
Expand All @@ -128,11 +130,26 @@ def build_neighbor_list(
nlist = nlist[:, :, :nsel]
else:
rr = xp.concatenate(
[rr, xp.ones([batch_size, nloc, nsel - nnei], dtype=rr.dtype) + rcut],
[
rr,
xp.ones(
[batch_size, nloc, nsel - nnei],
dtype=rr.dtype,
device=array_api_compat.device(rr),
)
+ rcut,
],
axis=-1,
)
nlist = xp.concatenate(
[nlist, xp.ones([batch_size, nloc, nsel - nnei], dtype=nlist.dtype)],
[
nlist,
xp.ones(
[batch_size, nloc, nsel - nnei],
dtype=nlist.dtype,
device=array_api_compat.device(nlist),
),
],
axis=-1,
)
assert list(nlist.shape) == [batch_size, nloc, nsel]
Expand Down Expand Up @@ -218,7 +235,11 @@ def build_multiple_neighbor_list(
return {}
nb, nloc, nsel = nlist.shape
if nsel < nsels[-1]:
pad = -1 * xp.ones((nb, nloc, nsels[-1] - nsel), dtype=nlist.dtype)
pad = -1 * xp.ones(
(nb, nloc, nsels[-1] - nsel),
dtype=nlist.dtype,
device=array_api_compat.device(nlist),
)
nlist = xp.concat([nlist, pad], axis=-1)
nsel = nsels[-1]
coord1 = xp.reshape(coord, (nb, -1, 3))
Expand Down Expand Up @@ -276,7 +297,12 @@ def extend_coord_with_ghosts(
xp = array_api_compat.array_namespace(coord, atype)
nf, nloc = atype.shape
# int64 for index
aidx = xp.tile(xp.arange(nloc, dtype=xp.int64)[xp.newaxis, :], (nf, 1))
aidx = xp.tile(
xp.arange(nloc, dtype=xp.int64, device=array_api_compat.device(atype))[
xp.newaxis, :
],
(nf, 1),
)
if cell is None:
nall = nloc
extend_coord = coord
Expand All @@ -288,17 +314,41 @@ def extend_coord_with_ghosts(
to_face = to_face_distance(cell)
nbuff = xp.astype(xp.ceil(rcut / to_face), xp.int64)
nbuff = xp.max(nbuff, axis=0)
xi = xp.arange(-int(nbuff[0]), int(nbuff[0]) + 1, 1, dtype=xp.int64)
yi = xp.arange(-int(nbuff[1]), int(nbuff[1]) + 1, 1, dtype=xp.int64)
zi = xp.arange(-int(nbuff[2]), int(nbuff[2]) + 1, 1, dtype=xp.int64)
xyz = xp.linalg.outer(xi, xp.asarray([1, 0, 0]))[:, xp.newaxis, xp.newaxis, :]
xi = xp.arange(
-int(nbuff[0]),
int(nbuff[0]) + 1,
1,
dtype=xp.int64,
device=array_api_compat.device(coord),
)
yi = xp.arange(
-int(nbuff[1]),
int(nbuff[1]) + 1,
1,
dtype=xp.int64,
device=array_api_compat.device(coord),
)
zi = xp.arange(
-int(nbuff[2]),
int(nbuff[2]) + 1,
1,
dtype=xp.int64,
device=array_api_compat.device(coord),
)
xyz = xp.linalg.outer(
xi, xp.asarray([1, 0, 0], device=array_api_compat.device(xi))
)[:, xp.newaxis, xp.newaxis, :]
xyz = (
xyz
+ xp.linalg.outer(yi, xp.asarray([0, 1, 0]))[xp.newaxis, :, xp.newaxis, :]
+ xp.linalg.outer(
yi, xp.asarray([0, 1, 0], device=array_api_compat.device(yi))
)[xp.newaxis, :, xp.newaxis, :]
)
xyz = (
xyz
+ xp.linalg.outer(zi, xp.asarray([0, 0, 1]))[xp.newaxis, xp.newaxis, :, :]
+ xp.linalg.outer(
zi, xp.asarray([0, 0, 1], device=array_api_compat.device(zi))
)[xp.newaxis, xp.newaxis, :, :]
)
xyz = xp.reshape(xyz, (-1, 3))
xyz = xp.astype(xyz, coord.dtype)
Expand Down
99 changes: 4 additions & 95 deletions deepmd/pt/utils/neighbor_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,103 +12,13 @@
from deepmd.pt.utils.env import (
DEVICE,
)
from deepmd.pt.utils.nlist import (
extend_coord_with_ghosts,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
from deepmd.utils.neighbor_stat import NeighborStat as BaseNeighborStat


class NeighborStatOP(torch.nn.Module):
"""Class for getting neighbor statistics data information.

Parameters
----------
ntypes
The num of atom types
rcut
The cut-off radius
mixed_types : bool, optional
If True, treat neighbors of all types as a single type.
"""

def __init__(
self,
ntypes: int,
rcut: float,
mixed_types: bool,
) -> None:
super().__init__()
self.rcut = float(rcut)
self.ntypes = ntypes
self.mixed_types = mixed_types

def forward(
self,
coord: torch.Tensor,
atype: torch.Tensor,
cell: torch.Tensor | None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Calculate the neareest neighbor distance between atoms, maximum nbor size of
atoms and the output data range of the environment matrix.

Parameters
----------
coord
The coordinates of atoms.
atype
The atom types.
cell
The cell.

Returns
-------
torch.Tensor
The minimal squared distance between two atoms, in the shape of (nframes,)
torch.Tensor
The maximal number of neighbors
"""
nframes = coord.shape[0]
coord = coord.view(nframes, -1, 3)
nloc = coord.shape[1]
coord = coord.view(nframes, nloc * 3)
extend_coord, extend_atype, _ = extend_coord_with_ghosts(
coord, atype, cell, self.rcut
)

coord1 = extend_coord.reshape(nframes, -1)
nall = coord1.shape[1] // 3
coord0 = coord1[:, : nloc * 3]
diff = (
coord1.reshape([nframes, -1, 3])[:, None, :, :]
- coord0.reshape([nframes, -1, 3])[:, :, None, :]
)
assert list(diff.shape) == [nframes, nloc, nall, 3]
# remove the diagonal elements
mask = torch.eye(nloc, nall, dtype=torch.bool, device=diff.device)
diff[:, mask] = torch.inf
rr2 = torch.sum(torch.square(diff), dim=-1)
min_rr2, _ = torch.min(rr2, dim=-1)
# count the number of neighbors
if not self.mixed_types:
mask = rr2 < self.rcut**2
nnei = torch.zeros(
(nframes, nloc, self.ntypes), dtype=torch.int32, device=mask.device
)
for ii in range(self.ntypes):
nnei[:, :, ii] = torch.sum(
mask & extend_atype.eq(ii)[:, None, :], dim=-1
)
else:
mask = rr2 < self.rcut**2
# virtual types (<0) are not counted
nnei = torch.sum(mask & extend_atype.ge(0)[:, None, :], dim=-1).view(
nframes, nloc, 1
)
max_nnei, _ = torch.max(nnei, dim=1)
return min_rr2, max_nnei
from deepmd.dpmodel.utils.neighbor_stat import (
NeighborStatOP,
)


class NeighborStat(BaseNeighborStat):
Expand All @@ -131,8 +41,7 @@ def __init__(
mixed_type: bool = False,
) -> None:
super().__init__(ntypes, rcut, mixed_type)
op = NeighborStatOP(ntypes, rcut, mixed_type)
self.op = torch.jit.script(op)
self.op = NeighborStatOP(ntypes, rcut, mixed_type)
self.auto_batch_size = AutoBatchSize()

def iterator(
Expand Down
Loading