Skip to content

fix(pt): fix NaN Hessian in DPA2 and DPA3#5351

Merged
iProzd merged 6 commits intodeepmodeling:masterfrom
njzjz:fix-hessian-nan-dpa3
Mar 31, 2026
Merged

fix(pt): fix NaN Hessian in DPA2 and DPA3#5351
iProzd merged 6 commits intodeepmodeling:masterfrom
njzjz:fix-hessian-nan-dpa3

Conversation

@njzjz
Copy link
Copy Markdown
Member

@njzjz njzjz commented Mar 28, 2026

This fix is similar to #4668 and #4809.

Summary by CodeRabbit

  • New Features

    • Added safe numerical helpers that return zero outputs and zero gradients for problematic zero-valued inputs.
  • Bug Fixes

    • Replaced unstable norm/sqrt usages in descriptor computations to improve stability at zeros and boundaries, reducing spurious gradients and NaNs.
  • Tests

    • Added unit tests that verify Hessians remain finite across representative model configurations.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 28, 2026

📝 Walkthrough

Walkthrough

Adds a safe-norm utility and replaces direct torch.linalg.norm calls in descriptor code with safe_for_norm to avoid zero-value gradient issues; adds tests that assert Hessian finiteness for affected models.

Changes

Cohort / File(s) Summary
Safe gradient utilities
deepmd/pt/utils/safe_gradient.py
New module adding safe_for_sqrt and safe_for_norm to mask zero-valued reductions so outputs and gradients are zero where the squared-sum is zero.
Repflows & Repformers descriptors
deepmd/pt/model/descriptor/repflows.py, deepmd/pt/model/descriptor/repformers.py
Replaced torch.linalg.norm(...) usages with safe_for_norm(...) for angle cutoff masking, edge distance initialization, angular-difference normalization, and direct-distance computation while preserving dim/keepdim semantics.
Tests
source/tests/pt/model/test_dpa_hessian_finite.py
New unit test class TestDPAHessianFinite that builds deterministic inputs, enables Hessian computation, and asserts energy Hessians are finite for configured models.

Sequence Diagram(s)

(omitted — changes are not a new multi-component control flow)

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

Possibly related PRs

Suggested reviewers

  • wanghan-iapcm
  • iProzd
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 28.57% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title accurately describes the main purpose of the PR - fixing NaN Hessian computation issues in DPA2 and DPA3 models.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
deepmd/pt/utils/safe_gradient.py (2)

31-33: Minor efficiency note: mask_out recomputes the squared sum.

When keepdim=False, line 33 recomputes torch.sum(torch.square(x), dim=dim_list) instead of squeezing the already-computed mask. This is a minor inefficiency but unlikely to matter in practice.

♻️ Optional: Reuse computed mask
     dim_list = [dim]
     mask = torch.sum(torch.square(x), dim=dim_list, keepdim=True) > 0
-    mask_out = mask if keepdim else (torch.sum(torch.square(x), dim=dim_list) > 0)
+    mask_out = mask if keepdim else mask.squeeze(dim)

     x_safe = torch.where(mask, x, torch.ones_like(x))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/utils/safe_gradient.py` around lines 31 - 33, The code recomputes
torch.sum(torch.square(x), dim=dim_list) when building mask_out; instead reuse
the already-computed mask by squeezing it when keepdim is False. Update the
mask_out assignment to use mask (and .squeeze or equivalent) instead of
recomputing the squared-sum so mask_out = mask if keepdim else
mask.squeeze(dim_list) (adjust squeeze dims to match dim_list) referring to the
existing variables mask, mask_out, dim_list, x, and keepdim.

18-37: Consider renaming ord parameter to avoid shadowing the Python builtin.

The static analysis tool flags that ord shadows the Python builtin. While this pattern is common in numerical code (mirroring torch.linalg.norm's API), renaming to p or order would address the linter warning.

♻️ Optional: Rename parameter to avoid shadowing
 def safe_for_norm(
     x: torch.Tensor,
     dim: int | None = None,
     keepdim: bool = False,
-    ord: float = 2.0,
+    order: float = 2.0,
 ) -> torch.Tensor:
     """Safe version of vector_norm that has a gradient of 0 at x = 0."""
     if dim is None:
         mask = torch.sum(torch.square(x)) > 0
         x_safe = torch.where(mask, x, torch.ones_like(x))
-        norm = torch.linalg.norm(x_safe, ord=ord)
+        norm = torch.linalg.norm(x_safe, ord=order)
         return torch.where(mask, norm, torch.zeros_like(norm))

     dim_list = [dim]
     mask = torch.sum(torch.square(x), dim=dim_list, keepdim=True) > 0
     mask_out = mask if keepdim else (torch.sum(torch.square(x), dim=dim_list) > 0)

     x_safe = torch.where(mask, x, torch.ones_like(x))
-    norm = torch.linalg.norm(x_safe, ord=ord, dim=dim_list, keepdim=keepdim)
+    norm = torch.linalg.norm(x_safe, ord=order, dim=dim_list, keepdim=keepdim)
     return torch.where(mask_out, norm, torch.zeros_like(norm))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/utils/safe_gradient.py` around lines 18 - 37, The function
