@@ -189,6 +189,7 @@ def __init__(
189189 use_e3nn_angle_conv : bool = False ,
190190 e3nn_angle_conv_l_max : int = 2 ,
191191 e3nn_angle_conv_pattern : str = "64x0e+32x1e+32x2e" ,
192+ e3nn_angle_use_cross : bool = False ,
192193 seed : Optional [Union [int , list [int ]]] = None ,
193194 ) -> None :
194195 r"""
@@ -491,6 +492,7 @@ def __init__(
491492 self .use_e3nn_angle_conv = use_e3nn_angle_conv
492493 self .e3nn_angle_conv_l_max = e3nn_angle_conv_l_max
493494 self .e3nn_angle_conv_pattern = e3nn_angle_conv_pattern
495+ self .e3nn_angle_use_cross = e3nn_angle_use_cross
494496
495497 if not self .edge_use_esen_rbf :
496498 self .edge_embd = MLPLayer (
@@ -607,7 +609,7 @@ def __init__(
607609 "irreps_out" : irreps_out ,
608610 "denominator" : 1.0 if not self .use_e3nn_denominator else self .dynamic_e_sel / 4 ,
609611 "train_denominator" : True ,
610- "weight_layer_input_to_hidden" : [8 , 64 , 64 ] if not self .e3nn_use_edge_feat_weights else [self .e_dim ],
612+ "weight_layer_input_to_hidden" : [8 , 64 , 64 ] if not self .e3nn_use_edge_feat_weights else [self .e_dim , 64 , 64 ],
611613 }
612614 irreps_x = irreps_out
613615
@@ -627,7 +629,7 @@ def __init__(
627629 "irreps_out" : irreps_edge_out ,
628630 "denominator" : 1.0 if not self .use_e3nn_denominator else self .dynamic_a_sel / 4 ,
629631 "train_denominator" : True ,
630- "weight_layer_input_to_hidden" : [self .a_dim ],
632+ "weight_layer_input_to_hidden" : [8 , 64 , 64 ] if self . e3nn_angle_use_cross else [ self .a_dim ],
631633 }
632634 irreps_edge = irreps_edge_out
633635
@@ -699,6 +701,7 @@ def __init__(
699701 use_e3nn_angle_conv = self .use_e3nn_angle_conv ,
700702 e3nn_angle_conv_args = e3nn_angle_conv_args ,
701703 e3nn_angle_conv_pattern = self .e3nn_angle_conv_pattern ,
704+ e3nn_angle_use_cross = self .e3nn_angle_use_cross ,
702705 seed = child_seed (child_seed (seed , 1 ), ii ),
703706 )
704707 )
@@ -1249,15 +1252,28 @@ def forward(
12491252 assert self .edge_spherical_embd_for_angle is not None
12501253 assert self .edge_to_angle_filter_prod is not None
12511254 assert self .edge_to_angle_filter_linear is not None
1252- edge_sph_for_angle = self . edge_spherical_embd_for_angle ( diff )
1255+ edge_sph_embed = edge_ebd
12531256 edge_i_index = angle_index [:, 1 ]
12541257 edge_j_index = angle_index [:, 2 ]
1255- edge_angle_filter_tp = self .edge_to_angle_filter_prod (edge_sph_for_angle [edge_i_index ], edge_sph_for_angle [edge_j_index ])
1256- edge_angle_filter = self .edge_to_angle_filter_linear (edge_angle_filter_tp )
1257- edge_sph_embed = edge_ebd
1258+ if not self .e3nn_angle_use_cross :
1259+ edge_sph_for_angle = self .edge_spherical_embd_for_angle (diff )
1260+ edge_angle_filter_tp = self .edge_to_angle_filter_prod (edge_sph_for_angle [edge_i_index ], edge_sph_for_angle [edge_j_index ])
1261+ edge_angle_filter = self .edge_to_angle_filter_linear (edge_angle_filter_tp )
1262+ angle_weights = None
1263+ else :
1264+ angle_cross_vec = torch .cross (diff [edge_i_index ], diff [edge_j_index ], dim = - 1 )
1265+ edge_angle_filter = self .edge_spherical_embd_for_angle (angle_cross_vec )
1266+ # angle basis as weights
1267+ # 1 - 1e-6 for torch.acos stability
1268+ cosine_ij = cosine_ij [a_nlist_mask ]
1269+ sine_ij = torch .sqrt (1 - cosine_ij ** 2 )
1270+ theta = torch .acos (cosine_ij ).unsqueeze (- 1 )
1271+ theta_list = torch .cat ([theta * 2 , theta * 4 , theta * 8 ], dim = - 1 )
1272+ angle_weights = torch .cat ([cosine_ij .unsqueeze (- 1 ), torch .cos (theta_list ), sine_ij .unsqueeze (- 1 ), torch .sin (theta_list )], dim = - 1 )
12581273 else :
12591274 edge_angle_filter = None
12601275 edge_sph_embed = None
1276+ angle_weights = None
12611277
12621278 for idx , ll in enumerate (self .layers ):
12631279 # node_ebd: nb x nloc x n_dim
@@ -1379,6 +1395,7 @@ def forward(
13791395 node_sph_embed = node_sph_embed ,
13801396 edge_angle_filter = edge_angle_filter ,
13811397 edge_sph_embed = edge_sph_embed ,
1398+ angle_weights = angle_weights ,
13821399 )
13831400 # may cause jit slow, todo fix
13841401 elif not self .rk_update_diff_layer :
0 commit comments