diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 433897860f..a9d2326b93 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -35,6 +35,9 @@ from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) +from deepmd.pt.utils.safe_gradient import ( + safe_for_norm, +) from deepmd.pt.utils.spin import ( concat_switch_virtual, ) @@ -473,9 +476,7 @@ def forward( sw = sw.masked_fill(~nlist_mask, 0.0) # get angle nlist (maybe smaller) - a_dist_mask = (torch.linalg.norm(diff, dim=-1) < self.a_rcut)[ - :, :, : self.a_sel - ] + a_dist_mask = (safe_for_norm(diff, dim=-1) < self.a_rcut)[:, :, : self.a_sel] a_nlist = nlist[:, :, : self.a_sel] a_nlist = torch.where(a_dist_mask, a_nlist, -1) _, a_diff, a_sw = prod_env_mat( @@ -512,11 +513,11 @@ def forward( edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1) if self.edge_init_use_dist: # nb x nloc x nnei x 1 - edge_input = torch.linalg.norm(diff, dim=-1, keepdim=True) + edge_input = safe_for_norm(diff, dim=-1, keepdim=True) # nf x nloc x a_nnei x 3 normalized_diff_i = a_diff / ( - torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6 + safe_for_norm(a_diff, dim=-1, keepdim=True) + 1e-6 ) # nf x nloc x 3 x a_nnei normalized_diff_j = torch.transpose(normalized_diff_i, 2, 3) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 69f2cc4eaa..75e2f97576 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -32,6 +32,9 @@ from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) +from deepmd.pt.utils.safe_gradient import ( + safe_for_norm, +) from deepmd.pt.utils.spin import ( concat_switch_virtual, ) @@ -446,7 +449,7 @@ def forward( if not self.direct_dist: g2, h2 = torch.split(dmatrix, [1, 3], dim=-1) else: - g2, h2 = torch.linalg.norm(diff, dim=-1, keepdim=True), diff + g2, h2 = safe_for_norm(diff, dim=-1, keepdim=True), diff g2 = g2 / self.rcut h2 = h2 / self.rcut # nb x nloc x nnei x ng2 diff --git a/deepmd/pt/utils/safe_gradient.py b/deepmd/pt/utils/safe_gradient.py new file mode 100644 index 0000000000..e0bb76a8cf --- /dev/null +++ b/deepmd/pt/utils/safe_gradient.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Safe versions of some functions that have problematic gradients. + +Check https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where +for more information. +""" + +import torch + + +def safe_for_sqrt(x: torch.Tensor) -> torch.Tensor: + """Safe version of sqrt that has a gradient of 0 at x = 0.""" + mask = x > 0.0 + x_safe = torch.where(mask, x, torch.ones_like(x)) + return torch.where(mask, torch.sqrt(x_safe), torch.zeros_like(x)) + + +def safe_for_norm( + x: torch.Tensor, + dim: int | None = None, + keepdim: bool = False, + ord: float = 2.0, +) -> torch.Tensor: + """Safe version of torch.linalg.norm that has a gradient of 0 at x = 0. + + This helper is currently used for vector-norm cases in PT descriptors. + """ + if dim is None: + mask = torch.sum(torch.square(x)) > 0 + x_safe = torch.where(mask, x, torch.ones_like(x)) + norm = torch.linalg.norm(x_safe, ord=ord) + return torch.where(mask, norm, torch.zeros_like(norm)) + + mask = torch.sum(torch.square(x), dim=(dim,), keepdim=True) > 0 + mask_out = mask if keepdim else mask.squeeze(dim) + + x_safe = torch.where(mask, x, torch.ones_like(x)) + norm = torch.linalg.norm(x_safe, ord=ord, dim=dim, keepdim=keepdim) + return torch.where(mask_out, norm, torch.zeros_like(norm)) diff --git a/source/tests/pt/model/test_dpa_hessian_finite.py b/source/tests/pt/model/test_dpa_hessian_finite.py new file mode 100644 index 0000000000..286a565710 --- /dev/null +++ b/source/tests/pt/model/test_dpa_hessian_finite.py @@ -0,0 +1,62 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import numpy as np +import torch + +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .test_permutation import ( + model_dpa2, + model_dpa3, +) + +dtype = torch.float64 + + +class TestDPAHessianFinite(unittest.TestCase): + def _build_inputs(self): + natoms = 5 + cell = 4.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) + generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED) + coord = 3.0 * torch.rand( + [1, natoms, 3], dtype=dtype, device=env.DEVICE, generator=generator + ) + atype = torch.tensor([[0, 0, 0, 1, 1]], dtype=torch.int64, device=env.DEVICE) + return coord.view(1, natoms * 3), atype, cell.view(1, 9) + + def _assert_hessian_finite(self, model_params): + model = get_model(copy.deepcopy(model_params)).to(env.DEVICE) + model.enable_hessian() + model.requires_hessian("energy") + coord, atype, cell = self._build_inputs() + ret = model.forward_common(coord, atype, box=cell) + hessian = to_numpy_array(ret["energy_derv_r_derv_r"]) + self.assertTrue(np.isfinite(hessian).all()) + + def test_dpa2_direct_dist_hessian_is_finite(self): + model_params = copy.deepcopy(model_dpa2) + model_params["descriptor"]["repformer"]["direct_dist"] = True + self._assert_hessian_finite(model_params) + + def test_dpa3_hessian_is_finite(self): + model_params = copy.deepcopy(model_dpa3) + model_params["descriptor"]["precision"] = "float64" + model_params["fitting_net"]["precision"] = "float64" + self._assert_hessian_finite(model_params) + + +if __name__ == "__main__": + unittest.main()