safe_for_norm shadows the built-in name "ord"; rename the parameter (e.g., to
"order" or "p") in the safe_for_norm signature and update all internal uses (the
torch.linalg.norm calls and any references inside the function) to the new name,
and ensure any call sites within the repo that call safe_for_norm are updated to
the new parameter name if they use keyword arguments; keep behavior identical
aside from the parameter name change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@deepmd/pt/utils/safe_gradient.py`:
- Around line 31-33: The code recomputes torch.sum(torch.square(x),
dim=dim_list) when building mask_out; instead reuse the already-computed mask by
squeezing it when keepdim is False. Update the mask_out assignment to use mask
(and .squeeze or equivalent) instead of recomputing the squared-sum so mask_out
= mask if keepdim else mask.squeeze(dim_list) (adjust squeeze dims to match
dim_list) referring to the existing variables mask, mask_out, dim_list, x, and
keepdim.
- Around line 18-37: The function safe_for_norm shadows the built-in name "ord";
rename the parameter (e.g., to "order" or "p") in the safe_for_norm signature
and update all internal uses (the torch.linalg.norm calls and any references
inside the function) to the new name, and ensure any call sites within the repo
that call safe_for_norm are updated to the new parameter name if they use
keyword arguments; keep behavior identical aside from the parameter name change.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 83461c4f-b5ca-46b8-8cd5-ef95fa85df08

📥 Commits

Reviewing files that changed from the base of the PR and between 2a82988 and 7c1c1b2.

📒 Files selected for processing (3)
  • deepmd/pt/model/descriptor/repflows.py
  • deepmd/pt/model/descriptor/repformers.py
  • deepmd/pt/utils/safe_gradient.py

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes NaN issues when computing Hessians for PyTorch DPA2/DPA3 descriptors by introducing “safe gradient” norm/sqrt helpers and using them in Repformer/Repflow distance computations (similar to prior JAX-side fixes in #4668/#4809).

Changes:

  • Add deepmd/pt/utils/safe_gradient.py providing safe_for_sqrt and safe_for_norm.
  • Replace torch.linalg.norm(...) with safe_for_norm(...) in DPA2 Repformer and DPA3 Repflow descriptor paths where zero-distance edge cases can occur.
  • Use the safe norm in Repflow’s angular neighbor mask and direction normalization.

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.

File Description
deepmd/pt/utils/safe_gradient.py Introduces safe-gradient helpers intended to avoid NaNs in higher-order derivatives at zero.
deepmd/pt/model/descriptor/repformers.py Uses safe norm for the direct_dist distance feature path in Repformer.
deepmd/pt/model/descriptor/repflows.py Uses safe norm for angle neighbor selection and vector normalization in Repflow.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread deepmd/pt/utils/safe_gradient.py Outdated
Comment thread deepmd/pt/model/descriptor/repformers.py
Comment thread deepmd/pt/model/descriptor/repflows.py
@codecov
Copy link
Copy Markdown

codecov bot commented Mar 28, 2026

Codecov Report

❌ Patch coverage is 85.71429% with 3 lines in your changes missing coverage. Please review.
✅ Project coverage is 82.32%. Comparing base (2a82988) to head (f9e6a5b).
⚠️ Report is 8 commits behind head on master.

Files with missing lines Patch % Lines
deepmd/pt/utils/safe_gradient.py 81.25% 3 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #5351      +/-   ##
==========================================
+ Coverage   82.26%   82.32%   +0.05%     
==========================================
  Files         799      811      +12     
  Lines       82563    83236     +673     
  Branches     4066     4066              
==========================================
+ Hits        67924    68524     +600     
- Misses      13424    13497      +73     
  Partials     1215     1215              

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@njzjz
Copy link
Copy Markdown
Member Author

njzjz commented Mar 29, 2026

@copilot apply changes based on the comments in this thread

@iProzd iProzd enabled auto-merge March 30, 2026 03:00
Use vector_norm semantics in safe_for_norm and add focused
regression tests to verify DPA2/DPA3 Hessians stay finite.

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4)
njzjz-bot and others added 2 commits March 30, 2026 23:47
Keep the helper aligned with torch.linalg.norm semantics while
retaining the zero-gradient masking cleanup.

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
deepmd/pt/utils/safe_gradient.py (1)

18-23: Consider renaming ord to avoid shadowing Python builtin.

Ruff flags this as A002. While ord mirrors torch.linalg.norm's signature, renaming to order or norm_ord would satisfy the linter and avoid potential confusion.

Suggested fix
 def safe_for_norm(
     x: torch.Tensor,
     dim: int | None = None,
     keepdim: bool = False,
-    ord: float = 2.0,
+    norm_ord: float = 2.0,
 ) -> torch.Tensor:
     """Safe version of torch.linalg.norm that has a gradient of 0 at x = 0.

     This helper is currently used for vector-norm cases in PT descriptors.
     """
     if dim is None:
         mask = torch.sum(torch.square(x)) > 0
         x_safe = torch.where(mask, x, torch.ones_like(x))
-        norm = torch.linalg.norm(x_safe, ord=ord)
+        norm = torch.linalg.norm(x_safe, ord=norm_ord)
         return torch.where(mask, norm, torch.zeros_like(norm))

     mask = torch.sum(torch.square(x), dim=(dim,), keepdim=True) > 0
     mask_out = mask if keepdim else mask.squeeze(dim)

     x_safe = torch.where(mask, x, torch.ones_like(x))
-    norm = torch.linalg.norm(x_safe, ord=ord, dim=dim, keepdim=keepdim)
+    norm = torch.linalg.norm(x_safe, ord=norm_ord, dim=dim, keepdim=keepdim)
     return torch.where(mask_out, norm, torch.zeros_like(norm))

Alternatively, if this is intentional to match PyTorch's API, add # noqa: A002 to suppress.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@deepmd/pt/utils/safe_gradient.py` around lines 18 - 23, The parameter name
ord in the safe_for_norm function shadows a Python builtin and triggers linter
A002; rename it to a non-shadowing identifier (e.g., order or norm_ord) across
the function signature and all internal references in safe_for_norm to fix the
linter, or if you intentionally matched torch.linalg.norm, add a "# noqa: A002"
comment next to the parameter definition to suppress the warning; ensure all
callers or tests are updated to use the new parameter name (or accept the noqa
approach) to keep behavior consistent.
source/tests/pt/model/test_dpa_hessian_finite.py (1)

