|
| 1 | +# SPDX-License-Identifier: LGPL-3.0-or-later |
| 2 | +import torch |
| 3 | + |
| 4 | +from deepmd.pt.loss.loss import ( |
| 5 | + TaskLoss, |
| 6 | +) |
| 7 | +from deepmd.pt.utils import ( |
| 8 | + env, |
| 9 | +) |
| 10 | +from deepmd.pt.utils.env import ( |
| 11 | + GLOBAL_PT_FLOAT_PRECISION, |
| 12 | +) |
| 13 | +from deepmd.utils.data import ( |
| 14 | + DataRequirementItem, |
| 15 | +) |
| 16 | + |
| 17 | + |
| 18 | +class GridDensityLoss(TaskLoss): |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + starter_learning_rate=1.0, |
| 22 | + start_pref_d=0.0, |
| 23 | + limit_pref_d=0.0, |
| 24 | + inference=False, |
| 25 | + **kwargs, |
| 26 | + ): |
| 27 | + r"""Construct a layer to compute loss on grid density. |
| 28 | +
|
| 29 | + Parameters |
| 30 | + ---------- |
| 31 | + starter_learning_rate : float |
| 32 | + The learning rate at the start of the training. |
| 33 | + start_pref_d : float |
| 34 | + The prefactor of charge density loss at the start of the training. |
| 35 | + limit_pref_d : float |
| 36 | + The prefactor of charge density loss at the end of the training. |
| 37 | + inference : bool |
| 38 | + If true, it will output all losses found in output, ignoring the pre-factors. |
| 39 | + **kwargs |
| 40 | + Other keyword arguments. |
| 41 | + """ |
| 42 | + super().__init__() |
| 43 | + self.starter_learning_rate = starter_learning_rate |
| 44 | + self.has_d = (start_pref_d != 0.0 and limit_pref_d != 0.0) or inference |
| 45 | + |
| 46 | + self.start_pref_d = start_pref_d |
| 47 | + self.limit_pref_d = limit_pref_d |
| 48 | + self.inference = inference |
| 49 | + |
| 50 | + def forward(self, input_dict, model, label, natoms, learning_rate, mae=False): |
| 51 | + """Return loss on energy and force. |
| 52 | +
|
| 53 | + Parameters |
| 54 | + ---------- |
| 55 | + input_dict : dict[str, torch.Tensor] |
| 56 | + Model inputs. |
| 57 | + model : torch.nn.Module |
| 58 | + Model to be used to output the predictions. |
| 59 | + label : dict[str, torch.Tensor] |
| 60 | + Labels. |
| 61 | + natoms : int |
| 62 | + The local atom number. |
| 63 | +
|
| 64 | + Returns |
| 65 | + ------- |
| 66 | + model_pred: dict[str, torch.Tensor] |
| 67 | + Model predictions. |
| 68 | + loss: torch.Tensor |
| 69 | + Loss for model to minimize. |
| 70 | + more_loss: dict[str, torch.Tensor] |
| 71 | + Other losses for display. |
| 72 | + """ |
| 73 | + model_pred = model(**input_dict) |
| 74 | + coef = learning_rate / self.starter_learning_rate |
| 75 | + pref_d = self.limit_pref_d + (self.start_pref_d - self.limit_pref_d) * coef |
| 76 | + |
| 77 | + loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0] |
| 78 | + more_loss = {} |
| 79 | + # more_loss['log_keys'] = [] # showed when validation on the fly |
| 80 | + # more_loss['test_keys'] = [] # showed when doing dp test |
| 81 | + atom_norm = 1.0 / natoms |
| 82 | + if self.has_d and "density" in model_pred and "density" in label: |
| 83 | + density_pred = model_pred["density"] |
| 84 | + density_label = label["density"] |
| 85 | + find_density = label.get("find_density", 0.0) |
| 86 | + pref_d = pref_d * find_density |
| 87 | + density_pred_reshape = density_pred.reshape(-1) |
| 88 | + density_label_reshape = density_label.reshape(-1) |
| 89 | + l2_density_loss = torch.square( |
| 90 | + density_label_reshape - density_pred_reshape |
| 91 | + ).mean() |
| 92 | + rmse_d = l2_density_loss.sqrt() |
| 93 | + more_loss["rmse_d"] = self.display_if_exist(rmse_d.detach(), find_density) |
| 94 | + l1_density_loss = torch.abs( |
| 95 | + density_label_reshape - density_pred_reshape |
| 96 | + ).mean() |
| 97 | + loss += (pref_d * l1_density_loss).to(GLOBAL_PT_FLOAT_PRECISION) |
| 98 | + mae_d = l1_density_loss |
| 99 | + more_loss["mae_d"] = self.display_if_exist(mae_d.detach(), find_density) |
| 100 | + return model_pred, loss, more_loss |
| 101 | + |
| 102 | + @property |
| 103 | + def label_requirement(self) -> list[DataRequirementItem]: |
| 104 | + """Return data label requirements needed for this loss calculation.""" |
| 105 | + label_requirement = [] |
| 106 | + label_requirement.append( |
| 107 | + DataRequirementItem( |
| 108 | + "grid", |
| 109 | + ndof=3, |
| 110 | + atomic=True, # the grid is defined for each atom, so it is atomic |
| 111 | + must=True, |
| 112 | + high_prec=True, |
| 113 | + ) |
| 114 | + ) |
| 115 | + if self.has_d: |
| 116 | + label_requirement.append( |
| 117 | + DataRequirementItem( |
| 118 | + "density", |
| 119 | + ndof=1, |
| 120 | + atomic=True, |
| 121 | + must=False, |
| 122 | + high_prec=True, |
| 123 | + ) |
| 124 | + ) |
| 125 | + return label_requirement |
0 commit comments