@@ -94,7 +94,7 @@ def xp_scatter_sum(input, dim, index: np.ndarray, src: np.ndarray) -> np.ndarray
9494 raise NotImplementedError ("Only JAX arrays are supported." )
9595
9696
97- def add_at (x , indices , values ):
97+ def xp_add_at (x , indices , values ):
9898 """Adds values to the specified indices of x in place or returns new x (for JAX)."""
9999 xp = array_api_compat .array_namespace (x , indices , values )
100100 if array_api_compat .is_numpy_array (x ):
@@ -115,7 +115,7 @@ def add_at(x, indices, values):
115115 return x
116116
117117
118- def bincount (x , weights = None , minlength = 0 ):
118+ def xp_bincount (x , weights = None , minlength = 0 ):
119119 """Counts the number of occurrences of each value in x."""
120120 xp = array_api_compat .array_namespace (x )
121121 if array_api_compat .is_numpy_array (x ) or array_api_compat .is_jax_array (x ):
@@ -124,5 +124,5 @@ def bincount(x, weights=None, minlength=0):
124124 if weights is None :
125125 weights = xp .ones_like (x )
126126 result = xp .zeros ((max (minlength , int (xp .max (x )) + 1 ),), dtype = weights .dtype )
127- result = add_at (result , x , weights )
127+ result = xp_add_at (result , x , weights )
128128 return result
0 commit comments