77
88import torch
99
10- from e3nn .o3 import Irreps , Linear
10+ from e3nn .o3 import Irreps , Linear , TensorProduct
1111from e3nn .o3 import FullyConnectedTensorProduct as FCTP
1212import sevenn .util as util
1313from sevenn .nn .edge_embedding import (
@@ -184,6 +184,7 @@ def __init__(
184184 use_e3nn_conv : bool = False ,
185185 e3nn_conv_pattern : str = "128x0e+64x1e+32x2e+32x3e" ,
186186 use_e3nn_denominator : bool = False ,
187+ e3nn_conv_use_edge_sh_feat : bool = False ,
187188 e3nn_conv_l_max : int = 3 ,
188189 e3nn_use_edge_feat_weights : bool = False ,
189190 use_e3nn_angle_conv : bool = False ,
@@ -488,6 +489,7 @@ def __init__(
488489 self .use_e3nn_conv = use_e3nn_conv
489490 self .e3nn_conv_pattern = e3nn_conv_pattern
490491 self .use_e3nn_denominator = use_e3nn_denominator
492+ self .e3nn_conv_use_edge_sh_feat = e3nn_conv_use_edge_sh_feat
491493 self .e3nn_conv_l_max = e3nn_conv_l_max
492494 self .e3nn_use_edge_feat_weights = e3nn_use_edge_feat_weights
493495 self .use_e3nn_angle_conv = use_e3nn_angle_conv
@@ -564,10 +566,42 @@ def __init__(
564566 self .edge_rbf_embed = None
565567 self .edge_env = None
566568 self .edge_spherical_embd = None
567- self .irreps_filter = "0e"
569+ self .irreps_filter = Irreps ("0e" )
570+
571+ if self .e3nn_conv_use_edge_sh_feat :
572+ irreps_edge = Irreps (f"{ self .e_dim } x0e" )
573+ irreps_filter = self .irreps_filter
574+ irreps_out = ((irreps_filter * self .e_dim ).sort ()[0 ]).simplify ()
575+ self .edge_sph_embd_ir = Irreps (self .e3nn_angle_conv_pattern )
576+ instructions = []
577+ irreps_mid = []
578+ weight_numel = 0
579+ for i , (mul_x , ir_x ) in enumerate (irreps_edge ):
580+ for j , (_ , ir_filter ) in enumerate (irreps_filter ):
581+ for ir_out in ir_x * ir_filter :
582+ if ir_out in irreps_out : # here we drop l > lmax
583+ k = len (irreps_mid )
584+ weight_numel += mul_x * 1 # path shape
585+ irreps_mid .append ((mul_x , ir_out ))
586+ instructions .append ((i , j , k , 'uvu' , False ))
587+
588+ irreps_mid = Irreps (irreps_mid )
589+ self .edge_sph_embd_init_tp = TensorProduct (
590+ irreps_in1 = irreps_edge ,
591+ irreps_in2 = irreps_filter ,
592+ irreps_out = irreps_mid ,
593+ instructions = instructions ,
594+ shared_weights = False ,
595+ internal_weights = False ,
596+ )
597+ self .edge_sph_embd_init_linear = Linear (irreps_mid , self .edge_sph_embd_ir )
598+ else :
599+ self .edge_sph_embd_init_tp = None
600+ self .edge_sph_embd_ir = Irreps (f'{ self .e_dim } x0e' )
601+ self .edge_sph_embd_init_linear = None
568602
569603 # for edge angle e3nn conv
570- irreps_edge = Irreps ( f' { self .e_dim } x0e' )
604+ irreps_edge = self .edge_sph_embd_ir
571605 self .angle_lmax = e3nn_angle_conv_l_max
572606 self .e3nn_angle_conv_pattern = Irreps ("+" .join (e3nn_angle_conv_pattern .split ("+" )[:self .angle_lmax + 1 ]))
573607 if self .use_e3nn_angle_conv :
@@ -699,6 +733,7 @@ def __init__(
699733 use_e3nn_conv = self .use_e3nn_conv ,
700734 e3nn_conv_pattern = self .e3nn_conv_pattern ,
701735 e3nn_use_edge_feat_weights = self .e3nn_use_edge_feat_weights ,
736+ e3nn_conv_use_edge_sh_feat = self .e3nn_conv_use_edge_sh_feat ,
702737 e3nn_conv_args = e3nn_conv_args ,
703738 use_e3nn_angle_conv = self .use_e3nn_angle_conv ,
704739 e3nn_angle_conv_args = e3nn_angle_conv_args ,
@@ -1249,12 +1284,23 @@ def forward(
12491284 edge_sph = None
12501285 node_sph_embed = None
12511286
1287+ if not self .e3nn_conv_use_edge_sh_feat :
1288+ if self .use_e3nn_angle_conv :
1289+ edge_sph_embed = edge_ebd
1290+ else :
1291+ edge_sph_embed = None
1292+ else :
1293+ assert edge_sph is not None
1294+ assert self .edge_sph_embd_init_tp is not None
1295+ assert self .edge_sph_embd_init_linear is not None
1296+ edge_sph_embed = self .edge_sph_embd_init_tp (edge_ebd , edge_sph )
1297+ edge_sph_embed = self .edge_sph_embd_init_linear (edge_sph_embed )
1298+
12521299 if self .use_e3nn_angle_conv :
12531300 assert self .use_dynamic_sel , "e3nn conv must use dynamic sel"
12541301 assert self .edge_spherical_embd_for_angle is not None
12551302 assert self .edge_to_angle_filter_prod is not None
12561303 assert self .edge_to_angle_filter_linear is not None
1257- edge_sph_embed = edge_ebd
12581304 edge_i_index = angle_index [:, 1 ]
12591305 edge_j_index = angle_index [:, 2 ]
12601306 if not self .e3nn_angle_use_cross :
@@ -1277,7 +1323,6 @@ def forward(
12771323 angle_weights = torch .cat ([cosine_ij .unsqueeze (- 1 ), sine_ij .unsqueeze (- 1 )], dim = - 1 )
12781324 else :
12791325 edge_angle_filter = None
1280- edge_sph_embed = None
12811326 angle_weights = None
12821327
12831328 for idx , ll in enumerate (self .layers ):
0 commit comments