@@ -187,6 +187,7 @@ def __init__(
187187 e3nn_conv_use_edge_sh_feat : bool = False ,
188188 edge_sh_feat_use_rbf_weights : bool = False ,
189189 e3nn_conv_use_vi : bool = False ,
190+ e3nn_conv_weights_use_tebd : bool = False ,
190191 e3nn_conv_l_max : int = 3 ,
191192 e3nn_use_edge_feat_weights : bool = False ,
192193 use_e3nn_angle_conv : bool = False ,
@@ -493,6 +494,7 @@ def __init__(
493494 self .use_e3nn_denominator = use_e3nn_denominator
494495 self .e3nn_conv_use_edge_sh_feat = e3nn_conv_use_edge_sh_feat
495496 self .e3nn_conv_use_vi = e3nn_conv_use_vi
497+ self .e3nn_conv_weights_use_tebd = e3nn_conv_weights_use_tebd
496498 if self .e3nn_conv_use_vi :
497499 assert e3nn_conv_use_edge_sh_feat , "e3nn_conv_use_edge_sh_feat must be True when e3nn_conv_use_vi is True"
498500 self .e3nn_conv_l_max = e3nn_conv_l_max
@@ -651,7 +653,7 @@ def __init__(
651653 "irreps_out" : irreps_out ,
652654 "denominator" : 1.0 if not self .use_e3nn_denominator else self .dynamic_e_sel / 4 ,
653655 "train_denominator" : True ,
654- "weight_layer_input_to_hidden" : [8 , 64 , 64 ] if not self .e3nn_use_edge_feat_weights else [self .e_dim , 64 , 64 ],
656+ "weight_layer_input_to_hidden" : [8 if not self . e3nn_conv_weights_use_tebd else 8 + 2 * self . n_dim , 64 , 64 ] if not self .e3nn_use_edge_feat_weights else [self .e_dim , 64 , 64 ],
655657 }
656658 irreps_x = irreps_out
657659
@@ -1284,7 +1286,13 @@ def forward(
12841286 assert edge_dist is not None
12851287 # n_edge x rbf
12861288 edge_env = self .edge_env (edge_dist / self .e_rcut )
1287- edge_rbf_ebd = self .edge_rbf_embed (edge_dist ) * edge_env
1289+ if not self .e3nn_conv_weights_use_tebd :
1290+ edge_rbf_ebd = self .edge_rbf_embed (edge_dist ) * edge_env
1291+ else :
1292+ edge_src = edge_index [:, 1 ]
1293+ edge_dst = edge_index [:, 0 ]
1294+ atype_embd_reshape = atype_embd .reshape (nframes * nloc , - 1 )
1295+ edge_rbf_ebd = torch .cat ([self .edge_rbf_embed (edge_dist ), atype_embd_reshape [edge_dst ], atype_embd_reshape [edge_src ]], dim = - 1 ) * edge_env
12881296 # n_edge x num_sph(16)
12891297 edge_sph = self .edge_spherical_embd (diff )
12901298 node_sph_embed = node_ebd
0 commit comments