diff --git a/deepmd/dpmodel/loss/ener.py b/deepmd/dpmodel/loss/ener.py index 2bf103c249..7ac81d8a3d 100644 --- a/deepmd/dpmodel/loss/ener.py +++ b/deepmd/dpmodel/loss/ener.py @@ -179,7 +179,9 @@ def call( delta=self.huber_delta, ) loss += pref_e * l_huber_loss - more_loss["rmse_e"] = self.display_if_exist(l2_ener_loss, find_energy) + more_loss["rmse_e"] = self.display_if_exist( + xp.sqrt(l2_ener_loss), find_energy + ) if self.has_f: l2_force_loss = xp.mean(xp.square(diff_f)) if not self.use_huber: @@ -191,7 +193,9 @@ def call( delta=self.huber_delta, ) loss += pref_f * l_huber_loss - more_loss["rmse_f"] = self.display_if_exist(l2_force_loss, find_force) + more_loss["rmse_f"] = self.display_if_exist( + xp.sqrt(l2_force_loss), find_force + ) if self.has_v: virial_reshape = xp.reshape(virial, (-1,)) virial_hat_reshape = xp.reshape(virial_hat, (-1,)) @@ -207,7 +211,9 @@ def call( delta=self.huber_delta, ) loss += pref_v * l_huber_loss - more_loss["rmse_v"] = self.display_if_exist(l2_virial_loss, find_virial) + more_loss["rmse_v"] = self.display_if_exist( + xp.sqrt(l2_virial_loss), find_virial + ) if self.has_ae: atom_ener_reshape = xp.reshape(atom_ener, (-1,)) atom_ener_hat_reshape = xp.reshape(atom_ener_hat, (-1,)) @@ -224,7 +230,7 @@ def call( ) loss += pref_ae * l_huber_loss more_loss["rmse_ae"] = self.display_if_exist( - l2_atom_ener_loss, find_atom_ener + xp.sqrt(l2_atom_ener_loss), find_atom_ener ) if self.has_pf: atom_pref_reshape = xp.reshape(atom_pref, (-1,)) @@ -233,7 +239,7 @@ def call( ) loss += pref_pf * l2_pref_force_loss more_loss["rmse_pf"] = self.display_if_exist( - l2_pref_force_loss, find_atom_pref + xp.sqrt(l2_pref_force_loss), find_atom_pref ) if self.has_gf: find_drdq = label_dict["find_drdq"] @@ -254,7 +260,9 @@ def call( + (self.start_pref_gf - self.limit_pref_gf) * lr_ratio ) loss += pref_gf * l2_gen_force_loss - more_loss["rmse_gf"] = self.display_if_exist(l2_gen_force_loss, find_drdq) + more_loss["rmse_gf"] = self.display_if_exist( + xp.sqrt(l2_gen_force_loss), find_drdq + ) self.l2_l = loss more_loss["rmse"] = xp.sqrt(loss)