40-47: Redundant hessian enablement calls.

When hessian_mode=True is set in model_params, get_model() already calls enable_hessian() (see get_standard_model() in model/__init__.py:286), which in turn calls requires_hessian("energy") (see ener_model.py:44). Lines 42-43 are therefore redundant for these tests.

This is harmless (likely idempotent), but could be cleaned up for clarity.

Suggested simplification
     def _assert_hessian_finite(self, model_params):
         model = get_model(copy.deepcopy(model_params)).to(env.DEVICE)
-        model.enable_hessian()
-        model.requires_hessian("energy")
         coord, atype, cell = self._build_inputs()
         ret = model.forward_common(coord, atype, box=cell)
         hessian = to_numpy_array(ret["energy_derv_r_derv_r"])
         self.assertTrue(np.isfinite(hessian).all())
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt/model/test_dpa_hessian_finite.py` around lines 40 - 47, The
test _assert_hessian_finite contains redundant calls to model.enable_hessian()
and model.requires_hessian("energy") because get_model(...) already enables
hessian when model_params contains hessian_mode=True; remove those two calls
from _assert_hessian_finite so the test relies on get_model's setup (references:
_assert_hessian_finite, get_model, enable_hessian, requires_hessian,
hessian_mode).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@deepmd/pt/utils/safe_gradient.py`:
- Around line 18-23: The parameter name ord in the safe_for_norm function
shadows a Python builtin and triggers linter A002; rename it to a non-shadowing
identifier (e.g., order or norm_ord) across the function signature and all
internal references in safe_for_norm to fix the linter, or if you intentionally
matched torch.linalg.norm, add a "# noqa: A002" comment next to the parameter
definition to suppress the warning; ensure all callers or tests are updated to
use the new parameter name (or accept the noqa approach) to keep behavior
consistent.

