@@ -90,6 +90,8 @@ def __init__(
9090 only_angle_gated_mlp : bool = False ,
9191 node_use_rmsnorm : bool = False ,
9292 angle_use_node : bool = True ,
93+ angle_self_attention : bool = False ,
94+ angle_self_attention_gate : str = "none" ,
9395 activation_function : str = "silu" ,
9496 update_style : str = "res_residual" ,
9597 update_residual : float = 0.1 ,
@@ -188,6 +190,8 @@ def __init__(
188190 self .node_rmsnorm = None
189191
190192 self .angle_use_node = angle_use_node
193+ self .angle_self_attention = angle_self_attention
194+ self .angle_self_attention_gate = angle_self_attention_gate
191195
192196 if self .edge_rbf_dot_self or self .edge_rbf_dot_message :
193197 self .rbf_mlp = MLPLayer (
@@ -501,6 +505,23 @@ def __init__(
501505 )
502506 )
503507
508+ if self .angle_self_attention :
509+ self .angle_attention_mlp_in = MLPLayer (
510+ self .a_dim ,
511+ self .a_dim * 3 , # query, key, value
512+ precision = precision ,
513+ seed = child_seed (seed , 21 ),
514+ )
515+ self .angle_attention_mlp_out = MLPLayer (
516+ self .a_dim ,
517+ self .a_dim ,
518+ precision = precision ,
519+ seed = child_seed (seed , 22 ),
520+ )
521+ else :
522+ self .angle_attention_mlp_in = None
523+ self .angle_attention_mlp_out = None
524+
504525 if self .update_dihedral :
505526 self .dihedral_dim = self .d_dim + 2 * self .a_dim
506527 # angle dihedral message
@@ -1581,6 +1602,63 @@ def forward(
15811602 )
15821603 a_update_list .append (angle_self_update )
15831604
1605+ if self .angle_self_attention :
1606+ # add a self-attention mechanism for angle_ebd with shape [nb x nloc x a_nnei x a_nnei x a_dim], on the last two dimensions
1607+ assert self .angle_attention_mlp_in is not None
1608+ assert self .angle_attention_mlp_out is not None
1609+ # nb x nloc x a_nnei x a_nnei x (3 * a_dim)
1610+ attention_output = self .angle_attention_mlp_in (angle_ebd )
1611+ # nb x nloc x a_nnei x a_nnei x a_dim
1612+ query , key , value = torch .chunk (
1613+ attention_output , 3 , dim = - 1
1614+ ) # Split into query, key, value
1615+ # nb x nloc x a_nnei x a_nnei x a_nnei
1616+ attention_scores = torch .matmul (query , key .transpose (- 2 , - 1 )) / (
1617+ query .size (- 1 ) ** 0.5
1618+ ) # Scaled dot-product attention
1619+ # smooth
1620+ attention_scores = (attention_scores + 20.0 ) * a_sw [
1621+ :, :, None , :, None
1622+ ] * a_sw [:, :, None , None , :] - 20.0
1623+ # nb x nloc x a_nnei x a_nnei x a_nnei
1624+ attention_weights = torch .softmax (
1625+ attention_scores , dim = - 1
1626+ ) # Normalize scores
1627+ # smooth
1628+ attention_weights = (
1629+ attention_weights
1630+ * a_sw [:, :, None , :, None ]
1631+ * a_sw [:, :, None , None , :]
1632+ )
1633+ # optional gates
1634+ if self .angle_self_attention_gate == "edge" :
1635+ # nb x nloc x a_nnei x 3
1636+ h2_angle = h2 [..., : self .a_sel , :]
1637+ # normalize
1638+ h2_angle = h2_angle / torch .linalg .norm (
1639+ h2_angle , dim = - 1 , keepdim = True
1640+ )
1641+ # nb x nloc x a_nnei x 3
1642+ h2_angle = torch .where (
1643+ a_nlist_mask .unsqueeze (- 1 ).expand ([- 1 , - 1 , - 1 , 3 ]),
1644+ h2_angle ,
1645+ 0.0 ,
1646+ )
1647+ # nb x nloc x a_nnei x a_nnei
1648+ h2h2t = torch .matmul (h2_angle , torch .transpose (h2_angle , - 1 , - 2 ))
1649+ # nb x nloc x a_nnei x a_nnei x a_nnei
1650+ attention_weights = attention_weights * h2h2t [:, :, None , :, :]
1651+
1652+ # nb x nloc x a_nnei x a_nnei x a_dim
1653+ angle_ebd_attended = torch .matmul (
1654+ attention_weights , value
1655+ ) # Apply attention weights to value
1656+ # nb x nloc x a_nnei x a_nnei x a_dim
1657+ angle_attention_updated = self .act (
1658+ self .angle_attention_mlp_out (angle_ebd_attended )
1659+ ) # Apply attention output layer
1660+ a_update_list .append (angle_attention_updated )
1661+
15841662 # dihedral update with fixed sel
15851663 if self .update_dihedral and not self .use_dynamic_sel :
15861664 assert d_nlist is not None
0 commit comments