Skip to content

Commit 6b77202

Browse files
committed
rename
1 parent a491ef5 commit 6b77202

2 files changed

Lines changed: 7 additions & 7 deletions

File tree

deepmd/dpmodel/array_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray
9494
raise NotImplementedError("Only JAX arrays are supported.")
9595

9696

97-
def add_at(x, indices, values):
97+
def xp_add_at(x, indices, values):
9898
"""Adds values to the specified indices of x in place or returns new x (for JAX)."""
9999
xp = array_api_compat.array_namespace(x, indices, values)
100100
if array_api_compat.is_numpy_array(x):
@@ -115,7 +115,7 @@ def add_at(x, indices, values):
115115
return x
116116

117117

118-
def bincount(x, weights=None, minlength=0):
118+
def xp_bincount(x, weights=None, minlength=0):
119119
"""Counts the number of occurrences of each value in x."""
120120
xp = array_api_compat.array_namespace(x)
121121
if array_api_compat.is_numpy_array(x) or array_api_compat.is_jax_array(x):
@@ -124,5 +124,5 @@ def bincount(x, weights=None, minlength=0):
124124
if weights is None:
125125
weights = xp.ones_like(x)
126126
result = xp.zeros((max(minlength, int(xp.max(x)) + 1),), dtype=weights.dtype)
127-
result = add_at(result, x, weights)
127+
result = xp_add_at(result, x, weights)
128128
return result

deepmd/dpmodel/utils/network.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
NativeOP,
2222
)
2323
from deepmd.dpmodel.array_api import (
24-
add_at,
25-
bincount,
2624
support_array_api,
25+
xp_add_at,
26+
xp_bincount,
2727
)
2828
from deepmd.dpmodel.common import (
2929
to_numpy_array,
@@ -985,15 +985,15 @@ def aggregate(
985985
output: [num_owner, feature_dim]
986986
"""
987987
xp = array_api_compat.array_namespace(data, owners)
988-
bin_count = bincount(owners)
988+
bin_count = xp_bincount(owners)
989989
bin_count = xp.where(bin_count == 0, xp.ones_like(bin_count), bin_count)
990990

991991
if num_owner is not None and bin_count.shape[0] != num_owner:
992992
difference = num_owner - bin_count.shape[0]
993993
bin_count = xp.concat([bin_count, xp.ones(difference, dtype=bin_count.dtype)])
994994

995995
output = xp.zeros((bin_count.shape[0], data.shape[1]), dtype=data.dtype)
996-
output = add_at(output, owners, data)
996+
output = xp_add_at(output, owners, data)
997997

998998
if average:
999999
output = xp.transpose(xp.transpose(output) / bin_count)

0 commit comments

Comments
 (0)