Skip to content

Commit 1dc1248

Browse files
feat(pt): add Mean absolute percentage error (MAPE) loss for prop. pred. (#4854)
feat: add Mean absolute percentage error (MAPE) loss for property prediction <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added support for the Mean Absolute Percentage Error (MAPE) as a selectable loss function and metric. * **Documentation** * Updated user-facing documentation to include "mape" as an option for loss function selection. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent dbc6d7b commit 1dc1248

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

deepmd/pt/loss/property.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
var_name : str
4343
The atomic property to fit, 'energy', 'dipole', and 'polar'.
4444
loss_func : str
45-
The loss function, such as "smooth_mae", "mae", "rmse".
45+
The loss function, such as "smooth_mae", "mae", "rmse", "mape".
4646
metric : list
4747
The metric such as mae, rmse which will be printed.
4848
beta : float
@@ -151,6 +151,12 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
151151
reduction="mean",
152152
)
153153
)
154+
elif self.loss_func == "mape":
155+
loss += torch.mean(
156+
torch.abs(
157+
(label[var_name] - model_pred[var_name]) / (label[var_name] + 1e-3)
158+
)
159+
)
154160
else:
155161
raise RuntimeError(f"Unknown loss function : {self.loss_func}")
156162

@@ -182,6 +188,12 @@ def forward(self, input_dict, model, label, natoms, learning_rate=0.0, mae=False
182188
reduction="mean",
183189
)
184190
).detach()
191+
if "mape" in self.metric:
192+
more_loss["mape"] = torch.mean(
193+
torch.abs(
194+
(label[var_name] - model_pred[var_name]) / (label[var_name] + 1e-3)
195+
)
196+
).detach()
185197

186198
return model_pred, loss, more_loss
187199

0 commit comments

Comments
 (0)