@@ -184,6 +184,7 @@ def __init__(
184184 e3nn_conv_pattern : str = "128x0e+64x1e+32x2e+32x3e" ,
185185 use_e3nn_denominator : bool = False ,
186186 e3nn_conv_l_max : int = 3 ,
187+ e3nn_use_edge_feat_weights : bool = False ,
187188 seed : Optional [Union [int , list [int ]]] = None ,
188189 ) -> None :
189190 r"""
@@ -482,6 +483,7 @@ def __init__(
482483 self .e3nn_conv_pattern = e3nn_conv_pattern
483484 self .use_e3nn_denominator = use_e3nn_denominator
484485 self .e3nn_conv_l_max = e3nn_conv_l_max
486+ self .e3nn_use_edge_feat_weights = e3nn_use_edge_feat_weights
485487
486488 if not self .edge_use_esen_rbf :
487489 self .edge_embd = MLPLayer (
@@ -568,6 +570,7 @@ def __init__(
568570 "irreps_out" : irreps_out ,
569571 "denominator" : 1.0 if not self .use_e3nn_denominator else self .dynamic_e_sel / 4 ,
570572 "train_denominator" : True ,
573+ "weight_layer_input_to_hidden" : [8 , 64 , 64 ] if not self .e3nn_use_edge_feat_weights else [self .e_dim ],
571574 }
572575 irreps_x = irreps_out
573576 layers .append (
@@ -633,6 +636,7 @@ def __init__(
633636 dropout_rate = self .dropout_rate ,
634637 use_e3nn_conv = self .use_e3nn_conv ,
635638 e3nn_conv_pattern = self .e3nn_conv_pattern ,
639+ e3nn_use_edge_feat_weights = self .e3nn_use_edge_feat_weights ,
636640 e3nn_conv_args = e3nn_conv_args ,
637641 seed = child_seed (child_seed (seed , 1 ), ii ),
638642 )
0 commit comments