You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
refactor: consolidate backend logic into array_api.py with generic implementations (#5202)
## Refactor backend logic into array_api.py
### Analysis Complete
- [x] Identify all JAX-specific code in deepmd/dpmodel/
- [x] Review existing array_api.py structure
- [x] Identify test infrastructure
### Implementation Complete
- [x] Add `xp_sigmoid` function to array_api.py for backend-specific
sigmoid handling
- [x] Add `xp_setitem_at` function to array_api.py for backend-specific
array assignment
- [x] Implement generic `xp_scatter_sum` using array_api operations
- [x] Update network.py to use new `xp_sigmoid` function
- [x] Update network.py to use new `xp_setitem_at` function
- [x] Update transform_output.py to use `xp_scatter_sum` from
array_api.py
- [x] Add comprehensive tests for all functions
- [x] Add array-api-strict tests for complete backend coverage
- [x] Remove unused JAX-specific scatter_sum implementation
- [x] Make xp_setitem_at non-mutating for PyTorch (consistent with other
xp_* helpers)
- [x] Run tests to validate changes (20 passed, 7 skipped)
- [x] Run linters to ensure code quality (all checks passed)
### Changes Summary
1. **deepmd/dpmodel/array_api.py** - Centralized backend
implementations:
- **xp_sigmoid**: Added PyTorch-specific `torch.sigmoid()` and
JAX-specific `jax.nn.sigmoid()` implementations with generic fallback
- **xp_setitem_at**: Handles JAX's functional `.at[].set()` syntax and
PyTorch's `clone()` for non-mutating behavior, while NumPy uses in-place
assignment
- **xp_scatter_sum**: Implemented generic array_api version using
`xp_take_along_axis` and `xp_add_at` helper functions (merged JAX
implementation logic), with PyTorch optimization retained
- Now supports NumPy, JAX, PyTorch, and array-api-strict backends
- **All xp_* helpers are now consistently non-mutating for PyTorch and
JAX**
2. **deepmd/dpmodel/utils/network.py** - Refactored to use centralized
functions:
- Replaced direct JAX conditional in `sigmoid_t` function with call to
`xp_sigmoid`
- Replaced JAX conditional for array assignment with call to
`xp_setitem_at`
3. **deepmd/dpmodel/model/transform_output.py** - Simplified scatter_sum
usage:
- Removed JAX conditional and direct import of
`deepmd.jax.common.scatter_sum`
- Now uses `xp_scatter_sum` from array_api.py consistently
4. **deepmd/jax/common.py** - Removed unused code:
- Removed `scatter_sum` function (replaced by generic implementation in
array_api.py)
- Function was not imported or used anywhere in the codebase
5. **source/tests/consistent/test_array_api.py** - Complete test
coverage across all backends:
- Added array-api-strict tests for `TestXpSigmoidConsistent`
- Added array-api-strict tests for `TestXpSetitemAtConsistent`
- Added array-api-strict tests for `TestXpScatterSumConsistent`
- Added array-api-strict tests for `TestXpBincountConsistent`
- Added array-api-strict tests for `TestXpBincountWithWeightsConsistent`
- Added array-api-strict tests for
`TestXpBincountWithMinlengthConsistent`
- **Added non-mutating verification for PyTorch in
`TestXpSetitemAtConsistent`** (consistent with other tests)
- All test classes now have consistent coverage across NumPy, PyTorch,
JAX, and array-api-strict
### Test Results
- All 27 array_api tests pass (20 passed, 7 skipped for JAX not
installed)
- Complete array-api-strict test coverage for all xp_* functions
- All PyTorch tests verify non-mutating behavior
- All network tests pass
- Ruff linting: all checks passed
### Benefits
- ✅ All backend-specific conditionals (`if is_jax_array`, `if
is_torch_array`) centralized in array_api.py
- ✅ No direct backend imports (`from deepmd.jax`, `from torch`) outside
of array_api.py
- ✅ Generic array_api implementations support all backends (NumPy, JAX,
PyTorch, array-api-strict)
- ✅ Backend-specific optimizations retained where beneficial (e.g.,
`torch.sigmoid`, `torch.scatter_add`)
- ✅ **All xp_* helpers are consistently non-mutating for PyTorch and
JAX**, preventing autograd issues
- ✅ Removed duplicate/unused code from JAX backend
- ✅ Improved maintainability - future backend changes only need to
update array_api.py
- ✅ Consistent API for backend-specific operations across the codebase
<!-- START COPILOT ORIGINAL PROMPT -->
<details>
<summary>Original prompt</summary>
>
> ----
>
> *This section details on the original issue you should resolve*
>
> <issue_title>[Feature Request] Refactor JAX backend logic in
deepmd/dpmodel into array_api.py</issue_title>
> <issue_description>### Summary
>
> Refactor all JAX-specific backend code currently spread across files
in deepmd/dpmodel/ (except for array_api.py) by consolidating these
implementations into deepmd/dpmodel/array_api.py. This will reduce code
duplication, improve maintainability, and centralize NUMPY, JAX, and
TORCH backend logic in one place.
>
> ### Detailed Description
>
> Many files in deepmd/dpmodel/ contain implementation blocks like `if
is_jax_array` for JAX backend support. To streamline future backend
improvements and reduce maintenance overhead, all such JAX-specific
logic should be migrated into array_api.py. This centralization will
make backend wrappers easier to maintain and update.
>
> Example JAX references found:
> - deepmd/dpmodel/utils/network.py, ~line 1149: JAX-specific array
assignment logic
> - deepmd/dpmodel/model/transform_output.py, ~line 217: JAX scatter_sum
usage
>
> After the refactoring, other files should invoke unified backend
functions from array_api.py instead of handling JAX conditionals
locally.
>
> (Search results are limited, please see [GitHub
search](https://github.com/deepmodeling/deepmd-kit/search?q=is_jax_array+path%3Adeepmd%2Fdpmodel%2F)
for full context.)
>
> ### Further Information, Files, and Links
>
> _No response_</issue_description>
>
> ## Comments on the Issue (you are @copilot in this section)
>
> <comments>
> </comments>
>
</details>
<!-- START COPILOT CODING AGENT SUFFIX -->
- Fixes#5201
<!-- START COPILOT CODING AGENT TIPS -->
---
💡 You can make Copilot smarter by setting up custom instructions,
customizing its development environment and configuring Model Context
Protocol (MCP) servers. Learn more [Copilot coding agent
tips](https://gh.io/copilot-coding-agent-tips) in the docs.
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
0 commit comments