fix(pt): fix NaN Hessian in DPA2 and DPA3#5351
Conversation
📝 WalkthroughWalkthroughAdds a safe-norm utility and replaces direct Changes
Sequence Diagram(s)(omitted — changes are not a new multi-component control flow) Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Suggested reviewers
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (2)
deepmd/pt/utils/safe_gradient.py (2)
31-33: Minor efficiency note:mask_outrecomputes the squared sum.When
keepdim=False, line 33 recomputestorch.sum(torch.square(x), dim=dim_list)instead of squeezing the already-computedmask. This is a minor inefficiency but unlikely to matter in practice.♻️ Optional: Reuse computed mask
dim_list = [dim] mask = torch.sum(torch.square(x), dim=dim_list, keepdim=True) > 0 - mask_out = mask if keepdim else (torch.sum(torch.square(x), dim=dim_list) > 0) + mask_out = mask if keepdim else mask.squeeze(dim) x_safe = torch.where(mask, x, torch.ones_like(x))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/utils/safe_gradient.py` around lines 31 - 33, The code recomputes torch.sum(torch.square(x), dim=dim_list) when building mask_out; instead reuse the already-computed mask by squeezing it when keepdim is False. Update the mask_out assignment to use mask (and .squeeze or equivalent) instead of recomputing the squared-sum so mask_out = mask if keepdim else mask.squeeze(dim_list) (adjust squeeze dims to match dim_list) referring to the existing variables mask, mask_out, dim_list, x, and keepdim.
18-37: Consider renamingordparameter to avoid shadowing the Python builtin.The static analysis tool flags that
ordshadows the Python builtin. While this pattern is common in numerical code (mirroringtorch.linalg.norm's API), renaming topororderwould address the linter warning.♻️ Optional: Rename parameter to avoid shadowing
def safe_for_norm( x: torch.Tensor, dim: int | None = None, keepdim: bool = False, - ord: float = 2.0, + order: float = 2.0, ) -> torch.Tensor: """Safe version of vector_norm that has a gradient of 0 at x = 0.""" if dim is None: mask = torch.sum(torch.square(x)) > 0 x_safe = torch.where(mask, x, torch.ones_like(x)) - norm = torch.linalg.norm(x_safe, ord=ord) + norm = torch.linalg.norm(x_safe, ord=order) return torch.where(mask, norm, torch.zeros_like(norm)) dim_list = [dim] mask = torch.sum(torch.square(x), dim=dim_list, keepdim=True) > 0 mask_out = mask if keepdim else (torch.sum(torch.square(x), dim=dim_list) > 0) x_safe = torch.where(mask, x, torch.ones_like(x)) - norm = torch.linalg.norm(x_safe, ord=ord, dim=dim_list, keepdim=keepdim) + norm = torch.linalg.norm(x_safe, ord=order, dim=dim_list, keepdim=keepdim) return torch.where(mask_out, norm, torch.zeros_like(norm))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/utils/safe_gradient.py` around lines 18 - 37, The function safe_for_norm shadows the built-in name "ord"; rename the parameter (e.g., to "order" or "p") in the safe_for_norm signature and update all internal uses (the torch.linalg.norm calls and any references inside the function) to the new name, and ensure any call sites within the repo that call safe_for_norm are updated to the new parameter name if they use keyword arguments; keep behavior identical aside from the parameter name change.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@deepmd/pt/utils/safe_gradient.py`:
- Around line 31-33: The code recomputes torch.sum(torch.square(x),
dim=dim_list) when building mask_out; instead reuse the already-computed mask by
squeezing it when keepdim is False. Update the mask_out assignment to use mask
(and .squeeze or equivalent) instead of recomputing the squared-sum so mask_out
= mask if keepdim else mask.squeeze(dim_list) (adjust squeeze dims to match
dim_list) referring to the existing variables mask, mask_out, dim_list, x, and
keepdim.
- Around line 18-37: The function safe_for_norm shadows the built-in name "ord";
rename the parameter (e.g., to "order" or "p") in the safe_for_norm signature
and update all internal uses (the torch.linalg.norm calls and any references
inside the function) to the new name, and ensure any call sites within the repo
that call safe_for_norm are updated to the new parameter name if they use
keyword arguments; keep behavior identical aside from the parameter name change.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 83461c4f-b5ca-46b8-8cd5-ef95fa85df08
📒 Files selected for processing (3)
deepmd/pt/model/descriptor/repflows.pydeepmd/pt/model/descriptor/repformers.pydeepmd/pt/utils/safe_gradient.py
There was a problem hiding this comment.
Pull request overview
Fixes NaN issues when computing Hessians for PyTorch DPA2/DPA3 descriptors by introducing “safe gradient” norm/sqrt helpers and using them in Repformer/Repflow distance computations (similar to prior JAX-side fixes in #4668/#4809).
Changes:
- Add
deepmd/pt/utils/safe_gradient.pyprovidingsafe_for_sqrtandsafe_for_norm. - Replace
torch.linalg.norm(...)withsafe_for_norm(...)in DPA2 Repformer and DPA3 Repflow descriptor paths where zero-distance edge cases can occur. - Use the safe norm in Repflow’s angular neighbor mask and direction normalization.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
deepmd/pt/utils/safe_gradient.py |
Introduces safe-gradient helpers intended to avoid NaNs in higher-order derivatives at zero. |
deepmd/pt/model/descriptor/repformers.py |
Uses safe norm for the direct_dist distance feature path in Repformer. |
deepmd/pt/model/descriptor/repflows.py |
Uses safe norm for angle neighbor selection and vector normalization in Repflow. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5351 +/- ##
==========================================
+ Coverage 82.26% 82.32% +0.05%
==========================================
Files 799 811 +12
Lines 82563 83236 +673
Branches 4066 4066
==========================================
+ Hits 67924 68524 +600
- Misses 13424 13497 +73
Partials 1215 1215 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
@copilot apply changes based on the comments in this thread |
Use vector_norm semantics in safe_for_norm and add focused regression tests to verify DPA2/DPA3 Hessians stay finite. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4)
Keep the helper aligned with torch.linalg.norm semantics while retaining the zero-gradient masking cleanup. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4)
There was a problem hiding this comment.
🧹 Nitpick comments (2)
deepmd/pt/utils/safe_gradient.py (1)
18-23: Consider renamingordto avoid shadowing Python builtin.Ruff flags this as A002. While
ordmirrorstorch.linalg.norm's signature, renaming toorderornorm_ordwould satisfy the linter and avoid potential confusion.Suggested fix
def safe_for_norm( x: torch.Tensor, dim: int | None = None, keepdim: bool = False, - ord: float = 2.0, + norm_ord: float = 2.0, ) -> torch.Tensor: """Safe version of torch.linalg.norm that has a gradient of 0 at x = 0. This helper is currently used for vector-norm cases in PT descriptors. """ if dim is None: mask = torch.sum(torch.square(x)) > 0 x_safe = torch.where(mask, x, torch.ones_like(x)) - norm = torch.linalg.norm(x_safe, ord=ord) + norm = torch.linalg.norm(x_safe, ord=norm_ord) return torch.where(mask, norm, torch.zeros_like(norm)) mask = torch.sum(torch.square(x), dim=(dim,), keepdim=True) > 0 mask_out = mask if keepdim else mask.squeeze(dim) x_safe = torch.where(mask, x, torch.ones_like(x)) - norm = torch.linalg.norm(x_safe, ord=ord, dim=dim, keepdim=keepdim) + norm = torch.linalg.norm(x_safe, ord=norm_ord, dim=dim, keepdim=keepdim) return torch.where(mask_out, norm, torch.zeros_like(norm))Alternatively, if this is intentional to match PyTorch's API, add
# noqa: A002to suppress.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt/utils/safe_gradient.py` around lines 18 - 23, The parameter name ord in the safe_for_norm function shadows a Python builtin and triggers linter A002; rename it to a non-shadowing identifier (e.g., order or norm_ord) across the function signature and all internal references in safe_for_norm to fix the linter, or if you intentionally matched torch.linalg.norm, add a "# noqa: A002" comment next to the parameter definition to suppress the warning; ensure all callers or tests are updated to use the new parameter name (or accept the noqa approach) to keep behavior consistent.source/tests/pt/model/test_dpa_hessian_finite.py (1)
40-47: Redundant hessian enablement calls.When
hessian_mode=Trueis set inmodel_params,get_model()already callsenable_hessian()(seeget_standard_model()inmodel/__init__.py:286), which in turn callsrequires_hessian("energy")(seeener_model.py:44). Lines 42-43 are therefore redundant for these tests.This is harmless (likely idempotent), but could be cleaned up for clarity.
Suggested simplification
def _assert_hessian_finite(self, model_params): model = get_model(copy.deepcopy(model_params)).to(env.DEVICE) - model.enable_hessian() - model.requires_hessian("energy") coord, atype, cell = self._build_inputs() ret = model.forward_common(coord, atype, box=cell) hessian = to_numpy_array(ret["energy_derv_r_derv_r"]) self.assertTrue(np.isfinite(hessian).all())🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt/model/test_dpa_hessian_finite.py` around lines 40 - 47, The test _assert_hessian_finite contains redundant calls to model.enable_hessian() and model.requires_hessian("energy") because get_model(...) already enables hessian when model_params contains hessian_mode=True; remove those two calls from _assert_hessian_finite so the test relies on get_model's setup (references: _assert_hessian_finite, get_model, enable_hessian, requires_hessian, hessian_mode).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@deepmd/pt/utils/safe_gradient.py`:
- Around line 18-23: The parameter name ord in the safe_for_norm function
shadows a Python builtin and triggers linter A002; rename it to a non-shadowing
identifier (e.g., order or norm_ord) across the function signature and all
internal references in safe_for_norm to fix the linter, or if you intentionally
matched torch.linalg.norm, add a "# noqa: A002" comment next to the parameter
definition to suppress the warning; ensure all callers or tests are updated to
use the new parameter name (or accept the noqa approach) to keep behavior
consistent.
In `@source/tests/pt/model/test_dpa_hessian_finite.py`:
- Around line 40-47: The test _assert_hessian_finite contains redundant calls to
model.enable_hessian() and model.requires_hessian("energy") because
get_model(...) already enables hessian when model_params contains
hessian_mode=True; remove those two calls from _assert_hessian_finite so the
test relies on get_model's setup (references: _assert_hessian_finite, get_model,
enable_hessian, requires_hessian, hessian_mode).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 127b78bd-ffb1-4712-a7e2-bff07e9d6728
📒 Files selected for processing (2)
deepmd/pt/utils/safe_gradient.pysource/tests/pt/model/test_dpa_hessian_finite.py
The model returned by get_model should remain in normal mode here; calling enable_hessian once inside the test helper is enough. Authored by OpenClaw (model: custom-chat-jinzhezeng-group/gpt-5.4)
test(pt): fix hessian regression follow-up
There was a problem hiding this comment.
🧹 Nitpick comments (1)
source/tests/pt/model/test_dpa_hessian_finite.py (1)
40-43: Remove duplicaterequires_hessian()call and set eval mode.Line 42 already enables
"energy"Hessian viaenable_hessian(), which internally callsrequires_hessian("energy"), making line 43 redundant. Additionally, settingmodel.eval()avoids unnecessary graph construction in downstream Hessian code paths since this helper only reads the computed Hessian without backpropagation.♻️ Proposed cleanup
def _assert_hessian_finite(self, model_params): model = get_model(copy.deepcopy(model_params)).to(env.DEVICE) model.enable_hessian() - model.requires_hessian("energy") + model.eval() coord, atype, cell = self._build_inputs() ret = model.forward_common(coord, atype, box=cell)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt/model/test_dpa_hessian_finite.py` around lines 40 - 43, Remove the redundant requires_hessian("energy") call and set the model to eval mode in the helper: inside _assert_hessian_finite, after creating the model with get_model(...), call model.enable_hessian() (which already marks the "energy" Hessian) and then call model.eval() to avoid unnecessary graph construction before running Hessian-related reads; remove the explicit model.requires_hessian("energy") line.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@source/tests/pt/model/test_dpa_hessian_finite.py`:
- Around line 40-43: Remove the redundant requires_hessian("energy") call and
set the model to eval mode in the helper: inside _assert_hessian_finite, after
creating the model with get_model(...), call model.enable_hessian() (which
already marks the "energy" Hessian) and then call model.eval() to avoid
unnecessary graph construction before running Hessian-related reads; remove the
explicit model.requires_hessian("energy") line.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 11243f60-6f82-4135-9654-7438eebc9df3
📒 Files selected for processing (1)
source/tests/pt/model/test_dpa_hessian_finite.py
This fix is similar to #4668 and #4809.
Summary by CodeRabbit
New Features
Bug Fixes
Tests