Skip to content

Commit a491ef5

Browse files
committed
mv methods to array api
1 parent e3ea98b commit a491ef5

2 files changed

Lines changed: 38 additions & 29 deletions

File tree

deepmd/dpmodel/array_api.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,37 @@ def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray
9292
)
9393
else:
9494
raise NotImplementedError("Only JAX arrays are supported.")
95+
96+
97+
def add_at(x, indices, values):
98+
"""Adds values to the specified indices of x in place or returns new x (for JAX)."""
99+
xp = array_api_compat.array_namespace(x, indices, values)
100+
if array_api_compat.is_numpy_array(x):
101+
# NumPy: supports np.add.at (in-place)
102+
xp.add.at(x, indices, values)
103+
return x
104+
105+
elif array_api_compat.is_jax_array(x):
106+
# JAX: functional update, not in-place
107+
return x.at[indices].add(values)
108+
else:
109+
# Fallback for array_api_strict: use basic indexing only
110+
# may need a more efficient way to do this
111+
n = indices.shape[0]
112+
for i in range(n):
113+
idx = indices[i]
114+
x[idx] = x[idx] + values[i]
115+
return x
116+
117+
118+
def bincount(x, weights=None, minlength=0):
119+
"""Counts the number of occurrences of each value in x."""
120+
xp = array_api_compat.array_namespace(x)
121+
if array_api_compat.is_numpy_array(x) or array_api_compat.is_jax_array(x):
122+
result = xp.bincount(x, weights=weights, minlength=minlength)
123+
else:
124+
if weights is None:
125+
weights = xp.ones_like(x)
126+
result = xp.zeros((max(minlength, int(xp.max(x)) + 1),), dtype=weights.dtype)
127+
result = add_at(result, x, weights)
128+
return result

deepmd/dpmodel/utils/network.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
NativeOP,
2222
)
2323
from deepmd.dpmodel.array_api import (
24+
add_at,
25+
bincount,
2426
support_array_api,
2527
)
2628
from deepmd.dpmodel.common import (
@@ -983,42 +985,15 @@ def aggregate(
983985
output: [num_owner, feature_dim]
984986
"""
985987
xp = array_api_compat.array_namespace(data, owners)
986-
987-
def add_at(x, indices, values):
988-
unique_ids = xp.unique(indices)
989-
for i in unique_ids:
990-
mask = xp.where(indices == i, 1, 0)
991-
mask = xp.expand_dims(mask, axis=1) if len(values.shape) != 1 else mask
992-
selected = values * mask
993-
summed = xp.sum(selected, axis=0)
994-
x[i] = x[i] + summed
995-
996-
return x
997-
998-
def bincount(x, weights=None, minlength=0):
999-
if weights is None:
1000-
weights = xp.ones_like(x)
1001-
result = xp.zeros((max(minlength, int(xp.max(x)) + 1),), dtype=weights.dtype)
1002-
result = add_at(result, x, weights)
1003-
return result
1004-
1005-
if hasattr(xp, "bincount"):
1006-
bin_count = xp.bincount(owners)
1007-
else:
1008-
# for array_api_strict
1009-
bin_count = bincount(owners)
988+
bin_count = bincount(owners)
1010989
bin_count = xp.where(bin_count == 0, xp.ones_like(bin_count), bin_count)
1011990

1012991
if num_owner is not None and bin_count.shape[0] != num_owner:
1013992
difference = num_owner - bin_count.shape[0]
1014993
bin_count = xp.concat([bin_count, xp.ones(difference, dtype=bin_count.dtype)])
1015994

1016995
output = xp.zeros((bin_count.shape[0], data.shape[1]), dtype=data.dtype)
1017-
if hasattr(xp, "add") and hasattr(xp.add, "at"):
1018-
xp.add.at(output, owners, data)
1019-
else:
1020-
# for array_api_strict
1021-
output = add_at(output, owners, data)
996+
output = add_at(output, owners, data)
1022997

1023998
if average:
1024999
output = xp.transpose(xp.transpose(output) / bin_count)

0 commit comments

Comments
 (0)