Skip to content

Commit 89a3180

Browse files
authored
fix(dpmodel): fix rmse_* in more_loss (#5105)
The previous value was MSE. <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Fixed RMSE metric calculations to correctly display square-root values for energy, force, virial, and atom-related loss measurements, ensuring accurate performance metrics reporting. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai -->
1 parent bbe66a3 commit 89a3180

1 file changed

Lines changed: 14 additions & 6 deletions

File tree

deepmd/dpmodel/loss/ener.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,9 @@ def call(
179179
delta=self.huber_delta,
180180
)
181181
loss += pref_e * l_huber_loss
182-
more_loss["rmse_e"] = self.display_if_exist(l2_ener_loss, find_energy)
182+
more_loss["rmse_e"] = self.display_if_exist(
183+
xp.sqrt(l2_ener_loss), find_energy
184+
)
183185
if self.has_f:
184186
l2_force_loss = xp.mean(xp.square(diff_f))
185187
if not self.use_huber:
@@ -191,7 +193,9 @@ def call(
191193
delta=self.huber_delta,
192194
)
193195
loss += pref_f * l_huber_loss
194-
more_loss["rmse_f"] = self.display_if_exist(l2_force_loss, find_force)
196+
more_loss["rmse_f"] = self.display_if_exist(
197+
xp.sqrt(l2_force_loss), find_force
198+
)
195199
if self.has_v:
196200
virial_reshape = xp.reshape(virial, (-1,))
197201
virial_hat_reshape = xp.reshape(virial_hat, (-1,))
@@ -207,7 +211,9 @@ def call(
207211
delta=self.huber_delta,
208212
)
209213
loss += pref_v * l_huber_loss
210-
more_loss["rmse_v"] = self.display_if_exist(l2_virial_loss, find_virial)
214+
more_loss["rmse_v"] = self.display_if_exist(
215+
xp.sqrt(l2_virial_loss), find_virial
216+
)
211217
if self.has_ae:
212218
atom_ener_reshape = xp.reshape(atom_ener, (-1,))
213219
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, (-1,))
@@ -224,7 +230,7 @@ def call(
224230
)
225231
loss += pref_ae * l_huber_loss
226232
more_loss["rmse_ae"] = self.display_if_exist(
227-
l2_atom_ener_loss, find_atom_ener
233+
xp.sqrt(l2_atom_ener_loss), find_atom_ener
228234
)
229235
if self.has_pf:
230236
atom_pref_reshape = xp.reshape(atom_pref, (-1,))
@@ -233,7 +239,7 @@ def call(
233239
)
234240
loss += pref_pf * l2_pref_force_loss
235241
more_loss["rmse_pf"] = self.display_if_exist(
236-
l2_pref_force_loss, find_atom_pref
242+
xp.sqrt(l2_pref_force_loss), find_atom_pref
237243
)
238244
if self.has_gf:
239245
find_drdq = label_dict["find_drdq"]
@@ -254,7 +260,9 @@ def call(
254260
+ (self.start_pref_gf - self.limit_pref_gf) * lr_ratio
255261
)
256262
loss += pref_gf * l2_gen_force_loss
257-
more_loss["rmse_gf"] = self.display_if_exist(l2_gen_force_loss, find_drdq)
263+
more_loss["rmse_gf"] = self.display_if_exist(
264+
xp.sqrt(l2_gen_force_loss), find_drdq
265+
)
258266

259267
self.l2_l = loss
260268
more_loss["rmse"] = xp.sqrt(loss)

0 commit comments

Comments
 (0)