@@ -536,6 +536,8 @@ def __init__(
536536 add_angle_readout : bool = False ,
537537 slim_edge_readout : bool = False ,
538538 slim_angle_readout : bool = False ,
539+ edge_extra_fact : float = 1.0 ,
540+ angle_extra_fact : float = 1.0 ,
539541 ** kwargs ,
540542 ) -> None :
541543 """Construct a fitting net for energy.
@@ -549,6 +551,8 @@ def __init__(
549551 """
550552 self .add_edge_readout = add_edge_readout
551553 self .add_angle_readout = add_angle_readout
554+ self .edge_extra_fact = edge_extra_fact
555+ self .angle_extra_fact = angle_extra_fact
552556 super ().__init__ (
553557 "energy" ,
554558 ntypes ,
@@ -714,7 +718,7 @@ def forward(
714718 # nf x nloc x 1
715719 edge_energy = torch .sum (edge_atomic_contrib , dim = - 2 )
716720 # energy
717- out = out + edge_energy / self .norm_e_fact
721+ out = out + ( edge_energy * self . edge_extra_fact ) / self .norm_e_fact
718722
719723 if self .add_angle_readout :
720724 assert angle_embd is not None
@@ -747,5 +751,5 @@ def forward(
747751 )
748752 # energy
749753 # self.norm_a_fact ** 2
750- out = out + angle_energy / (self .norm_a_fact ** 2 )
754+ out = out + ( angle_energy * self . angle_extra_fact ) / (self .norm_a_fact ** 2 )
751755 return {self .var_name : out .to (env .GLOBAL_PT_FLOAT_PRECISION )}
0 commit comments