Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions deepmd/pt/model/descriptor/repflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.safe_gradient import (
safe_for_norm,
)
from deepmd.pt.utils.spin import (
concat_switch_virtual,
)
Expand Down Expand Up @@ -473,9 +476,7 @@ def forward(
sw = sw.masked_fill(~nlist_mask, 0.0)

# get angle nlist (maybe smaller)
a_dist_mask = (torch.linalg.norm(diff, dim=-1) < self.a_rcut)[
:, :, : self.a_sel
]
a_dist_mask = (safe_for_norm(diff, dim=-1) < self.a_rcut)[:, :, : self.a_sel]
a_nlist = nlist[:, :, : self.a_sel]
a_nlist = torch.where(a_dist_mask, a_nlist, -1)
_, a_diff, a_sw = prod_env_mat(
Expand Down Expand Up @@ -512,11 +513,11 @@ def forward(
edge_input, h2 = torch.split(dmatrix, [1, 3], dim=-1)
if self.edge_init_use_dist:
# nb x nloc x nnei x 1
edge_input = torch.linalg.norm(diff, dim=-1, keepdim=True)
edge_input = safe_for_norm(diff, dim=-1, keepdim=True)

# nf x nloc x a_nnei x 3
normalized_diff_i = a_diff / (
torch.linalg.norm(a_diff, dim=-1, keepdim=True) + 1e-6
safe_for_norm(a_diff, dim=-1, keepdim=True) + 1e-6
Comment thread
njzjz marked this conversation as resolved.
)
# nf x nloc x 3 x a_nnei
normalized_diff_j = torch.transpose(normalized_diff_i, 2, 3)
Expand Down
5 changes: 4 additions & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
from deepmd.pt.utils.exclude_mask import (
PairExcludeMask,
)
from deepmd.pt.utils.safe_gradient import (
safe_for_norm,
)
from deepmd.pt.utils.spin import (
concat_switch_virtual,
)
Expand Down Expand Up @@ -446,7 +449,7 @@ def forward(
if not self.direct_dist:
g2, h2 = torch.split(dmatrix, [1, 3], dim=-1)
else:
g2, h2 = torch.linalg.norm(diff, dim=-1, keepdim=True), diff
g2, h2 = safe_for_norm(diff, dim=-1, keepdim=True), diff
g2 = g2 / self.rcut
Comment thread
njzjz marked this conversation as resolved.
h2 = h2 / self.rcut
# nb x nloc x nnei x ng2
Expand Down
39 changes: 39 additions & 0 deletions deepmd/pt/utils/safe_gradient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Safe versions of some functions that have problematic gradients.

Check https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where
for more information.
"""

import torch


def safe_for_sqrt(x: torch.Tensor) -> torch.Tensor:
"""Safe version of sqrt that has a gradient of 0 at x = 0."""
mask = x > 0.0
x_safe = torch.where(mask, x, torch.ones_like(x))
return torch.where(mask, torch.sqrt(x_safe), torch.zeros_like(x))


def safe_for_norm(
x: torch.Tensor,
dim: int | None = None,
keepdim: bool = False,
ord: float = 2.0,
) -> torch.Tensor:
"""Safe version of torch.linalg.norm that has a gradient of 0 at x = 0.

This helper is currently used for vector-norm cases in PT descriptors.
"""
if dim is None:
mask = torch.sum(torch.square(x)) > 0
x_safe = torch.where(mask, x, torch.ones_like(x))
norm = torch.linalg.norm(x_safe, ord=ord)
return torch.where(mask, norm, torch.zeros_like(norm))

mask = torch.sum(torch.square(x), dim=(dim,), keepdim=True) > 0
mask_out = mask if keepdim else mask.squeeze(dim)

x_safe = torch.where(mask, x, torch.ones_like(x))
norm = torch.linalg.norm(x_safe, ord=ord, dim=dim, keepdim=keepdim)
return torch.where(mask_out, norm, torch.zeros_like(norm))
62 changes: 62 additions & 0 deletions source/tests/pt/model/test_dpa_hessian_finite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import unittest

import numpy as np
import torch

from deepmd.pt.model.model import (
get_model,
)
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.utils import (
to_numpy_array,
)

from ...seed import (
GLOBAL_SEED,
)
from .test_permutation import (
model_dpa2,
model_dpa3,
)

dtype = torch.float64


class TestDPAHessianFinite(unittest.TestCase):
def _build_inputs(self):
natoms = 5
cell = 4.0 * torch.eye(3, dtype=dtype, device=env.DEVICE)
generator = torch.Generator(device=env.DEVICE).manual_seed(GLOBAL_SEED)
coord = 3.0 * torch.rand(
[1, natoms, 3], dtype=dtype, device=env.DEVICE, generator=generator
)
atype = torch.tensor([[0, 0, 0, 1, 1]], dtype=torch.int64, device=env.DEVICE)
return coord.view(1, natoms * 3), atype, cell.view(1, 9)

def _assert_hessian_finite(self, model_params):
model = get_model(copy.deepcopy(model_params)).to(env.DEVICE)
model.enable_hessian()
model.requires_hessian("energy")
coord, atype, cell = self._build_inputs()
ret = model.forward_common(coord, atype, box=cell)
hessian = to_numpy_array(ret["energy_derv_r_derv_r"])
self.assertTrue(np.isfinite(hessian).all())

def test_dpa2_direct_dist_hessian_is_finite(self):
model_params = copy.deepcopy(model_dpa2)
model_params["descriptor"]["repformer"]["direct_dist"] = True
self._assert_hessian_finite(model_params)

def test_dpa3_hessian_is_finite(self):
model_params = copy.deepcopy(model_dpa3)
model_params["descriptor"]["precision"] = "float64"
model_params["fitting_net"]["precision"] = "float64"
self._assert_hessian_finite(model_params)


if __name__ == "__main__":
unittest.main()
Loading