Skip to content

Commit 5f73113

Browse files
OutisLiCopilot
andauthored
fix(pt): pairtab (#5119)
<!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **Bug Fixes** * Improved numerical stability of pairwise distance computations to prevent NaN gradients for zero differences and padded/masked entries; outputs remain unchanged while gradient behavior is now robust for those edge cases. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: OutisLi <137472077+OutisLi@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent f736ab2 commit 5f73113

1 file changed

Lines changed: 16 additions & 1 deletion

File tree

deepmd/pt/model/atomic_model/pairtab_atomic_model.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,14 +392,29 @@ def _get_pairwise_dist(coords: torch.Tensor, nlist: torch.Tensor) -> torch.Tenso
392392
-------
393393
torch.Tensor
394394
The pairwise distance between the atoms (nframes, nloc, nnei).
395+
396+
Notes
397+
-----
398+
Safe gradient implementation: when diff is zero (padding entries),
399+
both distance and gradient are zero.
395400
"""
396401
nframes, nloc, nnei = nlist.shape
397402
coord_l = coords[:, :nloc].view(nframes, -1, 1, 3)
398403
index = nlist.view(nframes, -1).unsqueeze(-1).expand(-1, -1, 3)
399404
coord_r = torch.gather(coords, 1, index)
400405
coord_r = coord_r.view(nframes, nloc, nnei, 3)
401406
diff = coord_r - coord_l
402-
pairwise_rr = torch.linalg.norm(diff, dim=-1, keepdim=True).squeeze(-1)
407+
diff_sq = torch.sum(diff * diff, dim=-1, keepdim=True)
408+
409+
# When diff is zero, output is zero and gradient is also zero
410+
mask = diff_sq.squeeze(-1) > 0
411+
pairwise_rr = torch.where(
412+
mask.unsqueeze(-1),
413+
torch.sqrt(
414+
torch.where(mask.unsqueeze(-1), diff_sq, torch.ones_like(diff_sq))
415+
),
416+
torch.zeros_like(diff_sq),
417+
).squeeze(-1)
403418
return pairwise_rr
404419

405420
@staticmethod

0 commit comments

Comments
 (0)