|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +import math |
| 3 | +from typing import ( |
| 4 | + Optional, |
| 5 | +) |
| 6 | + |
| 7 | +import torch |
| 8 | + |
| 9 | +from deepmd.pt.utils import ( |
| 10 | + env, |
| 11 | +) |
| 12 | + |
| 13 | + |
| 14 | +@torch.jit.export |
| 15 | +def aggregate( |
| 16 | + data: torch.Tensor, |
| 17 | + owners: torch.Tensor, |
| 18 | + average: bool = True, |
| 19 | + num_owner: Optional[int] = None, |
| 20 | +) -> torch.Tensor: |
| 21 | + """ |
| 22 | + Aggregate rows in data by specifying the owners. |
| 23 | +
|
| 24 | + Parameters |
| 25 | + ---------- |
| 26 | + data : data tensor to aggregate [n_row, feature_dim] |
| 27 | + owners : specify the owner of each row [n_row, 1] |
| 28 | + average : if True, average the rows, if False, sum the rows. |
| 29 | + Default = True |
| 30 | + num_owner : the number of owners, this is needed if the |
| 31 | + max idx of owner is not presented in owners tensor |
| 32 | + Default = None |
| 33 | +
|
| 34 | + Returns |
| 35 | + ------- |
| 36 | + output: [num_owner, feature_dim] |
| 37 | + """ |
| 38 | + bin_count = torch.bincount(owners) |
| 39 | + bin_count = bin_count.where(bin_count != 0, bin_count.new_ones(1)) |
| 40 | + |
| 41 | + if (num_owner is not None) and (bin_count.shape[0] != num_owner): |
| 42 | + difference = num_owner - bin_count.shape[0] |
| 43 | + bin_count = torch.cat([bin_count, bin_count.new_ones(difference)]) |
| 44 | + |
| 45 | + # make sure this operation is done on the same device of data and owners |
| 46 | + output = data.new_zeros([bin_count.shape[0], data.shape[1]]) |
| 47 | + output = output.index_add_(0, owners, data) |
| 48 | + if average: |
| 49 | + output = (output.T / bin_count).T |
| 50 | + return output |
| 51 | + |
| 52 | + |
| 53 | +@torch.jit.export |
| 54 | +def get_graph_index( |
| 55 | + nlist: torch.Tensor, |
| 56 | + nlist_mask: torch.Tensor, |
| 57 | + a_nlist_mask: torch.Tensor, |
| 58 | + d_nlist_mask: torch.Tensor, |
| 59 | + nall: int, |
| 60 | + calculate_dihedral: bool = False, |
| 61 | + use_loc_mapping: bool = True, |
| 62 | +): |
| 63 | + """ |
| 64 | + Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`. |
| 65 | +
|
| 66 | + Parameters |
| 67 | + ---------- |
| 68 | + nlist : nf x nloc x nnei |
| 69 | + Neighbor list. (padded neis are set to 0) |
| 70 | + nlist_mask : nf x nloc x nnei |
| 71 | + Masks of the neighbor list. real nei 1 otherwise 0 |
| 72 | + a_nlist_mask : nf x nloc x a_nnei |
| 73 | + Masks of the neighbor list for angle. real nei 1 otherwise 0 |
| 74 | + nall |
| 75 | + The number of extended atoms. |
| 76 | +
|
| 77 | + Returns |
| 78 | + ------- |
| 79 | + edge_index : n_edge x 2 |
| 80 | + n2e_index : n_edge |
| 81 | + Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i). |
| 82 | + n_ext2e_index : n_edge |
| 83 | + Broadcast indices from extended node(j) to edge(ij). |
| 84 | + angle_index : n_angle x 3 |
| 85 | + n2a_index : n_angle |
| 86 | + Broadcast indices from extended node(j) to angle(ijk). |
| 87 | + eij2a_index : n_angle |
| 88 | + Broadcast indices from edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij). |
| 89 | + eik2a_index : n_angle |
| 90 | + Broadcast indices from edge(ik) to angle(ijk). |
| 91 | + dihedral_index : n_dihedral x 2 |
| 92 | + aijk2d_index : n_dihedral |
| 93 | + Broadcast indices from angle(ijk) to dihedral(ijkl), or reduction indices from dihedral(ijkl) to angle(ijk). |
| 94 | + aijl2d_index : n_dihedral |
| 95 | + Broadcast indices from angle(ijl) to dihedral(ijkl). |
| 96 | + """ |
| 97 | + nf, nloc, nnei = nlist.shape |
| 98 | + _, _, a_nnei = a_nlist_mask.shape |
| 99 | + # nf x nloc x nnei x nnei |
| 100 | + # nlist_mask_3d = nlist_mask[:, :, :, None] & nlist_mask[:, :, None, :] |
| 101 | + a_nlist_mask_3d = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :] |
| 102 | + n_edge = nlist_mask.sum().item() |
| 103 | + |
| 104 | + # following: get n2e_index, n_ext2e_index, n2a_index, eij2a_index, eik2a_index |
| 105 | + |
| 106 | + # 1. atom graph |
| 107 | + # node(i) to edge(ij) index_select; edge(ij) to node aggregate |
| 108 | + nlist_loc_index = torch.arange(0, nf * nloc, dtype=nlist.dtype, device=nlist.device) |
| 109 | + # nf x nloc x nnei |
| 110 | + n2e_index = nlist_loc_index.reshape(nf, nloc, 1).expand(-1, -1, nnei) |
| 111 | + # n_edge |
| 112 | + n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0] |
| 113 | + |
| 114 | + # node_ext(j) to edge(ij) index_select |
| 115 | + frame_shift = torch.arange(0, nf, dtype=nlist.dtype, device=nlist.device) * ( |
| 116 | + nall if not use_loc_mapping else nloc |
| 117 | + ) |
| 118 | + shifted_nlist = nlist + frame_shift[:, None, None] |
| 119 | + # n_edge |
| 120 | + n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1] |
| 121 | + |
| 122 | + # 2. edge graph |
| 123 | + # node(i) to angle(ijk) index_select |
| 124 | + n2a_index = nlist_loc_index.reshape(nf, nloc, 1, 1).expand(-1, -1, a_nnei, a_nnei) |
| 125 | + # n_angle |
| 126 | + n2a_index = n2a_index[a_nlist_mask_3d] |
| 127 | + |
| 128 | + # edge(ij) to angle(ijk) index_select; angle(ijk) to edge(ij) aggregate |
| 129 | + edge_id = torch.arange(0, n_edge, dtype=nlist.dtype, device=nlist.device) |
| 130 | + # nf x nloc x nnei |
| 131 | + edge_index = torch.zeros([nf, nloc, nnei], dtype=nlist.dtype, device=nlist.device) |
| 132 | + edge_index[nlist_mask] = edge_id |
| 133 | + # only cut a_nnei neighbors, to avoid nnei x nnei |
| 134 | + edge_index = edge_index[:, :, :a_nnei] |
| 135 | + edge_index_ij = edge_index.unsqueeze(-1).expand(-1, -1, -1, a_nnei) |
| 136 | + # n_angle |
| 137 | + eij2a_index = edge_index_ij[a_nlist_mask_3d] |
| 138 | + |
| 139 | + # edge(ik) to angle(ijk) index_select |
| 140 | + edge_index_ik = edge_index.unsqueeze(-2).expand(-1, -1, a_nnei, -1) |
| 141 | + # n_angle |
| 142 | + eik2a_index = edge_index_ik[a_nlist_mask_3d] |
| 143 | + |
| 144 | + if calculate_dihedral: |
| 145 | + # 3. angle graph |
| 146 | + n_angle = a_nlist_mask_3d.sum().item() |
| 147 | + _, _, d_nnei = d_nlist_mask.shape |
| 148 | + |
| 149 | + # nf x nloc x d_nnei x d_nnei x d_nnei |
| 150 | + # should expel same j k l |
| 151 | + d_nlist_mask_4d = ( |
| 152 | + d_nlist_mask[:, :, :, None, None] |
| 153 | + & d_nlist_mask[:, :, None, :, None] |
| 154 | + & d_nlist_mask[:, :, None, None, :] |
| 155 | + ) |
| 156 | + # d_nnei x d_nnei |
| 157 | + d_eye = torch.eye(d_nnei, dtype=d_nlist_mask.dtype, device=d_nlist_mask.device) |
| 158 | + d_eye = d_eye[:, :, None] | d_eye[:, None, :] | d_eye[None, :, :] |
| 159 | + d_nlist_mask_4d = d_nlist_mask_4d & ~d_eye[None, None, ...] |
| 160 | + |
| 161 | + # angle(ijk) to dihedral(ijkl) index_select; dihedral(ijkl) to angle(ijk) aggregate |
| 162 | + angle_id = torch.arange(0, n_angle, dtype=nlist.dtype, device=nlist.device) |
| 163 | + # nf x nloc x a_nnei x a_nnei |
| 164 | + angle_index = torch.zeros( |
| 165 | + [nf, nloc, a_nnei, a_nnei], dtype=nlist.dtype, device=nlist.device |
| 166 | + ) |
| 167 | + angle_index[a_nlist_mask_3d] = angle_id |
| 168 | + |
| 169 | + # only cut d_nnei neighbors, to avoid a_nnei x a_nnei x a_nnei |
| 170 | + angle_index = angle_index[:, :, :d_nnei, :d_nnei] |
| 171 | + angle_index_ijk = angle_index.unsqueeze(-1).expand(-1, -1, -1, -1, d_nnei) |
| 172 | + # n_dihedral |
| 173 | + aijk2d_index = angle_index_ijk[d_nlist_mask_4d] |
| 174 | + |
| 175 | + # angle(ijl) to dihedral(ijkl) index_select; |
| 176 | + angle_index_ijl = angle_index.unsqueeze(-2).expand(-1, -1, -1, d_nnei, -1) |
| 177 | + # n_dihedral |
| 178 | + aijl2d_index = angle_index_ijl[d_nlist_mask_4d] |
| 179 | + |
| 180 | + dihedral_index = torch.cat( |
| 181 | + [aijk2d_index.unsqueeze(-1), aijl2d_index.unsqueeze(-1)], dim=-1 |
| 182 | + ) |
| 183 | + else: |
| 184 | + dihedral_index = None |
| 185 | + d_nlist_mask_4d = None |
| 186 | + |
| 187 | + return ( |
| 188 | + torch.cat([n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], dim=-1), |
| 189 | + torch.cat( |
| 190 | + [ |
| 191 | + n2a_index.unsqueeze(-1), |
| 192 | + eij2a_index.unsqueeze(-1), |
| 193 | + eik2a_index.unsqueeze(-1), |
| 194 | + ], |
| 195 | + dim=-1, |
| 196 | + ), |
| 197 | + dihedral_index, |
| 198 | + a_nlist_mask_3d, |
| 199 | + d_nlist_mask_4d, |
| 200 | + ) |
| 201 | + |
| 202 | + |
| 203 | +class BesselBasis(torch.nn.Module): |
| 204 | + """f : (*, 1) -> (*, bessel_basis_num).""" |
| 205 | + |
| 206 | + def __init__( |
| 207 | + self, |
| 208 | + cutoff_length: float, |
| 209 | + bessel_basis_num: int = 8, |
| 210 | + trainable_coeff: bool = True, |
| 211 | + ): |
| 212 | + super().__init__() |
| 213 | + self.num_basis = bessel_basis_num |
| 214 | + self.prefactor = 2.0 / cutoff_length |
| 215 | + self.coeffs = torch.FloatTensor( |
| 216 | + [n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1)] |
| 217 | + ) |
| 218 | + if trainable_coeff: |
| 219 | + self.coeffs = torch.nn.Parameter(self.coeffs) |
| 220 | + |
| 221 | + def forward(self, r: torch.Tensor) -> torch.Tensor: |
| 222 | + return self.prefactor * torch.sin(self.coeffs * r) / (r + 1e-8) |
| 223 | + |
| 224 | + |
| 225 | +class GaussianSmearing(torch.nn.Module): |
| 226 | + def __init__( |
| 227 | + self, |
| 228 | + start: float = -5.0, |
| 229 | + stop: float = 5.0, |
| 230 | + num_gaussians: int = 50, |
| 231 | + basis_width_scalar: float = 1.0, |
| 232 | + ) -> None: |
| 233 | + super().__init__() |
| 234 | + self.num_output = num_gaussians |
| 235 | + offset = torch.linspace( |
| 236 | + start, stop, num_gaussians, device=env.DEVICE, dtype=torch.float32 |
| 237 | + ) |
| 238 | + self.coeff = -0.5 / (basis_width_scalar * (offset[1] - offset[0])).item() ** 2 |
| 239 | + self.register_buffer("offset", offset) |
| 240 | + |
| 241 | + def forward(self, dist) -> torch.Tensor: |
| 242 | + dist = dist - self.offset |
| 243 | + return torch.exp(self.coeff * torch.pow(dist, 2)) |
| 244 | + |
| 245 | + |
| 246 | +class RadialMLP(torch.nn.Module): |
| 247 | + """Contruct a radial function (linear layers + layer normalization + SiLU) given a list of channels.""" |
| 248 | + |
| 249 | + def __init__(self, channels_list) -> None: |
| 250 | + super().__init__() |
| 251 | + modules = [] |
| 252 | + input_channels = channels_list[0] |
| 253 | + for i in range(len(channels_list)): |
| 254 | + if i == 0: |
| 255 | + continue |
| 256 | + |
| 257 | + modules.append(torch.nn.Linear(input_channels, channels_list[i], bias=True)) |
| 258 | + input_channels = channels_list[i] |
| 259 | + |
| 260 | + if i == len(channels_list) - 1: |
| 261 | + break |
| 262 | + |
| 263 | + modules.append(torch.nn.LayerNorm(channels_list[i])) |
| 264 | + modules.append(torch.nn.SiLU()) |
| 265 | + |
| 266 | + self.net = torch.nn.Sequential(*modules) |
| 267 | + |
| 268 | + def forward(self, inputs: torch.Tensor) -> torch.Tensor: |
| 269 | + return self.net(inputs) |
| 270 | + |
| 271 | + |
| 272 | +class PolynomialEnvelope(torch.nn.Module): |
| 273 | + """Polynomial envelope function that ensures a smooth cutoff.""" |
| 274 | + |
| 275 | + def __init__(self, exponent: int = 5) -> None: |
| 276 | + super().__init__() |
| 277 | + assert exponent > 0 |
| 278 | + self.p: float = float(exponent) |
| 279 | + self.a: float = -(self.p + 1) * (self.p + 2) / 2 |
| 280 | + self.b: float = self.p * (self.p + 2) |
| 281 | + self.c: float = -self.p * (self.p + 1) / 2 |
| 282 | + |
| 283 | + def forward(self, d_scaled: torch.Tensor) -> torch.Tensor: |
| 284 | + env_val = ( |
| 285 | + 1 |
| 286 | + + self.a * d_scaled**self.p |
| 287 | + + self.b * d_scaled ** (self.p + 1) |
| 288 | + + self.c * d_scaled ** (self.p + 2) |
| 289 | + ) |
| 290 | + return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled)) |
0 commit comments