Skip to content

Commit 1c83637

Browse files
Copilotnjzjz
andauthored
refactor: consolidate backend logic into array_api.py with generic implementations (#5202)
## Refactor backend logic into array_api.py ### Analysis Complete - [x] Identify all JAX-specific code in deepmd/dpmodel/ - [x] Review existing array_api.py structure - [x] Identify test infrastructure ### Implementation Complete - [x] Add `xp_sigmoid` function to array_api.py for backend-specific sigmoid handling - [x] Add `xp_setitem_at` function to array_api.py for backend-specific array assignment - [x] Implement generic `xp_scatter_sum` using array_api operations - [x] Update network.py to use new `xp_sigmoid` function - [x] Update network.py to use new `xp_setitem_at` function - [x] Update transform_output.py to use `xp_scatter_sum` from array_api.py - [x] Add comprehensive tests for all functions - [x] Add array-api-strict tests for complete backend coverage - [x] Remove unused JAX-specific scatter_sum implementation - [x] Make xp_setitem_at non-mutating for PyTorch (consistent with other xp_* helpers) - [x] Run tests to validate changes (20 passed, 7 skipped) - [x] Run linters to ensure code quality (all checks passed) ### Changes Summary 1. **deepmd/dpmodel/array_api.py** - Centralized backend implementations: - **xp_sigmoid**: Added PyTorch-specific `torch.sigmoid()` and JAX-specific `jax.nn.sigmoid()` implementations with generic fallback - **xp_setitem_at**: Handles JAX's functional `.at[].set()` syntax and PyTorch's `clone()` for non-mutating behavior, while NumPy uses in-place assignment - **xp_scatter_sum**: Implemented generic array_api version using `xp_take_along_axis` and `xp_add_at` helper functions (merged JAX implementation logic), with PyTorch optimization retained - Now supports NumPy, JAX, PyTorch, and array-api-strict backends - **All xp_* helpers are now consistently non-mutating for PyTorch and JAX** 2. **deepmd/dpmodel/utils/network.py** - Refactored to use centralized functions: - Replaced direct JAX conditional in `sigmoid_t` function with call to `xp_sigmoid` - Replaced JAX conditional for array assignment with call to `xp_setitem_at` 3. **deepmd/dpmodel/model/transform_output.py** - Simplified scatter_sum usage: - Removed JAX conditional and direct import of `deepmd.jax.common.scatter_sum` - Now uses `xp_scatter_sum` from array_api.py consistently 4. **deepmd/jax/common.py** - Removed unused code: - Removed `scatter_sum` function (replaced by generic implementation in array_api.py) - Function was not imported or used anywhere in the codebase 5. **source/tests/consistent/test_array_api.py** - Complete test coverage across all backends: - Added array-api-strict tests for `TestXpSigmoidConsistent` - Added array-api-strict tests for `TestXpSetitemAtConsistent` - Added array-api-strict tests for `TestXpScatterSumConsistent` - Added array-api-strict tests for `TestXpBincountConsistent` - Added array-api-strict tests for `TestXpBincountWithWeightsConsistent` - Added array-api-strict tests for `TestXpBincountWithMinlengthConsistent` - **Added non-mutating verification for PyTorch in `TestXpSetitemAtConsistent`** (consistent with other tests) - All test classes now have consistent coverage across NumPy, PyTorch, JAX, and array-api-strict ### Test Results - All 27 array_api tests pass (20 passed, 7 skipped for JAX not installed) - Complete array-api-strict test coverage for all xp_* functions - All PyTorch tests verify non-mutating behavior - All network tests pass - Ruff linting: all checks passed ### Benefits - ✅ All backend-specific conditionals (`if is_jax_array`, `if is_torch_array`) centralized in array_api.py - ✅ No direct backend imports (`from deepmd.jax`, `from torch`) outside of array_api.py - ✅ Generic array_api implementations support all backends (NumPy, JAX, PyTorch, array-api-strict) - ✅ Backend-specific optimizations retained where beneficial (e.g., `torch.sigmoid`, `torch.scatter_add`) - ✅ **All xp_* helpers are consistently non-mutating for PyTorch and JAX**, preventing autograd issues - ✅ Removed duplicate/unused code from JAX backend - ✅ Improved maintainability - future backend changes only need to update array_api.py - ✅ Consistent API for backend-specific operations across the codebase <!-- START COPILOT ORIGINAL PROMPT --> <details> <summary>Original prompt</summary> > > ---- > > *This section details on the original issue you should resolve* > > <issue_title>[Feature Request] Refactor JAX backend logic in deepmd/dpmodel into array_api.py</issue_title> > <issue_description>### Summary > > Refactor all JAX-specific backend code currently spread across files in deepmd/dpmodel/ (except for array_api.py) by consolidating these implementations into deepmd/dpmodel/array_api.py. This will reduce code duplication, improve maintainability, and centralize NUMPY, JAX, and TORCH backend logic in one place. > > ### Detailed Description > > Many files in deepmd/dpmodel/ contain implementation blocks like `if is_jax_array` for JAX backend support. To streamline future backend improvements and reduce maintenance overhead, all such JAX-specific logic should be migrated into array_api.py. This centralization will make backend wrappers easier to maintain and update. > > Example JAX references found: > - deepmd/dpmodel/utils/network.py, ~line 1149: JAX-specific array assignment logic > - deepmd/dpmodel/model/transform_output.py, ~line 217: JAX scatter_sum usage > > After the refactoring, other files should invoke unified backend functions from array_api.py instead of handling JAX conditionals locally. > > (Search results are limited, please see [GitHub search](https://github.com/deepmodeling/deepmd-kit/search?q=is_jax_array+path%3Adeepmd%2Fdpmodel%2F) for full context.) > > ### Further Information, Files, and Links > > _No response_</issue_description> > > ## Comments on the Issue (you are @copilot in this section) > > <comments> > </comments> > </details> <!-- START COPILOT CODING AGENT SUFFIX --> - Fixes #5201 <!-- START COPILOT CODING AGENT TIPS --> --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 4457061 commit 1c83637

5 files changed

Lines changed: 227 additions & 55 deletions

File tree

deepmd/dpmodel/array_api.py

Lines changed: 83 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -57,25 +57,38 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
5757

5858

5959
def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array:
60-
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
61-
if array_api_compat.is_jax_array(input):
62-
from deepmd.jax.common import (
63-
scatter_sum,
64-
)
60+
"""Reduces all values from the src tensor to the indices specified in the index tensor.
6561
66-
return scatter_sum(
67-
input,
68-
dim,
69-
index,
70-
src,
71-
)
72-
elif array_api_compat.is_torch_array(input):
73-
# PyTorch: use scatter_add (non-mutating version)
62+
This function is similar to PyTorch's scatter_add and JAX's scatter_sum.
63+
It adds values from src to input at positions specified by index along the given dimension.
64+
"""
65+
if array_api_compat.is_torch_array(input):
66+
# PyTorch: use scatter_add (non-mutating version) for better performance
7467
import torch
7568

7669
return torch.scatter_add(input, dim, index, src)
77-
else:
78-
raise NotImplementedError("Only JAX and PyTorch arrays are supported.")
70+
71+
# Generic array_api implementation (works for JAX, NumPy, array-api-strict, etc.)
72+
xp = array_api_compat.array_namespace(input)
73+
74+
# Create flat index array matching input shape
75+
idx = xp.arange(input.size, dtype=xp.int64)
76+
idx = xp.reshape(idx, input.shape)
77+
78+
# Get flat indices where we want to add values
79+
new_idx = xp_take_along_axis(idx, index, axis=dim)
80+
new_idx = xp.reshape(new_idx, (-1,))
81+
82+
# Flatten arrays
83+
shape = input.shape
84+
input_flat = xp.reshape(input, (-1,))
85+
src_flat = xp.reshape(src, (-1,))
86+
87+
# Add values at the specified indices
88+
result = xp_add_at(input_flat, new_idx, src_flat)
89+
90+
# Reshape back to original shape
91+
return xp.reshape(result, shape)
7992

8093

8194
def xp_add_at(x: Array, indices: Array, values: Array) -> Array:
@@ -104,6 +117,61 @@ def xp_add_at(x: Array, indices: Array, values: Array) -> Array:
104117
return x
105118

106119

120+
def xp_sigmoid(x: Array) -> Array:
121+
"""Compute the sigmoid function.
122+
123+
JAX and PyTorch have optimized sigmoid implementations.
124+
See https://github.com/jax-ml/jax/discussions/15617
125+
"""
126+
if array_api_compat.is_jax_array(x):
127+
from deepmd.jax.env import (
128+
jax,
129+
)
130+
131+
return jax.nn.sigmoid(x)
132+
elif array_api_compat.is_torch_array(x):
133+
import torch
134+
135+
return torch.sigmoid(x)
136+
xp = array_api_compat.array_namespace(x)
137+
return 1 / (1 + xp.exp(-x))
138+
139+
140+
def xp_setitem_at(x: Array, mask: Array, values: Array) -> Array:
141+
"""Set items at boolean mask indices.
142+
143+
For JAX and PyTorch arrays, returns a new array (non-mutating).
144+
For NumPy arrays, modifies in-place and returns the same array.
145+
146+
Parameters
147+
----------
148+
x : Array
149+
The array to modify
150+
mask : Array
151+
Boolean mask indicating positions to set
152+
values : Array
153+
Values to set at masked positions
154+
155+
Returns
156+
-------
157+
Array
158+
Modified array (new array for JAX/PyTorch, same array for NumPy)
159+
"""
160+
if array_api_compat.is_jax_array(x):
161+
# JAX doesn't support in-place item assignment
162+
return x.at[mask].set(values)
163+
elif array_api_compat.is_torch_array(x):
164+
# PyTorch: clone to avoid mutating the input (non-mutating version)
165+
import torch
166+
167+
result = torch.clone(x)
168+
result[mask] = values
169+
return result
170+
# Standard item assignment for NumPy, array-api-strict, etc.
171+
x[mask] = values
172+
return x
173+
174+
107175
def xp_bincount(x: Array, weights: Array | None = None, minlength: int = 0) -> Array:
108176
"""Counts the number of occurrences of each value in x."""
109177
xp = array_api_compat.array_namespace(x)

deepmd/dpmodel/model/transform_output.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -219,20 +219,12 @@ def communicate_extended_output(
219219
vldims + derv_c_ext_dims,
220220
dtype=vv.dtype,
221221
)
222-
# jax only
223-
if array_api_compat.is_jax_array(virial):
224-
from deepmd.jax.common import (
225-
scatter_sum,
226-
)
227-
228-
virial = scatter_sum(
229-
virial,
230-
1,
231-
mapping,
232-
model_ret[kk_derv_c],
233-
)
234-
else:
235-
raise NotImplementedError("Only JAX arrays are supported.")
222+
virial = xp_scatter_sum(
223+
virial,
224+
1,
225+
mapping,
226+
model_ret[kk_derv_c],
227+
)
236228
new_ret[kk_derv_c] = virial
237229
new_ret[kk_derv_c + "_redu"] = xp.sum(new_ret[kk_derv_c], axis=1)
238230
else:

deepmd/dpmodel/utils/network.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
Array,
2626
xp_add_at,
2727
xp_bincount,
28+
xp_setitem_at,
29+
xp_sigmoid,
2830
)
2931
from deepmd.dpmodel.common import (
3032
to_numpy_array,
@@ -39,15 +41,7 @@
3941

4042
def sigmoid_t(x): # noqa: ANN001, ANN201
4143
"""Sigmoid."""
42-
if array_api_compat.is_jax_array(x):
43-
from deepmd.jax.env import (
44-
jax,
45-
)
46-
47-
# see https://github.com/jax-ml/jax/discussions/15617
48-
return jax.nn.sigmoid(x)
49-
xp = array_api_compat.array_namespace(x)
50-
return 1 / (1 + xp.exp(-x))
44+
return xp_sigmoid(x)
5145

5246

5347
class Identity(NativeOP):
@@ -1251,11 +1245,7 @@ def get_graph_index( # noqa: ANN201
12511245
# edge(ij) to angle(ijk) index_select; angle(ijk) to edge(ij) aggregate
12521246
edge_id = xp.arange(n_edge, dtype=nlist.dtype)
12531247
edge_index = xp.zeros((nf, nloc, nnei), dtype=nlist.dtype)
1254-
if array_api_compat.is_jax_array(nlist):
1255-
# JAX doesn't support in-place item assignment
1256-
edge_index = edge_index.at[xp.astype(nlist_mask, xp.bool)].set(edge_id)
1257-
else:
1258-
edge_index[xp.astype(nlist_mask, xp.bool)] = edge_id
1248+
edge_index = xp_setitem_at(edge_index, xp.astype(nlist_mask, xp.bool), edge_id)
12591249
# only cut a_nnei neighbors, to avoid nnei x nnei
12601250
edge_index = edge_index[:, :, :a_nnei]
12611251
edge_index_ij = xp.broadcast_to(

deepmd/jax/common.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -94,15 +94,3 @@ def __dlpack__(self, *args: Any, **kwargs: Any) -> Any:
9494

9595
def __dlpack_device__(self, *args: Any, **kwargs: Any) -> Any:
9696
return self.value.__dlpack_device__(*args, **kwargs)
97-
98-
99-
def scatter_sum(
100-
input: jnp.ndarray, dim: int, index: jnp.ndarray, src: jnp.ndarray
101-
) -> jnp.ndarray:
102-
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
103-
idx = jnp.arange(input.size, dtype=jnp.int64).reshape(input.shape)
104-
new_idx = jnp.take_along_axis(idx, index, axis=dim).ravel()
105-
shape = input.shape
106-
input = input.ravel()
107-
input = input.at[new_idx].add(src.ravel())
108-
return input.reshape(shape)

source/tests/consistent/test_array_api.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
xp_add_at,
99
xp_bincount,
1010
xp_scatter_sum,
11+
xp_setitem_at,
12+
xp_sigmoid,
1113
)
1214
from deepmd.dpmodel.common import (
1315
to_numpy_array,
@@ -66,6 +68,19 @@ def test_jax_consistent_with_ref(self) -> None:
6668
result = xp_scatter_sum(input_jax, self.dim, index_jax, src_jax)
6769
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
6870

71+
@unittest.skipUnless(
72+
INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed"
73+
)
74+
@unittest.skipUnless(
75+
sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8"
76+
)
77+
def test_array_api_strict_consistent_with_ref(self) -> None:
78+
input_xp = xp.asarray(self.input_np)
79+
index_xp = xp.asarray(self.index_np)
80+
src_xp = xp.asarray(self.src_np)
81+
result = xp_scatter_sum(input_xp, self.dim, index_xp, src_xp)
82+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
83+
6984

7085
class TestXpAddAtConsistent(unittest.TestCase):
7186
"""Test xp_add_at consistency across backends."""
@@ -139,6 +154,17 @@ def test_jax_consistent_with_ref(self) -> None:
139154
result = xp_bincount(x_jax)
140155
np.testing.assert_equal(self.ref, to_numpy_array(result))
141156

157+
@unittest.skipUnless(
158+
INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed"
159+
)
160+
@unittest.skipUnless(
161+
sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8"
162+
)
163+
def test_array_api_strict_consistent_with_ref(self) -> None:
164+
x_xp = xp.asarray(self.x_np)
165+
result = xp_bincount(x_xp)
166+
np.testing.assert_equal(self.ref, to_numpy_array(result))
167+
142168

143169
class TestXpBincountWithWeightsConsistent(unittest.TestCase):
144170
"""Test xp_bincount with weights consistency across backends."""
@@ -166,6 +192,18 @@ def test_jax_consistent_with_ref(self) -> None:
166192
result = xp_bincount(x_jax, weights=weights_jax)
167193
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
168194

195+
@unittest.skipUnless(
196+
INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed"
197+
)
198+
@unittest.skipUnless(
199+
sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8"
200+
)
201+
def test_array_api_strict_consistent_with_ref(self) -> None:
202+
x_xp = xp.asarray(self.x_np)
203+
weights_xp = xp.asarray(self.weights_np)
204+
result = xp_bincount(x_xp, weights=weights_xp)
205+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
206+
169207

170208
class TestXpBincountWithMinlengthConsistent(unittest.TestCase):
171209
"""Test xp_bincount with minlength consistency across backends."""
@@ -190,3 +228,99 @@ def test_jax_consistent_with_ref(self) -> None:
190228
x_jax = jnp.array(self.x_np)
191229
result = xp_bincount(x_jax, minlength=self.minlength)
192230
np.testing.assert_equal(self.ref, to_numpy_array(result))
231+
232+
@unittest.skipUnless(
233+
INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed"
234+
)
235+
@unittest.skipUnless(
236+
sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8"
237+
)
238+
def test_array_api_strict_consistent_with_ref(self) -> None:
239+
x_xp = xp.asarray(self.x_np)
240+
result = xp_bincount(x_xp, minlength=self.minlength)
241+
np.testing.assert_equal(self.ref, to_numpy_array(result))
242+
243+
244+
class TestXpSigmoidConsistent(unittest.TestCase):
245+
"""Test xp_sigmoid consistency across backends."""
246+
247+
def setUp(self) -> None:
248+
self.x_np = np.array([-2.0, -1.0, 0.0, 1.0, 2.0])
249+
# Reference using NumPy sigmoid
250+
self.ref = 1 / (1 + np.exp(-self.x_np))
251+
252+
def test_numpy_consistent_with_ref(self) -> None:
253+
result = xp_sigmoid(self.x_np)
254+
np.testing.assert_allclose(self.ref, result, atol=1e-10)
255+
256+
@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
257+
def test_pt_consistent_with_ref(self) -> None:
258+
x_pt = torch.from_numpy(self.x_np)
259+
result = xp_sigmoid(x_pt)
260+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
261+
262+
@unittest.skipUnless(INSTALLED_JAX, "JAX is not installed")
263+
def test_jax_consistent_with_ref(self) -> None:
264+
x_jax = jnp.array(self.x_np)
265+
result = xp_sigmoid(x_jax)
266+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
267+
268+
@unittest.skipUnless(
269+
INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed"
270+
)
271+
@unittest.skipUnless(
272+
sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8"
273+
)
274+
def test_array_api_strict_consistent_with_ref(self) -> None:
275+
x_xp = xp.asarray(self.x_np)
276+
result = xp_sigmoid(x_xp)
277+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
278+
279+
280+
class TestXpSetitemAtConsistent(unittest.TestCase):
281+
"""Test xp_setitem_at consistency across backends."""
282+
283+
def setUp(self) -> None:
284+
self.x_np = np.zeros((5, 3))
285+
self.mask_np = np.array([True, False, True, False, True])
286+
self.values_np = np.ones((3, 3))
287+
# Reference using NumPy
288+
self.ref = self.x_np.copy()
289+
self.ref[self.mask_np] = self.values_np
290+
291+
def test_numpy_consistent_with_ref(self) -> None:
292+
x = self.x_np.copy()
293+
result = xp_setitem_at(x, self.mask_np, self.values_np)
294+
np.testing.assert_allclose(self.ref, result, atol=1e-10)
295+
296+
@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
297+
def test_pt_consistent_with_ref(self) -> None:
298+
x_pt = torch.from_numpy(self.x_np)
299+
mask_pt = torch.from_numpy(self.mask_np)
300+
values_pt = torch.from_numpy(self.values_np)
301+
result = xp_setitem_at(x_pt, mask_pt, values_pt)
302+
# Verify original tensor is unchanged (non-mutating)
303+
np.testing.assert_allclose(self.x_np, to_numpy_array(x_pt), atol=1e-10)
304+
# Verify result matches reference
305+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
306+
307+
@unittest.skipUnless(INSTALLED_JAX, "JAX is not installed")
308+
def test_jax_consistent_with_ref(self) -> None:
309+
x_jax = jnp.array(self.x_np)
310+
mask_jax = jnp.array(self.mask_np)
311+
values_jax = jnp.array(self.values_np)
312+
result = xp_setitem_at(x_jax, mask_jax, values_jax)
313+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
314+
315+
@unittest.skipUnless(
316+
INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed"
317+
)
318+
@unittest.skipUnless(
319+
sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8"
320+
)
321+
def test_array_api_strict_consistent_with_ref(self) -> None:
322+
x_xp = xp.asarray(self.x_np)
323+
mask_xp = xp.asarray(self.mask_np)
324+
values_xp = xp.asarray(self.values_np)
325+
result = xp_setitem_at(x_xp, mask_xp, values_xp)
326+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)

0 commit comments

Comments
 (0)