@@ -93,35 +93,72 @@ def build_neighbor_list(
9393
9494 """
9595 batch_size = coord .shape [0 ]
96- coord = coord .view (batch_size , - 1 )
96+ # coord is expected to be [batch_size, nall * 3]
97+ # The original line `coord = coord.view(batch_size, -1)` is a no-op if input is already 2D,
98+ # and input from `extend_input_and_build_neighbor_list` is `[nf, nall * 3]`.
99+ # So, it can be removed.
97100 nall = coord .shape [1 ] // 3
98101 # fill virtual atoms with large coords so they are not neighbors of any
99102 # real atom.
100103 if coord .numel () > 0 :
101- xmax = torch .max (coord ) + 2.0 * rcut
104+ xmax = torch .max (coord ) + 2.0 * rcut # coord is [batch_size, nall*3]
102105 else :
103106 xmax = torch .zeros (1 , dtype = coord .dtype , device = coord .device ) + 2.0 * rcut
104- # nf x nall
107+ # nf x nall (comment refers to batch_size x nall)
105108 is_vir = atype < 0
106- coord1 = torch .where (
107- is_vir [:, :, None ], xmax , coord .view (batch_size , nall , 3 )
108- ).view (batch_size , nall * 3 )
109+
110+ # Reshape coord to [batch_size, nall, 3] for easier manipulation
111+ coord_xyz = coord .view (batch_size , nall , 3 )
112+
113+ # Create a version of coordinates where virtual atoms are replaced by xmax
114+ # This tensor will have shape [batch_size, nall, 3]
115+ vcoord_xyz = torch .where (
116+ is_vir [:, :, None ], xmax , coord_xyz
117+ )
118+ # Original coord1 was:
119+ # coord1 = torch.where(
120+ # is_vir[:, :, None], xmax, coord.view(batch_size, nall, 3)
121+ # ).view(batch_size, nall * 3)
122+
109123 if isinstance (sel , int ):
110124 sel = [sel ]
111- # nloc x 3
112- coord0 = coord1 [:, : nloc * 3 ]
113- # nloc x nall x 3
114- diff = coord1 .view ([batch_size , - 1 , 3 ]).unsqueeze (1 ) - coord0 .view (
115- [batch_size , - 1 , 3 ]
116- ).unsqueeze (2 )
125+
126+ # Get the coordinates for the local atoms (first nloc atoms)
127+ # Shape: [batch_size, nloc, 3]
128+ vcoord_local_xyz = vcoord_xyz [:, :nloc , :]
129+ # Original coord0 was:
130+ # coord0 = coord1[:, : nloc * 3] # where coord1 was [batch_size, nall*3]
131+
132+ # Calculate displacement vectors.
133+ # vcoord_xyz.unsqueeze(1) gives [batch_size, 1, nall, 3]
134+ # vcoord_local_xyz.unsqueeze(2) gives [batch_size, nloc, 1, 3]
135+ # Broadcasting results in diff tensor of shape [batch_size, nloc, nall, 3]
136+ diff = vcoord_xyz .unsqueeze (1 ) - vcoord_local_xyz .unsqueeze (2 )
137+ # Original diff calculation that used views:
138+ # diff = coord1.view([batch_size, -1, 3]).unsqueeze(1) - coord0.view(
139+ # [batch_size, -1, 3]
140+ # ).unsqueeze(2)
117141 assert list (diff .shape ) == [batch_size , nloc , nall , 3 ]
118142 # nloc x nall
119143 rr = torch .linalg .norm (diff , dim = - 1 )
120144 # if central atom has two zero distances, sorting sometimes can not exclude itself
121- rr -= torch .eye (nloc , nall , dtype = rr .dtype , device = rr .device ).unsqueeze (0 )
145+ # The following operation makes rr[b, i, i] = -1.0 (assuming original self-distance is 0)
146+ # so that self-atom is sorted first.
147+ # Original line: rr -= torch.eye(nloc, nall, dtype=rr.dtype, device=rr.device).unsqueeze(0)
148+ # Efficiently subtract 1 from diagonal elements rr[b, i, i] for i < min(nloc, nall).
149+ # nall is rr.shape[2] here.
150+ diag_len = min (nloc , nall )
151+ if diag_len > 0 : # Ensure idx is not empty if using older PyTorch versions or for clarity
152+ idx = torch .arange (diag_len , device = rr .device )
153+ rr [:, idx , idx ] -= 1.0
122154 nsel = sum (sel )
123155 nnei = rr .shape [- 1 ]
124- rr , nlist = torch .topk (rr , min (nsel , nnei ), largest = False )
156+ # print(f"{nsel=}, {nnei=}")
157+ top_k = nsel if nsel <= nnei else nnei
158+ rr , nlist = torch .topk (rr , top_k + 1 , largest = False )
159+ # rr, nlist = torch.sort(rr, dim=-1) # FIXME
160+ # assert torch.allclose(rr, other=rr2[..., :top_k], atol=0)
161+
125162 # nloc x (nall-1)
126163 rr = rr [:, :, 1 :]
127164 nlist = nlist [:, :, 1 :]
0 commit comments