Skip to content

Commit bc8b0e0

Browse files
committed
mimic array_api_strict
1 parent 51702ea commit bc8b0e0

1 file changed

Lines changed: 15 additions & 3 deletions

File tree

deepmd/dpmodel/utils/network.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -984,22 +984,34 @@ def aggregate(
984984
"""
985985
xp = array_api_compat.array_namespace(data, owners)
986986

987+
def add_at(x, indices, values):
988+
for idx, val in zip(indices, values):
989+
x[idx] = x[idx] + val
990+
return x
991+
987992
def bincount(x, weights=None, minlength=0):
988993
if weights is None:
989994
weights = xp.ones_like(x)
990995
result = xp.zeros((max(minlength, int(x.max()) + 1),), dtype=weights.dtype)
991-
xp.add.at(result, x, weights)
996+
result = add_at(result, x, weights)
992997
return result
993998

994-
bin_count = bincount(owners)
999+
if hasattr(xp, "bincount"):
1000+
bin_count = xp.bincount(owners)
1001+
else:
1002+
# for array_api_strict
1003+
bin_count = bincount(owners)
9951004
bin_count = xp.where(bin_count == 0, xp.ones_like(bin_count), bin_count)
9961005

9971006
if num_owner is not None and bin_count.shape[0] != num_owner:
9981007
difference = num_owner - bin_count.shape[0]
9991008
bin_count = xp.concat([bin_count, xp.ones(difference, dtype=bin_count.dtype)])
10001009

10011010
output = xp.zeros((bin_count.shape[0], data.shape[1]), dtype=data.dtype)
1002-
xp.add.at(output, owners, data)
1011+
if hasattr(xp, "add") and hasattr(xp.add, "at"):
1012+
xp.add.at(output, owners, data)
1013+
else:
1014+
output = add_at(output, owners, data)
10031015

10041016
if average:
10051017
output = (output.T / bin_count).T

0 commit comments

Comments
 (0)