@@ -447,6 +447,7 @@ def __init__(
447447 )
448448 )
449449 self .layers = torch .nn .ModuleList (layers )
450+ self .additional_output_for_fitting : dict [str , Optional [torch .Tensor ]] = {}
450451
451452 wanted_shape = (self .ntypes , self .nnei , 4 )
452453 mean = torch .zeros (wanted_shape , dtype = self .prec , device = env .DEVICE )
@@ -461,6 +462,8 @@ def get_rcut(self) -> float:
461462 """Returns the cut-off radius."""
462463 return self .e_rcut
463464
465+ additional_output_for_fitting : dict [str , Optional [torch .Tensor ]]
466+
464467 def get_rcut_smth (self ) -> float :
465468 """Returns the radius where the neighbor information starts to smoothly decay to 0."""
466469 return self .e_rcut_smth
@@ -548,6 +551,9 @@ def reinit_exclude(
548551 self .exclude_types = exclude_types
549552 self .emask = PairExcludeMask (self .ntypes , exclude_types = exclude_types )
550553
554+ def get_additional_output_for_fitting (self ):
555+ return self .additional_output_for_fitting
556+
551557 def forward (
552558 self ,
553559 nlist : torch .Tensor ,
@@ -782,6 +788,8 @@ def forward(
782788 sw = sw [nlist_mask ]
783789 # n_edge x 4
784790 dmatrix = dmatrix [nlist_mask ]
791+ # n_edge x 3
792+ diff = diff [nlist_mask ]
785793
786794 if self .edge_use_esen_atom_ebd :
787795 assert source_type is not None
@@ -809,12 +817,16 @@ def forward(
809817 * d_sw [:, :, None , :, None ]
810818 * d_sw [:, :, None , None , :]
811819 )[d_nlist_mask ]
820+ self .additional_output_for_fitting ["edge_index" ] = edge_index
812821 else :
813822 # avoid jit assertion
814823 edge_index = angle_index = torch .zeros (
815824 [1 , 3 ], device = nlist .device , dtype = nlist .dtype
816825 )
817826 dihedral_index = None
827+ self .additional_output_for_fitting ["edge_index" ] = None
828+ self .additional_output_for_fitting ["diff" ] = diff
829+ self .additional_output_for_fitting ["sw" ] = sw
818830 # get edge and angle embedding
819831 # nb x nloc x nnei x e_dim [OR] n_edge x e_dim
820832 if self .edge_use_esen_rbf :
0 commit comments