@@ -100,6 +100,9 @@ def __init__(
100100 env_protection : float = 0.0 ,
101101 precision : str = "float64" ,
102102 skip_stat : bool = True ,
103+ smooth_angle_init : bool = False ,
104+ angle_init_use_sin : bool = False ,
105+ smooth_edge_update : bool = False ,
103106 optim_update : bool = True ,
104107 seed : Optional [Union [int , list [int ]]] = None ,
105108 ) -> None :
@@ -202,6 +205,9 @@ def __init__(
202205 self .skip_stat = skip_stat
203206 self .a_compress_use_split = a_compress_use_split
204207 self .optim_update = optim_update
208+ self .smooth_angle_init = smooth_angle_init
209+ self .angle_init_use_sin = angle_init_use_sin
210+ self .smooth_edge_update = smooth_edge_update
205211
206212 self .n_dim = n_dim
207213 self .e_dim = e_dim
@@ -226,7 +232,11 @@ def __init__(
226232 1 , self .e_dim , precision = precision , seed = child_seed (seed , 0 )
227233 )
228234 self .angle_embd = MLPLayer (
229- 1 , self .a_dim , precision = precision , bias = False , seed = child_seed (seed , 1 )
235+ 1 if not self .angle_init_use_sin else 2 ,
236+ self .a_dim ,
237+ precision = precision ,
238+ bias = False ,
239+ seed = child_seed (seed , 1 ),
230240 )
231241 layers = []
232242 for ii in range (nlayers ):
@@ -254,6 +264,7 @@ def __init__(
254264 update_residual_init = self .update_residual_init ,
255265 precision = precision ,
256266 optim_update = self .optim_update ,
267+ smooth_edge_update = self .smooth_edge_update ,
257268 seed = child_seed (child_seed (seed , 1 ), ii ),
258269 )
259270 )
@@ -434,10 +445,21 @@ def forward(
434445 # nf x nloc x a_nnei x a_nnei
435446 # 1 - 1e-6 for torch.acos stability
436447 cosine_ij = torch .matmul (normalized_diff_i , normalized_diff_j ) * (1 - 1e-6 )
437- # nf x nloc x a_nnei x a_nnei x 1
438- cosine_ij = cosine_ij .unsqueeze (- 1 ) / (torch .pi ** 0.5 )
448+ sine_ij = torch .sqrt (1 - cosine_ij ** 2 )
449+ if self .smooth_angle_init :
450+ cosine_ij = cosine_ij * a_sw .unsqueeze (- 1 ) * a_sw .unsqueeze (- 2 )
451+ sine_ij = sine_ij * a_sw .unsqueeze (- 1 ) * a_sw .unsqueeze (- 2 )
452+
453+ if not self .angle_init_use_sin :
454+ # nf x nloc x a_nnei x a_nnei x 1
455+ angle_input = cosine_ij .unsqueeze (- 1 ) / (torch .pi ** 0.5 )
456+ else :
457+ angle_input = torch .cat (
458+ [cosine_ij .unsqueeze (- 1 ), sine_ij .unsqueeze (- 1 )], dim = - 1
459+ ) / (torch .pi ** 0.5 )
460+
439461 # nf x nloc x a_nnei x a_nnei x a_dim
440- angle_ebd = self .angle_embd (cosine_ij ).reshape (
462+ angle_ebd = self .angle_embd (angle_input ).reshape (
441463 nframes , nloc , self .a_sel , self .a_sel , self .a_dim
442464 )
443465
0 commit comments