@@ -103,10 +103,7 @@ def __init__(
103103 env_protection : float = 0.0 ,
104104 precision : str = "float64" ,
105105 skip_stat : bool = True ,
106- smooth_angle_init : bool = False ,
107- angle_init_use_sin : bool = False ,
108106 smooth_edge_update : bool = False ,
109- angle_multi_freq : Optional [str ] = None ,
110107 use_dynamic_sel : bool = False ,
111108 sel_reduce_factor : float = 10.0 ,
112109 optim_update : bool = True ,
@@ -211,29 +208,9 @@ def __init__(
211208 self .skip_stat = skip_stat
212209 self .a_compress_use_split = a_compress_use_split
213210 self .optim_update = optim_update
214- self .smooth_angle_init = smooth_angle_init
215- self .angle_init_use_sin = angle_init_use_sin
216211 self .smooth_edge_update = smooth_edge_update
217212 self .use_dynamic_sel = use_dynamic_sel
218213 self .sel_reduce_factor = sel_reduce_factor
219- self .angle_multi_freq = angle_multi_freq
220- self .angle_use_multi_freq = angle_multi_freq is not None
221- self .angle_multi_freq_list_float = (
222- [float (freq ) for freq in angle_multi_freq .split (":" )]
223- if self .angle_use_multi_freq
224- else []
225- )
226- if self .angle_use_multi_freq :
227- self .register_buffer (
228- "angle_multi_freq_list" ,
229- torch .tensor (
230- self .angle_multi_freq_list_float ,
231- dtype = torch .float ,
232- device = env .DEVICE ,
233- ),
234- )
235- else :
236- self .angle_multi_freq_list = None
237214
238215 self .n_dim = n_dim
239216 self .e_dim = e_dim
@@ -258,9 +235,7 @@ def __init__(
258235 1 , self .e_dim , precision = precision , seed = child_seed (seed , 0 )
259236 )
260237 self .angle_embd = MLPLayer (
261- len (self .angle_multi_freq_list_float ) + 1
262- if not self .angle_init_use_sin
263- else 2 * (len (self .angle_multi_freq_list_float ) + 1 ),
238+ 1 ,
264239 self .a_dim ,
265240 precision = precision ,
266241 bias = False ,
@@ -478,27 +453,7 @@ def forward(
478453 # nf x nloc x a_nnei x a_nnei
479454 # 1 - 1e-6 for torch.acos stability
480455 cosine_ij = torch .matmul (normalized_diff_i , normalized_diff_j ) * (1 - 1e-6 )
481- sine_ij = torch .sqrt (1 - cosine_ij ** 2 )
482- if self .smooth_angle_init :
483- cosine_ij = cosine_ij * a_sw .unsqueeze (- 1 ) * a_sw .unsqueeze (- 2 )
484- sine_ij = sine_ij * a_sw .unsqueeze (- 1 ) * a_sw .unsqueeze (- 2 )
485-
486- if self .angle_use_multi_freq :
487- assert self .angle_multi_freq_list is not None
488- theta = torch .acos (cosine_ij )
489- theta_list = theta [..., None ] * self .angle_multi_freq_list
490- else :
491- theta_list = None
492-
493- # nf x nloc x a_nnei x a_nnei x 1, nf x nloc x a_nnei x a_nnei x n_freq
494- angle_input_list = [cosine_ij .unsqueeze (- 1 )] + (
495- [torch .cos (theta_list )] if theta_list is not None else []
496- )
497- if self .angle_init_use_sin :
498- angle_input_list += [sine_ij .unsqueeze (- 1 )] + (
499- [torch .sin (theta_list )] if theta_list is not None else []
500- )
501- angle_input = torch .cat (angle_input_list , dim = - 1 ) / (torch .pi ** 0.5 )
456+ angle_input = cosine_ij .unsqueeze (- 1 ) / (torch .pi ** 0.5 )
502457
503458 if self .use_dynamic_sel :
504459 # get graph index
0 commit comments