Skip to content

Commit 88f090b

Browse files
committed
use index to replace torch.eye
1 parent 9ef3723 commit 88f090b

File tree

1 file changed

+51
-14
lines changed

1 file changed

+51
-14
lines changed

deepmd/pt/utils/nlist.py

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -93,35 +93,72 @@ def build_neighbor_list(
9393
9494
"""
9595
batch_size = coord.shape[0]
96-
coord = coord.view(batch_size, -1)
96+
# coord is expected to be [batch_size, nall * 3]
97+
# The original line `coord = coord.view(batch_size, -1)` is a no-op if input is already 2D,
98+
# and input from `extend_input_and_build_neighbor_list` is `[nf, nall * 3]`.
99+
# So, it can be removed.
97100
nall = coord.shape[1] // 3
98101
# fill virtual atoms with large coords so they are not neighbors of any
99102
# real atom.
100103
if coord.numel() > 0:
101-
xmax = torch.max(coord) + 2.0 * rcut
104+
xmax = torch.max(coord) + 2.0 * rcut # coord is [batch_size, nall*3]
102105
else:
103106
xmax = torch.zeros(1, dtype=coord.dtype, device=coord.device) + 2.0 * rcut
104-
# nf x nall
107+
# nf x nall (comment refers to batch_size x nall)
105108
is_vir = atype < 0
106-
coord1 = torch.where(
107-
is_vir[:, :, None], xmax, coord.view(batch_size, nall, 3)
108-
).view(batch_size, nall * 3)
109+
110+
# Reshape coord to [batch_size, nall, 3] for easier manipulation
111+
coord_xyz = coord.view(batch_size, nall, 3)
112+
113+
# Create a version of coordinates where virtual atoms are replaced by xmax
114+
# This tensor will have shape [batch_size, nall, 3]
115+
vcoord_xyz = torch.where(
116+
is_vir[:, :, None], xmax, coord_xyz
117+
)
118+
# Original coord1 was:
119+
# coord1 = torch.where(
120+
# is_vir[:, :, None], xmax, coord.view(batch_size, nall, 3)
121+
# ).view(batch_size, nall * 3)
122+
109123
if isinstance(sel, int):
110124
sel = [sel]
111-
# nloc x 3
112-
coord0 = coord1[:, : nloc * 3]
113-
# nloc x nall x 3
114-
diff = coord1.view([batch_size, -1, 3]).unsqueeze(1) - coord0.view(
115-
[batch_size, -1, 3]
116-
).unsqueeze(2)
125+
126+
# Get the coordinates for the local atoms (first nloc atoms)
127+
# Shape: [batch_size, nloc, 3]
128+
vcoord_local_xyz = vcoord_xyz[:, :nloc, :]
129+
# Original coord0 was:
130+
# coord0 = coord1[:, : nloc * 3] # where coord1 was [batch_size, nall*3]
131+
132+
# Calculate displacement vectors.
133+
# vcoord_xyz.unsqueeze(1) gives [batch_size, 1, nall, 3]
134+
# vcoord_local_xyz.unsqueeze(2) gives [batch_size, nloc, 1, 3]
135+
# Broadcasting results in diff tensor of shape [batch_size, nloc, nall, 3]
136+
diff = vcoord_xyz.unsqueeze(1) - vcoord_local_xyz.unsqueeze(2)
137+
# Original diff calculation that used views:
138+
# diff = coord1.view([batch_size, -1, 3]).unsqueeze(1) - coord0.view(
139+
# [batch_size, -1, 3]
140+
# ).unsqueeze(2)
117141
assert list(diff.shape) == [batch_size, nloc, nall, 3]
118142
# nloc x nall
119143
rr = torch.linalg.norm(diff, dim=-1)
120144
# if central atom has two zero distances, sorting sometimes can not exclude itself
121-
rr -= torch.eye(nloc, nall, dtype=rr.dtype, device=rr.device).unsqueeze(0)
145+
# The following operation makes rr[b, i, i] = -1.0 (assuming original self-distance is 0)
146+
# so that self-atom is sorted first.
147+
# Original line: rr -= torch.eye(nloc, nall, dtype=rr.dtype, device=rr.device).unsqueeze(0)
148+
# Efficiently subtract 1 from diagonal elements rr[b, i, i] for i < min(nloc, nall).
149+
# nall is rr.shape[2] here.
150+
diag_len = min(nloc, nall)
151+
if diag_len > 0: # Ensure idx is not empty if using older PyTorch versions or for clarity
152+
idx = torch.arange(diag_len, device=rr.device)
153+
rr[:, idx, idx] -= 1.0
122154
nsel = sum(sel)
123155
nnei = rr.shape[-1]
124-
rr, nlist = torch.topk(rr, min(nsel, nnei), largest=False)
156+
# print(f"{nsel=}, {nnei=}")
157+
top_k = nsel if nsel <= nnei else nnei
158+
rr, nlist = torch.topk(rr, top_k+1, largest=False)
159+
# rr, nlist = torch.sort(rr, dim=-1) # FIXME
160+
# assert torch.allclose(rr, other=rr2[..., :top_k], atol=0)
161+
125162
# nloc x (nall-1)
126163
rr = rr[:, :, 1:]
127164
nlist = nlist[:, :, 1:]

0 commit comments

Comments
 (0)