Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 80 additions & 36 deletions deepmd/dpmodel/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def __init__(
numb_generalized_coord: int = 0,
use_huber: bool = False,
huber_delta: float = 0.01,
use_mae_loss: bool = False,
f_use_norm: bool = False,
**kwargs: Any,
) -> None:
self.starter_learning_rate = starter_learning_rate
Expand Down Expand Up @@ -80,6 +82,12 @@ def __init__(
)
self.use_huber = use_huber
self.huber_delta = huber_delta
self.use_mae_loss = use_mae_loss
self.f_use_norm = f_use_norm
if self.f_use_norm and not (self.use_huber or self.use_mae_loss):
raise RuntimeError(
"f_use_norm can only be True when use_huber or use_mae_loss is True."
)
if self.use_huber and (
self.has_pf or self.has_gf or self.relative_f is not None
):
Expand Down Expand Up @@ -169,51 +177,85 @@ def call(
loss = 0
more_loss = {}
if self.has_e:
l2_ener_loss = xp.mean(xp.square(energy - energy_hat))
if not self.use_huber:
loss += atom_norm_ener * (pref_e * l2_ener_loss)
if not self.use_mae_loss:
l2_ener_loss = xp.mean(xp.square(energy - energy_hat))
if not self.use_huber:
loss += atom_norm_ener * (pref_e * l2_ener_loss)
else:
l_huber_loss = custom_huber_loss(
atom_norm_ener * energy,
atom_norm_ener * energy_hat,
delta=self.huber_delta,
)
loss += pref_e * l_huber_loss
more_loss["rmse_e"] = self.display_if_exist(
xp.sqrt(l2_ener_loss) * atom_norm_ener, find_energy
)
else:
l_huber_loss = custom_huber_loss(
atom_norm_ener * energy,
atom_norm_ener * energy_hat,
delta=self.huber_delta,
l1_ener_loss = xp.mean(xp.abs(energy - energy_hat))
loss += atom_norm_ener * (pref_e * l1_ener_loss)
more_loss["mae_e"] = self.display_if_exist(
l1_ener_loss * atom_norm_ener, find_energy
)
loss += pref_e * l_huber_loss
more_loss["rmse_e"] = self.display_if_exist(
xp.sqrt(l2_ener_loss) * atom_norm_ener, find_energy
)
if self.has_f:
l2_force_loss = xp.mean(xp.square(diff_f))
if not self.use_huber:
loss += pref_f * l2_force_loss
else:
l_huber_loss = custom_huber_loss(
xp.reshape(force, (-1,)),
xp.reshape(force_hat, (-1,)),
delta=self.huber_delta,
if not self.use_mae_loss:
l2_force_loss = xp.mean(xp.square(diff_f))
if not self.use_huber:
loss += pref_f * l2_force_loss
else:
if not self.f_use_norm:
l_huber_loss = custom_huber_loss(
xp.reshape(force, (-1,)),
xp.reshape(force_hat, (-1,)),
delta=self.huber_delta,
)
else:
force_diff_3 = xp.reshape(force_hat - force, (-1, 3))
force_diff_norm = xp.reshape(
xp.linalg.vector_norm(force_diff_3, axis=1), (-1, 1)
)
l_huber_loss = custom_huber_loss(
force_diff_norm,
xp.zeros_like(force_diff_norm),
delta=self.huber_delta,
)
loss += pref_f * l_huber_loss
more_loss["rmse_f"] = self.display_if_exist(
xp.sqrt(l2_force_loss), find_force
)
loss += pref_f * l_huber_loss
more_loss["rmse_f"] = self.display_if_exist(
xp.sqrt(l2_force_loss), find_force
)
else:
if not self.f_use_norm:
l1_force_loss = xp.mean(xp.abs(diff_f))
else:
force_diff_3 = xp.reshape(force_hat - force, (-1, 3))
l1_force_loss = xp.mean(xp.linalg.vector_norm(force_diff_3, axis=1))
loss += pref_f * l1_force_loss
more_loss["mae_f"] = self.display_if_exist(l1_force_loss, find_force)
if self.has_v:
virial_reshape = xp.reshape(virial, (-1,))
virial_hat_reshape = xp.reshape(virial_hat, (-1,))
l2_virial_loss = xp.mean(
xp.square(virial_hat_reshape - virial_reshape),
)
if not self.use_huber:
loss += atom_norm * (pref_v * l2_virial_loss)
if not self.use_mae_loss:
l2_virial_loss = xp.mean(
xp.square(virial_hat_reshape - virial_reshape),
)
if not self.use_huber:
loss += atom_norm * (pref_v * l2_virial_loss)
else:
l_huber_loss = custom_huber_loss(
atom_norm * virial_reshape,
atom_norm * virial_hat_reshape,
delta=self.huber_delta,
)
loss += pref_v * l_huber_loss
more_loss["rmse_v"] = self.display_if_exist(
xp.sqrt(l2_virial_loss) * atom_norm, find_virial
)
else:
l_huber_loss = custom_huber_loss(
atom_norm * virial_reshape,
atom_norm * virial_hat_reshape,
delta=self.huber_delta,
l1_virial_loss = xp.mean(xp.abs(virial_hat_reshape - virial_reshape))
loss += atom_norm * (pref_v * l1_virial_loss)
more_loss["mae_v"] = self.display_if_exist(
l1_virial_loss * atom_norm, find_virial
)
loss += pref_v * l_huber_loss
more_loss["rmse_v"] = self.display_if_exist(
xp.sqrt(l2_virial_loss) * atom_norm, find_virial
)
if self.has_ae:
atom_ener_reshape = xp.reshape(atom_ener, (-1,))
atom_ener_hat_reshape = xp.reshape(atom_ener_hat, (-1,))
Expand Down Expand Up @@ -371,6 +413,8 @@ def serialize(self) -> dict:
"numb_generalized_coord": self.numb_generalized_coord,
"use_huber": self.use_huber,
"huber_delta": self.huber_delta,
"use_mae_loss": self.use_mae_loss,
"f_use_norm": self.f_use_norm,
}

@classmethod
Expand Down
113 changes: 78 additions & 35 deletions deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ def __init__(
start_pref_gf: float = 0.0,
limit_pref_gf: float = 0.0,
numb_generalized_coord: int = 0,
use_l1_all: bool = False,
use_mae_loss: bool = False,
inference: bool = False,
use_huber: bool = False,
f_use_norm: bool = False,
huber_delta: float = 0.01,
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -97,8 +98,8 @@ def __init__(
The prefactor of generalized force loss at the end of the training.
numb_generalized_coord : int
The dimension of generalized coordinates.
use_l1_all : bool
Whether to use L1 loss, if False (default), it will use L2 loss.
use_mae_loss : bool
Whether to use MAE (L1) loss for all terms (energy, force, virial), if False (default), it will use L2 loss.
inference : bool
If true, it will output all losses found in output, ignoring the pre-factors.
use_huber : bool
Expand All @@ -107,6 +108,9 @@ def __init__(
- For absolute prediction errors within D: quadratic loss (0.5 * (error**2))
- For absolute errors exceeding D: linear loss (D * |error| - 0.5 * D)
Formula: loss = 0.5 * (error**2) if |error| <= D else D * (|error| - 0.5 * D).
f_use_norm : bool
If true, use L2 norm of force vectors for MAE calculation when use_mae_loss or use_huber is True.
Instead of computing MAE on force components, computes MAE on ||F_pred - F_label||_2.
huber_delta : float
The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss.
**kwargs
Expand Down Expand Up @@ -140,9 +144,14 @@ def __init__(
raise RuntimeError(
"When generalized force loss is used, the dimension of generalized coordinates should be larger than 0"
)
self.use_l1_all = use_l1_all
self.use_mae_loss = use_mae_loss
self.inference = inference
self.use_huber = use_huber
self.f_use_norm = f_use_norm
if self.f_use_norm and not (self.use_huber or self.use_mae_loss):
raise RuntimeError(
Comment thread
iProzd marked this conversation as resolved.
"f_use_norm can only be True when use_huber or use_mae_loss is True."
)
self.huber_delta = huber_delta
if self.use_huber and (
self.has_pf or self.has_gf or self.relative_f is not None
Expand Down Expand Up @@ -214,7 +223,7 @@ def forward(
energy_pred = torch.sum(atom_ener_coeff * atom_ener_pred, dim=1)
find_energy = label.get("find_energy", 0.0)
pref_e = pref_e * find_energy
if not self.use_l1_all:
if not self.use_mae_loss:
l2_ener_loss = torch.mean(torch.square(energy_pred - energy_label))
if not self.inference:
more_loss["l2_ener_loss"] = self.display_if_exist(
Expand All @@ -234,19 +243,15 @@ def forward(
rmse_e.detach(), find_energy
)
# more_loss['log_keys'].append('rmse_e')
else: # use l1 and for all atoms
else:
l1_ener_loss = F.l1_loss(
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="sum",
reduction="mean",
)
loss += pref_e * l1_ener_loss
loss += atom_norm * (pref_e * l1_ener_loss)
more_loss["mae_e"] = self.display_if_exist(
F.l1_loss(
energy_pred.reshape(-1),
energy_label.reshape(-1),
reduction="mean",
).detach(),
l1_ener_loss.detach() * atom_norm,
find_energy,
)
# more_loss['log_keys'].append('rmse_e')
Expand Down Expand Up @@ -277,7 +282,7 @@ def forward(
diff_f = diff_f_3.reshape(-1)

if self.has_f:
if not self.use_l1_all:
if not self.use_mae_loss:
l2_force_loss = torch.mean(torch.square(diff_f))
if not self.inference:
more_loss["l2_force_loss"] = self.display_if_exist(
Expand All @@ -286,22 +291,46 @@ def forward(
if not self.use_huber:
loss += (pref_f * l2_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
else:
l_huber_loss = custom_huber_loss(
force_pred.reshape(-1),
force_label.reshape(-1),
delta=self.huber_delta,
)
if not self.f_use_norm:
l_huber_loss = custom_huber_loss(
force_pred.reshape(-1),
force_label.reshape(-1),
delta=self.huber_delta,
)
else:
force_diff_norm = torch.linalg.vector_norm(
(force_label - force_pred).reshape(-1, 3),
ord=2,
dim=1,
keepdim=True,
)
l_huber_loss = custom_huber_loss(
force_diff_norm,
torch.zeros_like(force_diff_norm),
delta=self.huber_delta,
)
loss += pref_f * l_huber_loss
rmse_f = l2_force_loss.sqrt()
more_loss["rmse_f"] = self.display_if_exist(
rmse_f.detach(), find_force
)
else:
l1_force_loss = F.l1_loss(force_label, force_pred, reduction="none")
if not self.f_use_norm:
l1_force_loss = F.l1_loss(
force_label.reshape(-1),
force_pred.reshape(-1),
reduction="mean",
)
else:
l1_force_loss = torch.linalg.vector_norm(
(force_label - force_pred).reshape(-1, 3),
ord=2,
dim=1,
keepdim=True,
).mean()
more_loss["mae_f"] = self.display_if_exist(
l1_force_loss.mean().detach(), find_force
l1_force_loss.detach(), find_force
)
l1_force_loss = l1_force_loss.sum(-1).mean(-1).sum()
loss += (pref_f * l1_force_loss).to(GLOBAL_PT_FLOAT_PRECISION)
if mae:
mae_f = torch.mean(torch.abs(diff_f))
Expand Down Expand Up @@ -354,22 +383,36 @@ def forward(
find_virial = label.get("find_virial", 0.0)
pref_v = pref_v * find_virial
diff_v = label["virial"] - model_pred["virial"].reshape(-1, 9)
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
if not self.use_mae_loss:
l2_virial_loss = torch.mean(torch.square(diff_v))
if not self.inference:
more_loss["l2_virial_loss"] = self.display_if_exist(
l2_virial_loss.detach(), find_virial
)
if not self.use_huber:
loss += atom_norm * (pref_v * l2_virial_loss)
else:
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["virial"].reshape(-1),
atom_norm * label["virial"].reshape(-1),
delta=self.huber_delta,
)
loss += pref_v * l_huber_loss
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(
rmse_v.detach(), find_virial
)
if not self.use_huber:
loss += atom_norm * (pref_v * l2_virial_loss)
else:
l_huber_loss = custom_huber_loss(
atom_norm * model_pred["virial"].reshape(-1),
atom_norm * label["virial"].reshape(-1),
delta=self.huber_delta,
l1_virial_loss = F.l1_loss(
label["virial"].reshape(-1),
model_pred["virial"].reshape(-1),
reduction="mean",
)
loss += atom_norm * (pref_v * l1_virial_loss)
more_loss["mae_v"] = self.display_if_exist(
l1_virial_loss.detach() * atom_norm,
find_virial,
)
loss += pref_v * l_huber_loss
rmse_v = l2_virial_loss.sqrt() * atom_norm
more_loss["rmse_v"] = self.display_if_exist(rmse_v.detach(), find_virial)
if mae:
mae_v = torch.mean(torch.abs(diff_v)) * atom_norm
more_loss["mae_v"] = self.display_if_exist(mae_v.detach(), find_virial)
Expand Down
28 changes: 28 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -3084,6 +3084,17 @@ def loss_ener() -> list[Argument]:
"Formula: loss = 0.5 * (error**2) if \\|error\\| <= D else D * (\\|error\\| - 0.5 * D). "
)
doc_huber_delta = "The threshold delta (D) used for Huber loss, controlling transition between L2 and L1 loss. "
doc_use_mae_loss = (
"If true, use MAE (Mean Absolute Error, L1 loss) for all terms (energy, force, virial). "
"If false (default), use MSE (Mean Squared Error, L2 loss). "
"MAE loss is less sensitive to outliers compared to MSE loss."
)
doc_f_use_norm = (
"If true, use L2 norm of force vectors for loss calculation when use_mae_loss or use_huber is True. "
"Instead of computing loss on individual force components, computes loss on ||F_pred - F_label||_2 for each atom. "
"This treats the force vector as a whole rather than three independent components. "
"Only effective when use_mae_loss=True or use_huber=True."
)
return [
Argument(
"start_pref_e",
Expand Down Expand Up @@ -3205,6 +3216,23 @@ def loss_ener() -> list[Argument]:
default=False,
doc=doc_use_huber,
),
Argument(
"use_mae_loss",
bool,
optional=True,
default=False,
doc=doc_use_mae_loss,
alias=[
"use_l1_all",
],
),
Argument(
"f_use_norm",
bool,
optional=True,
default=False,
doc=doc_f_use_norm,
),
Argument(
"huber_delta",
float,
Expand Down
Loading
Loading