feat(dpmodel): add PyTorch support to array_api utilities#5198
feat(dpmodel): add PyTorch support to array_api utilities#5198wanghan-iapcm merged 4 commits intomasterfrom
Conversation
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5198 +/- ##
=======================================
Coverage 81.96% 81.96%
=======================================
Files 714 714
Lines 73502 73508 +6
Branches 3615 3616 +1
=======================================
+ Hits 60247 60252 +5
Misses 12092 12092
- Partials 1163 1164 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Pull request overview
Adds PyTorch support to deepmd/dpmodel/array_api.py utility functions so DPModel’s array-API helpers can operate on torch tensors in addition to existing NumPy/JAX support, with new consistency tests to validate behavior across backends.
Changes:
- Added PyTorch implementations for
xp_scatter_sum(torch.scatter_add) andxp_add_at(torch.index_add). - Extended
xp_bincountbackend handling to include PyTorch arrays. - Introduced new consistency tests covering PyTorch/JAX/NumPy (and
array_api_strictwhere applicable).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated no comments.
| File | Description |
|---|---|
deepmd/dpmodel/array_api.py |
Adds torch branches for scatter/add-at utilities and enables torch in xp_bincount. |
source/tests/consistent/test_array_api.py |
New backend-consistency tests for xp_scatter_sum, xp_add_at, and xp_bincount including non-mutating checks for torch. |
Comments suppressed due to low confidence (1)
deepmd/dpmodel/array_api.py:113
xp_add_atdocstring mentions returning a new array only for JAX, but the new PyTorch implementation is also out-of-place (non-mutating). Update the docstring to reflect that behavior for PyTorch as well, so callers don’t assume PyTorch will mutate like NumPy.
def xp_add_at(x: Array, indices: Array, values: Array) -> Array:
"""Adds values to the specified indices of x in place or returns new x (for JAX)."""
xp = array_api_compat.array_namespace(x, indices, values)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
The
array_api.pymodule only supported JAX and NumPy arrays. PyTorch tensors now work with all array API utilities.Changes
xp_scatter_sum: Added PyTorch path usingtorch.scatter_add()xp_add_at: Added PyTorch path usingtorch.index_add()xp_bincount: Extended backend check to include PyTorch arraysAll PyTorch operations use non-mutating variants (no trailing
_) to maintain functional semantics consistent with JAX.Testing
source/tests/consistent/test_array_api.pyUsage
Original prompt
💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more Copilot coding agent tips in the docs.