Skip to content

Commit ef0be63

Browse files
Copilotnjzjz
andauthored
feat(dpmodel): add PyTorch support to array_api utilities (#5198)
The `array_api.py` module only supported JAX and NumPy arrays. PyTorch tensors now work with all array API utilities. ## Changes - **`xp_scatter_sum`**: Added PyTorch path using `torch.scatter_add()` - **`xp_add_at`**: Added PyTorch path using `torch.index_add()` - **`xp_bincount`**: Extended backend check to include PyTorch arrays All PyTorch operations use non-mutating variants (no trailing `_`) to maintain functional semantics consistent with JAX. ## Testing - Created comprehensive consistency tests in `source/tests/consistent/test_array_api.py` - Tests verify consistency across NumPy, PyTorch, JAX, and array_api_strict backends - Tests verify non-mutating behavior for PyTorch and JAX - All 10 tests pass successfully (5 JAX tests skipped when JAX not installed) ## Usage ```python import torch from deepmd.dpmodel.array_api import xp_bincount, xp_add_at # PyTorch tensors now work seamlessly x = torch.tensor([0, 1, 1, 3, 2, 1, 7]) counts = xp_bincount(x) # tensor([1, 3, 1, 1, 0, 0, 0, 1]) # Operations are non-mutating x = torch.zeros(5, 3) result = xp_add_at(x, torch.tensor([0, 1]), torch.ones(2, 3)) # x remains unchanged, result contains the update ``` <!-- START COPILOT CODING AGENT SUFFIX --> <!-- START COPILOT ORIGINAL PROMPT --> <details> <summary>Original prompt</summary> > The current `deepmd/dpmodel/array_api.py` only considers JAX and NumPy. Please also support pytorch. </details> <!-- START COPILOT CODING AGENT TIPS --> --- 💡 You can make Copilot smarter by setting up custom instructions, customizing its development environment and configuring Model Context Protocol (MCP) servers. Learn more [Copilot coding agent tips](https://gh.io/copilot-coding-agent-tips) in the docs. --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: njzjz <9496702+njzjz@users.noreply.github.com>
1 parent 6112545 commit ef0be63

2 files changed

Lines changed: 208 additions & 3 deletions

File tree

deepmd/dpmodel/array_api.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@ def xp_take_along_axis(arr: Array, indices: Array, axis: int) -> Array:
8888

8989
def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array:
9090
"""Reduces all values from the src tensor to the indices specified in the index tensor."""
91-
# jax only
9291
if array_api_compat.is_jax_array(input):
9392
from deepmd.jax.common import (
9493
scatter_sum,
@@ -100,8 +99,13 @@ def xp_scatter_sum(input: Array, dim: int, index: Array, src: Array) -> Array:
10099
index,
101100
src,
102101
)
102+
elif array_api_compat.is_torch_array(input):
103+
# PyTorch: use scatter_add (non-mutating version)
104+
import torch
105+
106+
return torch.scatter_add(input, dim, index, src)
103107
else:
104-
raise NotImplementedError("Only JAX arrays are supported.")
108+
raise NotImplementedError("Only JAX and PyTorch arrays are supported.")
105109

106110

107111
def xp_add_at(x: Array, indices: Array, values: Array) -> Array:
@@ -115,6 +119,11 @@ def xp_add_at(x: Array, indices: Array, values: Array) -> Array:
115119
elif array_api_compat.is_jax_array(x):
116120
# JAX: functional update, not in-place
117121
return x.at[indices].add(values)
122+
elif array_api_compat.is_torch_array(x):
123+
# PyTorch: use index_add (non-mutating version)
124+
import torch
125+
126+
return torch.index_add(x, 0, indices, values)
118127
else:
119128
# Fallback for array_api_strict: use basic indexing only
120129
# may need a more efficient way to do this
@@ -128,7 +137,11 @@ def xp_add_at(x: Array, indices: Array, values: Array) -> Array:
128137
def xp_bincount(x: Array, weights: Array | None = None, minlength: int = 0) -> Array:
129138
"""Counts the number of occurrences of each value in x."""
130139
xp = array_api_compat.array_namespace(x)
131-
if array_api_compat.is_numpy_array(x) or array_api_compat.is_jax_array(x):
140+
if (
141+
array_api_compat.is_numpy_array(x)
142+
or array_api_compat.is_jax_array(x)
143+
or array_api_compat.is_torch_array(x)
144+
):
132145
result = xp.bincount(x, weights=weights, minlength=minlength)
133146
else:
134147
if weights is None:
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# SPDX-License-Identifier: LGPL-3.0-or-later
2+
import sys
3+
import unittest
4+
5+
import numpy as np
6+
7+
from deepmd.dpmodel.array_api import (
8+
xp_add_at,
9+
xp_bincount,
10+
xp_scatter_sum,
11+
)
12+
from deepmd.dpmodel.common import (
13+
to_numpy_array,
14+
)
15+
16+
from .common import (
17+
INSTALLED_ARRAY_API_STRICT,
18+
INSTALLED_JAX,
19+
INSTALLED_PT,
20+
)
21+
22+
if INSTALLED_PT:
23+
import torch
24+
25+
if INSTALLED_JAX:
26+
from deepmd.jax.env import (
27+
jnp,
28+
)
29+
30+
if INSTALLED_ARRAY_API_STRICT:
31+
import array_api_strict as xp
32+
33+
34+
class TestXpScatterSumConsistent(unittest.TestCase):
35+
"""Test xp_scatter_sum consistency across backends."""
36+
37+
def setUp(self) -> None:
38+
# Reference using NumPy (via clone and scatter_add simulation)
39+
self.input_np = np.zeros((3, 5))
40+
self.dim = 0
41+
self.index_np = np.array([[0, 1, 2, 0, 0]])
42+
self.src_np = np.ones((1, 5))
43+
# Manually compute reference for scatter_sum
44+
self.ref = self.input_np.copy()
45+
for i in range(self.index_np.shape[0]):
46+
for j in range(self.index_np.shape[1]):
47+
idx = self.index_np[i, j]
48+
self.ref[idx, j] += self.src_np[i, j]
49+
50+
@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
51+
def test_pt_consistent_with_ref(self) -> None:
52+
input_pt = torch.from_numpy(self.input_np)
53+
index_pt = torch.from_numpy(self.index_np).long()
54+
src_pt = torch.from_numpy(self.src_np)
55+
result = xp_scatter_sum(input_pt, self.dim, index_pt, src_pt)
56+
# Verify original tensor is unchanged (non-mutating)
57+
np.testing.assert_allclose(self.input_np, to_numpy_array(input_pt), atol=1e-10)
58+
# Verify result matches reference
59+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
60+
61+
@unittest.skipUnless(INSTALLED_JAX, "JAX is not installed")
62+
def test_jax_consistent_with_ref(self) -> None:
63+
input_jax = jnp.array(self.input_np)
64+
index_jax = jnp.array(self.index_np)
65+
src_jax = jnp.array(self.src_np)
66+
result = xp_scatter_sum(input_jax, self.dim, index_jax, src_jax)
67+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
68+
69+
70+
class TestXpAddAtConsistent(unittest.TestCase):
71+
"""Test xp_add_at consistency across backends."""
72+
73+
def setUp(self) -> None:
74+
self.x_np = np.zeros((5, 3))
75+
self.indices_np = np.array([0, 1, 1, 3])
76+
self.values_np = np.ones((4, 3))
77+
# Reference using NumPy
78+
self.ref = self.x_np.copy()
79+
np.add.at(self.ref, self.indices_np, self.values_np)
80+
81+
def test_numpy_consistent_with_ref(self) -> None:
82+
x = self.x_np.copy()
83+
result = xp_add_at(x, self.indices_np, self.values_np)
84+
np.testing.assert_allclose(self.ref, result, atol=1e-10)
85+
86+
@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
87+
def test_pt_consistent_with_ref(self) -> None:
88+
x_pt = torch.from_numpy(self.x_np)
89+
indices_pt = torch.from_numpy(self.indices_np).long()
90+
values_pt = torch.from_numpy(self.values_np)
91+
result = xp_add_at(x_pt, indices_pt, values_pt)
92+
# Verify original tensor is unchanged (non-mutating)
93+
np.testing.assert_allclose(self.x_np, to_numpy_array(x_pt), atol=1e-10)
94+
# Verify result matches reference
95+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
96+
97+
@unittest.skipUnless(INSTALLED_JAX, "JAX is not installed")
98+
def test_jax_consistent_with_ref(self) -> None:
99+
x_jax = jnp.array(self.x_np)
100+
indices_jax = jnp.array(self.indices_np)
101+
values_jax = jnp.array(self.values_np)
102+
result = xp_add_at(x_jax, indices_jax, values_jax)
103+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
104+
105+
@unittest.skipUnless(
106+
INSTALLED_ARRAY_API_STRICT, "array_api_strict is not installed"
107+
)
108+
@unittest.skipUnless(
109+
sys.version_info >= (3, 9), "array_api_strict doesn't support Python<=3.8"
110+
)
111+
def test_array_api_strict_consistent_with_ref(self) -> None:
112+
x_xp = xp.asarray(self.x_np)
113+
indices_xp = xp.asarray(self.indices_np)
114+
values_xp = xp.asarray(self.values_np)
115+
result = xp_add_at(x_xp, indices_xp, values_xp)
116+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
117+
118+
119+
class TestXpBincountConsistent(unittest.TestCase):
120+
"""Test xp_bincount consistency across backends."""
121+
122+
def setUp(self) -> None:
123+
self.x_np = np.array([0, 1, 1, 3, 2, 1, 7])
124+
self.ref = np.bincount(self.x_np)
125+
126+
def test_numpy_consistent_with_ref(self) -> None:
127+
result = xp_bincount(self.x_np)
128+
np.testing.assert_equal(self.ref, result)
129+
130+
@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
131+
def test_pt_consistent_with_ref(self) -> None:
132+
x_pt = torch.from_numpy(self.x_np)
133+
result = xp_bincount(x_pt)
134+
np.testing.assert_equal(self.ref, to_numpy_array(result))
135+
136+
@unittest.skipUnless(INSTALLED_JAX, "JAX is not installed")
137+
def test_jax_consistent_with_ref(self) -> None:
138+
x_jax = jnp.array(self.x_np)
139+
result = xp_bincount(x_jax)
140+
np.testing.assert_equal(self.ref, to_numpy_array(result))
141+
142+
143+
class TestXpBincountWithWeightsConsistent(unittest.TestCase):
144+
"""Test xp_bincount with weights consistency across backends."""
145+
146+
def setUp(self) -> None:
147+
self.x_np = np.array([0, 1, 1, 3, 2, 1, 7])
148+
self.weights_np = np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7])
149+
self.ref = np.bincount(self.x_np, weights=self.weights_np)
150+
151+
def test_numpy_consistent_with_ref(self) -> None:
152+
result = xp_bincount(self.x_np, weights=self.weights_np)
153+
np.testing.assert_allclose(self.ref, result, atol=1e-10)
154+
155+
@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
156+
def test_pt_consistent_with_ref(self) -> None:
157+
x_pt = torch.from_numpy(self.x_np)
158+
weights_pt = torch.from_numpy(self.weights_np)
159+
result = xp_bincount(x_pt, weights=weights_pt)
160+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
161+
162+
@unittest.skipUnless(INSTALLED_JAX, "JAX is not installed")
163+
def test_jax_consistent_with_ref(self) -> None:
164+
x_jax = jnp.array(self.x_np)
165+
weights_jax = jnp.array(self.weights_np)
166+
result = xp_bincount(x_jax, weights=weights_jax)
167+
np.testing.assert_allclose(self.ref, to_numpy_array(result), atol=1e-10)
168+
169+
170+
class TestXpBincountWithMinlengthConsistent(unittest.TestCase):
171+
"""Test xp_bincount with minlength consistency across backends."""
172+
173+
def setUp(self) -> None:
174+
self.x_np = np.array([0, 1, 1, 3])
175+
self.minlength = 10
176+
self.ref = np.bincount(self.x_np, minlength=self.minlength)
177+
178+
def test_numpy_consistent_with_ref(self) -> None:
179+
result = xp_bincount(self.x_np, minlength=self.minlength)
180+
np.testing.assert_equal(self.ref, result)
181+
182+
@unittest.skipUnless(INSTALLED_PT, "PyTorch is not installed")
183+
def test_pt_consistent_with_ref(self) -> None:
184+
x_pt = torch.from_numpy(self.x_np)
185+
result = xp_bincount(x_pt, minlength=self.minlength)
186+
np.testing.assert_equal(self.ref, to_numpy_array(result))
187+
188+
@unittest.skipUnless(INSTALLED_JAX, "JAX is not installed")
189+
def test_jax_consistent_with_ref(self) -> None:
190+
x_jax = jnp.array(self.x_np)
191+
result = xp_bincount(x_jax, minlength=self.minlength)
192+
np.testing.assert_equal(self.ref, to_numpy_array(result))

0 commit comments

Comments
 (0)