@@ -112,6 +112,9 @@ def __init__(
112112 precision : str = "float64" ,
113113 skip_stat : bool = True ,
114114 no_sym : bool = False ,
115+ smooth_angle_init : bool = False ,
116+ angle_init_use_sin : bool = False ,
117+ smooth_edge_update : bool = False ,
115118 pre_ln : bool = False ,
116119 only_e_ln : bool = False ,
117120 pre_bn : bool = False ,
@@ -239,6 +242,9 @@ def __init__(
239242 self .auto_batchsize = auto_batchsize
240243 self .optim_update = optim_update
241244 self .no_sym = no_sym
245+ self .smooth_angle_init = smooth_angle_init
246+ self .angle_init_use_sin = angle_init_use_sin
247+ self .smooth_edge_update = smooth_edge_update
242248
243249 self .n_dim = n_dim
244250 self .e_dim = e_dim
@@ -299,7 +305,11 @@ def __init__(
299305 1 , self .e_dim , precision = precision , seed = child_seed (seed , 0 )
300306 )
301307 self .angle_embd = MLPLayer (
302- 1 , self .a_dim , precision = precision , bias = False , seed = child_seed (seed , 1 )
308+ 1 if not self .angle_init_use_sin else 2 ,
309+ self .a_dim ,
310+ precision = precision ,
311+ bias = False ,
312+ seed = child_seed (seed , 1 ),
303313 )
304314 self .has_h1 = self .update_n_has_h1 or self .update_e_has_h1
305315 if self .has_h1 :
@@ -452,6 +462,7 @@ def __init__(
452462 bn_moment = self .bn_moment ,
453463 optim_update = self .optim_update ,
454464 no_sym = self .no_sym ,
465+ smooth_edge_update = self .smooth_edge_update ,
455466 seed = child_seed (child_seed (seed , 1 ), ii ),
456467 )
457468 )
@@ -632,10 +643,21 @@ def forward(
632643 # nf x nloc x a_nnei x a_nnei
633644 # 1 - 1e-6 for torch.acos stability
634645 cosine_ij = torch .matmul (normalized_diff_i , normalized_diff_j ) * (1 - 1e-6 )
635- # nf x nloc x a_nnei x a_nnei x 1
636- cosine_ij = cosine_ij .unsqueeze (- 1 ) / (torch .pi ** 0.5 )
646+ sine_ij = torch .sqrt (1 - cosine_ij ** 2 )
647+ if self .smooth_angle_init :
648+ cosine_ij = cosine_ij * a_sw .unsqueeze (- 1 ) * a_sw .unsqueeze (- 2 )
649+ sine_ij = sine_ij * a_sw .unsqueeze (- 1 ) * a_sw .unsqueeze (- 2 )
650+
651+ if not self .angle_init_use_sin :
652+ # nf x nloc x a_nnei x a_nnei x 1
653+ angle_input = cosine_ij .unsqueeze (- 1 ) / (torch .pi ** 0.5 )
654+ else :
655+ angle_input = torch .cat (
656+ [cosine_ij .unsqueeze (- 1 ), sine_ij .unsqueeze (- 1 )], dim = - 1
657+ ) / (torch .pi ** 0.5 )
658+
637659 # nf x nloc x a_nnei x a_nnei x a_dim
638- angle_ebd = self .angle_embd (cosine_ij ).reshape (
660+ angle_ebd = self .angle_embd (angle_input ).reshape (
639661 nframes , nloc , self .a_sel , self .a_sel , self .a_dim
640662 )
641663 if self .has_h1 :
0 commit comments