Skip to content
36 changes: 23 additions & 13 deletions deepmd/pt/utils/nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,33 +93,43 @@ def build_neighbor_list(

"""
batch_size = coord.shape[0]
coord = coord.view(batch_size, -1)
nall = coord.shape[1] // 3
# fill virtual atoms with large coords so they are not neighbors of any
# real atom.
if coord.numel() > 0:
xmax = torch.max(coord) + 2.0 * rcut
else:
xmax = torch.zeros(1, dtype=coord.dtype, device=coord.device) + 2.0 * rcut

coord_xyz = coord.view(batch_size, nall, 3)
# nf x nall
is_vir = atype < 0
coord1 = torch.where(
is_vir[:, :, None], xmax, coord.view(batch_size, nall, 3)
).view(batch_size, nall * 3)
# batch_size x nall x 3
vcoord_xyz = torch.where(is_vir[:, :, None], xmax, coord_xyz)
if isinstance(sel, int):
sel = [sel]
# nloc x 3
coord0 = coord1[:, : nloc * 3]
# nloc x nall x 3
diff = coord1.view([batch_size, -1, 3]).unsqueeze(1) - coord0.view(
[batch_size, -1, 3]
).unsqueeze(2)
assert list(diff.shape) == [batch_size, nloc, nall, 3]

# Get the coordinates for the local atoms (first nloc atoms)
# batch_size x nloc x 3
vcoord_local_xyz = vcoord_xyz[:, :nloc, :]

# Calculate displacement vectors.
diff = vcoord_xyz.unsqueeze(1) - vcoord_local_xyz.unsqueeze(2)
assert diff.shape == (batch_size, nloc, nall, 3)
# nloc x nall
rr = torch.linalg.norm(diff, dim=-1)
# if central atom has two zero distances, sorting sometimes can not exclude itself
rr -= torch.eye(nloc, nall, dtype=rr.dtype, device=rr.device).unsqueeze(0)
rr, nlist = torch.sort(rr, dim=-1)
# The following operation makes rr[b, i, i] = -1.0 (assuming original self-distance is 0)
# so that self-atom is sorted first.
diag_len = min(nloc, nall)
idx = torch.arange(diag_len, device=rr.device, dtype=torch.int)
rr[:, idx, idx] -= 1.0

nsel = sum(sel)
nnei = rr.shape[-1]
top_k = min(nsel + 1, nnei)
rr, nlist = torch.topk(rr, top_k, largest=False)

# nloc x (nall-1)
rr = rr[:, :, 1:]
nlist = nlist[:, :, 1:]
Expand Down
12 changes: 8 additions & 4 deletions source/tests/pt/model/test_nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,13 +152,17 @@ def test_build_multiple_nlist(self) -> None:
nlists[get_multiple_nlist_key(rcuts[dd], nsels[dd])].shape[-1],
nsels[dd],
)

# Since the nlist is created using unstable sort,
# we check if the set of indices in the nlist matches,
# regardless of the order
torch.testing.assert_close(
nlists[get_multiple_nlist_key(rcuts[0], nsels[0])],
nlist0,
nlists[get_multiple_nlist_key(rcuts[0], nsels[0])].sort(dim=-1).values,
nlist0.sort(dim=-1).values,
)
torch.testing.assert_close(
nlists[get_multiple_nlist_key(rcuts[1], nsels[1])],
nlist2,
nlists[get_multiple_nlist_key(rcuts[1], nsels[1])].sort(dim=-1).values,
nlist2.sort(dim=-1).values,
)

def test_extend_coord(self) -> None:
Expand Down
Loading