@@ -186,6 +186,7 @@ def __init__(
186186 precision : str = "float64" ,
187187 fix_stat_std : float = 0.3 ,
188188 smooth_edge_update : bool = False ,
189+ edge_init_use_dist : bool = False ,
189190 use_exp_switch : bool = False ,
190191 use_dynamic_sel : bool = False ,
191192 sel_reduce_factor : float = 10.0 ,
@@ -221,6 +222,7 @@ def __init__(
221222 self .a_compress_use_split = a_compress_use_split
222223 self .optim_update = optim_update
223224 self .smooth_edge_update = smooth_edge_update
225+ self .edge_init_use_dist = edge_init_use_dist
224226 self .use_exp_switch = use_exp_switch
225227 self .use_dynamic_sel = use_dynamic_sel
226228 self .sel_reduce_factor = sel_reduce_factor
@@ -450,6 +452,10 @@ def forward(
450452 # get edge and angle embedding input
451453 # nb x nloc x nnei x 1, nb x nloc x nnei x 3
452454 edge_input , h2 = torch .split (dmatrix , [1 , 3 ], dim = - 1 )
455+ if self .edge_init_use_dist :
456+ # nb x nloc x nnei x 1
457+ edge_input = torch .linalg .norm (diff , dim = - 1 , keepdim = True )
458+
453459 # nf x nloc x a_nnei x 3
454460 normalized_diff_i = a_diff / (
455461 torch .linalg .norm (a_diff , dim = - 1 , keepdim = True ) + 1e-6
@@ -486,7 +492,10 @@ def forward(
486492 )
487493 # get edge and angle embedding
488494 # nb x nloc x nnei x e_dim [OR] n_edge x e_dim
489- edge_ebd = self .act (self .edge_embd (edge_input ))
495+ if not self .edge_init_use_dist :
496+ edge_ebd = self .act (self .edge_embd (edge_input ))
497+ else :
498+ edge_ebd = self .edge_embd (edge_input )
490499 # nf x nloc x a_nnei x a_nnei x a_dim [OR] n_angle x a_dim
491500 angle_ebd = self .angle_embd (angle_input )
492501
0 commit comments