@@ -152,8 +152,6 @@ def __init__(
152152 ):
153153 assert not self .optim_update , "FFN does not support optim update!"
154154
155- if self .update_dihedral :
156- assert self .use_dynamic_sel , "Dihedral update requires dynamic selection!"
157155 self .edge_use_attn = edge_use_attn
158156 self .edge_attn_hidden = edge_attn_hidden
159157 self .edge_attn_head = edge_attn_head
@@ -725,7 +723,47 @@ def symmetrization_op_dynamic(
725723 # nb x nloc x (axis x e_dim)
726724 grrg = self ._cal_grrg (h2g2 , axis_neuron )
727725 return grrg
726+
727+ def optim_dihedral_update (
728+ self ,
729+ dihedral_ebd : torch .Tensor ,
730+ angle_ebd : torch .Tensor ,
731+ feat : str = "angle" ,
732+ ) -> torch .Tensor :
733+ angle_dim = angle_ebd .shape [- 1 ]
734+ dihedral_dim = dihedral_ebd .shape [- 1 ]
735+ sub_dihedral_idx = (0 , angle_dim )
736+ sub_angle_idx_ijk = (angle_dim , angle_dim + angle_dim )
737+ sub_edge_idx_ijl = (angle_dim + angle_dim , angle_dim + angle_dim + angle_dim )
738+
739+ if feat == "angle" :
740+ matrix , bias = self .angle_dihedral_linear .matrix , self .angle_dihedral_linear .bias
741+ elif feat == "dihedral" :
742+ matrix , bias = self .dihedral_self_linear .matrix , self .dihedral_self_linear .bias
743+ else :
744+ raise NotImplementedError
745+ assert dihedral_dim + 2 * angle_dim == matrix .size ()[0 ]
746+
747+ sub_dihedral_update = torch .matmul (
748+ dihedral_ebd , matrix [sub_dihedral_idx [0 ] : sub_dihedral_idx [1 ]]
749+ )
750+
751+ sub_angle_update_ijk = torch .matmul (
752+ angle_ebd , matrix [sub_angle_idx_ijk [0 ] : sub_angle_idx_ijk [1 ]]
753+ )
754+
755+ sub_angle_update_ijl = torch .matmul (
756+ angle_ebd , matrix [sub_edge_idx_ijl [0 ] : sub_edge_idx_ijl [1 ]]
757+ )
758+ result_update = (
759+ sub_dihedral_update
760+ + sub_angle_update_ijk [:, :, :, :, None , :]
761+ + sub_angle_update_ijl [:, :, :, None , :, :]
762+ ) + bias
763+ return result_update
764+
728765
766+
729767 def optim_angle_update (
730768 self ,
731769 angle_ebd : torch .Tensor ,
@@ -945,6 +983,8 @@ def forward(
945983 a_sw : torch .Tensor , # switch func, nf x nloc x a_nnei
946984 edge_index : torch .Tensor , # n_edge x 2
947985 angle_index : torch .Tensor , # n_angle x 3
986+ d_nlist : Optional [torch .Tensor ] = None , # nf x nloc x d_nnei
987+ d_nlist_mask : Optional [torch .Tensor ] = None , # nf x nloc x d_nnei
948988 dihedral_index : Optional [torch .Tensor ] = None , # n_dihedral x 2
949989 dihedral_ebd : Optional [torch .Tensor ] = None , # n_dihedral x d_dim
950990 d_sw : Optional [torch .Tensor ] = None , # n_dihedral
@@ -1239,7 +1279,7 @@ def forward(
12391279 if self .edge_attn_use_ln :
12401280 edge_attention_update = self .edge_lm (edge_attention_update )
12411281 e_update_list .append (edge_attention_update )
1242-
1282+
12431283 if self .update_angle :
12441284 assert self .angle_self_linear is not None
12451285 assert self .edge_angle_linear1 is not None
@@ -1329,6 +1369,7 @@ def forward(
13291369 self .edge_angle_linear1 (angle_info_ffn )
13301370 )
13311371 else :
1372+
13321373 edge_angle_update = self .act (
13331374 self .optim_angle_update (
13341375 angle_ebd ,
@@ -1445,9 +1486,83 @@ def forward(
14451486 )
14461487 )
14471488 a_update_list .append (angle_self_update )
1448- if self .update_dihedral :
1489+
1490+ # dihedral update with fixed sel
1491+ if self .update_dihedral and not self .use_dynamic_sel :
1492+ assert d_nlist is not None
1493+ assert d_nlist_mask is not None
1494+ assert dihedral_ebd is not None
1495+ assert d_sw is not None
1496+ assert self .angle_dihedral_linear is not None
1497+
1498+ # nb x nloc x d_sel x d_sel x e_dim
1499+ angle_ebd_for_dihedral = angle_ebd [:, :, :self .d_sel , :self .d_sel , :]
1500+ # nb x nloc x d_sel x d_sel x e_dim
1501+ d_nlist_mask = d_nlist_mask [:,:,:,None ] * d_nlist_mask [:,:,None ,:]
1502+ angle_ebd_for_dihedral = torch .where (
1503+ d_nlist_mask .unsqueeze (- 1 ), angle_ebd_for_dihedral , 0.0
1504+ )
1505+
1506+ # nb x nloc x d_sel x d_sel x d_sel x a_dim
1507+ angle_dihedral_update = self .act (
1508+ self .optim_dihedral_update (
1509+ dihedral_ebd ,
1510+ angle_ebd_for_dihedral ,
1511+ "angle" ,
1512+ )
1513+ )
1514+ # nb x nloc x d_sel x d_sel x d_sel x a_dim
1515+ weighted_angle_dihedral_update = (
1516+ angle_dihedral_update
1517+ * d_sw [:, :, :, None , None , None ]
1518+ * d_sw [:, :, None , :, None , None ]
1519+ * d_sw [:, :, None , None , :, None ]
1520+ )
1521+ # nb x nloc x d_sel x d_sel x a_dim
1522+ reduced_angle_dihedral_update = torch .sum (
1523+ weighted_angle_dihedral_update , dim = - 2
1524+ ) / (self .d_sel ** 0.5 )
1525+
1526+ # Need two dimensional padding
1527+ # nb x nloc x a_sel x a_sel x a_dim
1528+ padding_angle_dihedral_update = torch .concat (
1529+ [
1530+ reduced_angle_dihedral_update ,
1531+ torch .zeros (
1532+ [nb , nloc , self .d_sel , self .a_sel - self .d_sel , self .a_dim ],
1533+ dtype = edge_ebd .dtype ,
1534+ device = edge_ebd .device ,
1535+ ),
1536+ ],
1537+ dim = - 2 ,
1538+ )
1539+ padding_angle_dihedral_update = torch .concat (
1540+ [
1541+ padding_angle_dihedral_update ,
1542+ torch .zeros (
1543+ [nb , nloc , self .a_sel - self .d_sel , self .a_sel , self .a_dim ],
1544+ dtype = edge_ebd .dtype ,
1545+ device = edge_ebd .device ,
1546+ ),
1547+ ],
1548+ dim = - 3 ,
1549+ )
1550+ a_update_list .append (padding_angle_dihedral_update )
1551+
1552+ dihedral_self_update = self .act (
1553+ self .optim_dihedral_update (
1554+ dihedral_ebd ,
1555+ angle_ebd_for_dihedral ,
1556+ "dihedral" ,
1557+ )
1558+ )
1559+
1560+ d_update_list : list [torch .Tensor ] = [dihedral_ebd , dihedral_self_update ]
1561+ d_updated = self .list_update (d_update_list , "dihedral" )
1562+
1563+ # dihedral update with dynamic sel
1564+ elif self .update_dihedral and self .use_dynamic_sel :
14491565 n_angle = int (a_nlist_mask .sum ().item ())
1450- assert self .use_dynamic_sel , "dihedral update only support dynamic sel"
14511566 assert dihedral_ebd is not None
14521567 assert d_sw is not None
14531568 assert dihedral_index is not None
0 commit comments