Skip to content

Commit 735c578

Browse files
author
liwentao
authored
fix dihedral selection (#46)
1 parent c8b28b0 commit 735c578

2 files changed

Lines changed: 124 additions & 6 deletions

File tree

deepmd/pt/model/descriptor/repflow_layer.py

Lines changed: 120 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,6 @@ def __init__(
152152
):
153153
assert not self.optim_update, "FFN does not support optim update!"
154154

155-
if self.update_dihedral:
156-
assert self.use_dynamic_sel, "Dihedral update requires dynamic selection!"
157155
self.edge_use_attn = edge_use_attn
158156
self.edge_attn_hidden = edge_attn_hidden
159157
self.edge_attn_head = edge_attn_head
@@ -725,7 +723,47 @@ def symmetrization_op_dynamic(
725723
# nb x nloc x (axis x e_dim)
726724
grrg = self._cal_grrg(h2g2, axis_neuron)
727725
return grrg
726+
727+
def optim_dihedral_update(
728+
self,
729+
dihedral_ebd: torch.Tensor,
730+
angle_ebd: torch.Tensor,
731+
feat: str = "angle",
732+
) -> torch.Tensor:
733+
angle_dim = angle_ebd.shape[-1]
734+
dihedral_dim = dihedral_ebd.shape[-1]
735+
sub_dihedral_idx = (0, angle_dim)
736+
sub_angle_idx_ijk = (angle_dim, angle_dim + angle_dim)
737+
sub_edge_idx_ijl = (angle_dim + angle_dim, angle_dim + angle_dim + angle_dim)
738+
739+
if feat == "angle":
740+
matrix, bias = self.angle_dihedral_linear.matrix, self.angle_dihedral_linear.bias
741+
elif feat == "dihedral":
742+
matrix, bias = self.dihedral_self_linear.matrix, self.dihedral_self_linear.bias
743+
else:
744+
raise NotImplementedError
745+
assert dihedral_dim + 2 * angle_dim == matrix.size()[0]
746+
747+
sub_dihedral_update = torch.matmul(
748+
dihedral_ebd, matrix[sub_dihedral_idx[0] : sub_dihedral_idx[1]]
749+
)
750+
751+
sub_angle_update_ijk = torch.matmul(
752+
angle_ebd, matrix[sub_angle_idx_ijk[0] : sub_angle_idx_ijk[1]]
753+
)
754+
755+
sub_angle_update_ijl = torch.matmul(
756+
angle_ebd, matrix[sub_edge_idx_ijl[0] : sub_edge_idx_ijl[1]]
757+
)
758+
result_update = (
759+
sub_dihedral_update
760+
+ sub_angle_update_ijk[:, :, :, :, None, :]
761+
+ sub_angle_update_ijl[:, :, :, None, :, :]
762+
) + bias
763+
return result_update
764+
728765

766+
729767
def optim_angle_update(
730768
self,
731769
angle_ebd: torch.Tensor,
@@ -945,6 +983,8 @@ def forward(
945983
a_sw: torch.Tensor, # switch func, nf x nloc x a_nnei
946984
edge_index: torch.Tensor, # n_edge x 2
947985
angle_index: torch.Tensor, # n_angle x 3
986+
d_nlist: Optional[torch.Tensor] = None, # nf x nloc x d_nnei
987+
d_nlist_mask: Optional[torch.Tensor] = None, # nf x nloc x d_nnei
948988
dihedral_index: Optional[torch.Tensor] = None, # n_dihedral x 2
949989
dihedral_ebd: Optional[torch.Tensor] = None, # n_dihedral x d_dim
950990
d_sw: Optional[torch.Tensor] = None, # n_dihedral
@@ -1239,7 +1279,7 @@ def forward(
12391279
if self.edge_attn_use_ln:
12401280
edge_attention_update = self.edge_lm(edge_attention_update)
12411281
e_update_list.append(edge_attention_update)
1242-
1282+
12431283
if self.update_angle:
12441284
assert self.angle_self_linear is not None
12451285
assert self.edge_angle_linear1 is not None
@@ -1329,6 +1369,7 @@ def forward(
13291369
self.edge_angle_linear1(angle_info_ffn)
13301370
)
13311371
else:
1372+
13321373
edge_angle_update = self.act(
13331374
self.optim_angle_update(
13341375
angle_ebd,
@@ -1445,9 +1486,83 @@ def forward(
14451486
)
14461487
)
14471488
a_update_list.append(angle_self_update)
1448-
if self.update_dihedral:
1489+
1490+
# dihedral update with fixed sel
1491+
if self.update_dihedral and not self.use_dynamic_sel:
1492+
assert d_nlist is not None
1493+
assert d_nlist_mask is not None
1494+
assert dihedral_ebd is not None
1495+
assert d_sw is not None
1496+
assert self.angle_dihedral_linear is not None
1497+
1498+
# nb x nloc x d_sel x d_sel x e_dim
1499+
angle_ebd_for_dihedral = angle_ebd[:, :, :self.d_sel, :self.d_sel, :]
1500+
# nb x nloc x d_sel x d_sel x e_dim
1501+
d_nlist_mask = d_nlist_mask[:,:,:,None] * d_nlist_mask[:,:,None,:]
1502+
angle_ebd_for_dihedral = torch.where(
1503+
d_nlist_mask.unsqueeze(-1), angle_ebd_for_dihedral, 0.0
1504+
)
1505+
1506+
# nb x nloc x d_sel x d_sel x d_sel x a_dim
1507+
angle_dihedral_update = self.act(
1508+
self.optim_dihedral_update(
1509+
dihedral_ebd,
1510+
angle_ebd_for_dihedral,
1511+
"angle",
1512+
)
1513+
)
1514+
# nb x nloc x d_sel x d_sel x d_sel x a_dim
1515+
weighted_angle_dihedral_update = (
1516+
angle_dihedral_update
1517+
* d_sw[:, :, :, None, None, None]
1518+
* d_sw[:, :, None, :, None, None]
1519+
* d_sw[:, :, None, None, :, None]
1520+
)
1521+
# nb x nloc x d_sel x d_sel x a_dim
1522+
reduced_angle_dihedral_update = torch.sum(
1523+
weighted_angle_dihedral_update, dim=-2
1524+
) / (self.d_sel**0.5)
1525+
1526+
# Need two dimensional padding
1527+
# nb x nloc x a_sel x a_sel x a_dim
1528+
padding_angle_dihedral_update = torch.concat(
1529+
[
1530+
reduced_angle_dihedral_update,
1531+
torch.zeros(
1532+
[nb, nloc, self.d_sel, self.a_sel - self.d_sel, self.a_dim],
1533+
dtype=edge_ebd.dtype,
1534+
device=edge_ebd.device,
1535+
),
1536+
],
1537+
dim=-2,
1538+
)
1539+
padding_angle_dihedral_update = torch.concat(
1540+
[
1541+
padding_angle_dihedral_update,
1542+
torch.zeros(
1543+
[nb, nloc, self.a_sel-self.d_sel, self.a_sel, self.a_dim],
1544+
dtype=edge_ebd.dtype,
1545+
device=edge_ebd.device,
1546+
),
1547+
],
1548+
dim=-3,
1549+
)
1550+
a_update_list.append(padding_angle_dihedral_update)
1551+
1552+
dihedral_self_update = self.act(
1553+
self.optim_dihedral_update(
1554+
dihedral_ebd,
1555+
angle_ebd_for_dihedral,
1556+
"dihedral",
1557+
)
1558+
)
1559+
1560+
d_update_list: list[torch.Tensor] = [dihedral_ebd, dihedral_self_update]
1561+
d_updated = self.list_update(d_update_list, "dihedral")
1562+
1563+
# dihedral update with dynamic sel
1564+
elif self.update_dihedral and self.use_dynamic_sel:
14491565
n_angle = int(a_nlist_mask.sum().item())
1450-
assert self.use_dynamic_sel, "dihedral update only support dynamic sel"
14511566
assert dihedral_ebd is not None
14521567
assert d_sw is not None
14531568
assert dihedral_index is not None

deepmd/pt/model/descriptor/repflows.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ def forward(
739739
else:
740740
d_sw = None
741741
dihedral_input = None
742-
742+
743743
if self.edge_use_esen_atom_ebd:
744744
# nf x (nl x nnei)
745745
nlist_index = nlist.reshape(nframes, nloc * nnei)
@@ -915,6 +915,7 @@ def forward(
915915
node_ebd_ext = concat_switch_virtual(
916916
node_ebd_real_ext, node_ebd_virtual_ext, real_nloc
917917
)
918+
918919
node_ebd, edge_ebd, angle_ebd, dihedral_ebd = ll.forward(
919920
node_ebd_ext,
920921
edge_ebd,
@@ -926,6 +927,8 @@ def forward(
926927
a_nlist,
927928
a_nlist_mask,
928929
a_sw,
930+
d_nlist=d_nlist,
931+
d_nlist_mask = d_nlist_mask,
929932
edge_index=edge_index,
930933
angle_index=angle_index,
931934
dihedral_index=dihedral_index,

0 commit comments

Comments
 (0)