@@ -122,6 +122,7 @@ def __init__(
122122 use_ffn_edge_angle_message : bool = False ,
123123 use_ffn_angle_angle_message : bool = False ,
124124 ffn_hidden_dim : int = 1024 ,
125+ edge_use_concat_rbf : bool = False ,
125126 edge_use_rbf : bool = False ,
126127 edge_use_dist : bool = False ,
127128 embed_use_bias : bool = True ,
@@ -262,11 +263,15 @@ def __init__(
262263 self .use_ffn_edge_angle_message = use_ffn_edge_angle_message
263264 self .use_ffn_angle_angle_message = use_ffn_angle_angle_message
264265 self .ffn_hidden_dim = ffn_hidden_dim
266+ self .edge_use_concat_rbf = edge_use_concat_rbf
265267 self .edge_use_rbf = edge_use_rbf
266268 self .edge_use_dist = edge_use_dist
267269 self .embed_use_bias = embed_use_bias
268270 self .edge_embed_input_dim = 1
269- if self .edge_use_rbf :
271+ if self .edge_use_concat_rbf :
272+ self .rbf = BesselBasis (self .e_rcut )
273+ self .edge_embed_input_dim = 1 + self .rbf .num_basis
274+ elif self .edge_use_rbf :
270275 self .rbf = BesselBasis (self .e_rcut )
271276 self .edge_embed_input_dim = self .rbf .num_basis
272277 else :
@@ -536,7 +541,7 @@ def forward(
536541 # get edge and angle embedding input
537542 # nb x nloc x nnei x 1, nb x nloc x nnei x 3
538543 edge_input , h2 = torch .split (dmatrix , [1 , 3 ], dim = - 1 )
539- if self .edge_use_rbf or self .edge_use_dist :
544+ if self .edge_use_concat_rbf or self . edge_use_rbf or self .edge_use_dist :
540545 # nb x nloc x nnei x 1
541546 edge_input = torch .linalg .norm (diff , dim = - 1 , keepdim = True )
542547 # nf x nloc x a_nnei x 3
@@ -668,6 +673,11 @@ def forward(
668673 # nb x nloc x nnei x e_dim [OR] n_edge x e_dim
669674 if self .edge_use_dist :
670675 edge_ebd = self .edge_embd (edge_input )
676+ elif self .edge_use_concat_rbf :
677+ assert self .rbf is not None
678+ edge_ebd = self .edge_embd (
679+ torch .cat ([dmatrix [..., :1 ], self .rbf (edge_input )], dim = - 1 )
680+ )
671681 elif self .edge_use_rbf :
672682 assert self .rbf is not None
673683 edge_ebd = self .edge_embd (self .rbf (edge_input ))
0 commit comments