In `@source/tests/pt/model/test_dpa_hessian_finite.py`:
- Around line 40-47: The test _assert_hessian_finite contains redundant calls to
model.enable_hessian() and model.requires_hessian("energy") because
get_model(...) already enables hessian when model_params contains
hessian_mode=True; remove those two calls from _assert_hessian_finite so the
test relies on get_model's setup (references: _assert_hessian_finite, get_model,
enable_hessian, requires_hessian, hessian_mode).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 127b78bd-ffb1-4712-a7e2-bff07e9d6728

📥 Commits

Reviewing files that changed from the base of the PR and between 7c1c1b2 and 9b215f7.

📒 Files selected for processing (2)
  • deepmd/pt/utils/safe_gradient.py
  • source/tests/pt/model/test_dpa_hessian_finite.py

njzjz-bot and others added 2 commits March 31, 2026 01:43
The model returned by get_model should remain in normal mode here;
calling enable_hessian once inside the test helper is enough.

Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4)
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
source/tests/pt/model/test_dpa_hessian_finite.py (1)

40-43: Remove duplicate requires_hessian() call and set eval mode.

Line 42 already enables "energy" Hessian via enable_hessian(), which internally calls requires_hessian("energy"), making line 43 redundant. Additionally, setting model.eval() avoids unnecessary graph construction in downstream Hessian code paths since this helper only reads the computed Hessian without backpropagation.

♻️ Proposed cleanup
 def _assert_hessian_finite(self, model_params):
     model = get_model(copy.deepcopy(model_params)).to(env.DEVICE)
     model.enable_hessian()
-    model.requires_hessian("energy")
+    model.eval()
     coord, atype, cell = self._build_inputs()
     ret = model.forward_common(coord, atype, box=cell)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@source/tests/pt/model/test_dpa_hessian_finite.py` around lines 40 - 43,
Remove the redundant requires_hessian("energy") call and set the model to eval
mode in the helper: inside _assert_hessian_finite, after creating the model with
get_model(...), call model.enable_hessian() (which already marks the "energy"
Hessian) and then call model.eval() to avoid unnecessary graph construction
before running Hessian-related reads; remove the explicit
model.requires_hessian("energy") line.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@source/tests/pt/model/test_dpa_hessian_finite.py`:
- Around line 40-43: Remove the redundant requires_hessian("energy") call and
set the model to eval mode in the helper: inside _assert_hessian_finite, after
creating the model with get_model(...), call model.enable_hessian() (which
already marks the "energy" Hessian) and then call model.eval() to avoid
unnecessary graph construction before running Hessian-related reads; remove the
explicit model.requires_hessian("energy") line.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Repository UI

Review profile: CHILL

Plan: Pro

Run ID: 11243f60-6f82-4135-9654-7438eebc9df3

📥 Commits

Reviewing files that changed from the base of the PR and between 9b215f7 and f9e6a5b.

📒 Files selected for processing (1)
  • source/tests/pt/model/test_dpa_hessian_finite.py

@iProzd iProzd added this pull request to the merge queue Mar 31, 2026
Merged via the queue into deepmodeling:master with commit afb97f6 Mar 31, 2026
70 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants