From 7c1c1b2769e3f8ca1c54c5be2b9b727c52dd2b61 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 29 Mar 2026 01:16:10 +0800 Subject: [PATCH 1/4] fix(pt): fix NaN Hessian in DPA2 and DPA3 --- deepmd/pt/model/descriptor/repflows.py | 11 +++---- deepmd/pt/model/descriptor/repformers.py | 5 +++- deepmd/pt/utils/safe_gradient.py | 37 ++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 6 deletions(-) create mode 100644 deepmd/pt/utils/safe_gradient.py diff --git a/deepmd/pt/model/descriptor/repflows.py b/deepmd/pt/model/descriptor/repflows.py index 433897860f..a9d2326b93 100644 --- a/deepmd/pt/model/descriptor/repflows.py +++ b/deepmd/pt/model/descriptor/repflows.py @@ -35,6 +35,9 @@ from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) +from deepmd.pt.utils.safe_gradient import ( + safe_for_norm, +) from deepmd.pt.utils.spin import ( concat_switch_virtual, ) @@ -473,9 +476,7 @@ def forward( sw = sw.masked_fill(~nlist_mask, 0.0) # get angle nlist (maybe smaller) - a_dist_mask = (torch.linalg.norm(diff, dim=-1) < self.a_rcut)[ - :, :, : self.a_sel - ] + a_dist_mask = (safe_for_norm(diff, dim=-1) < self.a_rcut)[:, :, : self.a_sel] a_nlist = nlist[:, :, : self.a_sel] a_nlist = torch.where(a_dist_mask, a_nlist, -1) _, a_diff, a_sw = prod_env_mat( @@ -512,11 +513,11 @@ def forward( edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1) if self.edge_init_use_dist: # nb x nloc x nnei x 1 - edge_input = torch.linalg.norm(diff, dim=-1, keepdim=True) + edge_input = safe_for_norm(diff, dim=-1, keepdim=True) # nf x nloc x a_nnei x 3 normalized_diff_i = a_diff / ( - torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6 + safe_for_norm(a_diff, dim=-1, keepdim=True) + 1e-6 ) # nf x nloc x 3 x a_nnei normalized_diff_j = torch.transpose(normalized_diff_i, 2, 3) diff --git a/deepmd/pt/model/descriptor/repformers.py b/deepmd/pt/model/descriptor/repformers.py index 69f2cc4eaa..75e2f97576 100644 --- a/deepmd/pt/model/descriptor/repformers.py +++ b/deepmd/pt/model/descriptor/repformers.py @@ -32,6 +32,9 @@ from deepmd.pt.utils.exclude_mask import ( PairExcludeMask, ) +from deepmd.pt.utils.safe_gradient import ( + safe_for_norm, +) from deepmd.pt.utils.spin import ( concat_switch_virtual, ) @@ -446,7 +449,7 @@ def forward( if not self.direct_dist: g2, h2 = torch.split(dmatrix, [1, 3], dim=-1) else: - g2, h2 = torch.linalg.norm(diff, dim=-1, keepdim=True), diff + g2, h2 = safe_for_norm(diff, dim=-1, keepdim=True), diff g2 = g2 / self.rcut h2 = h2 / self.rcut # nb x nloc x nnei x ng2 diff --git a/deepmd/pt/utils/safe_gradient.py b/deepmd/pt/utils/safe_gradient.py new file mode 100644 index 0000000000..37310d1af4 --- /dev/null +++ b/deepmd/pt/utils/safe_gradient.py @@ -0,0 +1,37 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +"""Safe versions of some functions that have problematic gradients. + +Check https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where +for more information. +""" + +import torch + + +def safe_for_sqrt(x: torch.Tensor) -> torch.Tensor: + """Safe version of sqrt that has a gradient of 0 at x = 0.""" + mask = x > 0.0 + x_safe = torch.where(mask, x, torch.ones_like(x)) + return torch.where(mask, torch.sqrt(x_safe), torch.zeros_like(x)) + + +def safe_for_norm( + x: torch.Tensor, + dim: int | None = None, + keepdim: bool = False, + ord: float = 2.0, +) -> torch.Tensor: + """Safe version of vector_norm that has a gradient of 0 at x = 0.""" + if dim is None: + mask = torch.sum(torch.square(x)) > 0 + x_safe = torch.where(mask, x, torch.ones_like(x)) + norm = torch.linalg.norm(x_safe, ord=ord) + return torch.where(mask, norm, torch.zeros_like(norm)) + + dim_list = [dim] + mask = torch.sum(torch.square(x), dim=dim_list, keepdim=True) > 0 + mask_out = mask if keepdim else (torch.sum(torch.square(x), dim=dim_list) > 0) + + x_safe = torch.where(mask, x, torch.ones_like(x)) + norm = torch.linalg.norm(x_safe, ord=ord, dim=dim_list, keepdim=keepdim) + return torch.where(mask_out, norm, torch.zeros_like(norm)) From 5f0ae435fb3b9e0d2682cc8e38096cda054075be Mon Sep 17 00:00:00 2001 From: "njzjz-bot (driven by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4))[bot]" <48687836+njzjz-bot@users.noreply.github.com> Date: Mon, 30 Mar 2026 22:55:49 +0000 Subject: [PATCH 2/4] fix(pt): address hessian review comments Use vector_norm semantics in safe_for_norm and add focused regression tests to verify DPA2/DPA3 Hessians stay finite. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4) --- deepmd/pt/utils/safe_gradient.py | 9 ++- .../tests/pt/model/test_dpa_hessian_finite.py | 64 +++++++++++++++++++ 2 files changed, 68 insertions(+), 5 deletions(-) create mode 100644 source/tests/pt/model/test_dpa_hessian_finite.py diff --git a/deepmd/pt/utils/safe_gradient.py b/deepmd/pt/utils/safe_gradient.py index 37310d1af4..f8deb63b4f 100644 --- a/deepmd/pt/utils/safe_gradient.py +++ b/deepmd/pt/utils/safe_gradient.py @@ -25,13 +25,12 @@ def safe_for_norm( if dim is None: mask = torch.sum(torch.square(x)) > 0 x_safe = torch.where(mask, x, torch.ones_like(x)) - norm = torch.linalg.norm(x_safe, ord=ord) + norm = torch.linalg.vector_norm(x_safe, ord=ord) return torch.where(mask, norm, torch.zeros_like(norm)) - dim_list = [dim] - mask = torch.sum(torch.square(x), dim=dim_list, keepdim=True) > 0 - mask_out = mask if keepdim else (torch.sum(torch.square(x), dim=dim_list) > 0) + mask = torch.sum(torch.square(x), dim=(dim,), keepdim=True) > 0 + mask_out = mask if keepdim else mask.squeeze(dim) x_safe = torch.where(mask, x, torch.ones_like(x)) - norm = torch.linalg.norm(x_safe, ord=ord, dim=dim_list, keepdim=keepdim) + norm = torch.linalg.vector_norm(x_safe, ord=ord, dim=dim, keepdim=keepdim) return torch.where(mask_out, norm, torch.zeros_like(norm)) diff --git a/source/tests/pt/model/test_dpa_hessian_finite.py b/source/tests/pt/model/test_dpa_hessian_finite.py new file mode 100644 index 0000000000..d5a0be30e8 --- /dev/null +++ b/source/tests/pt/model/test_dpa_hessian_finite.py @@ -0,0 +1,64 @@ +# SPDX-License-Identifier: LGPL-3.0-or-later +import copy +import unittest + +import numpy as np +import torch + +from deepmd.pt.model.model import ( + get_model, +) +from deepmd.pt.utils import ( + env, +) +from deepmd.pt.utils.utils import ( + to_numpy_array, +) + +from ...seed import ( + GLOBAL_SEED, +) +from .test_permutation import ( + model_dpa2, + model_dpa3, +) + +dtype = torch.float64 + + +class TestDPAHessianFinite(unittest.TestCase): + def _build_inputs(self): + natoms = 5 + cell = 4.0 * torch.eye(3, dtype=dtype, device=env.DEVICE) + generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED) + coord = 3.0 * torch.rand( + [1, natoms, 3], dtype=dtype, device=env.DEVICE, generator=generator + ) + atype = torch.tensor([[0, 0, 0, 1, 1]], dtype=torch.int64, device=env.DEVICE) + return coord.view(1, natoms * 3), atype, cell.view(1, 9) + + def _assert_hessian_finite(self, model_params): + model = get_model(copy.deepcopy(model_params)).to(env.DEVICE) + model.enable_hessian() + model.requires_hessian("energy") + coord, atype, cell = self._build_inputs() + ret = model.forward_common(coord, atype, box=cell) + hessian = to_numpy_array(ret["energy_derv_r_derv_r"]) + self.assertTrue(np.isfinite(hessian).all()) + + def test_dpa2_direct_dist_hessian_is_finite(self): + model_params = copy.deepcopy(model_dpa2) + model_params["descriptor"]["repformer"]["direct_dist"] = True + model_params["hessian_mode"] = True + self._assert_hessian_finite(model_params) + + def test_dpa3_hessian_is_finite(self): + model_params = copy.deepcopy(model_dpa3) + model_params["descriptor"]["precision"] = "float64" + model_params["fitting_net"]["precision"] = "float64" + model_params["hessian_mode"] = True + self._assert_hessian_finite(model_params) + + +if __name__ == "__main__": + unittest.main() From 32620db96716f1265495c63fa5303b4877ea8946 Mon Sep 17 00:00:00 2001 From: "njzjz-bot (driven by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4))[bot]" <48687836+njzjz-bot@users.noreply.github.com> Date: Mon, 30 Mar 2026 23:47:41 +0000 Subject: [PATCH 3/4] fix(pt): keep safe_for_norm aligned with torch.linalg.norm Keep the helper aligned with torch.linalg.norm semantics while retaining the zero-gradient masking cleanup. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4) --- deepmd/pt/utils/safe_gradient.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/deepmd/pt/utils/safe_gradient.py b/deepmd/pt/utils/safe_gradient.py index f8deb63b4f..e0bb76a8cf 100644 --- a/deepmd/pt/utils/safe_gradient.py +++ b/deepmd/pt/utils/safe_gradient.py @@ -21,16 +21,19 @@ def safe_for_norm( keepdim: bool = False, ord: float = 2.0, ) -> torch.Tensor: - """Safe version of vector_norm that has a gradient of 0 at x = 0.""" + """Safe version of torch.linalg.norm that has a gradient of 0 at x = 0. + + This helper is currently used for vector-norm cases in PT descriptors. + """ if dim is None: mask = torch.sum(torch.square(x)) > 0 x_safe = torch.where(mask, x, torch.ones_like(x)) - norm = torch.linalg.vector_norm(x_safe, ord=ord) + norm = torch.linalg.norm(x_safe, ord=ord) return torch.where(mask, norm, torch.zeros_like(norm)) mask = torch.sum(torch.square(x), dim=(dim,), keepdim=True) > 0 mask_out = mask if keepdim else mask.squeeze(dim) x_safe = torch.where(mask, x, torch.ones_like(x)) - norm = torch.linalg.vector_norm(x_safe, ord=ord, dim=dim, keepdim=keepdim) + norm = torch.linalg.norm(x_safe, ord=ord, dim=dim, keepdim=keepdim) return torch.where(mask_out, norm, torch.zeros_like(norm)) From 8e7b03828c77c7ed13c8241993b5a2dcce1c0300 Mon Sep 17 00:00:00 2001 From: "njzjz-bot (driven by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4))[bot]" <48687836+njzjz-bot@users.noreply.github.com> Date: Tue, 31 Mar 2026 01:30:11 +0000 Subject: [PATCH 4/4] test(pt): avoid double-enabling hessian in regression test The model returned by get_model should remain in normal mode here; calling enable_hessian once inside the test helper is enough. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4) --- source/tests/pt/model/test_dpa_hessian_finite.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/source/tests/pt/model/test_dpa_hessian_finite.py b/source/tests/pt/model/test_dpa_hessian_finite.py index d5a0be30e8..286a565710 100644 --- a/source/tests/pt/model/test_dpa_hessian_finite.py +++ b/source/tests/pt/model/test_dpa_hessian_finite.py @@ -49,14 +49,12 @@ def _assert_hessian_finite(self, model_params): def test_dpa2_direct_dist_hessian_is_finite(self): model_params = copy.deepcopy(model_dpa2) model_params["descriptor"]["repformer"]["direct_dist"] = True - model_params["hessian_mode"] = True self._assert_hessian_finite(model_params) def test_dpa3_hessian_is_finite(self): model_params = copy.deepcopy(model_dpa3) model_params["descriptor"]["precision"] = "float64" model_params["fitting_net"]["precision"] = "float64" - model_params["hessian_mode"] = True self._assert_hessian_finite(model_params)