77
88import torch
99
10- from e3nn .o3 import Irreps
10+ from e3nn .o3 import Irreps , Linear
11+ from e3nn .o3 import FullyConnectedTensorProduct as FCTP
1112import sevenn .util as util
1213from sevenn .nn .edge_embedding import (
1314 SphericalEncoding ,
@@ -185,6 +186,9 @@ def __init__(
185186 use_e3nn_denominator : bool = False ,
186187 e3nn_conv_l_max : int = 3 ,
187188 e3nn_use_edge_feat_weights : bool = False ,
189+ use_e3nn_angle_conv : bool = False ,
190+ e3nn_angle_conv_l_max : int = 2 ,
191+ e3nn_angle_conv_pattern : str = "64x0e+32x1e+32x2e" ,
188192 seed : Optional [Union [int , list [int ]]] = None ,
189193 ) -> None :
190194 r"""
@@ -484,6 +488,9 @@ def __init__(
484488 self .use_e3nn_denominator = use_e3nn_denominator
485489 self .e3nn_conv_l_max = e3nn_conv_l_max
486490 self .e3nn_use_edge_feat_weights = e3nn_use_edge_feat_weights
491+ self .use_e3nn_angle_conv = use_e3nn_angle_conv
492+ self .e3nn_angle_conv_l_max = e3nn_angle_conv_l_max
493+ self .e3nn_angle_conv_pattern = e3nn_angle_conv_pattern
487494
488495 if not self .edge_use_esen_rbf :
489496 self .edge_embd = MLPLayer (
@@ -540,6 +547,7 @@ def __init__(
540547 self .force_embedding_linear = None
541548
542549 layers = []
550+ # for node edge e3nn conv
543551 irreps_x = Irreps (f'{ self .n_dim } x0e' )
544552 self .lmax = e3nn_conv_l_max
545553 self .e3nn_conv_pattern = Irreps ("+" .join (e3nn_conv_pattern .split ("+" )[:self .lmax + 1 ]))
@@ -554,7 +562,36 @@ def __init__(
554562 self .edge_spherical_embd = None
555563 self .irreps_filter = "0e"
556564
565+ # for edge angle e3nn conv
566+ irreps_edge = Irreps (f'{ self .e_dim } x0e' )
567+ self .angle_lmax = e3nn_angle_conv_l_max
568+ self .e3nn_angle_conv_pattern = Irreps ("+" .join (e3nn_angle_conv_pattern .split ("+" )[:self .angle_lmax + 1 ]))
569+ if self .use_e3nn_angle_conv :
570+ self .edge_spherical_embd_for_angle = SphericalEncoding (self .angle_lmax , parity = 1 , normalize = True )
571+ self .irreps_angle_filter = self .edge_spherical_embd_for_angle .irreps_out
572+ self .angle_edge_filter_out = util .infer_irreps_out (
573+ self .irreps_angle_filter , # type: ignore
574+ self .irreps_angle_filter ,
575+ self .angle_lmax , # type: ignore
576+ 'full' ,
577+ False ,
578+ )
579+ # edge_i x edge_j --> linear --> angle_filter
580+ self .edge_to_angle_filter_prod = FCTP (self .irreps_angle_filter , self .irreps_angle_filter , self .angle_edge_filter_out )
581+ self .edge_to_angle_filter_linear = Linear (
582+ irreps_in = self .angle_edge_filter_out ,
583+ irreps_out = self .irreps_angle_filter ,
584+ biases = False ,
585+ )
586+ else :
587+ self .edge_spherical_embd_for_angle = None
588+ self .irreps_angle_filter = "0e"
589+ self .edge_to_angle_filter_prod = None
590+ self .edge_to_angle_filter_linear = None
591+
592+
557593 for ii in range (nlayers ):
594+ # for node edge e3nn conv
558595 irreps_out = Irreps (self .e3nn_conv_pattern )
559596 irreps_out_tp = util .infer_irreps_out (
560597 irreps_x , # type: ignore
@@ -573,6 +610,27 @@ def __init__(
573610 "weight_layer_input_to_hidden" : [8 , 64 , 64 ] if not self .e3nn_use_edge_feat_weights else [self .e_dim ],
574611 }
575612 irreps_x = irreps_out
613+
614+ # for edge angle e3nn conv
615+ irreps_edge_out = Irreps (self .e3nn_angle_conv_pattern )
616+ irreps_out_tp = util .infer_irreps_out (
617+ irreps_edge , # type: ignore
618+ self .irreps_angle_filter ,
619+ irreps_edge_out .lmax , # type: ignore
620+ 'full' ,
621+ False ,
622+ )
623+ e3nn_angle_conv_args = {
624+ "irreps_x" : irreps_edge ,
625+ "irreps_filter" : self .irreps_angle_filter ,
626+ "irreps_out_tp" : irreps_out_tp ,
627+ "irreps_out" : irreps_edge_out ,
628+ "denominator" : 1.0 if not self .use_e3nn_denominator else self .dynamic_a_sel / 4 ,
629+ "train_denominator" : True ,
630+ "weight_layer_input_to_hidden" : [self .a_dim ],
631+ }
632+ irreps_edge = irreps_edge_out
633+
576634 layers .append (
577635 RepFlowLayer (
578636 e_rcut = self .e_rcut ,
@@ -638,6 +696,9 @@ def __init__(
638696 e3nn_conv_pattern = self .e3nn_conv_pattern ,
639697 e3nn_use_edge_feat_weights = self .e3nn_use_edge_feat_weights ,
640698 e3nn_conv_args = e3nn_conv_args ,
699+ use_e3nn_angle_conv = self .use_e3nn_angle_conv ,
700+ e3nn_angle_conv_args = e3nn_angle_conv_args ,
701+ e3nn_angle_conv_pattern = self .e3nn_angle_conv_pattern ,
641702 seed = child_seed (child_seed (seed , 1 ), ii ),
642703 )
643704 )
@@ -1183,6 +1244,21 @@ def forward(
11831244 edge_sph = None
11841245 node_sph_embed = None
11851246
1247+ if self .use_e3nn_angle_conv :
1248+ assert self .use_dynamic_sel , "e3nn conv must use dynamic sel"
1249+ assert self .edge_spherical_embd_for_angle is not None
1250+ assert self .edge_to_angle_filter_prod is not None
1251+ assert self .edge_to_angle_filter_linear is not None
1252+ edge_sph_for_angle = self .edge_spherical_embd_for_angle (diff )
1253+ edge_i_index = angle_index [:, 1 ]
1254+ edge_j_index = angle_index [:, 2 ]
1255+ edge_angle_filter_tp = self .edge_to_angle_filter_prod (edge_sph_for_angle [edge_i_index ], edge_sph_for_angle [edge_j_index ])
1256+ edge_angle_filter = self .edge_to_angle_filter_linear (edge_angle_filter_tp )
1257+ edge_sph_embed = edge_ebd
1258+ else :
1259+ edge_angle_filter = None
1260+ edge_sph_embed = None
1261+
11861262 for idx , ll in enumerate (self .layers ):
11871263 # node_ebd: nb x nloc x n_dim
11881264 # node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parrallel_mode
@@ -1256,7 +1332,7 @@ def forward(
12561332 # for jit
12571333 assert not self .use_rk_update
12581334 assert dihedral_ebd is not None
1259- node_ebd , edge_ebd , angle_ebd , dihedral_ebd , ___ = ll .forward (
1335+ node_ebd , edge_ebd , angle_ebd , dihedral_ebd , ___ , ___ , = ll .forward (
12601336 node_ebd_ext ,
12611337 edge_ebd ,
12621338 h2 ,
@@ -1279,7 +1355,7 @@ def forward(
12791355 else :
12801356 assert dihedral_ebd is None
12811357 if not self .use_rk_update :
1282- node_ebd , edge_ebd , angle_ebd , ___ , node_sph_embed = ll .forward (
1358+ node_ebd , edge_ebd , angle_ebd , ___ , node_sph_embed , edge_sph_embed = ll .forward (
12831359 node_ebd_ext ,
12841360 edge_ebd ,
12851361 h2 ,
@@ -1301,6 +1377,8 @@ def forward(
13011377 edge_rbf_ebd = edge_rbf_ebd ,
13021378 edge_sph = edge_sph ,
13031379 node_sph_embed = node_sph_embed ,
1380+ edge_angle_filter = edge_angle_filter ,
1381+ edge_sph_embed = edge_sph_embed ,
13041382 )
13051383 # may cause jit slow, todo fix
13061384 elif not self .rk_update_diff_layer :
@@ -1321,6 +1399,7 @@ def forward(
13211399 angle_ebd_k1 ,
13221400 ___ ,
13231401 ___ ,
1402+ ___ ,
13241403 ) = ll .forward (
13251404 node_ebd_k_in ,
13261405 edge_ebd_k_in ,
0 commit comments