Skip to content

Commit 51702ea

Browse files
committed
mimic bincount
1 parent fc84503 commit 51702ea

1 file changed

Lines changed: 8 additions & 1 deletion

File tree

deepmd/dpmodel/utils/network.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -984,7 +984,14 @@ def aggregate(
984984
"""
985985
xp = array_api_compat.array_namespace(data, owners)
986986

987-
bin_count = xp.bincount(owners)
987+
def bincount(x, weights=None, minlength=0):
988+
if weights is None:
989+
weights = xp.ones_like(x)
990+
result = xp.zeros((max(minlength, int(x.max()) + 1),), dtype=weights.dtype)
991+
xp.add.at(result, x, weights)
992+
return result
993+
994+
bin_count = bincount(owners)
988995
bin_count = xp.where(bin_count == 0, xp.ones_like(bin_count), bin_count)
989996

990997
if num_owner is not None and bin_count.shape[0] != num_owner:

0 commit comments

Comments
 (0)