|
21 | 21 | NativeOP, |
22 | 22 | ) |
23 | 23 | from deepmd.dpmodel.array_api import ( |
| 24 | + add_at, |
| 25 | + bincount, |
24 | 26 | support_array_api, |
25 | 27 | ) |
26 | 28 | from deepmd.dpmodel.common import ( |
@@ -983,42 +985,15 @@ def aggregate( |
983 | 985 | output: [num_owner, feature_dim] |
984 | 986 | """ |
985 | 987 | xp = array_api_compat.array_namespace(data, owners) |
986 | | - |
987 | | - def add_at(x, indices, values): |
988 | | - unique_ids = xp.unique(indices) |
989 | | - for i in unique_ids: |
990 | | - mask = xp.where(indices == i, 1, 0) |
991 | | - mask = xp.expand_dims(mask, axis=1) if len(values.shape) != 1 else mask |
992 | | - selected = values * mask |
993 | | - summed = xp.sum(selected, axis=0) |
994 | | - x[i] = x[i] + summed |
995 | | - |
996 | | - return x |
997 | | - |
998 | | - def bincount(x, weights=None, minlength=0): |
999 | | - if weights is None: |
1000 | | - weights = xp.ones_like(x) |
1001 | | - result = xp.zeros((max(minlength, int(xp.max(x)) + 1),), dtype=weights.dtype) |
1002 | | - result = add_at(result, x, weights) |
1003 | | - return result |
1004 | | - |
1005 | | - if hasattr(xp, "bincount"): |
1006 | | - bin_count = xp.bincount(owners) |
1007 | | - else: |
1008 | | - # for array_api_strict |
1009 | | - bin_count = bincount(owners) |
| 988 | + bin_count = bincount(owners) |
1010 | 989 | bin_count = xp.where(bin_count == 0, xp.ones_like(bin_count), bin_count) |
1011 | 990 |
|
1012 | 991 | if num_owner is not None and bin_count.shape[0] != num_owner: |
1013 | 992 | difference = num_owner - bin_count.shape[0] |
1014 | 993 | bin_count = xp.concat([bin_count, xp.ones(difference, dtype=bin_count.dtype)]) |
1015 | 994 |
|
1016 | 995 | output = xp.zeros((bin_count.shape[0], data.shape[1]), dtype=data.dtype) |
1017 | | - if hasattr(xp, "add") and hasattr(xp.add, "at"): |
1018 | | - xp.add.at(output, owners, data) |
1019 | | - else: |
1020 | | - # for array_api_strict |
1021 | | - output = add_at(output, owners, data) |
| 996 | + output = add_at(output, owners, data) |
1022 | 997 |
|
1023 | 998 | if average: |
1024 | 999 | output = xp.transpose(xp.transpose(output) / bin_count) |
|
0 commit comments