Skip to content

Commit a349de1

Browse files
iProzdcaic99
authored andcommitted
feat(pt): add use_loc_mapping
1 parent ac6677e commit a349de1

6 files changed

Lines changed: 566 additions & 1 deletion

File tree

deepmd/pt/model/descriptor/dpa3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def init_subclass_params(sub_data, sub_class):
162162
)
163163

164164
self.use_econf_tebd = use_econf_tebd
165+
self.use_loc_mapping = use_loc_mapping
165166
self.use_tebd_bias = use_tebd_bias
166167
self.type_map = type_map
167168
self.tebd_dim = self.repflow_args.n_dim
@@ -472,6 +473,7 @@ def forward(
472473
The smooth switch function. shape: nf x nloc x nnei
473474
474475
"""
476+
parrallel_mode = comm_dict is not None
475477
# cast the input to internal precsion
476478
extended_coord = extended_coord.to(dtype=self.prec)
477479
nframes, nloc, nnei = nlist.shape

deepmd/pt/model/descriptor/repflows.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,7 @@ def forward(
476476
else:
477477
node_ebd_ext = None
478478
else:
479+
assert comm_dict is not None
479480
has_spin = "has_spin" in comm_dict
480481
if not has_spin:
481482
n_padding = nall - nloc

deepmd/pt/model/network/utils.py

Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import math
3+
from typing import (
4+
Optional,
5+
)
6+
7+
import torch
8+
9+
from deepmd.pt.utils import (
10+
env,
11+
)
12+
13+
14+
@torch.jit.export
15+
def aggregate(
16+
data: torch.Tensor,
17+
owners: torch.Tensor,
18+
average: bool = True,
19+
num_owner: Optional[int] = None,
20+
) -> torch.Tensor:
21+
"""
22+
Aggregate rows in data by specifying the owners.
23+
24+
Parameters
25+
----------
26+
data : data tensor to aggregate [n_row, feature_dim]
27+
owners : specify the owner of each row [n_row, 1]
28+
average : if True, average the rows, if False, sum the rows.
29+
Default = True
30+
num_owner : the number of owners, this is needed if the
31+
max idx of owner is not presented in owners tensor
32+
Default = None
33+
34+
Returns
35+
-------
36+
output: [num_owner, feature_dim]
37+
"""
38+
bin_count = torch.bincount(owners)
39+
bin_count = bin_count.where(bin_count != 0, bin_count.new_ones(1))
40+
41+
if (num_owner is not None) and (bin_count.shape[0] != num_owner):
42+
difference = num_owner - bin_count.shape[0]
43+
bin_count = torch.cat([bin_count, bin_count.new_ones(difference)])
44+
45+
# make sure this operation is done on the same device of data and owners
46+
output = data.new_zeros([bin_count.shape[0], data.shape[1]])
47+
output = output.index_add_(0, owners, data)
48+
if average:
49+
output = (output.T / bin_count).T
50+
return output
51+
52+
53+
@torch.jit.export
54+
def get_graph_index(
55+
nlist: torch.Tensor,
56+
nlist_mask: torch.Tensor,
57+
a_nlist_mask: torch.Tensor,
58+
d_nlist_mask: torch.Tensor,
59+
nall: int,
60+
calculate_dihedral: bool = False,
61+
use_loc_mapping: bool = True,
62+
):
63+
"""
64+
Get the index mapping for edge graph and angle graph, ready in `aggregate` or `index_select`.
65+
66+
Parameters
67+
----------
68+
nlist : nf x nloc x nnei
69+
Neighbor list. (padded neis are set to 0)
70+
nlist_mask : nf x nloc x nnei
71+
Masks of the neighbor list. real nei 1 otherwise 0
72+
a_nlist_mask : nf x nloc x a_nnei
73+
Masks of the neighbor list for angle. real nei 1 otherwise 0
74+
nall
75+
The number of extended atoms.
76+
77+
Returns
78+
-------
79+
edge_index : n_edge x 2
80+
n2e_index : n_edge
81+
Broadcast indices from node(i) to edge(ij), or reduction indices from edge(ij) to node(i).
82+
n_ext2e_index : n_edge
83+
Broadcast indices from extended node(j) to edge(ij).
84+
angle_index : n_angle x 3
85+
n2a_index : n_angle
86+
Broadcast indices from extended node(j) to angle(ijk).
87+
eij2a_index : n_angle
88+
Broadcast indices from edge(ij) to angle(ijk), or reduction indices from angle(ijk) to edge(ij).
89+
eik2a_index : n_angle
90+
Broadcast indices from edge(ik) to angle(ijk).
91+
dihedral_index : n_dihedral x 2
92+
aijk2d_index : n_dihedral
93+
Broadcast indices from angle(ijk) to dihedral(ijkl), or reduction indices from dihedral(ijkl) to angle(ijk).
94+
aijl2d_index : n_dihedral
95+
Broadcast indices from angle(ijl) to dihedral(ijkl).
96+
"""
97+
nf, nloc, nnei = nlist.shape
98+
_, _, a_nnei = a_nlist_mask.shape
99+
# nf x nloc x nnei x nnei
100+
# nlist_mask_3d = nlist_mask[:, :, :, None] & nlist_mask[:, :, None, :]
101+
a_nlist_mask_3d = a_nlist_mask[:, :, :, None] & a_nlist_mask[:, :, None, :]
102+
n_edge = nlist_mask.sum().item()
103+
104+
# following: get n2e_index, n_ext2e_index, n2a_index, eij2a_index, eik2a_index
105+
106+
# 1. atom graph
107+
# node(i) to edge(ij) index_select; edge(ij) to node aggregate
108+
nlist_loc_index = torch.arange(0, nf * nloc, dtype=nlist.dtype, device=nlist.device)
109+
# nf x nloc x nnei
110+
n2e_index = nlist_loc_index.reshape(nf, nloc, 1).expand(-1, -1, nnei)
111+
# n_edge
112+
n2e_index = n2e_index[nlist_mask] # graph node index, atom_graph[:, 0]
113+
114+
# node_ext(j) to edge(ij) index_select
115+
frame_shift = torch.arange(0, nf, dtype=nlist.dtype, device=nlist.device) * (
116+
nall if not use_loc_mapping else nloc
117+
)
118+
shifted_nlist = nlist + frame_shift[:, None, None]
119+
# n_edge
120+
n_ext2e_index = shifted_nlist[nlist_mask] # graph neighbor index, atom_graph[:, 1]
121+
122+
# 2. edge graph
123+
# node(i) to angle(ijk) index_select
124+
n2a_index = nlist_loc_index.reshape(nf, nloc, 1, 1).expand(-1, -1, a_nnei, a_nnei)
125+
# n_angle
126+
n2a_index = n2a_index[a_nlist_mask_3d]
127+
128+
# edge(ij) to angle(ijk) index_select; angle(ijk) to edge(ij) aggregate
129+
edge_id = torch.arange(0, n_edge, dtype=nlist.dtype, device=nlist.device)
130+
# nf x nloc x nnei
131+
edge_index = torch.zeros([nf, nloc, nnei], dtype=nlist.dtype, device=nlist.device)
132+
edge_index[nlist_mask] = edge_id
133+
# only cut a_nnei neighbors, to avoid nnei x nnei
134+
edge_index = edge_index[:, :, :a_nnei]
135+
edge_index_ij = edge_index.unsqueeze(-1).expand(-1, -1, -1, a_nnei)
136+
# n_angle
137+
eij2a_index = edge_index_ij[a_nlist_mask_3d]
138+
139+
# edge(ik) to angle(ijk) index_select
140+
edge_index_ik = edge_index.unsqueeze(-2).expand(-1, -1, a_nnei, -1)
141+
# n_angle
142+
eik2a_index = edge_index_ik[a_nlist_mask_3d]
143+
144+
if calculate_dihedral:
145+
# 3. angle graph
146+
n_angle = a_nlist_mask_3d.sum().item()
147+
_, _, d_nnei = d_nlist_mask.shape
148+
149+
# nf x nloc x d_nnei x d_nnei x d_nnei
150+
# should expel same j k l
151+
d_nlist_mask_4d = (
152+
d_nlist_mask[:, :, :, None, None]
153+
& d_nlist_mask[:, :, None, :, None]
154+
& d_nlist_mask[:, :, None, None, :]
155+
)
156+
# d_nnei x d_nnei
157+
d_eye = torch.eye(d_nnei, dtype=d_nlist_mask.dtype, device=d_nlist_mask.device)
158+
d_eye = d_eye[:, :, None] | d_eye[:, None, :] | d_eye[None, :, :]
159+
d_nlist_mask_4d = d_nlist_mask_4d & ~d_eye[None, None, ...]
160+
161+
# angle(ijk) to dihedral(ijkl) index_select; dihedral(ijkl) to angle(ijk) aggregate
162+
angle_id = torch.arange(0, n_angle, dtype=nlist.dtype, device=nlist.device)
163+
# nf x nloc x a_nnei x a_nnei
164+
angle_index = torch.zeros(
165+
[nf, nloc, a_nnei, a_nnei], dtype=nlist.dtype, device=nlist.device
166+
)
167+
angle_index[a_nlist_mask_3d] = angle_id
168+
169+
# only cut d_nnei neighbors, to avoid a_nnei x a_nnei x a_nnei
170+
angle_index = angle_index[:, :, :d_nnei, :d_nnei]
171+
angle_index_ijk = angle_index.unsqueeze(-1).expand(-1, -1, -1, -1, d_nnei)
172+
# n_dihedral
173+
aijk2d_index = angle_index_ijk[d_nlist_mask_4d]
174+
175+
# angle(ijl) to dihedral(ijkl) index_select;
176+
angle_index_ijl = angle_index.unsqueeze(-2).expand(-1, -1, -1, d_nnei, -1)
177+
# n_dihedral
178+
aijl2d_index = angle_index_ijl[d_nlist_mask_4d]
179+
180+
dihedral_index = torch.cat(
181+
[aijk2d_index.unsqueeze(-1), aijl2d_index.unsqueeze(-1)], dim=-1
182+
)
183+
else:
184+
dihedral_index = None
185+
d_nlist_mask_4d = None
186+
187+
return (
188+
torch.cat([n2e_index.unsqueeze(-1), n_ext2e_index.unsqueeze(-1)], dim=-1),
189+
torch.cat(
190+
[
191+
n2a_index.unsqueeze(-1),
192+
eij2a_index.unsqueeze(-1),
193+
eik2a_index.unsqueeze(-1),
194+
],
195+
dim=-1,
196+
),
197+
dihedral_index,
198+
a_nlist_mask_3d,
199+
d_nlist_mask_4d,
200+
)
201+
202+
203+
class BesselBasis(torch.nn.Module):
204+
"""f : (*, 1) -> (*, bessel_basis_num)."""
205+
206+
def __init__(
207+
self,
208+
cutoff_length: float,
209+
bessel_basis_num: int = 8,
210+
trainable_coeff: bool = True,
211+
):
212+
super().__init__()
213+
self.num_basis = bessel_basis_num
214+
self.prefactor = 2.0 / cutoff_length
215+
self.coeffs = torch.FloatTensor(
216+
[n * math.pi / cutoff_length for n in range(1, bessel_basis_num + 1)]
217+
)
218+
if trainable_coeff:
219+
self.coeffs = torch.nn.Parameter(self.coeffs)
220+
221+
def forward(self, r: torch.Tensor) -> torch.Tensor:
222+
return self.prefactor * torch.sin(self.coeffs * r) / (r + 1e-8)
223+
224+
225+
class GaussianSmearing(torch.nn.Module):
226+
def __init__(
227+
self,
228+
start: float = -5.0,
229+
stop: float = 5.0,
230+
num_gaussians: int = 50,
231+
basis_width_scalar: float = 1.0,
232+
) -> None:
233+
super().__init__()
234+
self.num_output = num_gaussians
235+
offset = torch.linspace(
236+
start, stop, num_gaussians, device=env.DEVICE, dtype=torch.float32
237+
)
238+
self.coeff = -0.5 / (basis_width_scalar * (offset[1] - offset[0])).item() ** 2
239+
self.register_buffer("offset", offset)
240+
241+
def forward(self, dist) -> torch.Tensor:
242+
dist = dist - self.offset
243+
return torch.exp(self.coeff * torch.pow(dist, 2))
244+
245+
246+
class RadialMLP(torch.nn.Module):
247+
"""Contruct a radial function (linear layers + layer normalization + SiLU) given a list of channels."""
248+
249+
def __init__(self, channels_list) -> None:
250+
super().__init__()
251+
modules = []
252+
input_channels = channels_list[0]
253+
for i in range(len(channels_list)):
254+
if i == 0:
255+
continue
256+
257+
modules.append(torch.nn.Linear(input_channels, channels_list[i], bias=True))
258+
input_channels = channels_list[i]
259+
260+
if i == len(channels_list) - 1:
261+
break
262+
263+
modules.append(torch.nn.LayerNorm(channels_list[i]))
264+
modules.append(torch.nn.SiLU())
265+
266+
self.net = torch.nn.Sequential(*modules)
267+
268+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
269+
return self.net(inputs)
270+
271+
272+
class PolynomialEnvelope(torch.nn.Module):
273+
"""Polynomial envelope function that ensures a smooth cutoff."""
274+
275+
def __init__(self, exponent: int = 5) -> None:
276+
super().__init__()
277+
assert exponent > 0
278+
self.p: float = float(exponent)
279+
self.a: float = -(self.p + 1) * (self.p + 2) / 2
280+
self.b: float = self.p * (self.p + 2)
281+
self.c: float = -self.p * (self.p + 1) / 2
282+
283+
def forward(self, d_scaled: torch.Tensor) -> torch.Tensor:
284+
env_val = (
285+
1
286+
+ self.a * d_scaled**self.p
287+
+ self.b * d_scaled ** (self.p + 1)
288+
+ self.c * d_scaled ** (self.p + 2)
289+
)
290+
return torch.where(d_scaled < 1, env_val, torch.zeros_like(d_scaled))

deepmd/utils/argcheck.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1421,7 +1421,18 @@ def descrpt_dpa3_args():
14211421
default=False,
14221422
doc=doc_use_tebd_bias,
14231423
),
1424-
Argument("use_ext_ebd", bool, optional=True, default=False),
1424+
Argument(
1425+
"use_torch_embed",
1426+
bool,
1427+
optional=True,
1428+
default=False,
1429+
),
1430+
Argument(
1431+
"use_loc_mapping",
1432+
bool,
1433+
optional=True,
1434+
default=True,
1435+
),
14251436
]
14261437

14271438

0 commit comments

Comments
 (0)