Skip to content

Commit 9ef3723

Browse files
committed
perf: use torch.topk to construct nlist
1 parent 43e0288 commit 9ef3723

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

deepmd/pt/utils/nlist.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def build_neighbor_list(
119119
rr = torch.linalg.norm(diff, dim=-1)
120120
# if central atom has two zero distances, sorting sometimes can not exclude itself
121121
rr -= torch.eye(nloc, nall, dtype=rr.dtype, device=rr.device).unsqueeze(0)
122-
rr, nlist = torch.sort(rr, dim=-1)
122+
nsel = sum(sel)
123+
nnei = rr.shape[-1]
124+
rr, nlist = torch.topk(rr, min(nsel, nnei), largest=False)
123125
# nloc x (nall-1)
124126
rr = rr[:, :, 1:]
125127
nlist = nlist[:, :, 1:]

0 commit comments

Comments
 (0)