@@ -465,8 +465,9 @@ def forward(
465465 # beyond the cutoff sw should be 0.0
466466 sw = sw .masked_fill (~ nlist_mask , 0.0 )
467467
468+ eps = 1e-6 # for numerical stability; see #4809
468469 # get angle nlist (maybe smaller)
469- a_dist_mask = (torch .linalg .norm (diff , dim = - 1 ) < self .a_rcut )[
470+ a_dist_mask = (torch .linalg .norm (diff + eps , dim = - 1 ) < self .a_rcut )[
470471 :, :, : self .a_sel
471472 ]
472473 a_nlist = nlist [:, :, : self .a_sel ]
@@ -505,17 +506,16 @@ def forward(
505506 edge_input , h2 = torch .split (dmatrix , [1 , 3 ], dim = - 1 )
506507 if self .edge_init_use_dist :
507508 # nb x nloc x nnei x 1
508- edge_input = torch .linalg .norm (diff , dim = - 1 , keepdim = True )
509+ edge_input = torch .linalg .norm (diff + eps , dim = - 1 , keepdim = True )
509510
510511 # nf x nloc x a_nnei x 3
511512 normalized_diff_i = a_diff / (
512- torch .linalg .norm (a_diff , dim = - 1 , keepdim = True ) + 1e-6
513+ torch .linalg .norm (a_diff + eps , dim = - 1 , keepdim = True ) + eps
513514 )
514515 # nf x nloc x 3 x a_nnei
515516 normalized_diff_j = torch .transpose (normalized_diff_i , 2 , 3 )
516517 # nf x nloc x a_nnei x a_nnei
517- # 1 - 1e-6 for torch.acos stability
518- cosine_ij = torch .matmul (normalized_diff_i , normalized_diff_j ) * (1 - 1e-6 )
518+ cosine_ij = torch .matmul (normalized_diff_i , normalized_diff_j )
519519 angle_input = cosine_ij .unsqueeze (- 1 ) / (torch .pi ** 0.5 )
520520
521521 if not parallel_mode and self .use_loc_mapping :
0 commit comments