@@ -148,6 +148,8 @@ def __init__(
148148 force_embedding_on_edge : bool = False ,
149149 use_gated_mlp : bool = False ,
150150 gated_mlp_norm : str = "none" ,
151+ use_res_gnn : bool = False ,
152+ res_gnn_layer : int = 6 ,
151153 use_loc_mapping : bool = True ,
152154 optim_update : bool = True ,
153155 seed : Optional [Union [int , list [int ]]] = None ,
@@ -337,6 +339,12 @@ def __init__(
337339 self .use_loc_mapping = use_loc_mapping
338340 self .use_gated_mlp = use_gated_mlp
339341 self .gated_mlp_norm = gated_mlp_norm
342+ self .use_res_gnn = use_res_gnn
343+ self .res_gnn_layer = res_gnn_layer
344+ if self .use_res_gnn :
345+ assert (
346+ self .nlayers % self .res_gnn_layer == 0
347+ ), "nlayers must be divisible by res_gnn_layer"
340348 assert not (
341349 self .message_use_self_concat and self .use_slim_message
342350 ), "only one of message_use_self_concat and use_slim_message can be True"
@@ -953,6 +961,7 @@ def forward(
953961 mapping = (
954962 mapping .view (nframes , nall ).unsqueeze (- 1 ).expand (- 1 , - 1 , self .n_dim )
955963 )
964+ res_node_list = []
956965 for idx , ll in enumerate (self .layers ):
957966 # node_ebd: nb x nloc x n_dim
958967 # node_ebd_ext: nb x nall x n_dim [OR] nb x nloc x n_dim when not parrallel_mode
@@ -1042,6 +1051,11 @@ def forward(
10421051 d_sw = d_sw ,
10431052 rbf_ebd = rbf_ebd ,
10441053 )
1054+ if self .use_res_gnn and (idx + 1 ) % self .res_gnn_layer == 0 :
1055+ res_node_list .append (node_ebd .unsqueeze (- 1 ))
1056+
1057+ if self .use_res_gnn :
1058+ node_ebd = torch .concat (res_node_list , dim = - 1 ).mean (dim = - 1 )
10451059
10461060 if self .use_combined_output :
10471061 concat_list = [node_ebd ]
0 commit comments