@@ -161,6 +161,7 @@ def __init__(
161161 angle_self_attention : bool = False ,
162162 angle_self_attention_gate : str = "none" ,
163163 rmsnorm_mode : str = "none" ,
164+ edge_rbf_cat_message : bool = False ,
164165 seed : Optional [Union [int , list [int ]]] = None ,
165166 ) -> None :
166167 r"""
@@ -314,6 +315,7 @@ def __init__(
314315 self .edge_use_esen_rbf = edge_use_esen_rbf
315316 self .edge_use_esen_atom_ebd = edge_use_esen_atom_ebd
316317 self .edge_use_esen_env = edge_use_esen_env
318+ self .edge_rbf_cat_message = edge_rbf_cat_message
317319 if self .edge_rbf_dot_self or self .edge_rbf_dot_message :
318320 assert self .edge_use_rbf or self .edge_use_concat_rbf , "rbf is not used"
319321 self .edge_embed_input_dim = 1
@@ -333,6 +335,9 @@ def __init__(
333335 elif self .edge_use_rbf :
334336 self .rbf = BesselBasis (self .e_rcut )
335337 self .edge_embed_input_dim = self .rbf .num_basis
338+ elif self .edge_rbf_cat_message :
339+ # edge can use dist itself
340+ self .rbf = BesselBasis (self .e_rcut )
336341 else :
337342 self .rbf = None
338343
@@ -379,6 +384,11 @@ def __init__(
379384 not self .optim_update
380385 ), "optim_update must be False when angle_use_node is False"
381386
387+ if self .edge_rbf_cat_message :
388+ assert (
389+ not self .optim_update
390+ ), "optim_update must be False when edge_rbf_cat_message is True"
391+
382392 if self .edge_use_esen_atom_ebd :
383393 self .source_embedding = torch .nn .Embedding (self .ntypes , self .e_dim )
384394 self .target_embedding = torch .nn .Embedding (self .ntypes , self .e_dim )
@@ -508,7 +518,9 @@ def __init__(
508518 edge_attn_use_ln = self .edge_attn_use_ln ,
509519 edge_rbf_dot_self = self .edge_rbf_dot_self ,
510520 edge_rbf_dot_message = self .edge_rbf_dot_message ,
511- rbf_dim = self .edge_embed_input_dim ,
521+ rbf_dim = self .edge_embed_input_dim
522+ if not self .edge_rbf_cat_message
523+ else self .rbf .num_basis ,
512524 residual_pref = self .residual_pref ,
513525 message_use_self_concat = self .message_use_self_concat ,
514526 use_slim_message = self .use_slim_message ,
@@ -520,6 +532,7 @@ def __init__(
520532 angle_self_attention = self .angle_self_attention ,
521533 angle_self_attention_gate = self .angle_self_attention_gate ,
522534 rmsnorm_mode = self .rmsnorm_mode ,
535+ edge_rbf_cat_message = self .edge_rbf_cat_message ,
523536 seed = child_seed (child_seed (seed , 1 ), ii ),
524537 )
525538 )
@@ -934,7 +947,11 @@ def forward(
934947 edge_ebd = self .edge_embd (rbf_input )
935948 elif self .edge_use_dist :
936949 edge_ebd = self .edge_embd (edge_input )
937- rbf_ebd = None
950+ if not self .edge_rbf_cat_message :
951+ rbf_ebd = None
952+ else :
953+ assert self .rbf is not None
954+ rbf_ebd = self .rbf (edge_input )
938955 elif self .edge_use_concat_rbf :
939956 assert self .rbf is not None
940957 rbf_ebd = torch .cat ([dmatrix [..., :1 ], self .rbf (edge_input )], dim = - 1 )
0 commit comments