@@ -50,6 +50,13 @@ class EnergySpinLoss(Loss):
5050 if true, the energy will be computed as \sum_i c_i E_i
5151 loss_func : str
5252 Loss function type: 'mse' or 'mae'.
53+ intensive : bool
54+ If true, energy and virial losses are computed as intensive quantities,
55+ normalized by the square of the number of atoms (1/N^2). This ensures the loss
56+ value is independent of system size and consistent with per-atom RMSE reporting.
57+ If false (default), uses the legacy normalization (1/N), which may cause the loss to scale
58+ with system size. The default is false for backward compatibility with models trained
59+ using deepmd-kit <= 3.0.1.
5360 **kwargs
5461 Other keyword arguments.
5562 """
@@ -69,6 +76,7 @@ def __init__(
6976 limit_pref_ae : float = 0.0 ,
7077 enable_atom_ener_coeff : bool = False ,
7178 loss_func : str = "mse" ,
79+ intensive : bool = False ,
7280 ** kwargs : Any ,
7381 ) -> None :
7482 valid_loss_funcs = ["mse" , "mae" ]
@@ -89,6 +97,7 @@ def __init__(
8997 self .start_pref_ae = start_pref_ae
9098 self .limit_pref_ae = limit_pref_ae
9199 self .enable_atom_ener_coeff = enable_atom_ener_coeff
100+ self .intensive = intensive
92101 self .has_e = self .start_pref_e != 0.0 or self .limit_pref_e != 0.0
93102 self .has_fr = self .start_pref_fr != 0.0 or self .limit_pref_fr != 0.0
94103 self .has_fm = self .start_pref_fm != 0.0 or self .limit_pref_fm != 0.0
@@ -117,6 +126,10 @@ def call(
117126 loss = 0
118127 more_loss = {}
119128 atom_norm = 1.0 / natoms
129+ # Normalization exponent controls loss scaling with system size:
130+ # - norm_exp=2 (intensive=True): loss uses 1/N² scaling, making it independent of system size
131+ # - norm_exp=1 (intensive=False, legacy): loss uses 1/N scaling, which varies with system size
132+ norm_exp = 2 if self .intensive else 1
120133
121134 if self .has_e :
122135 energy_pred = model_dict ["energy" ]
@@ -130,7 +143,7 @@ def call(
130143 energy_pred = xp .sum (atom_ener_coeff * atom_ener_pred , axis = 1 )
131144 if self .loss_func == "mse" :
132145 l2_ener_loss = xp .mean (xp .square (energy_pred - energy_label ))
133- loss += atom_norm * (pref_e * l2_ener_loss )
146+ loss += atom_norm ** norm_exp * (pref_e * l2_ener_loss )
134147 more_loss ["rmse_e" ] = self .display_if_exist (
135148 xp .sqrt (l2_ener_loss ) * atom_norm , find_energy
136149 )
@@ -238,7 +251,7 @@ def call(
238251 diff_v = virial_label - virial_pred
239252 if self .loss_func == "mse" :
240253 l2_virial_loss = xp .mean (xp .square (diff_v ))
241- loss += atom_norm * (pref_v * l2_virial_loss )
254+ loss += atom_norm ** norm_exp * (pref_v * l2_virial_loss )
242255 more_loss ["rmse_v" ] = self .display_if_exist (
243256 xp .sqrt (l2_virial_loss ) * atom_norm , find_virial
244257 )
@@ -326,7 +339,7 @@ def serialize(self) -> dict:
326339 """Serialize the loss module."""
327340 return {
328341 "@class" : "EnergySpinLoss" ,
329- "@version" : 1 ,
342+ "@version" : 2 ,
330343 "starter_learning_rate" : self .starter_learning_rate ,
331344 "start_pref_e" : self .start_pref_e ,
332345 "limit_pref_e" : self .limit_pref_e ,
@@ -340,12 +353,17 @@ def serialize(self) -> dict:
340353 "limit_pref_ae" : self .limit_pref_ae ,
341354 "enable_atom_ener_coeff" : self .enable_atom_ener_coeff ,
342355 "loss_func" : self .loss_func ,
356+ "intensive" : self .intensive ,
343357 }
344358
345359 @classmethod
346360 def deserialize (cls , data : dict ) -> "EnergySpinLoss" :
347361 """Deserialize the loss module."""
348362 data = data .copy ()
349- check_version_compatibility (data .pop ("@version" ), 1 , 1 )
363+ version = data .pop ("@version" )
364+ check_version_compatibility (version , 2 , 1 )
350365 data .pop ("@class" )
366+ # Backward compatibility: version 1 used legacy normalization
367+ if version < 2 :
368+ data .setdefault ("intensive" , False )
351369 return cls (** data )
0 commit comments